diff options
-rw-r--r-- | libmproxy/flow.py | 21 | ||||
-rw-r--r-- | libmproxy/proxy.py | 41 | ||||
-rw-r--r-- | test/test_proxy.py | 31 |
3 files changed, 87 insertions, 6 deletions
diff --git a/libmproxy/flow.py b/libmproxy/flow.py index 03b8b309..3520cc93 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -99,11 +99,26 @@ class Flow: def load_state(self, state): self._backup = state["backup"] - self.request = proxy.Request.from_state(state["request"]) + if self.request: + self.request.load_state(state["request"]) + else: + self.request = proxy.Request.from_state(state["request"]) + if state["response"]: - self.response = proxy.Response.from_state(self.request, state["response"]) + if self.response: + self.response.load_state(state["response"]) + else: + self.response = proxy.Response.from_state(self.request, state["response"]) + else: + self.response = None + if state["error"]: - self.error = proxy.Error.from_state(state["error"]) + if self.error: + self.error.load_state(state["error"]) + else: + self.error = proxy.Error.from_state(state["error"]) + else: + self.error = None @classmethod def from_state(klass, state): diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 88c62b25..4ab19694 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -148,6 +148,23 @@ class Request(controller.Msg): def is_cached(self): return False + def load_state(self, state): + if state["client_conn"]: + if self.client_conn: + self.client_conn.load_state(state["client_conn"]) + else: + self.client_conn = ClientConnect.from_state(state["client_conn"]) + else: + self.client_conn = None + self.host = state["host"] + self.port = state["port"] + self.scheme = state["scheme"] + self.method = state["method"] + self.path = state["path"] + self.headers = utils.Headers.from_state(state["headers"]) + self.content = base64.decodestring(state["content"]) + self.timestamp = state["timestamp"] + def get_state(self): return dict( client_conn = self.client_conn.get_state() if self.client_conn else None, @@ -164,7 +181,7 @@ class Request(controller.Msg): @classmethod def from_state(klass, state): return klass( - ClientConnect.from_state(state["client_conn"]) if state["client_conn"] else None, + ClientConnect.from_state(state["client_conn"]), state["host"], state["port"], state["scheme"], @@ -249,6 +266,13 @@ class Response(controller.Msg): self.cached = False controller.Msg.__init__(self) + def load_state(self, state): + self.code = state["code"] + self.msg = state["msg"] + self.headers = utils.Headers.from_state(state["headers"]) + self.content = base64.decodestring(state["content"]) + self.timestamp = state["timestamp"] + def get_state(self): return dict( code = self.code, @@ -325,12 +349,21 @@ class ClientConnect(controller.Msg): self.close = False controller.Msg.__init__(self) + def __eq__(self, other): + return self.get_state() == other.get_state() + + def load_state(self, state): + self.address = state + def get_state(self): return list(self.address) if self.address else None @classmethod def from_state(klass, state): - return klass(state) + if state: + return klass(state) + else: + return None def copy(self): return copy.copy(self) @@ -342,6 +375,10 @@ class Error(controller.Msg): self.timestamp = timestamp or time.time() controller.Msg.__init__(self) + def load_state(self, state): + self.msg = state["msg"] + self.timestamp = state["timestamp"] + def copy(self): return copy.copy(self) diff --git a/test/test_proxy.py b/test/test_proxy.py index cb2528fd..0b40164e 100644 --- a/test/test_proxy.py +++ b/test/test_proxy.py @@ -234,6 +234,11 @@ class uRequest(libpry.AutoTree): state = r.get_state() assert proxy.Request.from_state(state) == r + r2 = proxy.Request(c, "testing", 20, "http", "PUT", "/foo", h, "test") + assert not r == r2 + r.load_state(r2.get_state()) + assert r == r2 + class uResponse(libpry.AutoTree): def test_simple(self): @@ -256,6 +261,11 @@ class uResponse(libpry.AutoTree): state = resp.get_state() assert proxy.Response.from_state(req, state) == resp + resp2 = proxy.Response(req, 220, "foo", h.copy(), "test") + assert not resp == resp2 + resp.load_state(resp2.get_state()) + assert resp == resp2 + class uError(libpry.AutoTree): def test_getset_state(self): @@ -265,6 +275,12 @@ class uError(libpry.AutoTree): assert e.copy() + e2 = proxy.Error(None, "bar") + assert not e == e2 + e.load_state(e2.get_state()) + assert e == e2 + + class uProxyError(libpry.AutoTree): def test_simple(self): @@ -272,6 +288,18 @@ class uProxyError(libpry.AutoTree): assert repr(p) +class uClientConnect(libpry.AutoTree): + def test_state(self): + c = proxy.ClientConnect(("a", 22)) + assert proxy.ClientConnect.from_state(c.get_state()) == c + + c2 = proxy.ClientConnect(("a", 25)) + assert not c == c2 + + c.load_state(c2.get_state()) + assert c == c2 + + tests = [ uProxyError(), @@ -281,8 +309,9 @@ tests = [ u_parse_request_line(), u_parse_url(), uError(), + uClientConnect(), _TestServers(), [ uSanity(), uProxy(), - ] + ], ] |