diff options
-rw-r--r-- | mitmproxy/contrib/tnetstring.py (renamed from mitmproxy/tnetstring.py) | 280 | ||||
-rw-r--r-- | mitmproxy/flow/io.py | 2 | ||||
-rw-r--r-- | test/mitmproxy/test_contrib_tnetstring.py | 141 | ||||
-rw-r--r-- | test/mitmproxy/test_flow.py | 3 | ||||
-rw-r--r-- | tox.ini | 2 |
5 files changed, 273 insertions, 155 deletions
diff --git a/mitmproxy/tnetstring.py b/mitmproxy/contrib/tnetstring.py index f40e8ad8..9bf20b09 100644 --- a/mitmproxy/tnetstring.py +++ b/mitmproxy/contrib/tnetstring.py @@ -79,9 +79,8 @@ __version__ = "%d.%d.%d%s" % ( __ver_major__, __ver_minor__, __ver_patch__, __ver_sub__) -def dumps(value, encoding=None): - """dumps(object,encoding=None) -> string - +def dumps(value): + """ This function dumps a python object as a tnetstring. """ # This uses a deque to collect output fragments in reverse order, @@ -91,22 +90,21 @@ def dumps(value, encoding=None): # consider the _gdumps() function instead; it's a standard top-down # generator that's simpler to understand but much less efficient. q = deque() - _rdumpq(q, 0, value, encoding) - return "".join(q) - + _rdumpq(q, 0, value) + return b''.join(q) -def dump(value, file, encoding=None): - """dump(object,file,encoding=None) - This function dumps a python object as a tnetstring and writes it to - the given file. +def dump(value, file_handle): """ - file.write(dumps(value, encoding)) - file.flush() + This function dumps a python object as a tnetstring and + writes it to the given file. + """ + file_handle.write(dumps(value)) -def _rdumpq(q, size, value, encoding=None): - """Dump value as a tnetstring, to a deque instance, last chunks first. +def _rdumpq(q, size, value): + """ + Dump value as a tnetstring, to a deque instance, last chunks first. This function generates the tnetstring representation of the given value, pushing chunks of the output onto the given deque instance. It pushes @@ -122,79 +120,70 @@ def _rdumpq(q, size, value, encoding=None): """ write = q.appendleft if value is None: - write("0:~") + write(b'0:~') return size + 3 - if value is True: - write("4:true!") + elif value is True: + write(b'4:true!') return size + 7 - if value is False: - write("5:false!") + elif value is False: + write(b'5:false!') return size + 8 - if isinstance(value, six.integer_types): - data = str(value) + elif isinstance(value, six.integer_types): + data = str(value).encode() ldata = len(data) - span = str(ldata) - write("#") + span = str(ldata).encode() + write(b'#') write(data) - write(":") + write(b':') write(span) return size + 2 + len(span) + ldata - if isinstance(value, (float,)): + elif isinstance(value, float): # Use repr() for float rather than str(). # It round-trips more accurately. # Probably unnecessary in later python versions that # use David Gay's ftoa routines. - data = repr(value) + data = repr(value).encode() ldata = len(data) - span = str(ldata) - write("^") + span = str(ldata).encode() + write(b'^') write(data) - write(":") + write(b':') write(span) return size + 2 + len(span) + ldata - if isinstance(value, str): + elif isinstance(value, bytes): lvalue = len(value) - span = str(lvalue) - write(",") + span = str(lvalue).encode() + write(b',') write(value) - write(":") + write(b':') write(span) return size + 2 + len(span) + lvalue - if isinstance(value, (list, tuple,)): - write("]") + elif isinstance(value, (list, tuple)): + write(b']') init_size = size = size + 1 for item in reversed(value): - size = _rdumpq(q, size, item, encoding) - span = str(size - init_size) - write(":") + size = _rdumpq(q, size, item) + span = str(size - init_size).encode() + write(b':') write(span) return size + 1 + len(span) - if isinstance(value, dict): - write("}") + elif isinstance(value, dict): + write(b'}') init_size = size = size + 1 - for (k, v) in six.iteritems(value): - size = _rdumpq(q, size, v, encoding) - size = _rdumpq(q, size, k, encoding) - span = str(size - init_size) - write(":") + for (k, v) in value.items(): + size = _rdumpq(q, size, v) + size = _rdumpq(q, size, k) + span = str(size - init_size).encode() + write(b':') write(span) return size + 1 + len(span) - if isinstance(value, unicode): - if encoding is None: - raise ValueError("must specify encoding to dump unicode strings") - value = value.encode(encoding) - lvalue = len(value) - span = str(lvalue) - write(",") - write(value) - write(":") - write(span) - return size + 2 + len(span) + lvalue - raise ValueError("unserializable object") + else: + raise ValueError("unserializable object: {} ({})".format(value, type(value))) -def _gdumps(value, encoding): - """Generate fragments of value dumped as a tnetstring. +def _gdumps(value): + """ + Generate fragments of value dumped as a tnetstring. This is the naive dumping algorithm, implemented as a generator so that it's easy to pass to "".join() without building a new list. @@ -203,72 +192,63 @@ def _gdumps(value, encoding): measurably faster as it doesn't have to build intermediate strins. """ if value is None: - yield "0:~" + yield b'0:~' elif value is True: - yield "4:true!" + yield b'4:true!' elif value is False: - yield "5:false!" + yield b'5:false!' elif isinstance(value, six.integer_types): - data = str(value) - yield str(len(data)) - yield ":" + data = str(value).encode() + yield str(len(data)).encode() + yield b':' yield data - yield "#" - elif isinstance(value, (float,)): - data = repr(value) - yield str(len(data)) - yield ":" + yield b'#' + elif isinstance(value, float): + data = repr(value).encode() + yield str(len(data)).encode() + yield b':' yield data - yield "^" - elif isinstance(value, (str,)): - yield str(len(value)) - yield ":" + yield b'^' + elif isinstance(value, bytes): + yield str(len(value)).encode() + yield b':' yield value - yield "," - elif isinstance(value, (list, tuple,)): + yield b',' + elif isinstance(value, (list, tuple)): sub = [] for item in value: sub.extend(_gdumps(item)) - sub = "".join(sub) - yield str(len(sub)) - yield ":" + sub = b''.join(sub) + yield str(len(sub)).encode() + yield b':' yield sub - yield "]" + yield b']' elif isinstance(value, (dict,)): sub = [] - for (k, v) in six.iteritems(value): + for (k, v) in value.items(): sub.extend(_gdumps(k)) sub.extend(_gdumps(v)) - sub = "".join(sub) - yield str(len(sub)) - yield ":" + sub = b''.join(sub) + yield str(len(sub)).encode() + yield b':' yield sub - yield "}" - elif isinstance(value, (unicode,)): - if encoding is None: - raise ValueError("must specify encoding to dump unicode strings") - value = value.encode(encoding) - yield str(len(value)) - yield ":" - yield value - yield "," + yield b'}' else: raise ValueError("unserializable object") -def loads(string, encoding=None): - """loads(string,encoding=None) -> object - +def loads(string): + """ This function parses a tnetstring into a python object. """ # No point duplicating effort here. In the C-extension version, # loads() is measurably faster then pop() since it can avoid # the overhead of building a second string. - return pop(string, encoding)[0] + return pop(string)[0] -def load(file, encoding=None): - """load(file,encoding=None) -> object +def load(file_handle): + """load(file) -> object This function reads a tnetstring from a file and parses it into a python object. The file must support the read() method, and this @@ -276,70 +256,68 @@ def load(file, encoding=None): """ # Read the length prefix one char at a time. # Note that the netstring spec explicitly forbids padding zeros. - c = file.read(1) + c = file_handle.read(1) if not c.isdigit(): raise ValueError("not a tnetstring: missing or invalid length prefix") - datalen = ord(c) - ord("0") - c = file.read(1) + datalen = ord(c) - ord('0') + c = file_handle.read(1) if datalen != 0: while c.isdigit(): - datalen = (10 * datalen) + (ord(c) - ord("0")) + datalen = (10 * datalen) + (ord(c) - ord('0')) if datalen > 999999999: errmsg = "not a tnetstring: absurdly large length prefix" raise ValueError(errmsg) - c = file.read(1) - if c != ":": + c = file_handle.read(1) + if c != b':': raise ValueError("not a tnetstring: missing or invalid length prefix") # Now we can read and parse the payload. # This repeats the dispatch logic of pop() so we can avoid # re-constructing the outermost tnetstring. - data = file.read(datalen) + data = file_handle.read(datalen) if len(data) != datalen: raise ValueError("not a tnetstring: length prefix too big") - type = file.read(1) - if type == ",": - if encoding is not None: - return data.decode(encoding) + tns_type = file_handle.read(1) + if tns_type == b',': return data - if type == "#": + if tns_type == b'#': try: return int(data) except ValueError: raise ValueError("not a tnetstring: invalid integer literal") - if type == "^": + if tns_type == b'^': try: return float(data) except ValueError: raise ValueError("not a tnetstring: invalid float literal") - if type == "!": - if data == "true": + if tns_type == b'!': + if data == b'true': return True - elif data == "false": + elif data == b'false': return False else: raise ValueError("not a tnetstring: invalid boolean literal") - if type == "~": + if tns_type == b'~': if data: raise ValueError("not a tnetstring: invalid null literal") return None - if type == "]": + if tns_type == b']': l = [] while data: - (item, data) = pop(data, encoding) + item, data = pop(data) l.append(item) return l - if type == "}": + if tns_type == b'}': d = {} while data: - (key, data) = pop(data, encoding) - (val, data) = pop(data, encoding) + key, data = pop(data) + val, data = pop(data) d[key] = val return d raise ValueError("unknown type tag") -def pop(string, encoding=None): - """pop(string,encoding=None) -> (object, remain) +def pop(string): + """pop(string,encoding='utf_8') -> (object, remain) This function parses a tnetstring into a python object. It returns a tuple giving the parsed object and a string @@ -347,53 +325,51 @@ def pop(string, encoding=None): """ # Parse out data length, type and remaining string. try: - (dlen, rest) = string.split(":", 1) + dlen, rest = string.split(b':', 1) dlen = int(dlen) except ValueError: - raise ValueError("not a tnetstring: missing or invalid length prefix") + raise ValueError("not a tnetstring: missing or invalid length prefix: {}".format(string)) try: - (data, type, remain) = (rest[:dlen], rest[dlen], rest[dlen + 1:]) + data, tns_type, remain = rest[:dlen], rest[dlen:dlen + 1], rest[dlen + 1:] except IndexError: # This fires if len(rest) < dlen, meaning we don't need # to further validate that data is the right length. - raise ValueError("not a tnetstring: invalid length prefix") + raise ValueError("not a tnetstring: invalid length prefix: {}".format(dlen)) # Parse the data based on the type tag. - if type == ",": - if encoding is not None: - return (data.decode(encoding), remain) - return (data, remain) - if type == "#": + if tns_type == b',': + return data, remain + if tns_type == b'#': try: - return (int(data), remain) + return int(data), remain except ValueError: - raise ValueError("not a tnetstring: invalid integer literal") - if type == "^": + raise ValueError("not a tnetstring: invalid integer literal: {}".format(data)) + if tns_type == b'^': try: - return (float(data), remain) + return float(data), remain except ValueError: - raise ValueError("not a tnetstring: invalid float literal") - if type == "!": - if data == "true": - return (True, remain) - elif data == "false": - return (False, remain) + raise ValueError("not a tnetstring: invalid float literal: {}".format(data)) + if tns_type == b'!': + if data == b'true': + return True, remain + elif data == b'false': + return False, remain else: - raise ValueError("not a tnetstring: invalid boolean literal") - if type == "~": + raise ValueError("not a tnetstring: invalid boolean literal: {}".format(data)) + if tns_type == b'~': if data: raise ValueError("not a tnetstring: invalid null literal") - return (None, remain) - if type == "]": + return None, remain + if tns_type == b']': l = [] while data: - (item, data) = pop(data, encoding) + item, data = pop(data) l.append(item) return (l, remain) - if type == "}": + if tns_type == b'}': d = {} while data: - (key, data) = pop(data, encoding) - (val, data) = pop(data, encoding) + key, data = pop(data) + val, data = pop(data) d[key] = val - return (d, remain) - raise ValueError("unknown type tag") + return d, remain + raise ValueError("unknown type tag: {}".format(tns_type)) diff --git a/mitmproxy/flow/io.py b/mitmproxy/flow/io.py index cd3d9986..671ddf43 100644 --- a/mitmproxy/flow/io.py +++ b/mitmproxy/flow/io.py @@ -4,7 +4,7 @@ import os from mitmproxy import exceptions from mitmproxy import models -from mitmproxy import tnetstring +from mitmproxy.contrib import tnetstring from mitmproxy.flow import io_compat diff --git a/test/mitmproxy/test_contrib_tnetstring.py b/test/mitmproxy/test_contrib_tnetstring.py new file mode 100644 index 00000000..17654ad9 --- /dev/null +++ b/test/mitmproxy/test_contrib_tnetstring.py @@ -0,0 +1,141 @@ +import unittest +import random +import math +import io +import struct + +from mitmproxy.contrib import tnetstring + +MAXINT = 2 ** (struct.Struct('i').size * 8 - 1) - 1 + +FORMAT_EXAMPLES = { + b'0:}': {}, + b'0:]': [], + b'51:5:hello,39:11:12345678901#4:this,4:true!0:~4:\x00\x00\x00\x00,]}': + {b'hello': [12345678901, b'this', True, None, b'\x00\x00\x00\x00']}, + b'5:12345#': 12345, + b'12:this is cool,': b'this is cool', + b'0:,': b'', + b'0:~': None, + b'4:true!': True, + b'5:false!': False, + b'10:\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00,': b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', + b'24:5:12345#5:67890#5:xxxxx,]': [12345, 67890, b'xxxxx'], + b'18:3:0.1^3:0.2^3:0.3^]': [0.1, 0.2, 0.3], + b'243:238:233:228:223:218:213:208:203:198:193:188:183:178:173:168:163:158:153:148:143:138:133:128:123:118:113:108:103:99:95:91:87:83:79:75:71:67:63:59:55:51:47:43:39:35:31:27:23:19:15:11:hello-there,]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]': [[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[b'hello-there']]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]] # noqa +} + + +def get_random_object(random=random, depth=0): + """Generate a random serializable object.""" + # The probability of generating a scalar value increases as the depth increase. + # This ensures that we bottom out eventually. + if random.randint(depth, 10) <= 4: + what = random.randint(0, 1) + if what == 0: + n = random.randint(0, 10) + l = [] + for _ in range(n): + l.append(get_random_object(random, depth + 1)) + return l + if what == 1: + n = random.randint(0, 10) + d = {} + for _ in range(n): + n = random.randint(0, 100) + k = bytes([random.randint(32, 126) for _ in range(n)]) + d[k] = get_random_object(random, depth + 1) + return d + else: + what = random.randint(0, 4) + if what == 0: + return None + if what == 1: + return True + if what == 2: + return False + if what == 3: + if random.randint(0, 1) == 0: + return random.randint(0, MAXINT) + else: + return -1 * random.randint(0, MAXINT) + n = random.randint(0, 100) + return bytes([random.randint(32, 126) for _ in range(n)]) + + +class Test_Format(unittest.TestCase): + + def test_roundtrip_format_examples(self): + for data, expect in FORMAT_EXAMPLES.items(): + self.assertEqual(expect, tnetstring.loads(data)) + self.assertEqual( + expect, tnetstring.loads(tnetstring.dumps(expect))) + self.assertEqual((expect, b''), tnetstring.pop(data)) + + def test_roundtrip_format_random(self): + for _ in range(500): + v = get_random_object() + self.assertEqual(v, tnetstring.loads(tnetstring.dumps(v))) + self.assertEqual((v, b""), tnetstring.pop(tnetstring.dumps(v))) + + def test_unicode_handling(self): + with self.assertRaises(ValueError): + tnetstring.dumps(u"hello") + self.assertEqual(tnetstring.dumps(u"hello".encode()), b"5:hello,") + self.assertEqual(type(tnetstring.loads(b"5:hello,")), bytes) + + def test_roundtrip_format_unicode(self): + for _ in range(500): + v = get_random_object() + self.assertEqual(v, tnetstring.loads(tnetstring.dumps(v))) + self.assertEqual((v, b''), tnetstring.pop(tnetstring.dumps(v))) + + def test_roundtrip_big_integer(self): + i1 = math.factorial(30000) + s = tnetstring.dumps(i1) + i2 = tnetstring.loads(s) + self.assertEqual(i1, i2) + + +class Test_FileLoading(unittest.TestCase): + + def test_roundtrip_file_examples(self): + for data, expect in FORMAT_EXAMPLES.items(): + s = io.BytesIO() + s.write(data) + s.write(b'OK') + s.seek(0) + self.assertEqual(expect, tnetstring.load(s)) + self.assertEqual(b'OK', s.read()) + s = io.BytesIO() + tnetstring.dump(expect, s) + s.write(b'OK') + s.seek(0) + self.assertEqual(expect, tnetstring.load(s)) + self.assertEqual(b'OK', s.read()) + + def test_roundtrip_file_random(self): + for _ in range(500): + v = get_random_object() + s = io.BytesIO() + tnetstring.dump(v, s) + s.write(b'OK') + s.seek(0) + self.assertEqual(v, tnetstring.load(s)) + self.assertEqual(b'OK', s.read()) + + def test_error_on_absurd_lengths(self): + s = io.BytesIO() + s.write(b'1000000000:pwned!,') + s.seek(0) + with self.assertRaises(ValueError): + tnetstring.load(s) + self.assertEqual(s.read(1), b':') + + +def suite(): + loader = unittest.TestLoader() + suite = unittest.TestSuite() + suite.addTest(loader.loadTestsFromTestCase(Test_Format)) + suite.addTest(loader.loadTestsFromTestCase(Test_FileLoading)) + return suite diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index af8256c4..9eaab9aa 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -5,7 +5,8 @@ import mock import netlib.utils from netlib.http import Headers -from mitmproxy import filt, controller, tnetstring, flow +from mitmproxy import filt, controller, flow +from mitmproxy.contrib import tnetstring from mitmproxy.exceptions import FlowReadException, ScriptException from mitmproxy.models import Error from mitmproxy.models import Flow @@ -7,7 +7,7 @@ deps = codecov>=2.0.5 passenv = CI TRAVIS_BUILD_ID TRAVIS TRAVIS_BRANCH TRAVIS_JOB_NUMBER TRAVIS_PULL_REQUEST TRAVIS_JOB_ID TRAVIS_REPO_SLUG TRAVIS_COMMIT setenv = - PY3TESTS = test/netlib test/pathod/ test/mitmproxy/script test/mitmproxy/test_contentview.py test/mitmproxy/test_custom_contentview.py test/mitmproxy/test_app.py test/mitmproxy/test_controller.py test/mitmproxy/test_fuzzing.py test/mitmproxy/test_script.py test/mitmproxy/test_web_app.py test/mitmproxy/test_utils.py test/mitmproxy/test_stateobject.py test/mitmproxy/test_cmdline.py + PY3TESTS = test/netlib test/pathod/ test/mitmproxy/script test/mitmproxy/test_contentview.py test/mitmproxy/test_custom_contentview.py test/mitmproxy/test_app.py test/mitmproxy/test_controller.py test/mitmproxy/test_fuzzing.py test/mitmproxy/test_script.py test/mitmproxy/test_web_app.py test/mitmproxy/test_utils.py test/mitmproxy/test_stateobject.py test/mitmproxy/test_cmdline.py test/mitmproxy/test_contrib_tnetstring.py [testenv:py27] commands = |