From b0cfeff06d9dd99a16dfae19c5df3c73c5864fb9 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 3 Sep 2014 16:57:56 +0200 Subject: fix #341 - work on flows instead of request/response internally. --- examples/flowbasic | 12 +- examples/proxapp | 12 +- libmproxy/app.py | 8 +- libmproxy/console/__init__.py | 26 +-- libmproxy/dump.py | 32 ++-- libmproxy/filt.py | 4 +- libmproxy/flow.py | 106 ++++++------ libmproxy/protocol/http.py | 102 +++++------ libmproxy/protocol/primitives.py | 23 +-- libmproxy/script.py | 11 +- test/test_console.py | 12 +- test/test_console_common.py | 2 +- test/test_dump.py | 41 ++--- test/test_flow.py | 353 ++++++++++++++++++--------------------- test/test_protocol_http.py | 37 ++-- test/test_proxy.py | 10 +- test/test_script.py | 14 +- test/test_server.py | 57 ++++--- test/tservers.py | 12 +- test/tutils.py | 120 ++++++------- 20 files changed, 463 insertions(+), 531 deletions(-) diff --git a/examples/flowbasic b/examples/flowbasic index b8184262..8dbe2f28 100755 --- a/examples/flowbasic +++ b/examples/flowbasic @@ -16,16 +16,16 @@ class MyMaster(flow.FlowMaster): except KeyboardInterrupt: self.shutdown() - def handle_request(self, r): - f = flow.FlowMaster.handle_request(self, r) + def handle_request(self, f): + f = flow.FlowMaster.handle_request(self, f) if f: - r.reply() + f.reply() return f - def handle_response(self, r): - f = flow.FlowMaster.handle_response(self, r) + def handle_response(self, f): + f = flow.FlowMaster.handle_response(self, f) if f: - r.reply() + f.reply() print f return f diff --git a/examples/proxapp b/examples/proxapp index 3a94cd55..9f299d25 100755 --- a/examples/proxapp +++ b/examples/proxapp @@ -20,16 +20,16 @@ class MyMaster(flow.FlowMaster): except KeyboardInterrupt: self.shutdown() - def handle_request(self, r): - f = flow.FlowMaster.handle_request(self, r) + def handle_request(self, f): + f = flow.FlowMaster.handle_request(self, f) if f: - r.reply() + f.reply() return f - def handle_response(self, r): - f = flow.FlowMaster.handle_response(self, r) + def handle_response(self, f): + f = flow.FlowMaster.handle_response(self, f) if f: - r.reply() + f.reply() print f return f diff --git a/libmproxy/app.py b/libmproxy/app.py index 9941d6ea..ed7ec72a 100644 --- a/libmproxy/app.py +++ b/libmproxy/app.py @@ -1,7 +1,7 @@ from __future__ import absolute_import import flask -import os.path, os -from . import proxy +import os +from .proxy import config mapp = flask.Flask(__name__) mapp.debug = True @@ -18,12 +18,12 @@ def index(): @mapp.route("/cert/pem") def certs_pem(): - p = os.path.join(master().server.config.confdir, proxy.config.CONF_BASENAME + "-ca-cert.pem") + p = os.path.join(master().server.config.confdir, config.CONF_BASENAME + "-ca-cert.pem") return flask.Response(open(p, "rb").read(), mimetype='application/x-x509-ca-cert') @mapp.route("/cert/p12") def certs_p12(): - p = os.path.join(master().server.config.confdir, proxy.config.CONF_BASENAME + "-ca-cert.p12") + p = os.path.join(master().server.config.confdir, config.CONF_BASENAME + "-ca-cert.p12") return flask.Response(open(p, "rb").read(), mimetype='application/x-pkcs12') diff --git a/libmproxy/console/__init__.py b/libmproxy/console/__init__.py index 1325aae5..a5920915 100644 --- a/libmproxy/console/__init__.py +++ b/libmproxy/console/__init__.py @@ -268,8 +268,8 @@ class ConsoleState(flow.State): d = self.flowsettings.get(flow, {}) return d.get(key, default) - def add_request(self, req): - f = flow.State.add_request(self, req) + def add_request(self, f): + flow.State.add_request(self, f) if self.focus is None: self.set_focus(0) elif self.follow_focus: @@ -996,11 +996,11 @@ class ConsoleMaster(flow.FlowMaster): if hasattr(self.statusbar, "refresh_flow"): self.statusbar.refresh_flow(c) - def process_flow(self, f, r): + def process_flow(self, f): if self.state.intercept and f.match(self.state.intercept) and not f.request.is_replay: f.intercept() else: - r.reply() + f.reply() self.sync_list_view() self.refresh_flow(f) @@ -1022,20 +1022,20 @@ class ConsoleMaster(flow.FlowMaster): self.eventlist.set_focus(len(self.eventlist)-1) # Handlers - def handle_error(self, r): - f = flow.FlowMaster.handle_error(self, r) + def handle_error(self, f): + f = flow.FlowMaster.handle_error(self, f) if f: - self.process_flow(f, r) + self.process_flow(f) return f - def handle_request(self, r): - f = flow.FlowMaster.handle_request(self, r) + def handle_request(self, f): + f = flow.FlowMaster.handle_request(self, f) if f: - self.process_flow(f, r) + self.process_flow(f) return f - def handle_response(self, r): - f = flow.FlowMaster.handle_response(self, r) + def handle_response(self, f): + f = flow.FlowMaster.handle_response(self, f) if f: - self.process_flow(f, r) + self.process_flow(f) return f diff --git a/libmproxy/dump.py b/libmproxy/dump.py index aeb34cc3..8ecd56e7 100644 --- a/libmproxy/dump.py +++ b/libmproxy/dump.py @@ -50,13 +50,13 @@ def str_response(resp): return r -def str_request(req, showhost): - if req.flow.client_conn: - c = req.flow.client_conn.address.host +def str_request(f, showhost): + if f.client_conn: + c = f.client_conn.address.host else: c = "[replay]" - r = "%s %s %s"%(c, req.method, req.get_url(showhost)) - if req.stickycookie: + r = "%s %s %s"%(c, f.request.method, f.request.get_url(showhost, f)) + if f.request.stickycookie: r = "[stickycookie] " + r return r @@ -185,16 +185,16 @@ class DumpMaster(flow.FlowMaster): result = " << %s"%f.error.msg if self.o.flow_detail == 1: - print >> self.outfile, str_request(f.request, self.showhost) + print >> self.outfile, str_request(f, self.showhost) print >> self.outfile, result elif self.o.flow_detail == 2: - print >> self.outfile, str_request(f.request, self.showhost) + print >> self.outfile, str_request(f, self.showhost) print >> self.outfile, self.indent(4, f.request.headers) print >> self.outfile print >> self.outfile, result print >> self.outfile, "\n" elif self.o.flow_detail >= 3: - print >> self.outfile, str_request(f.request, self.showhost) + print >> self.outfile, str_request(f, self.showhost) print >> self.outfile, self.indent(4, f.request.headers) if utils.isBin(f.request.content): print >> self.outfile, self.indent(4, netlib.utils.hexdump(f.request.content)) @@ -206,21 +206,21 @@ class DumpMaster(flow.FlowMaster): if self.o.flow_detail: self.outfile.flush() - def handle_request(self, r): - f = flow.FlowMaster.handle_request(self, r) + def handle_request(self, f): + flow.FlowMaster.handle_request(self, f) if f: - r.reply() + f.reply() return f - def handle_response(self, msg): - f = flow.FlowMaster.handle_response(self, msg) + def handle_response(self, f): + flow.FlowMaster.handle_response(self, f) if f: - msg.reply() + f.reply() self._process_flow(f) return f - def handle_error(self, msg): - f = flow.FlowMaster.handle_error(self, msg) + def handle_error(self, f): + flow.FlowMaster.handle_error(self, f) if f: self._process_flow(f) return f diff --git a/libmproxy/filt.py b/libmproxy/filt.py index e17ed735..925dbfbb 100644 --- a/libmproxy/filt.py +++ b/libmproxy/filt.py @@ -208,7 +208,7 @@ class FDomain(_Rex): code = "d" help = "Domain" def __call__(self, f): - return bool(re.search(self.expr, f.request.get_host(), re.IGNORECASE)) + return bool(re.search(self.expr, f.request.get_host(False, f), re.IGNORECASE)) class FUrl(_Rex): @@ -222,7 +222,7 @@ class FUrl(_Rex): return klass(*toks) def __call__(self, f): - return re.search(self.expr, f.request.get_url()) + return re.search(self.expr, f.request.get_url(False, f)) class _Int(_Action): diff --git a/libmproxy/flow.py b/libmproxy/flow.py index 2540435e..eb183d9f 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -34,11 +34,11 @@ class AppRegistry: """ Returns an WSGIAdaptor instance if request matches an app, or None. """ - if (request.get_host(), request.get_port()) in self.apps: - return self.apps[(request.get_host(), request.get_port())] + if (request.host, request.port) in self.apps: + return self.apps[(request.host, request.port)] if "host" in request.headers: host = request.headers["host"][0] - return self.apps.get((host, request.get_port()), None) + return self.apps.get((host, request.port), None) class ReplaceHooks: @@ -185,11 +185,11 @@ class ClientPlaybackState: n = self.flows.pop(0) n.request.reply = controller.DummyReply() n.client_conn = None - self.current = master.handle_request(n.request) + self.current = master.handle_request(n) if not testing and not self.current.response: - master.replay_request(self.current) # pragma: no cover + master.replay_request(self.current) # pragma: no cover elif self.current.response: - master.handle_response(self.current.response) + master.handle_response(self.current) class ServerPlaybackState: @@ -260,8 +260,8 @@ class StickyCookieState: Returns a (domain, port, path) tuple. """ return ( - m["domain"] or f.request.get_host(), - f.request.get_port(), + m["domain"] or f.request.get_host(False, f), + f.request.get_port(f), m["path"] or "/" ) @@ -279,7 +279,7 @@ class StickyCookieState: c = Cookie.SimpleCookie(str(i)) m = c.values()[0] k = self.ckey(m, f) - if self.domain_match(f.request.get_host(), k[0]): + if self.domain_match(f.request.get_host(False, f), k[0]): self.jar[self.ckey(m, f)] = m def handle_request(self, f): @@ -287,8 +287,8 @@ class StickyCookieState: if f.match(self.flt): for i in self.jar.keys(): match = [ - self.domain_match(f.request.get_host(), i[0]), - f.request.get_port() == i[1], + self.domain_match(f.request.get_host(False, f), i[0]), + f.request.get_port(f) == i[1], f.request.path.startswith(i[2]) ] if all(match): @@ -307,7 +307,7 @@ class StickyAuthState: self.hosts = {} def handle_request(self, f): - host = f.request.get_host() + host = f.request.get_host(False, f) if "authorization" in f.request.headers: self.hosts[host] = f.request.headers["authorization"] elif f.match(self.flt): @@ -342,33 +342,30 @@ class State(object): c += 1 return c - def add_request(self, req): + def add_request(self, flow): """ Add a request to the state. Returns the matching flow. """ - f = req.flow - self._flow_list.append(f) - if f.match(self._limit): - self.view.append(f) - return f + self._flow_list.append(flow) + if flow.match(self._limit): + self.view.append(flow) + return flow - def add_response(self, resp): + def add_response(self, f): """ Add a response to the state. Returns the matching flow. """ - f = resp.flow if not f: return False if f.match(self._limit) and not f in self.view: self.view.append(f) return f - def add_error(self, err): + def add_error(self, f): """ Add an error response to the state. Returns the matching flow, or None if there isn't one. """ - f = err.flow if not f: return None if f.match(self._limit) and not f in self.view: @@ -586,7 +583,7 @@ class FlowMaster(controller.Master): response.is_replay = True if self.refresh_server_playback: response.refresh() - flow.request.reply(response) + flow.reply(response) if self.server_playback.count() == 0: self.stop_server_playback() return True @@ -612,16 +609,14 @@ class FlowMaster(controller.Master): """ Loads a flow, and returns a new flow object. """ + f.reply = controller.DummyReply() if f.request: - f.request.reply = controller.DummyReply() - fr = self.handle_request(f.request) + self.handle_request(f) if f.response: - f.response.reply = controller.DummyReply() - self.handle_response(f.response) + self.handle_response(f) if f.error: - f.error.reply = controller.DummyReply() - self.handle_error(f.error) - return fr + self.handle_error(f) + return f def load_flows(self, fr): """ @@ -647,7 +642,7 @@ class FlowMaster(controller.Master): if self.kill_nonreplay: f.kill(self) else: - f.request.reply() + f.reply() def process_new_response(self, f): if self.stickycookie_state: @@ -694,54 +689,49 @@ class FlowMaster(controller.Master): self.run_script_hook("serverconnect", sc) sc.reply() - def handle_error(self, r): - f = self.state.add_error(r) - if f: - self.run_script_hook("error", f) + def handle_error(self, f): + self.state.add_error(f) + self.run_script_hook("error", f) if self.client_playback: self.client_playback.clear(f) - r.reply() + f.reply() return f - def handle_request(self, r): - if r.flow.live: - app = self.apps.get(r) + def handle_request(self, f): + if f.live: + app = self.apps.get(f.request) if app: - err = app.serve(r, r.flow.client_conn.wfile, **{"mitmproxy.master": self}) + err = app.serve(f, f.client_conn.wfile, **{"mitmproxy.master": self}) if err: self.add_event("Error in wsgi app. %s"%err, "error") - r.reply(protocol.KILL) + f.reply(protocol.KILL) return - f = self.state.add_request(r) + self.state.add_request(f) self.replacehooks.run(f) self.setheaders.run(f) self.run_script_hook("request", f) self.process_new_request(f) return f - def handle_responseheaders(self, resp): - f = resp.flow + def handle_responseheaders(self, f): self.run_script_hook("responseheaders", f) if self.stream_large_bodies: self.stream_large_bodies.run(f, False) - resp.reply() + f.reply() return f - def handle_response(self, r): - f = self.state.add_response(r) - if f: - self.replacehooks.run(f) - self.setheaders.run(f) - self.run_script_hook("response", f) - if self.client_playback: - self.client_playback.clear(f) - self.process_new_response(f) - if self.stream: - self.stream.add(f) - else: - r.reply() + def handle_response(self, f): + self.state.add_response(f) + self.replacehooks.run(f) + self.setheaders.run(f) + self.run_script_hook("response", f) + if self.client_playback: + self.client_playback.clear(f) + self.process_new_response(f) + if self.stream: + self.stream.add(f) return f def shutdown(self): diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index 658c08ed..3f9eecb3 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -77,9 +77,6 @@ class HTTPMessage(stateobject.SimpleStateObject): self.timestamp_start = timestamp_start if timestamp_start is not None else utils.timestamp() self.timestamp_end = timestamp_end if timestamp_end is not None else utils.timestamp() - self.flow = None # will usually be set by the flow backref mixin - """@type: HTTPFlow""" - _stateobject_attributes = dict( httpversion=tuple, headers=ODictCaseless, @@ -346,10 +343,10 @@ class HTTPRequest(HTTPMessage): del headers[k] if headers["Upgrade"] == ["h2c"]: # Suppress HTTP2 https://http2.github.io/http2-spec/index.html#discover-http del headers["Upgrade"] - if not 'host' in headers: + if not 'host' in headers and self.scheme and self.host and self.port: headers["Host"] = [utils.hostport(self.scheme, - self.host or self.flow.server_conn.address.host, - self.port or self.flow.server_conn.address.port)] + self.host, + self.port)] if self.content: headers["Content-Length"] = [str(len(self.content))] @@ -429,16 +426,16 @@ class HTTPRequest(HTTPMessage): self.headers["Content-Type"] = [HDR_FORM_URLENCODED] self.content = utils.urlencode(odict.lst) - def get_path_components(self): + def get_path_components(self, f): """ Returns the path components of the URL as a list of strings. Components are unquoted. """ - _, _, path, _, _, _ = urlparse.urlparse(self.get_url()) + _, _, path, _, _, _ = urlparse.urlparse(self.get_url(False, f)) return [urllib.unquote(i) for i in path.split("/") if i] - def set_path_components(self, lst): + def set_path_components(self, lst, f): """ Takes a list of strings, and sets the path component of the URL. @@ -446,27 +443,27 @@ class HTTPRequest(HTTPMessage): """ lst = [urllib.quote(i, safe="") for i in lst] path = "/" + "/".join(lst) - scheme, netloc, _, params, query, fragment = urlparse.urlparse(self.get_url()) - self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment])) + scheme, netloc, _, params, query, fragment = urlparse.urlparse(self.get_url(False, f)) + self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment]), f) - def get_query(self): + def get_query(self, f): """ Gets the request query string. Returns an ODict object. """ - _, _, _, _, query, _ = urlparse.urlparse(self.get_url()) + _, _, _, _, query, _ = urlparse.urlparse(self.get_url(False, f)) if query: return ODict(utils.urldecode(query)) return ODict([]) - def set_query(self, odict): + def set_query(self, odict, f): """ Takes an ODict object, and sets the request query string. """ - scheme, netloc, path, params, _, fragment = urlparse.urlparse(self.get_url()) + scheme, netloc, path, params, _, fragment = urlparse.urlparse(self.get_url(False, f)) query = utils.urlencode(odict.lst) - self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment])) + self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment]), f) - def get_host(self, hostheader=False): + def get_host(self, hostheader, flow): """ Heuristic to get the host of the request. @@ -484,16 +481,16 @@ class HTTPRequest(HTTPMessage): if self.host: host = self.host else: - for s in self.flow.server_conn.state: + for s in flow.server_conn.state: if s[0] == "http" and s[1]["state"] == "connect": host = s[1]["host"] break if not host: - host = self.flow.server_conn.address.host + host = flow.server_conn.address.host host = host.encode("idna") return host - def get_scheme(self): + def get_scheme(self, flow): """ Returns the request port, either from the request itself or from the flow's server connection """ @@ -501,20 +498,20 @@ class HTTPRequest(HTTPMessage): return self.scheme if self.form_out == "authority": # On SSLed connections, the original CONNECT request is still unencrypted. return "http" - return "https" if self.flow.server_conn.ssl_established else "http" + return "https" if flow.server_conn.ssl_established else "http" - def get_port(self): + def get_port(self, flow): """ Returns the request port, either from the request itself or from the flow's server connection """ if self.port: return self.port - for s in self.flow.server_conn.state: + for s in flow.server_conn.state: if s[0] == "http" and s[1].get("state") == "connect": return s[1]["port"] - return self.flow.server_conn.address.port + return flow.server_conn.address.port - def get_url(self, hostheader=False): + def get_url(self, hostheader, flow): """ Returns a URL string, constructed from the Request's URL components. @@ -522,13 +519,13 @@ class HTTPRequest(HTTPMessage): Host header to construct the URL. """ if self.form_out == "authority": # upstream proxy mode - return "%s:%s" % (self.get_host(hostheader), self.get_port()) - return utils.unparse_url(self.get_scheme(), - self.get_host(hostheader), - self.get_port(), + return "%s:%s" % (self.get_host(hostheader, flow), self.get_port(flow)) + return utils.unparse_url(self.get_scheme(flow), + self.get_host(hostheader, flow), + self.get_port(flow), self.path).encode('ascii') - def set_url(self, url): + def set_url(self, url, flow): """ Parses a URL specification, and updates the Request's information accordingly. @@ -543,14 +540,14 @@ class HTTPRequest(HTTPMessage): self.path = path - if host != self.get_host() or port != self.get_port(): - if self.flow.live: - self.flow.live.change_server((host, port), ssl=is_ssl) + if host != self.get_host(False, flow) or port != self.get_port(flow): + if flow.live: + flow.live.change_server((host, port), ssl=is_ssl) else: # There's not live server connection, we're just changing the attributes here. - self.flow.server_conn = ServerConnection((host, port), + flow.server_conn = ServerConnection((host, port), proxy.AddressPriority.MANUALLY_CHANGED) - self.flow.server_conn.ssl_established = is_ssl + flow.server_conn.ssl_established = is_ssl # If this is an absolute request, replace the attributes on the request object as well. if self.host: @@ -802,8 +799,6 @@ class HTTPFlow(Flow): self.intercepting = False # FIXME: Should that rather be an attribute of Flow? - _backrefattr = Flow._backrefattr + ("request", "response") - _stateobject_attributes = Flow._stateobject_attributes.copy() _stateobject_attributes.update( request=HTTPRequest, @@ -855,13 +850,10 @@ class HTTPFlow(Flow): Kill this request. """ self.error = Error("Connection killed") - self.error.reply = controller.DummyReply() - if self.request and not self.request.reply.acked: - self.request.reply(KILL) - elif self.response and not self.response.reply.acked: - self.response.reply(KILL) - master.handle_error(self.error) self.intercepting = False + self.reply(KILL) + self.reply = controller.DummyReply() + master.handle_error(self) def intercept(self): """ @@ -874,12 +866,8 @@ class HTTPFlow(Flow): """ Continue with the flow - called after an intercept(). """ - if self.request: - if not self.request.reply.acked: - self.request.reply() - elif self.response and not self.response.reply.acked: - self.response.reply() - self.intercepting = False + self.intercepting = False + self.reply() def replace(self, pattern, repl, *args, **kwargs): """ @@ -961,7 +949,7 @@ class HTTPHandler(ProtocolHandler): # in an Error object that has an attached request that has not been # sent through to the Master. flow.request = req - request_reply = self.c.channel.ask("request", flow.request) + request_reply = self.c.channel.ask("request", flow) self.determine_server_address(flow, flow.request) # The inline script may have changed request.host flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow @@ -976,7 +964,7 @@ class HTTPHandler(ProtocolHandler): flow.response = self.get_response_from_server(flow.request, include_body=False) # call the appropriate script hook - this is an opportunity for an inline script to set flow.stream = True - self.c.channel.ask("responseheaders", flow.response) + self.c.channel.ask("responseheaders", flow) # now get the rest of the request body, if body still needs to be read but not streaming this response if flow.response.stream: @@ -991,7 +979,7 @@ class HTTPHandler(ProtocolHandler): flow.server_conn = self.c.server_conn self.c.log("response", "debug", [flow.response._assemble_first_line()]) - response_reply = self.c.channel.ask("response", flow.response) + response_reply = self.c.channel.ask("response", flow) if response_reply is None or response_reply == KILL: return False @@ -1079,7 +1067,7 @@ class HTTPHandler(ProtocolHandler): # TODO: no flows without request or with both request and response at the moment. if flow.request and not flow.response: flow.error = Error(message) - self.c.channel.ask("error", flow.error) + self.c.channel.ask("error", flow) try: code = getattr(error, "code", 502) @@ -1204,12 +1192,12 @@ class RequestReplayThread(threading.Thread): except proxy.ProxyError: pass if not server_address: - server_address = (r.get_host(), r.get_port()) + server_address = (r.get_host(False, self.flow), r.get_port(self.flow)) server = ServerConnection(server_address, None) server.connect() - if server_ssl or r.get_scheme() == "https": + if server_ssl or r.get_scheme(self.flow) == "https": if self.config.http_form_out == "absolute": # form_out == absolute -> forward mode -> send CONNECT send_connect_request(server, r.get_host(), r.get_port()) r.form_out = "relative" @@ -1218,9 +1206,9 @@ class RequestReplayThread(threading.Thread): server.send(r._assemble()) self.flow.response = HTTPResponse.from_stream(server.rfile, r.method, body_size_limit=self.config.body_size_limit) - self.channel.ask("response", self.flow.response) + self.channel.ask("response", self.flow) except (proxy.ProxyError, http.HttpError, tcp.NetLibError), v: self.flow.error = Error(repr(v)) - self.channel.ask("error", self.flow.error) + self.channel.ask("error", self.flow) finally: r.form_out = form_out_backup \ No newline at end of file diff --git a/libmproxy/protocol/primitives.py b/libmproxy/protocol/primitives.py index a227d904..a84b4061 100644 --- a/libmproxy/protocol/primitives.py +++ b/libmproxy/protocol/primitives.py @@ -9,24 +9,6 @@ from ..proxy.connection import ClientConnection, ServerConnection KILL = 0 # const for killed requests -class BackreferenceMixin(object): - """ - If an attribute from the _backrefattr tuple is set, - this mixin sets a reference back on the attribute object. - Example: - e = Error() - f = Flow() - f.error = e - assert f is e.flow - """ - _backrefattr = tuple() - - def __setattr__(self, key, value): - super(BackreferenceMixin, self).__setattr__(key, value) - if key in self._backrefattr and value is not None: - setattr(value, self._backrefname, self) - - class Error(stateobject.SimpleStateObject): """ An Error. @@ -70,7 +52,7 @@ class Error(stateobject.SimpleStateObject): return c -class Flow(stateobject.SimpleStateObject, BackreferenceMixin): +class Flow(stateobject.SimpleStateObject): def __init__(self, conntype, client_conn, server_conn, live=None): self.conntype = conntype self.client_conn = client_conn @@ -84,9 +66,6 @@ class Flow(stateobject.SimpleStateObject, BackreferenceMixin): """@type: Error""" self._backup = None - _backrefattr = ("error",) - _backrefname = "flow" - _stateobject_attributes = dict( error=Error, client_conn=ClientConnection, diff --git a/libmproxy/script.py b/libmproxy/script.py index e582c4e8..706d84d5 100644 --- a/libmproxy/script.py +++ b/libmproxy/script.py @@ -125,13 +125,8 @@ def _handle_concurrent_reply(fn, o, *args, **kwargs): def concurrent(fn): - if fn.func_name in ["request", "response", "error"]: - def _concurrent(ctx, flow): - r = getattr(flow, fn.func_name) - _handle_concurrent_reply(fn, r, ctx, flow) - return _concurrent - elif fn.func_name in ["clientconnect", "serverconnect", "clientdisconnect"]: - def _concurrent(ctx, conn): - _handle_concurrent_reply(fn, conn, ctx, conn) + if fn.func_name in ("request", "response", "error", "clientconnect", "serverconnect", "clientdisconnect"): + def _concurrent(ctx, obj): + _handle_concurrent_reply(fn, obj, ctx, obj) return _concurrent raise NotImplementedError("Concurrent decorator not supported for this method.") diff --git a/test/test_console.py b/test/test_console.py index 0c5b4591..3b6c941d 100644 --- a/test/test_console.py +++ b/test/test_console.py @@ -51,20 +51,20 @@ class TestConsoleState: assert c.get_focus() == (None, None) def _add_request(self, state): - r = tutils.treq() - return state.add_request(r) + f = tutils.tflow() + return state.add_request(f) def _add_response(self, state): f = self._add_request(state) - r = tutils.tresp(f.request) - state.add_response(r) + f.response = tutils.tresp() + state.add_response(f) def test_add_response(self): c = console.ConsoleState() f = self._add_request(c) - r = tutils.tresp(f.request) + f.response = tutils.tresp() c.focus = None - c.add_response(r) + c.add_response(f) def test_focus_view(self): c = console.ConsoleState() diff --git a/test/test_console_common.py b/test/test_console_common.py index d798e4dc..1949dad5 100644 --- a/test/test_console_common.py +++ b/test/test_console_common.py @@ -9,7 +9,7 @@ import tutils def test_format_flow(): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) assert common.format_flow(f, True) assert common.format_flow(f, True, hostheader=True) assert common.format_flow(f, True, extended=True) diff --git a/test/test_dump.py b/test/test_dump.py index 6f70450f..fd93cc03 100644 --- a/test/test_dump.py +++ b/test/test_dump.py @@ -10,31 +10,27 @@ def test_strfuncs(): t.is_replay = True dump.str_response(t) - t = tutils.treq() - t.flow.client_conn = None - t.stickycookie = True - assert "stickycookie" in dump.str_request(t, False) - assert "stickycookie" in dump.str_request(t, True) - assert "replay" in dump.str_request(t, False) - assert "replay" in dump.str_request(t, True) + f = tutils.tflow() + f.client_conn = None + f.request.stickycookie = True + assert "stickycookie" in dump.str_request(f, False) + assert "stickycookie" in dump.str_request(f, True) + assert "replay" in dump.str_request(f, False) + assert "replay" in dump.str_request(f, True) class TestDumpMaster: def _cycle(self, m, content): - req = tutils.treq(content=content) + f = tutils.tflow(req=tutils.treq(content)) l = Log("connect") l.reply = mock.MagicMock() m.handle_log(l) - cc = req.flow.client_conn - cc.reply = mock.MagicMock() - m.handle_clientconnect(cc) - sc = proxy.connection.ServerConnection((req.get_host(), req.get_port()), None) - sc.reply = mock.MagicMock() - m.handle_serverconnect(sc) - m.handle_request(req) - resp = tutils.tresp(req, content=content) - f = m.handle_response(resp) - m.handle_clientdisconnect(cc) + m.handle_clientconnect(f.client_conn) + m.handle_serverconnect(f.server_conn) + m.handle_request(f) + f.response = tutils.tresp(content) + f = m.handle_response(f) + m.handle_clientdisconnect(f.client_conn) return f def _dummy_cycle(self, n, filt, content, **options): @@ -49,8 +45,7 @@ class TestDumpMaster: def _flowfile(self, path): f = open(path, "wb") fw = flow.FlowWriter(f) - t = tutils.tflow_full() - t.response = tutils.tresp(t.request) + t = tutils.tflow(resp=True) fw.add(t) f.close() @@ -58,9 +53,9 @@ class TestDumpMaster: cs = StringIO() o = dump.Options(flow_detail=1) m = dump.DumpMaster(None, o, None, outfile=cs) - f = tutils.tflow_err() - m.handle_request(f.request) - assert m.handle_error(f.error) + f = tutils.tflow(err=True) + m.handle_request(f) + assert m.handle_error(f) assert "error" in cs.getvalue() def test_replay(self): diff --git a/test/test_flow.py b/test/test_flow.py index 88e7b9d7..6e9464e7 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -14,7 +14,8 @@ def test_app_registry(): ar.add("foo", "domain", 80) r = tutils.treq() - r.set_url("http://domain:80/") + r.host = "domain" + r.port = 80 assert ar.get(r) r.port = 81 @@ -32,8 +33,7 @@ def test_app_registry(): class TestStickyCookieState: def _response(self, cookie, host): s = flow.StickyCookieState(filt.parse(".*")) - f = tutils.tflow_full() - f.server_conn.address = tcp.Address((host, 80)) + f = tutils.tflow(req=tutils.treq(host=host, port=80), resp=True) f.response.headers["Set-Cookie"] = [cookie] s.handle_response(f) return s, f @@ -66,12 +66,12 @@ class TestStickyCookieState: class TestStickyAuthState: def test_handle_response(self): s = flow.StickyAuthState(filt.parse(".*")) - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.request.headers["authorization"] = ["foo"] s.handle_request(f) assert "address" in s.hosts - f = tutils.tflow_full() + f = tutils.tflow(resp=True) s.handle_request(f) assert f.request.headers["authorization"] == ["foo"] @@ -123,24 +123,24 @@ class TestServerPlaybackState: def test_headers(self): s = flow.ServerPlaybackState(["foo"], [], False, False) - r = tutils.tflow_full() + r = tutils.tflow(resp=True) r.request.headers["foo"] = ["bar"] - r2 = tutils.tflow_full() + r2 = tutils.tflow(resp=True) assert not s._hash(r) == s._hash(r2) r2.request.headers["foo"] = ["bar"] assert s._hash(r) == s._hash(r2) r2.request.headers["oink"] = ["bar"] assert s._hash(r) == s._hash(r2) - r = tutils.tflow_full() - r2 = tutils.tflow_full() + r = tutils.tflow(resp=True) + r2 = tutils.tflow(resp=True) assert s._hash(r) == s._hash(r2) def test_load(self): - r = tutils.tflow_full() + r = tutils.tflow(resp=True) r.request.headers["key"] = ["one"] - r2 = tutils.tflow_full() + r2 = tutils.tflow(resp=True) r2.request.headers["key"] = ["two"] s = flow.ServerPlaybackState(None, [r, r2], False, False) @@ -158,10 +158,10 @@ class TestServerPlaybackState: assert not s.next_flow(r) def test_load_with_nopop(self): - r = tutils.tflow_full() + r = tutils.tflow(resp=True) r.request.headers["key"] = ["one"] - r2 = tutils.tflow_full() + r2 = tutils.tflow(resp=True) r2.request.headers["key"] = ["two"] s = flow.ServerPlaybackState(None, [r, r2], False, True) @@ -173,7 +173,7 @@ class TestServerPlaybackState: class TestFlow: def test_copy(self): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) a0 = f._get_state() f2 = f.copy() a = f._get_state() @@ -188,7 +188,7 @@ class TestFlow: assert f.response == f2.response assert not f.response is f2.response - f = tutils.tflow_err() + f = tutils.tflow(err=True) f2 = f.copy() assert not f is f2 assert not f.request is f2.request @@ -198,12 +198,12 @@ class TestFlow: assert not f.error is f2.error def test_match(self): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) assert not f.match("~b test") assert f.match(None) assert not f.match("~b test") - f = tutils.tflow_err() + f = tutils.tflow(err=True) assert f.match("~e") tutils.raises(ValueError, f.match, "~") @@ -220,14 +220,14 @@ class TestFlow: assert f.request.content == "foo" def test_backup_idempotence(self): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.backup() f.revert() f.backup() f.revert() def test_getset_state(self): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) state = f._get_state() assert f._get_state() == protocol.http.HTTPFlow._from_state(state)._get_state() @@ -248,55 +248,42 @@ class TestFlow: s = flow.State() fm = flow.FlowMaster(None, s) f = tutils.tflow() - f.request = tutils.treq() f.intercept() - assert not f.request.reply.acked + assert not f.reply.acked f.kill(fm) - assert f.request.reply.acked - f.intercept() - f.response = tutils.tresp() - f.request.reply() - assert not f.response.reply.acked - f.kill(fm) - assert f.response.reply.acked + assert f.reply.acked def test_killall(self): s = flow.State() fm = flow.FlowMaster(None, s) - r = tutils.treq() - fm.handle_request(r) + f = tutils.tflow() + fm.handle_request(f) - r = tutils.treq() - fm.handle_request(r) + f = tutils.tflow() + fm.handle_request(f) for i in s.view: - assert not i.request.reply.acked + assert not i.reply.acked s.killall(fm) for i in s.view: - assert i.request.reply.acked + assert i.reply.acked def test_accept_intercept(self): f = tutils.tflow() - f.request = tutils.treq() - f.intercept() - assert not f.request.reply.acked - f.accept_intercept() - assert f.request.reply.acked - f.response = tutils.tresp() + f.intercept() - f.request.reply() - assert not f.response.reply.acked + assert not f.reply.acked f.accept_intercept() - assert f.response.reply.acked + assert f.reply.acked def test_replace_unicode(self): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.response.content = "\xc2foo" f.replace("foo", u"bar") def test_replace(self): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.request.headers["foo"] = ["foo"] f.request.content = "afoob" @@ -311,7 +298,7 @@ class TestFlow: assert f.response.content == "abarb" def test_replace_encoded(self): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.request.content = "afoob" f.request.encode("gzip") f.response.content = "afoob" @@ -332,9 +319,8 @@ class TestFlow: class TestState: def test_backup(self): c = flow.State() - req = tutils.treq() - f = c.add_request(req) - + f = tutils.tflow() + c.add_request(f) f.backup() c.revert(f) @@ -344,72 +330,66 @@ class TestState: connect -> request -> response """ - bc = tutils.tclient_conn() c = flow.State() - - req = tutils.treq(bc) - f = c.add_request(req) + f = tutils.tflow() + c.add_request(f) assert f assert c.flow_count() == 1 assert c.active_flow_count() == 1 - newreq = tutils.treq() - assert c.add_request(newreq) + newf = tutils.tflow() + assert c.add_request(newf) assert c.active_flow_count() == 2 - resp = tutils.tresp(req) - assert c.add_response(resp) + f.response = tutils.tresp() + assert c.add_response(f) assert c.flow_count() == 2 assert c.active_flow_count() == 1 - unseen_resp = tutils.tresp() - unseen_resp.flow = None - assert not c.add_response(unseen_resp) + _ = tutils.tresp() + assert not c.add_response(None) assert c.active_flow_count() == 1 - resp = tutils.tresp(newreq) - assert c.add_response(resp) + newf.response = tutils.tresp() + assert c.add_response(newf) assert c.active_flow_count() == 0 def test_err(self): c = flow.State() - req = tutils.treq() - f = c.add_request(req) + f = tutils.tflow() + c.add_request(f) f.error = Error("message") - assert c.add_error(f.error) - - e = Error("message") - assert not c.add_error(e) + assert c.add_error(f) c = flow.State() - req = tutils.treq() - f = c.add_request(req) - e = tutils.terr() + f = tutils.tflow() + c.add_request(f) c.set_limit("~e") assert not c.view - assert c.add_error(e) + f.error = tutils.terr() + assert c.add_error(f) assert c.view def test_set_limit(self): c = flow.State() - req = tutils.treq() + f = tutils.tflow() assert len(c.view) == 0 - c.add_request(req) + c.add_request(f) assert len(c.view) == 1 c.set_limit("~s") assert c.limit_txt == "~s" assert len(c.view) == 0 - resp = tutils.tresp(req) - c.add_response(resp) + f.response = tutils.tresp() + c.add_response(f) assert len(c.view) == 1 c.set_limit(None) assert len(c.view) == 1 - req = tutils.treq() - c.add_request(req) + f = tutils.tflow() + c.add_request(f) assert len(c.view) == 2 c.set_limit("~q") assert len(c.view) == 1 @@ -427,20 +407,19 @@ class TestState: assert c.intercept_txt == None def _add_request(self, state): - req = tutils.treq() - f = state.add_request(req) + f = tutils.tflow() + state.add_request(f) return f def _add_response(self, state): - req = tutils.treq() - state.add_request(req) - resp = tutils.tresp(req) - state.add_response(resp) + f = tutils.tflow() + state.add_request(f) + f.response = tutils.tresp() + state.add_response(f) def _add_error(self, state): - req = tutils.treq() - f = state.add_request(req) - f.error = Error("msg") + f = tutils.tflow(err=True) + state.add_request(f) def test_clear(self): c = flow.State() @@ -479,10 +458,10 @@ class TestSerialize: sio = StringIO() w = flow.FlowWriter(sio) for i in range(3): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) w.add(f) for i in range(3): - f = tutils.tflow_err() + f = tutils.tflow(err=True) w.add(f) sio.seek(0) @@ -516,11 +495,11 @@ class TestSerialize: fl = filt.parse("~c 200") w = flow.FilteredFlowWriter(sio, fl) - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.response.code = 200 w.add(f) - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.response.code = 201 w.add(f) @@ -565,7 +544,7 @@ class TestFlowMaster: def test_replay(self): s = flow.State() fm = flow.FlowMaster(None, s) - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.request.content = CONTENT_MISSING assert "missing" in fm.replay_request(f) @@ -576,48 +555,44 @@ class TestFlowMaster: s = flow.State() fm = flow.FlowMaster(None, s) assert not fm.load_script(tutils.test_data.path("scripts/reqerr.py")) - req = tutils.treq() - fm.handle_clientconnect(req.flow.client_conn) - assert fm.handle_request(req) + f = tutils.tflow() + fm.handle_clientconnect(f.client_conn) + assert fm.handle_request(f) def test_script(self): s = flow.State() fm = flow.FlowMaster(None, s) assert not fm.load_script(tutils.test_data.path("scripts/all.py")) - req = tutils.treq() - fm.handle_clientconnect(req.flow.client_conn) + f = tutils.tflow(resp=True) + + fm.handle_clientconnect(f.client_conn) assert fm.scripts[0].ns["log"][-1] == "clientconnect" - sc = ServerConnection((req.get_host(), req.get_port()), None) - sc.reply = controller.DummyReply() - fm.handle_serverconnect(sc) + fm.handle_serverconnect(f.server_conn) assert fm.scripts[0].ns["log"][-1] == "serverconnect" - f = fm.handle_request(req) + fm.handle_request(f) assert fm.scripts[0].ns["log"][-1] == "request" - resp = tutils.tresp(req) - fm.handle_response(resp) + fm.handle_response(f) assert fm.scripts[0].ns["log"][-1] == "response" #load second script assert not fm.load_script(tutils.test_data.path("scripts/all.py")) assert len(fm.scripts) == 2 - fm.handle_clientdisconnect(sc) + fm.handle_clientdisconnect(f.server_conn) assert fm.scripts[0].ns["log"][-1] == "clientdisconnect" assert fm.scripts[1].ns["log"][-1] == "clientdisconnect" - #unload first script fm.unload_scripts() assert len(fm.scripts) == 0 - assert not fm.load_script(tutils.test_data.path("scripts/all.py")) - err = tutils.terr() - err.reply = controller.DummyReply() - fm.handle_error(err) + + f.error = tutils.terr() + fm.handle_error(f) assert fm.scripts[0].ns["log"][-1] == "error" def test_duplicate_flow(self): s = flow.State() fm = flow.FlowMaster(None, s) - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f = fm.load_flow(f) assert s.flow_count() == 1 f2 = fm.duplicate_flow(f) @@ -630,25 +605,22 @@ class TestFlowMaster: fm = flow.FlowMaster(None, s) fm.anticache = True fm.anticomp = True - req = tutils.treq() - fm.handle_clientconnect(req.flow.client_conn) - - f = fm.handle_request(req) + f = tutils.tflow(req=None) + fm.handle_clientconnect(f.client_conn) + f.request = tutils.treq() + fm.handle_request(f) assert s.flow_count() == 1 - resp = tutils.tresp(req) - fm.handle_response(resp) + f.response = tutils.tresp() + fm.handle_response(f) + assert not fm.handle_response(None) assert s.flow_count() == 1 - rx = tutils.tresp() - rx.flow = None - assert not fm.handle_response(rx) - - fm.handle_clientdisconnect(req.flow.client_conn) + fm.handle_clientdisconnect(f.client_conn) f.error = Error("msg") f.error.reply = controller.DummyReply() - fm.handle_error(f.error) + fm.handle_error(f) fm.load_script(tutils.test_data.path("scripts/a.py")) fm.shutdown() @@ -656,8 +628,8 @@ class TestFlowMaster: def test_client_playback(self): s = flow.State() - f = tutils.tflow_full() - pb = [tutils.tflow_full(), f] + f = tutils.tflow(resp=True) + pb = [tutils.tflow(resp=True), f] fm = flow.FlowMaster(None, s) assert not fm.start_server_playback(pb, False, [], False, False) assert not fm.start_client_playback(pb, False) @@ -668,8 +640,7 @@ class TestFlowMaster: assert fm.state.flow_count() f.error = Error("error") - f.error.reply = controller.DummyReply() - fm.handle_error(f.error) + fm.handle_error(f) def test_server_playback(self): s = flow.State() @@ -723,15 +694,15 @@ class TestFlowMaster: assert not fm.stickycookie_state fm.set_stickycookie(".*") - tf = tutils.tflow_full() - tf.response.headers["set-cookie"] = ["foo=bar"] - fm.handle_request(tf.request) - fm.handle_response(tf.response) + f = tutils.tflow(resp=True) + f.response.headers["set-cookie"] = ["foo=bar"] + fm.handle_request(f) + fm.handle_response(f) assert fm.stickycookie_state.jar - assert not "cookie" in tf.request.headers - tf = tf.copy() - fm.handle_request(tf.request) - assert tf.request.headers["cookie"] == ["foo=bar"] + assert not "cookie" in f.request.headers + f = f.copy() + fm.handle_request(f) + assert f.request.headers["cookie"] == ["foo=bar"] def test_stickyauth(self): s = flow.State() @@ -743,14 +714,14 @@ class TestFlowMaster: assert not fm.stickyauth_state fm.set_stickyauth(".*") - tf = tutils.tflow_full() - tf.request.headers["authorization"] = ["foo"] - fm.handle_request(tf.request) + f = tutils.tflow(resp=True) + f.request.headers["authorization"] = ["foo"] + fm.handle_request(f) - f = tutils.tflow_full() + f = tutils.tflow(resp=True) assert fm.stickyauth_state.hosts assert not "authorization" in f.request.headers - fm.handle_request(f.request) + fm.handle_request(f) assert f.request.headers["authorization"] == ["foo"] def test_stream(self): @@ -762,29 +733,30 @@ class TestFlowMaster: s = flow.State() fm = flow.FlowMaster(None, s) - tf = tutils.tflow_full() + f = tutils.tflow(resp=True) fm.start_stream(file(p, "ab"), None) - fm.handle_request(tf.request) - fm.handle_response(tf.response) + fm.handle_request(f) + fm.handle_response(f) fm.stop_stream() assert r()[0].response - tf = tutils.tflow() + f = tutils.tflow() fm.start_stream(file(p, "ab"), None) - fm.handle_request(tf.request) + fm.handle_request(f) fm.shutdown() assert not r()[1].response class TestRequest: def test_simple(self): - r = tutils.treq() - u = r.get_url() - assert r.set_url(u) - assert not r.set_url("") - assert r.get_url() == u + f = tutils.tflow() + r = f.request + u = r.get_url(False, f) + assert r.set_url(u, f) + assert not r.set_url("", f) + assert r.get_url(False, f) == u assert r._assemble() assert r.size() == len(r._assemble()) @@ -799,42 +771,45 @@ class TestRequest: tutils.raises("Cannot assemble flow with CONTENT_MISSING", r._assemble) def test_get_url(self): - r = tutils.tflow().request + f = tutils.tflow() + r = f.request - assert r.get_url() == "http://address:22/path" + assert r.get_url(False, f) == "http://address:22/path" - r.flow.server_conn.ssl_established = True - assert r.get_url() == "https://address:22/path" + r.scheme = "https" + assert r.get_url(False, f) == "https://address:22/path" - r.flow.server_conn.address = tcp.Address(("host", 42)) - assert r.get_url() == "https://host:42/path" + r.host = "host" + r.port = 42 + assert r.get_url(False, f) == "https://host:42/path" r.host = "address" r.port = 22 - assert r.get_url() == "https://address:22/path" + assert r.get_url(False, f) == "https://address:22/path" - assert r.get_url(hostheader=True) == "https://address:22/path" + assert r.get_url(True, f) == "https://address:22/path" r.headers["Host"] = ["foo.com"] - assert r.get_url() == "https://address:22/path" - assert r.get_url(hostheader=True) == "https://foo.com:22/path" + assert r.get_url(False, f) == "https://address:22/path" + assert r.get_url(True, f) == "https://foo.com:22/path" def test_path_components(self): - r = tutils.treq() + f = tutils.tflow() + r = f.request r.path = "/" - assert r.get_path_components() == [] + assert r.get_path_components(f) == [] r.path = "/foo/bar" - assert r.get_path_components() == ["foo", "bar"] + assert r.get_path_components(f) == ["foo", "bar"] q = flow.ODict() q["test"] = ["123"] - r.set_query(q) - assert r.get_path_components() == ["foo", "bar"] - - r.set_path_components([]) - assert r.get_path_components() == [] - r.set_path_components(["foo"]) - assert r.get_path_components() == ["foo"] - r.set_path_components(["/oo"]) - assert r.get_path_components() == ["/oo"] + r.set_query(q, f) + assert r.get_path_components(f) == ["foo", "bar"] + + r.set_path_components([], f) + assert r.get_path_components(f) == [] + r.set_path_components(["foo"], f) + assert r.get_path_components(f) == ["foo"] + r.set_path_components(["/oo"], f) + assert r.get_path_components(f) == ["/oo"] assert "%2F" in r.path def test_getset_form_urlencoded(self): @@ -853,26 +828,26 @@ class TestRequest: def test_getset_query(self): h = flow.ODictCaseless() - r = tutils.treq() - r.path = "/foo?x=y&a=b" - q = r.get_query() + f = tutils.tflow() + f.request.path = "/foo?x=y&a=b" + q = f.request.get_query(f) assert q.lst == [("x", "y"), ("a", "b")] - r.path = "/" - q = r.get_query() + f.request.path = "/" + q = f.request.get_query(f) assert not q - r.path = "/?adsfa" - q = r.get_query() + f.request.path = "/?adsfa" + q = f.request.get_query(f) assert q.lst == [("adsfa", "")] - r.path = "/foo?x=y&a=b" - assert r.get_query() - r.set_query(flow.ODict([])) - assert not r.get_query() + f.request.path = "/foo?x=y&a=b" + assert f.request.get_query(f) + f.request.set_query(flow.ODict([]), f) + assert not f.request.get_query(f) qv = flow.ODict([("a", "b"), ("c", "d")]) - r.set_query(qv) - assert r.get_query() == qv + f.request.set_query(qv, f) + assert f.request.get_query(f) == qv def test_anticache(self): h = flow.ODictCaseless() @@ -979,8 +954,8 @@ class TestRequest: h["headername"] = ["headervalue"] r = tutils.treq() r.headers = h - result = len(r._assemble_headers()) - assert result == 62 + raw = r._assemble_headers() + assert len(raw) == 62 def test_get_content_type(self): h = flow.ODictCaseless() @@ -991,7 +966,7 @@ class TestRequest: class TestResponse: def test_simple(self): - f = tutils.tflow_full() + f = tutils.tflow(resp=True) resp = f.response assert resp._assemble() assert resp.size() == len(resp._assemble()) @@ -1227,7 +1202,7 @@ def test_replacehooks(): h.run(f) assert f.request.content == "foo" - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.request.content = "foo" f.response.content = "foo" h.run(f) @@ -1280,7 +1255,7 @@ def test_setheaders(): h.clear() h.add("~s", "one", "two") h.add("~s", "one", "three") - f = tutils.tflow_full() + f = tutils.tflow(resp=True) f.request.headers["one"] = ["xxx"] f.response.headers["one"] = ["xxx"] h.run(f) diff --git a/test/test_protocol_http.py b/test/test_protocol_http.py index 3b922c06..c2ff7b44 100644 --- a/test/test_protocol_http.py +++ b/test/test_protocol_http.py @@ -26,10 +26,12 @@ def test_stripped_chunked_encoding_no_content(): class TestHTTPRequest: def test_asterisk_form(self): s = StringIO("OPTIONS * HTTP/1.1") - f = tutils.tflow_noreq() + f = tutils.tflow(req=None) f.request = HTTPRequest.from_stream(s) assert f.request.form_in == "relative" - x = f.request._assemble() + f.request.host = f.server_conn.address.host + f.request.port = f.server_conn.address.port + f.request.scheme = "http" assert f.request._assemble() == "OPTIONS * HTTP/1.1\r\nHost: address:22\r\n\r\n" def test_origin_form(self): @@ -41,6 +43,7 @@ class TestHTTPRequest: tutils.raises("Bad HTTP request line", HTTPRequest.from_stream, s) s = StringIO("CONNECT address:22 HTTP/1.1") r = HTTPRequest.from_stream(s) + r.scheme, r.host, r.port = "http", "address", 22 assert r._assemble() == "CONNECT address:22 HTTP/1.1\r\nHost: address:22\r\n\r\n" def test_absolute_form(self): @@ -55,12 +58,12 @@ class TestHTTPRequest: tutils.raises("Invalid request form", r._assemble, "antiauthority") def test_set_url(self): - r = tutils.treq_absolute() - r.set_url("https://otheraddress:42/ORLY") - assert r.scheme == "https" - assert r.host == "otheraddress" - assert r.port == 42 - assert r.path == "/ORLY" + f = tutils.tflow(req=tutils.treq_absolute()) + f.request.set_url("https://otheraddress:42/ORLY", f) + assert f.request.scheme == "https" + assert f.request.host == "otheraddress" + assert f.request.port == 42 + assert f.request.path == "/ORLY" class TestHTTPResponse: @@ -130,10 +133,10 @@ class TestProxyChainingSSL(tservers.HTTPChainProxyTest): """ https://github.com/mitmproxy/mitmproxy/issues/313 """ - def handle_request(r): - r.httpversion = (1,0) - del r.headers["Content-Length"] - r.reply() + def handle_request(f): + f.request.httpversion = (1, 0) + del f.request.headers["Content-Length"] + f.reply() _handle_request = self.chain[0].tmaster.handle_request self.chain[0].tmaster.handle_request = handle_request try: @@ -159,13 +162,13 @@ class TestProxyChainingSSLReconnect(tservers.HTTPChainProxyTest): def kill_requests(master, attr, exclude): k = [0] # variable scope workaround: put into array _func = getattr(master, attr) - def handler(r): + def handler(f): k[0] += 1 if not (k[0] in exclude): - r.flow.client_conn.finish() - r.flow.error = Error("terminated") - r.reply(KILL) - return _func(r) + f.client_conn.finish() + f.error = Error("terminated") + f.reply(KILL) + return _func(f) setattr(master, attr, handler) kill_requests(self.proxy.tmaster, "handle_request", diff --git a/test/test_proxy.py b/test/test_proxy.py index 2ff01acc..91e4954f 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -25,11 +25,11 @@ class TestServerConnection: def test_simple(self): sc = ServerConnection((self.d.IFACE, self.d.port), None) sc.connect() - r = tutils.treq() - r.flow.server_conn = sc - r.path = "/p/200:da" - sc.send(r._assemble()) - assert http.read_response(sc.rfile, r.method, 1000) + f = tutils.tflow() + f.server_conn = sc + f.request.path = "/p/200:da" + sc.send(f.request._assemble()) + assert http.read_response(sc.rfile, f.request.method, 1000) assert self.d.last_log() sc.finish() diff --git a/test/test_script.py b/test/test_script.py index 587c52d6..7c421fde 100644 --- a/test/test_script.py +++ b/test/test_script.py @@ -29,8 +29,8 @@ class TestScript: s = flow.State() fm = flow.FlowMaster(None, s) fm.load_script(tutils.test_data.path("scripts/duplicate_flow.py")) - r = tutils.treq() - fm.handle_request(r) + f = tutils.tflow() + fm.handle_request(f) assert fm.state.flow_count() == 2 assert not fm.state.view[0].request.is_replay assert fm.state.view[1].request.is_replay @@ -65,12 +65,12 @@ class TestScript: fm.load_script(tutils.test_data.path("scripts/concurrent_decorator.py")) with mock.patch("libmproxy.controller.DummyReply.__call__") as m: - r1, r2 = tutils.treq(), tutils.treq() + f1, f2 = tutils.tflow(), tutils.tflow() t_start = time.time() - fm.handle_request(r1) - r1.reply() - fm.handle_request(r2) - r2.reply() + fm.handle_request(f1) + f1.reply() + fm.handle_request(f2) + f2.reply() # Two instantiations assert m.call_count == 0 # No calls yet. diff --git a/test/test_server.py b/test/test_server.py index a570f10f..48527547 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -327,29 +327,32 @@ class TestProxySSL(tservers.HTTPProxTest): # tests that the ssl timestamp is present when ssl is used f = self.pathod("304:b@10k") assert f.status_code == 304 - first_request = self.master.state.view[0].request - assert first_request.flow.server_conn.timestamp_ssl_setup + first_flow = self.master.state.view[0] + assert first_flow.server_conn.timestamp_ssl_setup class MasterRedirectRequest(tservers.TestMaster): - def handle_request(self, request): + redirect_port = None # Set by TestRedirectRequest + + def handle_request(self, f): + request = f.request if request.path == "/p/201": - url = request.get_url() + url = request.get_url(False, f) new = "http://127.0.0.1:%s/p/201" % self.redirect_port - request.set_url(new) - request.set_url(new) - request.flow.live.change_server(("127.0.0.1", self.redirect_port), False) - request.set_url(url) - tutils.raises("SSL handshake error", request.flow.live.change_server, ("127.0.0.1", self.redirect_port), True) - request.set_url(new) - request.set_url(url) - request.set_url(new) - tservers.TestMaster.handle_request(self, request) + request.set_url(new, f) + request.set_url(new, f) + f.live.change_server(("127.0.0.1", self.redirect_port), False) + request.set_url(url, f) + tutils.raises("SSL handshake error", f.live.change_server, ("127.0.0.1", self.redirect_port), True) + request.set_url(new, f) + request.set_url(url, f) + request.set_url(new, f) + tservers.TestMaster.handle_request(self, f) - def handle_response(self, response): - response.content = str(response.flow.client_conn.address.port) - tservers.TestMaster.handle_response(self, response) + def handle_response(self, f): + f.response.content = str(f.client_conn.address.port) + tservers.TestMaster.handle_response(self, f) class TestRedirectRequest(tservers.HTTPProxTest): @@ -388,9 +391,9 @@ class MasterStreamRequest(tservers.TestMaster): """ Enables the stream flag on the flow for all requests """ - def handle_responseheaders(self, r): - r.stream = True - r.reply() + def handle_responseheaders(self, f): + f.response.stream = True + f.reply() class TestStreamRequest(tservers.HTTPProxTest): masterclass = MasterStreamRequest @@ -441,9 +444,9 @@ class TestStreamRequest(tservers.HTTPProxTest): class MasterFakeResponse(tservers.TestMaster): - def handle_request(self, m): + def handle_request(self, f): resp = tutils.tresp() - m.reply(resp) + f.reply(resp) class TestFakeResponse(tservers.HTTPProxTest): @@ -454,8 +457,8 @@ class TestFakeResponse(tservers.HTTPProxTest): class MasterKillRequest(tservers.TestMaster): - def handle_request(self, m): - m.reply(KILL) + def handle_request(self, f): + f.reply(KILL) class TestKillRequest(tservers.HTTPProxTest): @@ -467,8 +470,8 @@ class TestKillRequest(tservers.HTTPProxTest): class MasterKillResponse(tservers.TestMaster): - def handle_response(self, m): - m.reply(KILL) + def handle_response(self, f): + f.reply(KILL) class TestKillResponse(tservers.HTTPProxTest): @@ -491,10 +494,10 @@ class TestTransparentResolveError(tservers.TransparentProxTest): class MasterIncomplete(tservers.TestMaster): - def handle_request(self, m): + def handle_request(self, f): resp = tutils.tresp() resp.content = CONTENT_MISSING - m.reply(resp) + f.reply(resp) class TestIncompleteResponse(tservers.HTTPProxTest): diff --git a/test/tservers.py b/test/tservers.py index a12a440e..9f2abbe1 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -36,13 +36,13 @@ class TestMaster(flow.FlowMaster): self.apps.add(errapp, "errapp", 80) self.clear_log() - def handle_request(self, m): - flow.FlowMaster.handle_request(self, m) - m.reply() + def handle_request(self, f): + flow.FlowMaster.handle_request(self, f) + f.reply() - def handle_response(self, m): - flow.FlowMaster.handle_response(self, m) - m.reply() + def handle_response(self, f): + flow.FlowMaster.handle_response(self, f) + f.reply() def clear_log(self): self.log = [] diff --git a/test/tutils.py b/test/tutils.py index dc049adb..84a9bba0 100644 --- a/test/tutils.py +++ b/test/tutils.py @@ -21,7 +21,38 @@ def SkipWindows(fn): return fn +def tflow(client_conn=True, server_conn=True, req=True, resp=None, err=None): + """ + @type client_conn: bool | None | libmproxy.proxy.connection.ClientConnection + @type server_conn: bool | None | libmproxy.proxy.connection.ServerConnection + @type req: bool | None | libmproxy.protocol.http.HTTPRequest + @type resp: bool | None | libmproxy.protocol.http.HTTPResponse + @type err: bool | None | libmproxy.protocol.primitives.Error + @return: bool | None | libmproxy.protocol.http.HTTPFlow + """ + if client_conn is True: + client_conn = tclient_conn() + if server_conn is True: + server_conn = tserver_conn() + if req is True: + req = treq() + if resp is True: + resp = tresp() + if err is True: + err = terr() + + f = http.HTTPFlow(client_conn, server_conn) + f.request = req + f.response = resp + f.error = err + f.reply = controller.DummyReply() + return f + + def tclient_conn(): + """ + @return: libmproxy.proxy.connection.ClientConnection + """ c = ClientConnection._from_state(dict( address=dict(address=("address", 22), use_ipv6=True), clientcert=None @@ -31,6 +62,9 @@ def tclient_conn(): def tserver_conn(): + """ + @return: libmproxy.proxy.connection.ServerConnection + """ c = ServerConnection._from_state(dict( address=dict(address=("address", 22), use_ipv6=True), state=[], @@ -41,75 +75,46 @@ def tserver_conn(): return c -def treq_absolute(conn=None, content="content"): - r = treq(conn, content) +def treq(content="content", scheme="http", host="address", port=22): + """ + @return: libmproxy.protocol.http.HTTPRequest + """ + headers = flow.ODictCaseless() + headers["header"] = ["qvalue"] + req = http.HTTPRequest("relative", "GET", scheme, host, port, "/path", (1, 1), headers, content, + None, None, None) + return req + +def treq_absolute(content="content"): + """ + @return: libmproxy.protocol.http.HTTPRequest + """ + r = treq(content) r.form_in = r.form_out = "absolute" r.host = "address" r.port = 22 r.scheme = "http" return r -def treq(conn=None, content="content"): - if not conn: - conn = tclient_conn() - server_conn = tserver_conn() - headers = flow.ODictCaseless() - headers["header"] = ["qvalue"] - f = http.HTTPFlow(conn, server_conn) - f.request = http.HTTPRequest("relative", "GET", None, None, None, "/path", (1, 1), headers, content, - None, None, None) - f.request.reply = controller.DummyReply() - return f.request - - -def tresp(req=None, content="message"): - if not req: - req = treq() - f = req.flow +def tresp(content="message"): + """ + @return: libmproxy.protocol.http.HTTPResponse + """ headers = flow.ODictCaseless() headers["header_response"] = ["svalue"] - cert = certutils.SSLCert.from_der(file(test_data.path("data/dercert"), "rb").read()) - f.server_conn = ServerConnection._from_state(dict( - address=dict(address=("address", 22), use_ipv6=True), - state=[], - source_address=None, - cert=cert.to_pem())) - f.response = http.HTTPResponse((1, 1), 200, "OK", headers, content, time(), time()) - f.response.reply = controller.DummyReply() - return f.response + resp = http.HTTPResponse((1, 1), 200, "OK", headers, content, time(), time()) + return resp -def terr(req=None): - if not req: - req = treq() - f = req.flow - f.error = Error("error") - f.error.reply = controller.DummyReply() - return f.error - -def tflow_noreq(): - f = tflow() - f.request = None - return f -def tflow(req=None): - if not req: - req = treq() - return req.flow - - -def tflow_full(): - f = tflow() - f.response = tresp(f.request) - return f - - -def tflow_err(): - f = tflow() - f.error = terr(f.request) - return f +def terr(content="error"): + """ + @return: libmproxy.protocol.primitives.Error + """ + err = Error(content) + return err def tflowview(request_contents=None): m = Mock() @@ -117,8 +122,7 @@ def tflowview(request_contents=None): if request_contents == None: flow = tflow() else: - req = treq(None, request_contents) - flow = tflow(req) + flow = tflow(req=treq(request_contents)) fv = FlowView(m, cs, flow) return fv -- cgit v1.2.3 From cd43c5ba9c2981aeffee354cbcb574b6f5e435ba Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 3 Sep 2014 20:12:30 +0200 Subject: simplify server changes for inline scripts --- libmproxy/protocol/http.py | 95 ++++++++++++++++++++++++++++++++-------- libmproxy/protocol/primitives.py | 51 +++++++++++---------- libmproxy/proxy/connection.py | 5 +-- libmproxy/proxy/primitives.py | 13 ------ libmproxy/proxy/server.py | 24 ++++------ 5 files changed, 115 insertions(+), 73 deletions(-) diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index 3f9eecb3..7577e0d3 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -26,7 +26,7 @@ def get_line(fp): return line -def send_connect_request(conn, host, port): +def send_connect_request(conn, host, port, update_state=True): upstream_request = HTTPRequest("authority", "CONNECT", None, host, port, None, (1, 1), ODictCaseless(), "") conn.send(upstream_request._assemble()) @@ -36,6 +36,12 @@ def send_connect_request(conn, host, port): "Cannot establish SSL " + "connection with upstream proxy: \r\n" + str(resp._assemble())) + if update_state: + conn.state.append(("http", { + "state": "connect", + "host": host, + "port": port} + )) return resp @@ -545,8 +551,7 @@ class HTTPRequest(HTTPMessage): flow.live.change_server((host, port), ssl=is_ssl) else: # There's not live server connection, we're just changing the attributes here. - flow.server_conn = ServerConnection((host, port), - proxy.AddressPriority.MANUALLY_CHANGED) + flow.server_conn = ServerConnection((host, port)) flow.server_conn.ssl_established = is_ssl # If this is an absolute request, replace the attributes on the request object as well. @@ -815,7 +820,7 @@ class HTTPFlow(Flow): s = " %s:%s" % ( - self._c.server_conn.address.host, - self._c.server_conn.address.port, + self.c.log("Change server connection: %s:%s -> %s:%s [persistent: %s]" % ( + self.c.server_conn.address.host, + self.c.server_conn.address.port, address.host, - address.port + address.port, + persistent_change ), "debug") - if not hasattr(self, "_backup_server_conn"): - self._backup_server_conn = self._c.server_conn - self._c.server_conn = None + if self._backup_server_conn: + self._backup_server_conn = self.c.server_conn + self.c.server_conn = None else: # This is at least the second temporary change. We can kill the current connection. - self._c.del_server_connection() + self.c.del_server_connection() - self._c.set_server_address(address, AddressPriority.MANUALLY_CHANGED) - self._c.establish_server_connection(ask=False) + self.c.set_server_address(address) + self.c.establish_server_connection(ask=False) if ssl: - self._c.establish_ssl(server=True) - if hasattr(self, "_backup_server_conn") and persistent_change: - del self._backup_server_conn + self.c.establish_ssl(server=True) + if persistent_change: + self._backup_server_conn = None def restore_server(self): - if not hasattr(self, "_backup_server_conn"): + # TODO: Similar to _backup_server_conn, introduce _cache_server_conn, which keeps the changed connection open + # This may be beneficial if a user is rewriting all requests from http to https or similar. + if not self._backup_server_conn: return - self._c.log("Restore original server connection: %s:%s -> %s:%s" % ( - self._c.server_conn.address.host, - self._c.server_conn.address.port, + self.c.log("Restore original server connection: %s:%s -> %s:%s" % ( + self.c.server_conn.address.host, + self.c.server_conn.address.port, self._backup_server_conn.address.host, self._backup_server_conn.address.port ), "debug") - self._c.del_server_connection() - self._c.server_conn = self._backup_server_conn - del self._backup_server_conn \ No newline at end of file + self.c.del_server_connection() + self.c.server_conn = self._backup_server_conn + self._backup_server_conn = None \ No newline at end of file diff --git a/libmproxy/proxy/connection.py b/libmproxy/proxy/connection.py index d99ffa9b..5c421557 100644 --- a/libmproxy/proxy/connection.py +++ b/libmproxy/proxy/connection.py @@ -72,9 +72,8 @@ class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject): class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject): - def __init__(self, address, priority): + def __init__(self, address): tcp.TCPClient.__init__(self, address) - self.priority = priority self.state = [] # a list containing (conntype, state) tuples self.peername = None @@ -131,7 +130,7 @@ class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject): @classmethod def _from_state(cls, state): - f = cls(tuple(), None) + f = cls(tuple()) f._load_state(state) return f diff --git a/libmproxy/proxy/primitives.py b/libmproxy/proxy/primitives.py index e09f23e4..8c674381 100644 --- a/libmproxy/proxy/primitives.py +++ b/libmproxy/proxy/primitives.py @@ -45,19 +45,6 @@ class TransparentUpstreamServerResolver(UpstreamServerResolver): return [ssl, ssl] + list(dst) -class AddressPriority(object): - """ - Enum that signifies the priority of the given address when choosing the destination host. - Higher is better (None < i) - """ - MANUALLY_CHANGED = 3 - """user changed the target address in the ui""" - FROM_SETTINGS = 2 - """upstream server from arguments (reverse proxy, upstream proxy or from transparent resolver)""" - FROM_PROTOCOL = 1 - """derived from protocol (e.g. absolute-form http requests)""" - - class Log: def __init__(self, msg, level="info"): self.msg = msg diff --git a/libmproxy/proxy/server.py b/libmproxy/proxy/server.py index 092eae54..58e386ab 100644 --- a/libmproxy/proxy/server.py +++ b/libmproxy/proxy/server.py @@ -5,7 +5,7 @@ import socket from OpenSSL import SSL from netlib import tcp -from .primitives import ProxyServerError, Log, ProxyError, AddressPriority +from .primitives import ProxyServerError, Log, ProxyError from .connection import ClientConnection, ServerConnection from ..protocol.handle import protocol_handler from .. import version @@ -76,7 +76,7 @@ class ConnectionHandler: client_ssl, server_ssl = False, False if self.config.get_upstream_server: upstream_info = self.config.get_upstream_server(self.client_conn.connection) - self.set_server_address(upstream_info[2:], AddressPriority.FROM_SETTINGS) + self.set_server_address(upstream_info[2:]) client_ssl, server_ssl = upstream_info[:2] if self.check_ignore_address(self.server_conn.address): self.log("Ignore host: %s:%s" % self.server_conn.address(), "info") @@ -129,27 +129,22 @@ class ConnectionHandler: else: return False - def set_server_address(self, address, priority): + def set_server_address(self, address): """ Sets a new server address with the given priority. Does not re-establish either connection or SSL handshake. """ address = tcp.Address.wrap(address) - if self.server_conn: - if self.server_conn.priority > priority: - self.log("Attempt to change server address, " - "but priority is too low (is: %s, got: %s)" % ( - self.server_conn.priority, priority), "debug") - return - if self.server_conn.address == address: - self.server_conn.priority = priority # Possibly increase priority - return + # Don't reconnect to the same destination. + if self.server_conn and self.server_conn.address == address: + return + if self.server_conn: self.del_server_connection() self.log("Set new server address: %s:%s" % (address.host, address.port), "debug") - self.server_conn = ServerConnection(address, priority) + self.server_conn = ServerConnection(address) def establish_server_connection(self, ask=True): """ @@ -212,12 +207,11 @@ class ConnectionHandler: def server_reconnect(self): address = self.server_conn.address had_ssl = self.server_conn.ssl_established - priority = self.server_conn.priority state = self.server_conn.state sni = self.sni self.log("(server reconnect follows)", "debug") self.del_server_connection() - self.set_server_address(address, priority) + self.set_server_address(address) self.establish_server_connection() for s in state: -- cgit v1.2.3 From 2f44b26b4cd014e03dd62a125d79af9b81663a93 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 3 Sep 2014 23:44:54 +0200 Subject: improve HTTPRequest syntax --- examples/modify_form.py | 4 +- examples/modify_querystring.py | 4 +- examples/redirect_requests.py | 6 +- libmproxy/console/common.py | 2 +- libmproxy/console/flowview.py | 20 ++--- libmproxy/dump.py | 2 +- libmproxy/filt.py | 4 +- libmproxy/flow.py | 12 +-- libmproxy/protocol/http.py | 153 +++++++++++++++------------------------ libmproxy/protocol/primitives.py | 2 +- test/test_flow.py | 102 +++++++++++++------------- test/test_protocol_http.py | 12 +-- test/test_proxy.py | 4 +- test/test_server.py | 11 +-- 14 files changed, 151 insertions(+), 187 deletions(-) diff --git a/examples/modify_form.py b/examples/modify_form.py index 2d839aed..cb12ee0f 100644 --- a/examples/modify_form.py +++ b/examples/modify_form.py @@ -1,8 +1,8 @@ def request(context, flow): if "application/x-www-form-urlencoded" in flow.request.headers["content-type"]: - frm = flow.request.get_form_urlencoded() + frm = flow.request.form_urlencoded frm["mitmproxy"] = ["rocks"] - flow.request.set_form_urlencoded(frm) + flow.request.form_urlencoded = frm diff --git a/examples/modify_querystring.py b/examples/modify_querystring.py index b1abcc1e..7e3a068a 100644 --- a/examples/modify_querystring.py +++ b/examples/modify_querystring.py @@ -1,7 +1,7 @@ def request(context, flow): - q = flow.request.get_query() + q = flow.request.query if q: q["mitmproxy"] = ["rocks"] - flow.request.set_query(q) + flow.request.query = q diff --git a/examples/redirect_requests.py b/examples/redirect_requests.py index a9a7e795..b57df2b2 100644 --- a/examples/redirect_requests.py +++ b/examples/redirect_requests.py @@ -7,12 +7,12 @@ This example shows two ways to redirect flows to other destinations. def request(context, flow): - if flow.request.get_host(hostheader=True).endswith("example.com"): + if flow.request.pretty_host(hostheader=True).endswith("example.com"): resp = HTTPResponse( [1, 1], 200, "OK", ODictCaseless([["Content-Type", "text/html"]]), "helloworld") flow.request.reply(resp) - if flow.request.get_host(hostheader=True).endswith("example.org"): + if flow.request.pretty_host(hostheader=True).endswith("example.org"): flow.request.host = "mitmproxy.org" - flow.request.headers["Host"] = ["mitmproxy.org"] + flow.request.update_host_header() diff --git a/libmproxy/console/common.py b/libmproxy/console/common.py index a8440f79..e2caac3b 100644 --- a/libmproxy/console/common.py +++ b/libmproxy/console/common.py @@ -177,7 +177,7 @@ def format_flow(f, focus, extended=False, hostheader=False, padding=2): req_is_replay = f.request.is_replay, req_method = f.request.method, req_acked = f.request.reply.acked, - req_url = f.request.get_url(hostheader=hostheader), + req_url = f.request.pretty_url(hostheader=hostheader), err_msg = f.error.msg if f.error else None, resp_code = f.response.code if f.response else None, diff --git a/libmproxy/console/flowview.py b/libmproxy/console/flowview.py index 4aaf8944..356d8d99 100644 --- a/libmproxy/console/flowview.py +++ b/libmproxy/console/flowview.py @@ -528,7 +528,9 @@ class FlowView(common.WWrap): def set_url(self, url): request = self.flow.request - if not request.set_url(str(url)): + try: + request.url = str(url) + except ValueError: return "Invalid URL." self.master.refresh_flow(self.flow) @@ -552,17 +554,17 @@ class FlowView(common.WWrap): conn.headers = flow.ODictCaseless(lst) def set_query(self, lst, conn): - conn.set_query(flow.ODict(lst)) + conn.query = flow.ODict(lst) def set_path_components(self, lst, conn): - conn.set_path_components([i[0] for i in lst]) + conn.path_components = [i[0] for i in lst] def set_form(self, lst, conn): - conn.set_form_urlencoded(flow.ODict(lst)) + conn.form_urlencoded = flow.ODict(lst) def edit_form(self, conn): self.master.view_grideditor( - grideditor.URLEncodedFormEditor(self.master, conn.get_form_urlencoded().lst, self.set_form, conn) + grideditor.URLEncodedFormEditor(self.master, conn.form_urlencoded.lst, self.set_form, conn) ) def edit_form_confirm(self, key, conn): @@ -587,7 +589,7 @@ class FlowView(common.WWrap): c = self.master.spawn_editor(conn.content or "") conn.content = c.rstrip("\n") # what? elif part == "f": - if not conn.get_form_urlencoded() and conn.content: + if not conn.form_urlencoded and conn.content: self.master.prompt_onekey( "Existing body is not a URL-encoded form. Clear and edit?", [ @@ -602,13 +604,13 @@ class FlowView(common.WWrap): elif part == "h": self.master.view_grideditor(grideditor.HeaderEditor(self.master, conn.headers.lst, self.set_headers, conn)) elif part == "p": - p = conn.get_path_components() + p = conn.path_components p = [[i] for i in p] self.master.view_grideditor(grideditor.PathEditor(self.master, p, self.set_path_components, conn)) elif part == "q": - self.master.view_grideditor(grideditor.QueryEditor(self.master, conn.get_query().lst, self.set_query, conn)) + self.master.view_grideditor(grideditor.QueryEditor(self.master, conn.query.lst, self.set_query, conn)) elif part == "u" and self.state.view_flow_mode == common.VIEW_FLOW_REQUEST: - self.master.prompt_edit("URL", conn.get_url(), self.set_url) + self.master.prompt_edit("URL", conn.url, self.set_url) elif part == "m" and self.state.view_flow_mode == common.VIEW_FLOW_REQUEST: self.master.prompt_onekey("Method", self.method_options, self.edit_method) elif part == "c" and self.state.view_flow_mode == common.VIEW_FLOW_RESPONSE: diff --git a/libmproxy/dump.py b/libmproxy/dump.py index 8ecd56e7..72ab58a3 100644 --- a/libmproxy/dump.py +++ b/libmproxy/dump.py @@ -55,7 +55,7 @@ def str_request(f, showhost): c = f.client_conn.address.host else: c = "[replay]" - r = "%s %s %s"%(c, f.request.method, f.request.get_url(showhost, f)) + r = "%s %s %s"%(c, f.request.method, f.request.pretty_url(showhost)) if f.request.stickycookie: r = "[stickycookie] " + r return r diff --git a/libmproxy/filt.py b/libmproxy/filt.py index 925dbfbb..7d2bd737 100644 --- a/libmproxy/filt.py +++ b/libmproxy/filt.py @@ -208,7 +208,7 @@ class FDomain(_Rex): code = "d" help = "Domain" def __call__(self, f): - return bool(re.search(self.expr, f.request.get_host(False, f), re.IGNORECASE)) + return bool(re.search(self.expr, f.request.host, re.IGNORECASE)) class FUrl(_Rex): @@ -222,7 +222,7 @@ class FUrl(_Rex): return klass(*toks) def __call__(self, f): - return re.search(self.expr, f.request.get_url(False, f)) + return re.search(self.expr, f.request.url) class _Int(_Action): diff --git a/libmproxy/flow.py b/libmproxy/flow.py index eb183d9f..9115ec9d 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -260,8 +260,8 @@ class StickyCookieState: Returns a (domain, port, path) tuple. """ return ( - m["domain"] or f.request.get_host(False, f), - f.request.get_port(f), + m["domain"] or f.request.host, + f.request.port, m["path"] or "/" ) @@ -279,7 +279,7 @@ class StickyCookieState: c = Cookie.SimpleCookie(str(i)) m = c.values()[0] k = self.ckey(m, f) - if self.domain_match(f.request.get_host(False, f), k[0]): + if self.domain_match(f.request.host, k[0]): self.jar[self.ckey(m, f)] = m def handle_request(self, f): @@ -287,8 +287,8 @@ class StickyCookieState: if f.match(self.flt): for i in self.jar.keys(): match = [ - self.domain_match(f.request.get_host(False, f), i[0]), - f.request.get_port(f) == i[1], + self.domain_match(f.request.host, i[0]), + f.request.port == i[1], f.request.path.startswith(i[2]) ] if all(match): @@ -307,7 +307,7 @@ class StickyAuthState: self.hosts = {} def handle_request(self, f): - host = f.request.get_host(False, f) + host = f.request.host if "authorization" in f.request.headers: self.hosts[host] = f.request.headers["authorization"] elif f.match(self.flt): diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index 7577e0d3..90d8ff16 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -411,7 +411,14 @@ class HTTPRequest(HTTPMessage): e for e in encoding.ENCODINGS if e in self.headers["accept-encoding"][0] )] - def get_form_urlencoded(self): + def update_host_header(self): + """ + Update the host header to reflect the current target. + """ + self.headers["Host"] = [self.host] + + @property + def form_urlencoded(self): """ Retrieves the URL-encoded form data, returning an ODict object. Returns an empty ODict if there is no data or the content-type @@ -421,7 +428,8 @@ class HTTPRequest(HTTPMessage): return ODict(utils.urldecode(self.content)) return ODict([]) - def set_form_urlencoded(self, odict): + @form_urlencoded.setter + def form_urlencoded(self, odict): """ Sets the body to the URL-encoded form data, and adds the appropriate content-type header. Note that this will destory the @@ -432,16 +440,18 @@ class HTTPRequest(HTTPMessage): self.headers["Content-Type"] = [HDR_FORM_URLENCODED] self.content = utils.urlencode(odict.lst) - def get_path_components(self, f): + @property + def path_components(self): """ Returns the path components of the URL as a list of strings. Components are unquoted. """ - _, _, path, _, _, _ = urlparse.urlparse(self.get_url(False, f)) + _, _, path, _, _, _ = urlparse.urlparse(self.url) return [urllib.unquote(i) for i in path.split("/") if i] - def set_path_components(self, lst, f): + @path_components.setter + def path_components(self, lst): """ Takes a list of strings, and sets the path component of the URL. @@ -449,32 +459,34 @@ class HTTPRequest(HTTPMessage): """ lst = [urllib.quote(i, safe="") for i in lst] path = "/" + "/".join(lst) - scheme, netloc, _, params, query, fragment = urlparse.urlparse(self.get_url(False, f)) - self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment]), f) + scheme, netloc, _, params, query, fragment = urlparse.urlparse(self.url) + self.url = urlparse.urlunparse([scheme, netloc, path, params, query, fragment]) - def get_query(self, f): + @property + def query(self): """ Gets the request query string. Returns an ODict object. """ - _, _, _, _, query, _ = urlparse.urlparse(self.get_url(False, f)) + _, _, _, _, query, _ = urlparse.urlparse(self.url) if query: return ODict(utils.urldecode(query)) return ODict([]) - def set_query(self, odict, f): + @query.setter + def query(self, odict): """ Takes an ODict object, and sets the request query string. """ - scheme, netloc, path, params, _, fragment = urlparse.urlparse(self.get_url(False, f)) + scheme, netloc, path, params, _, fragment = urlparse.urlparse(self.url) query = utils.urlencode(odict.lst) - self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment]), f) + self.url = urlparse.urlunparse([scheme, netloc, path, params, query, fragment]) - def get_host(self, hostheader, flow): + def pretty_host(self, hostheader): """ Heuristic to get the host of the request. - Note that get_host() does not always return the TCP destination of the request, - e.g. on a transparently intercepted request to an unrelated HTTP proxy. + Note that pretty_host() does not always return the TCP destination of the request, + e.g. if an upstream proxy is in place If hostheader is set to True, the Host: header will be used as additional (and preferred) data source. This is handy in transparent mode, where only the ip of the destination is known, but not the @@ -484,54 +496,27 @@ class HTTPRequest(HTTPMessage): if hostheader: host = self.headers.get_first("host") if not host: - if self.host: - host = self.host - else: - for s in flow.server_conn.state: - if s[0] == "http" and s[1]["state"] == "connect": - host = s[1]["host"] - break - if not host: - host = flow.server_conn.address.host + host = self.host host = host.encode("idna") return host - def get_scheme(self, flow): - """ - Returns the request port, either from the request itself or from the flow's server connection - """ - if self.scheme: - return self.scheme - if self.form_out == "authority": # On SSLed connections, the original CONNECT request is still unencrypted. - return "http" - return "https" if flow.server_conn.ssl_established else "http" - - def get_port(self, flow): - """ - Returns the request port, either from the request itself or from the flow's server connection - """ - if self.port: - return self.port - for s in flow.server_conn.state: - if s[0] == "http" and s[1].get("state") == "connect": - return s[1]["port"] - return flow.server_conn.address.port + def pretty_url(self, hostheader): + if self.form_out == "authority": # upstream proxy mode + return "%s:%s" % (self.pretty_host(hostheader), self.port) + return utils.unparse_url(self.scheme, + self.pretty_host(hostheader), + self.port, + self.path).encode('ascii') - def get_url(self, hostheader, flow): + @property + def url(self): """ Returns a URL string, constructed from the Request's URL components. - - If hostheader is True, we use the value specified in the request - Host header to construct the URL. """ - if self.form_out == "authority": # upstream proxy mode - return "%s:%s" % (self.get_host(hostheader, flow), self.get_port(flow)) - return utils.unparse_url(self.get_scheme(flow), - self.get_host(hostheader, flow), - self.get_port(flow), - self.path).encode('ascii') + return self.pretty_url(False) - def set_url(self, url, flow): + @url.setter + def url(self, url): """ Parses a URL specification, and updates the Request's information accordingly. @@ -540,31 +525,11 @@ class HTTPRequest(HTTPMessage): """ parts = http.parse_url(url) if not parts: - return False - scheme, host, port, path = parts - is_ssl = (True if scheme == "https" else False) - - self.path = path - - if host != self.get_host(False, flow) or port != self.get_port(flow): - if flow.live: - flow.live.change_server((host, port), ssl=is_ssl) - else: - # There's not live server connection, we're just changing the attributes here. - flow.server_conn = ServerConnection((host, port)) - flow.server_conn.ssl_established = is_ssl - - # If this is an absolute request, replace the attributes on the request object as well. - if self.host: - self.host = host - if self.port: - self.port = port - if self.scheme: - self.scheme = scheme + raise ValueError("Invalid URL: %s" % url) + self.scheme, self.host, self.port, self.path = parts - return True - - def get_cookies(self): + @property + def cookies(self): cookie_headers = self.headers.get("cookie") if not cookie_headers: return None @@ -760,7 +725,8 @@ class HTTPResponse(HTTPMessage): if c: self.headers["set-cookie"] = c - def get_cookies(self): + @property + def cookies(self): cookie_headers = self.headers.get("set-cookie") if not cookie_headers: return None @@ -1127,12 +1093,12 @@ class HTTPHandler(ProtocolHandler): if not request.host: # Host/Port Complication: In upstream mode, use the server we CONNECTed to, # not the upstream proxy. - for s in flow.server_conn.state: - if s[0] == "http" and s[1]["state"] == "connect": - request.host, request.port = s[1]["host"], s[1]["port"] - if not request.host: - request.host = flow.server_conn.address.host - request.port = flow.server_conn.address.port + if flow.server_conn: + for s in flow.server_conn.state: + if s[0] == "http" and s[1]["state"] == "connect": + request.host, request.port = s[1]["host"], s[1]["port"] + if not request.host and flow.server_conn: + request.host, request.port = flow.server_conn.address.host, flow.server_conn.address.port # Now we can process the request. if request.form_in == "authority": @@ -1242,7 +1208,9 @@ class RequestReplayThread(threading.Thread): r.form_out = self.config.http_form_out server_address, server_ssl = False, False - if self.config.get_upstream_server: + # If the flow is live, r.host is already the correct upstream server unless modified by a script. + # If modified by a script, we probably want to keep the modified destination. + if self.config.get_upstream_server and not self.flow.live: try: # this will fail in transparent mode upstream_info = self.config.get_upstream_server(self.flow.client_conn) @@ -1251,17 +1219,16 @@ class RequestReplayThread(threading.Thread): except proxy.ProxyError: pass if not server_address: - server_address = (r.get_host(False, self.flow), r.get_port(self.flow)) + server_address = (r.host, r.port) - server = ServerConnection(server_address, None) + server = ServerConnection(server_address) server.connect() - if server_ssl or r.get_scheme(self.flow) == "https": + if server_ssl or r.scheme == "https": if self.config.http_form_out == "absolute": # form_out == absolute -> forward mode -> send CONNECT - send_connect_request(server, r.get_host(), r.get_port()) + send_connect_request(server, r.host, r.port) r.form_out = "relative" - server.establish_ssl(self.config.clientcerts, - self.flow.server_conn.sni) + server.establish_ssl(self.config.clientcerts, sni=r.host) server.send(r._assemble()) self.flow.response = HTTPResponse.from_stream(server.rfile, r.method, body_size_limit=self.config.body_size_limit) diff --git a/libmproxy/protocol/primitives.py b/libmproxy/protocol/primitives.py index ef5c87fb..416e6880 100644 --- a/libmproxy/protocol/primitives.py +++ b/libmproxy/protocol/primitives.py @@ -58,7 +58,7 @@ class Flow(stateobject.SimpleStateObject): """@type: ClientConnection""" self.server_conn = server_conn """@type: ServerConnection""" - self.live = live # Used by flow.request.set_url to change the server address + self.live = live """@type: LiveConnection""" self.error = None diff --git a/test/test_flow.py b/test/test_flow.py index 6e9464e7..4bc2391e 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -753,10 +753,10 @@ class TestRequest: def test_simple(self): f = tutils.tflow() r = f.request - u = r.get_url(False, f) - assert r.set_url(u, f) - assert not r.set_url("", f) - assert r.get_url(False, f) == u + u = r.url + r.url = u + tutils.raises(ValueError, setattr, r, "url", "") + assert r.url == u assert r._assemble() assert r.size() == len(r._assemble()) @@ -771,83 +771,81 @@ class TestRequest: tutils.raises("Cannot assemble flow with CONTENT_MISSING", r._assemble) def test_get_url(self): - f = tutils.tflow() - r = f.request + r = tutils.treq() - assert r.get_url(False, f) == "http://address:22/path" + assert r.url == "http://address:22/path" r.scheme = "https" - assert r.get_url(False, f) == "https://address:22/path" + assert r.url == "https://address:22/path" r.host = "host" r.port = 42 - assert r.get_url(False, f) == "https://host:42/path" + assert r.url == "https://host:42/path" r.host = "address" r.port = 22 - assert r.get_url(False, f) == "https://address:22/path" + assert r.url== "https://address:22/path" - assert r.get_url(True, f) == "https://address:22/path" + assert r.pretty_url(True) == "https://address:22/path" r.headers["Host"] = ["foo.com"] - assert r.get_url(False, f) == "https://address:22/path" - assert r.get_url(True, f) == "https://foo.com:22/path" + assert r.pretty_url(False) == "https://address:22/path" + assert r.pretty_url(True) == "https://foo.com:22/path" def test_path_components(self): - f = tutils.tflow() - r = f.request + r = tutils.treq() r.path = "/" - assert r.get_path_components(f) == [] + assert r.path_components == [] r.path = "/foo/bar" - assert r.get_path_components(f) == ["foo", "bar"] + assert r.path_components == ["foo", "bar"] q = flow.ODict() q["test"] = ["123"] - r.set_query(q, f) - assert r.get_path_components(f) == ["foo", "bar"] - - r.set_path_components([], f) - assert r.get_path_components(f) == [] - r.set_path_components(["foo"], f) - assert r.get_path_components(f) == ["foo"] - r.set_path_components(["/oo"], f) - assert r.get_path_components(f) == ["/oo"] + r.query = q + assert r.path_components == ["foo", "bar"] + + r.path_components = [] + assert r.path_components == [] + r.path_components = ["foo"] + assert r.path_components == ["foo"] + r.path_components = ["/oo"] + assert r.path_components == ["/oo"] assert "%2F" in r.path def test_getset_form_urlencoded(self): d = flow.ODict([("one", "two"), ("three", "four")]) r = tutils.treq(content=utils.urlencode(d.lst)) r.headers["content-type"] = [protocol.http.HDR_FORM_URLENCODED] - assert r.get_form_urlencoded() == d + assert r.form_urlencoded == d d = flow.ODict([("x", "y")]) - r.set_form_urlencoded(d) - assert r.get_form_urlencoded() == d + r.form_urlencoded = d + assert r.form_urlencoded == d r.headers["content-type"] = ["foo"] - assert not r.get_form_urlencoded() + assert not r.form_urlencoded def test_getset_query(self): h = flow.ODictCaseless() - f = tutils.tflow() - f.request.path = "/foo?x=y&a=b" - q = f.request.get_query(f) + r = tutils.treq() + r.path = "/foo?x=y&a=b" + q = r.query assert q.lst == [("x", "y"), ("a", "b")] - f.request.path = "/" - q = f.request.get_query(f) + r.path = "/" + q = r.query assert not q - f.request.path = "/?adsfa" - q = f.request.get_query(f) + r.path = "/?adsfa" + q = r.query assert q.lst == [("adsfa", "")] - f.request.path = "/foo?x=y&a=b" - assert f.request.get_query(f) - f.request.set_query(flow.ODict([]), f) - assert not f.request.get_query(f) + r.path = "/foo?x=y&a=b" + assert r.query + r.query = flow.ODict([]) + assert not r.query qv = flow.ODict([("a", "b"), ("c", "d")]) - f.request.set_query(qv, f) - assert f.request.get_query(f) == qv + r.query = qv + assert r.query == qv def test_anticache(self): h = flow.ODictCaseless() @@ -918,14 +916,14 @@ class TestRequest: h = flow.ODictCaseless() r = tutils.treq() r.headers = h - assert r.get_cookies() is None + assert r.cookies is None def test_get_cookies_single(self): h = flow.ODictCaseless() h["Cookie"] = ["cookiename=cookievalue"] r = tutils.treq() r.headers = h - result = r.get_cookies() + result = r.cookies assert len(result)==1 assert result['cookiename']==('cookievalue',{}) @@ -934,7 +932,7 @@ class TestRequest: h["Cookie"] = ["cookiename=cookievalue;othercookiename=othercookievalue"] r = tutils.treq() r.headers = h - result = r.get_cookies() + result = r.cookies assert len(result)==2 assert result['cookiename']==('cookievalue',{}) assert result['othercookiename']==('othercookievalue',{}) @@ -944,7 +942,7 @@ class TestRequest: h["Cookie"] = ["cookiename=coo=kievalue;othercookiename=othercookievalue"] r = tutils.treq() r.headers = h - result = r.get_cookies() + result = r.cookies assert len(result)==2 assert result['cookiename']==('coo=kievalue',{}) assert result['othercookiename']==('othercookievalue',{}) @@ -1054,14 +1052,14 @@ class TestResponse: h = flow.ODictCaseless() resp = tutils.tresp() resp.headers = h - assert not resp.get_cookies() + assert not resp.cookies def test_get_cookies_simple(self): h = flow.ODictCaseless() h["Set-Cookie"] = ["cookiename=cookievalue"] resp = tutils.tresp() resp.headers = h - result = resp.get_cookies() + result = resp.cookies assert len(result)==1 assert "cookiename" in result assert result["cookiename"] == ("cookievalue", {}) @@ -1071,7 +1069,7 @@ class TestResponse: h["Set-Cookie"] = ["cookiename=cookievalue;domain=example.com;expires=Wed Oct 21 16:29:41 2015;path=/; HttpOnly"] resp = tutils.tresp() resp.headers = h - result = resp.get_cookies() + result = resp.cookies assert len(result)==1 assert "cookiename" in result assert result["cookiename"][0] == "cookievalue" @@ -1086,7 +1084,7 @@ class TestResponse: h["Set-Cookie"] = ["cookiename=; Expires=Thu, 01-Jan-1970 00:00:01 GMT; path=/"] resp = tutils.tresp() resp.headers = h - result = resp.get_cookies() + result = resp.cookies assert len(result)==1 assert "cookiename" in result assert result["cookiename"][0] == "" @@ -1097,7 +1095,7 @@ class TestResponse: h["Set-Cookie"] = ["cookiename=cookievalue","othercookie=othervalue"] resp = tutils.tresp() resp.headers = h - result = resp.get_cookies() + result = resp.cookies assert len(result)==2 assert "cookiename" in result assert result["cookiename"] == ("cookievalue", {}) diff --git a/test/test_protocol_http.py b/test/test_protocol_http.py index c2ff7b44..c76fa192 100644 --- a/test/test_protocol_http.py +++ b/test/test_protocol_http.py @@ -58,12 +58,12 @@ class TestHTTPRequest: tutils.raises("Invalid request form", r._assemble, "antiauthority") def test_set_url(self): - f = tutils.tflow(req=tutils.treq_absolute()) - f.request.set_url("https://otheraddress:42/ORLY", f) - assert f.request.scheme == "https" - assert f.request.host == "otheraddress" - assert f.request.port == 42 - assert f.request.path == "/ORLY" + r = tutils.treq_absolute() + r.url = "https://otheraddress:42/ORLY" + assert r.scheme == "https" + assert r.host == "otheraddress" + assert r.port == 42 + assert r.path == "/ORLY" class TestHTTPResponse: diff --git a/test/test_proxy.py b/test/test_proxy.py index 91e4954f..ad2bb2d7 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -23,7 +23,7 @@ class TestServerConnection: self.d.shutdown() def test_simple(self): - sc = ServerConnection((self.d.IFACE, self.d.port), None) + sc = ServerConnection((self.d.IFACE, self.d.port)) sc.connect() f = tutils.tflow() f.server_conn = sc @@ -35,7 +35,7 @@ class TestServerConnection: sc.finish() def test_terminate_error(self): - sc = ServerConnection((self.d.IFACE, self.d.port), None) + sc = ServerConnection((self.d.IFACE, self.d.port)) sc.connect() sc.connection = mock.Mock() sc.connection.recv = mock.Mock(return_value=False) diff --git a/test/test_server.py b/test/test_server.py index 48527547..4b8c796c 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -337,17 +337,14 @@ class MasterRedirectRequest(tservers.TestMaster): def handle_request(self, f): request = f.request if request.path == "/p/201": - url = request.get_url(False, f) + url = request.url new = "http://127.0.0.1:%s/p/201" % self.redirect_port - request.set_url(new, f) - request.set_url(new, f) + request.url = new f.live.change_server(("127.0.0.1", self.redirect_port), False) - request.set_url(url, f) + request.url = url tutils.raises("SSL handshake error", f.live.change_server, ("127.0.0.1", self.redirect_port), True) - request.set_url(new, f) - request.set_url(url, f) - request.set_url(new, f) + request.url = new tservers.TestMaster.handle_request(self, f) def handle_response(self, f): -- cgit v1.2.3 From 649e63ff3c868397f493e1dabdc1c63d572aedd8 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 4 Sep 2014 00:10:01 +0200 Subject: fix some leftovers --- examples/redirect_requests.py | 2 +- libmproxy/console/common.py | 7 +++---- libmproxy/console/flowview.py | 4 ++-- libmproxy/flow.py | 3 +-- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/examples/redirect_requests.py b/examples/redirect_requests.py index a9a7e795..530da200 100644 --- a/examples/redirect_requests.py +++ b/examples/redirect_requests.py @@ -12,7 +12,7 @@ def request(context, flow): [1, 1], 200, "OK", ODictCaseless([["Content-Type", "text/html"]]), "helloworld") - flow.request.reply(resp) + flow.reply(resp) if flow.request.get_host(hostheader=True).endswith("example.org"): flow.request.host = "mitmproxy.org" flow.request.headers["Host"] = ["mitmproxy.org"] diff --git a/libmproxy/console/common.py b/libmproxy/console/common.py index a8440f79..5cb3dd2a 100644 --- a/libmproxy/console/common.py +++ b/libmproxy/console/common.py @@ -108,7 +108,7 @@ def raw_format_flow(f, focus, extended, padding): preamble = sum(i[1] for i in req) + len(req) -1 - if f["intercepting"] and not f["req_acked"]: + if f["intercepting"] and not f["acked"]: uc = "intercept" elif f["resp_code"] or f["err_msg"]: uc = "text" @@ -138,7 +138,7 @@ def raw_format_flow(f, focus, extended, padding): if f["resp_is_replay"]: resp.append(fcol(SYMBOL_REPLAY, "replay")) resp.append(fcol(f["resp_code"], ccol)) - if f["intercepting"] and f["resp_code"] and not f["resp_acked"]: + if f["intercepting"] and f["resp_code"] and not f["acked"]: rc = "intercept" else: rc = "text" @@ -172,11 +172,11 @@ flowcache = FlowCache() def format_flow(f, focus, extended=False, hostheader=False, padding=2): d = dict( intercepting = f.intercepting, + acked = f.reply.acked, req_timestamp = f.request.timestamp_start, req_is_replay = f.request.is_replay, req_method = f.request.method, - req_acked = f.request.reply.acked, req_url = f.request.get_url(hostheader=hostheader), err_msg = f.error.msg if f.error else None, @@ -197,7 +197,6 @@ def format_flow(f, focus, extended=False, hostheader=False, padding=2): d.update(dict( resp_code = f.response.code, resp_is_replay = f.response.is_replay, - resp_acked = f.response.reply.acked, resp_clen = contentdesc, resp_rate = "{0}/s".format(rate), )) diff --git a/libmproxy/console/flowview.py b/libmproxy/console/flowview.py index 4aaf8944..3c63ac29 100644 --- a/libmproxy/console/flowview.py +++ b/libmproxy/console/flowview.py @@ -233,7 +233,7 @@ class FlowView(common.WWrap): def wrap_body(self, active, body): parts = [] - if self.flow.intercepting and not self.flow.request.reply.acked: + if self.flow.intercepting and not self.flow.reply.acked and not self.flow.response: qt = "Request intercepted" else: qt = "Request" @@ -242,7 +242,7 @@ class FlowView(common.WWrap): else: parts.append(self._tab(qt, "heading_inactive")) - if self.flow.intercepting and self.flow.response and not self.flow.response.reply.acked: + if self.flow.intercepting and not self.flow.reply.acked and self.flow.response: st = "Response intercepted" else: st = "Response" diff --git a/libmproxy/flow.py b/libmproxy/flow.py index eb183d9f..df72878f 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -183,8 +183,7 @@ class ClientPlaybackState: """ if self.flows and not self.current: n = self.flows.pop(0) - n.request.reply = controller.DummyReply() - n.client_conn = None + n.reply = controller.DummyReply() self.current = master.handle_request(n) if not testing and not self.current.response: master.replay_request(self.current) # pragma: no cover -- cgit v1.2.3 From f4d4332472c7fa68014996a1d55b37911d1515f9 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 4 Sep 2014 14:46:25 +0200 Subject: coverage++ --- libmproxy/protocol/tcp.py | 4 ++-- test/test_protocol_http.py | 31 +++++++++++++++++++++++++++++++ test/test_protocol_tcp.py | 2 ++ 3 files changed, 35 insertions(+), 2 deletions(-) create mode 100644 test/test_protocol_tcp.py diff --git a/libmproxy/protocol/tcp.py b/libmproxy/protocol/tcp.py index 57a48ab9..990c502a 100644 --- a/libmproxy/protocol/tcp.py +++ b/libmproxy/protocol/tcp.py @@ -59,11 +59,11 @@ class TCPHandler(ProtocolHandler): # if one of the peers is over SSL, we need to send bytes/strings if not src.ssl_established: # only ssl to dst, i.e. we revc'd into buf but need bytes/string now. contents = buf[:size].tobytes() - # self.c.log("%s %s\r\n%s" % (direction, dst_str, cleanBin(contents)), "debug") + self.c.log("%s %s\r\n%s" % (direction, dst_str, cleanBin(contents)), "debug") dst.connection.send(contents) else: # socket.socket.send supports raw bytearrays/memoryviews - # self.c.log("%s %s\r\n%s" % (direction, dst_str, cleanBin(buf.tobytes())), "debug") + self.c.log("%s %s\r\n%s" % (direction, dst_str, cleanBin(buf.tobytes())), "debug") dst.connection.send(buf[:size]) except socket.error as e: self.c.log("TCP connection closed unexpectedly.", "debug") diff --git a/test/test_protocol_http.py b/test/test_protocol_http.py index c76fa192..bcbdd5d0 100644 --- a/test/test_protocol_http.py +++ b/test/test_protocol_http.py @@ -37,6 +37,19 @@ class TestHTTPRequest: def test_origin_form(self): s = StringIO("GET /foo\xff HTTP/1.1") tutils.raises("Bad HTTP request line", HTTPRequest.from_stream, s) + s = StringIO("GET /foo HTTP/1.1\r\nConnection: Upgrade\r\nUpgrade: h2c") + r = HTTPRequest.from_stream(s) + assert r.headers["Upgrade"] == ["h2c"] + + raw = r._assemble_headers() + assert "Upgrade" not in raw + assert "Host" not in raw + + r.url = "http://example.com/foo" + + raw = r._assemble_headers() + assert "Host" in raw + def test_authority_form(self): s = StringIO("CONNECT oops-no-port.com HTTP/1.1") @@ -45,6 +58,7 @@ class TestHTTPRequest: r = HTTPRequest.from_stream(s) r.scheme, r.host, r.port = "http", "address", 22 assert r._assemble() == "CONNECT address:22 HTTP/1.1\r\nHost: address:22\r\n\r\n" + assert r.pretty_url(False) == "address:22" def test_absolute_form(self): s = StringIO("GET oops-no-protocol.com HTTP/1.1") @@ -65,6 +79,10 @@ class TestHTTPRequest: assert r.port == 42 assert r.path == "/ORLY" + def test_repr(self): + r = tutils.treq() + assert repr(r) + class TestHTTPResponse: def test_read_from_stringio(self): @@ -86,6 +104,19 @@ class TestHTTPResponse: assert r.content == "" tutils.raises("Invalid server response: 'content", HTTPResponse.from_stream, s, "GET") + def test_repr(self): + r = tutils.tresp() + assert "unknown content type" in repr(r) + r.headers["content-type"] = ["foo"] + assert "foo" in repr(r) + assert repr(tutils.tresp(content=CONTENT_MISSING)) + + +class TestHTTPFlow(object): + def test_repr(self): + f = tutils.tflow(resp=True, err=True) + assert repr(f) + class TestInvalidRequests(tservers.HTTPProxTest): ssl = True diff --git a/test/test_protocol_tcp.py b/test/test_protocol_tcp.py new file mode 100644 index 00000000..7236ee67 --- /dev/null +++ b/test/test_protocol_tcp.py @@ -0,0 +1,2 @@ +class TestTcp: + pass \ No newline at end of file -- cgit v1.2.3 From 795e19f6b7803f18a3bf5e8111493ed54a3d2e00 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 4 Sep 2014 16:37:50 +0200 Subject: coverage++ --- libmproxy/protocol/primitives.py | 2 +- libmproxy/proxy/connection.py | 13 ++++--------- test/test_protocol_tcp.py | 23 +++++++++++++++++++++-- test/test_proxy.py | 2 +- test/test_server.py | 6 ++++-- test/tservers.py | 23 +++++++++++++++++++++++ 6 files changed, 54 insertions(+), 15 deletions(-) diff --git a/libmproxy/protocol/primitives.py b/libmproxy/protocol/primitives.py index 416e6880..ee1199fc 100644 --- a/libmproxy/protocol/primitives.py +++ b/libmproxy/protocol/primitives.py @@ -168,7 +168,7 @@ class LiveConnection(object): persistent_change ), "debug") - if self._backup_server_conn: + if not self._backup_server_conn: self._backup_server_conn = self.c.server_conn self.c.server_conn = None else: # This is at least the second temporary change. We can kill the current connection. diff --git a/libmproxy/proxy/connection.py b/libmproxy/proxy/connection.py index 5c421557..de8e20d8 100644 --- a/libmproxy/proxy/connection.py +++ b/libmproxy/proxy/connection.py @@ -76,8 +76,6 @@ class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject): tcp.TCPClient.__init__(self, address) self.state = [] # a list containing (conntype, state) tuples - self.peername = None - self.sockname = None self.timestamp_start = None self.timestamp_end = None self.timestamp_tcp_setup = None @@ -98,8 +96,6 @@ class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject): _stateobject_attributes = dict( state=list, - peername=tuple, - sockname=tuple, timestamp_start=float, timestamp_end=float, timestamp_tcp_setup=float, @@ -114,9 +110,10 @@ class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject): def _get_state(self): d = super(ServerConnection, self)._get_state() d.update( - address={"address": self.address(), "use_ipv6": self.address.use_ipv6}, - source_address= {"address": self.source_address(), - "use_ipv6": self.source_address.use_ipv6} if self.source_address else None, + address={"address": self.address(), + "use_ipv6": self.address.use_ipv6}, + source_address= ({"address": self.source_address(), + "use_ipv6": self.source_address.use_ipv6} if self.source_address else None), cert=self.cert.to_pem() if self.cert else None ) return d @@ -140,8 +137,6 @@ class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject): def connect(self): self.timestamp_start = utils.timestamp() tcp.TCPClient.connect(self) - self.peername = self.connection.getpeername() - self.sockname = self.connection.getsockname() self.timestamp_tcp_setup = utils.timestamp() def send(self, message): diff --git a/test/test_protocol_tcp.py b/test/test_protocol_tcp.py index 7236ee67..8b6bb68d 100644 --- a/test/test_protocol_tcp.py +++ b/test/test_protocol_tcp.py @@ -1,2 +1,21 @@ -class TestTcp: - pass \ No newline at end of file +import tservers +from netlib.certutils import SSLCert + +class TestTcp(tservers.IgnoreProxTest): + ignore = [] + + def test_simple(self): + # i = ignore (tcp passthrough), n = normal + pi, pn = self.pathocs() + i = pi.request("get:'/p/304'") + i2 = pi.request("get:'/p/304'") + n = pn.request("get:'/p/304'") + + assert i.status_code == i2.status_code == n.status_code == 304 + + i_cert = SSLCert(i.sslinfo.certchain[0]) + i2_cert = SSLCert(i2.sslinfo.certchain[0]) + n_cert = SSLCert(n.sslinfo.certchain[0]) + + assert i_cert == i2_cert + assert not i_cert == n_cert \ No newline at end of file diff --git a/test/test_proxy.py b/test/test_proxy.py index ad2bb2d7..f762e610 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -100,7 +100,7 @@ class TestProcessProxyOptions: class TestProxyServer: - @tutils.SkipWindows # binding to 0.0.0.0:1 works without special permissions on Windows + @tutils.SkipWindows # binding to 0.0.0.0:1 works without special permissions on Windows def test_err(self): parser = argparse.ArgumentParser() cmdline.common_options(parser) diff --git a/test/test_server.py b/test/test_server.py index 21d01f5a..ed5133cb 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -312,7 +312,7 @@ class TestProxy(tservers.HTTPProxTest): f = self.pathod("200:b@100") assert f.status_code == 200 f = self.master.state.view[0] - assert f.server_conn.peername == ("127.0.0.1", self.server.port) + assert f.server_conn.address == ("127.0.0.1", self.server.port) class TestProxySSL(tservers.HTTPProxTest): ssl=True @@ -342,6 +342,7 @@ class MasterRedirectRequest(tservers.TestMaster): def handle_response(self, f): f.response.content = str(f.client_conn.address.port) + f.response.headers["server-conn-id"] = [str(f.server_conn.source_address.port)] tservers.TestMaster.handle_response(self, f) @@ -374,7 +375,8 @@ class TestRedirectRequest(tservers.HTTPProxTest): assert self.server.last_log() assert not self.server2.last_log() - assert r3.content == r2.content == r1.content + assert r1.content == r2.content == r3.content + assert r1.headers.get_first("server-conn-id") == r3.headers.get_first("server-conn-id") # Make sure that we actually use the same connection in this test case class MasterStreamRequest(tservers.TestMaster): diff --git a/test/tservers.py b/test/tservers.py index 9f2abbe1..91743903 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -263,6 +263,29 @@ class ReverseProxTest(ProxTestBase): return p.request(q) +class IgnoreProxTest(ProxTestBase): + ssl = True + + @classmethod + def get_proxy_config(cls): + d = super(IgnoreProxTest, cls).get_proxy_config() + d["ignore"] = [".+:%s" % cls.server.port] # ignore by port + return d + + def pathoc_raw(self): + return libpathod.pathoc.Pathoc(("127.0.0.1", self.proxy.port), ssl=self.ssl) + + def pathocs(self): + """ + Returns a (pathod_ignore, pathoc_normal) tuple. + """ + p_ignore = self.pathoc_raw() + p_ignore.connect(("127.0.0.1", self.server.port)) + p_normal = self.pathoc_raw() + p_normal.connect(("127.0.0.1", self.server2.port)) + return p_ignore, p_normal + + class ChainProxTest(ProxTestBase): """ Chain n instances of mitmproxy in a row - because we can. -- cgit v1.2.3 From b23a1aa4a4dd9f09fc199d03f546a8fafc8b27b8 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Thu, 4 Sep 2014 19:08:54 +0200 Subject: much tests. so tcp. very wow. --- libmproxy/flow.py | 3 ++- libmproxy/stateobject.py | 3 +++ test/test_protocol_http.py | 3 +++ test/test_protocol_tcp.py | 21 --------------------- test/test_proxy.py | 9 +++++++++ test/test_server.py | 47 +++++++++++++++++++++++++++++++++++++++++----- test/tservers.py | 23 ----------------------- 7 files changed, 59 insertions(+), 50 deletions(-) delete mode 100644 test/test_protocol_tcp.py diff --git a/libmproxy/flow.py b/libmproxy/flow.py index eeb53e81..086710bc 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -612,6 +612,7 @@ class FlowMaster(controller.Master): if f.request: self.handle_request(f) if f.response: + self.handle_responseheaders(f) self.handle_response(f) if f.error: self.handle_error(f) @@ -668,7 +669,7 @@ class FlowMaster(controller.Master): self.masterq, self.should_exit ) - rt.start() # pragma: no cover + rt.start() # pragma: no cover if block: rt.join() diff --git a/libmproxy/stateobject.py b/libmproxy/stateobject.py index 3437b90e..6fb73c24 100644 --- a/libmproxy/stateobject.py +++ b/libmproxy/stateobject.py @@ -21,6 +21,9 @@ class StateObject(object): except AttributeError: # we may compare with something that's not a StateObject return False + def __ne__(self, other): + return not self.__eq__(other) + class SimpleStateObject(StateObject): """ diff --git a/test/test_protocol_http.py b/test/test_protocol_http.py index bcbdd5d0..3ca590f1 100644 --- a/test/test_protocol_http.py +++ b/test/test_protocol_http.py @@ -49,6 +49,9 @@ class TestHTTPRequest: raw = r._assemble_headers() assert "Host" in raw + assert not "Host" in r.headers + r.update_host_header() + assert "Host" in r.headers def test_authority_form(self): diff --git a/test/test_protocol_tcp.py b/test/test_protocol_tcp.py deleted file mode 100644 index 8b6bb68d..00000000 --- a/test/test_protocol_tcp.py +++ /dev/null @@ -1,21 +0,0 @@ -import tservers -from netlib.certutils import SSLCert - -class TestTcp(tservers.IgnoreProxTest): - ignore = [] - - def test_simple(self): - # i = ignore (tcp passthrough), n = normal - pi, pn = self.pathocs() - i = pi.request("get:'/p/304'") - i2 = pi.request("get:'/p/304'") - n = pn.request("get:'/p/304'") - - assert i.status_code == i2.status_code == n.status_code == 304 - - i_cert = SSLCert(i.sslinfo.certchain[0]) - i2_cert = SSLCert(i2.sslinfo.certchain[0]) - n_cert = SSLCert(n.sslinfo.certchain[0]) - - assert i_cert == i2_cert - assert not i_cert == n_cert \ No newline at end of file diff --git a/test/test_proxy.py b/test/test_proxy.py index f762e610..e65841f4 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -42,6 +42,15 @@ class TestServerConnection: sc.connection.flush = mock.Mock(side_effect=tcp.NetLibDisconnect) sc.finish() + def test_repr(self): + sc = tutils.tserver_conn() + assert "address:22" in repr(sc) + assert "ssl" not in repr(sc) + sc.ssl_established = True + assert "ssl" in repr(sc) + sc.sni = "foo" + assert "foo" in repr(sc) + class TestProcessProxyOptions: def p(self, *args): diff --git a/test/test_server.py b/test/test_server.py index ed5133cb..a3fff0f1 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -1,7 +1,9 @@ import socket, time import mock +from libmproxy.proxy.config import ProxyConfig from netlib import tcp, http_auth, http from libpathod import pathoc, pathod +from netlib.certutils import SSLCert import tutils, tservers from libmproxy import flow from libmproxy.protocol import KILL @@ -55,6 +57,42 @@ class CommonMixin: line = t.rfile.readline() assert ("Bad Request" in line) or ("Bad Gateway" in line) +class TcpMixin: + def _ignore_on(self): + conf = ProxyConfig(ignore=[".+:%s" % self.server.port]) + self.config.ignore.append(conf.ignore[0]) + + def _ignore_off(self): + self.config.ignore.pop() + + def test_ignore(self): + spec = '304:h"Alternate-Protocol"="mitmproxy-will-remove-this"' + n = self.pathod(spec) + self._ignore_on() + i = self.pathod(spec) + i2 = self.pathod(spec) + self._ignore_off() + + assert i.status_code == i2.status_code == n.status_code == 304 + assert "Alternate-Protocol" in i.headers + assert "Alternate-Protocol" in i2.headers + assert "Alternate-Protocol" not in n.headers + + # Test that we get the original SSL cert + if self.ssl: + i_cert = SSLCert(i.sslinfo.certchain[0]) + i2_cert = SSLCert(i2.sslinfo.certchain[0]) + n_cert = SSLCert(n.sslinfo.certchain[0]) + + assert i_cert == i2_cert + assert i_cert != n_cert + + # Test Non-HTTP traffic + spec = "200:i0,@100:d0" # this results in just 100 random bytes + assert self.pathod(spec).status_code == 502 # mitmproxy responds with bad gateway + self._ignore_on() + tutils.raises("invalid server response", self.pathod, spec) # pathoc tries to parse answer as HTTP + self._ignore_off() class AppMixin: @@ -64,7 +102,6 @@ class AppMixin: assert "mitmproxy" in ret.content - class TestHTTP(tservers.HTTPProxTest, CommonMixin, AppMixin): def test_app_err(self): p = self.pathoc() @@ -175,7 +212,7 @@ class TestHTTPConnectSSLError(tservers.HTTPProxTest): tutils.raises("502 - Bad Gateway", p.http_connect, dst) -class TestHTTPS(tservers.HTTPProxTest, CommonMixin): +class TestHTTPS(tservers.HTTPProxTest, CommonMixin, TcpMixin): ssl = True ssloptions = pathod.SSLOptions(request_client_cert=True) clientcerts = True @@ -217,15 +254,15 @@ class TestHTTPSNoCommonName(tservers.HTTPProxTest): assert f.sslinfo.certchain[0].get_subject().CN == "127.0.0.1" -class TestReverse(tservers.ReverseProxTest, CommonMixin): +class TestReverse(tservers.ReverseProxTest, CommonMixin, TcpMixin): reverse = True -class TestTransparent(tservers.TransparentProxTest, CommonMixin): +class TestTransparent(tservers.TransparentProxTest, CommonMixin, TcpMixin): ssl = False -class TestTransparentSSL(tservers.TransparentProxTest, CommonMixin): +class TestTransparentSSL(tservers.TransparentProxTest, CommonMixin, TcpMixin): ssl = True def test_sni(self): f = self.pathod("304", sni="testserver.com") diff --git a/test/tservers.py b/test/tservers.py index 91743903..9f2abbe1 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -263,29 +263,6 @@ class ReverseProxTest(ProxTestBase): return p.request(q) -class IgnoreProxTest(ProxTestBase): - ssl = True - - @classmethod - def get_proxy_config(cls): - d = super(IgnoreProxTest, cls).get_proxy_config() - d["ignore"] = [".+:%s" % cls.server.port] # ignore by port - return d - - def pathoc_raw(self): - return libpathod.pathoc.Pathoc(("127.0.0.1", self.proxy.port), ssl=self.ssl) - - def pathocs(self): - """ - Returns a (pathod_ignore, pathoc_normal) tuple. - """ - p_ignore = self.pathoc_raw() - p_ignore.connect(("127.0.0.1", self.server.port)) - p_normal = self.pathoc_raw() - p_normal.connect(("127.0.0.1", self.server2.port)) - return p_ignore, p_normal - - class ChainProxTest(ProxTestBase): """ Chain n instances of mitmproxy in a row - because we can. -- cgit v1.2.3 From a7a3b5703adff7de12fb479a90ea2628465a4486 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 5 Sep 2014 00:18:17 +0200 Subject: change replay_request behaviour, refs #346; test upstream proxy mode --- libmproxy/protocol/http.py | 43 +++++------ test/test_protocol_http.py | 119 +----------------------------- test/test_script.py | 2 +- test/test_server.py | 175 +++++++++++++++++++++++++++++++++++++++------ test/tservers.py | 107 ++++++++++++--------------- 5 files changed, 223 insertions(+), 223 deletions(-) diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index 253192dd..90ee127c 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -1203,33 +1203,28 @@ class RequestReplayThread(threading.Thread): threading.Thread.__init__(self) def run(self): + r = self.flow.request + form_out_backup = r.form_out try: - r = self.flow.request - form_out_backup = r.form_out - - r.form_out = self.config.http_form_out - server_address, server_ssl = False, False - # If the flow is live, r.host is already the correct upstream server unless modified by a script. - # If modified by a script, we probably want to keep the modified destination. - if self.config.get_upstream_server and not self.flow.live: - try: - # this will fail in transparent mode - upstream_info = self.config.get_upstream_server(self.flow.client_conn) - server_ssl = upstream_info[1] - server_address = upstream_info[2:] - except proxy.ProxyError: - pass - if not server_address: - server_address = (r.host, r.port) - - server = ServerConnection(server_address) - server.connect() - - if server_ssl or r.scheme == "https": - if self.config.http_form_out == "absolute": # form_out == absolute -> forward mode -> send CONNECT + # In all modes, we directly connect to the server displayed + if self.config.http_form_out == "absolute": # form_out == absolute -> forward mode + server_address = self.config.get_upstream_server(self.flow.client_conn)[2:] + server = ServerConnection(server_address) + server.connect() + if r.scheme == "https": send_connect_request(server, r.host, r.port) + server.establish_ssl(self.config.clientcerts, sni=r.host) r.form_out = "relative" - server.establish_ssl(self.config.clientcerts, sni=r.host) + else: + r.form_out = "absolute" + else: + server_address = (r.host, r.port) + server = ServerConnection(server_address) + server.connect() + if r.scheme == "https": + server.establish_ssl(self.config.clientcerts, sni=r.host) + r.form_out = "relative" + server.send(r._assemble()) self.flow.response = HTTPResponse.from_stream(server.rfile, r.method, body_size_limit=self.config.body_size_limit) diff --git a/test/test_protocol_http.py b/test/test_protocol_http.py index 3ca590f1..41019672 100644 --- a/test/test_protocol_http.py +++ b/test/test_protocol_http.py @@ -1,5 +1,4 @@ from libmproxy.protocol.http import * -from libmproxy.protocol import KILL from cStringIO import StringIO import tutils, tservers @@ -134,120 +133,4 @@ class TestInvalidRequests(tservers.HTTPProxTest): p.connect() r = p.request("get:/p/200") assert r.status_code == 400 - assert "Invalid HTTP request form" in r.content - - -class TestProxyChaining(tservers.HTTPChainProxyTest): - def test_all(self): - self.chain[1].tmaster.replacehooks.add("~q", "foo", "bar") # replace in request - self.chain[0].tmaster.replacehooks.add("~q", "foo", "oh noes!") - self.proxy.tmaster.replacehooks.add("~q", "bar", "baz") - self.chain[0].tmaster.replacehooks.add("~s", "baz", "ORLY") # replace in response - - p = self.pathoc() - req = p.request("get:'%s/p/418:b\"foo\"'" % self.server.urlbase) - assert req.content == "ORLY" - assert req.status_code == 418 - -class TestProxyChainingSSL(tservers.HTTPChainProxyTest): - ssl = True - def test_simple(self): - p = self.pathoc() - req = p.request("get:'/p/418:b\"content\"'") - assert req.content == "content" - assert req.status_code == 418 - - assert self.chain[1].tmaster.state.flow_count() == 2 # CONNECT from pathoc to chain[0], - # request from pathoc to chain[0] - assert self.chain[0].tmaster.state.flow_count() == 2 # CONNECT from chain[1] to proxy, - # request from chain[1] to proxy - assert self.proxy.tmaster.state.flow_count() == 1 # request from chain[0] (regular proxy doesn't store CONNECTs) - - def test_closing_connect_response(self): - """ - https://github.com/mitmproxy/mitmproxy/issues/313 - """ - def handle_request(f): - f.request.httpversion = (1, 0) - del f.request.headers["Content-Length"] - f.reply() - _handle_request = self.chain[0].tmaster.handle_request - self.chain[0].tmaster.handle_request = handle_request - try: - assert self.pathoc().request("get:/p/418").status_code == 418 - finally: - self.chain[0].tmaster.handle_request = _handle_request - - def test_sni(self): - p = self.pathoc(sni="foo.com") - req = p.request("get:'/p/418:b\"content\"'") - assert req.content == "content" - assert req.status_code == 418 - -class TestProxyChainingSSLReconnect(tservers.HTTPChainProxyTest): - ssl = True - - def test_reconnect(self): - """ - Tests proper functionality of ConnectionHandler.server_reconnect mock. - If we have a disconnect on a secure connection that's transparently proxified to - an upstream http proxy, we need to send the CONNECT request again. - """ - def kill_requests(master, attr, exclude): - k = [0] # variable scope workaround: put into array - _func = getattr(master, attr) - def handler(f): - k[0] += 1 - if not (k[0] in exclude): - f.client_conn.finish() - f.error = Error("terminated") - f.reply(KILL) - return _func(f) - setattr(master, attr, handler) - - kill_requests(self.proxy.tmaster, "handle_request", - exclude=[ - # fail first request - 2, # allow second request - ]) - - kill_requests(self.chain[0].tmaster, "handle_request", - exclude=[ - 1, # CONNECT - # fail first request - 3, # reCONNECT - 4, # request - ]) - - p = self.pathoc() - req = p.request("get:'/p/418:b\"content\"'") - assert self.chain[1].tmaster.state.flow_count() == 2 # CONNECT and request - assert self.chain[0].tmaster.state.flow_count() == 4 # CONNECT, failing request, - # reCONNECT, request - assert self.proxy.tmaster.state.flow_count() == 2 # failing request, request - # (doesn't store (repeated) CONNECTs from chain[0] - # as it is a regular proxy) - assert req.content == "content" - assert req.status_code == 418 - - assert not self.proxy.tmaster.state._flow_list[0].response # killed - assert self.proxy.tmaster.state._flow_list[1].response - - assert self.chain[1].tmaster.state._flow_list[0].request.form_in == "authority" - assert self.chain[1].tmaster.state._flow_list[1].request.form_in == "relative" - - assert self.chain[0].tmaster.state._flow_list[0].request.form_in == "authority" - assert self.chain[0].tmaster.state._flow_list[1].request.form_in == "relative" - assert self.chain[0].tmaster.state._flow_list[2].request.form_in == "authority" - assert self.chain[0].tmaster.state._flow_list[3].request.form_in == "relative" - - assert self.proxy.tmaster.state._flow_list[0].request.form_in == "relative" - assert self.proxy.tmaster.state._flow_list[1].request.form_in == "relative" - - req = p.request("get:'/p/418:b\"content2\"'") - - assert req.status_code == 502 - assert self.chain[1].tmaster.state.flow_count() == 3 # + new request - assert self.chain[0].tmaster.state.flow_count() == 6 # + new request, repeated CONNECT from chain[1] - # (both terminated) - assert self.proxy.tmaster.state.flow_count() == 2 # nothing happened here + assert "Invalid HTTP request form" in r.content \ No newline at end of file diff --git a/test/test_script.py b/test/test_script.py index 7c421fde..aed7def1 100644 --- a/test/test_script.py +++ b/test/test_script.py @@ -99,7 +99,7 @@ class TestScript: d = Dummy() assert s.run(hook, d)[0] d.reply() - while (time.time() - t_start) < 5 and m.call_count <= 5: + while (time.time() - t_start) < 20 and m.call_count <= 5: if m.call_count == 5: return time.sleep(0.001) diff --git a/test/test_server.py b/test/test_server.py index a3fff0f1..a6d591ed 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -1,12 +1,10 @@ import socket, time -import mock from libmproxy.proxy.config import ProxyConfig from netlib import tcp, http_auth, http from libpathod import pathoc, pathod from netlib.certutils import SSLCert import tutils, tservers -from libmproxy import flow -from libmproxy.protocol import KILL +from libmproxy.protocol import KILL, Error from libmproxy.protocol.http import CONTENT_MISSING """ @@ -23,8 +21,11 @@ class CommonMixin: def test_replay(self): assert self.pathod("304").status_code == 304 - assert len(self.master.state.view) == 1 - l = self.master.state.view[0] + if isinstance(self, tservers.HTTPUpstreamProxTest) and self.ssl: + assert len(self.master.state.view) == 2 + else: + assert len(self.master.state.view) == 1 + l = self.master.state.view[-1] assert l.response.code == 304 l.request.path = "/p/305" rt = self.master.replay_request(l, block=True) @@ -33,18 +34,28 @@ class CommonMixin: # Disconnect error l.request.path = "/p/305:d0" rt = self.master.replay_request(l, block=True) - assert l.error + assert not rt + if isinstance(self, tservers.HTTPUpstreamProxTest): + assert l.response.code == 502 + else: + assert l.error # Port error l.request.port = 1 - self.master.replay_request(l, block=True) - assert l.error + # In upstream mode, we get a 502 response from the upstream proxy server. + # In upstream mode with ssl, the replay will fail as we cannot establish SSL with the upstream proxy. + rt = self.master.replay_request(l, block=True) + assert not rt + if isinstance(self, tservers.HTTPUpstreamProxTest) and not self.ssl: + assert l.response.code == 502 + else: + assert l.error def test_http(self): f = self.pathod("304") assert f.status_code == 304 - l = self.master.state.view[0] + l = self.master.state.view[-1] # In Upstream mode with SSL, we may already have a previous CONNECT request. assert l.client_conn.address assert "host" in l.request.headers assert l.response.code == 304 @@ -57,6 +68,15 @@ class CommonMixin: line = t.rfile.readline() assert ("Bad Request" in line) or ("Bad Gateway" in line) + def test_sni(self): + if not self.ssl: + return + + f = self.pathod("304", sni="testserver.com") + assert f.status_code == 304 + log = self.server.last_log() + assert log["request"]["sni"] == "testserver.com" + class TcpMixin: def _ignore_on(self): conf = ProxyConfig(ignore=[".+:%s" % self.server.port]) @@ -221,12 +241,6 @@ class TestHTTPS(tservers.HTTPProxTest, CommonMixin, TcpMixin): assert f.status_code == 304 assert self.server.last_log()["request"]["clientcert"]["keyinfo"] - def test_sni(self): - f = self.pathod("304", sni="testserver.com") - assert f.status_code == 304 - l = self.server.last_log() - assert self.server.last_log()["request"]["sni"] == "testserver.com" - def test_error_post_connect(self): p = self.pathoc() assert p.request("get:/:i0,'invalid\r\n\r\n'").status_code == 400 @@ -264,11 +278,6 @@ class TestTransparent(tservers.TransparentProxTest, CommonMixin, TcpMixin): class TestTransparentSSL(tservers.TransparentProxTest, CommonMixin, TcpMixin): ssl = True - def test_sni(self): - f = self.pathod("304", sni="testserver.com") - assert f.status_code == 304 - l = self.server.last_log() - assert l["request"]["sni"] == "testserver.com" def test_sslerr(self): p = pathoc.Pathoc(("localhost", self.proxy.port)) @@ -538,6 +547,132 @@ class TestIncompleteResponse(tservers.HTTPProxTest): class TestCertForward(tservers.HTTPProxTest): certforward = True ssl = True + def test_app_err(self): tutils.raises("handshake error", self.pathod, "200:b@100") + +class TestUpstreamProxy(tservers.HTTPUpstreamProxTest, CommonMixin, AppMixin): + ssl = False + + def test_order(self): + self.proxy.tmaster.replacehooks.add("~q", "foo", "bar") # replace in request + self.chain[0].tmaster.replacehooks.add("~q", "bar", "baz") + self.chain[1].tmaster.replacehooks.add("~q", "foo", "oh noes!") + self.chain[0].tmaster.replacehooks.add("~s", "baz", "ORLY") # replace in response + + p = self.pathoc() + req = p.request("get:'%s/p/418:b\"foo\"'" % self.server.urlbase) + assert req.content == "ORLY" + assert req.status_code == 418 + + +class TestUpstreamProxySSL(tservers.HTTPUpstreamProxTest, CommonMixin, TcpMixin): + ssl = True + + def _ignore_on(self): + super(TestUpstreamProxySSL, self)._ignore_on() + conf = ProxyConfig(ignore=[".+:%s" % self.server.port]) + for proxy in self.chain: + proxy.tmaster.server.config.ignore.append(conf.ignore[0]) + + def _ignore_off(self): + super(TestUpstreamProxySSL, self)._ignore_off() + for proxy in self.chain: + proxy.tmaster.server.config.ignore.pop() + + def test_simple(self): + p = self.pathoc() + req = p.request("get:'/p/418:b\"content\"'") + assert req.content == "content" + assert req.status_code == 418 + + assert self.proxy.tmaster.state.flow_count() == 2 # CONNECT from pathoc to chain[0], + # request from pathoc to chain[0] + assert self.chain[0].tmaster.state.flow_count() == 2 # CONNECT from proxy to chain[1], + # request from proxy to chain[1] + assert self.chain[1].tmaster.state.flow_count() == 1 # request from chain[0] (regular proxy doesn't store CONNECTs) + + def test_closing_connect_response(self): + """ + https://github.com/mitmproxy/mitmproxy/issues/313 + """ + def handle_request(f): + f.request.httpversion = (1, 0) + del f.request.headers["Content-Length"] + f.reply() + _handle_request = self.chain[0].tmaster.handle_request + self.chain[0].tmaster.handle_request = handle_request + try: + assert self.pathoc().request("get:/p/418").status_code == 418 + finally: + self.chain[0].tmaster.handle_request = _handle_request + + +class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxTest): + ssl = True + + def test_reconnect(self): + """ + Tests proper functionality of ConnectionHandler.server_reconnect mock. + If we have a disconnect on a secure connection that's transparently proxified to + an upstream http proxy, we need to send the CONNECT request again. + """ + def kill_requests(master, attr, exclude): + k = [0] # variable scope workaround: put into array + _func = getattr(master, attr) + def handler(f): + k[0] += 1 + if not (k[0] in exclude): + f.client_conn.finish() + f.error = Error("terminated") + f.reply(KILL) + return _func(f) + setattr(master, attr, handler) + + kill_requests(self.chain[1].tmaster, "handle_request", + exclude=[ + # fail first request + 2, # allow second request + ]) + + kill_requests(self.chain[0].tmaster, "handle_request", + exclude=[ + 1, # CONNECT + # fail first request + 3, # reCONNECT + 4, # request + ]) + + p = self.pathoc() + req = p.request("get:'/p/418:b\"content\"'") + assert self.proxy.tmaster.state.flow_count() == 2 # CONNECT and request + assert self.chain[0].tmaster.state.flow_count() == 4 # CONNECT, failing request, + # reCONNECT, request + assert self.chain[1].tmaster.state.flow_count() == 2 # failing request, request + # (doesn't store (repeated) CONNECTs from chain[0] + # as it is a regular proxy) + assert req.content == "content" + assert req.status_code == 418 + + assert not self.chain[1].tmaster.state._flow_list[0].response # killed + assert self.chain[1].tmaster.state._flow_list[1].response + + assert self.proxy.tmaster.state._flow_list[0].request.form_in == "authority" + assert self.proxy.tmaster.state._flow_list[1].request.form_in == "relative" + + assert self.chain[0].tmaster.state._flow_list[0].request.form_in == "authority" + assert self.chain[0].tmaster.state._flow_list[1].request.form_in == "relative" + assert self.chain[0].tmaster.state._flow_list[2].request.form_in == "authority" + assert self.chain[0].tmaster.state._flow_list[3].request.form_in == "relative" + + assert self.chain[1].tmaster.state._flow_list[0].request.form_in == "relative" + assert self.chain[1].tmaster.state._flow_list[1].request.form_in == "relative" + + req = p.request("get:'/p/418:b\"content2\"'") + + assert req.status_code == 502 + assert self.proxy.tmaster.state.flow_count() == 3 # + new request + assert self.chain[0].tmaster.state.flow_count() == 6 # + new request, repeated CONNECT from chain[1] + # (both terminated) + assert self.chain[1].tmaster.state.flow_count() == 2 # nothing happened here diff --git a/test/tservers.py b/test/tservers.py index 9f2abbe1..8a2e72a4 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -84,29 +84,19 @@ class ProxTestBase(object): masterclass = TestMaster externalapp = False certforward = False + @classmethod def setupAll(cls): cls.server = libpathod.test.Daemon(ssl=cls.ssl, ssloptions=cls.ssloptions) cls.server2 = libpathod.test.Daemon(ssl=cls.ssl, ssloptions=cls.ssloptions) - pconf = cls.get_proxy_config() - cls.confdir = os.path.join(tempfile.gettempdir(), "mitmproxy") - cls.config = ProxyConfig( - no_upstream_cert = cls.no_upstream_cert, - confdir = cls.confdir, - authenticator = cls.authenticator, - certforward = cls.certforward, - ssl_ports=([cls.server.port, cls.server2.port] if cls.ssl else []), - **pconf - ) + + cls.config = ProxyConfig(**cls.get_proxy_config()) + tmaster = cls.masterclass(cls.config) tmaster.start_app(APP_HOST, APP_PORT, cls.externalapp) cls.proxy = ProxyThread(tmaster) cls.proxy.start() - @property - def master(cls): - return cls.proxy.tmaster - @classmethod def teardownAll(cls): shutil.rmtree(cls.confdir) @@ -121,24 +111,20 @@ class ProxTestBase(object): self.server2.clear_log() @property - def scheme(self): - return "https" if self.ssl else "http" - - @property - def proxies(self): - """ - The URL base for the server instance. - """ - return ( - (self.scheme, ("127.0.0.1", self.proxy.port)) - ) + def master(self): + return self.proxy.tmaster @classmethod def get_proxy_config(cls): - d = dict() - if cls.clientcerts: - d["clientcerts"] = tutils.test_data.path("data/clientcert") - return d + cls.confdir = os.path.join(tempfile.gettempdir(), "mitmproxy") + return dict( + no_upstream_cert = cls.no_upstream_cert, + confdir = cls.confdir, + authenticator = cls.authenticator, + certforward = cls.certforward, + ssl_ports=([cls.server.port, cls.server2.port] if cls.ssl else []), + clientcerts = tutils.test_data.path("data/clientcert") if cls.clientcerts else None + ) class HTTPProxTest(ProxTestBase): @@ -265,49 +251,50 @@ class ReverseProxTest(ProxTestBase): class ChainProxTest(ProxTestBase): """ - Chain n instances of mitmproxy in a row - because we can. + Chain three instances of mitmproxy in a row to test upstream mode. + Proxy order is cls.proxy -> cls.chain[0] -> cls.chain[1] + cls.proxy and cls.chain[0] are in upstream mode, + cls.chain[1] is in regular mode. """ + chain = None n = 2 - chain_config = [lambda port, sslports: ProxyConfig( - upstream_server= (False, False, "127.0.0.1", port), - http_form_in = "absolute", - http_form_out = "absolute", - ssl_ports=sslports - )] * n + @classmethod def setupAll(cls): - super(ChainProxTest, cls).setupAll() cls.chain = [] - for i in range(cls.n): - sslports = [cls.server.port, cls.server2.port] - config = cls.chain_config[i](cls.proxy.port if i == 0 else cls.chain[-1].port, - sslports) + super(ChainProxTest, cls).setupAll() + for _ in range(cls.n): + config = ProxyConfig(**cls.get_proxy_config()) tmaster = cls.masterclass(config) - tmaster.start_app(APP_HOST, APP_PORT, cls.externalapp) - cls.chain.append(ProxyThread(tmaster)) - cls.chain[-1].start() + proxy = ProxyThread(tmaster) + proxy.start() + cls.chain.insert(0, proxy) + + # Patch the orginal proxy to upstream mode + cls.config = cls.proxy.tmaster.config = cls.proxy.tmaster.server.config = ProxyConfig(**cls.get_proxy_config()) + @classmethod def teardownAll(cls): super(ChainProxTest, cls).teardownAll() - for p in cls.chain: - p.tmaster.shutdown() + for proxy in cls.chain: + proxy.shutdown() def setUp(self): super(ChainProxTest, self).setUp() - for p in self.chain: - p.tmaster.clear_log() - p.tmaster.state.clear() + for proxy in self.chain: + proxy.tmaster.clear_log() + proxy.tmaster.state.clear() + @classmethod + def get_proxy_config(cls): + d = super(ChainProxTest, cls).get_proxy_config() + if cls.chain: # First proxy is in normal mode. + d.update( + mode="upstream", + upstream_server=(False, False, "127.0.0.1", cls.chain[0].port) + ) + return d -class HTTPChainProxyTest(ChainProxTest): - def pathoc(self, sni=None): - """ - Returns a connected Pathoc instance. - """ - p = libpathod.pathoc.Pathoc(("localhost", self.chain[-1].port), ssl=self.ssl, sni=sni) - if self.ssl: - p.connect(("127.0.0.1", self.server.port)) - else: - p.connect() - return p +class HTTPUpstreamProxTest(ChainProxTest, HTTPProxTest): + pass \ No newline at end of file -- cgit v1.2.3 From f2570c773aa18e4ac236b1cf7f43acfb4ca080dd Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 5 Sep 2014 15:05:44 +0200 Subject: iframe injector example: use inline script --- examples/iframe_injector | 50 --------------------------------------------- examples/iframe_injector.py | 18 ++++++++++++++++ test/test_examples.py | 2 ++ 3 files changed, 20 insertions(+), 50 deletions(-) delete mode 100755 examples/iframe_injector create mode 100644 examples/iframe_injector.py diff --git a/examples/iframe_injector b/examples/iframe_injector deleted file mode 100755 index 8b1e02f1..00000000 --- a/examples/iframe_injector +++ /dev/null @@ -1,50 +0,0 @@ -#!/usr/bin/env python -""" - Zap encoding in requests and inject iframe after body tag in html responses. - Usage: - iframe_injector http://someurl/somefile.html -""" -from libmproxy import controller, proxy -import os -import sys - - -class InjectingMaster(controller.Master): - def __init__(self, server, iframe_url): - controller.Master.__init__(self, server) - self._iframe_url = iframe_url - - def run(self): - try: - return controller.Master.run(self) - except KeyboardInterrupt: - self.shutdown() - - def handle_request(self, msg): - if 'Accept-Encoding' in msg.headers: - msg.headers["Accept-Encoding"] = 'none' - msg.reply() - - def handle_response(self, msg): - if msg.content: - c = msg.replace('', '' % self._iframe_url) - if c > 0: - print 'Iframe injected!' - msg.reply() - - -def main(argv): - if len(argv) != 2: - print "Usage: %s IFRAME_URL" % argv[0] - sys.exit(1) - iframe_url = argv[1] - config = proxy.ProxyConfig( - cacert = os.path.expanduser("~/.mitmproxy/mitmproxy-ca.pem") - ) - server = proxy.ProxyServer(config, 8080) - print 'Starting proxy...' - m = InjectingMaster(server, iframe_url) - m.run() - -if __name__ == '__main__': - main(sys.argv) diff --git a/examples/iframe_injector.py b/examples/iframe_injector.py new file mode 100644 index 00000000..7042dbab --- /dev/null +++ b/examples/iframe_injector.py @@ -0,0 +1,18 @@ +# Usage: mitmdump -s "iframe_injector.py url" +# (this script works best with --anticache) +from libmproxy.protocol.http import decoded + + +def start(ctx, argv): + if len(argv) != 2: + raise ValueError('Usage: -s "iframe_injector.py url"') + ctx.iframe_url = argv[1] + + +def handle_response(ctx, flow): + with decoded(flow.response): # Remove content encoding (gzip, ...) + c = flow.response.replace( + '', + '' % ctx.iframe_url) + if c > 0: + ctx.log("Iframe injected!") \ No newline at end of file diff --git a/test/test_examples.py b/test/test_examples.py index d18b5862..d557080e 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -12,6 +12,8 @@ def test_load_scripts(): tmaster = tservers.TestMaster(config.ProxyConfig()) for f in scripts: + if "iframe_injector" in f: + f += " foo" # one argument required if "modify_response_body" in f: f += " foo bar" # two arguments required script.Script(f, tmaster) # Loads the script file. \ No newline at end of file -- cgit v1.2.3 From 2a6337343a14f7f72c28d8bf5f24220f6d9ca6d0 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 5 Sep 2014 15:16:20 +0200 Subject: update docs, mostly revert 2f44b26b4cd014e03dd62a125d79af9b81663a93 --- doc-src/scripting/inlinescripts.html | 53 +++++++++++---------- examples/add_header.py | 4 +- examples/dup_and_replay.py | 6 +-- examples/flowbasic | 9 ++-- examples/modify_form.py | 10 ++-- examples/modify_querystring.py | 7 ++- examples/nonblocking.py | 3 +- examples/redirect_requests.py | 4 +- examples/stickycookies | 8 ++-- examples/stub.py | 6 +-- examples/upsidedownternet.py | 2 +- libmproxy/console/flowview.py | 14 +++--- libmproxy/protocol/http.py | 89 +++++++++++++++++------------------- libmproxy/protocol/primitives.py | 21 ++++++--- libmproxy/proxy/__init__.py | 3 +- libmproxy/proxy/config.py | 6 +-- libmproxy/proxy/server.py | 1 - test/test_flow.py | 86 +++++++++++++++++----------------- test/test_protocol_http.py | 8 ++-- test/test_proxy.py | 2 +- 20 files changed, 176 insertions(+), 166 deletions(-) diff --git a/doc-src/scripting/inlinescripts.html b/doc-src/scripting/inlinescripts.html index 738f8dc3..eef4e440 100644 --- a/doc-src/scripting/inlinescripts.html +++ b/doc-src/scripting/inlinescripts.html @@ -29,46 +29,45 @@ The new header will be added to all responses passing through the proxy. Called once on startup, before any other events. -### clientconnect(ScriptContext, ClientConnect) +### clientconnect(ScriptContext, ConnectionHandler) Called when a client initiates a connection to the proxy. Note that a connection can correspond to multiple HTTP requests. - -### serverconnect(ScriptContext, ServerConnection) +### serverconnect(ScriptContext, ConnectionHandler) Called when the proxy initiates a connection to the target server. Note that a connection can correspond to multiple HTTP requests. -### request(ScriptContext, Flow) +### request(ScriptContext, HTTPFlow) -Called when a client request has been received. The __Flow__ object is +Called when a client request has been received. The __HTTPFlow__ object is guaranteed to have a non-None __request__ attribute. -### responseheaders(ScriptContext, Flow) +### responseheaders(ScriptContext, HTTPFlow) Called when the headers of a server response have been received. This will always be called before the response hook. -The __Flow__ object is guaranteed to have non-None __request__ and -__response__ attributes. __response.content__ will not be valid, +The __HTTPFlow__ object is guaranteed to have non-None __request__ and +__response__ attributes. __response.content__ will be None, as the response body has not been read yet. -### response(ScriptContext, Flow) +### response(ScriptContext, HTTPFlow) -Called when a server response has been received. The __Flow__ object is +Called when a server response has been received. The __HTTPFlow__ object is guaranteed to have non-None __request__ and __response__ attributes. Note that if response streaming is enabled for this response, __response.content__ will not contain the response body. -### error(ScriptContext, Flow) +### error(ScriptContext, HTTPFlow) Called when a flow error has occurred, e.g. invalid server responses, or interrupted connections. This is distinct from a valid server HTTP error -response, which is simply a response with an HTTP error code. The __Flow__ +response, which is simply a response with an HTTP error code. The __HTTPFlow__ object is guaranteed to have non-None __request__ and __error__ attributes. -### clientdisconnect(ScriptContext, ClientDisconnect) +### clientdisconnect(ScriptContext, ConnectionHandler) Called when a client disconnects from the proxy. @@ -96,22 +95,10 @@ The main classes you will deal with in writing mitmproxy scripts are: libmproxy.proxy.connection.ServerConnection Describes a server connection. - - libmproxy.protocol.primitives.Error - A communications error. - libmproxy.protocol.http.HTTPFlow A collection of objects representing a single HTTP transaction. - - libmproxy.flow.ODict - - A dictionary-like object for managing sets of key/value data. There - is also a variant called CaselessODict that ignores key case for some - calls (used mainly for headers). - - libmproxy.protocol.http.HTTPResponse An HTTP response. @@ -120,10 +107,22 @@ The main classes you will deal with in writing mitmproxy scripts are: libmproxy.protocol.http.HTTPRequest An HTTP request. + + libmproxy.protocol.primitives.Error + A communications error. + libmproxy.script.ScriptContext A handle for interacting with mitmproxy's from within scripts. + + libmproxy.flow.ODict + + A dictionary-like object for managing sets of key/value data. There + is also a variant called CaselessODict that ignores key case for some + calls (used mainly for headers). + + libmproxy.certutils.SSLCert Exposes information SSL certificates. @@ -161,9 +160,9 @@ flows from a file (see the "scripted data transformation" example on the one-shot script on a single flow through the _|_ (pipe) shortcut in mitmproxy. In this case, there are no client connections, and the events are run in the -following order: __start__, __request__, __response__, __error__, __done__. If +following order: __start__, __request__, __responseheaders__, __response__, __error__, __done__. If the flow doesn't have a __response__ or __error__ associated with it, the -matching event will be skipped. +matching events will be skipped. ## Spaces in the script path By default, spaces are interpreted as separator between the inline script and its arguments (e.g. -s "foo.py diff --git a/examples/add_header.py b/examples/add_header.py index 0c0593d1..b9c8c1c6 100644 --- a/examples/add_header.py +++ b/examples/add_header.py @@ -1,2 +1,2 @@ -def response(context, flow): - flow.response.headers["newheader"] = ["foo"] +def response(ctx, flow): + flow.response.headers["newheader"] = ["foo"] \ No newline at end of file diff --git a/examples/dup_and_replay.py b/examples/dup_and_replay.py index 9c58d3a4..b38c2b7e 100644 --- a/examples/dup_and_replay.py +++ b/examples/dup_and_replay.py @@ -1,4 +1,4 @@ def request(ctx, flow): - f = ctx.duplicate_flow(flow) - f.request.path = "/changed" - ctx.replay_request(f) + f = ctx.duplicate_flow(flow) + f.request.path = "/changed" + ctx.replay_request(f) \ No newline at end of file diff --git a/examples/flowbasic b/examples/flowbasic index 8dbe2f28..2b44be3f 100755 --- a/examples/flowbasic +++ b/examples/flowbasic @@ -3,11 +3,14 @@ This example shows how to build a proxy based on mitmproxy's Flow primitives. + Heads Up: In the majority of cases, you want to use inline scripts. + Note that request and response messages are not automatically replied to, so we need to implement handlers to do this. """ import os -from libmproxy import proxy, flow +from libmproxy import flow, proxy +from libmproxy.proxy.server import ProxyServer class MyMaster(flow.FlowMaster): def run(self): @@ -31,9 +34,9 @@ class MyMaster(flow.FlowMaster): config = proxy.ProxyConfig( - cacert = os.path.expanduser("~/.mitmproxy/mitmproxy-ca.pem") + ca_file = os.path.expanduser("~/.mitmproxy/mitmproxy-ca.pem") ) state = flow.State() -server = proxy.ProxyServer(config, 8080) +server = ProxyServer(config, 8080) m = MyMaster(server, state) m.run() diff --git a/examples/modify_form.py b/examples/modify_form.py index cb12ee0f..6d651b19 100644 --- a/examples/modify_form.py +++ b/examples/modify_form.py @@ -1,8 +1,6 @@ -def request(context, flow): +def request(ctx, flow): if "application/x-www-form-urlencoded" in flow.request.headers["content-type"]: - frm = flow.request.form_urlencoded - frm["mitmproxy"] = ["rocks"] - flow.request.form_urlencoded = frm - - + form = flow.request.get_form_urlencoded() + form["mitmproxy"] = ["rocks"] + flow.request.set_form_urlencoded(form) \ No newline at end of file diff --git a/examples/modify_querystring.py b/examples/modify_querystring.py index 7e3a068a..56fbbb32 100644 --- a/examples/modify_querystring.py +++ b/examples/modify_querystring.py @@ -1,7 +1,6 @@ -def request(context, flow): - q = flow.request.query +def request(ctx, flow): + q = flow.request.get_query() if q: q["mitmproxy"] = ["rocks"] - flow.request.query = q - + flow.request.set_query(q) \ No newline at end of file diff --git a/examples/nonblocking.py b/examples/nonblocking.py index 9a131b32..1396742a 100644 --- a/examples/nonblocking.py +++ b/examples/nonblocking.py @@ -1,8 +1,9 @@ import time from libmproxy.script import concurrent + @concurrent def request(context, flow): print "handle request: %s%s" % (flow.request.host, flow.request.path) time.sleep(5) - print "start request: %s%s" % (flow.request.host, flow.request.path) + print "start request: %s%s" % (flow.request.host, flow.request.path) \ No newline at end of file diff --git a/examples/redirect_requests.py b/examples/redirect_requests.py index cc642039..c5561839 100644 --- a/examples/redirect_requests.py +++ b/examples/redirect_requests.py @@ -6,7 +6,9 @@ This example shows two ways to redirect flows to other destinations. """ -def request(context, flow): +def request(ctx, flow): + # pretty_host(hostheader=True) takes the Host: header of the request into account, + # which is useful in transparent mode where we usually only have the IP otherwise. if flow.request.pretty_host(hostheader=True).endswith("example.com"): resp = HTTPResponse( [1, 1], 200, "OK", diff --git a/examples/stickycookies b/examples/stickycookies index 17cd6019..2aab31d6 100755 --- a/examples/stickycookies +++ b/examples/stickycookies @@ -5,8 +5,10 @@ implement functionality similar to the "sticky cookies" option. This is at a lower level than the Flow mechanism, so we're dealing directly with request and response objects. """ -from libmproxy import controller, proxy import os +from libmproxy import controller, proxy +from libmproxy.proxy.server import ProxyServer + class StickyMaster(controller.Master): def __init__(self, server): @@ -35,8 +37,8 @@ class StickyMaster(controller.Master): config = proxy.ProxyConfig( - cacert = os.path.expanduser("~/.mitmproxy/mitmproxy-ca.pem") + ca_file = os.path.expanduser("~/.mitmproxy/mitmproxy-ca.pem") ) -server = proxy.ProxyServer(config, 8080) +server = ProxyServer(config, 8080) m = StickyMaster(server) m.run() diff --git a/examples/stub.py b/examples/stub.py index 0cf67db7..5976dd76 100644 --- a/examples/stub.py +++ b/examples/stub.py @@ -7,14 +7,14 @@ def start(ctx, argv): """ ctx.log("start") -def clientconnect(ctx, client_connect): +def clientconnect(ctx, conn_handler): """ Called when a client initiates a connection to the proxy. Note that a connection can correspond to multiple HTTP requests """ ctx.log("clientconnect") -def serverconnect(ctx, server_connection): +def serverconnect(ctx, conn_handler): """ Called when the proxy initiates a connection to the target server. Note that a connection can correspond to multiple HTTP requests @@ -50,7 +50,7 @@ def error(ctx, flow): """ ctx.log("error") -def clientdisconnect(ctx, client_disconnect): +def clientdisconnect(ctx, conn_handler): """ Called when a client disconnects from the proxy. """ diff --git a/examples/upsidedownternet.py b/examples/upsidedownternet.py index 181a40c2..a52b6d30 100644 --- a/examples/upsidedownternet.py +++ b/examples/upsidedownternet.py @@ -1,7 +1,7 @@ import cStringIO from PIL import Image -def response(context, flow): +def response(ctx, flow): if flow.response.headers["content-type"] == ["image/png"]: s = cStringIO.StringIO(flow.response.content) img = Image.open(s).rotate(180) diff --git a/libmproxy/console/flowview.py b/libmproxy/console/flowview.py index 9063c3e1..014d44c0 100644 --- a/libmproxy/console/flowview.py +++ b/libmproxy/console/flowview.py @@ -554,17 +554,17 @@ class FlowView(common.WWrap): conn.headers = flow.ODictCaseless(lst) def set_query(self, lst, conn): - conn.query = flow.ODict(lst) + conn.set_query(flow.ODict(lst)) def set_path_components(self, lst, conn): - conn.path_components = [i[0] for i in lst] + conn.set_path_components([i[0] for i in lst]) def set_form(self, lst, conn): - conn.form_urlencoded = flow.ODict(lst) + conn.set_form_urlencoded(flow.ODict(lst)) def edit_form(self, conn): self.master.view_grideditor( - grideditor.URLEncodedFormEditor(self.master, conn.form_urlencoded.lst, self.set_form, conn) + grideditor.URLEncodedFormEditor(self.master, conn.get_form_urlencoded().lst, self.set_form, conn) ) def edit_form_confirm(self, key, conn): @@ -589,7 +589,7 @@ class FlowView(common.WWrap): c = self.master.spawn_editor(conn.content or "") conn.content = c.rstrip("\n") # what? elif part == "f": - if not conn.form_urlencoded and conn.content: + if not conn.get_form_urlencoded() and conn.content: self.master.prompt_onekey( "Existing body is not a URL-encoded form. Clear and edit?", [ @@ -604,11 +604,11 @@ class FlowView(common.WWrap): elif part == "h": self.master.view_grideditor(grideditor.HeaderEditor(self.master, conn.headers.lst, self.set_headers, conn)) elif part == "p": - p = conn.path_components + p = conn.get_path_components() p = [[i] for i in p] self.master.view_grideditor(grideditor.PathEditor(self.master, p, self.set_path_components, conn)) elif part == "q": - self.master.view_grideditor(grideditor.QueryEditor(self.master, conn.query.lst, self.set_query, conn)) + self.master.view_grideditor(grideditor.QueryEditor(self.master, conn.get_query().lst, self.set_query, conn)) elif part == "u" and self.state.view_flow_mode == common.VIEW_FLOW_REQUEST: self.master.prompt_edit("URL", conn.url, self.set_url) elif part == "m" and self.state.view_flow_mode == common.VIEW_FLOW_REQUEST: diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index 90ee127c..9593c3cb 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -29,13 +29,13 @@ def get_line(fp): def send_connect_request(conn, host, port, update_state=True): upstream_request = HTTPRequest("authority", "CONNECT", None, host, port, None, (1, 1), ODictCaseless(), "") - conn.send(upstream_request._assemble()) + conn.send(upstream_request.assemble()) resp = HTTPResponse.from_stream(conn.rfile, upstream_request.method) if resp.code != 200: raise proxy.ProxyError(resp.code, "Cannot establish SSL " + "connection with upstream proxy: \r\n" + - str(resp._assemble())) + str(resp.assemble())) if update_state: conn.state.append(("http", { "state": "connect", @@ -73,6 +73,9 @@ class decoded(object): class HTTPMessage(stateobject.SimpleStateObject): + """ + Base class for HTTPRequest and HTTPResponse + """ def __init__(self, httpversion, headers, content, timestamp_start=None, timestamp_end=None): self.httpversion = httpversion @@ -162,31 +165,31 @@ class HTTPMessage(stateobject.SimpleStateObject): """ Parse an HTTP message from a file stream """ - raise NotImplementedError # pragma: nocover + raise NotImplementedError() # pragma: nocover def _assemble_first_line(self): """ Returns the assembled request/response line """ - raise NotImplementedError # pragma: nocover + raise NotImplementedError() # pragma: nocover def _assemble_headers(self): """ Returns the assembled headers """ - raise NotImplementedError # pragma: nocover + raise NotImplementedError() # pragma: nocover def _assemble_head(self): """ Returns the assembled request/response line plus headers """ - raise NotImplementedError # pragma: nocover + raise NotImplementedError() # pragma: nocover - def _assemble(self): + def assemble(self): """ Returns the assembled request/response """ - raise NotImplementedError # pragma: nocover + raise NotImplementedError() # pragma: nocover class HTTPRequest(HTTPMessage): @@ -195,7 +198,17 @@ class HTTPRequest(HTTPMessage): Exposes the following attributes: - flow: Flow object the request belongs to + method: HTTP method + + scheme: URL scheme (http/https) (absolute-form only) + + host: Host portion of the URL (absolute-form and authority-form only) + + port: Destination port (absolute-form and authority-form only) + + path: Path portion of the URL (not present in authority-form) + + httpversion: HTTP version tuple, e.g. (1,1) headers: ODictCaseless object @@ -211,18 +224,6 @@ class HTTPRequest(HTTPMessage): form_out: The request form which mitmproxy has send out to the destination - method: HTTP method - - scheme: URL scheme (http/https) (absolute-form only) - - host: Host portion of the URL (absolute-form and authority-form only) - - port: Destination port (absolute-form and authority-form only) - - path: Path portion of the URL (not present in authority-form) - - httpversion: HTTP version tuple - timestamp_start: Timestamp indicating when request transmission started timestamp_end: Timestamp indicating when request transmission ended @@ -364,7 +365,7 @@ class HTTPRequest(HTTPMessage): def _assemble_head(self, form=None): return "%s\r\n%s\r\n" % (self._assemble_first_line(form), self._assemble_headers()) - def _assemble(self, form=None): + def assemble(self, form=None): """ Assembles the request for transmission to the server. We make some modifications to make sure interception works properly. @@ -417,8 +418,7 @@ class HTTPRequest(HTTPMessage): """ self.headers["Host"] = [self.host] - @property - def form_urlencoded(self): + def get_form_urlencoded(self): """ Retrieves the URL-encoded form data, returning an ODict object. Returns an empty ODict if there is no data or the content-type @@ -428,8 +428,7 @@ class HTTPRequest(HTTPMessage): return ODict(utils.urldecode(self.content)) return ODict([]) - @form_urlencoded.setter - def form_urlencoded(self, odict): + def set_form_urlencoded(self, odict): """ Sets the body to the URL-encoded form data, and adds the appropriate content-type header. Note that this will destory the @@ -440,8 +439,7 @@ class HTTPRequest(HTTPMessage): self.headers["Content-Type"] = [HDR_FORM_URLENCODED] self.content = utils.urlencode(odict.lst) - @property - def path_components(self): + def get_path_components(self): """ Returns the path components of the URL as a list of strings. @@ -450,8 +448,7 @@ class HTTPRequest(HTTPMessage): _, _, path, _, _, _ = urlparse.urlparse(self.url) return [urllib.unquote(i) for i in path.split("/") if i] - @path_components.setter - def path_components(self, lst): + def set_path_components(self, lst): """ Takes a list of strings, and sets the path component of the URL. @@ -462,8 +459,7 @@ class HTTPRequest(HTTPMessage): scheme, netloc, _, params, query, fragment = urlparse.urlparse(self.url) self.url = urlparse.urlunparse([scheme, netloc, path, params, query, fragment]) - @property - def query(self): + def get_query(self): """ Gets the request query string. Returns an ODict object. """ @@ -472,8 +468,7 @@ class HTTPRequest(HTTPMessage): return ODict(utils.urldecode(query)) return ODict([]) - @query.setter - def query(self, odict): + def set_query(self, odict): """ Takes an ODict object, and sets the request query string. """ @@ -528,8 +523,7 @@ class HTTPRequest(HTTPMessage): raise ValueError("Invalid URL: %s" % url) self.scheme, self.host, self.port, self.path = parts - @property - def cookies(self): + def get_cookies(self): cookie_headers = self.headers.get("cookie") if not cookie_headers: return None @@ -560,7 +554,7 @@ class HTTPResponse(HTTPMessage): Exposes the following attributes: - flow: Flow object the request belongs to + httpversion: HTTP version tuple, e.g. (1,1) code: HTTP response code @@ -572,8 +566,6 @@ class HTTPResponse(HTTPMessage): is content associated, but not present. CONTENT_MISSING evaluates to False to make checking for the presence of content natural. - httpversion: HTTP version tuple - timestamp_start: Timestamp indicating when request transmission started timestamp_end: Timestamp indicating when request transmission ended @@ -661,7 +653,7 @@ class HTTPResponse(HTTPMessage): return '%s\r\n%s\r\n' % ( self._assemble_first_line(), self._assemble_headers(preserve_transfer_encoding=preserve_transfer_encoding)) - def _assemble(self): + def assemble(self): """ Assembles the response for transmission to the client. We make some modifications to make sure interception works properly. @@ -726,8 +718,7 @@ class HTTPResponse(HTTPMessage): if c: self.headers["set-cookie"] = c - @property - def cookies(self): + def get_cookies(self): cookie_headers = self.headers.get("set-cookie") if not cookie_headers: return None @@ -745,12 +736,14 @@ class HTTPResponse(HTTPMessage): class HTTPFlow(Flow): """ - A Flow is a collection of objects representing a single HTTP + A HTTPFlow is a collection of objects representing a single HTTP transaction. The main attributes are: request: HTTPRequest object response: HTTPResponse object error: Error object + server_conn: ServerConnection object + client_conn: ClientConnection object Note that it's possible for a Flow to have both a response and an error object. This might happen, for instance, when a response was received @@ -866,6 +859,10 @@ class HttpAuthenticationError(Exception): class HTTPHandler(ProtocolHandler): + """ + HTTPHandler implements mitmproxys understanding of the HTTP protocol. + + """ def __init__(self, c): super(HTTPHandler, self).__init__(c) self.expected_form_in = c.config.http_form_in @@ -878,7 +875,7 @@ class HTTPHandler(ProtocolHandler): def get_response_from_server(self, request, include_body=True): self.c.establish_server_connection() - request_raw = request._assemble() + request_raw = request.assemble() for i in range(2): try: @@ -957,7 +954,7 @@ class HTTPHandler(ProtocolHandler): if not flow.response.stream: # no streaming: # we already received the full response from the server and can send it to the client straight away. - self.c.client_conn.send(flow.response._assemble()) + self.c.client_conn.send(flow.response.assemble()) else: # streaming: # First send the body and then transfer the response incrementally: @@ -1225,7 +1222,7 @@ class RequestReplayThread(threading.Thread): server.establish_ssl(self.config.clientcerts, sni=r.host) r.form_out = "relative" - server.send(r._assemble()) + server.send(r.assemble()) self.flow.response = HTTPResponse.from_stream(server.rfile, r.method, body_size_limit=self.config.body_size_limit) self.channel.ask("response", self.flow) diff --git a/libmproxy/protocol/primitives.py b/libmproxy/protocol/primitives.py index ee1199fc..ecad9d9e 100644 --- a/libmproxy/protocol/primitives.py +++ b/libmproxy/protocol/primitives.py @@ -12,9 +12,9 @@ class Error(stateobject.SimpleStateObject): """ An Error. - This is distinct from an HTTP error response (say, a code 500), which - is represented by a normal Response object. This class is responsible - for indicating errors that fall outside of normal HTTP communications, + This is distinct from an protocol error response (say, a HTTP code 500), which + is represented by a normal HTTPResponse object. This class is responsible + for indicating errors that fall outside of normal protocol communications, like interrupted connections, timeouts, protocol errors. Exposes the following attributes: @@ -52,6 +52,10 @@ class Error(stateobject.SimpleStateObject): class Flow(stateobject.SimpleStateObject): + """ + A Flow is a collection of objects representing a single transaction. + This class is usually subclassed for each protocol, e.g. HTTPFlow. + """ def __init__(self, conntype, client_conn, server_conn, live=None): self.conntype = conntype self.client_conn = client_conn @@ -117,6 +121,10 @@ class Flow(stateobject.SimpleStateObject): class ProtocolHandler(object): + """ + A ProtocolHandler implements an application-layer protocol, e.g. HTTP. + See: libmproxy.protocol.http.HTTPHandler + """ def __init__(self, c): self.c = c """@type: libmproxy.proxy.server.ConnectionHandler""" @@ -148,13 +156,14 @@ class ProtocolHandler(object): class LiveConnection(object): """ - This facade allows protocol handlers to interface with a live connection, - without requiring the expose the ConnectionHandler. + This facade allows interested parties (FlowMaster, inline scripts) to interface with a live connection, + without requiring to expose the internals of the ConnectionHandler. """ def __init__(self, c): self.c = c - self._backup_server_conn = None """@type: libmproxy.proxy.server.ConnectionHandler""" + self._backup_server_conn = None + """@type: libmproxy.proxy.connection.ServerConnection""" def change_server(self, address, ssl=False, force=False, persistent_change=False): address = netlib.tcp.Address.wrap(address) diff --git a/libmproxy/proxy/__init__.py b/libmproxy/proxy/__init__.py index f5d6a2d0..e4c20030 100644 --- a/libmproxy/proxy/__init__.py +++ b/libmproxy/proxy/__init__.py @@ -1 +1,2 @@ -from .primitives import * \ No newline at end of file +from .primitives import * +from .config import ProxyConfig diff --git a/libmproxy/proxy/config.py b/libmproxy/proxy/config.py index 6d4c078b..ea815c69 100644 --- a/libmproxy/proxy/config.py +++ b/libmproxy/proxy/config.py @@ -1,8 +1,8 @@ from __future__ import absolute_import import os -from .. import utils, platform import re from netlib import http_auth, certutils +from .. import utils, platform from .primitives import ConstUpstreamServerResolver, TransparentUpstreamServerResolver TRANSPARENT_SSL_PORTS = [443, 8443] @@ -11,7 +11,7 @@ CONF_DIR = "~/.mitmproxy" class ProxyConfig: - def __init__(self, confdir=CONF_DIR, clientcerts=None, + def __init__(self, confdir=CONF_DIR, ca_file=None, clientcerts=None, no_upstream_cert=False, body_size_limit=None, mode=None, upstream_server=None, http_form_in=None, http_form_out=None, authenticator=None, ignore=[], @@ -44,7 +44,7 @@ class ProxyConfig: self.ignore = [re.compile(i, re.IGNORECASE) for i in ignore] self.authenticator = authenticator self.confdir = os.path.expanduser(confdir) - self.ca_file = os.path.join(self.confdir, CONF_BASENAME + "-ca.pem") + self.ca_file = ca_file or os.path.join(self.confdir, CONF_BASENAME + "-ca.pem") self.certstore = certutils.CertStore.from_store(self.confdir, CONF_BASENAME) for spec, cert in certs: self.certstore.add_cert_file(spec, cert) diff --git a/libmproxy/proxy/server.py b/libmproxy/proxy/server.py index 58e386ab..aacc908e 100644 --- a/libmproxy/proxy/server.py +++ b/libmproxy/proxy/server.py @@ -1,5 +1,4 @@ from __future__ import absolute_import -import re import socket from OpenSSL import SSL diff --git a/test/test_flow.py b/test/test_flow.py index 4bc2391e..914138c9 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -481,7 +481,7 @@ class TestSerialize: f2 = l[0] assert f2._get_state() == f._get_state() - assert f2.request._assemble() == f.request._assemble() + assert f2.request.assemble() == f.request.assemble() def test_load_flows(self): r = self._treader() @@ -757,18 +757,18 @@ class TestRequest: r.url = u tutils.raises(ValueError, setattr, r, "url", "") assert r.url == u - assert r._assemble() - assert r.size() == len(r._assemble()) + assert r.assemble() + assert r.size() == len(r.assemble()) r2 = r.copy() assert r == r2 r.content = None - assert r._assemble() - assert r.size() == len(r._assemble()) + assert r.assemble() + assert r.size() == len(r.assemble()) r.content = CONTENT_MISSING - tutils.raises("Cannot assemble flow with CONTENT_MISSING", r._assemble) + tutils.raises("Cannot assemble flow with CONTENT_MISSING", r.assemble) def test_get_url(self): r = tutils.treq() @@ -794,58 +794,58 @@ class TestRequest: def test_path_components(self): r = tutils.treq() r.path = "/" - assert r.path_components == [] + assert r.get_path_components() == [] r.path = "/foo/bar" - assert r.path_components == ["foo", "bar"] + assert r.get_path_components() == ["foo", "bar"] q = flow.ODict() q["test"] = ["123"] - r.query = q - assert r.path_components == ["foo", "bar"] - - r.path_components = [] - assert r.path_components == [] - r.path_components = ["foo"] - assert r.path_components == ["foo"] - r.path_components = ["/oo"] - assert r.path_components == ["/oo"] + r.set_query(q) + assert r.get_path_components() == ["foo", "bar"] + + r.set_path_components([]) + assert r.get_path_components() == [] + r.set_path_components(["foo"]) + assert r.get_path_components() == ["foo"] + r.set_path_components(["/oo"]) + assert r.get_path_components() == ["/oo"] assert "%2F" in r.path def test_getset_form_urlencoded(self): d = flow.ODict([("one", "two"), ("three", "four")]) r = tutils.treq(content=utils.urlencode(d.lst)) r.headers["content-type"] = [protocol.http.HDR_FORM_URLENCODED] - assert r.form_urlencoded == d + assert r.get_form_urlencoded() == d d = flow.ODict([("x", "y")]) - r.form_urlencoded = d - assert r.form_urlencoded == d + r.set_form_urlencoded(d) + assert r.get_form_urlencoded() == d r.headers["content-type"] = ["foo"] - assert not r.form_urlencoded + assert not r.get_form_urlencoded() def test_getset_query(self): h = flow.ODictCaseless() r = tutils.treq() r.path = "/foo?x=y&a=b" - q = r.query + q = r.get_query() assert q.lst == [("x", "y"), ("a", "b")] r.path = "/" - q = r.query + q = r.get_query() assert not q r.path = "/?adsfa" - q = r.query + q = r.get_query() assert q.lst == [("adsfa", "")] r.path = "/foo?x=y&a=b" - assert r.query - r.query = flow.ODict([]) - assert not r.query + assert r.get_query() + r.set_query(flow.ODict([])) + assert not r.get_query() qv = flow.ODict([("a", "b"), ("c", "d")]) - r.query = qv - assert r.query == qv + r.set_query(qv) + assert r.get_query() == qv def test_anticache(self): h = flow.ODictCaseless() @@ -916,14 +916,14 @@ class TestRequest: h = flow.ODictCaseless() r = tutils.treq() r.headers = h - assert r.cookies is None + assert r.get_cookies() is None def test_get_cookies_single(self): h = flow.ODictCaseless() h["Cookie"] = ["cookiename=cookievalue"] r = tutils.treq() r.headers = h - result = r.cookies + result = r.get_cookies() assert len(result)==1 assert result['cookiename']==('cookievalue',{}) @@ -932,7 +932,7 @@ class TestRequest: h["Cookie"] = ["cookiename=cookievalue;othercookiename=othercookievalue"] r = tutils.treq() r.headers = h - result = r.cookies + result = r.get_cookies() assert len(result)==2 assert result['cookiename']==('cookievalue',{}) assert result['othercookiename']==('othercookievalue',{}) @@ -942,7 +942,7 @@ class TestRequest: h["Cookie"] = ["cookiename=coo=kievalue;othercookiename=othercookievalue"] r = tutils.treq() r.headers = h - result = r.cookies + result = r.get_cookies() assert len(result)==2 assert result['cookiename']==('coo=kievalue',{}) assert result['othercookiename']==('othercookievalue',{}) @@ -966,18 +966,18 @@ class TestResponse: def test_simple(self): f = tutils.tflow(resp=True) resp = f.response - assert resp._assemble() - assert resp.size() == len(resp._assemble()) + assert resp.assemble() + assert resp.size() == len(resp.assemble()) resp2 = resp.copy() assert resp2 == resp resp.content = None - assert resp._assemble() - assert resp.size() == len(resp._assemble()) + assert resp.assemble() + assert resp.size() == len(resp.assemble()) resp.content = CONTENT_MISSING - tutils.raises("Cannot assemble flow with CONTENT_MISSING", resp._assemble) + tutils.raises("Cannot assemble flow with CONTENT_MISSING", resp.assemble) def test_refresh(self): r = tutils.tresp() @@ -1052,14 +1052,14 @@ class TestResponse: h = flow.ODictCaseless() resp = tutils.tresp() resp.headers = h - assert not resp.cookies + assert not resp.get_cookies() def test_get_cookies_simple(self): h = flow.ODictCaseless() h["Set-Cookie"] = ["cookiename=cookievalue"] resp = tutils.tresp() resp.headers = h - result = resp.cookies + result = resp.get_cookies() assert len(result)==1 assert "cookiename" in result assert result["cookiename"] == ("cookievalue", {}) @@ -1069,7 +1069,7 @@ class TestResponse: h["Set-Cookie"] = ["cookiename=cookievalue;domain=example.com;expires=Wed Oct 21 16:29:41 2015;path=/; HttpOnly"] resp = tutils.tresp() resp.headers = h - result = resp.cookies + result = resp.get_cookies() assert len(result)==1 assert "cookiename" in result assert result["cookiename"][0] == "cookievalue" @@ -1084,7 +1084,7 @@ class TestResponse: h["Set-Cookie"] = ["cookiename=; Expires=Thu, 01-Jan-1970 00:00:01 GMT; path=/"] resp = tutils.tresp() resp.headers = h - result = resp.cookies + result = resp.get_cookies() assert len(result)==1 assert "cookiename" in result assert result["cookiename"][0] == "" @@ -1095,7 +1095,7 @@ class TestResponse: h["Set-Cookie"] = ["cookiename=cookievalue","othercookie=othervalue"] resp = tutils.tresp() resp.headers = h - result = resp.cookies + result = resp.get_cookies() assert len(result)==2 assert "cookiename" in result assert result["cookiename"] == ("cookievalue", {}) diff --git a/test/test_protocol_http.py b/test/test_protocol_http.py index 41019672..ea6cf3fd 100644 --- a/test/test_protocol_http.py +++ b/test/test_protocol_http.py @@ -31,7 +31,7 @@ class TestHTTPRequest: f.request.host = f.server_conn.address.host f.request.port = f.server_conn.address.port f.request.scheme = "http" - assert f.request._assemble() == "OPTIONS * HTTP/1.1\r\nHost: address:22\r\n\r\n" + assert f.request.assemble() == "OPTIONS * HTTP/1.1\r\nHost: address:22\r\n\r\n" def test_origin_form(self): s = StringIO("GET /foo\xff HTTP/1.1") @@ -59,7 +59,7 @@ class TestHTTPRequest: s = StringIO("CONNECT address:22 HTTP/1.1") r = HTTPRequest.from_stream(s) r.scheme, r.host, r.port = "http", "address", 22 - assert r._assemble() == "CONNECT address:22 HTTP/1.1\r\nHost: address:22\r\n\r\n" + assert r.assemble() == "CONNECT address:22 HTTP/1.1\r\nHost: address:22\r\n\r\n" assert r.pretty_url(False) == "address:22" def test_absolute_form(self): @@ -67,11 +67,11 @@ class TestHTTPRequest: tutils.raises("Bad HTTP request line", HTTPRequest.from_stream, s) s = StringIO("GET http://address:22/ HTTP/1.1") r = HTTPRequest.from_stream(s) - assert r._assemble() == "GET http://address:22/ HTTP/1.1\r\nHost: address:22\r\n\r\n" + assert r.assemble() == "GET http://address:22/ HTTP/1.1\r\nHost: address:22\r\n\r\n" def test_assemble_unknown_form(self): r = tutils.treq() - tutils.raises("Invalid request form", r._assemble, "antiauthority") + tutils.raises("Invalid request form", r.assemble, "antiauthority") def test_set_url(self): r = tutils.treq_absolute() diff --git a/test/test_proxy.py b/test/test_proxy.py index e65841f4..073e76b5 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -28,7 +28,7 @@ class TestServerConnection: f = tutils.tflow() f.server_conn = sc f.request.path = "/p/200:da" - sc.send(f.request._assemble()) + sc.send(f.request.assemble()) assert http.read_response(sc.rfile, f.request.method, 1000) assert self.d.last_log() -- cgit v1.2.3 From 32e1ed212da8095530abad4ac41f10ee8a599c74 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 5 Sep 2014 19:39:05 +0200 Subject: streamline HTTPHandler.handle_flow() --- libmproxy/protocol/http.py | 191 +++++++++++++++++++++++---------------------- 1 file changed, 96 insertions(+), 95 deletions(-) diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index 9593c3cb..c67cb471 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -160,13 +160,6 @@ class HTTPMessage(stateobject.SimpleStateObject): c += self.headers.replace(pattern, repl, *args, **kwargs) return c - @classmethod - def from_stream(cls, rfile, include_body=True, body_size_limit=None): - """ - Parse an HTTP message from a file stream - """ - raise NotImplementedError() # pragma: nocover - def _assemble_first_line(self): """ Returns the assembled request/response line @@ -644,7 +637,8 @@ class HTTPResponse(HTTPMessage): if self.content: headers["Content-Length"] = [str(len(self.content))] - elif not preserve_transfer_encoding and 'Transfer-Encoding' in self.headers: # add content-length for chuncked transfer-encoding with no content + # add content-length for chuncked transfer-encoding with no content + elif not preserve_transfer_encoding and 'Transfer-Encoding' in self.headers: headers["Content-Length"] = ["0"] return str(headers) @@ -873,19 +867,21 @@ class HTTPHandler(ProtocolHandler): while self.handle_flow(): pass - def get_response_from_server(self, request, include_body=True): + def get_response_from_server(self, flow): self.c.establish_server_connection() - request_raw = request.assemble() + request_raw = flow.request.assemble() - for i in range(2): + for attempt in (0, 1): try: self.c.server_conn.send(request_raw) - res = HTTPResponse.from_stream(self.c.server_conn.rfile, request.method, - body_size_limit=self.c.config.body_size_limit, include_body=include_body) - return res + # Only get the headers at first... + flow.response = HTTPResponse.from_stream(self.c.server_conn.rfile, flow.request.method, + body_size_limit=self.c.config.body_size_limit, + include_body=False) + break except (tcp.NetLibDisconnect, http.HttpErrorConnClosed), v: self.c.log("error in server communication: %s" % repr(v), level="debug") - if i < 1: + if attempt == 0: # In any case, we try to reconnect at least once. # This is necessary because it might be possible that we already initiated an upstream connection # after clientconnect that has already been expired, e.g consider the following event log: @@ -899,13 +895,24 @@ class HTTPHandler(ProtocolHandler): else: raise + # call the appropriate script hook - this is an opportunity for an inline script to set flow.stream = True + self.c.channel.ask("responseheaders", flow) + + # now get the rest of the request body, if body still needs to be read but not streaming this response + if flow.response.stream: + flow.response.content = CONTENT_MISSING + else: + flow.response.content = http.read_http_body(self.c.server_conn.rfile, flow.response.headers, + self.c.config.body_size_limit, + flow.request.method, flow.response.code, False) + def handle_flow(self): flow = HTTPFlow(self.c.client_conn, self.c.server_conn, self.live) try: try: req = HTTPRequest.from_stream(self.c.client_conn.rfile, body_size_limit=self.c.config.body_size_limit) - except tcp.NetLibDisconnect: # specifically ignore disconnects that happen before/between requests. + except tcp.NetLibDisconnect: # don't throw an error for disconnects that happen before/between requests. return False self.c.log("request", "debug", [req._assemble_first_line(req.form_in)]) ret = self.process_request(flow, req) @@ -927,20 +934,7 @@ class HTTPHandler(ProtocolHandler): if isinstance(request_reply, HTTPResponse): flow.response = request_reply else: - - # read initially in "stream" mode, so we can get the headers separately - flow.response = self.get_response_from_server(flow.request, include_body=False) - - # call the appropriate script hook - this is an opportunity for an inline script to set flow.stream = True - self.c.channel.ask("responseheaders", flow) - - # now get the rest of the request body, if body still needs to be read but not streaming this response - if flow.response.stream: - flow.response.content = CONTENT_MISSING - else: - flow.response.content = http.read_http_body(self.c.server_conn.rfile, flow.response.headers, - self.c.config.body_size_limit, - flow.request.method, flow.response.code, False) + self.get_response_from_server(flow) # no further manipulation of self.c.server_conn beyond this point # we can safely set it as the final attribute value here. @@ -951,67 +945,34 @@ class HTTPHandler(ProtocolHandler): if response_reply is None or response_reply == KILL: return False - if not flow.response.stream: - # no streaming: - # we already received the full response from the server and can send it to the client straight away. - self.c.client_conn.send(flow.response.assemble()) - else: - # streaming: - # First send the body and then transfer the response incrementally: - h = flow.response._assemble_head(preserve_transfer_encoding=True) - self.c.client_conn.send(h) - for chunk in http.read_http_body_chunked(self.c.server_conn.rfile, - flow.response.headers, - self.c.config.body_size_limit, flow.request.method, - flow.response.code, False, 4096): - for part in chunk: - self.c.client_conn.wfile.write(part) - self.c.client_conn.wfile.flush() - flow.response.timestamp_end = utils.timestamp() - - flow.timestamp_end = utils.timestamp() + self.send_response_to_client(flow) - close_connection = ( - http.connection_close(flow.request.httpversion, flow.request.headers) or - http.connection_close(flow.response.httpversion, flow.response.headers) or - http.expected_http_body_size(flow.response.headers, False, flow.request.method, - flow.response.code) == -1) - if close_connection: - if flow.request.form_in == "authority" and flow.response.code == 200: - # Workaround for https://github.com/mitmproxy/mitmproxy/issues/313: - # Some proxies (e.g. Charles) send a CONNECT response with HTTP/1.0 and no Content-Length header - pass - else: - return False + if self.check_close_connection(flow): + return False # We sent a CONNECT request to an upstream proxy. if flow.request.form_in == "authority" and flow.response.code == 200: - # TODO: Eventually add headers (space/usefulness tradeoff) - # Make sure to add state info before the actual upgrade happens. - # During the upgrade, we may receive an SNI indication from the client, + # TODO: Possibly add headers (memory consumption/usefulness tradeoff) + # Make sure to add state info before the actual processing of the CONNECT request happens. + # During an SSL upgrade, we may receive an SNI indication from the client, # which resets the upstream connection. If this is the case, we must # already re-issue the CONNECT request at this point. self.c.server_conn.state.append(("http", {"state": "connect", "host": flow.request.host, "port": flow.request.port})) - - if self.c.check_ignore_address((flow.request.host, flow.request.port)): - self.c.log("Ignore host: %s:%s" % self.c.server_conn.address(), "info") - TCPHandler(self.c).handle_messages() + if not self.process_connect_request((flow.request.host, flow.request.port)): return False - else: - if flow.request.port in self.c.config.ssl_ports: - self.ssl_upgrade() - self.skip_authentication = True # If the user has changed the target server on this connection, # restore the original target server flow.live.restore_server() - flow.live = None - return True + return True # Next flow please. except (HttpAuthenticationError, http.HttpError, proxy.ProxyError, tcp.NetLibError), e: self.handle_error(e, flow) + finally: + flow.timestamp_end = utils.timestamp() + flow.live = None # Connection is not live anymore. return False def handle_server_reconnect(self, state): @@ -1060,16 +1021,6 @@ class HTTPHandler(ProtocolHandler): self.c.client_conn.wfile.write(html_content) self.c.client_conn.wfile.flush() - def ssl_upgrade(self): - """ - Upgrade the connection to SSL after an authority (CONNECT) request has been made. - """ - self.c.log("Received CONNECT request. Upgrading to SSL...", "debug") - self.expected_form_in = "relative" - self.expected_form_out = "relative" - self.c.establish_ssl(server=True, client=True) - self.c.log("Upgrade to SSL completed.", "debug") - def process_request(self, flow, request): """ @returns: @@ -1114,16 +1065,7 @@ class HTTPHandler(ProtocolHandler): ('Proxy-agent: %s\r\n' % self.c.server_version) + '\r\n' ) - - if self.c.check_ignore_address(self.c.server_conn.address): - self.c.log("Ignore host: %s:%s" % self.c.server_conn.address(), "info") - TCPHandler(self.c).handle_messages() - return False - else: - if self.c.server_conn.address.port in self.c.config.ssl_ports: - self.ssl_upgrade() - self.skip_authentication = True - return True + return self.process_connect_request(self.c.server_conn.address) else: # upstream proxy mode return None else: @@ -1140,7 +1082,6 @@ class HTTPHandler(ProtocolHandler): self.c.set_server_address((request.host, request.port)) flow.server_conn = self.c.server_conn - return None raise http.HttpError(400, "Invalid HTTP request form (expected: %s, got: %s)" % @@ -1182,6 +1123,66 @@ class HTTPHandler(ProtocolHandler): flow.server_conn = self.c.server_conn + def send_response_to_client(self, flow): + if not flow.response.stream: + # no streaming: + # we already received the full response from the server and can send it to the client straight away. + self.c.client_conn.send(flow.response.assemble()) + else: + # streaming: + # First send the body and then transfer the response incrementally: + h = flow.response._assemble_head(preserve_transfer_encoding=True) + self.c.client_conn.send(h) + for chunk in http.read_http_body_chunked(self.c.server_conn.rfile, + flow.response.headers, + self.c.config.body_size_limit, flow.request.method, + flow.response.code, False, 4096): + for part in chunk: + self.c.client_conn.wfile.write(part) + self.c.client_conn.wfile.flush() + flow.response.timestamp_end = utils.timestamp() + + def check_close_connection(self, flow): + """ + Checks if the connection should be closed depending on the HTTP semantics. Returns True, if so. + """ + close_connection = ( + http.connection_close(flow.request.httpversion, flow.request.headers) or + http.connection_close(flow.response.httpversion, flow.response.headers) or + http.expected_http_body_size(flow.response.headers, False, flow.request.method, + flow.response.code) == -1) + if close_connection: + if flow.request.form_in == "authority" and flow.response.code == 200: + # Workaround for https://github.com/mitmproxy/mitmproxy/issues/313: + # Some proxies (e.g. Charles) send a CONNECT response with HTTP/1.0 and no Content-Length header + pass + else: + return True + return False + + def process_connect_request(self, address): + """ + Process a CONNECT request. + Returns True if the CONNECT request has been processed successfully. + Returns False, if the connection should be closed immediately. + """ + address = tcp.Address.wrap(address) + if self.c.check_ignore_address(address): + self.c.log("Ignore host: %s:%s" % address(), "info") + TCPHandler(self.c).handle_messages() + return False + else: + self.expected_form_in = "relative" + self.expected_form_out = "relative" + self.skip_authentication = True + + if address.port in self.c.config.ssl_ports: + self.c.log("Received CONNECT request to SSL port. Upgrading to SSL...", "debug") + self.c.establish_ssl(server=True, client=True) + self.c.log("Upgrade to SSL completed.", "debug") + + return True + def authenticate(self, request): if self.c.config.authenticator: if self.c.config.authenticator.authenticate(request.headers): -- cgit v1.2.3 From ccb61829175b6ecb15cc753c5d134fe7b445b2ef Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 6 Sep 2014 12:39:23 +0200 Subject: fix race condition with the concurrent decorator --- libmproxy/script.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/libmproxy/script.py b/libmproxy/script.py index 706d84d5..f5fb6b41 100644 --- a/libmproxy/script.py +++ b/libmproxy/script.py @@ -108,15 +108,28 @@ class Script: return (False, None) +class ReplyProxy(object): + def __init__(self, original_reply): + self._ignore_calls = 1 + self.lock = threading.Lock() + self.original_reply = original_reply + + def __call__(self, *args, **kwargs): + with self.lock: + if self._ignore_calls > 0: + self._ignore_calls -= 1 + return + self.original_reply(*args, **kwargs) + + def __getattr__ (self, k): + return getattr(self.original_reply, k) + + def _handle_concurrent_reply(fn, o, *args, **kwargs): # Make first call to o.reply a no op - original_reply = o.reply - def restore_original_reply(): - o.reply = original_reply - if hasattr(original_reply, "q"): - restore_original_reply.q = original_reply.q - o.reply = restore_original_reply + reply_proxy = ReplyProxy(o.reply) + o.reply = reply_proxy def run(): fn(*args, **kwargs) -- cgit v1.2.3