From 5da27a9905302a5e43fdf4db8a7b7b784544bed2 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 19 Feb 2011 17:00:24 +1300 Subject: Refactor Flow primitives to remove HTTP1.0 assumption. This is a big patch removing the assumption that there's one connection per Request/Response pair. It touches pretty much every part of mitmproxy, so expect glitches until everything is ironed out. --- libmproxy/console.py | 20 ++++---- libmproxy/dump.py | 14 ++---- libmproxy/flow.py | 73 +++++++++++++---------------- libmproxy/proxy.py | 43 ++++++++++------- test/test_console.py | 32 ++++--------- test/test_dump.py | 2 +- test/test_filt.py | 2 +- test/test_flow.py | 129 ++++++++++++++++++++------------------------------- test/test_proxy.py | 10 ++-- test/utils.py | 6 +-- 10 files changed, 139 insertions(+), 192 deletions(-) diff --git a/libmproxy/console.py b/libmproxy/console.py index 895974d2..815eebdf 100644 --- a/libmproxy/console.py +++ b/libmproxy/console.py @@ -66,7 +66,7 @@ def format_flow(f, focus, extended=False, padding=2): f.request.url(), ), ] - if f.response or f.error or f.is_replay(): + if f.response or f.error or f.request.is_replay(): tsr = f.response or f.error if extended and tsr: ts = ("highlight", utils.format_timestamp(tsr.timestamp) + " ") @@ -77,7 +77,7 @@ def format_flow(f, focus, extended=False, padding=2): txt.append(("text", ts)) txt.append(" "*(padding+2)) met = "" - if f.is_replay(): + if f.request.is_replay(): txt.append(("method", "[replay] ")) elif f.modified(): txt.append(("method", "[edited] ")) @@ -715,17 +715,13 @@ class ConsoleState(flow.State): self.last_script = "" self.last_saveload = "" - def add_browserconnect(self, f): - flow.State.add_browserconnect(self, f) + def add_request(self, req): + f = flow.State.add_request(self, req) if self.focus is None: self.set_focus(0) else: self.set_focus(self.focus + 1) - - def add_request(self, req): - if self.focus is None: - self.set_focus(0) - return flow.State.add_request(self, req) + return f def add_response(self, resp): if self.store is not None: @@ -1305,7 +1301,7 @@ class ConsoleMaster(flow.FlowMaster): def process_flow(self, f, r): if f.match(self.state.beep): urwid.curses_display.curses.beep() - if f.match(self.state.intercept) and not f.is_replay(): + if f.match(self.state.intercept) and not f.request.is_replay(): f.intercept() else: r.ack() @@ -1313,8 +1309,8 @@ class ConsoleMaster(flow.FlowMaster): self.refresh_connection(f) # Handlers - def handle_clientconnection(self, r): - f = flow.FlowMaster.handle_clientconnection(self, r) + def handle_clientconnect(self, r): + f = flow.FlowMaster.handle_clientconnect(self, r) if f: self.sync_list_view() diff --git a/libmproxy/dump.py b/libmproxy/dump.py index 372e6ef6..f6a7ae7e 100644 --- a/libmproxy/dump.py +++ b/libmproxy/dump.py @@ -38,14 +38,6 @@ class DumpMaster(flow.FlowMaster): except IOError, v: raise DumpError(v.strerror) - def handle_clientconnection(self, r): - flow.FlowMaster.handle_clientconnection(self, r) - r.ack() - - def handle_error(self, r): - flow.FlowMaster.handle_error(self, r) - r.ack() - def _runscript(self, f, script): try: ret = f.run_script(script) @@ -80,12 +72,12 @@ class DumpMaster(flow.FlowMaster): return sz = utils.pretty_size(len(f.response.content)) if self.o.verbosity == 1: - print >> self.outfile, f.client_conn.address[0], + print >> self.outfile, f.request.client_conn.address[0], print >> self.outfile, f.request.short() print >> self.outfile, " <<", print >> self.outfile, f.response.short(), sz elif self.o.verbosity == 2: - print >> self.outfile, f.client_conn.address[0], + print >> self.outfile, f.request.client_conn.address[0], print >> self.outfile, f.request.short() print >> self.outfile, self.indent(4, f.request.headers) print >> self.outfile @@ -93,7 +85,7 @@ class DumpMaster(flow.FlowMaster): print >> self.outfile, self.indent(4, f.response.headers) print >> self.outfile, "\n" elif self.o.verbosity == 3: - print >> self.outfile, f.client_conn.address[0], + print >> self.outfile, f.request.client_conn.address[0], print >> self.outfile, f.request.short() print >> self.outfile, self.indent(4, f.request.headers) if utils.isBin(f.request.content): diff --git a/libmproxy/flow.py b/libmproxy/flow.py index cea0ca1c..d9df7a1a 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -32,9 +32,9 @@ class ReplayThread(threading.Thread): class Flow: - def __init__(self, client_conn): - self.client_conn = client_conn - self.request, self.response, self.error = None, None, None + def __init__(self, request): + self.request = request + self.response, self.error = None, None self.intercepting = False self._backup = None @@ -90,7 +90,6 @@ class Flow: request = self.request.get_state() if self.request else None, response = self.response.get_state() if self.response else None, error = self.error.get_state() if self.error else None, - client_conn = self.client_conn.get_state() ) if nobackup: d["backup"] = None @@ -99,10 +98,8 @@ class Flow: return d def load_state(self, state): - self.client_conn = proxy.ClientConnection.from_state(state["client_conn"]) self._backup = state["backup"] - if state["request"]: - self.request = proxy.Request.from_state(self.client_conn, state["request"]) + self.request = proxy.Request.from_state(state["request"]) if state["response"]: self.response = proxy.Response.from_state(self.request, state["response"]) if state["error"]: @@ -141,9 +138,6 @@ class Flow: return pattern(self.request) return False - def is_replay(self): - return self.client_conn.is_replay() - def kill(self): if self.request and not self.request.acked: self.request.ack(None) @@ -165,35 +159,43 @@ class Flow: class State: def __init__(self): + self.client_connections = [] self.flow_map = {} self.flow_list = [] + # These are compiled filt expressions: self.limit = None self.intercept = None - def add_browserconnect(self, f): + def clientconnect(self, cc): + if not isinstance(cc, proxy.ClientConnect): + assert False + self.client_connections.append(cc) + + def clientdisconnect(self, dc): """ Start a browser connection. """ - self.flow_list.insert(0, f) - self.flow_map[f.client_conn] = f + self.client_connections.remove(dc.client_conn) def add_request(self, req): """ Add a request to the state. Returns the matching flow. """ - f = self.flow_map.get(req.client_conn) - if not f: - f = Flow(req.client_conn) - self.add_browserconnect(f) - f.request = req + if not isinstance(req, proxy.Request): + assert False + f = Flow(req) + self.flow_list.insert(0, f) + self.flow_map[req] = f return f def add_response(self, resp): """ Add a response to the state. Returns the matching flow. """ - f = self.flow_map.get(resp.request.client_conn) + if not isinstance(resp, proxy.Response): + assert False + f = self.flow_map.get(resp.request) if not f: return False f.response = resp @@ -204,7 +206,7 @@ class State: Add an error response to the state. Returns the matching flow, or None if there isn't one. """ - f = self.flow_map.get(err.client_conn) + f = self.flow_map.get(err.flow.request) if not f: return None f.error = err @@ -213,7 +215,7 @@ class State: def load_flows(self, flows): self.flow_list.extend(flows) for i in flows: - self.flow_map[i.client_conn] = i + self.flow_map[i.request] = i def set_limit(self, limit): """ @@ -229,27 +231,17 @@ class State: return tuple(self.flow_list[:]) def get_client_conn(self, itm): - if isinstance(itm, proxy.ClientConnection): + if isinstance(itm, proxy.ClientConnect): return itm elif hasattr(itm, "client_conn"): return itm.client_conn elif hasattr(itm, "request"): return itm.request.client_conn - def lookup(self, itm): - """ - Checks for matching client_conn, using a Flow, Replay Connection, - ClientConnection, Request, Response or Error object. Returns None - if not found. - """ - client_conn = self.get_client_conn(itm) - return self.flow_map.get(client_conn) - def delete_flow(self, f): if not f.intercepting: - c = self.get_client_conn(f) - if c in self.flow_map: - del self.flow_map[c] + if f.request in self.flow_map: + del self.flow_map[f.request] self.flow_list.remove(f) return True return False @@ -280,7 +272,7 @@ class State: if f.request: f.backup() conn = self.get_client_conn(f) - f.client_conn.set_replay() + f.request.set_replay() if f.request.content: f.request.headers["content-length"] = [str(len(f.request.content))] f.response = None @@ -295,12 +287,13 @@ class FlowMaster(controller.Master): controller.Master.__init__(self, server) self.state = state - # Handlers - def handle_clientconnection(self, r): - f = Flow(r) - self.state.add_browserconnect(f) + def handle_clientconnect(self, r): + self.state.clientconnect(r) + r.ack() + + def handle_clientdisconnect(self, r): + self.state.clientdisconnect(r) r.ack() - return f def handle_error(self, r): f = self.state.add_error(r) diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 54ff2ec3..1c4d4d71 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -136,11 +136,21 @@ class Request(controller.Msg): self.close = False controller.Msg.__init__(self) + def set_replay(self): + self.client_conn = None + + def is_replay(self): + if self.client_conn: + return False + else: + return True + def is_cached(self): return False def get_state(self): return dict( + client_conn = self.client_conn.get_state(), host = self.host, port = self.port, scheme = self.scheme, @@ -152,9 +162,9 @@ class Request(controller.Msg): ) @classmethod - def from_state(klass, client_conn, state): + def from_state(klass, state): return klass( - client_conn, + ClientConnect.from_state(state["client_conn"]), state["host"], state["port"], state["scheme"], @@ -165,6 +175,9 @@ class Request(controller.Msg): state["timestamp"] ) + def __hash__(self): + return id(self) + def __eq__(self, other): return self.get_state() == other.get_state() @@ -296,7 +309,13 @@ class Response(controller.Msg): return self.FMT%data -class ClientConnection(controller.Msg): +class ClientDisconnect(controller.Msg): + def __init__(self, client_conn): + controller.Msg.__init__(self) + self.client_conn = client_conn + + +class ClientConnect(controller.Msg): def __init__(self, address): """ address is an (address, port) tuple, or None if this connection has @@ -313,22 +332,13 @@ class ClientConnection(controller.Msg): def from_state(klass, state): return klass(state) - def set_replay(self): - self.address = None - - def is_replay(self): - if self.address: - return False - else: - return True - def copy(self): return copy.copy(self) class Error(controller.Msg): - def __init__(self, client_conn, msg, timestamp=None): - self.client_conn, self.msg = client_conn, msg + def __init__(self, flow, msg, timestamp=None): + self.flow, self.msg = flow, msg self.timestamp = timestamp or time.time() controller.Msg.__init__(self) @@ -453,11 +463,12 @@ class ProxyHandler(SocketServer.StreamRequestHandler): SocketServer.StreamRequestHandler.__init__(self, request, client_address, server) def handle(self): - cc = ClientConnection(self.client_address) + cc = ClientConnect(self.client_address) cc.send(self.mqueue) while not cc.close: self.handle_request(cc) - cc = cc.copy() + cd = ClientDisconnect(cc) + cd.send(self.mqueue) self.finish() def handle_request(self, cc): diff --git a/test/test_console.py b/test/test_console.py index 93312824..6baab4ba 100644 --- a/test/test_console.py +++ b/test/test_console.py @@ -10,11 +10,9 @@ class uState(libpry.AutoTree): connect -> request -> response """ - bc = proxy.ClientConnection(("address", 22)) c = console.ConsoleState() - f = flow.Flow(bc) - c.add_browserconnect(f) - assert c.lookup(bc) + f = self._add_request(c) + assert f.request in c.flow_map assert c.get_focus() == (f, 0) def test_focus(self): @@ -24,18 +22,14 @@ class uState(libpry.AutoTree): connect -> request -> response """ c = console.ConsoleState() + f = self._add_request(c) - bc = proxy.ClientConnection(("address", 22)) - f = flow.Flow(bc) - c.add_browserconnect(f) assert c.get_focus() == (f, 0) assert c.get_from_pos(0) == (f, 0) assert c.get_from_pos(1) == (None, None) assert c.get_next(0) == (None, None) - bc2 = proxy.ClientConnection(("address", 22)) - f2 = flow.Flow(bc2) - c.add_browserconnect(f2) + f2 = self._add_request(c) assert c.get_focus() == (f, 1) assert c.get_next(0) == (f, 1) assert c.get_prev(1) == (f2, 0) @@ -52,25 +46,14 @@ class uState(libpry.AutoTree): assert c.get_focus() == (None, None) def _add_request(self, state): - f = utils.tflow() - state.add_browserconnect(f) - q = utils.treq(f.client_conn) - state.add_request(q) - return f + r = utils.treq() + return state.add_request(r) def _add_response(self, state): f = self._add_request(state) r = utils.tresp(f.request) state.add_response(r) - def test_add_request(self): - c = console.ConsoleState() - f = utils.tflow() - c.add_browserconnect(f) - q = utils.treq(f.client_conn) - c.focus = None - assert c.add_request(q) - def test_add_response(self): c = console.ConsoleState() f = self._add_request(c) @@ -118,11 +101,12 @@ class uformat_flow(libpry.AutoTree): assert ('method', '[edited] ') in console.format_flow(f, True) assert ('method', '[edited] ') in console.format_flow(f, True, True) - f.client_conn = proxy.ClientConnection(None) + f.request.set_replay() assert ('method', '[replay] ') in console.format_flow(f, True) assert ('method', '[replay] ') in console.format_flow(f, True, True) + class uPathCompleter(libpry.AutoTree): def test_lookup_construction(self): c = console._PathCompleter() diff --git a/test/test_dump.py b/test/test_dump.py index 978bf138..7b223645 100644 --- a/test/test_dump.py +++ b/test/test_dump.py @@ -10,7 +10,7 @@ class uDumpMaster(libpry.AutoTree): req = utils.treq() cc = req.client_conn resp = utils.tresp(req) - m.handle_clientconnection(cc) + m.handle_clientconnect(cc) m.handle_request(req) m.handle_response(resp) diff --git a/test/test_filt.py b/test/test_filt.py index 6f8579d3..791b9b39 100644 --- a/test/test_filt.py +++ b/test/test_filt.py @@ -72,7 +72,7 @@ class uParsing(libpry.AutoTree): class uMatching(libpry.AutoTree): def req(self): - conn = proxy.ClientConnection(("one", 2222)) + conn = proxy.ClientConnect(("one", 2222)) headers = utils.Headers() headers["header"] = ["qvalue"] return proxy.Request( diff --git a/test/test_flow.py b/test/test_flow.py index 35d336e8..3998943c 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -46,38 +46,6 @@ class uFlow(libpry.AutoTree): state = f.get_state() assert f == flow.Flow.from_state(state) - def test_simple(self): - f = utils.tflow() - assert console.format_flow(f, True) - assert console.format_flow(f, False) - - f.request = utils.treq() - assert console.format_flow(f, True) - assert console.format_flow(f, False) - - f.response = utils.tresp() - f.response.headers["content-type"] = ["text/html"] - assert console.format_flow(f, True) - assert console.format_flow(f, False) - f.response.code = 404 - assert console.format_flow(f, True) - assert console.format_flow(f, False) - - assert console.format_flow(f, True) - assert console.format_flow(f, False) - - f.client_conn.set_replay() - assert console.format_flow(f, True) - assert console.format_flow(f, False) - - f.response = None - assert console.format_flow(f, True) - assert console.format_flow(f, False) - - f.error = proxy.Error(200, "test") - assert console.format_flow(f, True) - assert console.format_flow(f, False) - def test_kill(self): f = utils.tflow() f.request = utils.treq() @@ -115,10 +83,10 @@ class uFlow(libpry.AutoTree): class uState(libpry.AutoTree): def test_backup(self): - bc = proxy.ClientConnection(("address", 22)) + bc = proxy.ClientConnect(("address", 22)) c = flow.State() - f = flow.Flow(bc) - c.add_browserconnect(f) + req = utils.treq() + f = c.add_request(req) f.backup() c.revert(f) @@ -129,92 +97,98 @@ class uState(libpry.AutoTree): connect -> request -> response """ - bc = proxy.ClientConnection(("address", 22)) + bc = proxy.ClientConnect(("address", 22)) c = flow.State() - f = flow.Flow(bc) - c.add_browserconnect(f) - assert c.lookup(bc) + c.clientconnect(bc) + assert len(c.client_connections) == 1 req = utils.treq(bc) - assert c.add_request(req) + f = c.add_request(req) + assert f assert len(c.flow_list) == 1 - assert c.lookup(req) + assert c.flow_map.get(req) newreq = utils.treq() assert c.add_request(newreq) - assert c.lookup(newreq) + assert c.flow_map.get(newreq) resp = utils.tresp(req) assert c.add_response(resp) assert len(c.flow_list) == 2 - assert c.lookup(resp) + assert c.flow_map.get(resp.request) newresp = utils.tresp() assert not c.add_response(newresp) - assert not c.lookup(newresp) + assert not c.flow_map.get(newresp.request) + + dc = proxy.ClientDisconnect(bc) + c.clientdisconnect(dc) + assert not c.client_connections def test_err(self): - bc = proxy.ClientConnection(("address", 22)) + bc = proxy.ClientConnect(("address", 22)) c = flow.State() - f = flow.Flow(bc) - c.add_browserconnect(f) - e = proxy.Error(bc, "message") + req = utils.treq() + f = c.add_request(req) + e = proxy.Error(f, "message") assert c.add_error(e) - e = proxy.Error(proxy.ClientConnection(("address", 22)), "message") + e = proxy.Error(utils.tflow(), "message") assert not c.add_error(e) def test_view(self): c = flow.State() - f = utils.tflow() - c.add_browserconnect(f) + req = utils.treq() + c.clientconnect(req.client_conn) + assert len(c.view) == 0 + + f = c.add_request(req) assert len(c.view) == 1 - c.set_limit(filt.parse("~q")) + + c.set_limit(filt.parse("~s")) assert len(c.view) == 0 + resp = utils.tresp(req) + c.add_response(resp) + assert len(c.view) == 1 c.set_limit(None) + assert len(c.view) == 1 - - f = utils.tflow() - req = utils.treq(f.client_conn) - c.add_browserconnect(f) + req = utils.treq() + c.clientconnect(req.client_conn) c.add_request(req) assert len(c.view) == 2 c.set_limit(filt.parse("~q")) assert len(c.view) == 1 c.set_limit(filt.parse("~s")) - assert len(c.view) == 0 + assert len(c.view) == 1 def _add_request(self, state): - f = utils.tflow() - state.add_browserconnect(f) - q = utils.treq(f.client_conn) - state.add_request(q) + req = utils.treq() + f = state.add_request(req) return f def _add_response(self, state): - f = self._add_request(state) - r = utils.tresp(f.request) - state.add_response(r) + req = utils.treq() + f = state.add_request(req) + resp = utils.tresp(req) + state.add_response(resp) def _add_error(self, state): - f = utils.tflow() - f.error = proxy.Error(None, "msg") - state.add_browserconnect(f) - q = utils.treq(f.client_conn) - state.add_request(q) + req = utils.treq() + f = state.add_request(req) + f.error = proxy.Error(f, "msg") def test_kill_flow(self): c = flow.State() - f = utils.tflow() - c.add_browserconnect(f) + req = utils.treq() + f = c.add_request(req) c.kill_flow(f) assert not c.flow_list def test_clear(self): c = flow.State() - f = utils.tflow() - c.add_browserconnect(f) + f = self._add_request(c) f.intercepting = True c.clear() @@ -265,15 +239,12 @@ class uFlowMaster(libpry.AutoTree): def test_one(self): s = flow.State() f = flow.FlowMaster(None, s) - req = utils.treq() - f.handle_clientconnection(req.client_conn) - assert len(s.flow_list) == 1 + f.handle_request(req) assert len(s.flow_list) == 1 - f.handle_request(req) - resp = utils.tresp() - resp.request = req + + resp = utils.tresp(req) f.handle_response(resp) assert len(s.flow_list) == 1 diff --git a/test/test_proxy.py b/test/test_proxy.py index 7bb608ae..ba9d9bfa 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -213,7 +213,7 @@ class uRequest(libpry.AutoTree): def test_simple(self): h = utils.Headers() h["test"] = ["test"] - c = proxy.ClientConnection(("addr", 2222)) + c = proxy.ClientConnect(("addr", 2222)) r = proxy.Request(c, "host", 22, "https", "GET", "/", h, "content") u = r.url() assert r.set_url(u) @@ -225,17 +225,17 @@ class uRequest(libpry.AutoTree): def test_getset_state(self): h = utils.Headers() h["test"] = ["test"] - c = proxy.ClientConnection(("addr", 2222)) + c = proxy.ClientConnect(("addr", 2222)) r = proxy.Request(c, "host", 22, "https", "GET", "/", h, "content") state = r.get_state() - assert proxy.Request.from_state(c, state) == r + assert proxy.Request.from_state(state) == r class uResponse(libpry.AutoTree): def test_simple(self): h = utils.Headers() h["test"] = ["test"] - c = proxy.ClientConnection(("addr", 2222)) + c = proxy.ClientConnect(("addr", 2222)) req = proxy.Request(c, "host", 22, "https", "GET", "/", h, "content") resp = proxy.Response(req, 200, "msg", h.copy(), "content") assert resp.short() @@ -244,7 +244,7 @@ class uResponse(libpry.AutoTree): def test_getset_state(self): h = utils.Headers() h["test"] = ["test"] - c = proxy.ClientConnection(("addr", 2222)) + c = proxy.ClientConnect(("addr", 2222)) r = proxy.Request(c, "host", 22, "https", "GET", "/", h, "content") req = proxy.Request(c, "host", 22, "https", "GET", "/", h, "content") resp = proxy.Response(req, 200, "msg", h.copy(), "content") diff --git a/test/utils.py b/test/utils.py index b63c48dc..b1dc46d4 100644 --- a/test/utils.py +++ b/test/utils.py @@ -2,7 +2,7 @@ from libmproxy import proxy, utils, filt, flow def treq(conn=None): if not conn: - conn = proxy.ClientConnection(("address", 22)) + conn = proxy.ClientConnect(("address", 22)) headers = utils.Headers() headers["header"] = ["qvalue"] return proxy.Request(conn, "host", 80, "http", "GET", "/path", headers, "content") @@ -17,6 +17,6 @@ def tresp(req=None): def tflow(): - bc = proxy.ClientConnection(("address", 22)) - return flow.Flow(bc) + r = treq() + return flow.Flow(r) -- cgit v1.2.3