aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAldo Cortesi <aldo@nullcube.com>2011-07-14 15:59:27 +1200
committerAldo Cortesi <aldo@nullcube.com>2011-07-14 16:01:54 +1200
commit1c9e7b982a7d3fe76f5bd03e53b21b9d450f4607 (patch)
treed8280341e1f23bce216c934398bf0079d85de444
parentb6e1bf63c3bb49e7515807e6b36dc3116b565f67 (diff)
downloadmitmproxy-1c9e7b982a7d3fe76f5bd03e53b21b9d450f4607.tar.gz
mitmproxy-1c9e7b982a7d3fe76f5bd03e53b21b9d450f4607.tar.bz2
mitmproxy-1c9e7b982a7d3fe76f5bd03e53b21b9d450f4607.zip
Rewrite Headers object to preserve order and case.
-rw-r--r--libmproxy/console.py10
-rw-r--r--libmproxy/filt.py2
-rw-r--r--libmproxy/flow.py7
-rw-r--r--libmproxy/proxy.py17
-rw-r--r--libmproxy/utils.py154
-rw-r--r--test/test_server.py4
-rw-r--r--test/test_utils.py114
7 files changed, 100 insertions, 208 deletions
diff --git a/libmproxy/console.py b/libmproxy/console.py
index d99dd8ac..ffe37fc3 100644
--- a/libmproxy/console.py
+++ b/libmproxy/console.py
@@ -93,7 +93,7 @@ def format_flow(f, focus, extended=False, padding=2):
txt.append(("goodcode", str(f.response.code)))
else:
txt.append(("error", str(f.response.code)))
- t = f.response.headers.get("content-type")
+ t = f.response.headers["content-type"]
if t:
t = t[0].split(";")[0]
txt.append(("text", " %s"%t))
@@ -295,7 +295,11 @@ class ConnectionView(WWrap):
def _conn_text(self, conn, viewmode):
if conn:
- return self.master._cached_conn_text(conn.content, tuple(conn.headers.itemPairs()), viewmode)
+ return self.master._cached_conn_text(
+ conn.content,
+ tuple([tuple(i) for i in conn.headers.lst]),
+ viewmode
+ )
else:
return urwid.ListBox([])
@@ -485,7 +489,7 @@ class ConnectionView(WWrap):
else:
conn = self.flow.response
if conn.content:
- t = conn.headers.get("content-type", [None])
+ t = conn.headers["content-type"] or [None]
t = t[0]
self.master.spawn_external_viewer(conn.content, t)
elif key == "b":
diff --git a/libmproxy/filt.py b/libmproxy/filt.py
index 31c43581..40cf7358 100644
--- a/libmproxy/filt.py
+++ b/libmproxy/filt.py
@@ -79,7 +79,7 @@ class _Rex(_Action):
raise ValueError, "Cannot compile expression."
def _check_content_type(expr, o):
- val = o.headers.get("content-type")
+ val = o.headers["content-type"]
if val and re.search(expr, val[0]):
return True
return False
diff --git a/libmproxy/flow.py b/libmproxy/flow.py
index be77af33..d29b8e2d 100644
--- a/libmproxy/flow.py
+++ b/libmproxy/flow.py
@@ -101,7 +101,7 @@ class ServerPlaybackState:
if self.headers:
hdrs = []
for i in self.headers:
- v = r.headers.get(i, [])
+ v = r.headers[i]
# Slightly subtle: we need to convert everything to strings
# to prevent a mismatch between unicode/non-unicode.
v = [str(x) for x in v]
@@ -139,7 +139,7 @@ class StickyCookieState:
)
def handle_response(self, f):
- for i in f.response.headers.get("set-cookie", []):
+ for i in f.response.headers["set-cookie"]:
# FIXME: We now know that Cookie.py screws up some cookies with
# valid RFC 822/1123 datetime specifications for expiry. Sigh.
c = Cookie.SimpleCookie(i)
@@ -158,9 +158,10 @@ class StickyCookieState:
f.request.path.startswith(i[2])
]
if all(match):
- l = f.request.headers.setdefault("cookie", [])
+ l = f.request.headers["cookie"]
f.request.stickycookie = True
l.append(self.jar[i].output(header="").strip())
+ f.request.headers["cookie"] = l
class StickyAuthState:
diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py
index 362d622d..690df9f4 100644
--- a/libmproxy/proxy.py
+++ b/libmproxy/proxy.py
@@ -56,11 +56,11 @@ def read_chunked(fp):
def read_http_body(rfile, connection, headers, all):
- if headers.has_key('transfer-encoding'):
+ if 'transfer-encoding' in headers:
if not ",".join(headers["transfer-encoding"]) == "chunked":
raise IOError('Invalid transfer-encoding')
content = read_chunked(rfile)
- elif headers.has_key("content-length"):
+ elif "content-length" in headers:
content = rfile.read(int(headers["content-length"][0]))
elif all:
content = rfile.read()
@@ -152,8 +152,7 @@ class Request(controller.Msg):
"if-none-match",
]
for i in delheaders:
- if i in self.headers:
- del self.headers[i]
+ del self.headers[i]
def set_replay(self):
self.client_conn = None
@@ -251,7 +250,7 @@ class Request(controller.Msg):
utils.try_del(headers, 'connection')
utils.try_del(headers, 'content-length')
utils.try_del(headers, 'transfer-encoding')
- if not headers.has_key('host'):
+ if not 'host' in headers:
headers["host"] = [self.hostport()]
content = self.content
if content is not None:
@@ -321,7 +320,7 @@ class Response(controller.Msg):
new = mktime_tz(d) + delta
self.headers[i] = [formatdate(new)]
c = []
- for i in self.headers.get("set-cookie", []):
+ for i in self.headers["set-cookie"]:
c.append(self._refresh_cookie(i, delta))
if c:
self.headers["set-cookie"] = c
@@ -656,7 +655,7 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
scheme = "https"
headers = utils.Headers()
headers.read(self.rfile)
- if host is None and headers.has_key("host"):
+ if host is None and "host" in headers:
netloc = headers["host"][0]
if ':' in netloc:
host, port = string.split(netloc, ':')
@@ -670,7 +669,7 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
port = int(port)
if host is None:
raise ProxyError(400, 'Invalid request: %s'%request)
- if headers.has_key('expect'):
+ if "expect" in headers:
expect = ",".join(headers['expect'])
if expect == "100-continue" and httpminor >= 1:
self.wfile.write('HTTP/1.1 100 Continue\r\n')
@@ -681,7 +680,7 @@ class ProxyHandler(SocketServer.StreamRequestHandler):
raise ProxyError(417, 'Unmet expect: %s'%expect)
if httpminor == 0:
client_conn.close = True
- if headers.has_key('connection'):
+ if "connection" in headers:
for value in ",".join(headers['connection']).split(","):
value = value.strip()
if value == "close":
diff --git a/libmproxy/utils.py b/libmproxy/utils.py
index 8ac1f547..38fc6107 100644
--- a/libmproxy/utils.py
+++ b/libmproxy/utils.py
@@ -12,7 +12,7 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
-import re, os, subprocess, datetime, textwrap, errno, sys, time, functools
+import re, os, subprocess, datetime, textwrap, errno, sys, time, functools, copy
import json
CERT_SLEEP_TIME = 1
@@ -164,10 +164,6 @@ def isSequenceLike(anobj):
return 1
-def _caseless(s):
- return s.lower()
-
-
def try_del(dict, key):
try:
del dict[key]
@@ -175,108 +171,72 @@ def try_del(dict, key):
pass
-class MultiDict:
- """
- Simple wrapper around a dictionary to make holding multiple objects per
- key easier.
-
- Note that this class assumes that keys are strings.
-
- Keys have no order, but the order in which values are added to a key is
- preserved.
- """
- # This ridiculous bit of subterfuge is needed to prevent the class from
- # treating this as a bound method.
- _helper = (str,)
- def __init__(self):
- self._d = dict()
-
- def copy(self):
- m = self.__class__()
- m._d = self._d.copy()
- return m
-
- def clear(self):
- return self._d.clear()
-
- def get(self, key, d=None):
- key = self._helper[0](key)
- return self._d.get(key, d)
+class Headers:
+ def __init__(self, lst=None):
+ if lst:
+ self.lst = lst
+ else:
+ self.lst = []
- def __contains__(self, key):
- key = self._helper[0](key)
- return self._d.__contains__(key)
+ def _kconv(self, s):
+ return s.lower()
def __eq__(self, other):
- return dict(self) == dict(other)
-
- def __delitem__(self, key):
- self._d.__delitem__(key)
-
- def __getitem__(self, key):
- key = self._helper[0](key)
- return self._d.__getitem__(key)
-
- def __setitem__(self, key, value):
- if not isSequenceLike(value):
- raise ValueError, "Cannot insert non-sequence."
- key = self._helper[0](key)
- return self._d.__setitem__(key, value)
-
- def has_key(self, key):
- key = self._helper[0](key)
- return self._d.has_key(key)
-
- def setdefault(self, key, default=None):
- key = self._helper[0](key)
- return self._d.setdefault(key, default)
-
- def keys(self):
- return self._d.keys()
-
- def extend(self, key, value):
- if not self.has_key(key):
- self[key] = []
- self[key].extend(value)
+ return self.lst == other.lst
+
+ def __getitem__(self, k):
+ ret = []
+ k = self._kconv(k)
+ for i in self.lst:
+ if self._kconv(i[0]) == k:
+ ret.append(i[1])
+ return ret
+
+ def _filter_lst(self, k, lst):
+ new = []
+ for i in lst:
+ if self._kconv(i[0]) != k:
+ new.append(i)
+ return new
+
+ def __setitem__(self, k, hdrs):
+ k = self._kconv(k)
+ first = None
+ new = self._filter_lst(k, self.lst)
+ for i in hdrs:
+ new.append((k, i))
+ self.lst = new
+
+ def __delitem__(self, k):
+ self.lst = self._filter_lst(k, self.lst)
+
+ def __contains__(self, k):
+ for i in self.lst:
+ if self._kconv(i[0]) == k:
+ return True
+ return False
- def append(self, key, value):
- self.extend(key, [value])
-
- def itemPairs(self):
- """
- Yield all possible pairs of items.
- """
- for i in self.keys():
- for j in self[i]:
- yield (i, j)
+ def add(self, key, value):
+ self.lst.append([key, str(value)])
def get_state(self):
- return list(self.itemPairs())
+ return [tuple(i) for i in self.lst]
@classmethod
def from_state(klass, state):
- md = klass()
- for i in state:
- md.append(*i)
- return md
-
+ return klass([list(i) for i in state])
-class Headers(MultiDict):
- """
- A dictionary-like class for keeping track of HTTP headers.
+ def copy(self):
+ lst = copy.deepcopy(self.lst)
+ return Headers(lst)
- It is case insensitive, and __repr__ formats the headers correcty for
- output to the server.
- """
- _helper = (_caseless,)
def __repr__(self):
"""
Returns a string containing a formatted header string.
"""
headerElements = []
- for key in sorted(self.keys()):
- for val in self[key]:
- headerElements.append(key + ": " + val)
+ for itm in self.lst:
+ headerElements.append(itm[0] + ": " + itm[1])
headerElements.append("")
return "\r\n".join(headerElements)
@@ -284,7 +244,7 @@ class Headers(MultiDict):
"""
Match the regular expression against each header (key, value) pair.
"""
- for k, v in self.itemPairs():
+ for k, v in self.lst:
s = "%s: %s"%(k, v)
if re.search(expr, s):
return True
@@ -295,6 +255,7 @@ class Headers(MultiDict):
Read a set of headers from a file pointer. Stop once a blank line
is reached.
"""
+ ret = []
name = ''
while 1:
line = fp.readline()
@@ -302,18 +263,15 @@ class Headers(MultiDict):
break
if line[0] in ' \t':
# continued header
- self[name][-1] = self[name][-1] + '\r\n ' + line.strip()
+ ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip()
else:
i = line.find(':')
# We're being liberal in what we accept, here.
if i > 0:
name = line[:i]
value = line[i+1:].strip()
- if self.has_key(name):
- # merge value
- self.append(name, value)
- else:
- self[name] = [value]
+ ret.append([name, value])
+ self.lst = ret
def pretty_size(size):
diff --git a/test/test_server.py b/test/test_server.py
index e9b61165..1e3c1df4 100644
--- a/test/test_server.py
+++ b/test/test_server.py
@@ -44,7 +44,7 @@ class uProxy(tutils.ProxTest):
l = self.log()
assert l[0].address
- assert l[1].headers.has_key("host")
+ assert "host" in l[1].headers
assert l[2].code == 200
def test_https(self):
@@ -55,7 +55,7 @@ class uProxy(tutils.ProxTest):
l = self.log()
assert l[0].address
- assert l[1].headers.has_key("host")
+ assert "host" in l[1].headers
assert l[2].code == 200
# Disable these two for now: they take a long time.
diff --git a/test/test_utils.py b/test/test_utils.py
index b64db918..d5957872 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -43,88 +43,6 @@ class uData(libpry.AutoTree):
libpry.raises("does not exist", utils.data.path, "nonexistent")
-class uMultiDict(libpry.AutoTree):
- def setUp(self):
- self.md = utils.MultiDict()
-
- def test_setget(self):
- assert not self.md.has_key("foo")
- self.md.append("foo", 1)
- assert self.md["foo"] == [1]
- assert self.md.has_key("foo")
-
- def test_del(self):
- self.md.append("foo", 1)
- del self.md["foo"]
- assert not self.md.has_key("foo")
-
- def test_extend(self):
- self.md.append("foo", 1)
- self.md.extend("foo", [2, 3])
- assert self.md["foo"] == [1, 2, 3]
-
- def test_extend_err(self):
- self.md.append("foo", 1)
- libpry.raises("not iterable", self.md.extend, "foo", 2)
-
- def test_get(self):
- self.md.append("foo", 1)
- self.md.append("foo", 2)
- assert self.md.get("foo") == [1, 2]
- assert self.md.get("bar") == None
-
- def test_caseSensitivity(self):
- self.md._helper = (utils._caseless,)
- self.md["foo"] = [1]
- self.md.append("FOO", 2)
- assert self.md["foo"] == [1, 2]
- assert self.md["FOO"] == [1, 2]
- assert self.md.has_key("FoO")
-
- def test_dict(self):
- self.md.append("foo", 1)
- self.md.append("foo", 2)
- self.md["bar"] = [3]
- assert self.md == self.md
- assert dict(self.md) == self.md
-
- def test_copy(self):
- self.md["foo"] = [1, 2]
- self.md["bar"] = [3, 4]
- md2 = self.md.copy()
- assert md2 == self.md
- assert id(md2) != id(self.md)
-
- def test_clear(self):
- self.md["foo"] = [1, 2]
- self.md["bar"] = [3, 4]
- self.md.clear()
- assert not self.md.keys()
-
- def test_setitem(self):
- libpry.raises(ValueError, self.md.__setitem__, "foo", "bar")
- self.md["foo"] = ["bar"]
- assert self.md["foo"] == ["bar"]
-
- def test_itemPairs(self):
- self.md.append("foo", 1)
- self.md.append("foo", 2)
- self.md.append("bar", 3)
- l = list(self.md.itemPairs())
- assert len(l) == 3
- assert ("foo", 1) in l
- assert ("foo", 2) in l
- assert ("bar", 3) in l
-
- def test_getset_state(self):
- self.md.append("foo", 1)
- self.md.append("foo", 2)
- self.md.append("bar", 3)
- state = self.md.get_state()
- nd = utils.MultiDict.from_state(state)
- assert nd == self.md
-
-
class uHeaders(libpry.AutoTree):
def setUp(self):
self.hd = utils.Headers()
@@ -168,9 +86,9 @@ class uHeaders(libpry.AutoTree):
assert self.hd["header"] == ['one\r\n two']
def test_dictToHeader1(self):
- self.hd.append("one", "uno")
- self.hd.append("two", "due")
- self.hd.append("two", "tre")
+ self.hd.add("one", "uno")
+ self.hd.add("two", "due")
+ self.hd.add("two", "tre")
expected = [
"one: uno\r\n",
"two: due\r\n",
@@ -191,21 +109,34 @@ class uHeaders(libpry.AutoTree):
def test_match_re(self):
h = utils.Headers()
- h.append("one", "uno")
- h.append("two", "due")
- h.append("two", "tre")
+ h.add("one", "uno")
+ h.add("two", "due")
+ h.add("two", "tre")
assert h.match_re("uno")
assert h.match_re("two: due")
assert not h.match_re("nonono")
def test_getset_state(self):
- self.hd.append("foo", 1)
- self.hd.append("foo", 2)
- self.hd.append("bar", 3)
+ self.hd.add("foo", 1)
+ self.hd.add("foo", 2)
+ self.hd.add("bar", 3)
state = self.hd.get_state()
nd = utils.Headers.from_state(state)
assert nd == self.hd
+ def test_copy(self):
+ self.hd.add("foo", 1)
+ self.hd.add("foo", 2)
+ self.hd.add("bar", 3)
+ assert self.hd == self.hd.copy()
+
+ def test_del(self):
+ self.hd.add("foo", 1)
+ self.hd.add("Foo", 2)
+ self.hd.add("bar", 3)
+ del self.hd["foo"]
+ assert len(self.hd.lst) == 1
+
class uisStringLike(libpry.AutoTree):
def test_all(self):
@@ -371,7 +302,6 @@ tests = [
upretty_size(),
uisStringLike(),
uisSequenceLike(),
- uMultiDict(),
uHeaders(),
uData(),
upretty_xmlish(),