diff options
-rw-r--r-- | examples/modify_form.py | 4 | ||||
-rw-r--r-- | examples/modify_querystring.py | 4 | ||||
-rw-r--r-- | examples/redirect_requests.py | 6 | ||||
-rw-r--r-- | libmproxy/console/common.py | 2 | ||||
-rw-r--r-- | libmproxy/console/flowview.py | 20 | ||||
-rw-r--r-- | libmproxy/dump.py | 2 | ||||
-rw-r--r-- | libmproxy/filt.py | 4 | ||||
-rw-r--r-- | libmproxy/flow.py | 12 | ||||
-rw-r--r-- | libmproxy/protocol/http.py | 234 | ||||
-rw-r--r-- | libmproxy/protocol/primitives.py | 53 | ||||
-rw-r--r-- | libmproxy/proxy/connection.py | 5 | ||||
-rw-r--r-- | libmproxy/proxy/primitives.py | 13 | ||||
-rw-r--r-- | libmproxy/proxy/server.py | 24 | ||||
-rw-r--r-- | test/test_flow.py | 102 | ||||
-rw-r--r-- | test/test_protocol_http.py | 12 | ||||
-rw-r--r-- | test/test_proxy.py | 4 | ||||
-rw-r--r-- | test/test_server.py | 11 |
17 files changed, 259 insertions, 253 deletions
diff --git a/examples/modify_form.py b/examples/modify_form.py index 2d839aed..cb12ee0f 100644 --- a/examples/modify_form.py +++ b/examples/modify_form.py @@ -1,8 +1,8 @@ def request(context, flow): if "application/x-www-form-urlencoded" in flow.request.headers["content-type"]: - frm = flow.request.get_form_urlencoded() + frm = flow.request.form_urlencoded frm["mitmproxy"] = ["rocks"] - flow.request.set_form_urlencoded(frm) + flow.request.form_urlencoded = frm diff --git a/examples/modify_querystring.py b/examples/modify_querystring.py index b1abcc1e..7e3a068a 100644 --- a/examples/modify_querystring.py +++ b/examples/modify_querystring.py @@ -1,7 +1,7 @@ def request(context, flow): - q = flow.request.get_query() + q = flow.request.query if q: q["mitmproxy"] = ["rocks"] - flow.request.set_query(q) + flow.request.query = q diff --git a/examples/redirect_requests.py b/examples/redirect_requests.py index 530da200..cc642039 100644 --- a/examples/redirect_requests.py +++ b/examples/redirect_requests.py @@ -7,12 +7,12 @@ This example shows two ways to redirect flows to other destinations. def request(context, flow): - if flow.request.get_host(hostheader=True).endswith("example.com"): + if flow.request.pretty_host(hostheader=True).endswith("example.com"): resp = HTTPResponse( [1, 1], 200, "OK", ODictCaseless([["Content-Type", "text/html"]]), "helloworld") flow.reply(resp) - if flow.request.get_host(hostheader=True).endswith("example.org"): + if flow.request.pretty_host(hostheader=True).endswith("example.org"): flow.request.host = "mitmproxy.org" - flow.request.headers["Host"] = ["mitmproxy.org"] + flow.request.update_host_header() diff --git a/libmproxy/console/common.py b/libmproxy/console/common.py index 5cb3dd2a..104b7216 100644 --- a/libmproxy/console/common.py +++ b/libmproxy/console/common.py @@ -177,7 +177,7 @@ def format_flow(f, focus, extended=False, hostheader=False, padding=2): req_timestamp = f.request.timestamp_start, req_is_replay = f.request.is_replay, req_method = f.request.method, - req_url = f.request.get_url(hostheader=hostheader), + req_url = f.request.pretty_url(hostheader=hostheader), err_msg = f.error.msg if f.error else None, resp_code = f.response.code if f.response else None, diff --git a/libmproxy/console/flowview.py b/libmproxy/console/flowview.py index 3c63ac29..9063c3e1 100644 --- a/libmproxy/console/flowview.py +++ b/libmproxy/console/flowview.py @@ -528,7 +528,9 @@ class FlowView(common.WWrap): def set_url(self, url): request = self.flow.request - if not request.set_url(str(url)): + try: + request.url = str(url) + except ValueError: return "Invalid URL." self.master.refresh_flow(self.flow) @@ -552,17 +554,17 @@ class FlowView(common.WWrap): conn.headers = flow.ODictCaseless(lst) def set_query(self, lst, conn): - conn.set_query(flow.ODict(lst)) + conn.query = flow.ODict(lst) def set_path_components(self, lst, conn): - conn.set_path_components([i[0] for i in lst]) + conn.path_components = [i[0] for i in lst] def set_form(self, lst, conn): - conn.set_form_urlencoded(flow.ODict(lst)) + conn.form_urlencoded = flow.ODict(lst) def edit_form(self, conn): self.master.view_grideditor( - grideditor.URLEncodedFormEditor(self.master, conn.get_form_urlencoded().lst, self.set_form, conn) + grideditor.URLEncodedFormEditor(self.master, conn.form_urlencoded.lst, self.set_form, conn) ) def edit_form_confirm(self, key, conn): @@ -587,7 +589,7 @@ class FlowView(common.WWrap): c = self.master.spawn_editor(conn.content or "") conn.content = c.rstrip("\n") # what? elif part == "f": - if not conn.get_form_urlencoded() and conn.content: + if not conn.form_urlencoded and conn.content: self.master.prompt_onekey( "Existing body is not a URL-encoded form. Clear and edit?", [ @@ -602,13 +604,13 @@ class FlowView(common.WWrap): elif part == "h": self.master.view_grideditor(grideditor.HeaderEditor(self.master, conn.headers.lst, self.set_headers, conn)) elif part == "p": - p = conn.get_path_components() + p = conn.path_components p = [[i] for i in p] self.master.view_grideditor(grideditor.PathEditor(self.master, p, self.set_path_components, conn)) elif part == "q": - self.master.view_grideditor(grideditor.QueryEditor(self.master, conn.get_query().lst, self.set_query, conn)) + self.master.view_grideditor(grideditor.QueryEditor(self.master, conn.query.lst, self.set_query, conn)) elif part == "u" and self.state.view_flow_mode == common.VIEW_FLOW_REQUEST: - self.master.prompt_edit("URL", conn.get_url(), self.set_url) + self.master.prompt_edit("URL", conn.url, self.set_url) elif part == "m" and self.state.view_flow_mode == common.VIEW_FLOW_REQUEST: self.master.prompt_onekey("Method", self.method_options, self.edit_method) elif part == "c" and self.state.view_flow_mode == common.VIEW_FLOW_RESPONSE: diff --git a/libmproxy/dump.py b/libmproxy/dump.py index 8ecd56e7..72ab58a3 100644 --- a/libmproxy/dump.py +++ b/libmproxy/dump.py @@ -55,7 +55,7 @@ def str_request(f, showhost): c = f.client_conn.address.host else: c = "[replay]" - r = "%s %s %s"%(c, f.request.method, f.request.get_url(showhost, f)) + r = "%s %s %s"%(c, f.request.method, f.request.pretty_url(showhost)) if f.request.stickycookie: r = "[stickycookie] " + r return r diff --git a/libmproxy/filt.py b/libmproxy/filt.py index 925dbfbb..7d2bd737 100644 --- a/libmproxy/filt.py +++ b/libmproxy/filt.py @@ -208,7 +208,7 @@ class FDomain(_Rex): code = "d" help = "Domain" def __call__(self, f): - return bool(re.search(self.expr, f.request.get_host(False, f), re.IGNORECASE)) + return bool(re.search(self.expr, f.request.host, re.IGNORECASE)) class FUrl(_Rex): @@ -222,7 +222,7 @@ class FUrl(_Rex): return klass(*toks) def __call__(self, f): - return re.search(self.expr, f.request.get_url(False, f)) + return re.search(self.expr, f.request.url) class _Int(_Action): diff --git a/libmproxy/flow.py b/libmproxy/flow.py index df72878f..eeb53e81 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -259,8 +259,8 @@ class StickyCookieState: Returns a (domain, port, path) tuple. """ return ( - m["domain"] or f.request.get_host(False, f), - f.request.get_port(f), + m["domain"] or f.request.host, + f.request.port, m["path"] or "/" ) @@ -278,7 +278,7 @@ class StickyCookieState: c = Cookie.SimpleCookie(str(i)) m = c.values()[0] k = self.ckey(m, f) - if self.domain_match(f.request.get_host(False, f), k[0]): + if self.domain_match(f.request.host, k[0]): self.jar[self.ckey(m, f)] = m def handle_request(self, f): @@ -286,8 +286,8 @@ class StickyCookieState: if f.match(self.flt): for i in self.jar.keys(): match = [ - self.domain_match(f.request.get_host(False, f), i[0]), - f.request.get_port(f) == i[1], + self.domain_match(f.request.host, i[0]), + f.request.port == i[1], f.request.path.startswith(i[2]) ] if all(match): @@ -306,7 +306,7 @@ class StickyAuthState: self.hosts = {} def handle_request(self, f): - host = f.request.get_host(False, f) + host = f.request.host if "authorization" in f.request.headers: self.hosts[host] = f.request.headers["authorization"] elif f.match(self.flt): diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index 9699c78a..253192dd 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -26,7 +26,7 @@ def get_line(fp): return line -def send_connect_request(conn, host, port): +def send_connect_request(conn, host, port, update_state=True): upstream_request = HTTPRequest("authority", "CONNECT", None, host, port, None, (1, 1), ODictCaseless(), "") conn.send(upstream_request._assemble()) @@ -36,6 +36,12 @@ def send_connect_request(conn, host, port): "Cannot establish SSL " + "connection with upstream proxy: \r\n" + str(resp._assemble())) + if update_state: + conn.state.append(("http", { + "state": "connect", + "host": host, + "port": port} + )) return resp @@ -405,7 +411,14 @@ class HTTPRequest(HTTPMessage): e for e in encoding.ENCODINGS if e in self.headers["accept-encoding"][0] )] - def get_form_urlencoded(self): + def update_host_header(self): + """ + Update the host header to reflect the current target. + """ + self.headers["Host"] = [self.host] + + @property + def form_urlencoded(self): """ Retrieves the URL-encoded form data, returning an ODict object. Returns an empty ODict if there is no data or the content-type @@ -415,7 +428,8 @@ class HTTPRequest(HTTPMessage): return ODict(utils.urldecode(self.content)) return ODict([]) - def set_form_urlencoded(self, odict): + @form_urlencoded.setter + def form_urlencoded(self, odict): """ Sets the body to the URL-encoded form data, and adds the appropriate content-type header. Note that this will destory the @@ -426,16 +440,18 @@ class HTTPRequest(HTTPMessage): self.headers["Content-Type"] = [HDR_FORM_URLENCODED] self.content = utils.urlencode(odict.lst) - def get_path_components(self, f): + @property + def path_components(self): """ Returns the path components of the URL as a list of strings. Components are unquoted. """ - _, _, path, _, _, _ = urlparse.urlparse(self.get_url(False, f)) + _, _, path, _, _, _ = urlparse.urlparse(self.url) return [urllib.unquote(i) for i in path.split("/") if i] - def set_path_components(self, lst, f): + @path_components.setter + def path_components(self, lst): """ Takes a list of strings, and sets the path component of the URL. @@ -443,32 +459,34 @@ class HTTPRequest(HTTPMessage): """ lst = [urllib.quote(i, safe="") for i in lst] path = "/" + "/".join(lst) - scheme, netloc, _, params, query, fragment = urlparse.urlparse(self.get_url(False, f)) - self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment]), f) + scheme, netloc, _, params, query, fragment = urlparse.urlparse(self.url) + self.url = urlparse.urlunparse([scheme, netloc, path, params, query, fragment]) - def get_query(self, f): + @property + def query(self): """ Gets the request query string. Returns an ODict object. """ - _, _, _, _, query, _ = urlparse.urlparse(self.get_url(False, f)) + _, _, _, _, query, _ = urlparse.urlparse(self.url) if query: return ODict(utils.urldecode(query)) return ODict([]) - def set_query(self, odict, f): + @query.setter + def query(self, odict): """ Takes an ODict object, and sets the request query string. """ - scheme, netloc, path, params, _, fragment = urlparse.urlparse(self.get_url(False, f)) + scheme, netloc, path, params, _, fragment = urlparse.urlparse(self.url) query = utils.urlencode(odict.lst) - self.set_url(urlparse.urlunparse([scheme, netloc, path, params, query, fragment]), f) + self.url = urlparse.urlunparse([scheme, netloc, path, params, query, fragment]) - def get_host(self, hostheader, flow): + def pretty_host(self, hostheader): """ Heuristic to get the host of the request. - Note that get_host() does not always return the TCP destination of the request, - e.g. on a transparently intercepted request to an unrelated HTTP proxy. + Note that pretty_host() does not always return the TCP destination of the request, + e.g. if an upstream proxy is in place If hostheader is set to True, the Host: header will be used as additional (and preferred) data source. This is handy in transparent mode, where only the ip of the destination is known, but not the @@ -478,54 +496,27 @@ class HTTPRequest(HTTPMessage): if hostheader: host = self.headers.get_first("host") if not host: - if self.host: - host = self.host - else: - for s in flow.server_conn.state: - if s[0] == "http" and s[1]["state"] == "connect": - host = s[1]["host"] - break - if not host: - host = flow.server_conn.address.host + host = self.host host = host.encode("idna") return host - def get_scheme(self, flow): - """ - Returns the request port, either from the request itself or from the flow's server connection - """ - if self.scheme: - return self.scheme - if self.form_out == "authority": # On SSLed connections, the original CONNECT request is still unencrypted. - return "http" - return "https" if flow.server_conn.ssl_established else "http" - - def get_port(self, flow): - """ - Returns the request port, either from the request itself or from the flow's server connection - """ - if self.port: - return self.port - for s in flow.server_conn.state: - if s[0] == "http" and s[1].get("state") == "connect": - return s[1]["port"] - return flow.server_conn.address.port + def pretty_url(self, hostheader): + if self.form_out == "authority": # upstream proxy mode + return "%s:%s" % (self.pretty_host(hostheader), self.port) + return utils.unparse_url(self.scheme, + self.pretty_host(hostheader), + self.port, + self.path).encode('ascii') - def get_url(self, hostheader, flow): + @property + def url(self): """ Returns a URL string, constructed from the Request's URL components. - - If hostheader is True, we use the value specified in the request - Host header to construct the URL. """ - if self.form_out == "authority": # upstream proxy mode - return "%s:%s" % (self.get_host(hostheader, flow), self.get_port(flow)) - return utils.unparse_url(self.get_scheme(flow), - self.get_host(hostheader, flow), - self.get_port(flow), - self.path).encode('ascii') + return self.pretty_url(False) - def set_url(self, url, flow): + @url.setter + def url(self, url): """ Parses a URL specification, and updates the Request's information accordingly. @@ -534,32 +525,11 @@ class HTTPRequest(HTTPMessage): """ parts = http.parse_url(url) if not parts: - return False - scheme, host, port, path = parts - is_ssl = (True if scheme == "https" else False) - - self.path = path + raise ValueError("Invalid URL: %s" % url) + self.scheme, self.host, self.port, self.path = parts - if host != self.get_host(False, flow) or port != self.get_port(flow): - if flow.live: - flow.live.change_server((host, port), ssl=is_ssl) - else: - # There's not live server connection, we're just changing the attributes here. - flow.server_conn = ServerConnection((host, port), - proxy.AddressPriority.MANUALLY_CHANGED) - flow.server_conn.ssl_established = is_ssl - - # If this is an absolute request, replace the attributes on the request object as well. - if self.host: - self.host = host - if self.port: - self.port = port - if self.scheme: - self.scheme = scheme - - return True - - def get_cookies(self): + @property + def cookies(self): cookie_headers = self.headers.get("cookie") if not cookie_headers: return None @@ -756,7 +726,8 @@ class HTTPResponse(HTTPMessage): if c: self.headers["set-cookie"] = c - def get_cookies(self): + @property + def cookies(self): cookie_headers = self.headers.get("set-cookie") if not cookie_headers: return None @@ -816,7 +787,7 @@ class HTTPFlow(Flow): s = "<HTTPFlow" for a in ("request", "response", "error", "client_conn", "server_conn"): if getattr(self, a, False): - s += "\r\n %s = {flow.%s}" % (a,a) + s += "\r\n %s = {flow.%s}" % (a, a) s += ">" return s.format(flow=self) @@ -951,8 +922,7 @@ class HTTPHandler(ProtocolHandler): # sent through to the Master. flow.request = req request_reply = self.c.channel.ask("request", flow) - self.determine_server_address(flow, flow.request) # The inline script may have changed request.host - flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow + self.process_server_address(flow) # The inline script may have changed request.host if request_reply is None or request_reply == KILL: return False @@ -1049,7 +1019,7 @@ class HTTPHandler(ProtocolHandler): def handle_server_reconnect(self, state): if state["state"] == "connect": - send_connect_request(self.c.server_conn, state["host"], state["port"]) + send_connect_request(self.c.server_conn, state["host"], state["port"], update_state=False) else: # pragma: nocover raise RuntimeError("Unknown State: %s" % state["state"]) @@ -1115,14 +1085,30 @@ class HTTPHandler(ProtocolHandler): if not self.skip_authentication: self.authenticate(request) + # Determine .scheme, .host and .port attributes + # For absolute-form requests, they are directly given in the request. + # For authority-form requests, we only need to determine the request scheme. + # For relative-form requests, we need to determine host and port as well. + if not request.scheme: + request.scheme = "https" if flow.server_conn and flow.server_conn.ssl_established else "http" + if not request.host: + # Host/Port Complication: In upstream mode, use the server we CONNECTed to, + # not the upstream proxy. + if flow.server_conn: + for s in flow.server_conn.state: + if s[0] == "http" and s[1]["state"] == "connect": + request.host, request.port = s[1]["host"], s[1]["port"] + if not request.host and flow.server_conn: + request.host, request.port = flow.server_conn.address.host, flow.server_conn.address.port + + # Now we can process the request. if request.form_in == "authority": if self.c.client_conn.ssl_established: raise http.HttpError(400, "Must not CONNECT on already encrypted connection") if self.expected_form_in == "absolute": - if not self.c.config.get_upstream_server: - self.c.set_server_address((request.host, request.port), - proxy.AddressPriority.FROM_PROTOCOL) + if not self.c.config.get_upstream_server: # Regular mode + self.c.set_server_address((request.host, request.port)) flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow self.c.establish_server_connection() self.c.client_conn.send( @@ -1141,24 +1127,63 @@ class HTTPHandler(ProtocolHandler): self.ssl_upgrade() self.skip_authentication = True return True - else: + else: # upstream proxy mode return None + else: + pass # CONNECT should never occur if we don't expect absolute-form requests + elif request.form_in == self.expected_form_in: + + request.form_out = self.expected_form_out + if request.form_in == "absolute": if request.scheme != "http": raise http.HttpError(400, "Invalid request scheme: %s" % request.scheme) - self.determine_server_address(flow, request) - request.form_out = self.expected_form_out + if request.form_out == "relative": + self.c.set_server_address((request.host, request.port)) + flow.server_conn = self.c.server_conn + + return None raise http.HttpError(400, "Invalid HTTP request form (expected: %s, got: %s)" % (self.expected_form_in, request.form_in)) - def determine_server_address(self, flow, request): - if request.form_in == "absolute": - self.c.set_server_address((request.host, request.port), - proxy.AddressPriority.FROM_PROTOCOL) - flow.server_conn = self.c.server_conn # Update server_conn attribute on the flow + def process_server_address(self, flow): + # Depending on the proxy mode, server handling is entirely different + # We provide a mostly unified API to the user, which needs to be unfiddled here + # ( See also: https://github.com/mitmproxy/mitmproxy/issues/337 ) + address = netlib.tcp.Address((flow.request.host, flow.request.port)) + + ssl = (flow.request.scheme == "https") + + if self.c.config.http_form_in == self.c.config.http_form_out == "absolute": # Upstream Proxy mode + + # The connection to the upstream proxy may have a state we may need to take into account. + connected_to = None + for s in flow.server_conn.state: + if s[0] == "http" and s[1]["state"] == "connect": + connected_to = tcp.Address((s[1]["host"], s[1]["port"])) + + # We need to reconnect if the current flow either requires a (possibly impossible) + # change to the connection state, e.g. the host has changed but we already CONNECTed somewhere else. + needs_server_change = ( + ssl != self.c.server_conn.ssl_established + or + (connected_to and address != connected_to) # HTTP proxying is "stateless", CONNECT isn't. + ) + + if needs_server_change: + # force create new connection to the proxy server to reset state + self.live.change_server(self.c.server_conn.address, force=True) + if ssl: + send_connect_request(self.c.server_conn, address.host, address.port) + self.c.establish_ssl(server=True) + else: + # If we're not in upstream mode, we just want to update the host and possibly establish TLS. + self.live.change_server(address, ssl=ssl) # this is a no op if the addresses match. + + flow.server_conn = self.c.server_conn def authenticate(self, request): if self.c.config.authenticator: @@ -1184,7 +1209,9 @@ class RequestReplayThread(threading.Thread): r.form_out = self.config.http_form_out server_address, server_ssl = False, False - if self.config.get_upstream_server: + # If the flow is live, r.host is already the correct upstream server unless modified by a script. + # If modified by a script, we probably want to keep the modified destination. + if self.config.get_upstream_server and not self.flow.live: try: # this will fail in transparent mode upstream_info = self.config.get_upstream_server(self.flow.client_conn) @@ -1193,17 +1220,16 @@ class RequestReplayThread(threading.Thread): except proxy.ProxyError: pass if not server_address: - server_address = (r.get_host(False, self.flow), r.get_port(self.flow)) + server_address = (r.host, r.port) - server = ServerConnection(server_address, None) + server = ServerConnection(server_address) server.connect() - if server_ssl or r.get_scheme(self.flow) == "https": + if server_ssl or r.scheme == "https": if self.config.http_form_out == "absolute": # form_out == absolute -> forward mode -> send CONNECT - send_connect_request(server, r.get_host(), r.get_port()) + send_connect_request(server, r.host, r.port) r.form_out = "relative" - server.establish_ssl(self.config.clientcerts, - self.flow.server_conn.sni) + server.establish_ssl(self.config.clientcerts, sni=r.host) server.send(r._assemble()) self.flow.response = HTTPResponse.from_stream(server.rfile, r.method, body_size_limit=self.config.body_size_limit) diff --git a/libmproxy/protocol/primitives.py b/libmproxy/protocol/primitives.py index a84b4061..416e6880 100644 --- a/libmproxy/protocol/primitives.py +++ b/libmproxy/protocol/primitives.py @@ -2,7 +2,6 @@ from __future__ import absolute_import import copy import netlib.tcp from .. import stateobject, utils, version -from ..proxy.primitives import AddressPriority from ..proxy.connection import ClientConnection, ServerConnection @@ -59,7 +58,7 @@ class Flow(stateobject.SimpleStateObject): """@type: ClientConnection""" self.server_conn = server_conn """@type: ServerConnection""" - self.live = live # Used by flow.request.set_url to change the server address + self.live = live """@type: LiveConnection""" self.error = None @@ -153,44 +152,48 @@ class LiveConnection(object): without requiring the expose the ConnectionHandler. """ def __init__(self, c): - self._c = c + self.c = c + self._backup_server_conn = None """@type: libmproxy.proxy.server.ConnectionHandler""" - def change_server(self, address, ssl, persistent_change=False): + def change_server(self, address, ssl=False, force=False, persistent_change=False): address = netlib.tcp.Address.wrap(address) - if address != self._c.server_conn.address: + if force or address != self.c.server_conn.address or ssl != self.c.server_conn.ssl_established: - self._c.log("Change server connection: %s:%s -> %s:%s" % ( - self._c.server_conn.address.host, - self._c.server_conn.address.port, + self.c.log("Change server connection: %s:%s -> %s:%s [persistent: %s]" % ( + self.c.server_conn.address.host, + self.c.server_conn.address.port, address.host, - address.port + address.port, + persistent_change ), "debug") - if not hasattr(self, "_backup_server_conn"): - self._backup_server_conn = self._c.server_conn - self._c.server_conn = None + if self._backup_server_conn: + self._backup_server_conn = self.c.server_conn + self.c.server_conn = None else: # This is at least the second temporary change. We can kill the current connection. - self._c.del_server_connection() + self.c.del_server_connection() - self._c.set_server_address(address, AddressPriority.MANUALLY_CHANGED) - self._c.establish_server_connection(ask=False) + self.c.set_server_address(address) + self.c.establish_server_connection(ask=False) if ssl: - self._c.establish_ssl(server=True) - if hasattr(self, "_backup_server_conn") and persistent_change: - del self._backup_server_conn + self.c.establish_ssl(server=True) + if persistent_change: + self._backup_server_conn = None def restore_server(self): - if not hasattr(self, "_backup_server_conn"): + # TODO: Similar to _backup_server_conn, introduce _cache_server_conn, which keeps the changed connection open + # This may be beneficial if a user is rewriting all requests from http to https or similar. + if not self._backup_server_conn: return - self._c.log("Restore original server connection: %s:%s -> %s:%s" % ( - self._c.server_conn.address.host, - self._c.server_conn.address.port, + self.c.log("Restore original server connection: %s:%s -> %s:%s" % ( + self.c.server_conn.address.host, + self.c.server_conn.address.port, self._backup_server_conn.address.host, self._backup_server_conn.address.port ), "debug") - self._c.del_server_connection() - self._c.server_conn = self._backup_server_conn - del self._backup_server_conn
\ No newline at end of file + self.c.del_server_connection() + self.c.server_conn = self._backup_server_conn + self._backup_server_conn = None
\ No newline at end of file diff --git a/libmproxy/proxy/connection.py b/libmproxy/proxy/connection.py index d99ffa9b..5c421557 100644 --- a/libmproxy/proxy/connection.py +++ b/libmproxy/proxy/connection.py @@ -72,9 +72,8 @@ class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject): class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject): - def __init__(self, address, priority): + def __init__(self, address): tcp.TCPClient.__init__(self, address) - self.priority = priority self.state = [] # a list containing (conntype, state) tuples self.peername = None @@ -131,7 +130,7 @@ class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject): @classmethod def _from_state(cls, state): - f = cls(tuple(), None) + f = cls(tuple()) f._load_state(state) return f diff --git a/libmproxy/proxy/primitives.py b/libmproxy/proxy/primitives.py index e09f23e4..8c674381 100644 --- a/libmproxy/proxy/primitives.py +++ b/libmproxy/proxy/primitives.py @@ -45,19 +45,6 @@ class TransparentUpstreamServerResolver(UpstreamServerResolver): return [ssl, ssl] + list(dst) -class AddressPriority(object): - """ - Enum that signifies the priority of the given address when choosing the destination host. - Higher is better (None < i) - """ - MANUALLY_CHANGED = 3 - """user changed the target address in the ui""" - FROM_SETTINGS = 2 - """upstream server from arguments (reverse proxy, upstream proxy or from transparent resolver)""" - FROM_PROTOCOL = 1 - """derived from protocol (e.g. absolute-form http requests)""" - - class Log: def __init__(self, msg, level="info"): self.msg = msg diff --git a/libmproxy/proxy/server.py b/libmproxy/proxy/server.py index 092eae54..58e386ab 100644 --- a/libmproxy/proxy/server.py +++ b/libmproxy/proxy/server.py @@ -5,7 +5,7 @@ import socket from OpenSSL import SSL from netlib import tcp -from .primitives import ProxyServerError, Log, ProxyError, AddressPriority +from .primitives import ProxyServerError, Log, ProxyError from .connection import ClientConnection, ServerConnection from ..protocol.handle import protocol_handler from .. import version @@ -76,7 +76,7 @@ class ConnectionHandler: client_ssl, server_ssl = False, False if self.config.get_upstream_server: upstream_info = self.config.get_upstream_server(self.client_conn.connection) - self.set_server_address(upstream_info[2:], AddressPriority.FROM_SETTINGS) + self.set_server_address(upstream_info[2:]) client_ssl, server_ssl = upstream_info[:2] if self.check_ignore_address(self.server_conn.address): self.log("Ignore host: %s:%s" % self.server_conn.address(), "info") @@ -129,27 +129,22 @@ class ConnectionHandler: else: return False - def set_server_address(self, address, priority): + def set_server_address(self, address): """ Sets a new server address with the given priority. Does not re-establish either connection or SSL handshake. """ address = tcp.Address.wrap(address) - if self.server_conn: - if self.server_conn.priority > priority: - self.log("Attempt to change server address, " - "but priority is too low (is: %s, got: %s)" % ( - self.server_conn.priority, priority), "debug") - return - if self.server_conn.address == address: - self.server_conn.priority = priority # Possibly increase priority - return + # Don't reconnect to the same destination. + if self.server_conn and self.server_conn.address == address: + return + if self.server_conn: self.del_server_connection() self.log("Set new server address: %s:%s" % (address.host, address.port), "debug") - self.server_conn = ServerConnection(address, priority) + self.server_conn = ServerConnection(address) def establish_server_connection(self, ask=True): """ @@ -212,12 +207,11 @@ class ConnectionHandler: def server_reconnect(self): address = self.server_conn.address had_ssl = self.server_conn.ssl_established - priority = self.server_conn.priority state = self.server_conn.state sni = self.sni self.log("(server reconnect follows)", "debug") self.del_server_connection() - self.set_server_address(address, priority) + self.set_server_address(address) self.establish_server_connection() for s in state: diff --git a/test/test_flow.py b/test/test_flow.py index 6e9464e7..4bc2391e 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -753,10 +753,10 @@ class TestRequest: def test_simple(self): f = tutils.tflow() r = f.request - u = r.get_url(False, f) - assert r.set_url(u, f) - assert not r.set_url("", f) - assert r.get_url(False, f) == u + u = r.url + r.url = u + tutils.raises(ValueError, setattr, r, "url", "") + assert r.url == u assert r._assemble() assert r.size() == len(r._assemble()) @@ -771,83 +771,81 @@ class TestRequest: tutils.raises("Cannot assemble flow with CONTENT_MISSING", r._assemble) def test_get_url(self): - f = tutils.tflow() - r = f.request + r = tutils.treq() - assert r.get_url(False, f) == "http://address:22/path" + assert r.url == "http://address:22/path" r.scheme = "https" - assert r.get_url(False, f) == "https://address:22/path" + assert r.url == "https://address:22/path" r.host = "host" r.port = 42 - assert r.get_url(False, f) == "https://host:42/path" + assert r.url == "https://host:42/path" r.host = "address" r.port = 22 - assert r.get_url(False, f) == "https://address:22/path" + assert r.url== "https://address:22/path" - assert r.get_url(True, f) == "https://address:22/path" + assert r.pretty_url(True) == "https://address:22/path" r.headers["Host"] = ["foo.com"] - assert r.get_url(False, f) == "https://address:22/path" - assert r.get_url(True, f) == "https://foo.com:22/path" + assert r.pretty_url(False) == "https://address:22/path" + assert r.pretty_url(True) == "https://foo.com:22/path" def test_path_components(self): - f = tutils.tflow() - r = f.request + r = tutils.treq() r.path = "/" - assert r.get_path_components(f) == [] + assert r.path_components == [] r.path = "/foo/bar" - assert r.get_path_components(f) == ["foo", "bar"] + assert r.path_components == ["foo", "bar"] q = flow.ODict() q["test"] = ["123"] - r.set_query(q, f) - assert r.get_path_components(f) == ["foo", "bar"] - - r.set_path_components([], f) - assert r.get_path_components(f) == [] - r.set_path_components(["foo"], f) - assert r.get_path_components(f) == ["foo"] - r.set_path_components(["/oo"], f) - assert r.get_path_components(f) == ["/oo"] + r.query = q + assert r.path_components == ["foo", "bar"] + + r.path_components = [] + assert r.path_components == [] + r.path_components = ["foo"] + assert r.path_components == ["foo"] + r.path_components = ["/oo"] + assert r.path_components == ["/oo"] assert "%2F" in r.path def test_getset_form_urlencoded(self): d = flow.ODict([("one", "two"), ("three", "four")]) r = tutils.treq(content=utils.urlencode(d.lst)) r.headers["content-type"] = [protocol.http.HDR_FORM_URLENCODED] - assert r.get_form_urlencoded() == d + assert r.form_urlencoded == d d = flow.ODict([("x", "y")]) - r.set_form_urlencoded(d) - assert r.get_form_urlencoded() == d + r.form_urlencoded = d + assert r.form_urlencoded == d r.headers["content-type"] = ["foo"] - assert not r.get_form_urlencoded() + assert not r.form_urlencoded def test_getset_query(self): h = flow.ODictCaseless() - f = tutils.tflow() - f.request.path = "/foo?x=y&a=b" - q = f.request.get_query(f) + r = tutils.treq() + r.path = "/foo?x=y&a=b" + q = r.query assert q.lst == [("x", "y"), ("a", "b")] - f.request.path = "/" - q = f.request.get_query(f) + r.path = "/" + q = r.query assert not q - f.request.path = "/?adsfa" - q = f.request.get_query(f) + r.path = "/?adsfa" + q = r.query assert q.lst == [("adsfa", "")] - f.request.path = "/foo?x=y&a=b" - assert f.request.get_query(f) - f.request.set_query(flow.ODict([]), f) - assert not f.request.get_query(f) + r.path = "/foo?x=y&a=b" + assert r.query + r.query = flow.ODict([]) + assert not r.query qv = flow.ODict([("a", "b"), ("c", "d")]) - f.request.set_query(qv, f) - assert f.request.get_query(f) == qv + r.query = qv + assert r.query == qv def test_anticache(self): h = flow.ODictCaseless() @@ -918,14 +916,14 @@ class TestRequest: h = flow.ODictCaseless() r = tutils.treq() r.headers = h - assert r.get_cookies() is None + assert r.cookies is None def test_get_cookies_single(self): h = flow.ODictCaseless() h["Cookie"] = ["cookiename=cookievalue"] r = tutils.treq() r.headers = h - result = r.get_cookies() + result = r.cookies assert len(result)==1 assert result['cookiename']==('cookievalue',{}) @@ -934,7 +932,7 @@ class TestRequest: h["Cookie"] = ["cookiename=cookievalue;othercookiename=othercookievalue"] r = tutils.treq() r.headers = h - result = r.get_cookies() + result = r.cookies assert len(result)==2 assert result['cookiename']==('cookievalue',{}) assert result['othercookiename']==('othercookievalue',{}) @@ -944,7 +942,7 @@ class TestRequest: h["Cookie"] = ["cookiename=coo=kievalue;othercookiename=othercookievalue"] r = tutils.treq() r.headers = h - result = r.get_cookies() + result = r.cookies assert len(result)==2 assert result['cookiename']==('coo=kievalue',{}) assert result['othercookiename']==('othercookievalue',{}) @@ -1054,14 +1052,14 @@ class TestResponse: h = flow.ODictCaseless() resp = tutils.tresp() resp.headers = h - assert not resp.get_cookies() + assert not resp.cookies def test_get_cookies_simple(self): h = flow.ODictCaseless() h["Set-Cookie"] = ["cookiename=cookievalue"] resp = tutils.tresp() resp.headers = h - result = resp.get_cookies() + result = resp.cookies assert len(result)==1 assert "cookiename" in result assert result["cookiename"] == ("cookievalue", {}) @@ -1071,7 +1069,7 @@ class TestResponse: h["Set-Cookie"] = ["cookiename=cookievalue;domain=example.com;expires=Wed Oct 21 16:29:41 2015;path=/; HttpOnly"] resp = tutils.tresp() resp.headers = h - result = resp.get_cookies() + result = resp.cookies assert len(result)==1 assert "cookiename" in result assert result["cookiename"][0] == "cookievalue" @@ -1086,7 +1084,7 @@ class TestResponse: h["Set-Cookie"] = ["cookiename=; Expires=Thu, 01-Jan-1970 00:00:01 GMT; path=/"] resp = tutils.tresp() resp.headers = h - result = resp.get_cookies() + result = resp.cookies assert len(result)==1 assert "cookiename" in result assert result["cookiename"][0] == "" @@ -1097,7 +1095,7 @@ class TestResponse: h["Set-Cookie"] = ["cookiename=cookievalue","othercookie=othervalue"] resp = tutils.tresp() resp.headers = h - result = resp.get_cookies() + result = resp.cookies assert len(result)==2 assert "cookiename" in result assert result["cookiename"] == ("cookievalue", {}) diff --git a/test/test_protocol_http.py b/test/test_protocol_http.py index c2ff7b44..c76fa192 100644 --- a/test/test_protocol_http.py +++ b/test/test_protocol_http.py @@ -58,12 +58,12 @@ class TestHTTPRequest: tutils.raises("Invalid request form", r._assemble, "antiauthority") def test_set_url(self): - f = tutils.tflow(req=tutils.treq_absolute()) - f.request.set_url("https://otheraddress:42/ORLY", f) - assert f.request.scheme == "https" - assert f.request.host == "otheraddress" - assert f.request.port == 42 - assert f.request.path == "/ORLY" + r = tutils.treq_absolute() + r.url = "https://otheraddress:42/ORLY" + assert r.scheme == "https" + assert r.host == "otheraddress" + assert r.port == 42 + assert r.path == "/ORLY" class TestHTTPResponse: diff --git a/test/test_proxy.py b/test/test_proxy.py index 91e4954f..ad2bb2d7 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -23,7 +23,7 @@ class TestServerConnection: self.d.shutdown() def test_simple(self): - sc = ServerConnection((self.d.IFACE, self.d.port), None) + sc = ServerConnection((self.d.IFACE, self.d.port)) sc.connect() f = tutils.tflow() f.server_conn = sc @@ -35,7 +35,7 @@ class TestServerConnection: sc.finish() def test_terminate_error(self): - sc = ServerConnection((self.d.IFACE, self.d.port), None) + sc = ServerConnection((self.d.IFACE, self.d.port)) sc.connect() sc.connection = mock.Mock() sc.connection.recv = mock.Mock(return_value=False) diff --git a/test/test_server.py b/test/test_server.py index 6c0829f4..21d01f5a 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -330,17 +330,14 @@ class MasterRedirectRequest(tservers.TestMaster): def handle_request(self, f): request = f.request if request.path == "/p/201": - url = request.get_url(False, f) + url = request.url new = "http://127.0.0.1:%s/p/201" % self.redirect_port - request.set_url(new, f) - request.set_url(new, f) + request.url = new f.live.change_server(("127.0.0.1", self.redirect_port), False) - request.set_url(url, f) + request.url = url tutils.raises("SSL handshake error", f.live.change_server, ("127.0.0.1", self.redirect_port), True) - request.set_url(new, f) - request.set_url(url, f) - request.set_url(new, f) + request.url = new tservers.TestMaster.handle_request(self, f) def handle_response(self, f): |