From d51b8cab0c0d1352865155865dfd258f66103ffe Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 16 Mar 2012 11:12:56 +1300 Subject: Add a decoded context manager. This simplifies a common chore when modifying traffic - decoding the object, modifying it, then re-encoding it with the same encoding afterwards. You can now simply say: with flow.decoded(request): request.content = "bar" --- libmproxy/flow.py | 28 ++++++++++++++++++++++++++++ test/test_flow.py | 24 ++++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/libmproxy/flow.py b/libmproxy/flow.py index 450fef30..dbc0c109 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -199,6 +199,34 @@ class ODictCaseless(ODict): return s.lower() +class decoded(object): + """ + + A context manager that decodes a request, response or error, and then + re-encodes it with the same encoding after execution of the block. + + Example: + + with decoded(request): + request.content = request.content.replace("foo", "bar") + """ + def __init__(self, o): + self.o = o + ce = o.headers["content-encoding"] + if ce and ce[0] in encoding.ENCODINGS: + self.ce = ce[0] + else: + self.ce = None + + def __enter__(self): + if self.ce: + self.o.decode() + + def __exit__(self, type, value, tb): + if self.ce: + self.o.encode(self.ce) + + class HTTPMsg(controller.Msg): def decode(self): """ diff --git a/test/test_flow.py b/test/test_flow.py index b6818960..56303881 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -996,6 +996,29 @@ class uODictCaseless(libpry.AutoTree): assert len(self.od) == 1 +class udecoded(libpry.AutoTree): + def test_del(self): + r = tutils.treq() + assert r.content == "content" + assert not r.headers["content-encoding"] + r.encode("gzip") + assert r.headers["content-encoding"] + assert r.content != "content" + with flow.decoded(r): + assert not r.headers["content-encoding"] + assert r.content == "content" + assert r.headers["content-encoding"] + assert r.content != "content" + + with flow.decoded(r): + r.content = "foo" + + assert r.content != "foo" + r.decode() + assert r.content == "foo" + + + tests = [ uStickyCookieState(), @@ -1012,4 +1035,5 @@ tests = [ uClientConnect(), uODict(), uODictCaseless(), + udecoded() ] -- cgit v1.2.3