From 9042d3f3b9ab96020dd314cdb9faeaae3947544c Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 3 Aug 2011 22:48:40 +1200 Subject: Clean up interfaces by making some methods pseudo-private. --- libmproxy/flow.py | 99 ++++++++++++++++++++++++++++--------------------------- test/test_flow.py | 46 +++++++++++++------------- 2 files changed, 74 insertions(+), 71 deletions(-) diff --git a/libmproxy/flow.py b/libmproxy/flow.py index 1afab895..1decb7d5 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -62,11 +62,11 @@ class Headers: def add(self, key, value): self.lst.append([key, str(value)]) - def get_state(self): + def _get_state(self): return [tuple(i) for i in self.lst] @classmethod - def from_state(klass, state): + def _from_state(klass, state): return klass([list(i) for i in state]) def copy(self): @@ -85,7 +85,10 @@ class Headers: def match_re(self, expr): """ - Match the regular expression against each header (key, value) pair. + Match the regular expression against each header. For each (key, + value) pair a string of the following format is matched against: + + "key: value" """ for k, v in self.lst: s = "%s: %s"%(k, v) @@ -211,12 +214,12 @@ class Request(HTTPMsg): else: return True - def load_state(self, state): + def _load_state(self, state): if state["client_conn"]: if self.client_conn: - self.client_conn.load_state(state["client_conn"]) + self.client_conn._load_state(state["client_conn"]) else: - self.client_conn = ClientConnect.from_state(state["client_conn"]) + self.client_conn = ClientConnect._from_state(state["client_conn"]) else: self.client_conn = None self.host = state["host"] @@ -224,33 +227,33 @@ class Request(HTTPMsg): self.scheme = state["scheme"] self.method = state["method"] self.path = state["path"] - self.headers = Headers.from_state(state["headers"]) + self.headers = Headers._from_state(state["headers"]) self.content = base64.decodestring(state["content"]) self.timestamp = state["timestamp"] - def get_state(self): + def _get_state(self): return dict( - client_conn = self.client_conn.get_state() if self.client_conn else None, + client_conn = self.client_conn._get_state() if self.client_conn else None, host = self.host, port = self.port, scheme = self.scheme, method = self.method, path = self.path, - headers = self.headers.get_state(), + headers = self.headers._get_state(), content = base64.encodestring(self.content), timestamp = self.timestamp, ) @classmethod - def from_state(klass, state): + def _from_state(klass, state): return klass( - ClientConnect.from_state(state["client_conn"]), + ClientConnect._from_state(state["client_conn"]), str(state["host"]), state["port"], str(state["scheme"]), str(state["method"]), str(state["path"]), - Headers.from_state(state["headers"]), + Headers._from_state(state["headers"]), base64.decodestring(state["content"]), state["timestamp"] ) @@ -259,7 +262,7 @@ class Request(HTTPMsg): return id(self) def __eq__(self, other): - return self.get_state() == other.get_state() + return self._get_state() == other._get_state() def copy(self): c = copy.copy(self) @@ -395,35 +398,35 @@ class Response(HTTPMsg): def is_replay(self): return self.replay - def load_state(self, state): + def _load_state(self, state): self.code = state["code"] self.msg = state["msg"] - self.headers = Headers.from_state(state["headers"]) + self.headers = Headers._from_state(state["headers"]) self.content = base64.decodestring(state["content"]) self.timestamp = state["timestamp"] - def get_state(self): + def _get_state(self): return dict( code = self.code, msg = self.msg, - headers = self.headers.get_state(), + headers = self.headers._get_state(), timestamp = self.timestamp, content = base64.encodestring(self.content) ) @classmethod - def from_state(klass, request, state): + def _from_state(klass, request, state): return klass( request, state["code"], str(state["msg"]), - Headers.from_state(state["headers"]), + Headers._from_state(state["headers"]), base64.decodestring(state["content"]), state["timestamp"], ) def __eq__(self, other): - return self.get_state() == other.get_state() + return self._get_state() == other._get_state() def copy(self): c = copy.copy(self) @@ -484,16 +487,16 @@ class ClientConnect(controller.Msg): controller.Msg.__init__(self) def __eq__(self, other): - return self.get_state() == other.get_state() + return self._get_state() == other._get_state() - def load_state(self, state): + def _load_state(self, state): self.address = state - def get_state(self): + def _get_state(self): return list(self.address) if self.address else None @classmethod - def from_state(klass, state): + def _from_state(klass, state): if state: return klass(state) else: @@ -509,21 +512,21 @@ class Error(controller.Msg): self.timestamp = timestamp or utils.timestamp() controller.Msg.__init__(self) - def load_state(self, state): + def _load_state(self, state): self.msg = state["msg"] self.timestamp = state["timestamp"] def copy(self): return copy.copy(self) - def get_state(self): + def _get_state(self): return dict( msg = self.msg, timestamp = self.timestamp, ) @classmethod - def from_state(klass, state): + def _from_state(klass, state): return klass( None, state["msg"], @@ -531,7 +534,7 @@ class Error(controller.Msg): ) def __eq__(self, other): - return self.get_state() == other.get_state() + return self._get_state() == other._get_state() def replace(self, pattern, repl, *args, **kwargs): """ @@ -708,9 +711,9 @@ class Flow: self._backup = None @classmethod - def from_state(klass, state): + def _from_state(klass, state): f = klass(None) - f.load_state(state) + f._load_state(state) return f @classmethod @@ -719,13 +722,13 @@ class Flow: data = json.loads(data) except Exception: return None - return klass.from_state(data) + return klass._from_state(data) - def get_state(self, nobackup=False): + def _get_state(self, nobackup=False): d = dict( - request = self.request.get_state() if self.request else None, - response = self.response.get_state() if self.response else None, - error = self.error.get_state() if self.error else None, + request = self.request._get_state() if self.request else None, + response = self.response._get_state() if self.response else None, + error = self.error._get_state() if self.error else None, version = version.IVERSION ) if nobackup: @@ -734,26 +737,26 @@ class Flow: d["backup"] = self._backup return d - def load_state(self, state): + def _load_state(self, state): self._backup = state["backup"] if self.request: - self.request.load_state(state["request"]) + self.request._load_state(state["request"]) else: - self.request = Request.from_state(state["request"]) + self.request = Request._from_state(state["request"]) if state["response"]: if self.response: - self.response.load_state(state["response"]) + self.response._load_state(state["response"]) else: - self.response = Response.from_state(self.request, state["response"]) + self.response = Response._from_state(self.request, state["response"]) else: self.response = None if state["error"]: if self.error: - self.error.load_state(state["error"]) + self.error._load_state(state["error"]) else: - self.error = Error.from_state(state["error"]) + self.error = Error._from_state(state["error"]) else: self.error = None @@ -766,11 +769,11 @@ class Flow: return False def backup(self): - self._backup = self.get_state(nobackup=True) + self._backup = self._get_state(nobackup=True) def revert(self): if self._backup: - self.load_state(self._backup) + self._load_state(self._backup) self._backup = None def match(self, pattern): @@ -1041,7 +1044,7 @@ class FlowMaster(controller.Master): rflow = self.server_playback.next_flow(flow) if not rflow: return None - response = Response.from_state(flow.request, rflow.response.get_state()) + response = Response._from_state(flow.request, rflow.response._get_state()) response.set_replay() flow.response = response if self.refresh_server_playback: @@ -1178,7 +1181,7 @@ class FlowWriter: self.ns = netstring.FileEncoder(fo) def add(self, flow): - d = flow.get_state() + d = flow._get_state() s = json.dumps(d) self.ns.write(s) @@ -1201,7 +1204,7 @@ class FlowReader: try: for i in self.ns: data = json.loads(i) - yield Flow.from_state(data) + yield Flow._from_state(data) except netstring.DecoderError: raise FlowReadError("Invalid data format.") diff --git a/test/test_flow.py b/test/test_flow.py index f61e8de5..0cfcbc3c 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -153,19 +153,19 @@ class uFlow(libpry.AutoTree): def test_getset_state(self): f = tutils.tflow() f.response = tutils.tresp(f.request) - state = f.get_state() - assert f.get_state() == flow.Flow.from_state(state).get_state() + state = f._get_state() + assert f._get_state() == flow.Flow._from_state(state)._get_state() f.response = None f.error = flow.Error(f.request, "error") - state = f.get_state() - assert f.get_state() == flow.Flow.from_state(state).get_state() + state = f._get_state() + assert f._get_state() == flow.Flow._from_state(state)._get_state() f2 = tutils.tflow() f2.error = flow.Error(f.request, "e2") assert not f == f2 - f.load_state(f2.get_state()) - assert f.get_state() == f2.get_state() + f._load_state(f2._get_state()) + assert f._get_state() == f2._get_state() def test_kill(self): s = flow.State() @@ -410,7 +410,7 @@ class uSerialize(libpry.AutoTree): assert len(l) == 1 f2 = l[0] - assert f2.get_state() == f.get_state() + assert f2._get_state() == f._get_state() assert f2.request.assemble() == f.request.assemble() def test_load_flows(self): @@ -594,20 +594,20 @@ class uRequest(libpry.AutoTree): h["test"] = ["test"] c = flow.ClientConnect(("addr", 2222)) r = flow.Request(c, "host", 22, "https", "GET", "/", h, "content") - state = r.get_state() - assert flow.Request.from_state(state) == r + state = r._get_state() + assert flow.Request._from_state(state) == r r.client_conn = None - state = r.get_state() - assert flow.Request.from_state(state) == r + state = r._get_state() + assert flow.Request._from_state(state) == r r2 = flow.Request(c, "testing", 20, "http", "PUT", "/foo", h, "test") assert not r == r2 - r.load_state(r2.get_state()) + r._load_state(r2._get_state()) assert r == r2 r2.client_conn = None - r.load_state(r2.get_state()) + r._load_state(r2._get_state()) assert not r.client_conn def test_replace(self): @@ -694,12 +694,12 @@ class uResponse(libpry.AutoTree): req = flow.Request(c, "host", 22, "https", "GET", "/", h, "content") resp = flow.Response(req, 200, "msg", h.copy(), "content") - state = resp.get_state() - assert flow.Response.from_state(req, state) == resp + state = resp._get_state() + assert flow.Response._from_state(req, state) == resp resp2 = flow.Response(req, 220, "foo", h.copy(), "test") assert not resp == resp2 - resp.load_state(resp2.get_state()) + resp._load_state(resp2._get_state()) assert resp == resp2 def test_replace(self): @@ -739,14 +739,14 @@ class uResponse(libpry.AutoTree): class uError(libpry.AutoTree): def test_getset_state(self): e = flow.Error(None, "Error") - state = e.get_state() - assert flow.Error.from_state(state) == e + state = e._get_state() + assert flow.Error._from_state(state) == e assert e.copy() e2 = flow.Error(None, "bar") assert not e == e2 - e.load_state(e2.get_state()) + e._load_state(e2._get_state()) assert e == e2 @@ -762,12 +762,12 @@ class uError(libpry.AutoTree): class uClientConnect(libpry.AutoTree): def test_state(self): c = flow.ClientConnect(("a", 22)) - assert flow.ClientConnect.from_state(c.get_state()) == c + assert flow.ClientConnect._from_state(c._get_state()) == c c2 = flow.ClientConnect(("a", 25)) assert not c == c2 - c.load_state(c2.get_state()) + c._load_state(c2._get_state()) assert c == c2 c3 = c.copy() @@ -851,8 +851,8 @@ class uHeaders(libpry.AutoTree): self.hd.add("foo", 1) self.hd.add("foo", 2) self.hd.add("bar", 3) - state = self.hd.get_state() - nd = flow.Headers.from_state(state) + state = self.hd._get_state() + nd = flow.Headers._from_state(state) assert nd == self.hd def test_copy(self): -- cgit v1.2.3