From 57c653be5f8a6fe0d1785421faa6513ebd3d48c0 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 3 Aug 2011 22:38:23 +1200 Subject: Move all HTTP objects to flow.py That's Request, Response, ClientConnect, ClientDisconnect, Error, and Headers. --- libmproxy/flow.py | 568 +++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 545 insertions(+), 23 deletions(-) (limited to 'libmproxy/flow.py') diff --git a/libmproxy/flow.py b/libmproxy/flow.py index 1043cb21..1afab895 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -2,8 +2,10 @@ This module provides more sophisticated flow tracking. These match requests with their responses, and provide filtering and interception facilities. """ -import subprocess, sys, json, hashlib, Cookie, cookielib -import proxy, threading, netstring, filt, script +import subprocess, sys, json, hashlib, Cookie, cookielib, base64, copy, re +import time +import netstring, filt, script, utils, encoding, proxy +from email.utils import parsedate_tz, formatdate, mktime_tz import controller, version class RunException(Exception): @@ -13,22 +15,542 @@ class RunException(Exception): self.errout = errout -# begin nocover -class RequestReplayThread(threading.Thread): - def __init__(self, flow, masterq): - self.flow, self.masterq = flow, masterq - threading.Thread.__init__(self) +class Headers: + def __init__(self, lst=None): + if lst: + self.lst = lst + else: + self.lst = [] + + def _kconv(self, s): + return s.lower() + + def __eq__(self, other): + return self.lst == other.lst + + def __getitem__(self, k): + ret = [] + k = self._kconv(k) + for i in self.lst: + if self._kconv(i[0]) == k: + ret.append(i[1]) + return ret + + def _filter_lst(self, k, lst): + new = [] + for i in lst: + if self._kconv(i[0]) != k: + new.append(i) + return new + + def __setitem__(self, k, hdrs): + k = self._kconv(k) + new = self._filter_lst(k, self.lst) + for i in hdrs: + new.append((k, i)) + self.lst = new + + def __delitem__(self, k): + self.lst = self._filter_lst(k, self.lst) + + def __contains__(self, k): + for i in self.lst: + if self._kconv(i[0]) == k: + return True + return False + + def add(self, key, value): + self.lst.append([key, str(value)]) + + def get_state(self): + return [tuple(i) for i in self.lst] + + @classmethod + def from_state(klass, state): + return klass([list(i) for i in state]) + + def copy(self): + lst = copy.deepcopy(self.lst) + return Headers(lst) + + def __repr__(self): + """ + Returns a string containing a formatted header string. + """ + headerElements = [] + for itm in self.lst: + headerElements.append(itm[0] + ": " + itm[1]) + headerElements.append("") + return "\r\n".join(headerElements) + + def match_re(self, expr): + """ + Match the regular expression against each header (key, value) pair. + """ + for k, v in self.lst: + s = "%s: %s"%(k, v) + if re.search(expr, s): + return True + return False + + def read(self, fp): + """ + Read a set of headers from a file pointer. Stop once a blank line + is reached. + """ + ret = [] + name = '' + while 1: + line = fp.readline() + if not line or line == '\r\n' or line == '\n': + break + if line[0] in ' \t': + # continued header + ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() + else: + i = line.find(':') + # We're being liberal in what we accept, here. + if i > 0: + name = line[:i] + value = line[i+1:].strip() + ret.append([name, value]) + self.lst = ret + + def replace(self, pattern, repl, *args, **kwargs): + """ + Replaces a regular expression pattern with repl in both header keys + and values. Returns the number of replacements made. + """ + nlst, count = [], 0 + for i in self.lst: + k, c = re.subn(pattern, repl, i[0], *args, **kwargs) + count += c + v, c = re.subn(pattern, repl, i[1], *args, **kwargs) + count += c + nlst.append([k, v]) + self.lst = nlst + return count + + +class HTTPMsg(controller.Msg): + def decode(self): + """ + Alters Response object, decoding its content based on the current + Content-Encoding header and changing Content-Encoding header to + 'identity'. + """ + ce = self.headers["content-encoding"] + if not ce or ce[0] not in encoding.ENCODINGS: + return + self.content = encoding.decode( + ce[0], + self.content + ) + del self.headers["content-encoding"] + + def encode(self, e): + """ + Alters Response object, encoding its content with the specified + coding. This method should only be called on Responses with + Content-Encoding headers of 'identity'. + """ + self.content = encoding.encode(e, self.content) + self.headers["content-encoding"] = [e] + + +class Request(HTTPMsg): + FMT = '%s %s HTTP/1.1\r\n%s\r\n%s' + FMT_PROXY = '%s %s://%s:%s%s HTTP/1.1\r\n%s\r\n%s' + def __init__(self, client_conn, host, port, scheme, method, path, headers, content, timestamp=None): + self.client_conn = client_conn + self.host, self.port, self.scheme = host, port, scheme + self.method, self.path, self.headers, self.content = method, path, headers, content + self.timestamp = timestamp or utils.timestamp() + self.close = False + controller.Msg.__init__(self) + + # Have this request's cookies been modified by sticky cookies or auth? + self.stickycookie = False + self.stickyauth = False + + def anticache(self): + """ + Modifies this request to remove headers that might produce a cached + response. That is, we remove ETags and If-Modified-Since headers. + """ + delheaders = [ + "if-modified-since", + "if-none-match", + ] + for i in delheaders: + del self.headers[i] + + def anticomp(self): + """ + Modifies this request to remove headers that will compress the + resource's data. + """ + self.headers["accept-encoding"] = ["identity"] + + def constrain_encoding(self): + """ + Limits the permissible Accept-Encoding values, based on what we can + decode appropriately. + """ + if self.headers["accept-encoding"]: + self.headers["accept-encoding"] = [', '.join([ + e for e in encoding.ENCODINGS if e in self.headers["accept-encoding"][0] + ])] + + def set_replay(self): + self.client_conn = None + + def is_replay(self): + if self.client_conn: + return False + else: + return True + + def load_state(self, state): + if state["client_conn"]: + if self.client_conn: + self.client_conn.load_state(state["client_conn"]) + else: + self.client_conn = ClientConnect.from_state(state["client_conn"]) + else: + self.client_conn = None + self.host = state["host"] + self.port = state["port"] + self.scheme = state["scheme"] + self.method = state["method"] + self.path = state["path"] + self.headers = Headers.from_state(state["headers"]) + self.content = base64.decodestring(state["content"]) + self.timestamp = state["timestamp"] + + def get_state(self): + return dict( + client_conn = self.client_conn.get_state() if self.client_conn else None, + host = self.host, + port = self.port, + scheme = self.scheme, + method = self.method, + path = self.path, + headers = self.headers.get_state(), + content = base64.encodestring(self.content), + timestamp = self.timestamp, + ) + + @classmethod + def from_state(klass, state): + return klass( + ClientConnect.from_state(state["client_conn"]), + str(state["host"]), + state["port"], + str(state["scheme"]), + str(state["method"]), + str(state["path"]), + Headers.from_state(state["headers"]), + base64.decodestring(state["content"]), + state["timestamp"] + ) + + def __hash__(self): + return id(self) + + def __eq__(self, other): + return self.get_state() == other.get_state() + + def copy(self): + c = copy.copy(self) + c.headers = self.headers.copy() + return c + + def hostport(self): + if (self.port, self.scheme) in [(80, "http"), (443, "https")]: + host = self.host + else: + host = "%s:%s"%(self.host, self.port) + return host + + def url(self): + return "%s://%s%s"%(self.scheme, self.hostport(), self.path) + + def set_url(self, url): + parts = utils.parse_url(url) + if not parts: + return False + self.scheme, self.host, self.port, self.path = parts + return True + + def is_response(self): + return False + + def assemble(self, _proxy = False): + """ + Assembles the request for transmission to the server. We make some + modifications to make sure interception works properly. + """ + headers = self.headers.copy() + utils.del_all( + headers, + [ + 'proxy-connection', + 'keep-alive', + 'connection', + 'content-length', + 'transfer-encoding' + ] + ) + if not 'host' in headers: + headers["host"] = [self.hostport()] + content = self.content + if content is not None: + headers["content-length"] = [str(len(content))] + else: + content = "" + if self.close: + headers["connection"] = ["close"] + if not _proxy: + return self.FMT % (self.method, self.path, str(headers), content) + else: + return self.FMT_PROXY % (self.method, self.scheme, self.host, self.port, self.path, str(headers), content) + + def replace(self, pattern, repl, *args, **kwargs): + """ + Replaces a regular expression pattern with repl in both the headers + and the body of the request. Returns the number of replacements + made. + """ + self.content, c = re.subn(pattern, repl, self.content, *args, **kwargs) + self.path, pc = re.subn(pattern, repl, self.path, *args, **kwargs) + c += pc + c += self.headers.replace(pattern, repl, *args, **kwargs) + return c + + +class Response(HTTPMsg): + FMT = '%s\r\n%s\r\n%s' + def __init__(self, request, code, msg, headers, content, timestamp=None): + self.request = request + self.code, self.msg = code, msg + self.headers, self.content = headers, content + self.timestamp = timestamp or utils.timestamp() + controller.Msg.__init__(self) + self.replay = False + + def _refresh_cookie(self, c, delta): + """ + Takes a cookie string c and a time delta in seconds, and returns + a refreshed cookie string. + """ + c = Cookie.SimpleCookie(str(c)) + for i in c.values(): + if "expires" in i: + d = parsedate_tz(i["expires"]) + if d: + d = mktime_tz(d) + delta + i["expires"] = formatdate(d) + else: + # This can happen when the expires tag is invalid. + # reddit.com sends a an expires tag like this: "Thu, 31 Dec + # 2037 23:59:59 GMT", which is valid RFC 1123, but not + # strictly correct according tot he cookie spec. Browsers + # appear to parse this tolerantly - maybe we should too. + # For now, we just ignore this. + del i["expires"] + return c.output(header="").strip() + + def refresh(self, now=None): + """ + This fairly complex and heuristic function refreshes a server + response for replay. + + - It adjusts date, expires and last-modified headers. + - It adjusts cookie expiration. + """ + if not now: + now = time.time() + delta = now - self.timestamp + refresh_headers = [ + "date", + "expires", + "last-modified", + ] + for i in refresh_headers: + if i in self.headers: + d = parsedate_tz(self.headers[i][0]) + if d: + new = mktime_tz(d) + delta + self.headers[i] = [formatdate(new)] + c = [] + for i in self.headers["set-cookie"]: + c.append(self._refresh_cookie(i, delta)) + if c: + self.headers["set-cookie"] = c + + def set_replay(self): + self.replay = True + + def is_replay(self): + return self.replay + + def load_state(self, state): + self.code = state["code"] + self.msg = state["msg"] + self.headers = Headers.from_state(state["headers"]) + self.content = base64.decodestring(state["content"]) + self.timestamp = state["timestamp"] + + def get_state(self): + return dict( + code = self.code, + msg = self.msg, + headers = self.headers.get_state(), + timestamp = self.timestamp, + content = base64.encodestring(self.content) + ) + + @classmethod + def from_state(klass, request, state): + return klass( + request, + state["code"], + str(state["msg"]), + Headers.from_state(state["headers"]), + base64.decodestring(state["content"]), + state["timestamp"], + ) + + def __eq__(self, other): + return self.get_state() == other.get_state() + + def copy(self): + c = copy.copy(self) + c.headers = self.headers.copy() + return c + + def is_response(self): + return True + + def assemble(self): + """ + Assembles the response for transmission to the client. We make some + modifications to make sure interception works properly. + """ + headers = self.headers.copy() + utils.del_all( + headers, + ['proxy-connection', 'connection', 'keep-alive', 'transfer-encoding'] + ) + content = self.content + if content is not None: + headers["content-length"] = [str(len(content))] + else: + content = "" + if self.request.client_conn.close: + headers["connection"] = ["close"] + proto = "HTTP/1.1 %s %s"%(self.code, str(self.msg)) + data = (proto, str(headers), content) + return self.FMT%data + + def replace(self, pattern, repl, *args, **kwargs): + """ + Replaces a regular expression pattern with repl in both the headers + and the body of the response. Returns the number of replacements + made. + """ + self.content, c = re.subn(pattern, repl, self.content, *args, **kwargs) + c += self.headers.replace(pattern, repl, *args, **kwargs) + return c + + +class ClientDisconnect(controller.Msg): + def __init__(self, client_conn): + controller.Msg.__init__(self) + self.client_conn = client_conn + + +class ClientConnect(controller.Msg): + def __init__(self, address): + """ + address is an (address, port) tuple, or None if this connection has + been replayed from within mitmproxy. + """ + self.address = address + self.close = False + self.requestcount = 0 + self.connection_error = None + controller.Msg.__init__(self) + + def __eq__(self, other): + return self.get_state() == other.get_state() + + def load_state(self, state): + self.address = state + + def get_state(self): + return list(self.address) if self.address else None + + @classmethod + def from_state(klass, state): + if state: + return klass(state) + else: + return None + + def copy(self): + return copy.copy(self) + + +class Error(controller.Msg): + def __init__(self, request, msg, timestamp=None): + self.request, self.msg = request, msg + self.timestamp = timestamp or utils.timestamp() + controller.Msg.__init__(self) + + def load_state(self, state): + self.msg = state["msg"] + self.timestamp = state["timestamp"] + + def copy(self): + return copy.copy(self) + + def get_state(self): + return dict( + msg = self.msg, + timestamp = self.timestamp, + ) + + @classmethod + def from_state(klass, state): + return klass( + None, + state["msg"], + state["timestamp"], + ) + + def __eq__(self, other): + return self.get_state() == other.get_state() + + def replace(self, pattern, repl, *args, **kwargs): + """ + Replaces a regular expression pattern with repl in both the headers + and the body of the request. Returns the number of replacements + made. + """ + self.msg, c = re.subn(pattern, repl, self.msg, *args, **kwargs) + return c + + + + + + + + + - def run(self): - try: - server = proxy.ServerConnection(self.flow.request) - server.send_request(self.flow.request) - response = server.read_response() - response.send(self.masterq) - except proxy.ProxyError, v: - err = proxy.Error(self.flow.request, v.msg) - err.send(self.masterq) -# end nocover class ClientPlaybackState: @@ -217,13 +739,13 @@ class Flow: if self.request: self.request.load_state(state["request"]) else: - self.request = proxy.Request.from_state(state["request"]) + self.request = Request.from_state(state["request"]) if state["response"]: if self.response: self.response.load_state(state["response"]) else: - self.response = proxy.Response.from_state(self.request, state["response"]) + self.response = Response.from_state(self.request, state["response"]) else: self.response = None @@ -231,7 +753,7 @@ class Flow: if self.error: self.error.load_state(state["error"]) else: - self.error = proxy.Error.from_state(state["error"]) + self.error = Error.from_state(state["error"]) else: self.error = None @@ -261,7 +783,7 @@ class Flow: return True def kill(self, master): - self.error = proxy.Error(self.request, "Connection killed") + self.error = Error(self.request, "Connection killed") if self.request and not self.request.acked: self.request.ack(None) elif self.response and not self.response.acked: @@ -519,7 +1041,7 @@ class FlowMaster(controller.Master): rflow = self.server_playback.next_flow(flow) if not rflow: return None - response = proxy.Response.from_state(flow.request, rflow.response.get_state()) + response = Response.from_state(flow.request, rflow.response.get_state()) response.set_replay() flow.response = response if self.refresh_server_playback: @@ -594,7 +1116,7 @@ class FlowMaster(controller.Master): f.response = None f.error = None self.process_new_request(f) - rt = RequestReplayThread(f, self.masterq) + rt = proxy.RequestReplayThread(f, self.masterq) rt.start() #end nocover -- cgit v1.2.3