aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mitmproxy/contrib/tnetstring.py (renamed from mitmproxy/tnetstring.py)280
-rw-r--r--mitmproxy/flow/io.py2
-rw-r--r--test/mitmproxy/test_contrib_tnetstring.py141
-rw-r--r--test/mitmproxy/test_flow.py3
-rw-r--r--tox.ini2
5 files changed, 273 insertions, 155 deletions
diff --git a/mitmproxy/tnetstring.py b/mitmproxy/contrib/tnetstring.py
index f40e8ad8..9bf20b09 100644
--- a/mitmproxy/tnetstring.py
+++ b/mitmproxy/contrib/tnetstring.py
@@ -79,9 +79,8 @@ __version__ = "%d.%d.%d%s" % (
__ver_major__, __ver_minor__, __ver_patch__, __ver_sub__)
-def dumps(value, encoding=None):
- """dumps(object,encoding=None) -> string
-
+def dumps(value):
+ """
This function dumps a python object as a tnetstring.
"""
# This uses a deque to collect output fragments in reverse order,
@@ -91,22 +90,21 @@ def dumps(value, encoding=None):
# consider the _gdumps() function instead; it's a standard top-down
# generator that's simpler to understand but much less efficient.
q = deque()
- _rdumpq(q, 0, value, encoding)
- return "".join(q)
-
+ _rdumpq(q, 0, value)
+ return b''.join(q)
-def dump(value, file, encoding=None):
- """dump(object,file,encoding=None)
- This function dumps a python object as a tnetstring and writes it to
- the given file.
+def dump(value, file_handle):
"""
- file.write(dumps(value, encoding))
- file.flush()
+ This function dumps a python object as a tnetstring and
+ writes it to the given file.
+ """
+ file_handle.write(dumps(value))
-def _rdumpq(q, size, value, encoding=None):
- """Dump value as a tnetstring, to a deque instance, last chunks first.
+def _rdumpq(q, size, value):
+ """
+ Dump value as a tnetstring, to a deque instance, last chunks first.
This function generates the tnetstring representation of the given value,
pushing chunks of the output onto the given deque instance. It pushes
@@ -122,79 +120,70 @@ def _rdumpq(q, size, value, encoding=None):
"""
write = q.appendleft
if value is None:
- write("0:~")
+ write(b'0:~')
return size + 3
- if value is True:
- write("4:true!")
+ elif value is True:
+ write(b'4:true!')
return size + 7
- if value is False:
- write("5:false!")
+ elif value is False:
+ write(b'5:false!')
return size + 8
- if isinstance(value, six.integer_types):
- data = str(value)
+ elif isinstance(value, six.integer_types):
+ data = str(value).encode()
ldata = len(data)
- span = str(ldata)
- write("#")
+ span = str(ldata).encode()
+ write(b'#')
write(data)
- write(":")
+ write(b':')
write(span)
return size + 2 + len(span) + ldata
- if isinstance(value, (float,)):
+ elif isinstance(value, float):
# Use repr() for float rather than str().
# It round-trips more accurately.
# Probably unnecessary in later python versions that
# use David Gay's ftoa routines.
- data = repr(value)
+ data = repr(value).encode()
ldata = len(data)
- span = str(ldata)
- write("^")
+ span = str(ldata).encode()
+ write(b'^')
write(data)
- write(":")
+ write(b':')
write(span)
return size + 2 + len(span) + ldata
- if isinstance(value, str):
+ elif isinstance(value, bytes):
lvalue = len(value)
- span = str(lvalue)
- write(",")
+ span = str(lvalue).encode()
+ write(b',')
write(value)
- write(":")
+ write(b':')
write(span)
return size + 2 + len(span) + lvalue
- if isinstance(value, (list, tuple,)):
- write("]")
+ elif isinstance(value, (list, tuple)):
+ write(b']')
init_size = size = size + 1
for item in reversed(value):
- size = _rdumpq(q, size, item, encoding)
- span = str(size - init_size)
- write(":")
+ size = _rdumpq(q, size, item)
+ span = str(size - init_size).encode()
+ write(b':')
write(span)
return size + 1 + len(span)
- if isinstance(value, dict):
- write("}")
+ elif isinstance(value, dict):
+ write(b'}')
init_size = size = size + 1
- for (k, v) in six.iteritems(value):
- size = _rdumpq(q, size, v, encoding)
- size = _rdumpq(q, size, k, encoding)
- span = str(size - init_size)
- write(":")
+ for (k, v) in value.items():
+ size = _rdumpq(q, size, v)
+ size = _rdumpq(q, size, k)
+ span = str(size - init_size).encode()
+ write(b':')
write(span)
return size + 1 + len(span)
- if isinstance(value, unicode):
- if encoding is None:
- raise ValueError("must specify encoding to dump unicode strings")
- value = value.encode(encoding)
- lvalue = len(value)
- span = str(lvalue)
- write(",")
- write(value)
- write(":")
- write(span)
- return size + 2 + len(span) + lvalue
- raise ValueError("unserializable object")
+ else:
+ raise ValueError("unserializable object: {} ({})".format(value, type(value)))
-def _gdumps(value, encoding):
- """Generate fragments of value dumped as a tnetstring.
+def _gdumps(value):
+ """
+ Generate fragments of value dumped as a tnetstring.
This is the naive dumping algorithm, implemented as a generator so that
it's easy to pass to "".join() without building a new list.
@@ -203,72 +192,63 @@ def _gdumps(value, encoding):
measurably faster as it doesn't have to build intermediate strins.
"""
if value is None:
- yield "0:~"
+ yield b'0:~'
elif value is True:
- yield "4:true!"
+ yield b'4:true!'
elif value is False:
- yield "5:false!"
+ yield b'5:false!'
elif isinstance(value, six.integer_types):
- data = str(value)
- yield str(len(data))
- yield ":"
+ data = str(value).encode()
+ yield str(len(data)).encode()
+ yield b':'
yield data
- yield "#"
- elif isinstance(value, (float,)):
- data = repr(value)
- yield str(len(data))
- yield ":"
+ yield b'#'
+ elif isinstance(value, float):
+ data = repr(value).encode()
+ yield str(len(data)).encode()
+ yield b':'
yield data
- yield "^"
- elif isinstance(value, (str,)):
- yield str(len(value))
- yield ":"
+ yield b'^'
+ elif isinstance(value, bytes):
+ yield str(len(value)).encode()
+ yield b':'
yield value
- yield ","
- elif isinstance(value, (list, tuple,)):
+ yield b','
+ elif isinstance(value, (list, tuple)):
sub = []
for item in value:
sub.extend(_gdumps(item))
- sub = "".join(sub)
- yield str(len(sub))
- yield ":"
+ sub = b''.join(sub)
+ yield str(len(sub)).encode()
+ yield b':'
yield sub
- yield "]"
+ yield b']'
elif isinstance(value, (dict,)):
sub = []
- for (k, v) in six.iteritems(value):
+ for (k, v) in value.items():
sub.extend(_gdumps(k))
sub.extend(_gdumps(v))
- sub = "".join(sub)
- yield str(len(sub))
- yield ":"
+ sub = b''.join(sub)
+ yield str(len(sub)).encode()
+ yield b':'
yield sub
- yield "}"
- elif isinstance(value, (unicode,)):
- if encoding is None:
- raise ValueError("must specify encoding to dump unicode strings")
- value = value.encode(encoding)
- yield str(len(value))
- yield ":"
- yield value
- yield ","
+ yield b'}'
else:
raise ValueError("unserializable object")
-def loads(string, encoding=None):
- """loads(string,encoding=None) -> object
-
+def loads(string):
+ """
This function parses a tnetstring into a python object.
"""
# No point duplicating effort here. In the C-extension version,
# loads() is measurably faster then pop() since it can avoid
# the overhead of building a second string.
- return pop(string, encoding)[0]
+ return pop(string)[0]
-def load(file, encoding=None):
- """load(file,encoding=None) -> object
+def load(file_handle):
+ """load(file) -> object
This function reads a tnetstring from a file and parses it into a
python object. The file must support the read() method, and this
@@ -276,70 +256,68 @@ def load(file, encoding=None):
"""
# Read the length prefix one char at a time.
# Note that the netstring spec explicitly forbids padding zeros.
- c = file.read(1)
+ c = file_handle.read(1)
if not c.isdigit():
raise ValueError("not a tnetstring: missing or invalid length prefix")
- datalen = ord(c) - ord("0")
- c = file.read(1)
+ datalen = ord(c) - ord('0')
+ c = file_handle.read(1)
if datalen != 0:
while c.isdigit():
- datalen = (10 * datalen) + (ord(c) - ord("0"))
+ datalen = (10 * datalen) + (ord(c) - ord('0'))
if datalen > 999999999:
errmsg = "not a tnetstring: absurdly large length prefix"
raise ValueError(errmsg)
- c = file.read(1)
- if c != ":":
+ c = file_handle.read(1)
+ if c != b':':
raise ValueError("not a tnetstring: missing or invalid length prefix")
# Now we can read and parse the payload.
# This repeats the dispatch logic of pop() so we can avoid
# re-constructing the outermost tnetstring.
- data = file.read(datalen)
+ data = file_handle.read(datalen)
if len(data) != datalen:
raise ValueError("not a tnetstring: length prefix too big")
- type = file.read(1)
- if type == ",":
- if encoding is not None:
- return data.decode(encoding)
+ tns_type = file_handle.read(1)
+ if tns_type == b',':
return data
- if type == "#":
+ if tns_type == b'#':
try:
return int(data)
except ValueError:
raise ValueError("not a tnetstring: invalid integer literal")
- if type == "^":
+ if tns_type == b'^':
try:
return float(data)
except ValueError:
raise ValueError("not a tnetstring: invalid float literal")
- if type == "!":
- if data == "true":
+ if tns_type == b'!':
+ if data == b'true':
return True
- elif data == "false":
+ elif data == b'false':
return False
else:
raise ValueError("not a tnetstring: invalid boolean literal")
- if type == "~":
+ if tns_type == b'~':
if data:
raise ValueError("not a tnetstring: invalid null literal")
return None
- if type == "]":
+ if tns_type == b']':
l = []
while data:
- (item, data) = pop(data, encoding)
+ item, data = pop(data)
l.append(item)
return l
- if type == "}":
+ if tns_type == b'}':
d = {}
while data:
- (key, data) = pop(data, encoding)
- (val, data) = pop(data, encoding)
+ key, data = pop(data)
+ val, data = pop(data)
d[key] = val
return d
raise ValueError("unknown type tag")
-def pop(string, encoding=None):
- """pop(string,encoding=None) -> (object, remain)
+def pop(string):
+ """pop(string,encoding='utf_8') -> (object, remain)
This function parses a tnetstring into a python object.
It returns a tuple giving the parsed object and a string
@@ -347,53 +325,51 @@ def pop(string, encoding=None):
"""
# Parse out data length, type and remaining string.
try:
- (dlen, rest) = string.split(":", 1)
+ dlen, rest = string.split(b':', 1)
dlen = int(dlen)
except ValueError:
- raise ValueError("not a tnetstring: missing or invalid length prefix")
+ raise ValueError("not a tnetstring: missing or invalid length prefix: {}".format(string))
try:
- (data, type, remain) = (rest[:dlen], rest[dlen], rest[dlen + 1:])
+ data, tns_type, remain = rest[:dlen], rest[dlen:dlen + 1], rest[dlen + 1:]
except IndexError:
# This fires if len(rest) < dlen, meaning we don't need
# to further validate that data is the right length.
- raise ValueError("not a tnetstring: invalid length prefix")
+ raise ValueError("not a tnetstring: invalid length prefix: {}".format(dlen))
# Parse the data based on the type tag.
- if type == ",":
- if encoding is not None:
- return (data.decode(encoding), remain)
- return (data, remain)
- if type == "#":
+ if tns_type == b',':
+ return data, remain
+ if tns_type == b'#':
try:
- return (int(data), remain)
+ return int(data), remain
except ValueError:
- raise ValueError("not a tnetstring: invalid integer literal")
- if type == "^":
+ raise ValueError("not a tnetstring: invalid integer literal: {}".format(data))
+ if tns_type == b'^':
try:
- return (float(data), remain)
+ return float(data), remain
except ValueError:
- raise ValueError("not a tnetstring: invalid float literal")
- if type == "!":
- if data == "true":
- return (True, remain)
- elif data == "false":
- return (False, remain)
+ raise ValueError("not a tnetstring: invalid float literal: {}".format(data))
+ if tns_type == b'!':
+ if data == b'true':
+ return True, remain
+ elif data == b'false':
+ return False, remain
else:
- raise ValueError("not a tnetstring: invalid boolean literal")
- if type == "~":
+ raise ValueError("not a tnetstring: invalid boolean literal: {}".format(data))
+ if tns_type == b'~':
if data:
raise ValueError("not a tnetstring: invalid null literal")
- return (None, remain)
- if type == "]":
+ return None, remain
+ if tns_type == b']':
l = []
while data:
- (item, data) = pop(data, encoding)
+ item, data = pop(data)
l.append(item)
return (l, remain)
- if type == "}":
+ if tns_type == b'}':
d = {}
while data:
- (key, data) = pop(data, encoding)
- (val, data) = pop(data, encoding)
+ key, data = pop(data)
+ val, data = pop(data)
d[key] = val
- return (d, remain)
- raise ValueError("unknown type tag")
+ return d, remain
+ raise ValueError("unknown type tag: {}".format(tns_type))
diff --git a/mitmproxy/flow/io.py b/mitmproxy/flow/io.py
index cd3d9986..671ddf43 100644
--- a/mitmproxy/flow/io.py
+++ b/mitmproxy/flow/io.py
@@ -4,7 +4,7 @@ import os
from mitmproxy import exceptions
from mitmproxy import models
-from mitmproxy import tnetstring
+from mitmproxy.contrib import tnetstring
from mitmproxy.flow import io_compat
diff --git a/test/mitmproxy/test_contrib_tnetstring.py b/test/mitmproxy/test_contrib_tnetstring.py
new file mode 100644
index 00000000..17654ad9
--- /dev/null
+++ b/test/mitmproxy/test_contrib_tnetstring.py
@@ -0,0 +1,141 @@
+import unittest
+import random
+import math
+import io
+import struct
+
+from mitmproxy.contrib import tnetstring
+
+MAXINT = 2 ** (struct.Struct('i').size * 8 - 1) - 1
+
+FORMAT_EXAMPLES = {
+ b'0:}': {},
+ b'0:]': [],
+ b'51:5:hello,39:11:12345678901#4:this,4:true!0:~4:\x00\x00\x00\x00,]}':
+ {b'hello': [12345678901, b'this', True, None, b'\x00\x00\x00\x00']},
+ b'5:12345#': 12345,
+ b'12:this is cool,': b'this is cool',
+ b'0:,': b'',
+ b'0:~': None,
+ b'4:true!': True,
+ b'5:false!': False,
+ b'10:\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00,': b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
+ b'24:5:12345#5:67890#5:xxxxx,]': [12345, 67890, b'xxxxx'],
+ b'18:3:0.1^3:0.2^3:0.3^]': [0.1, 0.2, 0.3],
+ b'243:238:233:228:223:218:213:208:203:198:193:188:183:178:173:168:163:158:153:148:143:138:133:128:123:118:113:108:103:99:95:91:87:83:79:75:71:67:63:59:55:51:47:43:39:35:31:27:23:19:15:11:hello-there,]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]': [[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[b'hello-there']]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]] # noqa
+}
+
+
+def get_random_object(random=random, depth=0):
+ """Generate a random serializable object."""
+ # The probability of generating a scalar value increases as the depth increase.
+ # This ensures that we bottom out eventually.
+ if random.randint(depth, 10) <= 4:
+ what = random.randint(0, 1)
+ if what == 0:
+ n = random.randint(0, 10)
+ l = []
+ for _ in range(n):
+ l.append(get_random_object(random, depth + 1))
+ return l
+ if what == 1:
+ n = random.randint(0, 10)
+ d = {}
+ for _ in range(n):
+ n = random.randint(0, 100)
+ k = bytes([random.randint(32, 126) for _ in range(n)])
+ d[k] = get_random_object(random, depth + 1)
+ return d
+ else:
+ what = random.randint(0, 4)
+ if what == 0:
+ return None
+ if what == 1:
+ return True
+ if what == 2:
+ return False
+ if what == 3:
+ if random.randint(0, 1) == 0:
+ return random.randint(0, MAXINT)
+ else:
+ return -1 * random.randint(0, MAXINT)
+ n = random.randint(0, 100)
+ return bytes([random.randint(32, 126) for _ in range(n)])
+
+
+class Test_Format(unittest.TestCase):
+
+ def test_roundtrip_format_examples(self):
+ for data, expect in FORMAT_EXAMPLES.items():
+ self.assertEqual(expect, tnetstring.loads(data))
+ self.assertEqual(
+ expect, tnetstring.loads(tnetstring.dumps(expect)))
+ self.assertEqual((expect, b''), tnetstring.pop(data))
+
+ def test_roundtrip_format_random(self):
+ for _ in range(500):
+ v = get_random_object()
+ self.assertEqual(v, tnetstring.loads(tnetstring.dumps(v)))
+ self.assertEqual((v, b""), tnetstring.pop(tnetstring.dumps(v)))
+
+ def test_unicode_handling(self):
+ with self.assertRaises(ValueError):
+ tnetstring.dumps(u"hello")
+ self.assertEqual(tnetstring.dumps(u"hello".encode()), b"5:hello,")
+ self.assertEqual(type(tnetstring.loads(b"5:hello,")), bytes)
+
+ def test_roundtrip_format_unicode(self):
+ for _ in range(500):
+ v = get_random_object()
+ self.assertEqual(v, tnetstring.loads(tnetstring.dumps(v)))
+ self.assertEqual((v, b''), tnetstring.pop(tnetstring.dumps(v)))
+
+ def test_roundtrip_big_integer(self):
+ i1 = math.factorial(30000)
+ s = tnetstring.dumps(i1)
+ i2 = tnetstring.loads(s)
+ self.assertEqual(i1, i2)
+
+
+class Test_FileLoading(unittest.TestCase):
+
+ def test_roundtrip_file_examples(self):
+ for data, expect in FORMAT_EXAMPLES.items():
+ s = io.BytesIO()
+ s.write(data)
+ s.write(b'OK')
+ s.seek(0)
+ self.assertEqual(expect, tnetstring.load(s))
+ self.assertEqual(b'OK', s.read())
+ s = io.BytesIO()
+ tnetstring.dump(expect, s)
+ s.write(b'OK')
+ s.seek(0)
+ self.assertEqual(expect, tnetstring.load(s))
+ self.assertEqual(b'OK', s.read())
+
+ def test_roundtrip_file_random(self):
+ for _ in range(500):
+ v = get_random_object()
+ s = io.BytesIO()
+ tnetstring.dump(v, s)
+ s.write(b'OK')
+ s.seek(0)
+ self.assertEqual(v, tnetstring.load(s))
+ self.assertEqual(b'OK', s.read())
+
+ def test_error_on_absurd_lengths(self):
+ s = io.BytesIO()
+ s.write(b'1000000000:pwned!,')
+ s.seek(0)
+ with self.assertRaises(ValueError):
+ tnetstring.load(s)
+ self.assertEqual(s.read(1), b':')
+
+
+def suite():
+ loader = unittest.TestLoader()
+ suite = unittest.TestSuite()
+ suite.addTest(loader.loadTestsFromTestCase(Test_Format))
+ suite.addTest(loader.loadTestsFromTestCase(Test_FileLoading))
+ return suite
diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py
index af8256c4..9eaab9aa 100644
--- a/test/mitmproxy/test_flow.py
+++ b/test/mitmproxy/test_flow.py
@@ -5,7 +5,8 @@ import mock
import netlib.utils
from netlib.http import Headers
-from mitmproxy import filt, controller, tnetstring, flow
+from mitmproxy import filt, controller, flow
+from mitmproxy.contrib import tnetstring
from mitmproxy.exceptions import FlowReadException, ScriptException
from mitmproxy.models import Error
from mitmproxy.models import Flow
diff --git a/tox.ini b/tox.ini
index 4837d5b5..3abd6e4c 100644
--- a/tox.ini
+++ b/tox.ini
@@ -7,7 +7,7 @@ deps =
codecov>=2.0.5
passenv = CI TRAVIS_BUILD_ID TRAVIS TRAVIS_BRANCH TRAVIS_JOB_NUMBER TRAVIS_PULL_REQUEST TRAVIS_JOB_ID TRAVIS_REPO_SLUG TRAVIS_COMMIT
setenv =
- PY3TESTS = test/netlib test/pathod/ test/mitmproxy/script test/mitmproxy/test_contentview.py test/mitmproxy/test_custom_contentview.py test/mitmproxy/test_app.py test/mitmproxy/test_controller.py test/mitmproxy/test_fuzzing.py test/mitmproxy/test_script.py test/mitmproxy/test_web_app.py test/mitmproxy/test_utils.py test/mitmproxy/test_stateobject.py test/mitmproxy/test_cmdline.py
+ PY3TESTS = test/netlib test/pathod/ test/mitmproxy/script test/mitmproxy/test_contentview.py test/mitmproxy/test_custom_contentview.py test/mitmproxy/test_app.py test/mitmproxy/test_controller.py test/mitmproxy/test_fuzzing.py test/mitmproxy/test_script.py test/mitmproxy/test_web_app.py test/mitmproxy/test_utils.py test/mitmproxy/test_stateobject.py test/mitmproxy/test_cmdline.py test/mitmproxy/test_contrib_tnetstring.py
[testenv:py27]
commands =