From e6e839d56d86e7f7126b3b662a07f12625f3d691 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 5 Jul 2016 17:27:20 -0700 Subject: add python3 tnetstring implementation --- mitmproxy/contrib/py2/__init__.py | 0 mitmproxy/contrib/py2/tnetstring.py | 375 +++++++++++++++++++++++++++++ mitmproxy/contrib/py3/__init__.py | 0 mitmproxy/contrib/py3/tnetstring.py | 233 ++++++++++++++++++ mitmproxy/contrib/py3/tnetstring_tests.py | 133 +++++++++++ mitmproxy/contrib/tnetstring.py | 377 +----------------------------- 6 files changed, 746 insertions(+), 372 deletions(-) create mode 100644 mitmproxy/contrib/py2/__init__.py create mode 100644 mitmproxy/contrib/py2/tnetstring.py create mode 100644 mitmproxy/contrib/py3/__init__.py create mode 100644 mitmproxy/contrib/py3/tnetstring.py create mode 100644 mitmproxy/contrib/py3/tnetstring_tests.py diff --git a/mitmproxy/contrib/py2/__init__.py b/mitmproxy/contrib/py2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mitmproxy/contrib/py2/tnetstring.py b/mitmproxy/contrib/py2/tnetstring.py new file mode 100644 index 00000000..9bf20b09 --- /dev/null +++ b/mitmproxy/contrib/py2/tnetstring.py @@ -0,0 +1,375 @@ +# imported from the tnetstring project: https://github.com/rfk/tnetstring +# +# Copyright (c) 2011 Ryan Kelly +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +""" +tnetstring: data serialization using typed netstrings +====================================================== + + +This is a data serialization library. It's a lot like JSON but it uses a +new syntax called "typed netstrings" that Zed has proposed for use in the +Mongrel2 webserver. It's designed to be simpler and easier to implement +than JSON, with a happy consequence of also being faster in many cases. + +An ordinary netstring is a blob of data prefixed with its length and postfixed +with a sanity-checking comma. The string "hello world" encodes like this:: + + 11:hello world, + +Typed netstrings add other datatypes by replacing the comma with a type tag. +Here's the integer 12345 encoded as a tnetstring:: + + 5:12345# + +And here's the list [12345,True,0] which mixes integers and bools:: + + 19:5:12345#4:true!1:0#] + +Simple enough? This module gives you the following functions: + + :dump: dump an object as a tnetstring to a file + :dumps: dump an object as a tnetstring to a string + :load: load a tnetstring-encoded object from a file + :loads: load a tnetstring-encoded object from a string + :pop: pop a tnetstring-encoded object from the front of a string + +Note that since parsing a tnetstring requires reading all the data into memory +at once, there's no efficiency gain from using the file-based versions of these +functions. They're only here so you can use load() to read precisely one +item from a file or socket without consuming any extra data. + +By default tnetstrings work only with byte strings, not unicode. If you want +unicode strings then pass an optional encoding to the various functions, +like so:: + + >>> print(repr(tnetstring.loads("2:\\xce\\xb1,"))) + '\\xce\\xb1' + >>> + >>> print(repr(tnetstring.loads("2:\\xce\\xb1,","utf8"))) + u'\u03b1' + +""" +from collections import deque + +import six + +__ver_major__ = 0 +__ver_minor__ = 2 +__ver_patch__ = 0 +__ver_sub__ = "" +__version__ = "%d.%d.%d%s" % ( + __ver_major__, __ver_minor__, __ver_patch__, __ver_sub__) + + +def dumps(value): + """ + This function dumps a python object as a tnetstring. + """ + # This uses a deque to collect output fragments in reverse order, + # then joins them together at the end. It's measurably faster + # than creating all the intermediate strings. + # If you're reading this to get a handle on the tnetstring format, + # consider the _gdumps() function instead; it's a standard top-down + # generator that's simpler to understand but much less efficient. + q = deque() + _rdumpq(q, 0, value) + return b''.join(q) + + +def dump(value, file_handle): + """ + This function dumps a python object as a tnetstring and + writes it to the given file. + """ + file_handle.write(dumps(value)) + + +def _rdumpq(q, size, value): + """ + Dump value as a tnetstring, to a deque instance, last chunks first. + + This function generates the tnetstring representation of the given value, + pushing chunks of the output onto the given deque instance. It pushes + the last chunk first, then recursively generates more chunks. + + When passed in the current size of the string in the queue, it will return + the new size of the string in the queue. + + Operating last-chunk-first makes it easy to calculate the size written + for recursive structures without having to build their representation as + a string. This is measurably faster than generating the intermediate + strings, especially on deeply nested structures. + """ + write = q.appendleft + if value is None: + write(b'0:~') + return size + 3 + elif value is True: + write(b'4:true!') + return size + 7 + elif value is False: + write(b'5:false!') + return size + 8 + elif isinstance(value, six.integer_types): + data = str(value).encode() + ldata = len(data) + span = str(ldata).encode() + write(b'#') + write(data) + write(b':') + write(span) + return size + 2 + len(span) + ldata + elif isinstance(value, float): + # Use repr() for float rather than str(). + # It round-trips more accurately. + # Probably unnecessary in later python versions that + # use David Gay's ftoa routines. + data = repr(value).encode() + ldata = len(data) + span = str(ldata).encode() + write(b'^') + write(data) + write(b':') + write(span) + return size + 2 + len(span) + ldata + elif isinstance(value, bytes): + lvalue = len(value) + span = str(lvalue).encode() + write(b',') + write(value) + write(b':') + write(span) + return size + 2 + len(span) + lvalue + elif isinstance(value, (list, tuple)): + write(b']') + init_size = size = size + 1 + for item in reversed(value): + size = _rdumpq(q, size, item) + span = str(size - init_size).encode() + write(b':') + write(span) + return size + 1 + len(span) + elif isinstance(value, dict): + write(b'}') + init_size = size = size + 1 + for (k, v) in value.items(): + size = _rdumpq(q, size, v) + size = _rdumpq(q, size, k) + span = str(size - init_size).encode() + write(b':') + write(span) + return size + 1 + len(span) + else: + raise ValueError("unserializable object: {} ({})".format(value, type(value))) + + +def _gdumps(value): + """ + Generate fragments of value dumped as a tnetstring. + + This is the naive dumping algorithm, implemented as a generator so that + it's easy to pass to "".join() without building a new list. + + This is mainly here for comparison purposes; the _rdumpq version is + measurably faster as it doesn't have to build intermediate strins. + """ + if value is None: + yield b'0:~' + elif value is True: + yield b'4:true!' + elif value is False: + yield b'5:false!' + elif isinstance(value, six.integer_types): + data = str(value).encode() + yield str(len(data)).encode() + yield b':' + yield data + yield b'#' + elif isinstance(value, float): + data = repr(value).encode() + yield str(len(data)).encode() + yield b':' + yield data + yield b'^' + elif isinstance(value, bytes): + yield str(len(value)).encode() + yield b':' + yield value + yield b',' + elif isinstance(value, (list, tuple)): + sub = [] + for item in value: + sub.extend(_gdumps(item)) + sub = b''.join(sub) + yield str(len(sub)).encode() + yield b':' + yield sub + yield b']' + elif isinstance(value, (dict,)): + sub = [] + for (k, v) in value.items(): + sub.extend(_gdumps(k)) + sub.extend(_gdumps(v)) + sub = b''.join(sub) + yield str(len(sub)).encode() + yield b':' + yield sub + yield b'}' + else: + raise ValueError("unserializable object") + + +def loads(string): + """ + This function parses a tnetstring into a python object. + """ + # No point duplicating effort here. In the C-extension version, + # loads() is measurably faster then pop() since it can avoid + # the overhead of building a second string. + return pop(string)[0] + + +def load(file_handle): + """load(file) -> object + + This function reads a tnetstring from a file and parses it into a + python object. The file must support the read() method, and this + function promises not to read more data than necessary. + """ + # Read the length prefix one char at a time. + # Note that the netstring spec explicitly forbids padding zeros. + c = file_handle.read(1) + if not c.isdigit(): + raise ValueError("not a tnetstring: missing or invalid length prefix") + datalen = ord(c) - ord('0') + c = file_handle.read(1) + if datalen != 0: + while c.isdigit(): + datalen = (10 * datalen) + (ord(c) - ord('0')) + if datalen > 999999999: + errmsg = "not a tnetstring: absurdly large length prefix" + raise ValueError(errmsg) + c = file_handle.read(1) + if c != b':': + raise ValueError("not a tnetstring: missing or invalid length prefix") + # Now we can read and parse the payload. + # This repeats the dispatch logic of pop() so we can avoid + # re-constructing the outermost tnetstring. + data = file_handle.read(datalen) + if len(data) != datalen: + raise ValueError("not a tnetstring: length prefix too big") + tns_type = file_handle.read(1) + if tns_type == b',': + return data + if tns_type == b'#': + try: + return int(data) + except ValueError: + raise ValueError("not a tnetstring: invalid integer literal") + if tns_type == b'^': + try: + return float(data) + except ValueError: + raise ValueError("not a tnetstring: invalid float literal") + if tns_type == b'!': + if data == b'true': + return True + elif data == b'false': + return False + else: + raise ValueError("not a tnetstring: invalid boolean literal") + if tns_type == b'~': + if data: + raise ValueError("not a tnetstring: invalid null literal") + return None + if tns_type == b']': + l = [] + while data: + item, data = pop(data) + l.append(item) + return l + if tns_type == b'}': + d = {} + while data: + key, data = pop(data) + val, data = pop(data) + d[key] = val + return d + raise ValueError("unknown type tag") + + +def pop(string): + """pop(string,encoding='utf_8') -> (object, remain) + + This function parses a tnetstring into a python object. + It returns a tuple giving the parsed object and a string + containing any unparsed data from the end of the string. + """ + # Parse out data length, type and remaining string. + try: + dlen, rest = string.split(b':', 1) + dlen = int(dlen) + except ValueError: + raise ValueError("not a tnetstring: missing or invalid length prefix: {}".format(string)) + try: + data, tns_type, remain = rest[:dlen], rest[dlen:dlen + 1], rest[dlen + 1:] + except IndexError: + # This fires if len(rest) < dlen, meaning we don't need + # to further validate that data is the right length. + raise ValueError("not a tnetstring: invalid length prefix: {}".format(dlen)) + # Parse the data based on the type tag. + if tns_type == b',': + return data, remain + if tns_type == b'#': + try: + return int(data), remain + except ValueError: + raise ValueError("not a tnetstring: invalid integer literal: {}".format(data)) + if tns_type == b'^': + try: + return float(data), remain + except ValueError: + raise ValueError("not a tnetstring: invalid float literal: {}".format(data)) + if tns_type == b'!': + if data == b'true': + return True, remain + elif data == b'false': + return False, remain + else: + raise ValueError("not a tnetstring: invalid boolean literal: {}".format(data)) + if tns_type == b'~': + if data: + raise ValueError("not a tnetstring: invalid null literal") + return None, remain + if tns_type == b']': + l = [] + while data: + item, data = pop(data) + l.append(item) + return (l, remain) + if tns_type == b'}': + d = {} + while data: + key, data = pop(data) + val, data = pop(data) + d[key] = val + return d, remain + raise ValueError("unknown type tag: {}".format(tns_type)) diff --git a/mitmproxy/contrib/py3/__init__.py b/mitmproxy/contrib/py3/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mitmproxy/contrib/py3/tnetstring.py b/mitmproxy/contrib/py3/tnetstring.py new file mode 100644 index 00000000..6f38a245 --- /dev/null +++ b/mitmproxy/contrib/py3/tnetstring.py @@ -0,0 +1,233 @@ +""" +tnetstring: data serialization using typed netstrings +====================================================== + +This is a custom Python 3 implementation of tnetstrings. +Compared to other implementations, the main difference +is the conversion of dictionary keys to str. + +An ordinary tnetstring is a blob of data prefixed with its length and postfixed +with its type. Here are some examples: + + >>> tnetstring.dumps("hello world") + 11:hello world, + >>> tnetstring.dumps(12345) + 5:12345# + >>> tnetstring.dumps([12345, True, 0]) + 19:5:12345#4:true!1:0#] + +This module gives you the following functions: + + :dump: dump an object as a tnetstring to a file + :dumps: dump an object as a tnetstring to a string + :load: load a tnetstring-encoded object from a file + :loads: load a tnetstring-encoded object from a string + +Note that since parsing a tnetstring requires reading all the data into memory +at once, there's no efficiency gain from using the file-based versions of these +functions. They're only here so you can use load() to read precisely one +item from a file or socket without consuming any extra data. + +The tnetstrings specification explicitly states that strings are binary blobs +and forbids the use of unicode at the protocol level. +**This implementation decodes dictionary keys as surrogate-escaped ASCII**, +all other strings are returned as plain bytes. + +:Copyright: (c) 2012-2013 by Ryan Kelly . +:Copyright: (c) 2014 by Carlo Pires . +:Copyright: (c) 2016 by Maximilian Hils . + +:License: MIT +""" + +import collections +from typing import io, Union, Tuple + +TSerializable = Union[None, bool, int, float, bytes, list, tuple, dict] + + +def dumps(value: TSerializable) -> bytes: + """ + This function dumps a python object as a tnetstring. + """ + # This uses a deque to collect output fragments in reverse order, + # then joins them together at the end. It's measurably faster + # than creating all the intermediate strings. + q = collections.deque() + _rdumpq(q, 0, value) + return b''.join(q) + + +def dump(value: TSerializable, file_handle: io.BinaryIO) -> None: + """ + This function dumps a python object as a tnetstring and + writes it to the given file. + """ + file_handle.write(dumps(value)) + + +def _rdumpq(q: collections.deque, size: int, value: TSerializable) -> int: + """ + Dump value as a tnetstring, to a deque instance, last chunks first. + + This function generates the tnetstring representation of the given value, + pushing chunks of the output onto the given deque instance. It pushes + the last chunk first, then recursively generates more chunks. + + When passed in the current size of the string in the queue, it will return + the new size of the string in the queue. + + Operating last-chunk-first makes it easy to calculate the size written + for recursive structures without having to build their representation as + a string. This is measurably faster than generating the intermediate + strings, especially on deeply nested structures. + """ + write = q.appendleft + if value is None: + write(b'0:~') + return size + 3 + elif value is True: + write(b'4:true!') + return size + 7 + elif value is False: + write(b'5:false!') + return size + 8 + elif isinstance(value, int): + data = str(value).encode() + ldata = len(data) + span = str(ldata).encode() + write(b'%s:%s#' % (span, data)) + return size + 2 + len(span) + ldata + elif isinstance(value, float): + # Use repr() for float rather than str(). + # It round-trips more accurately. + # Probably unnecessary in later python versions that + # use David Gay's ftoa routines. + data = repr(value).encode() + ldata = len(data) + span = str(ldata).encode() + write(b'%s:%s^' % (span, data)) + return size + 2 + len(span) + ldata + elif isinstance(value, bytes): + lvalue = len(value) + span = str(lvalue).encode() + write(b'%s:%s,' % (span, value)) + return size + 2 + len(span) + lvalue + elif isinstance(value, (list, tuple)): + write(b']') + init_size = size = size + 1 + for item in reversed(value): + size = _rdumpq(q, size, item) + span = str(size - init_size).encode() + write(b':') + write(span) + return size + 1 + len(span) + elif isinstance(value, dict): + write(b'}') + init_size = size = size + 1 + for (k, v) in value.items(): + size = _rdumpq(q, size, v) + size = _rdumpq(q, size, k) + span = str(size - init_size).encode() + write(b':') + write(span) + return size + 1 + len(span) + else: + raise ValueError("unserializable object: {} ({})".format(value, type(value))) + + +def loads(string: bytes) -> TSerializable: + """ + This function parses a tnetstring into a python object. + """ + return pop(string)[0] + + +def load(file_handle: io.BinaryIO) -> TSerializable: + """load(file) -> object + + This function reads a tnetstring from a file and parses it into a + python object. The file must support the read() method, and this + function promises not to read more data than necessary. + """ + # Read the length prefix one char at a time. + # Note that the netstring spec explicitly forbids padding zeros. + c = file_handle.read(1) + data_length = b"" + while ord(b'0') <= ord(c) <= ord(b'9'): + data_length += c + if len(data_length) > 9: + raise ValueError("not a tnetstring: absurdly large length prefix") + c = file_handle.read(1) + if c != b":": + raise ValueError("not a tnetstring: missing or invalid length prefix") + + data = file_handle.read(int(data_length)) + data_type = file_handle.read(1)[0] + + return parse(data_type, data) + + +def parse(data_type: int, data: bytes) -> TSerializable: + if data_type == ord(b','): + return data + if data_type == ord(b'#'): + try: + return int(data) + except ValueError: + raise ValueError("not a tnetstring: invalid integer literal: {}".format(data)) + if data_type == ord(b'^'): + try: + return float(data) + except ValueError: + raise ValueError("not a tnetstring: invalid float literal: {}".format(data)) + if data_type == ord(b'!'): + if data == b'true': + return True + elif data == b'false': + return False + else: + raise ValueError("not a tnetstring: invalid boolean literal: {}".format(data)) + if data_type == ord(b'~'): + if data: + raise ValueError("not a tnetstring: invalid null literal") + return None + if data_type == ord(b']'): + l = [] + while data: + item, data = pop(data) + l.append(item) + return l + if data_type == ord(b'}'): + d = {} + while data: + key, data = pop(data) + val, data = pop(data) + d[key] = val + return d + raise ValueError("unknown type tag: {}".format(data_type)) + + +def pop(data: bytes) -> Tuple[TSerializable, bytes]: + """ + This function parses a tnetstring into a python object. + It returns a tuple giving the parsed object and a string + containing any unparsed data from the end of the string. + """ + # Parse out data length, type and remaining string. + try: + length, data = data.split(b':', 1) + length = int(length) + except ValueError: + raise ValueError("not a tnetstring: missing or invalid length prefix: {}".format(data)) + try: + data, data_type, remain = data[:length], data[length], data[length + 1:] + except IndexError: + # This fires if len(data) < dlen, meaning we don't need + # to further validate that data is the right length. + raise ValueError("not a tnetstring: invalid length prefix: {}".format(length)) + # Parse the data based on the type tag. + return parse(data_type, data), remain + + +__all__ = ["dump", "dumps", "load", "loads"] diff --git a/mitmproxy/contrib/py3/tnetstring_tests.py b/mitmproxy/contrib/py3/tnetstring_tests.py new file mode 100644 index 00000000..545889c8 --- /dev/null +++ b/mitmproxy/contrib/py3/tnetstring_tests.py @@ -0,0 +1,133 @@ +import unittest +import random +import math +import io +from . import tnetstring +import struct + +MAXINT = 2 ** (struct.Struct('i').size * 8 - 1) - 1 + +FORMAT_EXAMPLES = { + b'0:}': {}, + b'0:]': [], + b'51:5:hello,39:11:12345678901#4:this,4:true!0:~4:\x00\x00\x00\x00,]}': + {b'hello': [12345678901, b'this', True, None, b'\x00\x00\x00\x00']}, + b'5:12345#': 12345, + b'12:this is cool,': b'this is cool', + b'0:,': b'', + b'0:~': None, + b'4:true!': True, + b'5:false!': False, + b'10:\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00,': b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', + b'24:5:12345#5:67890#5:xxxxx,]': [12345, 67890, b'xxxxx'], + b'18:3:0.1^3:0.2^3:0.3^]': [0.1, 0.2, 0.3], + b'243:238:233:228:223:218:213:208:203:198:193:188:183:178:173:168:163:158:153:148:143:138:133:128:123:118:113:108:103:99:95:91:87:83:79:75:71:67:63:59:55:51:47:43:39:35:31:27:23:19:15:11:hello-there,]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]': [[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[b'hello-there']]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]] +} + +def get_random_object(random=random, depth=0): + """Generate a random serializable object.""" + # The probability of generating a scalar value increases as the depth increase. + # This ensures that we bottom out eventually. + if random.randint(depth,10) <= 4: + what = random.randint(0,1) + if what == 0: + n = random.randint(0,10) + l = [] + for _ in range(n): + l.append(get_random_object(random,depth+1)) + return l + if what == 1: + n = random.randint(0,10) + d = {} + for _ in range(n): + n = random.randint(0,100) + k = bytes([random.randint(32,126) for _ in range(n)]) + d[k] = get_random_object(random,depth+1) + return d + else: + what = random.randint(0,4) + if what == 0: + return None + if what == 1: + return True + if what == 2: + return False + if what == 3: + if random.randint(0,1) == 0: + return random.randint(0,MAXINT) + else: + return -1 * random.randint(0,MAXINT) + n = random.randint(0,100) + return bytes([random.randint(32,126) for _ in range(n)]) + +class Test_Format(unittest.TestCase): + def test_roundtrip_format_examples(self): + for data, expect in FORMAT_EXAMPLES.items(): + self.assertEqual(expect,tnetstring.loads(data)) + self.assertEqual(expect,tnetstring.loads(tnetstring.dumps(expect))) + self.assertEqual((expect,b''),tnetstring.pop(data)) + + def test_roundtrip_format_random(self): + for _ in range(500): + v = get_random_object() + self.assertEqual(v,tnetstring.loads(tnetstring.dumps(v))) + self.assertEqual((v,b""),tnetstring.pop(tnetstring.dumps(v))) + + def test_unicode_handling(self): + with self.assertRaises(ValueError): + tnetstring.dumps("hello") + self.assertEqual(tnetstring.dumps("hello".encode()),b"5:hello,") + self.assertEqual(type(tnetstring.loads(b"5:hello,")),bytes) + + def test_roundtrip_format_unicode(self): + for _ in range(500): + v = get_random_object() + self.assertEqual(v,tnetstring.loads(tnetstring.dumps(v))) + self.assertEqual((v,b''),tnetstring.pop(tnetstring.dumps(v))) + + def test_roundtrip_big_integer(self): + i1 = math.factorial(30000) + s = tnetstring.dumps(i1) + i2 = tnetstring.loads(s) + self.assertEqual(i1, i2) + +class Test_FileLoading(unittest.TestCase): + def test_roundtrip_file_examples(self): + for data, expect in FORMAT_EXAMPLES.items(): + s = io.BytesIO() + s.write(data) + s.write(b'OK') + s.seek(0) + self.assertEqual(expect,tnetstring.load(s)) + self.assertEqual(b'OK',s.read()) + s = io.BytesIO() + tnetstring.dump(expect,s) + s.write(b'OK') + s.seek(0) + self.assertEqual(expect,tnetstring.load(s)) + self.assertEqual(b'OK',s.read()) + + def test_roundtrip_file_random(self): + for _ in range(500): + v = get_random_object() + s = io.BytesIO() + tnetstring.dump(v,s) + s.write(b'OK') + s.seek(0) + self.assertEqual(v,tnetstring.load(s)) + self.assertEqual(b'OK',s.read()) + + def test_error_on_absurd_lengths(self): + s = io.BytesIO() + s.write(b'1000000000:pwned!,') + s.seek(0) + with self.assertRaises(ValueError): + tnetstring.load(s) + self.assertEqual(s.read(1),b':') + +def suite(): + loader = unittest.TestLoader() + suite = unittest.TestSuite() + suite.addTest(loader.loadTestsFromTestCase(Test_Format)) + suite.addTest(loader.loadTestsFromTestCase(Test_FileLoading)) + return suite \ No newline at end of file diff --git a/mitmproxy/contrib/tnetstring.py b/mitmproxy/contrib/tnetstring.py index 9bf20b09..58daec5c 100644 --- a/mitmproxy/contrib/tnetstring.py +++ b/mitmproxy/contrib/tnetstring.py @@ -1,375 +1,8 @@ -# imported from the tnetstring project: https://github.com/rfk/tnetstring -# -# Copyright (c) 2011 Ryan Kelly -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. -""" -tnetstring: data serialization using typed netstrings -====================================================== - - -This is a data serialization library. It's a lot like JSON but it uses a -new syntax called "typed netstrings" that Zed has proposed for use in the -Mongrel2 webserver. It's designed to be simpler and easier to implement -than JSON, with a happy consequence of also being faster in many cases. - -An ordinary netstring is a blob of data prefixed with its length and postfixed -with a sanity-checking comma. The string "hello world" encodes like this:: - - 11:hello world, - -Typed netstrings add other datatypes by replacing the comma with a type tag. -Here's the integer 12345 encoded as a tnetstring:: - - 5:12345# - -And here's the list [12345,True,0] which mixes integers and bools:: - - 19:5:12345#4:true!1:0#] - -Simple enough? This module gives you the following functions: - - :dump: dump an object as a tnetstring to a file - :dumps: dump an object as a tnetstring to a string - :load: load a tnetstring-encoded object from a file - :loads: load a tnetstring-encoded object from a string - :pop: pop a tnetstring-encoded object from the front of a string - -Note that since parsing a tnetstring requires reading all the data into memory -at once, there's no efficiency gain from using the file-based versions of these -functions. They're only here so you can use load() to read precisely one -item from a file or socket without consuming any extra data. - -By default tnetstrings work only with byte strings, not unicode. If you want -unicode strings then pass an optional encoding to the various functions, -like so:: - - >>> print(repr(tnetstring.loads("2:\\xce\\xb1,"))) - '\\xce\\xb1' - >>> - >>> print(repr(tnetstring.loads("2:\\xce\\xb1,","utf8"))) - u'\u03b1' - -""" -from collections import deque - import six -__ver_major__ = 0 -__ver_minor__ = 2 -__ver_patch__ = 0 -__ver_sub__ = "" -__version__ = "%d.%d.%d%s" % ( - __ver_major__, __ver_minor__, __ver_patch__, __ver_sub__) - - -def dumps(value): - """ - This function dumps a python object as a tnetstring. - """ - # This uses a deque to collect output fragments in reverse order, - # then joins them together at the end. It's measurably faster - # than creating all the intermediate strings. - # If you're reading this to get a handle on the tnetstring format, - # consider the _gdumps() function instead; it's a standard top-down - # generator that's simpler to understand but much less efficient. - q = deque() - _rdumpq(q, 0, value) - return b''.join(q) - - -def dump(value, file_handle): - """ - This function dumps a python object as a tnetstring and - writes it to the given file. - """ - file_handle.write(dumps(value)) - - -def _rdumpq(q, size, value): - """ - Dump value as a tnetstring, to a deque instance, last chunks first. - - This function generates the tnetstring representation of the given value, - pushing chunks of the output onto the given deque instance. It pushes - the last chunk first, then recursively generates more chunks. - - When passed in the current size of the string in the queue, it will return - the new size of the string in the queue. - - Operating last-chunk-first makes it easy to calculate the size written - for recursive structures without having to build their representation as - a string. This is measurably faster than generating the intermediate - strings, especially on deeply nested structures. - """ - write = q.appendleft - if value is None: - write(b'0:~') - return size + 3 - elif value is True: - write(b'4:true!') - return size + 7 - elif value is False: - write(b'5:false!') - return size + 8 - elif isinstance(value, six.integer_types): - data = str(value).encode() - ldata = len(data) - span = str(ldata).encode() - write(b'#') - write(data) - write(b':') - write(span) - return size + 2 + len(span) + ldata - elif isinstance(value, float): - # Use repr() for float rather than str(). - # It round-trips more accurately. - # Probably unnecessary in later python versions that - # use David Gay's ftoa routines. - data = repr(value).encode() - ldata = len(data) - span = str(ldata).encode() - write(b'^') - write(data) - write(b':') - write(span) - return size + 2 + len(span) + ldata - elif isinstance(value, bytes): - lvalue = len(value) - span = str(lvalue).encode() - write(b',') - write(value) - write(b':') - write(span) - return size + 2 + len(span) + lvalue - elif isinstance(value, (list, tuple)): - write(b']') - init_size = size = size + 1 - for item in reversed(value): - size = _rdumpq(q, size, item) - span = str(size - init_size).encode() - write(b':') - write(span) - return size + 1 + len(span) - elif isinstance(value, dict): - write(b'}') - init_size = size = size + 1 - for (k, v) in value.items(): - size = _rdumpq(q, size, v) - size = _rdumpq(q, size, k) - span = str(size - init_size).encode() - write(b':') - write(span) - return size + 1 + len(span) - else: - raise ValueError("unserializable object: {} ({})".format(value, type(value))) - - -def _gdumps(value): - """ - Generate fragments of value dumped as a tnetstring. - - This is the naive dumping algorithm, implemented as a generator so that - it's easy to pass to "".join() without building a new list. - - This is mainly here for comparison purposes; the _rdumpq version is - measurably faster as it doesn't have to build intermediate strins. - """ - if value is None: - yield b'0:~' - elif value is True: - yield b'4:true!' - elif value is False: - yield b'5:false!' - elif isinstance(value, six.integer_types): - data = str(value).encode() - yield str(len(data)).encode() - yield b':' - yield data - yield b'#' - elif isinstance(value, float): - data = repr(value).encode() - yield str(len(data)).encode() - yield b':' - yield data - yield b'^' - elif isinstance(value, bytes): - yield str(len(value)).encode() - yield b':' - yield value - yield b',' - elif isinstance(value, (list, tuple)): - sub = [] - for item in value: - sub.extend(_gdumps(item)) - sub = b''.join(sub) - yield str(len(sub)).encode() - yield b':' - yield sub - yield b']' - elif isinstance(value, (dict,)): - sub = [] - for (k, v) in value.items(): - sub.extend(_gdumps(k)) - sub.extend(_gdumps(v)) - sub = b''.join(sub) - yield str(len(sub)).encode() - yield b':' - yield sub - yield b'}' - else: - raise ValueError("unserializable object") - - -def loads(string): - """ - This function parses a tnetstring into a python object. - """ - # No point duplicating effort here. In the C-extension version, - # loads() is measurably faster then pop() since it can avoid - # the overhead of building a second string. - return pop(string)[0] - - -def load(file_handle): - """load(file) -> object - - This function reads a tnetstring from a file and parses it into a - python object. The file must support the read() method, and this - function promises not to read more data than necessary. - """ - # Read the length prefix one char at a time. - # Note that the netstring spec explicitly forbids padding zeros. - c = file_handle.read(1) - if not c.isdigit(): - raise ValueError("not a tnetstring: missing or invalid length prefix") - datalen = ord(c) - ord('0') - c = file_handle.read(1) - if datalen != 0: - while c.isdigit(): - datalen = (10 * datalen) + (ord(c) - ord('0')) - if datalen > 999999999: - errmsg = "not a tnetstring: absurdly large length prefix" - raise ValueError(errmsg) - c = file_handle.read(1) - if c != b':': - raise ValueError("not a tnetstring: missing or invalid length prefix") - # Now we can read and parse the payload. - # This repeats the dispatch logic of pop() so we can avoid - # re-constructing the outermost tnetstring. - data = file_handle.read(datalen) - if len(data) != datalen: - raise ValueError("not a tnetstring: length prefix too big") - tns_type = file_handle.read(1) - if tns_type == b',': - return data - if tns_type == b'#': - try: - return int(data) - except ValueError: - raise ValueError("not a tnetstring: invalid integer literal") - if tns_type == b'^': - try: - return float(data) - except ValueError: - raise ValueError("not a tnetstring: invalid float literal") - if tns_type == b'!': - if data == b'true': - return True - elif data == b'false': - return False - else: - raise ValueError("not a tnetstring: invalid boolean literal") - if tns_type == b'~': - if data: - raise ValueError("not a tnetstring: invalid null literal") - return None - if tns_type == b']': - l = [] - while data: - item, data = pop(data) - l.append(item) - return l - if tns_type == b'}': - d = {} - while data: - key, data = pop(data) - val, data = pop(data) - d[key] = val - return d - raise ValueError("unknown type tag") - - -def pop(string): - """pop(string,encoding='utf_8') -> (object, remain) +if six.PY2: + from .py2.tnetstring import load, loads, dump, dumps +else: + from .py3.tnetstring import load, loads, dump, dumps - This function parses a tnetstring into a python object. - It returns a tuple giving the parsed object and a string - containing any unparsed data from the end of the string. - """ - # Parse out data length, type and remaining string. - try: - dlen, rest = string.split(b':', 1) - dlen = int(dlen) - except ValueError: - raise ValueError("not a tnetstring: missing or invalid length prefix: {}".format(string)) - try: - data, tns_type, remain = rest[:dlen], rest[dlen:dlen + 1], rest[dlen + 1:] - except IndexError: - # This fires if len(rest) < dlen, meaning we don't need - # to further validate that data is the right length. - raise ValueError("not a tnetstring: invalid length prefix: {}".format(dlen)) - # Parse the data based on the type tag. - if tns_type == b',': - return data, remain - if tns_type == b'#': - try: - return int(data), remain - except ValueError: - raise ValueError("not a tnetstring: invalid integer literal: {}".format(data)) - if tns_type == b'^': - try: - return float(data), remain - except ValueError: - raise ValueError("not a tnetstring: invalid float literal: {}".format(data)) - if tns_type == b'!': - if data == b'true': - return True, remain - elif data == b'false': - return False, remain - else: - raise ValueError("not a tnetstring: invalid boolean literal: {}".format(data)) - if tns_type == b'~': - if data: - raise ValueError("not a tnetstring: invalid null literal") - return None, remain - if tns_type == b']': - l = [] - while data: - item, data = pop(data) - l.append(item) - return (l, remain) - if tns_type == b'}': - d = {} - while data: - key, data = pop(data) - val, data = pop(data) - d[key] = val - return d, remain - raise ValueError("unknown type tag: {}".format(tns_type)) +__all__ = ["load", "loads", "dump", "dumps"] -- cgit v1.2.3 From 684b4b5130aa9cc75322dd270172b263615d39dc Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 5 Jul 2016 18:48:45 -0700 Subject: tnetstring: keys are str on py3. migrate flow.io_compat --- mitmproxy/contrib/py3/tnetstring.py | 8 ++++++-- mitmproxy/contrib/py3/tnetstring_tests.py | 4 ++-- mitmproxy/contrib/tnetstring.py | 6 +++--- mitmproxy/flow/io.py | 9 +++++---- mitmproxy/flow/io_compat.py | 3 +-- test/mitmproxy/test_contrib_tnetstring.py | 4 ++-- tox.ini | 2 +- 7 files changed, 20 insertions(+), 16 deletions(-) diff --git a/mitmproxy/contrib/py3/tnetstring.py b/mitmproxy/contrib/py3/tnetstring.py index 6f38a245..6998fc82 100644 --- a/mitmproxy/contrib/py3/tnetstring.py +++ b/mitmproxy/contrib/py3/tnetstring.py @@ -126,6 +126,8 @@ def _rdumpq(q: collections.deque, size: int, value: TSerializable) -> int: write(b'}') init_size = size = size + 1 for (k, v) in value.items(): + if isinstance(k, str): + k = k.encode("ascii", "strict") size = _rdumpq(q, size, v) size = _rdumpq(q, size, k) span = str(size - init_size).encode() @@ -154,7 +156,7 @@ def load(file_handle: io.BinaryIO) -> TSerializable: # Note that the netstring spec explicitly forbids padding zeros. c = file_handle.read(1) data_length = b"" - while ord(b'0') <= ord(c) <= ord(b'9'): + while c.isdigit(): data_length += c if len(data_length) > 9: raise ValueError("not a tnetstring: absurdly large length prefix") @@ -202,6 +204,8 @@ def parse(data_type: int, data: bytes) -> TSerializable: d = {} while data: key, data = pop(data) + if isinstance(key, bytes): + key = key.decode("ascii", "strict") val, data = pop(data) d[key] = val return d @@ -230,4 +234,4 @@ def pop(data: bytes) -> Tuple[TSerializable, bytes]: return parse(data_type, data), remain -__all__ = ["dump", "dumps", "load", "loads"] +__all__ = ["dump", "dumps", "load", "loads", "pop"] diff --git a/mitmproxy/contrib/py3/tnetstring_tests.py b/mitmproxy/contrib/py3/tnetstring_tests.py index 545889c8..4ee184d5 100644 --- a/mitmproxy/contrib/py3/tnetstring_tests.py +++ b/mitmproxy/contrib/py3/tnetstring_tests.py @@ -11,7 +11,7 @@ FORMAT_EXAMPLES = { b'0:}': {}, b'0:]': [], b'51:5:hello,39:11:12345678901#4:this,4:true!0:~4:\x00\x00\x00\x00,]}': - {b'hello': [12345678901, b'this', True, None, b'\x00\x00\x00\x00']}, + {'hello': [12345678901, b'this', True, None, b'\x00\x00\x00\x00']}, b'5:12345#': 12345, b'12:this is cool,': b'this is cool', b'0:,': b'', @@ -41,7 +41,7 @@ def get_random_object(random=random, depth=0): d = {} for _ in range(n): n = random.randint(0,100) - k = bytes([random.randint(32,126) for _ in range(n)]) + k = str([random.randint(32,126) for _ in range(n)]) d[k] = get_random_object(random,depth+1) return d else: diff --git a/mitmproxy/contrib/tnetstring.py b/mitmproxy/contrib/tnetstring.py index 58daec5c..1ebaba21 100644 --- a/mitmproxy/contrib/tnetstring.py +++ b/mitmproxy/contrib/tnetstring.py @@ -1,8 +1,8 @@ import six if six.PY2: - from .py2.tnetstring import load, loads, dump, dumps + from .py2.tnetstring import load, loads, dump, dumps, pop else: - from .py3.tnetstring import load, loads, dump, dumps + from .py3.tnetstring import load, loads, dump, dumps, pop -__all__ = ["load", "loads", "dump", "dumps"] +__all__ = ["load", "loads", "dump", "dumps", "pop"] diff --git a/mitmproxy/flow/io.py b/mitmproxy/flow/io.py index 671ddf43..e5716940 100644 --- a/mitmproxy/flow/io.py +++ b/mitmproxy/flow/io.py @@ -44,12 +44,13 @@ class FlowReader: raise exceptions.FlowReadException(str(e)) if can_tell: off = self.fo.tell() - if data["type"] not in models.FLOW_TYPES: - raise exceptions.FlowReadException("Unknown flow type: {}".format(data["type"])) - yield models.FLOW_TYPES[data["type"]].from_state(data) + data_type = data["type"].decode() + if data_type not in models.FLOW_TYPES: + raise exceptions.FlowReadException("Unknown flow type: {}".format(data_type)) + yield models.FLOW_TYPES[data_type].from_state(data) except ValueError: # Error is due to EOF - if can_tell and self.fo.tell() == off and self.fo.read() == '': + if can_tell and self.fo.tell() == off and self.fo.read() == b'': return raise exceptions.FlowReadException("Invalid data format.") diff --git a/mitmproxy/flow/io_compat.py b/mitmproxy/flow/io_compat.py index 1023e87f..55971f5e 100644 --- a/mitmproxy/flow/io_compat.py +++ b/mitmproxy/flow/io_compat.py @@ -9,6 +9,7 @@ from netlib import version def convert_013_014(data): data["request"]["first_line_format"] = data["request"].pop("form_in") data["request"]["http_version"] = "HTTP/" + ".".join(str(x) for x in data["request"].pop("httpversion")) + data["response"]["http_version"] = "HTTP/" + ".".join(str(x) for x in data["response"].pop("httpversion")) data["response"]["status_code"] = data["response"].pop("code") data["response"]["body"] = data["response"].pop("content") data["server_conn"].pop("state") @@ -26,8 +27,6 @@ def convert_015_016(data): for m in ("request", "response"): if "body" in data[m]: data[m]["content"] = data[m].pop("body") - if "httpversion" in data[m]: - data[m]["http_version"] = data[m].pop("httpversion") if "msg" in data["response"]: data["response"]["reason"] = data["response"].pop("msg") data["request"].pop("form_out", None) diff --git a/test/mitmproxy/test_contrib_tnetstring.py b/test/mitmproxy/test_contrib_tnetstring.py index 17654ad9..8ae35a25 100644 --- a/test/mitmproxy/test_contrib_tnetstring.py +++ b/test/mitmproxy/test_contrib_tnetstring.py @@ -12,7 +12,7 @@ FORMAT_EXAMPLES = { b'0:}': {}, b'0:]': [], b'51:5:hello,39:11:12345678901#4:this,4:true!0:~4:\x00\x00\x00\x00,]}': - {b'hello': [12345678901, b'this', True, None, b'\x00\x00\x00\x00']}, + {'hello': [12345678901, b'this', True, None, b'\x00\x00\x00\x00']}, b'5:12345#': 12345, b'12:this is cool,': b'this is cool', b'0:,': b'', @@ -43,7 +43,7 @@ def get_random_object(random=random, depth=0): d = {} for _ in range(n): n = random.randint(0, 100) - k = bytes([random.randint(32, 126) for _ in range(n)]) + k = str([random.randint(32, 126) for _ in range(n)]) d[k] = get_random_object(random, depth + 1) return d else: diff --git a/tox.ini b/tox.ini index a7b5e7d3..251609a5 100644 --- a/tox.ini +++ b/tox.ini @@ -16,7 +16,7 @@ commands = [testenv:py35] setenv = - TESTS = test/netlib test/pathod/ test/mitmproxy/script test/mitmproxy/test_contentview.py test/mitmproxy/test_custom_contentview.py test/mitmproxy/test_app.py test/mitmproxy/test_controller.py test/mitmproxy/test_fuzzing.py test/mitmproxy/test_script.py test/mitmproxy/test_web_app.py test/mitmproxy/test_utils.py test/mitmproxy/test_stateobject.py test/mitmproxy/test_cmdline.py test/mitmproxy/test_contrib_tnetstring.py test/mitmproxy/test_proxy.py test/mitmproxy/test_protocol_http1.py test/mitmproxy/test_platform_pf.py test/mitmproxy/test_server.py test/mitmproxy/test_filt.py test/mitmproxy/test_flow_export.py test/mitmproxy/test_web_master.py + TESTS = test/netlib test/pathod/ test/mitmproxy/script test/mitmproxy/test_contentview.py test/mitmproxy/test_custom_contentview.py test/mitmproxy/test_app.py test/mitmproxy/test_controller.py test/mitmproxy/test_fuzzing.py test/mitmproxy/test_script.py test/mitmproxy/test_web_app.py test/mitmproxy/test_utils.py test/mitmproxy/test_stateobject.py test/mitmproxy/test_cmdline.py test/mitmproxy/test_contrib_tnetstring.py test/mitmproxy/test_proxy.py test/mitmproxy/test_protocol_http1.py test/mitmproxy/test_platform_pf.py test/mitmproxy/test_server.py test/mitmproxy/test_filt.py test/mitmproxy/test_flow_export.py test/mitmproxy/test_web_master.py test/mitmproxy/test_flow_format_compat.py HOME = {envtmpdir} [testenv:docs] -- cgit v1.2.3 From 48ee3a553e30b36c16bfbe1674d3313605dff661 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 5 Jul 2016 19:25:56 -0700 Subject: add tnetstring unicode type --- mitmproxy/contrib/py2/__init__.py | 0 mitmproxy/contrib/py2/tnetstring.py | 375 ------------------------------ mitmproxy/contrib/py3/__init__.py | 0 mitmproxy/contrib/py3/tnetstring.py | 237 ------------------- mitmproxy/contrib/py3/tnetstring_tests.py | 133 ----------- mitmproxy/contrib/tnetstring.py | 256 +++++++++++++++++++- 6 files changed, 251 insertions(+), 750 deletions(-) delete mode 100644 mitmproxy/contrib/py2/__init__.py delete mode 100644 mitmproxy/contrib/py2/tnetstring.py delete mode 100644 mitmproxy/contrib/py3/__init__.py delete mode 100644 mitmproxy/contrib/py3/tnetstring.py delete mode 100644 mitmproxy/contrib/py3/tnetstring_tests.py diff --git a/mitmproxy/contrib/py2/__init__.py b/mitmproxy/contrib/py2/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/mitmproxy/contrib/py2/tnetstring.py b/mitmproxy/contrib/py2/tnetstring.py deleted file mode 100644 index 9bf20b09..00000000 --- a/mitmproxy/contrib/py2/tnetstring.py +++ /dev/null @@ -1,375 +0,0 @@ -# imported from the tnetstring project: https://github.com/rfk/tnetstring -# -# Copyright (c) 2011 Ryan Kelly -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. -""" -tnetstring: data serialization using typed netstrings -====================================================== - - -This is a data serialization library. It's a lot like JSON but it uses a -new syntax called "typed netstrings" that Zed has proposed for use in the -Mongrel2 webserver. It's designed to be simpler and easier to implement -than JSON, with a happy consequence of also being faster in many cases. - -An ordinary netstring is a blob of data prefixed with its length and postfixed -with a sanity-checking comma. The string "hello world" encodes like this:: - - 11:hello world, - -Typed netstrings add other datatypes by replacing the comma with a type tag. -Here's the integer 12345 encoded as a tnetstring:: - - 5:12345# - -And here's the list [12345,True,0] which mixes integers and bools:: - - 19:5:12345#4:true!1:0#] - -Simple enough? This module gives you the following functions: - - :dump: dump an object as a tnetstring to a file - :dumps: dump an object as a tnetstring to a string - :load: load a tnetstring-encoded object from a file - :loads: load a tnetstring-encoded object from a string - :pop: pop a tnetstring-encoded object from the front of a string - -Note that since parsing a tnetstring requires reading all the data into memory -at once, there's no efficiency gain from using the file-based versions of these -functions. They're only here so you can use load() to read precisely one -item from a file or socket without consuming any extra data. - -By default tnetstrings work only with byte strings, not unicode. If you want -unicode strings then pass an optional encoding to the various functions, -like so:: - - >>> print(repr(tnetstring.loads("2:\\xce\\xb1,"))) - '\\xce\\xb1' - >>> - >>> print(repr(tnetstring.loads("2:\\xce\\xb1,","utf8"))) - u'\u03b1' - -""" -from collections import deque - -import six - -__ver_major__ = 0 -__ver_minor__ = 2 -__ver_patch__ = 0 -__ver_sub__ = "" -__version__ = "%d.%d.%d%s" % ( - __ver_major__, __ver_minor__, __ver_patch__, __ver_sub__) - - -def dumps(value): - """ - This function dumps a python object as a tnetstring. - """ - # This uses a deque to collect output fragments in reverse order, - # then joins them together at the end. It's measurably faster - # than creating all the intermediate strings. - # If you're reading this to get a handle on the tnetstring format, - # consider the _gdumps() function instead; it's a standard top-down - # generator that's simpler to understand but much less efficient. - q = deque() - _rdumpq(q, 0, value) - return b''.join(q) - - -def dump(value, file_handle): - """ - This function dumps a python object as a tnetstring and - writes it to the given file. - """ - file_handle.write(dumps(value)) - - -def _rdumpq(q, size, value): - """ - Dump value as a tnetstring, to a deque instance, last chunks first. - - This function generates the tnetstring representation of the given value, - pushing chunks of the output onto the given deque instance. It pushes - the last chunk first, then recursively generates more chunks. - - When passed in the current size of the string in the queue, it will return - the new size of the string in the queue. - - Operating last-chunk-first makes it easy to calculate the size written - for recursive structures without having to build their representation as - a string. This is measurably faster than generating the intermediate - strings, especially on deeply nested structures. - """ - write = q.appendleft - if value is None: - write(b'0:~') - return size + 3 - elif value is True: - write(b'4:true!') - return size + 7 - elif value is False: - write(b'5:false!') - return size + 8 - elif isinstance(value, six.integer_types): - data = str(value).encode() - ldata = len(data) - span = str(ldata).encode() - write(b'#') - write(data) - write(b':') - write(span) - return size + 2 + len(span) + ldata - elif isinstance(value, float): - # Use repr() for float rather than str(). - # It round-trips more accurately. - # Probably unnecessary in later python versions that - # use David Gay's ftoa routines. - data = repr(value).encode() - ldata = len(data) - span = str(ldata).encode() - write(b'^') - write(data) - write(b':') - write(span) - return size + 2 + len(span) + ldata - elif isinstance(value, bytes): - lvalue = len(value) - span = str(lvalue).encode() - write(b',') - write(value) - write(b':') - write(span) - return size + 2 + len(span) + lvalue - elif isinstance(value, (list, tuple)): - write(b']') - init_size = size = size + 1 - for item in reversed(value): - size = _rdumpq(q, size, item) - span = str(size - init_size).encode() - write(b':') - write(span) - return size + 1 + len(span) - elif isinstance(value, dict): - write(b'}') - init_size = size = size + 1 - for (k, v) in value.items(): - size = _rdumpq(q, size, v) - size = _rdumpq(q, size, k) - span = str(size - init_size).encode() - write(b':') - write(span) - return size + 1 + len(span) - else: - raise ValueError("unserializable object: {} ({})".format(value, type(value))) - - -def _gdumps(value): - """ - Generate fragments of value dumped as a tnetstring. - - This is the naive dumping algorithm, implemented as a generator so that - it's easy to pass to "".join() without building a new list. - - This is mainly here for comparison purposes; the _rdumpq version is - measurably faster as it doesn't have to build intermediate strins. - """ - if value is None: - yield b'0:~' - elif value is True: - yield b'4:true!' - elif value is False: - yield b'5:false!' - elif isinstance(value, six.integer_types): - data = str(value).encode() - yield str(len(data)).encode() - yield b':' - yield data - yield b'#' - elif isinstance(value, float): - data = repr(value).encode() - yield str(len(data)).encode() - yield b':' - yield data - yield b'^' - elif isinstance(value, bytes): - yield str(len(value)).encode() - yield b':' - yield value - yield b',' - elif isinstance(value, (list, tuple)): - sub = [] - for item in value: - sub.extend(_gdumps(item)) - sub = b''.join(sub) - yield str(len(sub)).encode() - yield b':' - yield sub - yield b']' - elif isinstance(value, (dict,)): - sub = [] - for (k, v) in value.items(): - sub.extend(_gdumps(k)) - sub.extend(_gdumps(v)) - sub = b''.join(sub) - yield str(len(sub)).encode() - yield b':' - yield sub - yield b'}' - else: - raise ValueError("unserializable object") - - -def loads(string): - """ - This function parses a tnetstring into a python object. - """ - # No point duplicating effort here. In the C-extension version, - # loads() is measurably faster then pop() since it can avoid - # the overhead of building a second string. - return pop(string)[0] - - -def load(file_handle): - """load(file) -> object - - This function reads a tnetstring from a file and parses it into a - python object. The file must support the read() method, and this - function promises not to read more data than necessary. - """ - # Read the length prefix one char at a time. - # Note that the netstring spec explicitly forbids padding zeros. - c = file_handle.read(1) - if not c.isdigit(): - raise ValueError("not a tnetstring: missing or invalid length prefix") - datalen = ord(c) - ord('0') - c = file_handle.read(1) - if datalen != 0: - while c.isdigit(): - datalen = (10 * datalen) + (ord(c) - ord('0')) - if datalen > 999999999: - errmsg = "not a tnetstring: absurdly large length prefix" - raise ValueError(errmsg) - c = file_handle.read(1) - if c != b':': - raise ValueError("not a tnetstring: missing or invalid length prefix") - # Now we can read and parse the payload. - # This repeats the dispatch logic of pop() so we can avoid - # re-constructing the outermost tnetstring. - data = file_handle.read(datalen) - if len(data) != datalen: - raise ValueError("not a tnetstring: length prefix too big") - tns_type = file_handle.read(1) - if tns_type == b',': - return data - if tns_type == b'#': - try: - return int(data) - except ValueError: - raise ValueError("not a tnetstring: invalid integer literal") - if tns_type == b'^': - try: - return float(data) - except ValueError: - raise ValueError("not a tnetstring: invalid float literal") - if tns_type == b'!': - if data == b'true': - return True - elif data == b'false': - return False - else: - raise ValueError("not a tnetstring: invalid boolean literal") - if tns_type == b'~': - if data: - raise ValueError("not a tnetstring: invalid null literal") - return None - if tns_type == b']': - l = [] - while data: - item, data = pop(data) - l.append(item) - return l - if tns_type == b'}': - d = {} - while data: - key, data = pop(data) - val, data = pop(data) - d[key] = val - return d - raise ValueError("unknown type tag") - - -def pop(string): - """pop(string,encoding='utf_8') -> (object, remain) - - This function parses a tnetstring into a python object. - It returns a tuple giving the parsed object and a string - containing any unparsed data from the end of the string. - """ - # Parse out data length, type and remaining string. - try: - dlen, rest = string.split(b':', 1) - dlen = int(dlen) - except ValueError: - raise ValueError("not a tnetstring: missing or invalid length prefix: {}".format(string)) - try: - data, tns_type, remain = rest[:dlen], rest[dlen:dlen + 1], rest[dlen + 1:] - except IndexError: - # This fires if len(rest) < dlen, meaning we don't need - # to further validate that data is the right length. - raise ValueError("not a tnetstring: invalid length prefix: {}".format(dlen)) - # Parse the data based on the type tag. - if tns_type == b',': - return data, remain - if tns_type == b'#': - try: - return int(data), remain - except ValueError: - raise ValueError("not a tnetstring: invalid integer literal: {}".format(data)) - if tns_type == b'^': - try: - return float(data), remain - except ValueError: - raise ValueError("not a tnetstring: invalid float literal: {}".format(data)) - if tns_type == b'!': - if data == b'true': - return True, remain - elif data == b'false': - return False, remain - else: - raise ValueError("not a tnetstring: invalid boolean literal: {}".format(data)) - if tns_type == b'~': - if data: - raise ValueError("not a tnetstring: invalid null literal") - return None, remain - if tns_type == b']': - l = [] - while data: - item, data = pop(data) - l.append(item) - return (l, remain) - if tns_type == b'}': - d = {} - while data: - key, data = pop(data) - val, data = pop(data) - d[key] = val - return d, remain - raise ValueError("unknown type tag: {}".format(tns_type)) diff --git a/mitmproxy/contrib/py3/__init__.py b/mitmproxy/contrib/py3/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/mitmproxy/contrib/py3/tnetstring.py b/mitmproxy/contrib/py3/tnetstring.py deleted file mode 100644 index 6998fc82..00000000 --- a/mitmproxy/contrib/py3/tnetstring.py +++ /dev/null @@ -1,237 +0,0 @@ -""" -tnetstring: data serialization using typed netstrings -====================================================== - -This is a custom Python 3 implementation of tnetstrings. -Compared to other implementations, the main difference -is the conversion of dictionary keys to str. - -An ordinary tnetstring is a blob of data prefixed with its length and postfixed -with its type. Here are some examples: - - >>> tnetstring.dumps("hello world") - 11:hello world, - >>> tnetstring.dumps(12345) - 5:12345# - >>> tnetstring.dumps([12345, True, 0]) - 19:5:12345#4:true!1:0#] - -This module gives you the following functions: - - :dump: dump an object as a tnetstring to a file - :dumps: dump an object as a tnetstring to a string - :load: load a tnetstring-encoded object from a file - :loads: load a tnetstring-encoded object from a string - -Note that since parsing a tnetstring requires reading all the data into memory -at once, there's no efficiency gain from using the file-based versions of these -functions. They're only here so you can use load() to read precisely one -item from a file or socket without consuming any extra data. - -The tnetstrings specification explicitly states that strings are binary blobs -and forbids the use of unicode at the protocol level. -**This implementation decodes dictionary keys as surrogate-escaped ASCII**, -all other strings are returned as plain bytes. - -:Copyright: (c) 2012-2013 by Ryan Kelly . -:Copyright: (c) 2014 by Carlo Pires . -:Copyright: (c) 2016 by Maximilian Hils . - -:License: MIT -""" - -import collections -from typing import io, Union, Tuple - -TSerializable = Union[None, bool, int, float, bytes, list, tuple, dict] - - -def dumps(value: TSerializable) -> bytes: - """ - This function dumps a python object as a tnetstring. - """ - # This uses a deque to collect output fragments in reverse order, - # then joins them together at the end. It's measurably faster - # than creating all the intermediate strings. - q = collections.deque() - _rdumpq(q, 0, value) - return b''.join(q) - - -def dump(value: TSerializable, file_handle: io.BinaryIO) -> None: - """ - This function dumps a python object as a tnetstring and - writes it to the given file. - """ - file_handle.write(dumps(value)) - - -def _rdumpq(q: collections.deque, size: int, value: TSerializable) -> int: - """ - Dump value as a tnetstring, to a deque instance, last chunks first. - - This function generates the tnetstring representation of the given value, - pushing chunks of the output onto the given deque instance. It pushes - the last chunk first, then recursively generates more chunks. - - When passed in the current size of the string in the queue, it will return - the new size of the string in the queue. - - Operating last-chunk-first makes it easy to calculate the size written - for recursive structures without having to build their representation as - a string. This is measurably faster than generating the intermediate - strings, especially on deeply nested structures. - """ - write = q.appendleft - if value is None: - write(b'0:~') - return size + 3 - elif value is True: - write(b'4:true!') - return size + 7 - elif value is False: - write(b'5:false!') - return size + 8 - elif isinstance(value, int): - data = str(value).encode() - ldata = len(data) - span = str(ldata).encode() - write(b'%s:%s#' % (span, data)) - return size + 2 + len(span) + ldata - elif isinstance(value, float): - # Use repr() for float rather than str(). - # It round-trips more accurately. - # Probably unnecessary in later python versions that - # use David Gay's ftoa routines. - data = repr(value).encode() - ldata = len(data) - span = str(ldata).encode() - write(b'%s:%s^' % (span, data)) - return size + 2 + len(span) + ldata - elif isinstance(value, bytes): - lvalue = len(value) - span = str(lvalue).encode() - write(b'%s:%s,' % (span, value)) - return size + 2 + len(span) + lvalue - elif isinstance(value, (list, tuple)): - write(b']') - init_size = size = size + 1 - for item in reversed(value): - size = _rdumpq(q, size, item) - span = str(size - init_size).encode() - write(b':') - write(span) - return size + 1 + len(span) - elif isinstance(value, dict): - write(b'}') - init_size = size = size + 1 - for (k, v) in value.items(): - if isinstance(k, str): - k = k.encode("ascii", "strict") - size = _rdumpq(q, size, v) - size = _rdumpq(q, size, k) - span = str(size - init_size).encode() - write(b':') - write(span) - return size + 1 + len(span) - else: - raise ValueError("unserializable object: {} ({})".format(value, type(value))) - - -def loads(string: bytes) -> TSerializable: - """ - This function parses a tnetstring into a python object. - """ - return pop(string)[0] - - -def load(file_handle: io.BinaryIO) -> TSerializable: - """load(file) -> object - - This function reads a tnetstring from a file and parses it into a - python object. The file must support the read() method, and this - function promises not to read more data than necessary. - """ - # Read the length prefix one char at a time. - # Note that the netstring spec explicitly forbids padding zeros. - c = file_handle.read(1) - data_length = b"" - while c.isdigit(): - data_length += c - if len(data_length) > 9: - raise ValueError("not a tnetstring: absurdly large length prefix") - c = file_handle.read(1) - if c != b":": - raise ValueError("not a tnetstring: missing or invalid length prefix") - - data = file_handle.read(int(data_length)) - data_type = file_handle.read(1)[0] - - return parse(data_type, data) - - -def parse(data_type: int, data: bytes) -> TSerializable: - if data_type == ord(b','): - return data - if data_type == ord(b'#'): - try: - return int(data) - except ValueError: - raise ValueError("not a tnetstring: invalid integer literal: {}".format(data)) - if data_type == ord(b'^'): - try: - return float(data) - except ValueError: - raise ValueError("not a tnetstring: invalid float literal: {}".format(data)) - if data_type == ord(b'!'): - if data == b'true': - return True - elif data == b'false': - return False - else: - raise ValueError("not a tnetstring: invalid boolean literal: {}".format(data)) - if data_type == ord(b'~'): - if data: - raise ValueError("not a tnetstring: invalid null literal") - return None - if data_type == ord(b']'): - l = [] - while data: - item, data = pop(data) - l.append(item) - return l - if data_type == ord(b'}'): - d = {} - while data: - key, data = pop(data) - if isinstance(key, bytes): - key = key.decode("ascii", "strict") - val, data = pop(data) - d[key] = val - return d - raise ValueError("unknown type tag: {}".format(data_type)) - - -def pop(data: bytes) -> Tuple[TSerializable, bytes]: - """ - This function parses a tnetstring into a python object. - It returns a tuple giving the parsed object and a string - containing any unparsed data from the end of the string. - """ - # Parse out data length, type and remaining string. - try: - length, data = data.split(b':', 1) - length = int(length) - except ValueError: - raise ValueError("not a tnetstring: missing or invalid length prefix: {}".format(data)) - try: - data, data_type, remain = data[:length], data[length], data[length + 1:] - except IndexError: - # This fires if len(data) < dlen, meaning we don't need - # to further validate that data is the right length. - raise ValueError("not a tnetstring: invalid length prefix: {}".format(length)) - # Parse the data based on the type tag. - return parse(data_type, data), remain - - -__all__ = ["dump", "dumps", "load", "loads", "pop"] diff --git a/mitmproxy/contrib/py3/tnetstring_tests.py b/mitmproxy/contrib/py3/tnetstring_tests.py deleted file mode 100644 index 4ee184d5..00000000 --- a/mitmproxy/contrib/py3/tnetstring_tests.py +++ /dev/null @@ -1,133 +0,0 @@ -import unittest -import random -import math -import io -from . import tnetstring -import struct - -MAXINT = 2 ** (struct.Struct('i').size * 8 - 1) - 1 - -FORMAT_EXAMPLES = { - b'0:}': {}, - b'0:]': [], - b'51:5:hello,39:11:12345678901#4:this,4:true!0:~4:\x00\x00\x00\x00,]}': - {'hello': [12345678901, b'this', True, None, b'\x00\x00\x00\x00']}, - b'5:12345#': 12345, - b'12:this is cool,': b'this is cool', - b'0:,': b'', - b'0:~': None, - b'4:true!': True, - b'5:false!': False, - b'10:\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00,': b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00', - b'24:5:12345#5:67890#5:xxxxx,]': [12345, 67890, b'xxxxx'], - b'18:3:0.1^3:0.2^3:0.3^]': [0.1, 0.2, 0.3], - b'243:238:233:228:223:218:213:208:203:198:193:188:183:178:173:168:163:158:153:148:143:138:133:128:123:118:113:108:103:99:95:91:87:83:79:75:71:67:63:59:55:51:47:43:39:35:31:27:23:19:15:11:hello-there,]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]': [[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[[b'hello-there']]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]]] -} - -def get_random_object(random=random, depth=0): - """Generate a random serializable object.""" - # The probability of generating a scalar value increases as the depth increase. - # This ensures that we bottom out eventually. - if random.randint(depth,10) <= 4: - what = random.randint(0,1) - if what == 0: - n = random.randint(0,10) - l = [] - for _ in range(n): - l.append(get_random_object(random,depth+1)) - return l - if what == 1: - n = random.randint(0,10) - d = {} - for _ in range(n): - n = random.randint(0,100) - k = str([random.randint(32,126) for _ in range(n)]) - d[k] = get_random_object(random,depth+1) - return d - else: - what = random.randint(0,4) - if what == 0: - return None - if what == 1: - return True - if what == 2: - return False - if what == 3: - if random.randint(0,1) == 0: - return random.randint(0,MAXINT) - else: - return -1 * random.randint(0,MAXINT) - n = random.randint(0,100) - return bytes([random.randint(32,126) for _ in range(n)]) - -class Test_Format(unittest.TestCase): - def test_roundtrip_format_examples(self): - for data, expect in FORMAT_EXAMPLES.items(): - self.assertEqual(expect,tnetstring.loads(data)) - self.assertEqual(expect,tnetstring.loads(tnetstring.dumps(expect))) - self.assertEqual((expect,b''),tnetstring.pop(data)) - - def test_roundtrip_format_random(self): - for _ in range(500): - v = get_random_object() - self.assertEqual(v,tnetstring.loads(tnetstring.dumps(v))) - self.assertEqual((v,b""),tnetstring.pop(tnetstring.dumps(v))) - - def test_unicode_handling(self): - with self.assertRaises(ValueError): - tnetstring.dumps("hello") - self.assertEqual(tnetstring.dumps("hello".encode()),b"5:hello,") - self.assertEqual(type(tnetstring.loads(b"5:hello,")),bytes) - - def test_roundtrip_format_unicode(self): - for _ in range(500): - v = get_random_object() - self.assertEqual(v,tnetstring.loads(tnetstring.dumps(v))) - self.assertEqual((v,b''),tnetstring.pop(tnetstring.dumps(v))) - - def test_roundtrip_big_integer(self): - i1 = math.factorial(30000) - s = tnetstring.dumps(i1) - i2 = tnetstring.loads(s) - self.assertEqual(i1, i2) - -class Test_FileLoading(unittest.TestCase): - def test_roundtrip_file_examples(self): - for data, expect in FORMAT_EXAMPLES.items(): - s = io.BytesIO() - s.write(data) - s.write(b'OK') - s.seek(0) - self.assertEqual(expect,tnetstring.load(s)) - self.assertEqual(b'OK',s.read()) - s = io.BytesIO() - tnetstring.dump(expect,s) - s.write(b'OK') - s.seek(0) - self.assertEqual(expect,tnetstring.load(s)) - self.assertEqual(b'OK',s.read()) - - def test_roundtrip_file_random(self): - for _ in range(500): - v = get_random_object() - s = io.BytesIO() - tnetstring.dump(v,s) - s.write(b'OK') - s.seek(0) - self.assertEqual(v,tnetstring.load(s)) - self.assertEqual(b'OK',s.read()) - - def test_error_on_absurd_lengths(self): - s = io.BytesIO() - s.write(b'1000000000:pwned!,') - s.seek(0) - with self.assertRaises(ValueError): - tnetstring.load(s) - self.assertEqual(s.read(1),b':') - -def suite(): - loader = unittest.TestLoader() - suite = unittest.TestSuite() - suite.addTest(loader.loadTestsFromTestCase(Test_Format)) - suite.addTest(loader.loadTestsFromTestCase(Test_FileLoading)) - return suite \ No newline at end of file diff --git a/mitmproxy/contrib/tnetstring.py b/mitmproxy/contrib/tnetstring.py index 1ebaba21..5fc26b45 100644 --- a/mitmproxy/contrib/tnetstring.py +++ b/mitmproxy/contrib/tnetstring.py @@ -1,8 +1,254 @@ +""" +tnetstring: data serialization using typed netstrings +====================================================== + +This is a custom Python 3 implementation of tnetstrings. +Compared to other implementations, the main difference +is that this implementation supports a custom unicode datatype. + +An ordinary tnetstring is a blob of data prefixed with its length and postfixed +with its type. Here are some examples: + + >>> tnetstring.dumps("hello world") + 11:hello world, + >>> tnetstring.dumps(12345) + 5:12345# + >>> tnetstring.dumps([12345, True, 0]) + 19:5:12345#4:true!1:0#] + +This module gives you the following functions: + + :dump: dump an object as a tnetstring to a file + :dumps: dump an object as a tnetstring to a string + :load: load a tnetstring-encoded object from a file + :loads: load a tnetstring-encoded object from a string + +Note that since parsing a tnetstring requires reading all the data into memory +at once, there's no efficiency gain from using the file-based versions of these +functions. They're only here so you can use load() to read precisely one +item from a file or socket without consuming any extra data. + +The tnetstrings specification explicitly states that strings are binary blobs +and forbids the use of unicode at the protocol level. +**This implementation decodes dictionary keys as surrogate-escaped ASCII**, +all other strings are returned as plain bytes. + +:Copyright: (c) 2012-2013 by Ryan Kelly . +:Copyright: (c) 2014 by Carlo Pires . +:Copyright: (c) 2016 by Maximilian Hils . + +:License: MIT +""" + +import collections import six +from typing import io, Union, Tuple # noqa + +TSerializable = Union[None, bool, int, float, bytes, list, tuple, dict] + + +def dumps(value): + # type: (TSerializable) -> bytes + """ + This function dumps a python object as a tnetstring. + """ + # This uses a deque to collect output fragments in reverse order, + # then joins them together at the end. It's measurably faster + # than creating all the intermediate strings. + q = collections.deque() + _rdumpq(q, 0, value) + return b''.join(q) + + +def dump(value, file_handle): + # type: (TSerializable, io.BinaryIO) -> None + """ + This function dumps a python object as a tnetstring and + writes it to the given file. + """ + file_handle.write(dumps(value)) + + +def _rdumpq(q, size, value): + # type: (collections.deque, int, TSerializable) -> int + """ + Dump value as a tnetstring, to a deque instance, last chunks first. + + This function generates the tnetstring representation of the given value, + pushing chunks of the output onto the given deque instance. It pushes + the last chunk first, then recursively generates more chunks. + + When passed in the current size of the string in the queue, it will return + the new size of the string in the queue. + + Operating last-chunk-first makes it easy to calculate the size written + for recursive structures without having to build their representation as + a string. This is measurably faster than generating the intermediate + strings, especially on deeply nested structures. + """ + write = q.appendleft + if value is None: + write(b'0:~') + return size + 3 + elif value is True: + write(b'4:true!') + return size + 7 + elif value is False: + write(b'5:false!') + return size + 8 + elif isinstance(value, int): + data = str(value).encode() + ldata = len(data) + span = str(ldata).encode() + write(b'%s:%s#' % (span, data)) + return size + 2 + len(span) + ldata + elif isinstance(value, float): + # Use repr() for float rather than str(). + # It round-trips more accurately. + # Probably unnecessary in later python versions that + # use David Gay's ftoa routines. + data = repr(value).encode() + ldata = len(data) + span = str(ldata).encode() + write(b'%s:%s^' % (span, data)) + return size + 2 + len(span) + ldata + elif isinstance(value, bytes): + data = value + ldata = len(data) + span = str(ldata).encode() + write(b'%s:%s,' % (span, data)) + return size + 2 + len(span) + ldata + elif isinstance(value, six.text_type): + data = value.encode() + ldata = len(data) + span = str(ldata).encode() + write(b'%s:%s;' % (span, data)) + return size + 2 + len(span) + ldata + elif isinstance(value, (list, tuple)): + write(b']') + init_size = size = size + 1 + for item in reversed(value): + size = _rdumpq(q, size, item) + span = str(size - init_size).encode() + write(b':') + write(span) + return size + 1 + len(span) + elif isinstance(value, dict): + write(b'}') + init_size = size = size + 1 + for (k, v) in value.items(): + if isinstance(k, str): + k = k.encode("ascii", "strict") + size = _rdumpq(q, size, v) + size = _rdumpq(q, size, k) + span = str(size - init_size).encode() + write(b':') + write(span) + return size + 1 + len(span) + else: + raise ValueError("unserializable object: {} ({})".format(value, type(value))) + + +def loads(string): + # type: (bytes) -> TSerializable + """ + This function parses a tnetstring into a python object. + """ + return pop(string)[0] + + +def load(file_handle): + # type: (io.BinaryIO) -> TSerializable + """load(file) -> object + + This function reads a tnetstring from a file and parses it into a + python object. The file must support the read() method, and this + function promises not to read more data than necessary. + """ + # Read the length prefix one char at a time. + # Note that the netstring spec explicitly forbids padding zeros. + c = file_handle.read(1) + data_length = b"" + while c.isdigit(): + data_length += c + if len(data_length) > 9: + raise ValueError("not a tnetstring: absurdly large length prefix") + c = file_handle.read(1) + if c != b":": + raise ValueError("not a tnetstring: missing or invalid length prefix") + + data = file_handle.read(int(data_length)) + data_type = file_handle.read(1)[0] + + return parse(data_type, data) + + +def parse(data_type, data): + # type: (int, bytes) -> TSerializable + if data_type == ord(b','): + return data + if data_type == ord(b';'): + return data.decode() + if data_type == ord(b'#'): + try: + return int(data) + except ValueError: + raise ValueError("not a tnetstring: invalid integer literal: {}".format(data)) + if data_type == ord(b'^'): + try: + return float(data) + except ValueError: + raise ValueError("not a tnetstring: invalid float literal: {}".format(data)) + if data_type == ord(b'!'): + if data == b'true': + return True + elif data == b'false': + return False + else: + raise ValueError("not a tnetstring: invalid boolean literal: {}".format(data)) + if data_type == ord(b'~'): + if data: + raise ValueError("not a tnetstring: invalid null literal") + return None + if data_type == ord(b']'): + l = [] + while data: + item, data = pop(data) + l.append(item) + return l + if data_type == ord(b'}'): + d = {} + while data: + key, data = pop(data) + if isinstance(key, bytes): + key = key.decode("ascii", "strict") + val, data = pop(data) + d[key] = val + return d + raise ValueError("unknown type tag: {}".format(data_type)) + + +def pop(data): + # type: (bytes) -> Tuple[TSerializable, bytes] + """ + This function parses a tnetstring into a python object. + It returns a tuple giving the parsed object and a string + containing any unparsed data from the end of the string. + """ + # Parse out data length, type and remaining string. + try: + length, data = data.split(b':', 1) + length = int(length) + except ValueError: + raise ValueError("not a tnetstring: missing or invalid length prefix: {}".format(data)) + try: + data, data_type, remain = data[:length], data[length], data[length + 1:] + except IndexError: + # This fires if len(data) < dlen, meaning we don't need + # to further validate that data is the right length. + raise ValueError("not a tnetstring: invalid length prefix: {}".format(length)) + # Parse the data based on the type tag. + return parse(data_type, data), remain -if six.PY2: - from .py2.tnetstring import load, loads, dump, dumps, pop -else: - from .py3.tnetstring import load, loads, dump, dumps, pop -__all__ = ["load", "loads", "dump", "dumps", "pop"] +__all__ = ["dump", "dumps", "load", "loads", "pop"] -- cgit v1.2.3 From d406bee988dc01126cfbdfc938b561e10b518610 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Tue, 5 Jul 2016 20:28:13 -0700 Subject: tnetstring3: adapt to unicode support --- mitmproxy/contrib/tnetstring.py | 14 +++--- mitmproxy/flow/io.py | 7 ++- mitmproxy/flow/io_compat.py | 76 +++++++++++++++++++++---------- mitmproxy/models/connections.py | 2 +- test/mitmproxy/test_contrib_tnetstring.py | 8 +--- 5 files changed, 66 insertions(+), 41 deletions(-) diff --git a/mitmproxy/contrib/tnetstring.py b/mitmproxy/contrib/tnetstring.py index 5fc26b45..0383f98e 100644 --- a/mitmproxy/contrib/tnetstring.py +++ b/mitmproxy/contrib/tnetstring.py @@ -96,7 +96,7 @@ def _rdumpq(q, size, value): elif value is False: write(b'5:false!') return size + 8 - elif isinstance(value, int): + elif isinstance(value, six.integer_types): data = str(value).encode() ldata = len(data) span = str(ldata).encode() @@ -119,7 +119,7 @@ def _rdumpq(q, size, value): write(b'%s:%s,' % (span, data)) return size + 2 + len(span) + ldata elif isinstance(value, six.text_type): - data = value.encode() + data = value.encode("utf8") ldata = len(data) span = str(ldata).encode() write(b'%s:%s;' % (span, data)) @@ -137,8 +137,6 @@ def _rdumpq(q, size, value): write(b'}') init_size = size = size + 1 for (k, v) in value.items(): - if isinstance(k, str): - k = k.encode("ascii", "strict") size = _rdumpq(q, size, v) size = _rdumpq(q, size, k) span = str(size - init_size).encode() @@ -184,13 +182,17 @@ def load(file_handle): def parse(data_type, data): + if six.PY2: + data_type = ord(data_type) # type: (int, bytes) -> TSerializable if data_type == ord(b','): return data if data_type == ord(b';'): - return data.decode() + return data.decode("utf8") if data_type == ord(b'#'): try: + if six.PY2: + return long(data) return int(data) except ValueError: raise ValueError("not a tnetstring: invalid integer literal: {}".format(data)) @@ -220,8 +222,6 @@ def parse(data_type, data): d = {} while data: key, data = pop(data) - if isinstance(key, bytes): - key = key.decode("ascii", "strict") val, data = pop(data) d[key] = val return d diff --git a/mitmproxy/flow/io.py b/mitmproxy/flow/io.py index e5716940..276d7a5b 100644 --- a/mitmproxy/flow/io.py +++ b/mitmproxy/flow/io.py @@ -44,10 +44,9 @@ class FlowReader: raise exceptions.FlowReadException(str(e)) if can_tell: off = self.fo.tell() - data_type = data["type"].decode() - if data_type not in models.FLOW_TYPES: - raise exceptions.FlowReadException("Unknown flow type: {}".format(data_type)) - yield models.FLOW_TYPES[data_type].from_state(data) + if data["type"] not in models.FLOW_TYPES: + raise exceptions.FlowReadException("Unknown flow type: {}".format(data["type"])) + yield models.FLOW_TYPES[data["type"]].from_state(data) except ValueError: # Error is due to EOF if can_tell and self.fo.tell() == off and self.fo.read() == b'': diff --git a/mitmproxy/flow/io_compat.py b/mitmproxy/flow/io_compat.py index 55971f5e..1e67dde4 100644 --- a/mitmproxy/flow/io_compat.py +++ b/mitmproxy/flow/io_compat.py @@ -3,44 +3,74 @@ This module handles the import of mitmproxy flows generated by old versions. """ from __future__ import absolute_import, print_function, division +import six + from netlib import version def convert_013_014(data): - data["request"]["first_line_format"] = data["request"].pop("form_in") - data["request"]["http_version"] = "HTTP/" + ".".join(str(x) for x in data["request"].pop("httpversion")) - data["response"]["http_version"] = "HTTP/" + ".".join(str(x) for x in data["response"].pop("httpversion")) - data["response"]["status_code"] = data["response"].pop("code") - data["response"]["body"] = data["response"].pop("content") - data["server_conn"].pop("state") - data["server_conn"]["via"] = None - data["version"] = (0, 14) + data[b"request"][b"first_line_format"] = data[b"request"].pop(b"form_in") + data[b"request"][b"http_version"] = b"HTTP/" + ".".join(str(x) for x in data[b"request"].pop(b"httpversion")).encode() + data[b"response"][b"http_version"] = b"HTTP/" + ".".join(str(x) for x in data[b"response"].pop(b"httpversion")).encode() + data[b"response"][b"status_code"] = data[b"response"].pop(b"code") + data[b"response"][b"body"] = data[b"response"].pop(b"content") + data[b"server_conn"].pop(b"state") + data[b"server_conn"][b"via"] = None + data[b"version"] = (0, 14) return data def convert_014_015(data): - data["version"] = (0, 15) + data[b"version"] = (0, 15) return data def convert_015_016(data): - for m in ("request", "response"): - if "body" in data[m]: - data[m]["content"] = data[m].pop("body") - if "msg" in data["response"]: - data["response"]["reason"] = data["response"].pop("msg") - data["request"].pop("form_out", None) - data["version"] = (0, 16) + for m in (b"request", b"response"): + if b"body" in data[m]: + data[m][b"content"] = data[m].pop(b"body") + if b"msg" in data[b"response"]: + data[b"response"][b"reason"] = data[b"response"].pop(b"msg") + data[b"request"].pop(b"form_out", None) + data[b"version"] = (0, 16) return data def convert_016_017(data): - data["server_conn"]["peer_address"] = None - data["version"] = (0, 17) + data[b"server_conn"][b"peer_address"] = None + data[b"version"] = (0, 17) return data def convert_017_018(data): + if not six.PY2: + # Python 2 versions of mitmproxy did not support serializing unicode. + def dict_keys_to_str(o): + if isinstance(o, dict): + return {k.decode(): dict_keys_to_str(v) for k, v in o.items()} + else: + return o + data = dict_keys_to_str(data) + + def dict_vals_to_str(o, decode): + for k, v in decode.items(): + if not o or k not in o: + continue + if v is True: + o[k] = o[k].decode() + else: + dict_vals_to_str(o[k], v) + dict_vals_to_str(data, { + "type": True, + "id": True, + "request": { + "first_line_format": True + }, + "error": { + "msg": True + } + }) + data["server_conn"]["ip_address"] = data["server_conn"].pop("peer_address") data["version"] = (0, 18) return data @@ -57,13 +87,13 @@ converters = { def migrate_flow(flow_data): while True: - flow_version = tuple(flow_data["version"][:2]) - if flow_version == version.IVERSION[:2]: + flow_version = tuple(flow_data.get(b"version", flow_data.get("version"))) + if flow_version[:2] == version.IVERSION[:2]: break - elif flow_version in converters: - flow_data = converters[flow_version](flow_data) + elif flow_version[:2] in converters: + flow_data = converters[flow_version[:2]](flow_data) else: - v = ".".join(str(i) for i in flow_data["version"]) + v = ".".join(str(i) for i in flow_version) raise ValueError( "{} cannot read files serialized with version {}.".format(version.MITMPROXY, v) ) diff --git a/mitmproxy/models/connections.py b/mitmproxy/models/connections.py index d71379bc..3e1a0928 100644 --- a/mitmproxy/models/connections.py +++ b/mitmproxy/models/connections.py @@ -162,7 +162,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): source_address=tcp.Address, ssl_established=bool, cert=certutils.SSLCert, - sni=str, + sni=bytes, timestamp_start=float, timestamp_tcp_setup=float, timestamp_ssl_setup=float, diff --git a/test/mitmproxy/test_contrib_tnetstring.py b/test/mitmproxy/test_contrib_tnetstring.py index 8ae35a25..908cec27 100644 --- a/test/mitmproxy/test_contrib_tnetstring.py +++ b/test/mitmproxy/test_contrib_tnetstring.py @@ -15,7 +15,9 @@ FORMAT_EXAMPLES = { {'hello': [12345678901, b'this', True, None, b'\x00\x00\x00\x00']}, b'5:12345#': 12345, b'12:this is cool,': b'this is cool', + b'19:this is unicode \xe2\x98\x85;': u'this is unicode \u2605', b'0:,': b'', + b'0:;': u'', b'0:~': None, b'4:true!': True, b'5:false!': False, @@ -78,12 +80,6 @@ class Test_Format(unittest.TestCase): self.assertEqual(v, tnetstring.loads(tnetstring.dumps(v))) self.assertEqual((v, b""), tnetstring.pop(tnetstring.dumps(v))) - def test_unicode_handling(self): - with self.assertRaises(ValueError): - tnetstring.dumps(u"hello") - self.assertEqual(tnetstring.dumps(u"hello".encode()), b"5:hello,") - self.assertEqual(type(tnetstring.loads(b"5:hello,")), bytes) - def test_roundtrip_format_unicode(self): for _ in range(500): v = get_random_object() -- cgit v1.2.3 From 2c37ebfc7215649cc633047c0b036de66d847af1 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 6 Jul 2016 13:24:50 -0700 Subject: fix dump file cross compat between python versions --- mitmproxy/contrib/tnetstring.py | 10 ++++- mitmproxy/flow/io_compat.py | 83 ++++++++++++++++++++++++++--------------- 2 files changed, 61 insertions(+), 32 deletions(-) diff --git a/mitmproxy/contrib/tnetstring.py b/mitmproxy/contrib/tnetstring.py index 0383f98e..d99a83f9 100644 --- a/mitmproxy/contrib/tnetstring.py +++ b/mitmproxy/contrib/tnetstring.py @@ -116,13 +116,19 @@ def _rdumpq(q, size, value): data = value ldata = len(data) span = str(ldata).encode() - write(b'%s:%s,' % (span, data)) + write(b',') + write(data) + write(b':') + write(span) return size + 2 + len(span) + ldata elif isinstance(value, six.text_type): data = value.encode("utf8") ldata = len(data) span = str(ldata).encode() - write(b'%s:%s;' % (span, data)) + write(b';') + write(data) + write(b':') + write(span) return size + 2 + len(span) + ldata elif isinstance(value, (list, tuple)): write(b']') diff --git a/mitmproxy/flow/io_compat.py b/mitmproxy/flow/io_compat.py index 1e67dde4..ec825f71 100644 --- a/mitmproxy/flow/io_compat.py +++ b/mitmproxy/flow/io_compat.py @@ -5,13 +5,15 @@ from __future__ import absolute_import, print_function, division import six -from netlib import version +from netlib import version, strutils def convert_013_014(data): data[b"request"][b"first_line_format"] = data[b"request"].pop(b"form_in") - data[b"request"][b"http_version"] = b"HTTP/" + ".".join(str(x) for x in data[b"request"].pop(b"httpversion")).encode() - data[b"response"][b"http_version"] = b"HTTP/" + ".".join(str(x) for x in data[b"response"].pop(b"httpversion")).encode() + data[b"request"][b"http_version"] = b"HTTP/" + ".".join( + str(x) for x in data[b"request"].pop(b"httpversion")).encode() + data[b"response"][b"http_version"] = b"HTTP/" + ".".join( + str(x) for x in data[b"response"].pop(b"httpversion")).encode() data[b"response"][b"status_code"] = data[b"response"].pop(b"code") data[b"response"][b"body"] = data[b"response"].pop(b"content") data[b"server_conn"].pop(b"state") @@ -43,39 +45,57 @@ def convert_016_017(data): def convert_017_018(data): - if not six.PY2: - # Python 2 versions of mitmproxy did not support serializing unicode. - def dict_keys_to_str(o): - if isinstance(o, dict): - return {k.decode(): dict_keys_to_str(v) for k, v in o.items()} - else: - return o - data = dict_keys_to_str(data) - - def dict_vals_to_str(o, decode): - for k, v in decode.items(): - if not o or k not in o: - continue - if v is True: - o[k] = o[k].decode() - else: - dict_vals_to_str(o[k], v) - dict_vals_to_str(data, { - "type": True, - "id": True, - "request": { - "first_line_format": True - }, - "error": { - "msg": True - } - }) + # convert_unicode needs to be called for every dual release and the first py3-only release + data = convert_unicode(data) data["server_conn"]["ip_address"] = data["server_conn"].pop("peer_address") data["version"] = (0, 18) return data +def _convert_dict_keys(o): + # type: (Any) -> Any + if isinstance(o, dict): + return {strutils.native(k): _convert_dict_keys(v) for k, v in o.items()} + else: + return o + + +def _convert_dict_vals(o, values_to_convert): + # type: (dict, dict) -> dict + for k, v in values_to_convert.items(): + if not o or k not in o: + continue + if v is True: + o[k] = strutils.native(o[k]) + else: + _convert_dict_vals(o[k], v) + return o + + +def convert_unicode(data): + # type: (dict) -> dict + """ + The Python 2 version of mitmproxy serializes everything as bytes. + This method converts between Python 3 and Python 2 dumpfiles. + """ + if not six.PY2: + data = _convert_dict_keys(data) + data = _convert_dict_vals( + data, { + "type": True, + "id": True, + "request": { + "first_line_format": True + }, + "error": { + "msg": True + } + } + ) + return data + + converters = { (0, 13): convert_013_014, (0, 14): convert_014_015, @@ -97,4 +117,7 @@ def migrate_flow(flow_data): raise ValueError( "{} cannot read files serialized with version {}.".format(version.MITMPROXY, v) ) + # TODO: This should finally be moved in the converter for the first py3-only release. + # It's here so that a py2 0.18 dump can be read by py3 0.18 and vice versa. + flow_data = convert_unicode(flow_data) return flow_data -- cgit v1.2.3 From 8287ce7e6dcf31e65519629bb064044a44de46d1 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 6 Jul 2016 16:36:04 -0700 Subject: fix tests --- test/mitmproxy/test_contrib_tnetstring.py | 2 +- test/mitmproxy/tutils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/mitmproxy/test_contrib_tnetstring.py b/test/mitmproxy/test_contrib_tnetstring.py index 908cec27..05c4a7c9 100644 --- a/test/mitmproxy/test_contrib_tnetstring.py +++ b/test/mitmproxy/test_contrib_tnetstring.py @@ -12,7 +12,7 @@ FORMAT_EXAMPLES = { b'0:}': {}, b'0:]': [], b'51:5:hello,39:11:12345678901#4:this,4:true!0:~4:\x00\x00\x00\x00,]}': - {'hello': [12345678901, b'this', True, None, b'\x00\x00\x00\x00']}, + {b'hello': [12345678901, b'this', True, None, b'\x00\x00\x00\x00']}, b'5:12345#': 12345, b'12:this is cool,': b'this is cool', b'19:this is unicode \xe2\x98\x85;': u'this is unicode \u2605', diff --git a/test/mitmproxy/tutils.py b/test/mitmproxy/tutils.py index d0a09035..5aade60c 100644 --- a/test/mitmproxy/tutils.py +++ b/test/mitmproxy/tutils.py @@ -130,7 +130,7 @@ def tserver_conn(): timestamp_ssl_setup=3, timestamp_end=4, ssl_established=False, - sni="address", + sni=b"address", via=None )) c.reply = controller.DummyReply() -- cgit v1.2.3 From 64a867973d5bac136c2e1c3c11c457d6b04d6649 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Wed, 6 Jul 2016 21:03:17 -0700 Subject: sni is now str, not bytes --- mitmproxy/models/connections.py | 7 ++++--- mitmproxy/models/flow.py | 16 ++++++---------- mitmproxy/protocol/tls.py | 13 ++++++++----- netlib/tcp.py | 4 ++-- netlib/utils.py | 4 +--- pathod/pathod.py | 5 ++++- test/mitmproxy/test_server.py | 6 +++--- test/mitmproxy/tutils.py | 2 +- test/netlib/test_tcp.py | 26 +++++++++++++------------- test/pathod/test_pathoc.py | 4 ++-- 10 files changed, 44 insertions(+), 43 deletions(-) diff --git a/mitmproxy/models/connections.py b/mitmproxy/models/connections.py index 3e1a0928..570e89a9 100644 --- a/mitmproxy/models/connections.py +++ b/mitmproxy/models/connections.py @@ -8,7 +8,6 @@ import six from mitmproxy import stateobject from netlib import certutils -from netlib import strutils from netlib import tcp @@ -162,7 +161,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): source_address=tcp.Address, ssl_established=bool, cert=certutils.SSLCert, - sni=bytes, + sni=str, timestamp_start=float, timestamp_tcp_setup=float, timestamp_ssl_setup=float, @@ -206,6 +205,8 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): self.wfile.flush() def establish_ssl(self, clientcerts, sni, **kwargs): + if sni and not isinstance(sni, six.string_types): + raise ValueError("sni must be str, not " + type(sni).__name__) clientcert = None if clientcerts: if os.path.isfile(clientcerts): @@ -217,7 +218,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): if os.path.exists(path): clientcert = path - self.convert_to_ssl(cert=clientcert, sni=strutils.always_bytes(sni), **kwargs) + self.convert_to_ssl(cert=clientcert, sni=sni, **kwargs) self.sni = sni self.timestamp_ssl_setup = time.time() diff --git a/mitmproxy/models/flow.py b/mitmproxy/models/flow.py index 0e4f80cb..f4993b7a 100644 --- a/mitmproxy/models/flow.py +++ b/mitmproxy/models/flow.py @@ -9,6 +9,7 @@ from mitmproxy.models.connections import ClientConnection from mitmproxy.models.connections import ServerConnection from netlib import version +from typing import Optional # noqa class Error(stateobject.StateObject): @@ -70,18 +71,13 @@ class Flow(stateobject.StateObject): def __init__(self, type, client_conn, server_conn, live=None): self.type = type self.id = str(uuid.uuid4()) - self.client_conn = client_conn - """@type: ClientConnection""" - self.server_conn = server_conn - """@type: ServerConnection""" + self.client_conn = client_conn # type: ClientConnection + self.server_conn = server_conn # type: ServerConnection self.live = live - """@type: LiveConnection""" - self.error = None - """@type: Error""" - self.intercepted = False - """@type: bool""" - self._backup = None + self.error = None # type: Error + self.intercepted = False # type: bool + self._backup = None # type: Optional[Flow] self.reply = None _stateobject_attributes = dict( diff --git a/mitmproxy/protocol/tls.py b/mitmproxy/protocol/tls.py index 9f883b2b..8ef34493 100644 --- a/mitmproxy/protocol/tls.py +++ b/mitmproxy/protocol/tls.py @@ -10,6 +10,7 @@ import netlib.exceptions from mitmproxy import exceptions from mitmproxy.contrib.tls import _constructs from mitmproxy.protocol import base +from netlib import utils # taken from https://testssl.sh/openssl-rfc.mappping.html @@ -274,10 +275,11 @@ class TlsClientHello(object): is_valid_sni_extension = ( extension.type == 0x00 and len(extension.server_names) == 1 and - extension.server_names[0].type == 0 + extension.server_names[0].type == 0 and + utils.is_valid_host(extension.server_names[0].name) ) if is_valid_sni_extension: - return extension.server_names[0].name + return extension.server_names[0].name.decode("idna") @property def alpn_protocols(self): @@ -403,13 +405,14 @@ class TlsLayer(base.Layer): self._establish_tls_with_server() def set_server_tls(self, server_tls, sni=None): + # type: (bool, Union[six.text_type, None, False]) -> None """ Set the TLS settings for the next server connection that will be established. This function will not alter an existing connection. Args: server_tls: Shall we establish TLS with the server? - sni: ``bytes`` for a custom SNI value, + sni: ``str`` for a custom SNI value, ``None`` for the client SNI value, ``False`` if no SNI value should be sent. """ @@ -602,9 +605,9 @@ class TlsLayer(base.Layer): host = upstream_cert.cn.decode("utf8").encode("idna") # Also add SNI values. if self._client_hello.sni: - sans.add(self._client_hello.sni) + sans.add(self._client_hello.sni.encode("idna")) if self._custom_server_sni: - sans.add(self._custom_server_sni) + sans.add(self._custom_server_sni.encode("idna")) # RFC 2818: If a subjectAltName extension of type dNSName is present, that MUST be used as the identity. # In other words, the Common Name is irrelevant then. diff --git a/netlib/tcp.py b/netlib/tcp.py index 69dafc1f..cf099edd 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -676,7 +676,7 @@ class TCPClient(_Connection): self.connection = SSL.Connection(context, self.connection) if sni: self.sni = sni - self.connection.set_tlsext_host_name(sni) + self.connection.set_tlsext_host_name(sni.encode("idna")) self.connection.set_connect_state() try: self.connection.do_handshake() @@ -705,7 +705,7 @@ class TCPClient(_Connection): if self.cert.cn: crt["subject"] = [[["commonName", self.cert.cn.decode("ascii", "strict")]]] if sni: - hostname = sni.decode("ascii", "strict") + hostname = sni else: hostname = "no-hostname" ssl_match_hostname.match_hostname(crt, hostname) diff --git a/netlib/utils.py b/netlib/utils.py index 79340cbd..23c16dc3 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -73,11 +73,9 @@ _label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(? bool """ Checks if a hostname is valid. - - Args: - host (bytes): The hostname """ try: host.decode("idna") diff --git a/pathod/pathod.py b/pathod/pathod.py index 3df86aae..7087cba6 100644 --- a/pathod/pathod.py +++ b/pathod/pathod.py @@ -89,7 +89,10 @@ class PathodHandler(tcp.BaseHandler): self.http2_framedump = http2_framedump def handle_sni(self, connection): - self.sni = connection.get_servername() + sni = connection.get_servername() + if sni: + sni = sni.decode("idna") + self.sni = sni def http_serve_crafted(self, crafted, logctx): error, crafted = self.server.check_policy( diff --git a/test/mitmproxy/test_server.py b/test/mitmproxy/test_server.py index 1bbef975..0ab7624e 100644 --- a/test/mitmproxy/test_server.py +++ b/test/mitmproxy/test_server.py @@ -100,10 +100,10 @@ class CommonMixin: if not self.ssl: return - f = self.pathod("304", sni=b"testserver.com") + f = self.pathod("304", sni="testserver.com") assert f.status_code == 304 log = self.server.last_log() - assert log["request"]["sni"] == b"testserver.com" + assert log["request"]["sni"] == "testserver.com" class TcpMixin: @@ -498,7 +498,7 @@ class TestHttps2Http(tservers.ReverseProxyTest): assert p.request("get:'/p/200'").status_code == 200 def test_sni(self): - p = self.pathoc(ssl=True, sni=b"example.com") + p = self.pathoc(ssl=True, sni="example.com") assert p.request("get:'/p/200'").status_code == 200 assert all("Error in handle_sni" not in msg for msg in self.proxy.tlog) diff --git a/test/mitmproxy/tutils.py b/test/mitmproxy/tutils.py index 5aade60c..d0a09035 100644 --- a/test/mitmproxy/tutils.py +++ b/test/mitmproxy/tutils.py @@ -130,7 +130,7 @@ def tserver_conn(): timestamp_ssl_setup=3, timestamp_end=4, ssl_established=False, - sni=b"address", + sni="address", via=None )) c.reply = controller.DummyReply() diff --git a/test/netlib/test_tcp.py b/test/netlib/test_tcp.py index 590bcc01..273427d5 100644 --- a/test/netlib/test_tcp.py +++ b/test/netlib/test_tcp.py @@ -169,7 +169,7 @@ class TestServerSSL(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni=b"foo.com", options=SSL.OP_ALL) + c.convert_to_ssl(sni="foo.com", options=SSL.OP_ALL) testval = b"echo!\n" c.wfile.write(testval) c.wfile.flush() @@ -179,7 +179,7 @@ class TestServerSSL(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): assert not c.get_current_cipher() - c.convert_to_ssl(sni=b"foo.com") + c.convert_to_ssl(sni="foo.com") ret = c.get_current_cipher() assert ret assert "AES" in ret[0] @@ -195,7 +195,7 @@ class TestSSLv3Only(tservers.ServerTestBase): def test_failure(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - tutils.raises(TlsException, c.convert_to_ssl, sni=b"foo.com") + tutils.raises(TlsException, c.convert_to_ssl, sni="foo.com") class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): @@ -238,7 +238,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): with c.connect(): with tutils.raises(InvalidCertificateException): c.convert_to_ssl( - sni=b"example.mitmproxy.org", + sni="example.mitmproxy.org", verify_options=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt") ) @@ -272,7 +272,7 @@ class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase): with c.connect(): with tutils.raises(InvalidCertificateException): c.convert_to_ssl( - sni=b"mitmproxy.org", + sni="mitmproxy.org", verify_options=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt") ) @@ -291,7 +291,7 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): c.convert_to_ssl( - sni=b"example.mitmproxy.org", + sni="example.mitmproxy.org", verify_options=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt") ) @@ -307,7 +307,7 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): c.convert_to_ssl( - sni=b"example.mitmproxy.org", + sni="example.mitmproxy.org", verify_options=SSL.VERIFY_PEER, ca_path=tutils.test_data.path("data/verificationcerts/") ) @@ -371,8 +371,8 @@ class TestSNI(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni=b"foo.com") - assert c.sni == b"foo.com" + c.convert_to_ssl(sni="foo.com") + assert c.sni == "foo.com" assert c.rfile.readline() == b"foo.com" @@ -385,7 +385,7 @@ class TestServerCipherList(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni=b"foo.com") + c.convert_to_ssl(sni="foo.com") assert c.rfile.readline() == b"['RC4-SHA']" @@ -405,7 +405,7 @@ class TestServerCurrentCipher(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni=b"foo.com") + c.convert_to_ssl(sni="foo.com") assert b"RC4-SHA" in c.rfile.readline() @@ -418,7 +418,7 @@ class TestServerCipherListError(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - tutils.raises("handshake error", c.convert_to_ssl, sni=b"foo.com") + tutils.raises("handshake error", c.convert_to_ssl, sni="foo.com") class TestClientCipherListError(tservers.ServerTestBase): @@ -433,7 +433,7 @@ class TestClientCipherListError(tservers.ServerTestBase): tutils.raises( "cipher specification", c.convert_to_ssl, - sni=b"foo.com", + sni="foo.com", cipher_list="bogus" ) diff --git a/test/pathod/test_pathoc.py b/test/pathod/test_pathoc.py index 28f9f0f8..361a863b 100644 --- a/test/pathod/test_pathoc.py +++ b/test/pathod/test_pathoc.py @@ -54,10 +54,10 @@ class TestDaemonSSL(PathocTestDaemon): def test_sni(self): self.tval( ["get:/p/200"], - sni=b"foobar.com" + sni="foobar.com" ) log = self.d.log() - assert log[0]["request"]["sni"] == b"foobar.com" + assert log[0]["request"]["sni"] == "foobar.com" def test_showssl(self): assert "certificate chain" in self.tval(["get:/p/200"], showssl=True) -- cgit v1.2.3