diff options
| -rw-r--r-- | mitmproxy/contrib/tnetstring.py (renamed from mitmproxy/tnetstring.py) | 280 | ||||
| -rw-r--r-- | mitmproxy/flow/io.py | 2 | ||||
| -rw-r--r-- | test/mitmproxy/test_contrib_tnetstring.py | 141 | ||||
| -rw-r--r-- | test/mitmproxy/test_flow.py | 3 | ||||
| -rw-r--r-- | tox.ini | 2 | 
5 files changed, 273 insertions, 155 deletions
| diff --git a/mitmproxy/tnetstring.py b/mitmproxy/contrib/tnetstring.py index f40e8ad8..9bf20b09 100644 --- a/mitmproxy/tnetstring.py +++ b/mitmproxy/contrib/tnetstring.py @@ -79,9 +79,8 @@ __version__ = "%d.%d.%d%s" % (      __ver_major__, __ver_minor__, __ver_patch__, __ver_sub__) -def dumps(value, encoding=None): -    """dumps(object,encoding=None) -> string - +def dumps(value): +    """      This function dumps a python object as a tnetstring.      """      #  This uses a deque to collect output fragments in reverse order, @@ -91,22 +90,21 @@ def dumps(value, encoding=None):      #  consider the _gdumps() function instead; it's a standard top-down      #  generator that's simpler to understand but much less efficient.      q = deque() -    _rdumpq(q, 0, value, encoding) -    return "".join(q) - +    _rdumpq(q, 0, value) +    return b''.join(q) -def dump(value, file, encoding=None): -    """dump(object,file,encoding=None) -    This function dumps a python object as a tnetstring and writes it to -    the given file. +def dump(value, file_handle):      """ -    file.write(dumps(value, encoding)) -    file.flush() +    This function dumps a python object as a tnetstring and +    writes it to the given file. +    """ +    file_handle.write(dumps(value)) -def _rdumpq(q, size, value, encoding=None): -    """Dump value as a tnetstring, to a deque instance, last chunks first. +def _rdumpq(q, size, value): +    """ +    Dump value as a tnetstring, to a deque instance, last chunks first.      This function generates the tnetstring representation of the given value,      pushing chunks of the output onto the given deque instance.  It pushes @@ -122,79 +120,70 @@ def _rdumpq(q, size, value, encoding=None):      """      write = q.appendleft      if value is None: -        write("0:~") +        write(b'0:~')          return size + 3 -    if value is True: -        write("4:true!") +    elif value is True: +        write(b'4:true!')          return size + 7 -    if value is False: -        write("5:false!") +    elif value is False: +        write(b'5:false!')          return size + 8 -    if isinstance(value, six.integer_types): -        data = str(value) +    elif isinstance(value, six.integer_types): +        data = str(value).encode()          ldata = len(data) -        span = str(ldata) -        write("#") +        span = str(ldata).encode() +        write(b'#')          write(data) -        write(":") +        write(b':')          write(span)          return size + 2 + len(span) + ldata -    if isinstance(value, (float,)): +    elif isinstance(value, float):          #  Use repr() for float rather than str().          #  It round-trips more accurately.          #  Probably unnecessary in later python versions that          #  use David Gay's ftoa routines. -        data = repr(value) +        data = repr(value).encode()          ldata = len(data) -        span = str(ldata) -        write("^") +        span = str(ldata).encode() +        write(b'^')          write(data) -        write(":") +        write(b':')          write(span)          return size + 2 + len(span) + ldata -    if isinstance(value, str): +    elif isinstance(value, bytes):          lvalue = len(value) -        span = str(lvalue) -        write(",") +        span = str(lvalue).encode() +        write(b',')          write(value) -        write(":") +        write(b':')          write(span)          return size + 2 + len(span) + lvalue -    if isinstance(value, (list, tuple,)): -        write("]") +    elif isinstance(value, (list, tuple)): +        write(b']')          init_size = size = size + 1          for item in reversed(value): -            size = _rdumpq(q, size, item, encoding) -        span = str(size - init_size) -        write(":") +            size = _rdumpq(q, size, item) +        span = str(size - init_size).encode() +        write(b':')          write(span)          return size + 1 + len(span) -    if isinstance(value, dict): -        write("}") +    elif isinstance(value, dict): +        write(b'}')          init_size = size = size + 1 -        for (k, v) in six.iteritems(value): -            size = _rdumpq(q, size, v, encoding) -            size = _rdumpq(q, size, k, encoding) -        span = str(size - init_size) -        write(":") +        for (k, v) in value.items(): +            size = _rdumpq(q, size, v) +            size = _rdumpq(q, size, k) +        span = str(size - init_size).encode() +        write(b':')          write(span)          return size + 1 + len(span) -    if isinstance(value, unicode): -        if encoding is None: -            raise ValueError("must specify encoding to dump unicode strings") -        value = value.encode(encoding) -        lvalue = len(value) -        span = str(lvalue) -        write(",") -        write(value) -        write(":") -        write(span) -        return size + 2 + len(span) + lvalue -    raise ValueError("unserializable object") +    else: +        raise ValueError("unserializable object: {} ({})".format(value, type(value))) -def _gdumps(value, encoding): -    """Generate fragments of value dumped as a tnetstring. +def _gdumps(value): +    """ +    Generate fragments of value dumped as a tnetstring.      This is the naive dumping algorithm, implemented as a generator so that      it's easy to pass to "".join() without building a new list. @@ -203,72 +192,63 @@ def _gdumps(value, encoding):      measurably faster as it doesn't have to build intermediate strins.      """      if value is None: -        yield "0:~" +        yield b'0:~'      elif value is True: -        yield "4:true!" +        yield b'4:true!'      elif value is False: -        yield "5:false!" +        yield b'5:false!'      elif isinstance(value, six.integer_types): -        data = str(value) -        yield str(len(data)) -        yield ":" +        data = str(value).encode() +        yield str(len(data)).encode() +        yield b':'          yield data -        yield "#" -    elif isinstance(value, (float,)): -        data = repr(value) -        yield str(len(data)) -        yield ":" +        yield b'#' +    elif isinstance(value, float): +        data = repr(value).encode() +        yield str(len(data)).encode() +        yield b':'          yield data -        yield "^" -    elif isinstance(value, (str,)): -        yield str(len(value)) -        yield ":" +        yield b'^' +    elif isinstance(value, bytes): +        yield str(len(value)).encode() +        yield b':'          yield value -        yield "," -    elif isinstance(value, (list, tuple,)): +        yield b',' +    elif isinstance(value, (list, tuple)):          sub = []          for item in value:              sub.extend(_gdumps(item)) -        sub = "".join(sub) -        yield str(len(sub)) -        yield ":" +        sub = b''.join(sub) +        yield str(len(sub)).encode() +        yield b':'          yield sub -        yield "]" +        yield b']'      elif isinstance(value, (dict,)):          sub = [] -        for (k, v) in six.iteritems(value): +        for (k, v) in value.items():              sub.extend(_gdumps(k))              sub.extend(_gdumps(v)) -        sub = "".join(sub) -        yield str(len(sub)) -        yield ":" +        sub = b''.join(sub) +        yield str(len(sub)).encode() +        yield b':'          yield sub -        yield "}" -    elif isinstance(value, (unicode,)): -        if encoding is None: -            raise ValueError("must specify encoding to dump unicode strings") -        value = value.encode(encoding) -        yield str(len(value)) -        yield ":" -        yield value -        yield "," +        yield b'}'      else:          raise ValueError("unserializable object") -def loads(string, encoding=None): -    """loads(string,encoding=None) -> object - +def loads(string): +    """      This function parses a tnetstring into a python object.      """      #  No point duplicating effort here.  In the C-extension version,      #  loads() is measurably faster then pop() since it can avoid      #  the overhead of building a second string. -    return pop(string, encoding)[0] +    return pop(string)[0] -def load(file, encoding=None): -    """load(file,encoding=None) -> object +def load(file_handle): +    """load(file) -> object      This function reads a tnetstring from a file and parses it into a      python object.  The file must support the read() method, and this @@ -276,70 +256,68 @@ def load(file, encoding=None):      """      #  Read the length prefix one char at a time.      #  Note that the netstring spec explicitly forbids padding zeros. -    c = file.read(1) +    c = file_handle.read(1)      if not c.isdigit():          raise ValueError("not a tnetstring: missing or invalid length prefix") -    datalen = ord(c) - ord("0") -    c = file.read(1) +    datalen = ord(c) - ord('0') +    c = file_handle.read(1)      if datalen != 0:          while c.isdigit(): -            datalen = (10 * datalen) + (ord(c) - ord("0")) +            datalen = (10 * datalen) + (ord(c) - ord('0'))              if datalen > 999999999:                  errmsg = "not a tnetstring: absurdly large length prefix"                  raise ValueError(errmsg) -            c = file.read(1) -    if c != ":": +            c = file_handle.read(1) +    if c != b':':          raise ValueError("not a tnetstring: missing or invalid length prefix")      #  Now we can read and parse the payload.      #  This repeats the dispatch logic of pop() so we can avoid      #  re-constructing the outermost tnetstring. -    data = file.read(datalen) +    data = file_handle.read(datalen)      if len(data) != datalen:          raise ValueError("not a tnetstring: length prefix too big") -    type = file.read(1) -    if type == ",": -        if encoding is not None: -            return data.decode(encoding) +    tns_type = file_handle.read(1) +    if tns_type == b',':          return data -    if type == "#": +    if tns_type == b'#':          try:              return int(data)          except ValueError:              raise ValueError("not a tnetstring: invalid integer literal") -    if type == "^": +    if tns_type == b'^':          try:              return float(data)          except ValueError:              raise ValueError("not a tnetstring: invalid float literal") -    if type == "!": -        if data == "true": +    if tns_type == b'!': +        if data == b'true':              return True -        elif data == "false": +        elif data == b'false':              return False          else:              raise ValueError("not a tnetstring: invalid boolean literal") -    if type == "~": +    if tns_type == b'~':          if data:              raise ValueError("not a tnetstring: invalid null literal")          return None -    if type == "]": +    if tns_type == b']':          l = []          while data: -            (item, data) = pop(data, encoding) +            item, data = pop(data)              l.append(item)          return l -    if type == "}": +    if tns_type == b'}':          d = {}          while data: -            (key, data) = pop(data, encoding) -            (val, data) = pop(data, encoding) +            key, data = pop(data) +            val, data = pop(data)              d[key] = val          return d      raise ValueError("unknown type tag") -def pop(string, encoding=None): -    """pop(string,encoding=None) -> (object, remain) +def pop(string): +    """pop(string,encoding='utf_8') -> (object, remain)      This function parses a tnetstring into a python object.      It returns a tuple giving the parsed object and a string @@ -347,53 +325,51 @@ def pop(string, encoding=None):      """      #  Parse out data length, type and remaining string.      try: -        (dlen, rest) = string.split(":", 1) +        dlen, rest = string.split(b':', 1)          dlen = int(dlen)      except ValueError: -        raise ValueError("not a tnetstring: missing or invalid length prefix") +        raise ValueError("not a tnetstring: missing or invalid length prefix: {}".format(string))      try: -        (data, type, remain) = (rest[:dlen], rest[dlen], rest[dlen + 1:]) +        data, tns_type, remain = rest[:dlen], rest[dlen:dlen + 1], rest[dlen + 1:]      except IndexError:          #  This fires if len(rest) < dlen, meaning we don't need          #  to further validate that data is the right length. -        raise ValueError("not a tnetstring: invalid length prefix") +        raise ValueError("not a tnetstring: invalid length prefix: {}".format(dlen))      #  Parse the data based on the type tag. -    if type == ",": -        if encoding is not None: -            return (data.decode(encoding), remain) -        return (data, remain) -    if type == "#": +    if tns_type == b',': +        return data, remain +    if tns_type == b'#':          try: -            return (int(data), remain) +            return int(data), remain          except ValueError: -            raise ValueError("not a tnetstring: invalid integer literal") -    if type == "^": +            raise ValueError("not a tnetstring: invalid integer literal: {}".format(data)) +    if tns_type == b'^':          try: -            return (float(data), remain) +            return float(data), remain          except ValueError: -            raise ValueError("not a tnetstring: invalid float literal") -    if type == "!": -        if data == "true": -            return (True, remain) -        elif data == "false": -            return (False, remain) +            raise ValueError("not a tnetstring: invalid float literal: {}".format(data)) +    if tns_type == b'!': +        if data == b'true': +            return True, remain +        elif data == b'false': +            return False, remain          else: -            raise ValueError("not a tnetstring: invalid boolean literal") -    if type == "~": +            raise ValueError("not a tnetstring: invalid boolean literal: {}".format(data)) +    if tns_type == b'~':          if data:              raise ValueError("not a tnetstring: invalid null literal") -        return (None, remain) -    if type == "]": +        return None, remain +    if tns_type == b']':          l = []          while data: -            (item, data) = pop(data, encoding) +            item, data = pop(data)              l.append(item)          return (l, remain) -    if type == "}": +    if tns_type == b'}':          d = {}          while data: -            (key, data) = pop(data, encoding) -            (val, data) = pop(data, encoding) +            key, data = pop(data) +            val, data = pop(data)              d[key] = val -        return (d, remain) -    raise ValueError("unknown type tag") +        return d, remain +    raise ValueError("unknown type tag: {}".format(tns_type)) diff --git a/mitmproxy/flow/io.py b/mitmproxy/flow/io.py index cd3d9986..671ddf43 100644 --- a/mitmproxy/flow/io.py +++ b/mitmproxy/flow/io.py @@ -4,7 +4,7 @@ import os  from mitmproxy import exceptions  from mitmproxy import models -from mitmproxy import tnetstring +from mitmproxy.contrib import tnetstring  from mitmproxy.flow import io_compat diff --git a/test/mitmproxy/test_contrib_tnetstring.py b/test/mitmproxy/test_contrib_tnetstring.py new file mode 100644 index 00000000..17654ad9 --- /dev/null +++ b/test/mitmproxy/test_contrib_tnetstring.py @@ -0,0 +1,141 @@ +import unittest +import random +import math +import io +import struct + +from mitmproxy.contrib import tnetstring + +MAXINT = 2 ** (struct.Struct('i').size * 8 - 1) - 1 + +FORMAT_EXAMPLES = { +    b'0:}': {}, +    b'0:]': [], +    b'51:5:hello,39:11:12345678901#4:this,4:true!0:~4:\x00\x00\x00\x00,]}': +    {b'hello': [12345678901, b'this', True, None, b'\x00\x00\x00\x00']}, +    b'5:12345#': 12345, +    b'12:this is cool,': b'this is cool', +    b'0:,': b'', +    b'0:~': None, +    b'4:true!': True, +    b'5:false!': False, +    b'10:\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00,': b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', +    b'24:5:12345#5:67890#5:xxxxx,]': [12345, 67890, b'xxxxx'], +    b'18:3:0.1^3:0.2^3:0.3^]': [0.1, 0.2, 0.3], +    b'243:238:233:228:223:218:213:208:203:198:193:188:183:178:173:168:163:158:153:148:143:138:133:128:123:118:113:108:103:99:95:91:87:83:79:75:71:67:63:59:55:51:47:43:39:35:31:27:23:19:15:11:hello-there,]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]': [[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[b'hello-there']]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]  # noqa +} + + +def get_random_object(random=random, depth=0): +    """Generate a random serializable object.""" +    #  The probability of generating a scalar value increases as the depth increase. +    #  This ensures that we bottom out eventually. +    if random.randint(depth, 10) <= 4: +        what = random.randint(0, 1) +        if what == 0: +            n = random.randint(0, 10) +            l = [] +            for _ in range(n): +                l.append(get_random_object(random, depth + 1)) +            return l +        if what == 1: +            n = random.randint(0, 10) +            d = {} +            for _ in range(n): +                n = random.randint(0, 100) +                k = bytes([random.randint(32, 126) for _ in range(n)]) +                d[k] = get_random_object(random, depth + 1) +            return d +    else: +        what = random.randint(0, 4) +        if what == 0: +            return None +        if what == 1: +            return True +        if what == 2: +            return False +        if what == 3: +            if random.randint(0, 1) == 0: +                return random.randint(0, MAXINT) +            else: +                return -1 * random.randint(0, MAXINT) +        n = random.randint(0, 100) +        return bytes([random.randint(32, 126) for _ in range(n)]) + + +class Test_Format(unittest.TestCase): + +    def test_roundtrip_format_examples(self): +        for data, expect in FORMAT_EXAMPLES.items(): +            self.assertEqual(expect, tnetstring.loads(data)) +            self.assertEqual( +                expect, tnetstring.loads(tnetstring.dumps(expect))) +            self.assertEqual((expect, b''), tnetstring.pop(data)) + +    def test_roundtrip_format_random(self): +        for _ in range(500): +            v = get_random_object() +            self.assertEqual(v, tnetstring.loads(tnetstring.dumps(v))) +            self.assertEqual((v, b""), tnetstring.pop(tnetstring.dumps(v))) + +    def test_unicode_handling(self): +        with self.assertRaises(ValueError): +            tnetstring.dumps(u"hello") +        self.assertEqual(tnetstring.dumps(u"hello".encode()), b"5:hello,") +        self.assertEqual(type(tnetstring.loads(b"5:hello,")), bytes) + +    def test_roundtrip_format_unicode(self): +        for _ in range(500): +            v = get_random_object() +            self.assertEqual(v, tnetstring.loads(tnetstring.dumps(v))) +            self.assertEqual((v, b''), tnetstring.pop(tnetstring.dumps(v))) + +    def test_roundtrip_big_integer(self): +        i1 = math.factorial(30000) +        s = tnetstring.dumps(i1) +        i2 = tnetstring.loads(s) +        self.assertEqual(i1, i2) + + +class Test_FileLoading(unittest.TestCase): + +    def test_roundtrip_file_examples(self): +        for data, expect in FORMAT_EXAMPLES.items(): +            s = io.BytesIO() +            s.write(data) +            s.write(b'OK') +            s.seek(0) +            self.assertEqual(expect, tnetstring.load(s)) +            self.assertEqual(b'OK', s.read()) +            s = io.BytesIO() +            tnetstring.dump(expect, s) +            s.write(b'OK') +            s.seek(0) +            self.assertEqual(expect, tnetstring.load(s)) +            self.assertEqual(b'OK', s.read()) + +    def test_roundtrip_file_random(self): +        for _ in range(500): +            v = get_random_object() +            s = io.BytesIO() +            tnetstring.dump(v, s) +            s.write(b'OK') +            s.seek(0) +            self.assertEqual(v, tnetstring.load(s)) +            self.assertEqual(b'OK', s.read()) + +    def test_error_on_absurd_lengths(self): +        s = io.BytesIO() +        s.write(b'1000000000:pwned!,') +        s.seek(0) +        with self.assertRaises(ValueError): +            tnetstring.load(s) +        self.assertEqual(s.read(1), b':') + + +def suite(): +    loader = unittest.TestLoader() +    suite = unittest.TestSuite() +    suite.addTest(loader.loadTestsFromTestCase(Test_Format)) +    suite.addTest(loader.loadTestsFromTestCase(Test_FileLoading)) +    return suite diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index af8256c4..9eaab9aa 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -5,7 +5,8 @@ import mock  import netlib.utils  from netlib.http import Headers -from mitmproxy import filt, controller, tnetstring, flow +from mitmproxy import filt, controller, flow +from mitmproxy.contrib import tnetstring  from mitmproxy.exceptions import FlowReadException, ScriptException  from mitmproxy.models import Error  from mitmproxy.models import Flow @@ -7,7 +7,7 @@ deps =    codecov>=2.0.5  passenv = CI TRAVIS_BUILD_ID TRAVIS TRAVIS_BRANCH TRAVIS_JOB_NUMBER TRAVIS_PULL_REQUEST TRAVIS_JOB_ID TRAVIS_REPO_SLUG TRAVIS_COMMIT  setenv = -    PY3TESTS =  test/netlib test/pathod/ test/mitmproxy/script test/mitmproxy/test_contentview.py test/mitmproxy/test_custom_contentview.py test/mitmproxy/test_app.py test/mitmproxy/test_controller.py test/mitmproxy/test_fuzzing.py test/mitmproxy/test_script.py test/mitmproxy/test_web_app.py test/mitmproxy/test_utils.py test/mitmproxy/test_stateobject.py test/mitmproxy/test_cmdline.py +    PY3TESTS =  test/netlib test/pathod/ test/mitmproxy/script test/mitmproxy/test_contentview.py test/mitmproxy/test_custom_contentview.py test/mitmproxy/test_app.py test/mitmproxy/test_controller.py test/mitmproxy/test_fuzzing.py test/mitmproxy/test_script.py test/mitmproxy/test_web_app.py test/mitmproxy/test_utils.py test/mitmproxy/test_stateobject.py test/mitmproxy/test_cmdline.py test/mitmproxy/test_contrib_tnetstring.py  [testenv:py27]  commands = | 
