aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mitmproxy/contrib/tnetstring.py291
-rw-r--r--mitmproxy/flow/io.py2
-rw-r--r--mitmproxy/flow/io_compat.py102
-rw-r--r--mitmproxy/models/connections.py5
-rw-r--r--mitmproxy/models/flow.py16
-rw-r--r--mitmproxy/protocol/tls.py13
-rw-r--r--netlib/tcp.py4
-rw-r--r--netlib/utils.py4
-rw-r--r--pathod/pathod.py5
-rw-r--r--test/mitmproxy/test_contrib_tnetstring.py10
-rw-r--r--test/mitmproxy/test_server.py6
-rw-r--r--test/netlib/test_tcp.py26
-rw-r--r--test/pathod/test_pathoc.py4
-rw-r--r--tox.ini2
14 files changed, 212 insertions, 278 deletions
diff --git a/mitmproxy/contrib/tnetstring.py b/mitmproxy/contrib/tnetstring.py
index 9bf20b09..d99a83f9 100644
--- a/mitmproxy/contrib/tnetstring.py
+++ b/mitmproxy/contrib/tnetstring.py
@@ -1,100 +1,67 @@
-# imported from the tnetstring project: https://github.com/rfk/tnetstring
-#
-# Copyright (c) 2011 Ryan Kelly
-#
-# Permission is hereby granted, free of charge, to any person obtaining a copy
-# of this software and associated documentation files (the "Software"), to deal
-# in the Software without restriction, including without limitation the rights
-# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-# copies of the Software, and to permit persons to whom the Software is
-# furnished to do so, subject to the following conditions:
-#
-# The above copyright notice and this permission notice shall be included in
-# all copies or substantial portions of the Software.
-#
-# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-# THE SOFTWARE.
"""
tnetstring: data serialization using typed netstrings
======================================================
+This is a custom Python 3 implementation of tnetstrings.
+Compared to other implementations, the main difference
+is that this implementation supports a custom unicode datatype.
-This is a data serialization library. It's a lot like JSON but it uses a
-new syntax called "typed netstrings" that Zed has proposed for use in the
-Mongrel2 webserver. It's designed to be simpler and easier to implement
-than JSON, with a happy consequence of also being faster in many cases.
-
-An ordinary netstring is a blob of data prefixed with its length and postfixed
-with a sanity-checking comma. The string "hello world" encodes like this::
+An ordinary tnetstring is a blob of data prefixed with its length and postfixed
+with its type. Here are some examples:
+ >>> tnetstring.dumps("hello world")
11:hello world,
-
-Typed netstrings add other datatypes by replacing the comma with a type tag.
-Here's the integer 12345 encoded as a tnetstring::
-
+ >>> tnetstring.dumps(12345)
5:12345#
-
-And here's the list [12345,True,0] which mixes integers and bools::
-
+ >>> tnetstring.dumps([12345, True, 0])
19:5:12345#4:true!1:0#]
-Simple enough? This module gives you the following functions:
+This module gives you the following functions:
:dump: dump an object as a tnetstring to a file
:dumps: dump an object as a tnetstring to a string
:load: load a tnetstring-encoded object from a file
:loads: load a tnetstring-encoded object from a string
- :pop: pop a tnetstring-encoded object from the front of a string
Note that since parsing a tnetstring requires reading all the data into memory
at once, there's no efficiency gain from using the file-based versions of these
functions. They're only here so you can use load() to read precisely one
item from a file or socket without consuming any extra data.
-By default tnetstrings work only with byte strings, not unicode. If you want
-unicode strings then pass an optional encoding to the various functions,
-like so::
+The tnetstrings specification explicitly states that strings are binary blobs
+and forbids the use of unicode at the protocol level.
+**This implementation decodes dictionary keys as surrogate-escaped ASCII**,
+all other strings are returned as plain bytes.
- >>> print(repr(tnetstring.loads("2:\\xce\\xb1,")))
- '\\xce\\xb1'
- >>>
- >>> print(repr(tnetstring.loads("2:\\xce\\xb1,","utf8")))
- u'\u03b1'
+:Copyright: (c) 2012-2013 by Ryan Kelly <ryan@rfk.id.au>.
+:Copyright: (c) 2014 by Carlo Pires <carlopires@gmail.com>.
+:Copyright: (c) 2016 by Maximilian Hils <tnetstring3@maximilianhils.com>.
+:License: MIT
"""
-from collections import deque
+import collections
import six
+from typing import io, Union, Tuple # noqa
-__ver_major__ = 0
-__ver_minor__ = 2
-__ver_patch__ = 0
-__ver_sub__ = ""
-__version__ = "%d.%d.%d%s" % (
- __ver_major__, __ver_minor__, __ver_patch__, __ver_sub__)
+TSerializable = Union[None, bool, int, float, bytes, list, tuple, dict]
def dumps(value):
+ # type: (TSerializable) -> bytes
"""
This function dumps a python object as a tnetstring.
"""
# This uses a deque to collect output fragments in reverse order,
# then joins them together at the end. It's measurably faster
# than creating all the intermediate strings.
- # If you're reading this to get a handle on the tnetstring format,
- # consider the _gdumps() function instead; it's a standard top-down
- # generator that's simpler to understand but much less efficient.
- q = deque()
+ q = collections.deque()
_rdumpq(q, 0, value)
return b''.join(q)
def dump(value, file_handle):
+ # type: (TSerializable, io.BinaryIO) -> None
"""
This function dumps a python object as a tnetstring and
writes it to the given file.
@@ -103,6 +70,7 @@ def dump(value, file_handle):
def _rdumpq(q, size, value):
+ # type: (collections.deque, int, TSerializable) -> int
"""
Dump value as a tnetstring, to a deque instance, last chunks first.
@@ -132,10 +100,7 @@ def _rdumpq(q, size, value):
data = str(value).encode()
ldata = len(data)
span = str(ldata).encode()
- write(b'#')
- write(data)
- write(b':')
- write(span)
+ write(b'%s:%s#' % (span, data))
return size + 2 + len(span) + ldata
elif isinstance(value, float):
# Use repr() for float rather than str().
@@ -145,19 +110,26 @@ def _rdumpq(q, size, value):
data = repr(value).encode()
ldata = len(data)
span = str(ldata).encode()
- write(b'^')
+ write(b'%s:%s^' % (span, data))
+ return size + 2 + len(span) + ldata
+ elif isinstance(value, bytes):
+ data = value
+ ldata = len(data)
+ span = str(ldata).encode()
+ write(b',')
write(data)
write(b':')
write(span)
return size + 2 + len(span) + ldata
- elif isinstance(value, bytes):
- lvalue = len(value)
- span = str(lvalue).encode()
- write(b',')
- write(value)
+ elif isinstance(value, six.text_type):
+ data = value.encode("utf8")
+ ldata = len(data)
+ span = str(ldata).encode()
+ write(b';')
+ write(data)
write(b':')
write(span)
- return size + 2 + len(span) + lvalue
+ return size + 2 + len(span) + ldata
elif isinstance(value, (list, tuple)):
write(b']')
init_size = size = size + 1
@@ -181,73 +153,16 @@ def _rdumpq(q, size, value):
raise ValueError("unserializable object: {} ({})".format(value, type(value)))
-def _gdumps(value):
- """
- Generate fragments of value dumped as a tnetstring.
-
- This is the naive dumping algorithm, implemented as a generator so that
- it's easy to pass to "".join() without building a new list.
-
- This is mainly here for comparison purposes; the _rdumpq version is
- measurably faster as it doesn't have to build intermediate strins.
- """
- if value is None:
- yield b'0:~'
- elif value is True:
- yield b'4:true!'
- elif value is False:
- yield b'5:false!'
- elif isinstance(value, six.integer_types):
- data = str(value).encode()
- yield str(len(data)).encode()
- yield b':'
- yield data
- yield b'#'
- elif isinstance(value, float):
- data = repr(value).encode()
- yield str(len(data)).encode()
- yield b':'
- yield data
- yield b'^'
- elif isinstance(value, bytes):
- yield str(len(value)).encode()
- yield b':'
- yield value
- yield b','
- elif isinstance(value, (list, tuple)):
- sub = []
- for item in value:
- sub.extend(_gdumps(item))
- sub = b''.join(sub)
- yield str(len(sub)).encode()
- yield b':'
- yield sub
- yield b']'
- elif isinstance(value, (dict,)):
- sub = []
- for (k, v) in value.items():
- sub.extend(_gdumps(k))
- sub.extend(_gdumps(v))
- sub = b''.join(sub)
- yield str(len(sub)).encode()
- yield b':'
- yield sub
- yield b'}'
- else:
- raise ValueError("unserializable object")
-
-
def loads(string):
+ # type: (bytes) -> TSerializable
"""
This function parses a tnetstring into a python object.
"""
- # No point duplicating effort here. In the C-extension version,
- # loads() is measurably faster then pop() since it can avoid
- # the overhead of building a second string.
return pop(string)[0]
def load(file_handle):
+ # type: (io.BinaryIO) -> TSerializable
"""load(file) -> object
This function reads a tnetstring from a file and parses it into a
@@ -257,119 +172,89 @@ def load(file_handle):
# Read the length prefix one char at a time.
# Note that the netstring spec explicitly forbids padding zeros.
c = file_handle.read(1)
- if not c.isdigit():
- raise ValueError("not a tnetstring: missing or invalid length prefix")
- datalen = ord(c) - ord('0')
- c = file_handle.read(1)
- if datalen != 0:
- while c.isdigit():
- datalen = (10 * datalen) + (ord(c) - ord('0'))
- if datalen > 999999999:
- errmsg = "not a tnetstring: absurdly large length prefix"
- raise ValueError(errmsg)
- c = file_handle.read(1)
- if c != b':':
+ data_length = b""
+ while c.isdigit():
+ data_length += c
+ if len(data_length) > 9:
+ raise ValueError("not a tnetstring: absurdly large length prefix")
+ c = file_handle.read(1)
+ if c != b":":
raise ValueError("not a tnetstring: missing or invalid length prefix")
- # Now we can read and parse the payload.
- # This repeats the dispatch logic of pop() so we can avoid
- # re-constructing the outermost tnetstring.
- data = file_handle.read(datalen)
- if len(data) != datalen:
- raise ValueError("not a tnetstring: length prefix too big")
- tns_type = file_handle.read(1)
- if tns_type == b',':
+
+ data = file_handle.read(int(data_length))
+ data_type = file_handle.read(1)[0]
+
+ return parse(data_type, data)
+
+
+def parse(data_type, data):
+ if six.PY2:
+ data_type = ord(data_type)
+ # type: (int, bytes) -> TSerializable
+ if data_type == ord(b','):
return data
- if tns_type == b'#':
+ if data_type == ord(b';'):
+ return data.decode("utf8")
+ if data_type == ord(b'#'):
try:
+ if six.PY2:
+ return long(data)
return int(data)
except ValueError:
- raise ValueError("not a tnetstring: invalid integer literal")
- if tns_type == b'^':
+ raise ValueError("not a tnetstring: invalid integer literal: {}".format(data))
+ if data_type == ord(b'^'):
try:
return float(data)
except ValueError:
- raise ValueError("not a tnetstring: invalid float literal")
- if tns_type == b'!':
+ raise ValueError("not a tnetstring: invalid float literal: {}".format(data))
+ if data_type == ord(b'!'):
if data == b'true':
return True
elif data == b'false':
return False
else:
- raise ValueError("not a tnetstring: invalid boolean literal")
- if tns_type == b'~':
+ raise ValueError("not a tnetstring: invalid boolean literal: {}".format(data))
+ if data_type == ord(b'~'):
if data:
raise ValueError("not a tnetstring: invalid null literal")
return None
- if tns_type == b']':
+ if data_type == ord(b']'):
l = []
while data:
item, data = pop(data)
l.append(item)
return l
- if tns_type == b'}':
+ if data_type == ord(b'}'):
d = {}
while data:
key, data = pop(data)
val, data = pop(data)
d[key] = val
return d
- raise ValueError("unknown type tag")
-
+ raise ValueError("unknown type tag: {}".format(data_type))
-def pop(string):
- """pop(string,encoding='utf_8') -> (object, remain)
+def pop(data):
+ # type: (bytes) -> Tuple[TSerializable, bytes]
+ """
This function parses a tnetstring into a python object.
It returns a tuple giving the parsed object and a string
containing any unparsed data from the end of the string.
"""
# Parse out data length, type and remaining string.
try:
- dlen, rest = string.split(b':', 1)
- dlen = int(dlen)
+ length, data = data.split(b':', 1)
+ length = int(length)
except ValueError:
- raise ValueError("not a tnetstring: missing or invalid length prefix: {}".format(string))
+ raise ValueError("not a tnetstring: missing or invalid length prefix: {}".format(data))
try:
- data, tns_type, remain = rest[:dlen], rest[dlen:dlen + 1], rest[dlen + 1:]
+ data, data_type, remain = data[:length], data[length], data[length + 1:]
except IndexError:
- # This fires if len(rest) < dlen, meaning we don't need
+ # This fires if len(data) < dlen, meaning we don't need
# to further validate that data is the right length.
- raise ValueError("not a tnetstring: invalid length prefix: {}".format(dlen))
- # Parse the data based on the type tag.
- if tns_type == b',':
- return data, remain
- if tns_type == b'#':
- try:
- return int(data), remain
- except ValueError:
- raise ValueError("not a tnetstring: invalid integer literal: {}".format(data))
- if tns_type == b'^':
- try:
- return float(data), remain
- except ValueError:
- raise ValueError("not a tnetstring: invalid float literal: {}".format(data))
- if tns_type == b'!':
- if data == b'true':
- return True, remain
- elif data == b'false':
- return False, remain
- else:
- raise ValueError("not a tnetstring: invalid boolean literal: {}".format(data))
- if tns_type == b'~':
- if data:
- raise ValueError("not a tnetstring: invalid null literal")
- return None, remain
- if tns_type == b']':
- l = []
- while data:
- item, data = pop(data)
- l.append(item)
- return (l, remain)
- if tns_type == b'}':
- d = {}
- while data:
- key, data = pop(data)
- val, data = pop(data)
- d[key] = val
- return d, remain
- raise ValueError("unknown type tag: {}".format(tns_type))
+ raise ValueError("not a tnetstring: invalid length prefix: {}".format(length))
+ # Parse the data based on the type tag.
+ return parse(data_type, data), remain
+
+
+__all__ = ["dump", "dumps", "load", "loads", "pop"]
diff --git a/mitmproxy/flow/io.py b/mitmproxy/flow/io.py
index 671ddf43..276d7a5b 100644
--- a/mitmproxy/flow/io.py
+++ b/mitmproxy/flow/io.py
@@ -49,7 +49,7 @@ class FlowReader:
yield models.FLOW_TYPES[data["type"]].from_state(data)
except ValueError:
# Error is due to EOF
- if can_tell and self.fo.tell() == off and self.fo.read() == '':
+ if can_tell and self.fo.tell() == off and self.fo.read() == b'':
return
raise exceptions.FlowReadException("Invalid data format.")
diff --git a/mitmproxy/flow/io_compat.py b/mitmproxy/flow/io_compat.py
index 1023e87f..ec825f71 100644
--- a/mitmproxy/flow/io_compat.py
+++ b/mitmproxy/flow/io_compat.py
@@ -3,50 +3,99 @@ This module handles the import of mitmproxy flows generated by old versions.
"""
from __future__ import absolute_import, print_function, division
-from netlib import version
+import six
+
+from netlib import version, strutils
def convert_013_014(data):
- data["request"]["first_line_format"] = data["request"].pop("form_in")
- data["request"]["http_version"] = "HTTP/" + ".".join(str(x) for x in data["request"].pop("httpversion"))
- data["response"]["status_code"] = data["response"].pop("code")
- data["response"]["body"] = data["response"].pop("content")
- data["server_conn"].pop("state")
- data["server_conn"]["via"] = None
- data["version"] = (0, 14)
+ data[b"request"][b"first_line_format"] = data[b"request"].pop(b"form_in")
+ data[b"request"][b"http_version"] = b"HTTP/" + ".".join(
+ str(x) for x in data[b"request"].pop(b"httpversion")).encode()
+ data[b"response"][b"http_version"] = b"HTTP/" + ".".join(
+ str(x) for x in data[b"response"].pop(b"httpversion")).encode()
+ data[b"response"][b"status_code"] = data[b"response"].pop(b"code")
+ data[b"response"][b"body"] = data[b"response"].pop(b"content")
+ data[b"server_conn"].pop(b"state")
+ data[b"server_conn"][b"via"] = None
+ data[b"version"] = (0, 14)
return data
def convert_014_015(data):
- data["version"] = (0, 15)
+ data[b"version"] = (0, 15)
return data
def convert_015_016(data):
- for m in ("request", "response"):
- if "body" in data[m]:
- data[m]["content"] = data[m].pop("body")
- if "httpversion" in data[m]:
- data[m]["http_version"] = data[m].pop("httpversion")
- if "msg" in data["response"]:
- data["response"]["reason"] = data["response"].pop("msg")
- data["request"].pop("form_out", None)
- data["version"] = (0, 16)
+ for m in (b"request", b"response"):
+ if b"body" in data[m]:
+ data[m][b"content"] = data[m].pop(b"body")
+ if b"msg" in data[b"response"]:
+ data[b"response"][b"reason"] = data[b"response"].pop(b"msg")
+ data[b"request"].pop(b"form_out", None)
+ data[b"version"] = (0, 16)
return data
def convert_016_017(data):
- data["server_conn"]["peer_address"] = None
- data["version"] = (0, 17)
+ data[b"server_conn"][b"peer_address"] = None
+ data[b"version"] = (0, 17)
return data
def convert_017_018(data):
+ # convert_unicode needs to be called for every dual release and the first py3-only release
+ data = convert_unicode(data)
+
data["server_conn"]["ip_address"] = data["server_conn"].pop("peer_address")
data["version"] = (0, 18)
return data
+def _convert_dict_keys(o):
+ # type: (Any) -> Any
+ if isinstance(o, dict):
+ return {strutils.native(k): _convert_dict_keys(v) for k, v in o.items()}
+ else:
+ return o
+
+
+def _convert_dict_vals(o, values_to_convert):
+ # type: (dict, dict) -> dict
+ for k, v in values_to_convert.items():
+ if not o or k not in o:
+ continue
+ if v is True:
+ o[k] = strutils.native(o[k])
+ else:
+ _convert_dict_vals(o[k], v)
+ return o
+
+
+def convert_unicode(data):
+ # type: (dict) -> dict
+ """
+ The Python 2 version of mitmproxy serializes everything as bytes.
+ This method converts between Python 3 and Python 2 dumpfiles.
+ """
+ if not six.PY2:
+ data = _convert_dict_keys(data)
+ data = _convert_dict_vals(
+ data, {
+ "type": True,
+ "id": True,
+ "request": {
+ "first_line_format": True
+ },
+ "error": {
+ "msg": True
+ }
+ }
+ )
+ return data
+
+
converters = {
(0, 13): convert_013_014,
(0, 14): convert_014_015,
@@ -58,14 +107,17 @@ converters = {
def migrate_flow(flow_data):
while True:
- flow_version = tuple(flow_data["version"][:2])
- if flow_version == version.IVERSION[:2]:
+ flow_version = tuple(flow_data.get(b"version", flow_data.get("version")))
+ if flow_version[:2] == version.IVERSION[:2]:
break
- elif flow_version in converters:
- flow_data = converters[flow_version](flow_data)
+ elif flow_version[:2] in converters:
+ flow_data = converters[flow_version[:2]](flow_data)
else:
- v = ".".join(str(i) for i in flow_data["version"])
+ v = ".".join(str(i) for i in flow_version)
raise ValueError(
"{} cannot read files serialized with version {}.".format(version.MITMPROXY, v)
)
+ # TODO: This should finally be moved in the converter for the first py3-only release.
+ # It's here so that a py2 0.18 dump can be read by py3 0.18 and vice versa.
+ flow_data = convert_unicode(flow_data)
return flow_data
diff --git a/mitmproxy/models/connections.py b/mitmproxy/models/connections.py
index d71379bc..570e89a9 100644
--- a/mitmproxy/models/connections.py
+++ b/mitmproxy/models/connections.py
@@ -8,7 +8,6 @@ import six
from mitmproxy import stateobject
from netlib import certutils
-from netlib import strutils
from netlib import tcp
@@ -206,6 +205,8 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
self.wfile.flush()
def establish_ssl(self, clientcerts, sni, **kwargs):
+ if sni and not isinstance(sni, six.string_types):
+ raise ValueError("sni must be str, not " + type(sni).__name__)
clientcert = None
if clientcerts:
if os.path.isfile(clientcerts):
@@ -217,7 +218,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
if os.path.exists(path):
clientcert = path
- self.convert_to_ssl(cert=clientcert, sni=strutils.always_bytes(sni), **kwargs)
+ self.convert_to_ssl(cert=clientcert, sni=sni, **kwargs)
self.sni = sni
self.timestamp_ssl_setup = time.time()
diff --git a/mitmproxy/models/flow.py b/mitmproxy/models/flow.py
index 0e4f80cb..f4993b7a 100644
--- a/mitmproxy/models/flow.py
+++ b/mitmproxy/models/flow.py
@@ -9,6 +9,7 @@ from mitmproxy.models.connections import ClientConnection
from mitmproxy.models.connections import ServerConnection
from netlib import version
+from typing import Optional # noqa
class Error(stateobject.StateObject):
@@ -70,18 +71,13 @@ class Flow(stateobject.StateObject):
def __init__(self, type, client_conn, server_conn, live=None):
self.type = type
self.id = str(uuid.uuid4())
- self.client_conn = client_conn
- """@type: ClientConnection"""
- self.server_conn = server_conn
- """@type: ServerConnection"""
+ self.client_conn = client_conn # type: ClientConnection
+ self.server_conn = server_conn # type: ServerConnection
self.live = live
- """@type: LiveConnection"""
- self.error = None
- """@type: Error"""
- self.intercepted = False
- """@type: bool"""
- self._backup = None
+ self.error = None # type: Error
+ self.intercepted = False # type: bool
+ self._backup = None # type: Optional[Flow]
self.reply = None
_stateobject_attributes = dict(
diff --git a/mitmproxy/protocol/tls.py b/mitmproxy/protocol/tls.py
index 9f883b2b..8ef34493 100644
--- a/mitmproxy/protocol/tls.py
+++ b/mitmproxy/protocol/tls.py
@@ -10,6 +10,7 @@ import netlib.exceptions
from mitmproxy import exceptions
from mitmproxy.contrib.tls import _constructs
from mitmproxy.protocol import base
+from netlib import utils
# taken from https://testssl.sh/openssl-rfc.mappping.html
@@ -274,10 +275,11 @@ class TlsClientHello(object):
is_valid_sni_extension = (
extension.type == 0x00 and
len(extension.server_names) == 1 and
- extension.server_names[0].type == 0
+ extension.server_names[0].type == 0 and
+ utils.is_valid_host(extension.server_names[0].name)
)
if is_valid_sni_extension:
- return extension.server_names[0].name
+ return extension.server_names[0].name.decode("idna")
@property
def alpn_protocols(self):
@@ -403,13 +405,14 @@ class TlsLayer(base.Layer):
self._establish_tls_with_server()
def set_server_tls(self, server_tls, sni=None):
+ # type: (bool, Union[six.text_type, None, False]) -> None
"""
Set the TLS settings for the next server connection that will be established.
This function will not alter an existing connection.
Args:
server_tls: Shall we establish TLS with the server?
- sni: ``bytes`` for a custom SNI value,
+ sni: ``str`` for a custom SNI value,
``None`` for the client SNI value,
``False`` if no SNI value should be sent.
"""
@@ -602,9 +605,9 @@ class TlsLayer(base.Layer):
host = upstream_cert.cn.decode("utf8").encode("idna")
# Also add SNI values.
if self._client_hello.sni:
- sans.add(self._client_hello.sni)
+ sans.add(self._client_hello.sni.encode("idna"))
if self._custom_server_sni:
- sans.add(self._custom_server_sni)
+ sans.add(self._custom_server_sni.encode("idna"))
# RFC 2818: If a subjectAltName extension of type dNSName is present, that MUST be used as the identity.
# In other words, the Common Name is irrelevant then.
diff --git a/netlib/tcp.py b/netlib/tcp.py
index 69dafc1f..cf099edd 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -676,7 +676,7 @@ class TCPClient(_Connection):
self.connection = SSL.Connection(context, self.connection)
if sni:
self.sni = sni
- self.connection.set_tlsext_host_name(sni)
+ self.connection.set_tlsext_host_name(sni.encode("idna"))
self.connection.set_connect_state()
try:
self.connection.do_handshake()
@@ -705,7 +705,7 @@ class TCPClient(_Connection):
if self.cert.cn:
crt["subject"] = [[["commonName", self.cert.cn.decode("ascii", "strict")]]]
if sni:
- hostname = sni.decode("ascii", "strict")
+ hostname = sni
else:
hostname = "no-hostname"
ssl_match_hostname.match_hostname(crt, hostname)
diff --git a/netlib/utils.py b/netlib/utils.py
index 79340cbd..23c16dc3 100644
--- a/netlib/utils.py
+++ b/netlib/utils.py
@@ -73,11 +73,9 @@ _label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE)
def is_valid_host(host):
+ # type: (bytes) -> bool
"""
Checks if a hostname is valid.
-
- Args:
- host (bytes): The hostname
"""
try:
host.decode("idna")
diff --git a/pathod/pathod.py b/pathod/pathod.py
index 3df86aae..7087cba6 100644
--- a/pathod/pathod.py
+++ b/pathod/pathod.py
@@ -89,7 +89,10 @@ class PathodHandler(tcp.BaseHandler):
self.http2_framedump = http2_framedump
def handle_sni(self, connection):
- self.sni = connection.get_servername()
+ sni = connection.get_servername()
+ if sni:
+ sni = sni.decode("idna")
+ self.sni = sni
def http_serve_crafted(self, crafted, logctx):
error, crafted = self.server.check_policy(
diff --git a/test/mitmproxy/test_contrib_tnetstring.py b/test/mitmproxy/test_contrib_tnetstring.py
index 17654ad9..05c4a7c9 100644
--- a/test/mitmproxy/test_contrib_tnetstring.py
+++ b/test/mitmproxy/test_contrib_tnetstring.py
@@ -15,7 +15,9 @@ FORMAT_EXAMPLES = {
{b'hello': [12345678901, b'this', True, None, b'\x00\x00\x00\x00']},
b'5:12345#': 12345,
b'12:this is cool,': b'this is cool',
+ b'19:this is unicode \xe2\x98\x85;': u'this is unicode \u2605',
b'0:,': b'',
+ b'0:;': u'',
b'0:~': None,
b'4:true!': True,
b'5:false!': False,
@@ -43,7 +45,7 @@ def get_random_object(random=random, depth=0):
d = {}
for _ in range(n):
n = random.randint(0, 100)
- k = bytes([random.randint(32, 126) for _ in range(n)])
+ k = str([random.randint(32, 126) for _ in range(n)])
d[k] = get_random_object(random, depth + 1)
return d
else:
@@ -78,12 +80,6 @@ class Test_Format(unittest.TestCase):
self.assertEqual(v, tnetstring.loads(tnetstring.dumps(v)))
self.assertEqual((v, b""), tnetstring.pop(tnetstring.dumps(v)))
- def test_unicode_handling(self):
- with self.assertRaises(ValueError):
- tnetstring.dumps(u"hello")
- self.assertEqual(tnetstring.dumps(u"hello".encode()), b"5:hello,")
- self.assertEqual(type(tnetstring.loads(b"5:hello,")), bytes)
-
def test_roundtrip_format_unicode(self):
for _ in range(500):
v = get_random_object()
diff --git a/test/mitmproxy/test_server.py b/test/mitmproxy/test_server.py
index 1bbef975..0ab7624e 100644
--- a/test/mitmproxy/test_server.py
+++ b/test/mitmproxy/test_server.py
@@ -100,10 +100,10 @@ class CommonMixin:
if not self.ssl:
return
- f = self.pathod("304", sni=b"testserver.com")
+ f = self.pathod("304", sni="testserver.com")
assert f.status_code == 304
log = self.server.last_log()
- assert log["request"]["sni"] == b"testserver.com"
+ assert log["request"]["sni"] == "testserver.com"
class TcpMixin:
@@ -498,7 +498,7 @@ class TestHttps2Http(tservers.ReverseProxyTest):
assert p.request("get:'/p/200'").status_code == 200
def test_sni(self):
- p = self.pathoc(ssl=True, sni=b"example.com")
+ p = self.pathoc(ssl=True, sni="example.com")
assert p.request("get:'/p/200'").status_code == 200
assert all("Error in handle_sni" not in msg for msg in self.proxy.tlog)
diff --git a/test/netlib/test_tcp.py b/test/netlib/test_tcp.py
index 590bcc01..273427d5 100644
--- a/test/netlib/test_tcp.py
+++ b/test/netlib/test_tcp.py
@@ -169,7 +169,7 @@ class TestServerSSL(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
- c.convert_to_ssl(sni=b"foo.com", options=SSL.OP_ALL)
+ c.convert_to_ssl(sni="foo.com", options=SSL.OP_ALL)
testval = b"echo!\n"
c.wfile.write(testval)
c.wfile.flush()
@@ -179,7 +179,7 @@ class TestServerSSL(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
assert not c.get_current_cipher()
- c.convert_to_ssl(sni=b"foo.com")
+ c.convert_to_ssl(sni="foo.com")
ret = c.get_current_cipher()
assert ret
assert "AES" in ret[0]
@@ -195,7 +195,7 @@ class TestSSLv3Only(tservers.ServerTestBase):
def test_failure(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
- tutils.raises(TlsException, c.convert_to_ssl, sni=b"foo.com")
+ tutils.raises(TlsException, c.convert_to_ssl, sni="foo.com")
class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):
@@ -238,7 +238,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):
with c.connect():
with tutils.raises(InvalidCertificateException):
c.convert_to_ssl(
- sni=b"example.mitmproxy.org",
+ sni="example.mitmproxy.org",
verify_options=SSL.VERIFY_PEER,
ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
)
@@ -272,7 +272,7 @@ class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase):
with c.connect():
with tutils.raises(InvalidCertificateException):
c.convert_to_ssl(
- sni=b"mitmproxy.org",
+ sni="mitmproxy.org",
verify_options=SSL.VERIFY_PEER,
ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
)
@@ -291,7 +291,7 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_ssl(
- sni=b"example.mitmproxy.org",
+ sni="example.mitmproxy.org",
verify_options=SSL.VERIFY_PEER,
ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
)
@@ -307,7 +307,7 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_ssl(
- sni=b"example.mitmproxy.org",
+ sni="example.mitmproxy.org",
verify_options=SSL.VERIFY_PEER,
ca_path=tutils.test_data.path("data/verificationcerts/")
)
@@ -371,8 +371,8 @@ class TestSNI(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
- c.convert_to_ssl(sni=b"foo.com")
- assert c.sni == b"foo.com"
+ c.convert_to_ssl(sni="foo.com")
+ assert c.sni == "foo.com"
assert c.rfile.readline() == b"foo.com"
@@ -385,7 +385,7 @@ class TestServerCipherList(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
- c.convert_to_ssl(sni=b"foo.com")
+ c.convert_to_ssl(sni="foo.com")
assert c.rfile.readline() == b"['RC4-SHA']"
@@ -405,7 +405,7 @@ class TestServerCurrentCipher(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
- c.convert_to_ssl(sni=b"foo.com")
+ c.convert_to_ssl(sni="foo.com")
assert b"RC4-SHA" in c.rfile.readline()
@@ -418,7 +418,7 @@ class TestServerCipherListError(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
- tutils.raises("handshake error", c.convert_to_ssl, sni=b"foo.com")
+ tutils.raises("handshake error", c.convert_to_ssl, sni="foo.com")
class TestClientCipherListError(tservers.ServerTestBase):
@@ -433,7 +433,7 @@ class TestClientCipherListError(tservers.ServerTestBase):
tutils.raises(
"cipher specification",
c.convert_to_ssl,
- sni=b"foo.com",
+ sni="foo.com",
cipher_list="bogus"
)
diff --git a/test/pathod/test_pathoc.py b/test/pathod/test_pathoc.py
index 28f9f0f8..361a863b 100644
--- a/test/pathod/test_pathoc.py
+++ b/test/pathod/test_pathoc.py
@@ -54,10 +54,10 @@ class TestDaemonSSL(PathocTestDaemon):
def test_sni(self):
self.tval(
["get:/p/200"],
- sni=b"foobar.com"
+ sni="foobar.com"
)
log = self.d.log()
- assert log[0]["request"]["sni"] == b"foobar.com"
+ assert log[0]["request"]["sni"] == "foobar.com"
def test_showssl(self):
assert "certificate chain" in self.tval(["get:/p/200"], showssl=True)
diff --git a/tox.ini b/tox.ini
index a7b5e7d3..251609a5 100644
--- a/tox.ini
+++ b/tox.ini
@@ -16,7 +16,7 @@ commands =
[testenv:py35]
setenv =
- TESTS = test/netlib test/pathod/ test/mitmproxy/script test/mitmproxy/test_contentview.py test/mitmproxy/test_custom_contentview.py test/mitmproxy/test_app.py test/mitmproxy/test_controller.py test/mitmproxy/test_fuzzing.py test/mitmproxy/test_script.py test/mitmproxy/test_web_app.py test/mitmproxy/test_utils.py test/mitmproxy/test_stateobject.py test/mitmproxy/test_cmdline.py test/mitmproxy/test_contrib_tnetstring.py test/mitmproxy/test_proxy.py test/mitmproxy/test_protocol_http1.py test/mitmproxy/test_platform_pf.py test/mitmproxy/test_server.py test/mitmproxy/test_filt.py test/mitmproxy/test_flow_export.py test/mitmproxy/test_web_master.py
+ TESTS = test/netlib test/pathod/ test/mitmproxy/script test/mitmproxy/test_contentview.py test/mitmproxy/test_custom_contentview.py test/mitmproxy/test_app.py test/mitmproxy/test_controller.py test/mitmproxy/test_fuzzing.py test/mitmproxy/test_script.py test/mitmproxy/test_web_app.py test/mitmproxy/test_utils.py test/mitmproxy/test_stateobject.py test/mitmproxy/test_cmdline.py test/mitmproxy/test_contrib_tnetstring.py test/mitmproxy/test_proxy.py test/mitmproxy/test_protocol_http1.py test/mitmproxy/test_platform_pf.py test/mitmproxy/test_server.py test/mitmproxy/test_filt.py test/mitmproxy/test_flow_export.py test/mitmproxy/test_web_master.py test/mitmproxy/test_flow_format_compat.py
HOME = {envtmpdir}
[testenv:docs]