aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAldo Cortesi <aldo@nullcube.com>2011-08-19 21:30:24 +1200
committerAldo Cortesi <aldo@nullcube.com>2011-08-19 21:30:24 +1200
commita566684e3280ebbe15dd397710ee1b26bf8bd571 (patch)
tree6099d0213e8b2778823140d2e96e827743bd59ab
parent34adc83c717d82b8c9fc7fecc690b8b570a04644 (diff)
downloadmitmproxy-a566684e3280ebbe15dd397710ee1b26bf8bd571.tar.gz
mitmproxy-a566684e3280ebbe15dd397710ee1b26bf8bd571.tar.bz2
mitmproxy-a566684e3280ebbe15dd397710ee1b26bf8bd571.zip
Move to typed netstrings for serialization.
This change is backwards incompatible with the old serialization format!
-rw-r--r--libmproxy/flow.py28
-rw-r--r--libmproxy/netstring.py528
-rw-r--r--test/test_netstring.py65
3 files changed, 400 insertions, 221 deletions
diff --git a/libmproxy/flow.py b/libmproxy/flow.py
index 214ce8f9..8ec1a6a3 100644
--- a/libmproxy/flow.py
+++ b/libmproxy/flow.py
@@ -264,7 +264,7 @@ class Request(HTTPMsg):
self.method = state["method"]
self.path = state["path"]
self.headers = Headers._from_state(state["headers"])
- self.content = base64.decodestring(state["content"])
+ self.content = state["content"]
self.timestamp = state["timestamp"]
def _get_state(self):
@@ -276,7 +276,7 @@ class Request(HTTPMsg):
method = self.method,
path = self.path,
headers = self.headers._get_state(),
- content = base64.encodestring(self.content),
+ content = self.content,
timestamp = self.timestamp,
)
@@ -290,7 +290,7 @@ class Request(HTTPMsg):
str(state["method"]),
str(state["path"]),
Headers._from_state(state["headers"]),
- base64.decodestring(state["content"]),
+ state["content"],
state["timestamp"]
)
@@ -467,7 +467,7 @@ class Response(HTTPMsg):
self.code = state["code"]
self.msg = state["msg"]
self.headers = Headers._from_state(state["headers"])
- self.content = base64.decodestring(state["content"])
+ self.content = state["content"]
self.timestamp = state["timestamp"]
def _get_state(self):
@@ -476,7 +476,7 @@ class Response(HTTPMsg):
msg = self.msg,
headers = self.headers._get_state(),
timestamp = self.timestamp,
- content = base64.encodestring(self.content)
+ content = self.content
)
@classmethod
@@ -486,7 +486,7 @@ class Response(HTTPMsg):
state["code"],
str(state["msg"]),
Headers._from_state(state["headers"]),
- base64.decodestring(state["content"]),
+ state["content"],
state["timestamp"],
)
@@ -1316,12 +1316,10 @@ class FlowMaster(controller.Master):
class FlowWriter:
def __init__(self, fo):
self.fo = fo
- self.ns = netstring.FileEncoder(fo)
def add(self, flow):
d = flow._get_state()
- s = json.dumps(d)
- self.ns.write(s)
+ netstring.dump(d, self.fo)
class FlowReadError(Exception):
@@ -1333,16 +1331,20 @@ class FlowReadError(Exception):
class FlowReader:
def __init__(self, fo):
self.fo = fo
- self.ns = netstring.decode_file(fo)
def stream(self):
"""
Yields Flow objects from the dump.
"""
+ off = 0
try:
- for i in self.ns:
- data = json.loads(i)
+ while 1:
+ data = netstring.load(self.fo)
+ off = self.fo.tell()
yield Flow._from_state(data)
- except netstring.DecoderError:
+ except ValueError, v:
+ # Error is due to EOF
+ if self.fo.tell() == off and self.fo.read() == '':
+ return
raise FlowReadError("Invalid data format.")
diff --git a/libmproxy/netstring.py b/libmproxy/netstring.py
index 669e19e3..03e38c6a 100644
--- a/libmproxy/netstring.py
+++ b/libmproxy/netstring.py
@@ -1,151 +1,393 @@
"""
- Netstring is a module for encoding and decoding netstring streams.
- See http://cr.yp.to/proto/netstrings.txt for more information on netstrings.
- Author: Will McGugan (http://www.willmcgugan.com)
+
+tnetstring: data serialization using typed netstrings
+======================================================
+
+
+This is a data serialization library. It's a lot like JSON but it uses a
+new syntax called "typed netstrings" that Zed has proposed for use in the
+Mongrel2 webserver. It's designed to be simpler and easier to implement
+than JSON, with a happy consequence of also being faster in many cases.
+
+An ordinary netstring is a blob of data prefixed with its length and postfixed
+with a sanity-checking comma. The string "hello world" encodes like this::
+
+ 11:hello world,
+
+Typed netstrings add other datatypes by replacing the comma with a type tag.
+Here's the integer 12345 encoded as a tnetstring::
+
+ 5:12345#
+
+And here's the list [12345,True,0] which mixes integers and bools::
+
+ 19:5:12345#4:true!1:0#]
+
+Simple enough? This module gives you the following functions:
+
+ :dump: dump an object as a tnetstring to a file
+ :dumps: dump an object as a tnetstring to a string
+ :load: load a tnetstring-encoded object from a file
+ :loads: load a tnetstring-encoded object from a string
+ :pop: pop a tnetstring-encoded object from the front of a string
+
+Note that since parsing a tnetstring requires reading all the data into memory
+at once, there's no efficiency gain from using the file-based versions of these
+functions. They're only here so you can use load() to read precisely one
+item from a file or socket without consuming any extra data.
+
+By default tnetstrings work only with byte strings, not unicode. If you want
+unicode strings then pass an optional encoding to the various functions,
+like so::
+
+ >>> print repr(tnetstring.loads("2:\\xce\\xb1,"))
+ '\\xce\\xb1'
+ >>>
+ >>> print repr(tnetstring.loads("2:\\xce\\xb1,","utf8"))
+ u'\u03b1'
+
"""
-from cStringIO import StringIO
-
-def header(data):
- return str(len(data))+":"
-
-
-class FileEncoder(object):
- def __init__(self, file_out):
- """"
- file_out -- A writable file object
- """
- self.file_out = file_out
-
- def write(self, data):
- """
- Encodes a netstring and writes it to the file object.
-
- data -- A string to be encoded and written
- """
- write = self.file_out.write
- write(header(data))
+__ver_major__ = 0
+__ver_minor__ = 2
+__ver_patch__ = 0
+__ver_sub__ = ""
+__version__ = "%d.%d.%d%s" % (__ver_major__,__ver_minor__,__ver_patch__,__ver_sub__)
+
+
+from collections import deque
+
+
+def dumps(value,encoding=None):
+ """dumps(object,encoding=None) -> string
+
+ This function dumps a python object as a tnetstring.
+ """
+ # This uses a deque to collect output fragments in reverse order,
+ # then joins them together at the end. It's measurably faster
+ # than creating all the intermediate strings.
+ # If you're reading this to get a handle on the tnetstring format,
+ # consider the _gdumps() function instead; it's a standard top-down
+ # generator that's simpler to understand but much less efficient.
+ q = deque()
+ _rdumpq(q,0,value,encoding)
+ return "".join(q)
+
+
+def dump(value,file,encoding=None):
+ """dump(object,file,encoding=None)
+
+ This function dumps a python object as a tnetstring and writes it to
+ the given file.
+ """
+ file.write(dumps(value,encoding))
+
+
+def _rdumpq(q,size,value,encoding=None):
+ """Dump value as a tnetstring, to a deque instance, last chunks first.
+
+ This function generates the tnetstring representation of the given value,
+ pushing chunks of the output onto the given deque instance. It pushes
+ the last chunk first, then recursively generates more chunks.
+
+ When passed in the current size of the string in the queue, it will return
+ the new size of the string in the queue.
+
+ Operating last-chunk-first makes it easy to calculate the size written
+ for recursive structures without having to build their representation as
+ a string. This is measurably faster than generating the intermediate
+ strings, especially on deeply nested structures.
+ """
+ write = q.appendleft
+ if value is None:
+ write("0:~")
+ return size + 3
+ if value is True:
+ write("4:true!")
+ return size + 7
+ if value is False:
+ write("5:false!")
+ return size + 8
+ if isinstance(value,(int,long)):
+ data = str(value)
+ ldata = len(data)
+ span = str(ldata)
+ write("#")
write(data)
- write(',')
- return self
-
-
-class DecoderError(Exception):
- PRECEDING_ZERO_IN_SIZE = 0
- MAX_SIZE_REACHED = 1
- ILLEGAL_DIGIT_IN_SIZE = 2
- ILLEGAL_DIGIT = 3
- error_text = {
- PRECEDING_ZERO_IN_SIZE: "PRECEDING_ZERO_IN_SIZE",
- MAX_SIZE_REACHED: "MAX_SIZE_REACHED",
- ILLEGAL_DIGIT_IN_SIZE: "ILLEGAL_DIGIT_IN_SIZE",
- ILLEGAL_DIGIT: "ILLEGAL_DIGIT"
- }
- def __init__(self, code, text):
- Exception.__init__(self)
- self.code = code
- self.text = text
-
- def __str__(self):
- return "%s (#%i), %s" % (DecoderError.error_text[self.code], self.code, self.text)
-
-
-class Decoder(object):
+ write(":")
+ write(span)
+ return size + 2 + len(span) + ldata
+ if isinstance(value,(float,)):
+ # Use repr() for float rather than str().
+ # It round-trips more accurately.
+ # Probably unnecessary in later python versions that
+ # use David Gay's ftoa routines.
+ data = repr(value)
+ ldata = len(data)
+ span = str(ldata)
+ write("^")
+ write(data)
+ write(":")
+ write(span)
+ return size + 2 + len(span) + ldata
+ if isinstance(value,str):
+ lvalue = len(value)
+ span = str(lvalue)
+ write(",")
+ write(value)
+ write(":")
+ write(span)
+ return size + 2 + len(span) + lvalue
+ if isinstance(value,(list,tuple,)):
+ write("]")
+ init_size = size = size + 1
+ for item in reversed(value):
+ size = _rdumpq(q,size,item,encoding)
+ span = str(size - init_size)
+ write(":")
+ write(span)
+ return size + 1 + len(span)
+ if isinstance(value,dict):
+ write("}")
+ init_size = size = size + 1
+ for (k,v) in value.iteritems():
+ size = _rdumpq(q,size,v,encoding)
+ size = _rdumpq(q,size,k,encoding)
+ span = str(size - init_size)
+ write(":")
+ write(span)
+ return size + 1 + len(span)
+ if isinstance(value,unicode):
+ if encoding is None:
+ raise ValueError("must specify encoding to dump unicode strings")
+ value = value.encode(encoding)
+ lvalue = len(value)
+ span = str(lvalue)
+ write(",")
+ write(value)
+ write(":")
+ write(span)
+ return size + 2 + len(span) + lvalue
+ raise ValueError("unserializable object")
+
+
+def _gdumps(value,encoding):
+ """Generate fragments of value dumped as a tnetstring.
+
+ This is the naive dumping algorithm, implemented as a generator so that
+ it's easy to pass to "".join() without building a new list.
+
+ This is mainly here for comparison purposes; the _rdumpq version is
+ measurably faster as it doesn't have to build intermediate strins.
"""
- A netstring decoder.
- Turns a netstring stream in to a number of discreet strings.
+ if value is None:
+ yield "0:~"
+ elif value is True:
+ yield "4:true!"
+ elif value is False:
+ yield "5:false!"
+ elif isinstance(value,(int,long)):
+ data = str(value)
+ yield str(len(data))
+ yield ":"
+ yield data
+ yield "#"
+ elif isinstance(value,(float,)):
+ data = repr(value)
+ yield str(len(data))
+ yield ":"
+ yield data
+ yield "^"
+ elif isinstance(value,(str,)):
+ yield str(len(value))
+ yield ":"
+ yield value
+ yield ","
+ elif isinstance(value,(list,tuple,)):
+ sub = []
+ for item in value:
+ sub.extend(_gdumps(item))
+ sub = "".join(sub)
+ yield str(len(sub))
+ yield ":"
+ yield sub
+ yield "]"
+ elif isinstance(value,(dict,)):
+ sub = []
+ for (k,v) in value.iteritems():
+ sub.extend(_gdumps(k))
+ sub.extend(_gdumps(v))
+ sub = "".join(sub)
+ yield str(len(sub))
+ yield ":"
+ yield sub
+ yield "}"
+ elif isinstance(value,(unicode,)):
+ if encoding is None:
+ raise ValueError("must specify encoding to dump unicode strings")
+ value = value.encode(encoding)
+ yield str(len(value))
+ yield ":"
+ yield value
+ yield ","
+ else:
+ raise ValueError("unserializable object")
+
+
+def loads(string,encoding=None):
+ """loads(string,encoding=None) -> object
+
+ This function parses a tnetstring into a python object.
"""
- def __init__(self, max_size=None):
- """
- Create a netstring-stream decoder object.
-
- max_size -- The maximum size of a netstring encoded string, after which
- a DecoderError will be throw. A value of None (the default) indicates
- that there should be no maximum string size.
- """
- self.max_size = max_size
- self.data_pos = 0
- self.string_start = 0
- self.expecting_terminator = False
- self.size_string = ""
- self.data_size = None
- self.remaining_bytes = 0
- self.data_out = StringIO()
- self.yield_data = ""
-
- def feed(self, data):
- """
- A generator that yields 0 or more strings from the given data.
-
- data -- A string containing complete or partial netstring data
- """
- self.data_pos = 0
- self.string_start = 0
- while self.data_pos < len(data):
- if self.expecting_terminator:
- c = data[self.data_pos]
- self.data_pos += 1
- if c != ',':
- raise DecoderError(DecoderError.ILLEGAL_DIGIT, "Illegal digit (%s) at end of data"%repr(c))
- yield self.yield_data
- self.yield_data = ""
- self.expecting_terminator = False
- elif self.data_size is None:
- c = data[self.data_pos]
- self.data_pos += 1
-
- if not len(self.size_string):
- self.string_start = self.data_pos-1
-
- if c in "0123456789":
- if self.size_string == '0':
- raise DecoderError(DecoderError.PRECEDING_ZERO_IN_SIZE, "Preceding zeros in size field illegal")
- self.size_string += c
- if self.max_size is not None and int(self.size_string) > self.max_size:
- raise DecoderError(DecoderError.MAX_SIZE_REACHED, "Maximum size of netstring exceeded")
-
- elif c == ":":
- if not len(self.size_string):
- raise DecoderError(DecoderError.ILLEGAL_DIGIT_IN_SIZE, "Illegal digit (%s) in size field"%repr(c))
- self.data_size = int(self.size_string)
- self.remaining_bytes = self.data_size
-
- else:
- raise DecoderError(DecoderError.ILLEGAL_DIGIT_IN_SIZE, "Illegal digit (%s) in size field"%repr(c))
-
- elif self.data_size is not None:
- get_bytes = min(self.remaining_bytes, len(data)-self.data_pos)
- chunk = data[self.data_pos:self.data_pos+get_bytes]
- whole_string = len(chunk) == self.data_size
- if not whole_string:
- self.data_out.write(chunk)
- self.data_pos += get_bytes
- self.remaining_bytes -= get_bytes
- if self.remaining_bytes == 0:
- if whole_string:
- self.yield_data = chunk
- else:
- self.yield_data = self.data_out.getvalue()
- self.data_out.reset()
- self.data_out.truncate()
- self.data_size = None
- self.size_string = ""
- self.remaining_bytes = 0
- self.expecting_terminator = True
-
-
-def decode_file(file_in, buffer_size=1024):
+ # No point duplicating effort here. In the C-extension version,
+ # loads() is measurably faster then pop() since it can avoid
+ # the overhead of building a second string.
+ return pop(string,encoding)[0]
+
+
+def load(file,encoding=None):
+ """load(file,encoding=None) -> object
+
+ This function reads a tnetstring from a file and parses it into a
+ python object. The file must support the read() method, and this
+ function promises not to read more data than necessary.
"""
- Generates 0 or more strings from a netstring file.
-
- file_in -- A readable file-like object containing netstring data
- buffer_size -- The number of bytes to attempt to read in each iteration
- (default = 1024).
+ # Read the length prefix one char at a time.
+ # Note that the netstring spec explicitly forbids padding zeros.
+ c = file.read(1)
+ if not c.isdigit():
+ raise ValueError("not a tnetstring: missing or invalid length prefix")
+ datalen = ord(c) - ord("0")
+ c = file.read(1)
+ if datalen != 0:
+ while c.isdigit():
+ datalen = (10 * datalen) + (ord(c) - ord("0"))
+ if datalen > 999999999:
+ errmsg = "not a tnetstring: absurdly large length prefix"
+ raise ValueError(errmsg)
+ c = file.read(1)
+ if c != ":":
+ raise ValueError("not a tnetstring: missing or invalid length prefix")
+ # Now we can read and parse the payload.
+ # This repeats the dispatch logic of pop() so we can avoid
+ # re-constructing the outermost tnetstring.
+ data = file.read(datalen)
+ if len(data) != datalen:
+ raise ValueError("not a tnetstring: length prefix too big")
+ type = file.read(1)
+ if type == ",":
+ if encoding is not None:
+ return data.decode(encoding)
+ return data
+ if type == "#":
+ try:
+ return int(data)
+ except ValueError:
+ raise ValueError("not a tnetstring: invalid integer literal")
+ if type == "^":
+ try:
+ return float(data)
+ except ValueError:
+ raise ValueError("not a tnetstring: invalid float literal")
+ if type == "!":
+ if data == "true":
+ return True
+ elif data == "false":
+ return False
+ else:
+ raise ValueError("not a tnetstring: invalid boolean literal")
+ if type == "~":
+ if data:
+ raise ValueError("not a tnetstring: invalid null literal")
+ return None
+ if type == "]":
+ l = []
+ while data:
+ (item,data) = pop(data,encoding)
+ l.append(item)
+ return l
+ if type == "}":
+ d = {}
+ while data:
+ (key,data) = pop(data,encoding)
+ (val,data) = pop(data,encoding)
+ d[key] = val
+ return d
+ raise ValueError("unknown type tag")
+
+
+
+def pop(string,encoding=None):
+ """pop(string,encoding=None) -> (object, remain)
+
+ This function parses a tnetstring into a python object.
+ It returns a tuple giving the parsed object and a string
+ containing any unparsed data from the end of the string.
"""
- decoder = Decoder()
- while True:
- data = file_in.read(buffer_size)
- if not len(data):
- return
- for s in decoder.feed(data):
- yield s
+ # Parse out data length, type and remaining string.
+ try:
+ (dlen,rest) = string.split(":",1)
+ dlen = int(dlen)
+ except ValueError:
+ raise ValueError("not a tnetstring: missing or invalid length prefix")
+ try:
+ (data,type,remain) = (rest[:dlen],rest[dlen],rest[dlen+1:])
+ except IndexError:
+ # This fires if len(rest) < dlen, meaning we don't need
+ # to further validate that data is the right length.
+ raise ValueError("not a tnetstring: invalid length prefix")
+ # Parse the data based on the type tag.
+ if type == ",":
+ if encoding is not None:
+ return (data.decode(encoding),remain)
+ return (data,remain)
+ if type == "#":
+ try:
+ return (int(data),remain)
+ except ValueError:
+ raise ValueError("not a tnetstring: invalid integer literal")
+ if type == "^":
+ try:
+ return (float(data),remain)
+ except ValueError:
+ raise ValueError("not a tnetstring: invalid float literal")
+ if type == "!":
+ if data == "true":
+ return (True,remain)
+ elif data == "false":
+ return (False,remain)
+ else:
+ raise ValueError("not a tnetstring: invalid boolean literal")
+ if type == "~":
+ if data:
+ raise ValueError("not a tnetstring: invalid null literal")
+ return (None,remain)
+ if type == "]":
+ l = []
+ while data:
+ (item,data) = pop(data,encoding)
+ l.append(item)
+ return (l,remain)
+ if type == "}":
+ d = {}
+ while data:
+ (key,data) = pop(data,encoding)
+ (val,data) = pop(data,encoding)
+ d[key] = val
+ return (d,remain)
+ raise ValueError("unknown type tag")
+
+
+
+# Use the c-extension version if available
+try:
+ import _tnetstring
+except ImportError:
+ pass
+else:
+ dumps = _tnetstring.dumps
+ load = _tnetstring.load
+ loads = _tnetstring.loads
+ pop = _tnetstring.pop
+
+
diff --git a/test/test_netstring.py b/test/test_netstring.py
deleted file mode 100644
index 5c4f775a..00000000
--- a/test/test_netstring.py
+++ /dev/null
@@ -1,65 +0,0 @@
-from libmproxy import netstring
-from cStringIO import StringIO
-import libpry
-
-
-class uNetstring(libpry.AutoTree):
- def setUp(self):
- self.test_data = "Netstring module by Will McGugan"
- self.encoded_data = "9:Netstring,6:module,2:by,4:Will,7:McGugan,"
-
- def test_header(self):
- t = [ ("netstring", "9:"),
- ("Will McGugan", "12:"),
- ("", "0:") ]
- for test, result in t:
- assert netstring.header(test) == result
-
- def test_file_encoder(self):
- file_out = StringIO()
- data = self.test_data.split()
- encoder = netstring.FileEncoder(file_out)
- for s in data:
- encoder.write(s)
- encoded_data = file_out.getvalue()
- assert encoded_data == self.encoded_data
-
- def test_decode_file(self):
- data = self.test_data.split()
- for buffer_size in range(1, len(self.encoded_data)):
- file_in = StringIO(self.encoded_data[:])
- decoded_data = list(netstring.decode_file(file_in, buffer_size = buffer_size))
- assert decoded_data == data
-
- def test_decoder(self):
- encoded_data = self.encoded_data
- for step in range(1, len(encoded_data)):
- i = 0
- chunks = []
- while i < len(encoded_data):
- chunks.append(encoded_data[i:i+step])
- i += step
- decoder = netstring.Decoder()
- decoded_data = []
- for chunk in chunks:
- for s in decoder.feed(chunk):
- decoded_data.append(s)
- assert decoded_data == self.test_data.split()
-
- def test_errors(self):
- d = netstring.Decoder()
- libpry.raises("Illegal digit", list, d.feed("1:foo"))
- d = netstring.Decoder()
- libpry.raises("Preceding zero", list, d.feed("01:f"))
- d = netstring.Decoder(5)
- libpry.raises("Maximum size", list, d.feed("500:f"))
- d = netstring.Decoder()
- libpry.raises("Illegal digit", list, d.feed(":f"))
-
-
-
-
-tests = [
- uNetstring()
-]
-