From c8ae1e85b33e80f0c84ccdc8b3759affe4ef3900 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sat, 17 Mar 2012 11:31:05 +1300 Subject: Hooks -> ReplaceHooks It makes more sense to specialize this, which will let me build a nicer interface for replacement hooks in mitmproxy. --- libmproxy/flow.py | 36 +++++++++++++++++++----------------- test/test_flow.py | 43 +++++++++++++++++++++++-------------------- 2 files changed, 42 insertions(+), 37 deletions(-) diff --git a/libmproxy/flow.py b/libmproxy/flow.py index 4c6f2915..438cb9ad 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -26,27 +26,28 @@ import controller, version HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" -class Hooks: +class ReplaceHooks: def __init__(self): self.lst = [] - def add(self, patt, func): + def add(self, fpatt, rex, s): """ - Add a hook. + Add a replacement hook. - patt: A string specifying a filter pattern. - func: A callable taking the matching flow as argument. + fpatt: A string specifying a filter pattern. + rex: A regular expression. + s: The replacement string Returns True if hook was added, False if the pattern could not be parsed. """ - cpatt = filt.parse(patt) + cpatt = filt.parse(fpatt) if not cpatt: return False - self.lst.append((patt, func, cpatt)) + self.lst.append((fpatt, rex, s, cpatt)) return True - def remove(self, patt, func=None): + def remove(self, fpatt, rex, s): """ Remove a hook. @@ -54,15 +55,16 @@ class Hooks: func: Optional callable. If not specified, all hooks matching patt are removed. """ for i in range(len(self.lst)-1, -1, -1): - if func and (patt, func) == self.lst[i][:2]: - del self.lst[i] - elif not func and patt == self.lst[i][0]: + if (fpatt, rex, s) == self.lst[i][:3]: del self.lst[i] def run(self, f): - for _, func, cpatt in self.lst: + for _, rex, s, cpatt in self.lst: if cpatt(f): - func(f) + if f.response: + f.response.replace(rex, s) + else: + f.request.replace(rex, s) def clear(self): self.lst = [] @@ -1270,7 +1272,7 @@ class FlowMaster(controller.Master): self.anticache = False self.anticomp = False self.refresh_server_playback = False - self.hooks = Hooks() + self.replacehooks = ReplaceHooks() def add_event(self, e, level="info"): """ @@ -1480,7 +1482,7 @@ class FlowMaster(controller.Master): def handle_error(self, r): f = self.state.add_error(r) - self.hooks.run(f) + self.replacehooks.run(f) if f: self.run_script_hook("error", f) if self.client_playback: @@ -1490,14 +1492,14 @@ class FlowMaster(controller.Master): def handle_request(self, r): f = self.state.add_request(r) - self.hooks.run(f) + self.replacehooks.run(f) self.run_script_hook("request", f) self.process_new_request(f) return f def handle_response(self, r): f = self.state.add_response(r) - self.hooks.run(f) + self.replacehooks.run(f) if f: self.run_script_hook("response", f) if self.client_playback: diff --git a/test/test_flow.py b/test/test_flow.py index 8f7551c7..cd4de7ea 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -1036,44 +1036,47 @@ class udecoded(libpry.AutoTree): assert r.content == "foo" -class uHooks(libpry.AutoTree): +class uReplaceHooks(libpry.AutoTree): def test_add_remove(self): - f = lambda(x): None - h = flow.Hooks() - h.add("~q", f) + h = flow.ReplaceHooks() + h.add("~q", "foo", "bar") assert h.lst - - h.remove("~q", f) + h.remove("~q", "foo", "bar") assert not h.lst - h.add("~q", f) - h.add("~s", f) + h.add("~q", "foo", "bar") + h.add("~s", "foo", "bar") assert len(h.lst) == 2 - h.remove("~q", f) + h.remove("~q", "foo", "bar") assert len(h.lst) == 1 - h.remove("~q") + h.remove("~q", "foo", "bar") assert len(h.lst) == 1 - h.remove("~s") + h.clear() assert len(h.lst) == 0 - track = [] - def func(x): - track.append(x) - - h.add("~s", func) - f = tutils.tflow() + f.request.content = "foo" + h.add("~s", "foo", "bar") h.run(f) - assert not track + assert f.request.content == "foo" f = tutils.tflow_full() + f.request.content = "foo" + f.response.content = "foo" h.run(f) - assert len(track) == 1 + assert f.response.content == "bar" + assert f.request.content == "foo" + f = tutils.tflow() + h.clear() + h.add("~q", "foo", "bar") + f.request.content = "foo" + h.run(f) + assert f.request.content == "bar" tests = [ - uHooks(), + uReplaceHooks(), uStickyCookieState(), uStickyAuthState(), uServerPlaybackState(), -- cgit v1.2.3