aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--libmproxy/flow.py21
-rw-r--r--libmproxy/proxy.py41
-rw-r--r--test/test_proxy.py31
3 files changed, 87 insertions, 6 deletions
diff --git a/libmproxy/flow.py b/libmproxy/flow.py
index 03b8b309..3520cc93 100644
--- a/libmproxy/flow.py
+++ b/libmproxy/flow.py
@@ -99,11 +99,26 @@ class Flow:
def load_state(self, state):
self._backup = state["backup"]
- self.request = proxy.Request.from_state(state["request"])
+ if self.request:
+ self.request.load_state(state["request"])
+ else:
+ self.request = proxy.Request.from_state(state["request"])
+
if state["response"]:
- self.response = proxy.Response.from_state(self.request, state["response"])
+ if self.response:
+ self.response.load_state(state["response"])
+ else:
+ self.response = proxy.Response.from_state(self.request, state["response"])
+ else:
+ self.response = None
+
if state["error"]:
- self.error = proxy.Error.from_state(state["error"])
+ if self.error:
+ self.error.load_state(state["error"])
+ else:
+ self.error = proxy.Error.from_state(state["error"])
+ else:
+ self.error = None
@classmethod
def from_state(klass, state):
diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py
index 88c62b25..4ab19694 100644
--- a/libmproxy/proxy.py
+++ b/libmproxy/proxy.py
@@ -148,6 +148,23 @@ class Request(controller.Msg):
def is_cached(self):
return False
+ def load_state(self, state):
+ if state["client_conn"]:
+ if self.client_conn:
+ self.client_conn.load_state(state["client_conn"])
+ else:
+ self.client_conn = ClientConnect.from_state(state["client_conn"])
+ else:
+ self.client_conn = None
+ self.host = state["host"]
+ self.port = state["port"]
+ self.scheme = state["scheme"]
+ self.method = state["method"]
+ self.path = state["path"]
+ self.headers = utils.Headers.from_state(state["headers"])
+ self.content = base64.decodestring(state["content"])
+ self.timestamp = state["timestamp"]
+
def get_state(self):
return dict(
client_conn = self.client_conn.get_state() if self.client_conn else None,
@@ -164,7 +181,7 @@ class Request(controller.Msg):
@classmethod
def from_state(klass, state):
return klass(
- ClientConnect.from_state(state["client_conn"]) if state["client_conn"] else None,
+ ClientConnect.from_state(state["client_conn"]),
state["host"],
state["port"],
state["scheme"],
@@ -249,6 +266,13 @@ class Response(controller.Msg):
self.cached = False
controller.Msg.__init__(self)
+ def load_state(self, state):
+ self.code = state["code"]
+ self.msg = state["msg"]
+ self.headers = utils.Headers.from_state(state["headers"])
+ self.content = base64.decodestring(state["content"])
+ self.timestamp = state["timestamp"]
+
def get_state(self):
return dict(
code = self.code,
@@ -325,12 +349,21 @@ class ClientConnect(controller.Msg):
self.close = False
controller.Msg.__init__(self)
+ def __eq__(self, other):
+ return self.get_state() == other.get_state()
+
+ def load_state(self, state):
+ self.address = state
+
def get_state(self):
return list(self.address) if self.address else None
@classmethod
def from_state(klass, state):
- return klass(state)
+ if state:
+ return klass(state)
+ else:
+ return None
def copy(self):
return copy.copy(self)
@@ -342,6 +375,10 @@ class Error(controller.Msg):
self.timestamp = timestamp or time.time()
controller.Msg.__init__(self)
+ def load_state(self, state):
+ self.msg = state["msg"]
+ self.timestamp = state["timestamp"]
+
def copy(self):
return copy.copy(self)
diff --git a/test/test_proxy.py b/test/test_proxy.py
index cb2528fd..0b40164e 100644
--- a/test/test_proxy.py
+++ b/test/test_proxy.py
@@ -234,6 +234,11 @@ class uRequest(libpry.AutoTree):
state = r.get_state()
assert proxy.Request.from_state(state) == r
+ r2 = proxy.Request(c, "testing", 20, "http", "PUT", "/foo", h, "test")
+ assert not r == r2
+ r.load_state(r2.get_state())
+ assert r == r2
+
class uResponse(libpry.AutoTree):
def test_simple(self):
@@ -256,6 +261,11 @@ class uResponse(libpry.AutoTree):
state = resp.get_state()
assert proxy.Response.from_state(req, state) == resp
+ resp2 = proxy.Response(req, 220, "foo", h.copy(), "test")
+ assert not resp == resp2
+ resp.load_state(resp2.get_state())
+ assert resp == resp2
+
class uError(libpry.AutoTree):
def test_getset_state(self):
@@ -265,6 +275,12 @@ class uError(libpry.AutoTree):
assert e.copy()
+ e2 = proxy.Error(None, "bar")
+ assert not e == e2
+ e.load_state(e2.get_state())
+ assert e == e2
+
+
class uProxyError(libpry.AutoTree):
def test_simple(self):
@@ -272,6 +288,18 @@ class uProxyError(libpry.AutoTree):
assert repr(p)
+class uClientConnect(libpry.AutoTree):
+ def test_state(self):
+ c = proxy.ClientConnect(("a", 22))
+ assert proxy.ClientConnect.from_state(c.get_state()) == c
+
+ c2 = proxy.ClientConnect(("a", 25))
+ assert not c == c2
+
+ c.load_state(c2.get_state())
+ assert c == c2
+
+
tests = [
uProxyError(),
@@ -281,8 +309,9 @@ tests = [
u_parse_request_line(),
u_parse_url(),
uError(),
+ uClientConnect(),
_TestServers(), [
uSanity(),
uProxy(),
- ]
+ ],
]