diff options
| -rw-r--r-- | libmproxy/console/__init__.py | 19 | ||||
| -rw-r--r-- | libmproxy/console/grideditor.py | 11 | ||||
| -rw-r--r-- | libmproxy/flow.py | 73 | ||||
| -rw-r--r-- | test/test_flow.py | 67 | 
4 files changed, 148 insertions, 22 deletions
| diff --git a/libmproxy/console/__init__.py b/libmproxy/console/__init__.py index 9bac5773..e835340e 100644 --- a/libmproxy/console/__init__.py +++ b/libmproxy/console/__init__.py @@ -124,6 +124,10 @@ class StatusBar(common.WWrap):      def get_status(self):          r = [] +        if self.master.setheaders.count(): +            r.append("[") +            r.append(("heading_key", "H")) +            r.append("eaders]")          if self.master.replacehooks.count():              r.append("[")              r.append(("heading_key", "R")) @@ -762,11 +766,6 @@ class ConsoleMaster(flow.FlowMaster):          else:              self.view_flowlist() -    def set_replace(self, r): -        self.replacehooks.clear() -        for i in r: -            self.replacehooks.add(*i) -      def loop(self):          changed = True          try: @@ -815,6 +814,14 @@ class ConsoleMaster(flow.FlowMaster):                                          ),                                          self.stop_client_playback_prompt,                                      ) +                            elif k == "H": +                                self.view_grideditor( +                                    grideditor.SetHeadersEditor( +                                        self, +                                        self.setheaders.get_specs(), +                                        self.setheaders.set +                                    ) +                                )                              elif k == "i":                                  self.prompt(                                      "Intercept filter: ", @@ -853,7 +860,7 @@ class ConsoleMaster(flow.FlowMaster):                                      grideditor.ReplaceEditor(                                          self,                                          self.replacehooks.get_specs(), -                                        self.set_replace +                                        self.replacehooks.set                                      )                                  )                              elif k == "s": diff --git a/libmproxy/console/grideditor.py b/libmproxy/console/grideditor.py index 51002e77..d62cb206 100644 --- a/libmproxy/console/grideditor.py +++ b/libmproxy/console/grideditor.py @@ -373,3 +373,14 @@ class ReplaceEditor(GridEditor):                  return True          return False + +class SetHeadersEditor(GridEditor): +    title = "Editing header set patterns" +    columns = 3 +    headings = ("Filter", "Header", "Value") +    def is_error(self, col, val): +        if col == 0: +            if not filt.parse(val): +                return True +        return False + diff --git a/libmproxy/flow.py b/libmproxy/flow.py index 2d016a14..6076d202 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -35,6 +35,11 @@ class ReplaceHooks:      def __init__(self):          self.lst = [] +    def set(self, r): +        self.clear() +        for i in r: +            self.add(*i) +      def add(self, fpatt, rex, s):          """              Add a replacement hook. @@ -56,17 +61,6 @@ class ReplaceHooks:          self.lst.append((fpatt, rex, s, cpatt))          return True -    def remove(self, fpatt, rex, s): -        """ -            Remove a hook. - -            patt: A string specifying a filter pattern. -            func: Optional callable. If not specified, all hooks matching patt are removed. -        """ -        for i in range(len(self.lst)-1, -1, -1): -            if (fpatt, rex, s) == self.lst[i][:3]: -                del self.lst[i] -      def get_specs(self):          """              Retrieve the hook specifcations. Returns a list of (fpatt, rex, s) tuples. @@ -88,6 +82,59 @@ class ReplaceHooks:          self.lst = [] +class SetHeaders: +    def __init__(self): +        self.lst = [] + +    def set(self, r): +        self.clear() +        for i in r: +            self.add(*i) + +    def add(self, fpatt, header, value): +        """ +            Add a set header hook. + +            fpatt: String specifying a filter pattern. +            header: Header name. +            value: Header value string + +            Returns True if hook was added, False if the pattern could not be +            parsed. +        """ +        cpatt = filt.parse(fpatt) +        if not cpatt: +            return False +        self.lst.append((fpatt, header, value, cpatt)) +        return True + +    def get_specs(self): +        """ +            Retrieve the hook specifcations. Returns a list of (fpatt, rex, s) tuples. +        """ +        return [i[:3] for i in self.lst] + +    def count(self): +        return len(self.lst) + +    def clear(self): +        self.lst = [] + +    def run(self, f): +        for _, header, value, cpatt in self.lst: +            if cpatt(f): +                if f.response: +                    del f.response.headers[header] +                else: +                    del f.request.headers[header] +        for _, header, value, cpatt in self.lst: +            if cpatt(f): +                if f.response: +                    f.response.headers.add(header, value) +                else: +                    f.request.headers.add(header, value) + +  class ScriptContext:      def __init__(self, master):          self._master = master @@ -1220,6 +1267,7 @@ class FlowMaster(controller.Master):          self.anticomp = False          self.refresh_server_playback = False          self.replacehooks = ReplaceHooks() +        self.setheaders = SetHeaders()          self.stream = None @@ -1426,6 +1474,7 @@ class FlowMaster(controller.Master):      def handle_error(self, r):          f = self.state.add_error(r)          self.replacehooks.run(f) +        self.setheaders.run(f)          if f:              self.run_script_hook("error", f)          if self.client_playback: @@ -1436,6 +1485,7 @@ class FlowMaster(controller.Master):      def handle_request(self, r):          f = self.state.add_request(r)          self.replacehooks.run(f) +        self.setheaders.run(f)          self.run_script_hook("request", f)          self.process_new_request(f)          return f @@ -1444,6 +1494,7 @@ class FlowMaster(controller.Master):          f = self.state.add_response(r)          if f:              self.replacehooks.run(f) +            self.setheaders.run(f)              self.run_script_hook("response", f)              if self.client_playback:                  self.client_playback.clear(f) diff --git a/test/test_flow.py b/test/test_flow.py index eccd11f4..47a09360 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -1017,7 +1017,16 @@ def test_replacehooks():      h = flow.ReplaceHooks()      h.add("~q", "foo", "bar")      assert h.lst -    h.remove("~q", "foo", "bar") + +    h.set( +        [ +            (".*", "one", "two"), +            (".*", "three", "four"), +        ] +    ) +    assert h.count() == 2 + +    h.clear()      assert not h.lst      h.add("~q", "foo", "bar") @@ -1026,10 +1035,6 @@ def test_replacehooks():      v = h.get_specs()      assert v == [('~q', 'foo', 'bar'), ('~s', 'foo', 'bar')]      assert h.count() == 2 -    h.remove("~q", "foo", "bar") -    assert h.count() == 1 -    h.remove("~q", "foo", "bar") -    assert h.count() == 1      h.clear()      assert h.count() == 0 @@ -1056,3 +1061,55 @@ def test_replacehooks():      assert not h.add("~", "foo", "bar")      assert not h.add("foo", "*", "bar") + +def test_setheaders(): +    h = flow.SetHeaders() +    h.add("~q", "foo", "bar") +    assert h.lst + +    h.set( +        [ +            (".*", "one", "two"), +            (".*", "three", "four"), +        ] +    ) +    assert h.count() == 2 + +    h.clear() +    assert not h.lst + +    h.add("~q", "foo", "bar") +    h.add("~s", "foo", "bar") + +    v = h.get_specs() +    assert v == [('~q', 'foo', 'bar'), ('~s', 'foo', 'bar')] +    assert h.count() == 2 +    h.clear() +    assert h.count() == 0 + +    f = tutils.tflow() +    f.request.content = "foo" +    h.add("~s", "foo", "bar") +    h.run(f) +    assert f.request.content == "foo" + + +    h.clear() +    h.add("~s", "one", "two") +    h.add("~s", "one", "three") +    f = tutils.tflow_full() +    f.request.headers["one"] = ["xxx"] +    f.response.headers["one"] = ["xxx"] +    h.run(f) +    assert f.request.headers["one"] == ["xxx"] +    assert f.response.headers["one"] == ["two", "three"] + +    h.clear() +    h.add("~q", "one", "two") +    h.add("~q", "one", "three") +    f = tutils.tflow() +    f.request.headers["one"] = ["xxx"] +    h.run(f) +    assert f.request.headers["one"] == ["two", "three"] + +    assert not h.add("~", "foo", "bar") | 
