diff options
| -rw-r--r-- | libmproxy/console.py | 60 | ||||
| -rw-r--r-- | libmproxy/flow.py | 27 | ||||
| -rw-r--r-- | test/test_console.py | 4 | ||||
| -rw-r--r-- | test/test_flow.py | 6 | 
4 files changed, 63 insertions, 34 deletions
| diff --git a/libmproxy/console.py b/libmproxy/console.py index 13c31435..6aa7ff9c 100644 --- a/libmproxy/console.py +++ b/libmproxy/console.py @@ -165,7 +165,12 @@ class ConnectionItem(WWrap):              self.state.revert(self.flow)              self.master.sync_list_view()          elif key == "w": -            self.master.prompt("Save this flow: ", self.master.save_one_flow, self.flow) +            self.master.prompt( +                "Save this flow: ", +                self.state.last_saveload, +                self.master.save_one_flow, +                self.flow +            )          elif key == "z":              self.master.kill_connection(self.flow)          elif key == "enter": @@ -384,6 +389,7 @@ class ConnectionView(WWrap):      def save_body(self, path):          if not path:              return +        self.state.last_saveload = path          if self.state.view_flow_mode == VIEW_FLOW_REQUEST:              c = self.flow.request          else: @@ -516,7 +522,12 @@ class ConnectionView(WWrap):              self.state.revert(self.flow)              self.master.refresh_connection(self.flow)          elif key == "w": -            self.master.prompt("Save this flow: ", self.master.save_one_flow, self.flow) +            self.master.prompt( +                "Save this flow: ", +                self.state.last_saveload, +                self.master.save_one_flow, +                self.flow +            )          elif key == "v":              if self.state.view_flow_mode == VIEW_FLOW_REQUEST:                  conn = self.flow.request @@ -528,9 +539,17 @@ class ConnectionView(WWrap):                  self.master.spawn_external_viewer(conn.content, t)          elif key == "b":              if self.state.view_flow_mode == VIEW_FLOW_REQUEST: -                self.master.prompt("Save request body: ", self.save_body) +                self.master.prompt( +                    "Save request body: ", +                    self.state.last_saveload, +                    self.save_body +                )              else: -                self.master.prompt("Save response body: ", self.save_body) +                self.master.prompt( +                    "Save response body: ", +                    self.state.last_saveload, +                    self.save_body +                )          elif key == " ":              self.master.view_next_flow(self.flow)          elif key == "|": @@ -632,7 +651,7 @@ class ActionBar(WWrap):          self.w = PathEdit(prompt, text)      def prompt(self, prompt, text = ""): -        self.w = urwid.Edit(prompt, text) +        self.w = urwid.Edit(prompt, text or "")      def message(self, message):          self.w = urwid.Text(message) @@ -1100,8 +1119,8 @@ class ConsoleMaster(flow.FlowMaster):          self.view.set_focus("footer")          self.prompting = (callback, args) -    def prompt(self, prompt, callback, *args): -        self.statusbar.prompt(prompt) +    def prompt(self, prompt, text, callback, *args): +        self.statusbar.prompt(prompt, text)          self.view.set_focus("footer")          self.prompting = (callback, args) @@ -1129,7 +1148,7 @@ class ConsoleMaster(flow.FlowMaster):          prompt.extend(mkup)          prompt.append(")? ")          self.onekey = "".join([i[1] for i in keys]) -        self.prompt(prompt, callback) +        self.prompt(prompt, "", callback)      def prompt_done(self):          self.prompting = False @@ -1153,21 +1172,10 @@ class ConsoleMaster(flow.FlowMaster):          self.state.accept_all()      def set_limit(self, txt): -        if txt: -            f = filt.parse(txt) -            if not f: -                return "Invalid filter expression." -            self.state.set_limit(f) -        else: -            self.state.set_limit(None) +        return self.state.set_limit(txt)      def set_intercept(self, txt): -        if txt: -            self.state.intercept = filt.parse(txt) -            if not self.state.intercept: -                return "Invalid filter expression." -        else: -            self.state.intercept = None +        return self.state.set_intercept(txt)      def set_beep(self, txt):          if txt: @@ -1214,15 +1222,19 @@ class ConsoleMaster(flow.FlowMaster):                          if k == "?":                              self.view_help()                          elif k == "l": -                            self.prompt("Limit: ", self.set_limit) +                            self.prompt("Limit: ", self.state.limit_txt, self.set_limit)                              self.sync_list_view()                              k = None                          elif k == "i": -                            self.prompt("Intercept: ", self.set_intercept) +                            self.prompt( +                                "Intercept: ", +                                self.state.intercept_txt, +                                self.set_intercept +                            )                              self.sync_list_view()                              k = None                          elif k == "B": -                            self.prompt("Beep: ", self.set_beep) +                            self.prompt("Beep: ", "", self.set_beep)                              k = None                          elif k == "j":                              k = "down" diff --git a/libmproxy/flow.py b/libmproxy/flow.py index 4bcbbb97..7444b400 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -316,6 +316,7 @@ class State:          # These are compiled filt expressions:          self.limit = None          self.intercept = None +        self.limit_txt = None      def flow_count(self):          return len(self.flow_map) @@ -371,11 +372,27 @@ class State:          for i in flows:              self.flow_map[i.request] = i -    def set_limit(self, limit): -        """ -            Limit is a compiled filter expression, or None. -        """ -        self.limit = limit +    def set_limit(self, txt): +        if txt: +            f = filt.parse(txt) +            if not f: +                return "Invalid filter expression." +            self.limit = f +            self.limit_txt = txt +        else: +            self.limit = None +            self.limit_txt = None + +    def set_intercept(self, txt): +        if txt: +            f = filt.parse(txt) +            if not f: +                return "Invalid filter expression." +            self.intercept = f +            self.intercept_txt = txt +        else: +            self.intercept = None +            self.intercept_txt = None      @property      def view(self): diff --git a/test/test_console.py b/test/test_console.py index 034f8ea1..ffcb31f7 100644 --- a/test/test_console.py +++ b/test/test_console.py @@ -69,7 +69,7 @@ class uState(libpry.AutoTree):          self._add_response(c)          self._add_request(c)          self._add_response(c) -        c.set_limit(filt.parse("~q")) +        assert not c.set_limit("~q")          assert len(c.view) == 3          assert c.focus == 2 @@ -158,7 +158,7 @@ class uPathCompleter(libpry.AutoTree):  class uOptions(libpry.AutoTree):      def test_all(self): -        assert console.Options(beep=True) +        assert console.Options(kill=True) diff --git a/test/test_flow.py b/test/test_flow.py index d2cb85dc..f29e1e80 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -275,7 +275,7 @@ class uState(libpry.AutoTree):          f = c.add_request(req)          assert len(c.view) == 1 -        c.set_limit(filt.parse("~s")) +        c.set_limit("~s")          assert len(c.view) == 0          resp = tutils.tresp(req)          c.add_response(resp) @@ -287,9 +287,9 @@ class uState(libpry.AutoTree):          c.clientconnect(req.client_conn)          c.add_request(req)          assert len(c.view) == 2 -        c.set_limit(filt.parse("~q")) +        c.set_limit("~q")          assert len(c.view) == 1 -        c.set_limit(filt.parse("~s")) +        c.set_limit("~s")          assert len(c.view) == 1      def _add_request(self, state): | 
