diff options
57 files changed, 720 insertions, 655 deletions
@@ -23,3 +23,4 @@ sslkeylogfile.log .python-version coverage.xml web/coverage/ +.mypy_cache/ diff --git a/examples/simple/custom_contentview.py b/examples/simple/custom_contentview.py index 71f92575..b958bdce 100644 --- a/examples/simple/custom_contentview.py +++ b/examples/simple/custom_contentview.py @@ -3,10 +3,6 @@ This example shows how one can add a custom contentview to mitmproxy. The content view API is explained in the mitmproxy.contentviews module. """ from mitmproxy import contentviews -import typing - - -CVIEWSWAPCASE = typing.Tuple[str, typing.Iterable[typing.List[typing.Tuple[str, typing.AnyStr]]]] class ViewSwapCase(contentviews.View): @@ -17,7 +13,7 @@ class ViewSwapCase(contentviews.View): prompt = ("swap case text", "z") content_types = ["text/plain"] - def __call__(self, data: typing.AnyStr, **metadata) -> CVIEWSWAPCASE: + def __call__(self, data, **metadata) -> contentviews.TViewResult: return "case-swapped text", contentviews.format_text(data.swapcase()) diff --git a/examples/simple/io_read_dumpfile.py b/examples/simple/io_read_dumpfile.py index ea544cc4..534f357b 100644 --- a/examples/simple/io_read_dumpfile.py +++ b/examples/simple/io_read_dumpfile.py @@ -1,6 +1,4 @@ #!/usr/bin/env python - -# type: ignore # # Simple script showing how to read a mitmproxy dump file # diff --git a/examples/simple/io_write_dumpfile.py b/examples/simple/io_write_dumpfile.py index cf7c4f52..be6e4121 100644 --- a/examples/simple/io_write_dumpfile.py +++ b/examples/simple/io_write_dumpfile.py @@ -13,15 +13,15 @@ import typing # noqa class Writer: def __init__(self, path: str) -> None: - if path == "-": - f = sys.stdout # type: typing.IO[typing.Any] - else: - f = open(path, "wb") - self.w = io.FlowWriter(f) + self.f = open(path, "wb") # type: typing.IO[bytes] + self.w = io.FlowWriter(self.f) def response(self, flow: http.HTTPFlow) -> None: if random.choice([True, False]): self.w.add(flow) + def done(self): + self.f.close() + addons = [Writer(sys.argv[1])] diff --git a/mitmproxy/addons/view.py b/mitmproxy/addons/view.py index 13a17c56..aa3e11ed 100644 --- a/mitmproxy/addons/view.py +++ b/mitmproxy/addons/view.py @@ -339,11 +339,12 @@ class View(collections.Sequence): """ Load flows into the view, without processing them with addons. """ - for i in io.FlowReader(open(path, "rb")).stream(): - # Do this to get a new ID, so we can load the same file N times and - # get new flows each time. It would be more efficient to just have a - # .newid() method or something. - self.add([i.copy()]) + with open(path, "rb") as f: + for i in io.FlowReader(f).stream(): + # Do this to get a new ID, so we can load the same file N times and + # get new flows each time. It would be more efficient to just have a + # .newid() method or something. + self.add([i.copy()]) @command.command("view.go") def go(self, dst: int) -> None: diff --git a/mitmproxy/contentviews/__init__.py b/mitmproxy/contentviews/__init__.py index f57b27c7..a1866851 100644 --- a/mitmproxy/contentviews/__init__.py +++ b/mitmproxy/contentviews/__init__.py @@ -25,7 +25,7 @@ from . import ( auto, raw, hex, json, xml_html, html_outline, wbxml, javascript, css, urlencoded, multipart, image, query, protobuf ) -from .base import View, VIEW_CUTOFF, KEY_MAX, format_text, format_dict +from .base import View, VIEW_CUTOFF, KEY_MAX, format_text, format_dict, TViewResult views = [] # type: List[View] content_types_map = {} # type: Dict[str, List[View]] @@ -178,7 +178,7 @@ add(query.ViewQuery()) add(protobuf.ViewProtobuf()) __all__ = [ - "View", "VIEW_CUTOFF", "KEY_MAX", "format_text", "format_dict", + "View", "VIEW_CUTOFF", "KEY_MAX", "format_text", "format_dict", "TViewResult", "get", "get_by_shortcut", "add", "remove", "get_content_view", "get_message_content_view", ] diff --git a/mitmproxy/contentviews/base.py b/mitmproxy/contentviews/base.py index 0de4f786..97740eea 100644 --- a/mitmproxy/contentviews/base.py +++ b/mitmproxy/contentviews/base.py @@ -1,20 +1,21 @@ # Default view cutoff *in lines* - -from typing import Iterable, AnyStr, List -from typing import Mapping -from typing import Tuple +import typing VIEW_CUTOFF = 512 KEY_MAX = 30 +TTextType = typing.Union[str, bytes] # FIXME: This should be either bytes or str ultimately. +TViewLine = typing.List[typing.Tuple[str, TTextType]] +TViewResult = typing.Tuple[str, typing.Iterator[TViewLine]] + class View: name = None # type: str - prompt = None # type: Tuple[str,str] - content_types = [] # type: List[str] + prompt = None # type: typing.Tuple[str,str] + content_types = [] # type: typing.List[str] - def __call__(self, data: bytes, **metadata): + def __call__(self, data: bytes, **metadata) -> TViewResult: """ Transform raw data into human-readable output. @@ -38,8 +39,8 @@ class View: def format_dict( - d: Mapping[AnyStr, AnyStr] -) -> Iterable[List[Tuple[str, AnyStr]]]: + d: typing.Mapping[TTextType, TTextType] +) -> typing.Iterator[TViewLine]: """ Helper function that transforms the given dictionary into a list of ("key", key ) @@ -49,7 +50,10 @@ def format_dict( max_key_len = max(len(k) for k in d.keys()) max_key_len = min(max_key_len, KEY_MAX) for key, value in d.items(): - key += b":" if isinstance(key, bytes) else u":" + if isinstance(key, bytes): + key += b":" + else: + key += ":" key = key.ljust(max_key_len + 2) yield [ ("header", key), @@ -57,7 +61,7 @@ def format_dict( ] -def format_text(text: AnyStr) -> Iterable[List[Tuple[str, AnyStr]]]: +def format_text(text: TTextType) -> typing.Iterator[TViewLine]: """ Helper function that transforms bytes into the view output format. """ diff --git a/mitmproxy/contrib/kaitaistruct/make.sh b/mitmproxy/contrib/kaitaistruct/make.sh index 218d5198..9ef68886 100755 --- a/mitmproxy/contrib/kaitaistruct/make.sh +++ b/mitmproxy/contrib/kaitaistruct/make.sh @@ -6,5 +6,6 @@ wget -N https://raw.githubusercontent.com/kaitai-io/kaitai_struct_formats/master wget -N https://raw.githubusercontent.com/kaitai-io/kaitai_struct_formats/master/image/gif.ksy wget -N https://raw.githubusercontent.com/kaitai-io/kaitai_struct_formats/master/image/jpeg.ksy wget -N https://raw.githubusercontent.com/kaitai-io/kaitai_struct_formats/master/image/png.ksy +wget -N https://raw.githubusercontent.com/mitmproxy/mitmproxy/master/mitmproxy/contrib/tls_client_hello.py kaitai-struct-compiler --target python --opaque-types=true *.ksy diff --git a/mitmproxy/contrib/kaitaistruct/tls_client_hello.py b/mitmproxy/contrib/kaitaistruct/tls_client_hello.py new file mode 100644 index 00000000..6aff9b14 --- /dev/null +++ b/mitmproxy/contrib/kaitaistruct/tls_client_hello.py @@ -0,0 +1,146 @@ +# This is a generated file! Please edit source .ksy file and use kaitai-struct-compiler to rebuild + +import array +import struct +import zlib +from enum import Enum +from pkg_resources import parse_version + +from kaitaistruct import __version__ as ks_version, KaitaiStruct, KaitaiStream, BytesIO + +if parse_version(ks_version) < parse_version('0.7'): + raise Exception("Incompatible Kaitai Struct Python API: 0.7 or later is required, but you have %s" % (ks_version)) + + +class TlsClientHello(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.version = self._root.Version(self._io, self, self._root) + self.random = self._root.Random(self._io, self, self._root) + self.session_id = self._root.SessionId(self._io, self, self._root) + self.cipher_suites = self._root.CipherSuites(self._io, self, self._root) + self.compression_methods = self._root.CompressionMethods(self._io, self, self._root) + if self._io.is_eof() == True: + self.extensions = [None] * (0) + for i in range(0): + self.extensions[i] = self._io.read_bytes(0) + + if self._io.is_eof() == False: + self.extensions = self._root.Extensions(self._io, self, self._root) + + class ServerName(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.name_type = self._io.read_u1() + self.length = self._io.read_u2be() + self.host_name = self._io.read_bytes(self.length) + + class Random(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.gmt_unix_time = self._io.read_u4be() + self.random = self._io.read_bytes(28) + + class SessionId(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.len = self._io.read_u1() + self.sid = self._io.read_bytes(self.len) + + class Sni(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.list_length = self._io.read_u2be() + self.server_names = [] + while not self._io.is_eof(): + self.server_names.append(self._root.ServerName(self._io, self, self._root)) + + class CipherSuites(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.len = self._io.read_u2be() + self.cipher_suites = [None] * (self.len // 2) + for i in range(self.len // 2): + self.cipher_suites[i] = self._root.CipherSuite(self._io, self, self._root) + + class CompressionMethods(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.len = self._io.read_u1() + self.compression_methods = self._io.read_bytes(self.len) + + class Alpn(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.ext_len = self._io.read_u2be() + self.alpn_protocols = [] + while not self._io.is_eof(): + self.alpn_protocols.append(self._root.Protocol(self._io, self, self._root)) + + class Extensions(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.len = self._io.read_u2be() + self.extensions = [] + while not self._io.is_eof(): + self.extensions.append(self._root.Extension(self._io, self, self._root)) + + class Version(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.major = self._io.read_u1() + self.minor = self._io.read_u1() + + class CipherSuite(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.cipher_suite = self._io.read_u2be() + + class Protocol(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.strlen = self._io.read_u1() + self.name = self._io.read_bytes(self.strlen) + + class Extension(KaitaiStruct): + def __init__(self, _io, _parent=None, _root=None): + self._io = _io + self._parent = _parent + self._root = _root if _root else self + self.type = self._io.read_u2be() + self.len = self._io.read_u2be() + _on = self.type + if _on == 0: + self._raw_body = self._io.read_bytes(self.len) + io = KaitaiStream(BytesIO(self._raw_body)) + self.body = self._root.Sni(io, self, self._root) + elif _on == 16: + self._raw_body = self._io.read_bytes(self.len) + io = KaitaiStream(BytesIO(self._raw_body)) + self.body = self._root.Alpn(io, self, self._root) + else: + self.body = self._io.read_bytes(self.len) diff --git a/mitmproxy/contrib/tls_client_hello.ksy b/mitmproxy/contrib/tls_client_hello.ksy new file mode 100644 index 00000000..5b6eb0fb --- /dev/null +++ b/mitmproxy/contrib/tls_client_hello.ksy @@ -0,0 +1,139 @@ +meta: + id: tls_client_hello + endian: be + +seq: + - id: version + type: version + + - id: random + type: random + + - id: session_id + type: session_id + + - id: cipher_suites + type: cipher_suites + + - id: compression_methods + type: compression_methods + + - id: extensions + size: 0 + repeat: expr + repeat-expr: 0 + if: _io.eof == true + + - id: extensions + type: extensions + if: _io.eof == false + +types: + version: + seq: + - id: major + type: u1 + + - id: minor + type: u1 + + random: + seq: + - id: gmt_unix_time + type: u4 + + - id: random + size: 28 + + session_id: + seq: + - id: len + type: u1 + + - id: sid + size: len + + cipher_suites: + seq: + - id: len + type: u2 + + - id: cipher_suites + type: cipher_suite + repeat: expr + repeat-expr: len/2 + + cipher_suite: + seq: + - id: cipher_suite + type: u2 + + compression_methods: + seq: + - id: len + type: u1 + + - id: compression_methods + size: len + + extensions: + seq: + - id: len + type: u2 + + - id: extensions + type: extension + repeat: eos + + extension: + seq: + - id: type + type: u2 + + - id: len + type: u2 + + - id: body + size: len + type: + switch-on: type + cases: + 0: sni + 16: alpn + + sni: + seq: + - id: list_length + type: u2 + + - id: server_names + type: server_name + repeat: eos + + server_name: + seq: + - id: name_type + type: u1 + + - id: length + type: u2 + + - id: host_name + size: length + + alpn: + seq: + - id: ext_len + type: u2 + + - id: alpn_protocols + type: protocol + repeat: eos + + protocol: + seq: + - id: strlen + type: u1 + + - id: name + size: strlen diff --git a/mitmproxy/contrib/tls_parser.py b/mitmproxy/contrib/tls_parser.py deleted file mode 100644 index 61fb3e3e..00000000 --- a/mitmproxy/contrib/tls_parser.py +++ /dev/null @@ -1,208 +0,0 @@ -# This file originally comes from https://github.com/pyca/tls/blob/master/tls/_constructs.py. -# Modified by the mitmproxy team. - -# This file is dual licensed under the terms of the Apache License, Version -# 2.0, and the BSD License. See the LICENSE file in the root of this repository -# for complete details. - - -from construct import ( - Array, - Bytes, - Struct, - VarInt, - Int8ub, - Int16ub, - Int24ub, - Int32ub, - PascalString, - Embedded, - Prefixed, - Range, - GreedyRange, - Switch, - Optional, -) - -ProtocolVersion = "version" / Struct( - "major" / Int8ub, - "minor" / Int8ub, -) - -TLSPlaintext = "TLSPlaintext" / Struct( - "type" / Int8ub, - ProtocolVersion, - "length" / Int16ub, # TODO: Reject packets with length > 2 ** 14 - "fragment" / Bytes(lambda ctx: ctx.length), -) - -TLSCompressed = "TLSCompressed" / Struct( - "type" / Int8ub, - ProtocolVersion, - "length" / Int16ub, # TODO: Reject packets with length > 2 ** 14 + 1024 - "fragment" / Bytes(lambda ctx: ctx.length), -) - -TLSCiphertext = "TLSCiphertext" / Struct( - "type" / Int8ub, - ProtocolVersion, - "length" / Int16ub, # TODO: Reject packets with length > 2 ** 14 + 2048 - "fragment" / Bytes(lambda ctx: ctx.length), -) - -Random = "random" / Struct( - "gmt_unix_time" / Int32ub, - "random_bytes" / Bytes(28), -) - -SessionID = "session_id" / Struct( - "length" / Int8ub, - "session_id" / Bytes(lambda ctx: ctx.length), -) - -CipherSuites = "cipher_suites" / Struct( - "length" / Int16ub, # TODO: Reject packets of length 0 - Array(lambda ctx: ctx.length // 2, "cipher_suites" / Int16ub), -) - -CompressionMethods = "compression_methods" / Struct( - "length" / Int8ub, # TODO: Reject packets of length 0 - Array(lambda ctx: ctx.length, "compression_methods" / Int8ub), -) - -ServerName = Struct( - "type" / Int8ub, - "name" / PascalString("length" / Int16ub), -) - -SNIExtension = Prefixed( - Int16ub, - Struct( - Int16ub, - "server_names" / GreedyRange( - "server_name" / Struct( - "name_type" / Int8ub, - "host_name" / PascalString("length" / Int16ub), - ) - ) - ) -) - -ALPNExtension = Prefixed( - Int16ub, - Struct( - Int16ub, - "alpn_protocols" / GreedyRange( - "name" / PascalString(Int8ub), - ), - ) -) - -UnknownExtension = Struct( - "bytes" / PascalString("length" / Int16ub) -) - -Extension = "Extension" / Struct( - "type" / Int16ub, - Embedded( - Switch( - lambda ctx: ctx.type, - { - 0x00: SNIExtension, - 0x10: ALPNExtension, - }, - default=UnknownExtension - ) - ) -) - -extensions = "extensions" / Optional( - Struct( - Int16ub, - "extensions" / GreedyRange(Extension) - ) -) - -ClientHello = "ClientHello" / Struct( - ProtocolVersion, - Random, - SessionID, - CipherSuites, - CompressionMethods, - extensions, -) - -ServerHello = "ServerHello" / Struct( - ProtocolVersion, - Random, - SessionID, - "cipher_suite" / Bytes(2), - "compression_method" / Int8ub, - extensions, -) - -ClientCertificateType = "certificate_types" / Struct( - "length" / Int8ub, # TODO: Reject packets of length 0 - Array(lambda ctx: ctx.length, "certificate_types" / Int8ub), -) - -SignatureAndHashAlgorithm = "algorithms" / Struct( - "hash" / Int8ub, - "signature" / Int8ub, -) - -SupportedSignatureAlgorithms = "supported_signature_algorithms" / Struct( - "supported_signature_algorithms_length" / Int16ub, - # TODO: Reject packets of length 0 - Array( - lambda ctx: ctx.supported_signature_algorithms_length / 2, - SignatureAndHashAlgorithm, - ), -) - -DistinguishedName = "certificate_authorities" / Struct( - "length" / Int16ub, - "certificate_authorities" / Bytes(lambda ctx: ctx.length), -) - -CertificateRequest = "CertificateRequest" / Struct( - ClientCertificateType, - SupportedSignatureAlgorithms, - DistinguishedName, -) - -ServerDHParams = "ServerDHParams" / Struct( - "dh_p_length" / Int16ub, - "dh_p" / Bytes(lambda ctx: ctx.dh_p_length), - "dh_g_length" / Int16ub, - "dh_g" / Bytes(lambda ctx: ctx.dh_g_length), - "dh_Ys_length" / Int16ub, - "dh_Ys" / Bytes(lambda ctx: ctx.dh_Ys_length), -) - -PreMasterSecret = "pre_master_secret" / Struct( - ProtocolVersion, - "random_bytes" / Bytes(46), -) - -ASN1Cert = "ASN1Cert" / Struct( - "length" / Int32ub, # TODO: Reject packets with length not in 1..2^24-1 - "asn1_cert" / Bytes(lambda ctx: ctx.length), -) - -Certificate = "Certificate" / Struct( - # TODO: Reject packets with length > 2 ** 24 - 1 - "certificates_length" / Int32ub, - "certificates_bytes" / Bytes(lambda ctx: ctx.certificates_length), -) - -Handshake = "Handshake" / Struct( - "msg_type" / Int8ub, - "length" / Int24ub, - "body" / Bytes(lambda ctx: ctx.length), -) - -Alert = "Alert" / Struct( - "level" / Int8ub, - "description" / Int8ub, -) diff --git a/mitmproxy/flowfilter.py b/mitmproxy/flowfilter.py index 83c98bad..4edf0413 100644 --- a/mitmproxy/flowfilter.py +++ b/mitmproxy/flowfilter.py @@ -344,6 +344,8 @@ class FUrl(_Rex): @only(http.HTTPFlow) def __call__(self, f): + if not f.request: + return False return self.re.search(f.request.pretty_url) diff --git a/mitmproxy/net/socks.py b/mitmproxy/net/socks.py index 570a4afb..fdfcfb80 100644 --- a/mitmproxy/net/socks.py +++ b/mitmproxy/net/socks.py @@ -82,12 +82,12 @@ class ClientGreeting: client_greeting = cls(ver, []) if fail_early: client_greeting.assert_socks5() - client_greeting.methods.fromstring(f.safe_read(nmethods)) + client_greeting.methods.frombytes(f.safe_read(nmethods)) return client_greeting def to_file(self, f): f.write(struct.pack("!BB", self.ver, len(self.methods))) - f.write(self.methods.tostring()) + f.write(self.methods.tobytes()) class ServerGreeting: diff --git a/mitmproxy/net/tcp.py b/mitmproxy/net/tcp.py index 81568d24..12cf7337 100644 --- a/mitmproxy/net/tcp.py +++ b/mitmproxy/net/tcp.py @@ -502,7 +502,7 @@ class _Connection: # Cipher List if cipher_list: try: - context.set_cipher_list(cipher_list) + context.set_cipher_list(cipher_list.encode()) context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1')) except SSL.Error as v: raise exceptions.TlsException("SSL cipher specification error: %s" % str(v)) @@ -569,7 +569,9 @@ class TCPClient(_Connection): # Make sure to close the real socket, not the SSL proxy. # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection, # it tries to renegotiate... - if isinstance(self.connection, SSL.Connection): + if not self.connection: + return + elif isinstance(self.connection, SSL.Connection): close_socket(self.connection._socket) else: close_socket(self.connection) @@ -674,6 +676,8 @@ class TCPClient(_Connection): sock.setsockopt(socket.SOL_IP, socket.IP_TRANSPARENT, 1) # pragma: windows no cover pragma: osx no cover except Exception as e: # socket.IP_TRANSPARENT might not be available on every OS and Python version + if sock is not None: + sock.close() raise exceptions.TcpException( "Failed to spoof the source address: " + str(e) ) @@ -864,6 +868,8 @@ class TCPServer: self.socket.setsockopt(IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) self.socket.bind(self.address) except socket.error: + if self.socket: + self.socket.close() self.socket = None if not self.socket: diff --git a/mitmproxy/proxy/protocol/http2.py b/mitmproxy/proxy/protocol/http2.py index 2191b54b..ace7ecde 100644 --- a/mitmproxy/proxy/protocol/http2.py +++ b/mitmproxy/proxy/protocol/http2.py @@ -206,14 +206,15 @@ class Http2Layer(base.Layer): return True def _handle_stream_reset(self, eid, event, is_server, other_conn): - self.streams[eid].kill() - if eid in self.streams and event.error_code == h2.errors.ErrorCodes.CANCEL: - if is_server: - other_stream_id = self.streams[eid].client_stream_id - else: - other_stream_id = self.streams[eid].server_stream_id - if other_stream_id is not None: - self.connections[other_conn].safe_reset_stream(other_stream_id, event.error_code) + if eid in self.streams: + self.streams[eid].kill() + if event.error_code == h2.errors.ErrorCodes.CANCEL: + if is_server: + other_stream_id = self.streams[eid].client_stream_id + else: + other_stream_id = self.streams[eid].server_stream_id + if other_stream_id is not None: + self.connections[other_conn].safe_reset_stream(other_stream_id, event.error_code) return True def _handle_remote_settings_changed(self, event, other_conn): diff --git a/mitmproxy/proxy/protocol/http_replay.py b/mitmproxy/proxy/protocol/http_replay.py index 25867871..1aa91847 100644 --- a/mitmproxy/proxy/protocol/http_replay.py +++ b/mitmproxy/proxy/protocol/http_replay.py @@ -83,7 +83,11 @@ class RequestReplayThread(basethread.BaseThread): server.wfile.write(http1.assemble_request(r)) server.wfile.flush() + + if self.f.server_conn: + self.f.server_conn.close() self.f.server_conn = server + self.f.response = http.HTTPResponse.wrap( http1.read_response( server.rfile, diff --git a/mitmproxy/proxy/protocol/tls.py b/mitmproxy/proxy/protocol/tls.py index f55855f0..d42c7fdd 100644 --- a/mitmproxy/proxy/protocol/tls.py +++ b/mitmproxy/proxy/protocol/tls.py @@ -1,10 +1,11 @@ import struct from typing import Optional # noqa from typing import Union +import io -import construct +from kaitaistruct import KaitaiStream from mitmproxy import exceptions -from mitmproxy.contrib import tls_parser +from mitmproxy.contrib.kaitaistruct import tls_client_hello from mitmproxy.proxy.protocol import base from mitmproxy.net import check @@ -263,7 +264,7 @@ def get_client_hello(client_conn): class TlsClientHello: def __init__(self, raw_client_hello): - self._client_hello = tls_parser.ClientHello.parse(raw_client_hello) + self._client_hello = tls_client_hello.TlsClientHello(KaitaiStream(io.BytesIO(raw_client_hello))) def raw(self): return self._client_hello @@ -278,12 +279,12 @@ class TlsClientHello: for extension in self._client_hello.extensions.extensions: is_valid_sni_extension = ( extension.type == 0x00 and - len(extension.server_names) == 1 and - extension.server_names[0].name_type == 0 and - check.is_valid_host(extension.server_names[0].host_name) + len(extension.body.server_names) == 1 and + extension.body.server_names[0].name_type == 0 and + check.is_valid_host(extension.body.server_names[0].host_name) ) if is_valid_sni_extension: - return extension.server_names[0].host_name.decode("idna") + return extension.body.server_names[0].host_name.decode("idna") return None @property @@ -291,7 +292,7 @@ class TlsClientHello: if self._client_hello.extensions: for extension in self._client_hello.extensions.extensions: if extension.type == 0x10: - return list(extension.alpn_protocols) + return list(extension.body.alpn_protocols) return [] @classmethod @@ -310,7 +311,7 @@ class TlsClientHello: try: return cls(raw_client_hello) - except construct.ConstructError as e: + except EOFError as e: raise exceptions.TlsProtocolException( 'Cannot parse Client Hello: %s, Raw Client Hello: %s' % (repr(e), raw_client_hello.encode("hex")) @@ -518,7 +519,8 @@ class TlsLayer(base.Layer): # We only support http/1.1 and h2. # If the server only supports spdy (next to http/1.1), it may select that # and mitmproxy would enter TCP passthrough mode, which we want to avoid. - alpn = [x for x in self._client_hello.alpn_protocols if not (x.startswith(b"h2-") or x.startswith(b"spdy"))] + alpn = [x.name for x in self._client_hello.alpn_protocols if + not (x.name.startswith(b"h2-") or x.name.startswith(b"spdy"))] if alpn and b"h2" in alpn and not self.config.options.http2: alpn.remove(b"h2") @@ -537,8 +539,8 @@ class TlsLayer(base.Layer): if not ciphers_server and self._client_tls: ciphers_server = [] for id in self._client_hello.cipher_suites: - if id in CIPHER_ID_NAME_MAP.keys(): - ciphers_server.append(CIPHER_ID_NAME_MAP[id]) + if id.cipher_suite in CIPHER_ID_NAME_MAP.keys(): + ciphers_server.append(CIPHER_ID_NAME_MAP[id.cipher_suite]) ciphers_server = ':'.join(ciphers_server) self.server_conn.establish_ssl( diff --git a/mitmproxy/proxy/server.py b/mitmproxy/proxy/server.py index 50a2b76b..5171fbee 100644 --- a/mitmproxy/proxy/server.py +++ b/mitmproxy/proxy/server.py @@ -48,6 +48,8 @@ class ProxyServer(tcp.TCPServer): if config.options.mode == "transparent": platform.init_transparent_mode() except Exception as e: + if self.socket: + self.socket.close() raise exceptions.ServerException( 'Error starting proxy server: ' + repr(e) ) from e diff --git a/mitmproxy/tools/main.py b/mitmproxy/tools/main.py index d8fac077..84dab1fe 100644 --- a/mitmproxy/tools/main.py +++ b/mitmproxy/tools/main.py @@ -1,6 +1,8 @@ from __future__ import print_function # this is here for the version check to work on Python 2. import sys +# This must be at the very top, before importing anything else that might break! +# Keep all other imports below with the 'noqa' magic comment. if sys.version_info < (3, 5): print("#" * 49, file=sys.stderr) print("# mitmproxy only supports Python 3.5 and above! #", file=sys.stderr) @@ -13,8 +15,7 @@ from mitmproxy.tools import cmdline # noqa from mitmproxy import exceptions # noqa from mitmproxy import options # noqa from mitmproxy import optmanager # noqa -from mitmproxy.proxy import config # noqa -from mitmproxy.proxy import server # noqa +from mitmproxy import proxy # noqa from mitmproxy.utils import version_check # noqa from mitmproxy.utils import debug # noqa @@ -49,15 +50,7 @@ def process_options(parser, opts, args): adict[n] = getattr(args, n) opts.merge(adict) - pconf = config.ProxyConfig(opts) - if opts.server: - try: - return server.ProxyServer(pconf) - except exceptions.ServerException as v: - print(str(v), file=sys.stderr) - sys.exit(1) - else: - return server.DummyServer(pconf) + return proxy.config.ProxyConfig(opts) def run(MasterKlass, args, extra=None): # pragma: no cover @@ -74,7 +67,16 @@ def run(MasterKlass, args, extra=None): # pragma: no cover master = None try: unknown = optmanager.load_paths(opts, args.conf) - server = process_options(parser, opts, args) + pconf = process_options(parser, opts, args) + if pconf.options.server: + try: + server = proxy.server.ProxyServer(pconf) + except exceptions.ServerException as v: + print(str(v), file=sys.stderr) + sys.exit(1) + else: + server = proxy.server.DummyServer(pconf) + master = MasterKlass(opts, server) master.addons.trigger("configure", opts.keys()) master.addons.trigger("tick") diff --git a/pathod/language/actions.py b/pathod/language/actions.py index fc57a18b..3e48f40d 100644 --- a/pathod/language/actions.py +++ b/pathod/language/actions.py @@ -50,7 +50,7 @@ class _Action(base.Token): class PauseAt(_Action): - unique_name = None # type: ignore + unique_name = None def __init__(self, offset, seconds): _Action.__init__(self, offset) diff --git a/pathod/language/base.py b/pathod/language/base.py index c8892748..97871e7e 100644 --- a/pathod/language/base.py +++ b/pathod/language/base.py @@ -6,7 +6,8 @@ import pyparsing as pp from mitmproxy.utils import strutils from mitmproxy.utils import human import typing # noqa -from . import generators, exceptions +from . import generators +from . import exceptions class Settings: @@ -375,7 +376,7 @@ class OptionsOrValue(_Component): class Integer(_Component): - bounds = (None, None) # type: typing.Tuple[typing.Union[int, None], typing.Union[int , None]] + bounds = (None, None) # type: typing.Tuple[typing.Optional[int], typing.Optional[int]] preamble = "" def __init__(self, value): @@ -537,43 +538,3 @@ class IntField(_Component): def spec(self): return "%s%s" % (self.preamble, self.origvalue) - - -class NestedMessage(Token): - - """ - A nested message, as an escaped string with a preamble. - """ - preamble = "" - nest_type = None # type: ignore - - def __init__(self, value): - Token.__init__(self) - self.value = value - try: - self.parsed = self.nest_type( - self.nest_type.expr().parseString( - value.val.decode(), - parseAll=True - ) - ) - except pp.ParseException as v: - raise exceptions.ParseException(v.msg, v.line, v.col) - - @classmethod - def expr(cls): - e = pp.Literal(cls.preamble).suppress() - e = e + TokValueLiteral.expr() - return e.setParseAction(lambda x: cls(*x)) - - def values(self, settings): - return [ - self.value.get_generator(settings), - ] - - def spec(self): - return "%s%s" % (self.preamble, self.value.spec()) - - def freeze(self, settings): - f = self.parsed.freeze(settings).spec() - return self.__class__(TokValueLiteral(strutils.bytes_to_escaped_str(f.encode(), escape_single_quotes=True))) diff --git a/pathod/language/generators.py b/pathod/language/generators.py index d716804d..1961df74 100644 --- a/pathod/language/generators.py +++ b/pathod/language/generators.py @@ -1,7 +1,7 @@ +import os import string import random import mmap - import sys DATATYPES = dict( @@ -74,20 +74,20 @@ class RandomGenerator: class FileGenerator: - def __init__(self, path): self.path = path - self.fp = open(path, "rb") - self.map = mmap.mmap(self.fp.fileno(), 0, access=mmap.ACCESS_READ) def __len__(self): - return len(self.map) + return os.path.getsize(self.path) def __getitem__(self, x): - if isinstance(x, slice): - return self.map.__getitem__(x) - # A slice of length 1 returns a byte object (not an integer) - return self.map.__getitem__(slice(x, x + 1 or self.map.size())) + with open(self.path, mode="rb") as f: + if isinstance(x, slice): + with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mapped: + return mapped.__getitem__(x) + else: + f.seek(x) + return f.read(1) def __repr__(self): return "<%s" % self.path diff --git a/pathod/language/http.py b/pathod/language/http.py index 5cd717a9..5a962145 100644 --- a/pathod/language/http.py +++ b/pathod/language/http.py @@ -54,7 +54,9 @@ class Method(base.OptionsOrValue): class _HeaderMixin: - unique_name = None # type: ignore + @property + def unique_name(self): + return None def format_header(self, key, value): return [key, b": ", value, b"\r\n"] @@ -251,7 +253,7 @@ class Response(_HTTPMessage): return ":".join([i.spec() for i in self.tokens]) -class NestedResponse(base.NestedMessage): +class NestedResponse(message.NestedMessage): preamble = "s" nest_type = Response diff --git a/pathod/language/http2.py b/pathod/language/http2.py index 47d6e370..5b27d5bf 100644 --- a/pathod/language/http2.py +++ b/pathod/language/http2.py @@ -1,9 +1,9 @@ import pyparsing as pp + from mitmproxy.net import http from mitmproxy.net.http import user_agents, Headers from . import base, message - """ Normal HTTP requests: <method>:<path>:<header>:<body> @@ -41,7 +41,9 @@ def get_header(val, headers): class _HeaderMixin: - unique_name = None # type: ignore + @property + def unique_name(self): + return None def values(self, settings): return ( @@ -146,7 +148,7 @@ class Times(base.Integer): class Response(_HTTP2Message): - unique_name = None # type: ignore + unique_name = None comps = ( Header, Body, @@ -203,7 +205,7 @@ class Response(_HTTP2Message): return ":".join([i.spec() for i in self.tokens]) -class NestedResponse(base.NestedMessage): +class NestedResponse(message.NestedMessage): preamble = "s" nest_type = Response diff --git a/pathod/language/message.py b/pathod/language/message.py index 6b4c5021..5dda654b 100644 --- a/pathod/language/message.py +++ b/pathod/language/message.py @@ -1,8 +1,11 @@ import abc -from . import actions, exceptions -from mitmproxy.utils import strutils import typing # noqa +import pyparsing as pp + +from mitmproxy.utils import strutils +from . import actions, exceptions, base + LOG_TRUNCATE = 1024 @@ -96,3 +99,46 @@ class Message: def __repr__(self): return self.spec() + + +class NestedMessage(base.Token): + """ + A nested message, as an escaped string with a preamble. + """ + preamble = "" + nest_type = None # type: typing.Optional[typing.Type[Message]] + + def __init__(self, value): + super().__init__() + self.value = value + try: + self.parsed = self.nest_type( + self.nest_type.expr().parseString( + value.val.decode(), + parseAll=True + ) + ) + except pp.ParseException as v: + raise exceptions.ParseException(v.msg, v.line, v.col) + + @classmethod + def expr(cls): + e = pp.Literal(cls.preamble).suppress() + e = e + base.TokValueLiteral.expr() + return e.setParseAction(lambda x: cls(*x)) + + def values(self, settings): + return [ + self.value.get_generator(settings), + ] + + def spec(self): + return "%s%s" % (self.preamble, self.value.spec()) + + def freeze(self, settings): + f = self.parsed.freeze(settings).spec() + return self.__class__( + base.TokValueLiteral( + strutils.bytes_to_escaped_str(f.encode(), escape_single_quotes=True) + ) + ) diff --git a/pathod/language/websockets.py b/pathod/language/websockets.py index b4faf59b..cc00bcf1 100644 --- a/pathod/language/websockets.py +++ b/pathod/language/websockets.py @@ -1,10 +1,12 @@ import random import string +import typing # noqa + +import pyparsing as pp + import mitmproxy.net.websockets from mitmproxy.utils import strutils -import pyparsing as pp from . import base, generators, actions, message -import typing # noqa NESTED_LEADER = b"pathod!" @@ -74,7 +76,7 @@ class Times(base.Integer): preamble = "x" -COMPONENTS = ( +COMPONENTS = [ OpCode, Length, # Bit flags @@ -89,14 +91,13 @@ COMPONENTS = ( KeyNone, Key, Times, - Body, RawBody, -) +] class WebsocketFrame(message.Message): - components = COMPONENTS + components = COMPONENTS # type: typing.List[typing.Type[base._Component]] logattrs = ["body"] # Used for nested frames unique_name = "body" @@ -235,19 +236,10 @@ class WebsocketFrame(message.Message): return ":".join([i.spec() for i in self.tokens]) -class NestedFrame(base.NestedMessage): +class NestedFrame(message.NestedMessage): preamble = "f" nest_type = WebsocketFrame -COMP = typing.Tuple[ - typing.Type[OpCode], typing.Type[Length], typing.Type[Fin], typing.Type[RSV1], typing.Type[RSV2], typing.Type[RSV3], typing.Type[Mask], - typing.Type[actions.PauseAt], typing.Type[actions.DisconnectAt], typing.Type[actions.InjectAt], typing.Type[KeyNone], typing.Type[Key], - typing.Type[Times], typing.Type[Body], typing.Type[RawBody] -] - - class WebsocketClientFrame(WebsocketFrame): - components = typing.cast(COMP, COMPONENTS + ( - NestedFrame, - )) + components = COMPONENTS + [NestedFrame] diff --git a/pathod/pathod_cmdline.py b/pathod/pathod_cmdline.py index ef1e983f..dee19f4f 100644 --- a/pathod/pathod_cmdline.py +++ b/pathod/pathod_cmdline.py @@ -216,7 +216,8 @@ def args_pathod(argv, stdout_=sys.stdout, stderr_=sys.stderr): anchors = [] for patt, spec in args.anchors: if os.path.isfile(spec): - data = open(spec).read() + with open(spec) as f: + data = f.read() spec = data try: arex = re.compile(patt) diff --git a/release/setup.py b/release/setup.py index 01d0672d..17b02ebc 100644 --- a/release/setup.py +++ b/release/setup.py @@ -6,9 +6,9 @@ setup( py_modules=["rtool"], install_requires=[ "click>=6.2, <7.0", - "twine>=1.6.5, <1.9", + "twine>=1.6.5, <1.10", "pysftp==0.2.8", - "cryptography>=1.6, <1.7", + "cryptography>=1.6, <1.9", ], entry_points={ "console_scripts": [ diff --git a/requirements.txt b/requirements.txt index ab8e8a0b..28a0b495 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ --e .[dev,examples,contentviews] +-e .[dev,examples] @@ -21,7 +21,6 @@ exclude_lines = [tool:full_coverage] exclude = - mitmproxy/contentviews/wbxml.py mitmproxy/proxy/protocol/ mitmproxy/proxy/config.py mitmproxy/proxy/root_context.py @@ -39,7 +38,6 @@ exclude = mitmproxy/addons/onboardingapp/app.py mitmproxy/addons/termlog.py mitmproxy/contentviews/base.py - mitmproxy/contentviews/wbxml.py mitmproxy/controller.py mitmproxy/ctx.py mitmproxy/exceptions.py @@ -61,9 +61,9 @@ setup( # It is not considered best practice to use install_requires to pin dependencies to specific versions. install_requires=[ "blinker>=1.4, <1.5", - "click>=6.2, <7", + "brotlipy>=0.5.1, <0.7", "certifi>=2015.11.20.1", # no semver here - this should always be on the last release! - "construct>=2.8, <2.9", + "click>=6.2, <7", "cryptography>=1.4, <1.9", "cssutils>=1.0.1, <1.1", "h2>=3.0, <4", @@ -79,37 +79,29 @@ setup( "pyperclip>=1.5.22, <1.6", "requests>=2.9.1, <3", "ruamel.yaml>=0.13.2, <0.15", + "sortedcontainers>=1.5.4, <1.6", "tornado>=4.3, <4.6", "urwid>=1.3.1, <1.4", - "brotlipy>=0.5.1, <0.7", - "sortedcontainers>=1.5.4, <1.6", - # transitive from cryptography, we just blacklist here. - # https://github.com/pypa/setuptools/issues/861 - "setuptools>=11.3, !=29.0.0", ], extras_require={ ':sys_platform == "win32"': [ "pydivert>=2.0.3, <2.1", ], - ':sys_platform != "win32"': [ - ], 'dev': [ - "Flask>=0.10.1, <0.13", "flake8>=3.2.1, <3.4", - "mypy>=0.501, <0.502", - "rstcheck>=2.2, <4.0", - "tox>=2.3, <3", - "pytest>=3, <3.1", + "Flask>=0.10.1, <0.13", + "mypy>=0.501, <0.512", "pytest-cov>=2.2.1, <3", + "pytest-faulthandler>=1.3.0, <2", "pytest-timeout>=1.0.0, <2", "pytest-xdist>=1.14, <2", - "pytest-faulthandler>=1.3.0, <2", - "sphinx>=1.3.5, <1.7", + "pytest>=3.1, <4", + "rstcheck>=2.2, <4.0", + "sphinx_rtd_theme>=0.1.9, <0.3", "sphinx-autobuild>=0.5.2, <0.7", + "sphinx>=1.3.5, <1.7", "sphinxcontrib-documentedlist>=0.5.0, <0.7", - "sphinx_rtd_theme>=0.1.9, <0.3", - ], - 'contentviews': [ + "tox>=2.3, <3", ], 'examples': [ "beautifulsoup4>=4.4.1, <4.7", diff --git a/test/conftest.py b/test/conftest.py index b4e1da93..bb913548 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,8 +1,6 @@ import os import pytest import OpenSSL -import functools -from contextlib import contextmanager import mitmproxy.net.tcp @@ -32,21 +30,3 @@ skip_appveyor = pytest.mark.skipif( def disable_alpn(monkeypatch): monkeypatch.setattr(mitmproxy.net.tcp, 'HAS_ALPN', False) monkeypatch.setattr(OpenSSL.SSL._lib, 'Cryptography_HAS_ALPN', False) - - -################################################################################ -# TODO: remove this wrapper when pytest 3.1.0 is released -original_pytest_raises = pytest.raises - - -@contextmanager -@functools.wraps(original_pytest_raises) -def raises(exc, *args, **kwargs): - with original_pytest_raises(exc, *args, **kwargs) as exc_info: - yield - if 'match' in kwargs: - assert exc_info.match(kwargs['match']) - - -pytest.raises = raises -################################################################################ diff --git a/test/mitmproxy/addons/test_clientplayback.py b/test/mitmproxy/addons/test_clientplayback.py index 7ffda317..6089b2d5 100644 --- a/test/mitmproxy/addons/test_clientplayback.py +++ b/test/mitmproxy/addons/test_clientplayback.py @@ -10,9 +10,10 @@ from mitmproxy.test import taddons def tdump(path, flows): - w = io.FlowWriter(open(path, "wb")) - for i in flows: - w.add(i) + with open(path, "wb") as f: + w = io.FlowWriter(f) + for i in flows: + w.add(i) class MockThread(): diff --git a/test/mitmproxy/addons/test_save.py b/test/mitmproxy/addons/test_save.py index 85c2a398..a4e425cd 100644 --- a/test/mitmproxy/addons/test_save.py +++ b/test/mitmproxy/addons/test_save.py @@ -26,8 +26,9 @@ def test_configure(tmpdir): def rd(p): - x = io.FlowReader(open(p, "rb")) - return list(x.stream()) + with open(p, "rb") as f: + x = io.FlowReader(f) + return list(x.stream()) def test_tcp(tmpdir): diff --git a/test/mitmproxy/addons/test_serverplayback.py b/test/mitmproxy/addons/test_serverplayback.py index 3ceab3fa..7605a5d9 100644 --- a/test/mitmproxy/addons/test_serverplayback.py +++ b/test/mitmproxy/addons/test_serverplayback.py @@ -11,9 +11,10 @@ from mitmproxy import io def tdump(path, flows): - w = io.FlowWriter(open(path, "wb")) - for i in flows: - w.add(i) + with open(path, "wb") as f: + w = io.FlowWriter(f) + for i in flows: + w.add(i) def test_load_file(tmpdir): diff --git a/test/mitmproxy/addons/test_view.py b/test/mitmproxy/addons/test_view.py index 6da13650..d5a3a456 100644 --- a/test/mitmproxy/addons/test_view.py +++ b/test/mitmproxy/addons/test_view.py @@ -132,9 +132,10 @@ def test_filter(): def tdump(path, flows): - w = io.FlowWriter(open(path, "wb")) - for i in flows: - w.add(i) + with open(path, "wb") as f: + w = io.FlowWriter(f) + for i in flows: + w.add(i) def test_create(): diff --git a/test/mitmproxy/contentviews/test_protobuf.py b/test/mitmproxy/contentviews/test_protobuf.py index 31e382ec..71e51576 100644 --- a/test/mitmproxy/contentviews/test_protobuf.py +++ b/test/mitmproxy/contentviews/test_protobuf.py @@ -17,7 +17,9 @@ def test_view_protobuf_request(): m.configure_mock(**attrs) n.return_value = m - content_type, output = v(open(p, "rb").read()) + with open(p, "rb") as f: + data = f.read() + content_type, output = v(data) assert content_type == "Protobuf" assert output[0] == [('text', b'1: "3bbc333c-e61c-433b-819a-0b9a8cc103b8"')] diff --git a/test/mitmproxy/contentviews/test_wbxml.py b/test/mitmproxy/contentviews/test_wbxml.py index 777ab4dd..09c770e7 100644 --- a/test/mitmproxy/contentviews/test_wbxml.py +++ b/test/mitmproxy/contentviews/test_wbxml.py @@ -1 +1,21 @@ -# TODO: write tests +from mitmproxy.contentviews import wbxml +from mitmproxy.test import tutils +from . import full_eval + +data = tutils.test_data.push("mitmproxy/contentviews/test_wbxml_data/") + + +def test_wbxml(): + v = full_eval(wbxml.ViewWBXML()) + + assert v(b'\x03\x01\x6A\x00') == ('WBXML', [[('text', '<?xml version="1.0" ?>')]]) + assert v(b'foo') is None + + path = data.path("data.wbxml") # File taken from https://github.com/davidpshaw/PyWBXMLDecoder/tree/master/wbxml_samples + with open(path, 'rb') as f: + input = f.read() + with open("-formatted.".join(path.rsplit(".", 1))) as f: + expected = f.read() + + p = wbxml.ASCommandResponse.ASCommandResponse(input) + assert p.xmlString == expected diff --git a/test/mitmproxy/contentviews/test_wbxml_data/data-formatted.wbxml b/test/mitmproxy/contentviews/test_wbxml_data/data-formatted.wbxml new file mode 100644 index 00000000..fed293bd --- /dev/null +++ b/test/mitmproxy/contentviews/test_wbxml_data/data-formatted.wbxml @@ -0,0 +1,10 @@ +<?xml version="1.0" ?> +<Sync> + <Collections> + <Collection> + <SyncKey>1509029063</SyncKey> + <CollectionId>7</CollectionId> + <Status>1</Status> + </Collection> + </Collections> +</Sync> diff --git a/test/mitmproxy/contentviews/test_wbxml_data/data.wbxml b/test/mitmproxy/contentviews/test_wbxml_data/data.wbxml Binary files differnew file mode 100644 index 00000000..7c7a2004 --- /dev/null +++ b/test/mitmproxy/contentviews/test_wbxml_data/data.wbxml diff --git a/test/mitmproxy/contrib/test_tls_parser.py b/test/mitmproxy/contrib/test_tls_parser.py deleted file mode 100644 index 66972b62..00000000 --- a/test/mitmproxy/contrib/test_tls_parser.py +++ /dev/null @@ -1,38 +0,0 @@ -from mitmproxy.contrib import tls_parser - - -def test_parse_chrome(): - """ - Test if we properly parse a ClientHello sent by Chrome 54. - """ - data = bytes.fromhex( - "03033b70638d2523e1cba15f8364868295305e9c52aceabda4b5147210abc783e6e1000022c02bc02fc02cc030" - "cca9cca8cc14cc13c009c013c00ac014009c009d002f0035000a0100006cff0100010000000010000e00000b65" - "78616d706c652e636f6d0017000000230000000d00120010060106030501050304010403020102030005000501" - "00000000001200000010000e000c02683208687474702f312e3175500000000b00020100000a00080006001d00" - "170018" - ) - c = tls_parser.ClientHello.parse(data) - assert c.version.major == 3 - assert c.version.minor == 3 - - alpn = [a for a in c.extensions.extensions if a.type == 16] - assert len(alpn) == 1 - assert alpn[0].alpn_protocols == [b"h2", b"http/1.1"] - - sni = [a for a in c.extensions.extensions if a.type == 0] - assert len(sni) == 1 - assert sni[0].server_names[0].name_type == 0 - assert sni[0].server_names[0].host_name == b"example.com" - - -def test_parse_no_extensions(): - data = bytes.fromhex( - "03015658a756ab2c2bff55f636814deac086b7ca56b65058c7893ffc6074f5245f70205658a75475103a152637" - "78e1bb6d22e8bbd5b6b0a3a59760ad354e91ba20d353001a0035002f000a000500040009000300060008006000" - "61006200640100" - ) - c = tls_parser.ClientHello.parse(data) - assert c.version.major == 3 - assert c.version.minor == 1 - assert c.extensions is None diff --git a/test/mitmproxy/net/test_tcp.py b/test/mitmproxy/net/test_tcp.py index 81d51888..adf8701a 100644 --- a/test/mitmproxy/net/test_tcp.py +++ b/test/mitmproxy/net/test_tcp.py @@ -34,7 +34,7 @@ class ClientCipherListHandler(tcp.BaseHandler): sni = None def handle(self): - self.wfile.write("%s" % self.connection.get_cipher_list()) + self.wfile.write(str(self.connection.get_cipher_list()).encode()) self.wfile.flush() @@ -398,7 +398,8 @@ class TestServerCipherList(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): c.convert_to_ssl(sni="foo.com") - assert c.rfile.readline() == b"['AES256-GCM-SHA384']" + expected = b"['AES256-GCM-SHA384']" + assert c.rfile.read(len(expected) + 2) == expected class TestServerCurrentCipher(tservers.ServerTestBase): @@ -424,7 +425,7 @@ class TestServerCurrentCipher(tservers.ServerTestBase): class TestServerCipherListError(tservers.ServerTestBase): handler = ClientCipherListHandler ssl = dict( - cipher_list='bogus' + cipher_list=b'bogus' ) def test_echo(self): @@ -632,6 +633,7 @@ class TestTCPServer: with s.handler_counter: with pytest.raises(exceptions.Timeout): s.wait_for_silence() + s.shutdown() class TestFileLike: diff --git a/test/mitmproxy/net/tservers.py b/test/mitmproxy/net/tservers.py index ebe6d3eb..44701aa5 100644 --- a/test/mitmproxy/net/tservers.py +++ b/test/mitmproxy/net/tservers.py @@ -16,9 +16,6 @@ class _ServerThread(threading.Thread): def run(self): self.server.serve_forever() - def shutdown(self): - self.server.shutdown() - class _TServer(tcp.TCPServer): @@ -54,9 +51,9 @@ class _TServer(tcp.TCPServer): raw_key = self.ssl.get( "key", tutils.test_data.path("mitmproxy/net/data/server.key")) - key = OpenSSL.crypto.load_privatekey( - OpenSSL.crypto.FILETYPE_PEM, - open(raw_key, "rb").read()) + with open(raw_key) as f: + raw_key = f.read() + key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw_key) if self.ssl.get("v3_only", False): method = OpenSSL.SSL.SSLv3_METHOD options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1 @@ -64,7 +61,8 @@ class _TServer(tcp.TCPServer): method = OpenSSL.SSL.SSLv23_METHOD options = None h.convert_to_ssl( - cert, key, + cert, + key, method=method, options=options, handle_sni=getattr(h, "handle_sni", None), @@ -103,7 +101,7 @@ class ServerTestBase: @classmethod def teardown_class(cls): - cls.server.shutdown() + cls.server.server.shutdown() def teardown(self): self.server.server.wait_for_silence() diff --git a/test/mitmproxy/platform/test_pf.py b/test/mitmproxy/platform/test_pf.py index f644bcc5..3292d345 100644 --- a/test/mitmproxy/platform/test_pf.py +++ b/test/mitmproxy/platform/test_pf.py @@ -9,10 +9,11 @@ class TestLookup: def test_simple(self): if sys.platform == "freebsd10": p = tutils.test_data.path("mitmproxy/data/pf02") - d = open(p, "rb").read() else: p = tutils.test_data.path("mitmproxy/data/pf01") - d = open(p, "rb").read() + with open(p, "rb") as f: + d = f.read() + assert pf.lookup("192.168.1.111", 40000, d) == ("5.5.5.5", 80) with pytest.raises(Exception, match="Could not resolve original destination"): pf.lookup("192.168.1.112", 40000, d) diff --git a/test/mitmproxy/proxy/protocol/test_http1.py b/test/mitmproxy/proxy/protocol/test_http1.py index 07cd7dcc..b642afb3 100644 --- a/test/mitmproxy/proxy/protocol/test_http1.py +++ b/test/mitmproxy/proxy/protocol/test_http1.py @@ -65,6 +65,7 @@ class TestExpectHeader(tservers.HTTPProxyTest): assert resp.status_code == 200 client.finish() + client.close() class TestHeadContentLength(tservers.HTTPProxyTest): diff --git a/test/mitmproxy/proxy/protocol/test_http2.py b/test/mitmproxy/proxy/protocol/test_http2.py index b07257b3..261f8415 100644 --- a/test/mitmproxy/proxy/protocol/test_http2.py +++ b/test/mitmproxy/proxy/protocol/test_http2.py @@ -118,12 +118,16 @@ class _Http2TestBase: self.master.reset([]) self.server.server.handle_server_event = self.handle_server_event - def _setup_connection(self): - client = mitmproxy.net.tcp.TCPClient(("127.0.0.1", self.proxy.port)) - client.connect() + def teardown(self): + if self.client: + self.client.close() + + def setup_connection(self): + self.client = mitmproxy.net.tcp.TCPClient(("127.0.0.1", self.proxy.port)) + self.client.connect() # send CONNECT request - client.wfile.write(http1.assemble_request(mitmproxy.net.http.Request( + self.client.wfile.write(http1.assemble_request(mitmproxy.net.http.Request( 'authority', b'CONNECT', b'', @@ -134,13 +138,13 @@ class _Http2TestBase: [(b'host', b'localhost:%d' % self.server.server.address[1])], b'', ))) - client.wfile.flush() + self.client.wfile.flush() # read CONNECT response - while client.rfile.readline() != b"\r\n": + while self.client.rfile.readline() != b"\r\n": pass - client.convert_to_ssl(alpn_protos=[b'h2']) + self.client.convert_to_ssl(alpn_protos=[b'h2']) config = h2.config.H2Configuration( client_side=True, @@ -148,10 +152,10 @@ class _Http2TestBase: validate_inbound_headers=False) h2_conn = h2.connection.H2Connection(config) h2_conn.initiate_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() - return client, h2_conn + return h2_conn def _send_request(self, wfile, @@ -205,8 +209,8 @@ class TestSimple(_Http2Test): if isinstance(event, h2.events.ConnectionTerminated): return False elif isinstance(event, h2.events.RequestReceived): - assert (b'client-foo', b'client-bar-1') in event.headers - assert (b'client-foo', b'client-bar-2') in event.headers + assert (b'self.client-foo', b'self.client-bar-1') in event.headers + assert (b'self.client-foo', b'self.client-bar-2') in event.headers elif isinstance(event, h2.events.StreamEnded): import warnings with warnings.catch_warnings(): @@ -233,32 +237,32 @@ class TestSimple(_Http2Test): def test_simple(self): response_body_buffer = b'' - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() self._send_request( - client.wfile, + self.client.wfile, h2_conn, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), - ('ClIeNt-FoO', 'client-bar-1'), - ('ClIeNt-FoO', 'client-bar-2'), + ('self.client-FoO', 'self.client-bar-1'), + ('self.client-FoO', 'self.client-bar-2'), ], body=b'request body') done = False while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) events = h2_conn.receive_data(raw) except exceptions.HttpException: print(traceback.format_exc()) assert False - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() for event in events: if isinstance(event, h2.events.DataReceived): @@ -267,8 +271,8 @@ class TestSimple(_Http2Test): done = True h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() assert len(self.master.state.flows) == 1 assert self.master.state.flows[0].response.status_code == 200 @@ -317,10 +321,10 @@ class TestRequestWithPriority(_Http2Test): def test_request_with_priority(self, http2_priority_enabled, priority, expected_priority): self.config.options.http2_priority = http2_priority_enabled - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() self._send_request( - client.wfile, + self.client.wfile, h2_conn, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), @@ -336,22 +340,22 @@ class TestRequestWithPriority(_Http2Test): done = False while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) events = h2_conn.receive_data(raw) except exceptions.HttpException: print(traceback.format_exc()) assert False - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() for event in events: if isinstance(event, h2.events.StreamEnded): done = True h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() assert len(self.master.state.flows) == 1 @@ -397,15 +401,15 @@ class TestPriority(_Http2Test): self.config.options.http2_priority = http2_priority_enabled self.__class__.priority_data = [] - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() if prioritize_before: h2_conn.prioritize(1, exclusive=priority[0], depends_on=priority[1], weight=priority[2]) - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() self._send_request( - client.wfile, + self.client.wfile, h2_conn, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), @@ -419,28 +423,28 @@ class TestPriority(_Http2Test): if not prioritize_before: h2_conn.prioritize(1, exclusive=priority[0], depends_on=priority[1], weight=priority[2]) h2_conn.end_stream(1) - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() done = False while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) events = h2_conn.receive_data(raw) except exceptions.HttpException: print(traceback.format_exc()) assert False - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() for event in events: if isinstance(event, h2.events.StreamEnded): done = True h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() assert len(self.master.state.flows) == 1 assert self.priority_data == expected_priority @@ -460,10 +464,10 @@ class TestStreamResetFromServer(_Http2Test): return True def test_request_with_priority(self): - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() self._send_request( - client.wfile, + self.client.wfile, h2_conn, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), @@ -476,22 +480,22 @@ class TestStreamResetFromServer(_Http2Test): done = False while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) events = h2_conn.receive_data(raw) except exceptions.HttpException: print(traceback.format_exc()) assert False - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() for event in events: if isinstance(event, h2.events.StreamReset): done = True h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() assert len(self.master.state.flows) == 1 assert self.master.state.flows[0].response is None @@ -510,10 +514,10 @@ class TestBodySizeLimit(_Http2Test): self.config.options.body_size_limit = "20" self.config.options._processed["body_size_limit"] = 20 - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() self._send_request( - client.wfile, + self.client.wfile, h2_conn, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), @@ -527,22 +531,22 @@ class TestBodySizeLimit(_Http2Test): done = False while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) events = h2_conn.receive_data(raw) except exceptions.HttpException: print(traceback.format_exc()) assert False - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() for event in events: if isinstance(event, h2.events.StreamReset): done = True h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() assert len(self.master.state.flows) == 0 @@ -609,9 +613,9 @@ class TestPushPromise(_Http2Test): return True def test_push_promise(self): - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() - self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ + self._send_request(self.client.wfile, h2_conn, stream_id=1, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), @@ -625,15 +629,15 @@ class TestPushPromise(_Http2Test): responses = 0 while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) events = h2_conn.receive_data(raw) except exceptions.HttpException: print(traceback.format_exc()) assert False except: break - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() for event in events: if isinstance(event, h2.events.StreamEnded): @@ -649,8 +653,8 @@ class TestPushPromise(_Http2Test): done = True h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() assert ended_streams == 3 assert pushed_streams == 2 @@ -665,9 +669,9 @@ class TestPushPromise(_Http2Test): assert len(pushed_flows) == 2 def test_push_promise_reset(self): - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() - self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ + self._send_request(self.client.wfile, h2_conn, stream_id=1, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), @@ -681,14 +685,14 @@ class TestPushPromise(_Http2Test): responses = 0 while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) events = h2_conn.receive_data(raw) except exceptions.HttpException: print(traceback.format_exc()) assert False - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() for event in events: if isinstance(event, h2.events.StreamEnded) and event.stream_id == 1: @@ -696,8 +700,8 @@ class TestPushPromise(_Http2Test): elif isinstance(event, h2.events.PushedStreamReceived): pushed_streams += 1 h2_conn.reset_stream(event.pushed_stream_id, error_code=0x8) - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() elif isinstance(event, h2.events.ResponseReceived): responses += 1 if isinstance(event, h2.events.ConnectionTerminated): @@ -707,8 +711,8 @@ class TestPushPromise(_Http2Test): done = True h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() bodies = [flow.response.content for flow in self.master.state.flows if flow.response] assert len(bodies) >= 1 @@ -728,9 +732,9 @@ class TestConnectionLost(_Http2Test): return False def test_connection_lost(self): - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() - self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ + self._send_request(self.client.wfile, h2_conn, stream_id=1, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), @@ -741,7 +745,7 @@ class TestConnectionLost(_Http2Test): done = False while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) h2_conn.receive_data(raw) except exceptions.HttpException: print(traceback.format_exc()) @@ -749,8 +753,8 @@ class TestConnectionLost(_Http2Test): except: break try: - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() except: break @@ -782,12 +786,12 @@ class TestMaxConcurrentStreams(_Http2Test): return True def test_max_concurrent_streams(self): - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() new_streams = [1, 3, 5, 7, 9, 11] for stream_id in new_streams: # this will exceed MAX_CONCURRENT_STREAMS on the server connection # and cause mitmproxy to throttle stream creation to the server - self._send_request(client.wfile, h2_conn, stream_id=stream_id, headers=[ + self._send_request(self.client.wfile, h2_conn, stream_id=stream_id, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), @@ -798,20 +802,20 @@ class TestMaxConcurrentStreams(_Http2Test): ended_streams = 0 while ended_streams != len(new_streams): try: - header, body = http2.read_raw_frame(client.rfile) + header, body = http2.read_raw_frame(self.client.rfile) events = h2_conn.receive_data(b''.join([header, body])) except: break - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() for event in events: if isinstance(event, h2.events.StreamEnded): ended_streams += 1 h2_conn.close_connection() - client.wfile.write(h2_conn.data_to_send()) - client.wfile.flush() + self.client.wfile.write(h2_conn.data_to_send()) + self.client.wfile.flush() assert len(self.master.state.flows) == len(new_streams) for flow in self.master.state.flows: @@ -831,9 +835,9 @@ class TestConnectionTerminated(_Http2Test): return True def test_connection_terminated(self): - client, h2_conn = self._setup_connection() + h2_conn = self.setup_connection() - self._send_request(client.wfile, h2_conn, headers=[ + self._send_request(self.client.wfile, h2_conn, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), (':method', 'GET'), (':scheme', 'https'), @@ -844,7 +848,7 @@ class TestConnectionTerminated(_Http2Test): connection_terminated_event = None while not done: try: - raw = b''.join(http2.read_raw_frame(client.rfile)) + raw = b''.join(http2.read_raw_frame(self.client.rfile)) events = h2_conn.receive_data(raw) for event in events: if isinstance(event, h2.events.ConnectionTerminated): diff --git a/test/mitmproxy/proxy/protocol/test_tls.py b/test/mitmproxy/proxy/protocol/test_tls.py index e17ee46f..980ba7bd 100644 --- a/test/mitmproxy/proxy/protocol/test_tls.py +++ b/test/mitmproxy/proxy/protocol/test_tls.py @@ -23,4 +23,5 @@ class TestClientHello: ) c = TlsClientHello(data) assert c.sni == 'example.com' - assert c.alpn_protocols == [b'h2', b'http/1.1'] + assert c.alpn_protocols[0].name == b'h2' + assert c.alpn_protocols[1].name == b'http/1.1' diff --git a/test/mitmproxy/proxy/protocol/test_websocket.py b/test/mitmproxy/proxy/protocol/test_websocket.py index 8dfc4f2b..f78e173f 100644 --- a/test/mitmproxy/proxy/protocol/test_websocket.py +++ b/test/mitmproxy/proxy/protocol/test_websocket.py @@ -79,9 +79,13 @@ class _WebSocketTestBase: self.master.reset([]) self.server.server.handle_websockets = self.handle_websockets - def _setup_connection(self): - client = tcp.TCPClient(("127.0.0.1", self.proxy.port)) - client.connect() + def teardown(self): + if self.client: + self.client.close() + + def setup_connection(self): + self.client = tcp.TCPClient(("127.0.0.1", self.proxy.port)) + self.client.connect() request = http.Request( "authority", @@ -92,14 +96,14 @@ class _WebSocketTestBase: "", "HTTP/1.1", content=b'') - client.wfile.write(http.http1.assemble_request(request)) - client.wfile.flush() + self.client.wfile.write(http.http1.assemble_request(request)) + self.client.wfile.flush() - response = http.http1.read_response(client.rfile, request) + response = http.http1.read_response(self.client.rfile, request) if self.ssl: - client.convert_to_ssl() - assert client.ssl_established + self.client.convert_to_ssl() + assert self.client.ssl_established request = http.Request( "relative", @@ -116,14 +120,12 @@ class _WebSocketTestBase: sec_websocket_key="1234", ), content=b'') - client.wfile.write(http.http1.assemble_request(request)) - client.wfile.flush() + self.client.wfile.write(http.http1.assemble_request(request)) + self.client.wfile.flush() - response = http.http1.read_response(client.rfile, request) + response = http.http1.read_response(self.client.rfile, request) assert websockets.check_handshake(response.headers) - return client - class _WebSocketTest(_WebSocketTestBase, _WebSocketServerBase): @@ -154,25 +156,25 @@ class TestSimple(_WebSocketTest): wfile.flush() def test_simple(self): - client = self._setup_connection() + self.setup_connection() - frame = websockets.Frame.from_file(client.rfile) + frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'server-foobar' - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar'))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) + self.client.wfile.flush() - frame = websockets.Frame.from_file(client.rfile) - assert frame.payload == b'client-foobar' + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.payload == b'self.client-foobar' - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef'))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef'))) + self.client.wfile.flush() - frame = websockets.Frame.from_file(client.rfile) + frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'\xde\xad\xbe\xef' - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.flush() assert len(self.master.state.flows) == 2 assert isinstance(self.master.state.flows[0], HTTPFlow) @@ -180,9 +182,9 @@ class TestSimple(_WebSocketTest): assert len(self.master.state.flows[1].messages) == 5 assert self.master.state.flows[1].messages[0].content == b'server-foobar' assert self.master.state.flows[1].messages[0].type == websockets.OPCODE.TEXT - assert self.master.state.flows[1].messages[1].content == b'client-foobar' + assert self.master.state.flows[1].messages[1].content == b'self.client-foobar' assert self.master.state.flows[1].messages[1].type == websockets.OPCODE.TEXT - assert self.master.state.flows[1].messages[2].content == b'client-foobar' + assert self.master.state.flows[1].messages[2].content == b'self.client-foobar' assert self.master.state.flows[1].messages[2].type == websockets.OPCODE.TEXT assert self.master.state.flows[1].messages[3].content == b'\xde\xad\xbe\xef' assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY @@ -203,19 +205,19 @@ class TestSimpleTLS(_WebSocketTest): wfile.flush() def test_simple_tls(self): - client = self._setup_connection() + self.setup_connection() - frame = websockets.Frame.from_file(client.rfile) + frame = websockets.Frame.from_file(self.client.rfile) assert frame.payload == b'server-foobar' - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar'))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar'))) + self.client.wfile.flush() - frame = websockets.Frame.from_file(client.rfile) - assert frame.payload == b'client-foobar' + frame = websockets.Frame.from_file(self.client.rfile) + assert frame.payload == b'self.client-foobar' - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.flush() class TestPing(_WebSocketTest): @@ -233,16 +235,16 @@ class TestPing(_WebSocketTest): wfile.flush() def test_ping(self): - client = self._setup_connection() + self.setup_connection() - frame = websockets.Frame.from_file(client.rfile) + frame = websockets.Frame.from_file(self.client.rfile) assert frame.header.opcode == websockets.OPCODE.PING assert frame.payload == b'foobar' - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) + self.client.wfile.flush() - frame = websockets.Frame.from_file(client.rfile) + frame = websockets.Frame.from_file(self.client.rfile) assert frame.header.opcode == websockets.OPCODE.TEXT assert frame.payload == b'pong-received' @@ -259,12 +261,12 @@ class TestPong(_WebSocketTest): wfile.flush() def test_pong(self): - client = self._setup_connection() + self.setup_connection() - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) + self.client.wfile.flush() - frame = websockets.Frame.from_file(client.rfile) + frame = websockets.Frame.from_file(self.client.rfile) assert frame.header.opcode == websockets.OPCODE.PONG assert frame.payload == b'foobar' @@ -282,34 +284,34 @@ class TestClose(_WebSocketTest): websockets.Frame.from_file(rfile) def test_close(self): - client = self._setup_connection() + self.setup_connection() - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + self.client.wfile.flush() - websockets.Frame.from_file(client.rfile) + websockets.Frame.from_file(self.client.rfile) with pytest.raises(exceptions.TcpDisconnect): - websockets.Frame.from_file(client.rfile) + websockets.Frame.from_file(self.client.rfile) def test_close_payload_1(self): - client = self._setup_connection() + self.setup_connection() - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42'))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42'))) + self.client.wfile.flush() - websockets.Frame.from_file(client.rfile) + websockets.Frame.from_file(self.client.rfile) with pytest.raises(exceptions.TcpDisconnect): - websockets.Frame.from_file(client.rfile) + websockets.Frame.from_file(self.client.rfile) def test_close_payload_2(self): - client = self._setup_connection() + self.setup_connection() - client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar'))) - client.wfile.flush() + self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar'))) + self.client.wfile.flush() - websockets.Frame.from_file(client.rfile) + websockets.Frame.from_file(self.client.rfile) with pytest.raises(exceptions.TcpDisconnect): - websockets.Frame.from_file(client.rfile) + websockets.Frame.from_file(self.client.rfile) class TestInvalidFrame(_WebSocketTest): @@ -320,9 +322,9 @@ class TestInvalidFrame(_WebSocketTest): wfile.flush() def test_invalid_frame(self): - client = self._setup_connection() + self.setup_connection() # with pytest.raises(exceptions.TcpDisconnect): - frame = websockets.Frame.from_file(client.rfile) + frame = websockets.Frame.from_file(self.client.rfile) assert frame.header.opcode == 15 assert frame.payload == b'foobar' diff --git a/test/mitmproxy/test_connections.py b/test/mitmproxy/test_connections.py index e320885d..99367bb6 100644 --- a/test/mitmproxy/test_connections.py +++ b/test/mitmproxy/test_connections.py @@ -1,5 +1,4 @@ import socket -import os import threading import ssl import OpenSSL @@ -140,16 +139,17 @@ class TestServerConnection: assert d.last_log() c.finish() + c.close() d.shutdown() def test_terminate_error(self): d = test.Daemon() c = connections.ServerConnection((d.IFACE, d.port)) c.connect() + c.close() c.connection = mock.Mock() c.connection.recv = mock.Mock(return_value=False) c.connection.flush = mock.Mock(side_effect=exceptions.TcpDisconnect) - c.finish() d.shutdown() def test_sni(self): @@ -194,22 +194,25 @@ class TestClientConnectionTLS: s = socket.create_connection(address) s = ctx.wrap_socket(s, server_hostname=sni) s.send(b'foobar') - s.shutdown(socket.SHUT_RDWR) + s.close() threading.Thread(target=client_run).start() connection, client_address = sock.accept() c = connections.ClientConnection(connection, client_address, None) cert = tutils.test_data.path("mitmproxy/net/data/server.crt") + with open(tutils.test_data.path("mitmproxy/net/data/server.key")) as f: + raw_key = f.read() key = OpenSSL.crypto.load_privatekey( OpenSSL.crypto.FILETYPE_PEM, - open(tutils.test_data.path("mitmproxy/net/data/server.key"), "rb").read()) + raw_key) c.convert_to_ssl(cert, key) assert c.connected() assert c.sni == sni assert c.tls_established assert c.rfile.read(6) == b'foobar' c.finish() + sock.close() class TestServerConnectionTLS(tservers.ServerTestBase): @@ -222,7 +225,7 @@ class TestServerConnectionTLS(tservers.ServerTestBase): @pytest.mark.parametrize("clientcert", [ None, tutils.test_data.path("mitmproxy/data/clientcert"), - os.path.join(tutils.test_data.path("mitmproxy/data/clientcert"), "client.pem"), + tutils.test_data.path("mitmproxy/data/clientcert/client.pem"), ]) def test_tls(self, clientcert): c = connections.ServerConnection(("127.0.0.1", self.port)) diff --git a/test/mitmproxy/test_flowfilter.py b/test/mitmproxy/test_flowfilter.py index 46fff477..fe9b2408 100644 --- a/test/mitmproxy/test_flowfilter.py +++ b/test/mitmproxy/test_flowfilter.py @@ -209,6 +209,9 @@ class TestMatchingHTTPFlow: assert self.q("~u address:22/path", q) assert not self.q("~u moo/path", q) + q.request = None + assert not self.q("~u address", q) + assert self.q("~u address", s) assert self.q("~u address:22/path", s) assert not self.q("~u moo/path", s) diff --git a/test/mitmproxy/test_proxy.py b/test/mitmproxy/test_proxy.py index e1d0da00..299abab3 100644 --- a/test/mitmproxy/test_proxy.py +++ b/test/mitmproxy/test_proxy.py @@ -32,8 +32,7 @@ class TestProcessProxyOptions: opts = options.Options() cmdline.common_options(parser, opts) args = parser.parse_args(args=args) - main.process_options(parser, opts, args) - pconf = config.ProxyConfig(opts) + pconf = main.process_options(parser, opts, args) return parser, pconf def assert_noerr(self, *args): diff --git a/test/mitmproxy/tservers.py b/test/mitmproxy/tservers.py index b8005529..3a2050e1 100644 --- a/test/mitmproxy/tservers.py +++ b/test/mitmproxy/tservers.py @@ -133,7 +133,7 @@ class ProxyTestBase: @classmethod def teardown_class(cls): - # perf: we want to run tests in parallell + # perf: we want to run tests in parallel # should this ever cause an error, travis should catch it. # shutil.rmtree(cls.cadir) cls.proxy.shutdown() diff --git a/test/pathod/language/test_base.py b/test/pathod/language/test_base.py index ec460b07..910d298a 100644 --- a/test/pathod/language/test_base.py +++ b/test/pathod/language/test_base.py @@ -202,12 +202,14 @@ class TestMisc: e.parseString("m@1") s = base.Settings(staticdir=str(tmpdir)) - tmpdir.join("path").write_binary(b"a" * 20, ensure=True) + with open(str(tmpdir.join("path")), 'wb') as f: + f.write(b"a" * 20) v = e.parseString("m<path")[0] with pytest.raises(Exception, match="Invalid value length"): v.values(s) - tmpdir.join("path2").write_binary(b"a" * 4, ensure=True) + with open(str(tmpdir.join("path2")), 'wb') as f: + f.write(b"a" * 4) v = e.parseString("m<path2")[0] assert v.values(s) diff --git a/test/pathod/language/test_generators.py b/test/pathod/language/test_generators.py index dc15aaa1..5e64c726 100644 --- a/test/pathod/language/test_generators.py +++ b/test/pathod/language/test_generators.py @@ -14,18 +14,14 @@ def test_randomgenerator(): def test_filegenerator(tmpdir): f = tmpdir.join("foo") - f.write(b"x" * 10000) + f.write(b"abcdefghijklmnopqrstuvwxyz" * 1000) g = generators.FileGenerator(str(f)) - assert len(g) == 10000 - assert g[0] == b"x" - assert g[-1] == b"x" - assert g[0:5] == b"xxxxx" + assert len(g) == 26000 + assert g[0] == b"a" + assert g[2:7] == b"cdefg" assert len(g[1:10]) == 9 - assert len(g[10000:10001]) == 0 + assert len(g[26000:26001]) == 0 assert repr(g) - # remove all references to FileGenerator instance to close the file - # handle. - del g def test_transform_generator(): diff --git a/test/pathod/protocols/test_http2.py b/test/pathod/protocols/test_http2.py index 1c074197..c16a6d40 100644 --- a/test/pathod/protocols/test_http2.py +++ b/test/pathod/protocols/test_http2.py @@ -202,7 +202,7 @@ class TestApplySettings(net_tservers.ServerTestBase): def handle(self): # check settings acknowledgement assert self.rfile.read(9) == codecs.decode('000000040100000000', 'hex_codec') - self.wfile.write("OK") + self.wfile.write(b"OK") self.wfile.flush() self.rfile.safe_read(9) # just to keep the connection alive a bit longer diff --git a/test/pathod/test_test.py b/test/pathod/test_test.py index 40f45f53..d51a2c7a 100644 --- a/test/pathod/test_test.py +++ b/test/pathod/test_test.py @@ -1,15 +1,9 @@ -import logging +import os import requests import pytest from pathod import test - -from mitmproxy.test import tutils - -import requests.packages.urllib3 - -requests.packages.urllib3.disable_warnings() -logging.disable(logging.CRITICAL) +from pathod.pathod import SSLOptions, CA_CERT_NAME class TestDaemonManual: @@ -22,29 +16,17 @@ class TestDaemonManual: with pytest.raises(requests.ConnectionError): requests.get("http://localhost:%s/p/202:da" % d.port) - def test_startstop_ssl(self): - d = test.Daemon(ssl=True) - rsp = requests.get( - "https://localhost:%s/p/202:da" % - d.port, - verify=False) - assert rsp.ok - assert rsp.status_code == 202 - d.shutdown() - with pytest.raises(requests.ConnectionError): - requests.get("http://localhost:%s/p/202:da" % d.port) - - def test_startstop_ssl_explicit(self): - ssloptions = dict( - certfile=tutils.test_data.path("pathod/data/testkey.pem"), - cacert=tutils.test_data.path("pathod/data/testkey.pem"), - ssl_after_connect=False + @pytest.mark.parametrize('not_after_connect', [True, False]) + def test_startstop_ssl(self, not_after_connect): + ssloptions = SSLOptions( + cn=b'localhost', + sans=[b'localhost', b'127.0.0.1'], + not_after_connect=not_after_connect, ) - d = test.Daemon(ssl=ssloptions) + d = test.Daemon(ssl=True, ssloptions=ssloptions) rsp = requests.get( - "https://localhost:%s/p/202:da" % - d.port, - verify=False) + "https://localhost:%s/p/202:da" % d.port, + verify=os.path.expanduser(os.path.join(d.thread.server.ssloptions.confdir, CA_CERT_NAME))) assert rsp.ok assert rsp.status_code == 202 d.shutdown() diff --git a/test/pathod/tservers.py b/test/pathod/tservers.py index fab09288..a7c92964 100644 --- a/test/pathod/tservers.py +++ b/test/pathod/tservers.py @@ -1,3 +1,4 @@ +import os import tempfile import re import shutil @@ -13,6 +14,7 @@ from pathod import language from pathod import pathoc from pathod import pathod from pathod import test +from pathod.pathod import CA_CERT_NAME def treader(bytes): @@ -72,7 +74,7 @@ class DaemonTests: self.d.port, path ), - verify=False, + verify=os.path.join(self.d.thread.server.ssloptions.confdir, CA_CERT_NAME), params=params ) return resp |