aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--examples/simple/custom_contentview.py6
-rw-r--r--examples/simple/io_read_dumpfile.py2
-rw-r--r--examples/simple/io_write_dumpfile.py10
-rw-r--r--mitmproxy/addons/view.py11
-rw-r--r--mitmproxy/contentviews/__init__.py4
-rw-r--r--mitmproxy/contentviews/base.py26
-rwxr-xr-xmitmproxy/contrib/kaitaistruct/make.sh1
-rw-r--r--mitmproxy/contrib/kaitaistruct/tls_client_hello.py146
-rw-r--r--mitmproxy/contrib/tls_client_hello.ksy139
-rw-r--r--mitmproxy/contrib/tls_parser.py208
-rw-r--r--mitmproxy/flowfilter.py2
-rw-r--r--mitmproxy/net/socks.py4
-rw-r--r--mitmproxy/net/tcp.py10
-rw-r--r--mitmproxy/proxy/protocol/http2.py17
-rw-r--r--mitmproxy/proxy/protocol/http_replay.py4
-rw-r--r--mitmproxy/proxy/protocol/tls.py26
-rw-r--r--mitmproxy/proxy/server.py2
-rw-r--r--mitmproxy/tools/main.py26
-rw-r--r--pathod/language/actions.py2
-rw-r--r--pathod/language/base.py45
-rw-r--r--pathod/language/generators.py18
-rw-r--r--pathod/language/http.py6
-rw-r--r--pathod/language/http2.py10
-rw-r--r--pathod/language/message.py50
-rw-r--r--pathod/language/websockets.py26
-rw-r--r--pathod/pathod_cmdline.py3
-rw-r--r--release/setup.py4
-rw-r--r--requirements.txt2
-rw-r--r--setup.cfg2
-rw-r--r--setup.py30
-rw-r--r--test/conftest.py20
-rw-r--r--test/mitmproxy/addons/test_clientplayback.py7
-rw-r--r--test/mitmproxy/addons/test_save.py5
-rw-r--r--test/mitmproxy/addons/test_serverplayback.py7
-rw-r--r--test/mitmproxy/addons/test_view.py7
-rw-r--r--test/mitmproxy/contentviews/test_protobuf.py4
-rw-r--r--test/mitmproxy/contentviews/test_wbxml.py22
-rw-r--r--test/mitmproxy/contentviews/test_wbxml_data/data-formatted.wbxml10
-rw-r--r--test/mitmproxy/contentviews/test_wbxml_data/data.wbxmlbin0 -> 34 bytes
-rw-r--r--test/mitmproxy/contrib/test_tls_parser.py38
-rw-r--r--test/mitmproxy/net/test_tcp.py8
-rw-r--r--test/mitmproxy/net/tservers.py14
-rw-r--r--test/mitmproxy/platform/test_pf.py5
-rw-r--r--test/mitmproxy/proxy/protocol/test_http1.py1
-rw-r--r--test/mitmproxy/proxy/protocol/test_http2.py172
-rw-r--r--test/mitmproxy/proxy/protocol/test_tls.py3
-rw-r--r--test/mitmproxy/proxy/protocol/test_websocket.py122
-rw-r--r--test/mitmproxy/test_connections.py13
-rw-r--r--test/mitmproxy/test_flowfilter.py3
-rw-r--r--test/mitmproxy/test_proxy.py3
-rw-r--r--test/mitmproxy/tservers.py2
-rw-r--r--test/pathod/language/test_base.py6
-rw-r--r--test/pathod/language/test_generators.py14
-rw-r--r--test/pathod/protocols/test_http2.py2
-rw-r--r--test/pathod/test_test.py40
-rw-r--r--test/pathod/tservers.py4
57 files changed, 720 insertions, 655 deletions
diff --git a/.gitignore b/.gitignore
index a37a1f31..f88a2917 100644
--- a/.gitignore
+++ b/.gitignore
@@ -23,3 +23,4 @@ sslkeylogfile.log
.python-version
coverage.xml
web/coverage/
+.mypy_cache/
diff --git a/examples/simple/custom_contentview.py b/examples/simple/custom_contentview.py
index 71f92575..b958bdce 100644
--- a/examples/simple/custom_contentview.py
+++ b/examples/simple/custom_contentview.py
@@ -3,10 +3,6 @@ This example shows how one can add a custom contentview to mitmproxy.
The content view API is explained in the mitmproxy.contentviews module.
"""
from mitmproxy import contentviews
-import typing
-
-
-CVIEWSWAPCASE = typing.Tuple[str, typing.Iterable[typing.List[typing.Tuple[str, typing.AnyStr]]]]
class ViewSwapCase(contentviews.View):
@@ -17,7 +13,7 @@ class ViewSwapCase(contentviews.View):
prompt = ("swap case text", "z")
content_types = ["text/plain"]
- def __call__(self, data: typing.AnyStr, **metadata) -> CVIEWSWAPCASE:
+ def __call__(self, data, **metadata) -> contentviews.TViewResult:
return "case-swapped text", contentviews.format_text(data.swapcase())
diff --git a/examples/simple/io_read_dumpfile.py b/examples/simple/io_read_dumpfile.py
index ea544cc4..534f357b 100644
--- a/examples/simple/io_read_dumpfile.py
+++ b/examples/simple/io_read_dumpfile.py
@@ -1,6 +1,4 @@
#!/usr/bin/env python
-
-# type: ignore
#
# Simple script showing how to read a mitmproxy dump file
#
diff --git a/examples/simple/io_write_dumpfile.py b/examples/simple/io_write_dumpfile.py
index cf7c4f52..be6e4121 100644
--- a/examples/simple/io_write_dumpfile.py
+++ b/examples/simple/io_write_dumpfile.py
@@ -13,15 +13,15 @@ import typing # noqa
class Writer:
def __init__(self, path: str) -> None:
- if path == "-":
- f = sys.stdout # type: typing.IO[typing.Any]
- else:
- f = open(path, "wb")
- self.w = io.FlowWriter(f)
+ self.f = open(path, "wb") # type: typing.IO[bytes]
+ self.w = io.FlowWriter(self.f)
def response(self, flow: http.HTTPFlow) -> None:
if random.choice([True, False]):
self.w.add(flow)
+ def done(self):
+ self.f.close()
+
addons = [Writer(sys.argv[1])]
diff --git a/mitmproxy/addons/view.py b/mitmproxy/addons/view.py
index 13a17c56..aa3e11ed 100644
--- a/mitmproxy/addons/view.py
+++ b/mitmproxy/addons/view.py
@@ -339,11 +339,12 @@ class View(collections.Sequence):
"""
Load flows into the view, without processing them with addons.
"""
- for i in io.FlowReader(open(path, "rb")).stream():
- # Do this to get a new ID, so we can load the same file N times and
- # get new flows each time. It would be more efficient to just have a
- # .newid() method or something.
- self.add([i.copy()])
+ with open(path, "rb") as f:
+ for i in io.FlowReader(f).stream():
+ # Do this to get a new ID, so we can load the same file N times and
+ # get new flows each time. It would be more efficient to just have a
+ # .newid() method or something.
+ self.add([i.copy()])
@command.command("view.go")
def go(self, dst: int) -> None:
diff --git a/mitmproxy/contentviews/__init__.py b/mitmproxy/contentviews/__init__.py
index f57b27c7..a1866851 100644
--- a/mitmproxy/contentviews/__init__.py
+++ b/mitmproxy/contentviews/__init__.py
@@ -25,7 +25,7 @@ from . import (
auto, raw, hex, json, xml_html, html_outline, wbxml, javascript, css,
urlencoded, multipart, image, query, protobuf
)
-from .base import View, VIEW_CUTOFF, KEY_MAX, format_text, format_dict
+from .base import View, VIEW_CUTOFF, KEY_MAX, format_text, format_dict, TViewResult
views = [] # type: List[View]
content_types_map = {} # type: Dict[str, List[View]]
@@ -178,7 +178,7 @@ add(query.ViewQuery())
add(protobuf.ViewProtobuf())
__all__ = [
- "View", "VIEW_CUTOFF", "KEY_MAX", "format_text", "format_dict",
+ "View", "VIEW_CUTOFF", "KEY_MAX", "format_text", "format_dict", "TViewResult",
"get", "get_by_shortcut", "add", "remove",
"get_content_view", "get_message_content_view",
]
diff --git a/mitmproxy/contentviews/base.py b/mitmproxy/contentviews/base.py
index 0de4f786..97740eea 100644
--- a/mitmproxy/contentviews/base.py
+++ b/mitmproxy/contentviews/base.py
@@ -1,20 +1,21 @@
# Default view cutoff *in lines*
-
-from typing import Iterable, AnyStr, List
-from typing import Mapping
-from typing import Tuple
+import typing
VIEW_CUTOFF = 512
KEY_MAX = 30
+TTextType = typing.Union[str, bytes] # FIXME: This should be either bytes or str ultimately.
+TViewLine = typing.List[typing.Tuple[str, TTextType]]
+TViewResult = typing.Tuple[str, typing.Iterator[TViewLine]]
+
class View:
name = None # type: str
- prompt = None # type: Tuple[str,str]
- content_types = [] # type: List[str]
+ prompt = None # type: typing.Tuple[str,str]
+ content_types = [] # type: typing.List[str]
- def __call__(self, data: bytes, **metadata):
+ def __call__(self, data: bytes, **metadata) -> TViewResult:
"""
Transform raw data into human-readable output.
@@ -38,8 +39,8 @@ class View:
def format_dict(
- d: Mapping[AnyStr, AnyStr]
-) -> Iterable[List[Tuple[str, AnyStr]]]:
+ d: typing.Mapping[TTextType, TTextType]
+) -> typing.Iterator[TViewLine]:
"""
Helper function that transforms the given dictionary into a list of
("key", key )
@@ -49,7 +50,10 @@ def format_dict(
max_key_len = max(len(k) for k in d.keys())
max_key_len = min(max_key_len, KEY_MAX)
for key, value in d.items():
- key += b":" if isinstance(key, bytes) else u":"
+ if isinstance(key, bytes):
+ key += b":"
+ else:
+ key += ":"
key = key.ljust(max_key_len + 2)
yield [
("header", key),
@@ -57,7 +61,7 @@ def format_dict(
]
-def format_text(text: AnyStr) -> Iterable[List[Tuple[str, AnyStr]]]:
+def format_text(text: TTextType) -> typing.Iterator[TViewLine]:
"""
Helper function that transforms bytes into the view output format.
"""
diff --git a/mitmproxy/contrib/kaitaistruct/make.sh b/mitmproxy/contrib/kaitaistruct/make.sh
index 218d5198..9ef68886 100755
--- a/mitmproxy/contrib/kaitaistruct/make.sh
+++ b/mitmproxy/contrib/kaitaistruct/make.sh
@@ -6,5 +6,6 @@ wget -N https://raw.githubusercontent.com/kaitai-io/kaitai_struct_formats/master
wget -N https://raw.githubusercontent.com/kaitai-io/kaitai_struct_formats/master/image/gif.ksy
wget -N https://raw.githubusercontent.com/kaitai-io/kaitai_struct_formats/master/image/jpeg.ksy
wget -N https://raw.githubusercontent.com/kaitai-io/kaitai_struct_formats/master/image/png.ksy
+wget -N https://raw.githubusercontent.com/mitmproxy/mitmproxy/master/mitmproxy/contrib/tls_client_hello.py
kaitai-struct-compiler --target python --opaque-types=true *.ksy
diff --git a/mitmproxy/contrib/kaitaistruct/tls_client_hello.py b/mitmproxy/contrib/kaitaistruct/tls_client_hello.py
new file mode 100644
index 00000000..6aff9b14
--- /dev/null
+++ b/mitmproxy/contrib/kaitaistruct/tls_client_hello.py
@@ -0,0 +1,146 @@
+# This is a generated file! Please edit source .ksy file and use kaitai-struct-compiler to rebuild
+
+import array
+import struct
+import zlib
+from enum import Enum
+from pkg_resources import parse_version
+
+from kaitaistruct import __version__ as ks_version, KaitaiStruct, KaitaiStream, BytesIO
+
+if parse_version(ks_version) < parse_version('0.7'):
+ raise Exception("Incompatible Kaitai Struct Python API: 0.7 or later is required, but you have %s" % (ks_version))
+
+
+class TlsClientHello(KaitaiStruct):
+ def __init__(self, _io, _parent=None, _root=None):
+ self._io = _io
+ self._parent = _parent
+ self._root = _root if _root else self
+ self.version = self._root.Version(self._io, self, self._root)
+ self.random = self._root.Random(self._io, self, self._root)
+ self.session_id = self._root.SessionId(self._io, self, self._root)
+ self.cipher_suites = self._root.CipherSuites(self._io, self, self._root)
+ self.compression_methods = self._root.CompressionMethods(self._io, self, self._root)
+ if self._io.is_eof() == True:
+ self.extensions = [None] * (0)
+ for i in range(0):
+ self.extensions[i] = self._io.read_bytes(0)
+
+ if self._io.is_eof() == False:
+ self.extensions = self._root.Extensions(self._io, self, self._root)
+
+ class ServerName(KaitaiStruct):
+ def __init__(self, _io, _parent=None, _root=None):
+ self._io = _io
+ self._parent = _parent
+ self._root = _root if _root else self
+ self.name_type = self._io.read_u1()
+ self.length = self._io.read_u2be()
+ self.host_name = self._io.read_bytes(self.length)
+
+ class Random(KaitaiStruct):
+ def __init__(self, _io, _parent=None, _root=None):
+ self._io = _io
+ self._parent = _parent
+ self._root = _root if _root else self
+ self.gmt_unix_time = self._io.read_u4be()
+ self.random = self._io.read_bytes(28)
+
+ class SessionId(KaitaiStruct):
+ def __init__(self, _io, _parent=None, _root=None):
+ self._io = _io
+ self._parent = _parent
+ self._root = _root if _root else self
+ self.len = self._io.read_u1()
+ self.sid = self._io.read_bytes(self.len)
+
+ class Sni(KaitaiStruct):
+ def __init__(self, _io, _parent=None, _root=None):
+ self._io = _io
+ self._parent = _parent
+ self._root = _root if _root else self
+ self.list_length = self._io.read_u2be()
+ self.server_names = []
+ while not self._io.is_eof():
+ self.server_names.append(self._root.ServerName(self._io, self, self._root))
+
+ class CipherSuites(KaitaiStruct):
+ def __init__(self, _io, _parent=None, _root=None):
+ self._io = _io
+ self._parent = _parent
+ self._root = _root if _root else self
+ self.len = self._io.read_u2be()
+ self.cipher_suites = [None] * (self.len // 2)
+ for i in range(self.len // 2):
+ self.cipher_suites[i] = self._root.CipherSuite(self._io, self, self._root)
+
+ class CompressionMethods(KaitaiStruct):
+ def __init__(self, _io, _parent=None, _root=None):
+ self._io = _io
+ self._parent = _parent
+ self._root = _root if _root else self
+ self.len = self._io.read_u1()
+ self.compression_methods = self._io.read_bytes(self.len)
+
+ class Alpn(KaitaiStruct):
+ def __init__(self, _io, _parent=None, _root=None):
+ self._io = _io
+ self._parent = _parent
+ self._root = _root if _root else self
+ self.ext_len = self._io.read_u2be()
+ self.alpn_protocols = []
+ while not self._io.is_eof():
+ self.alpn_protocols.append(self._root.Protocol(self._io, self, self._root))
+
+ class Extensions(KaitaiStruct):
+ def __init__(self, _io, _parent=None, _root=None):
+ self._io = _io
+ self._parent = _parent
+ self._root = _root if _root else self
+ self.len = self._io.read_u2be()
+ self.extensions = []
+ while not self._io.is_eof():
+ self.extensions.append(self._root.Extension(self._io, self, self._root))
+
+ class Version(KaitaiStruct):
+ def __init__(self, _io, _parent=None, _root=None):
+ self._io = _io
+ self._parent = _parent
+ self._root = _root if _root else self
+ self.major = self._io.read_u1()
+ self.minor = self._io.read_u1()
+
+ class CipherSuite(KaitaiStruct):
+ def __init__(self, _io, _parent=None, _root=None):
+ self._io = _io
+ self._parent = _parent
+ self._root = _root if _root else self
+ self.cipher_suite = self._io.read_u2be()
+
+ class Protocol(KaitaiStruct):
+ def __init__(self, _io, _parent=None, _root=None):
+ self._io = _io
+ self._parent = _parent
+ self._root = _root if _root else self
+ self.strlen = self._io.read_u1()
+ self.name = self._io.read_bytes(self.strlen)
+
+ class Extension(KaitaiStruct):
+ def __init__(self, _io, _parent=None, _root=None):
+ self._io = _io
+ self._parent = _parent
+ self._root = _root if _root else self
+ self.type = self._io.read_u2be()
+ self.len = self._io.read_u2be()
+ _on = self.type
+ if _on == 0:
+ self._raw_body = self._io.read_bytes(self.len)
+ io = KaitaiStream(BytesIO(self._raw_body))
+ self.body = self._root.Sni(io, self, self._root)
+ elif _on == 16:
+ self._raw_body = self._io.read_bytes(self.len)
+ io = KaitaiStream(BytesIO(self._raw_body))
+ self.body = self._root.Alpn(io, self, self._root)
+ else:
+ self.body = self._io.read_bytes(self.len)
diff --git a/mitmproxy/contrib/tls_client_hello.ksy b/mitmproxy/contrib/tls_client_hello.ksy
new file mode 100644
index 00000000..5b6eb0fb
--- /dev/null
+++ b/mitmproxy/contrib/tls_client_hello.ksy
@@ -0,0 +1,139 @@
+meta:
+ id: tls_client_hello
+ endian: be
+
+seq:
+ - id: version
+ type: version
+
+ - id: random
+ type: random
+
+ - id: session_id
+ type: session_id
+
+ - id: cipher_suites
+ type: cipher_suites
+
+ - id: compression_methods
+ type: compression_methods
+
+ - id: extensions
+ size: 0
+ repeat: expr
+ repeat-expr: 0
+ if: _io.eof == true
+
+ - id: extensions
+ type: extensions
+ if: _io.eof == false
+
+types:
+ version:
+ seq:
+ - id: major
+ type: u1
+
+ - id: minor
+ type: u1
+
+ random:
+ seq:
+ - id: gmt_unix_time
+ type: u4
+
+ - id: random
+ size: 28
+
+ session_id:
+ seq:
+ - id: len
+ type: u1
+
+ - id: sid
+ size: len
+
+ cipher_suites:
+ seq:
+ - id: len
+ type: u2
+
+ - id: cipher_suites
+ type: cipher_suite
+ repeat: expr
+ repeat-expr: len/2
+
+ cipher_suite:
+ seq:
+ - id: cipher_suite
+ type: u2
+
+ compression_methods:
+ seq:
+ - id: len
+ type: u1
+
+ - id: compression_methods
+ size: len
+
+ extensions:
+ seq:
+ - id: len
+ type: u2
+
+ - id: extensions
+ type: extension
+ repeat: eos
+
+ extension:
+ seq:
+ - id: type
+ type: u2
+
+ - id: len
+ type: u2
+
+ - id: body
+ size: len
+ type:
+ switch-on: type
+ cases:
+ 0: sni
+ 16: alpn
+
+ sni:
+ seq:
+ - id: list_length
+ type: u2
+
+ - id: server_names
+ type: server_name
+ repeat: eos
+
+ server_name:
+ seq:
+ - id: name_type
+ type: u1
+
+ - id: length
+ type: u2
+
+ - id: host_name
+ size: length
+
+ alpn:
+ seq:
+ - id: ext_len
+ type: u2
+
+ - id: alpn_protocols
+ type: protocol
+ repeat: eos
+
+ protocol:
+ seq:
+ - id: strlen
+ type: u1
+
+ - id: name
+ size: strlen
diff --git a/mitmproxy/contrib/tls_parser.py b/mitmproxy/contrib/tls_parser.py
deleted file mode 100644
index 61fb3e3e..00000000
--- a/mitmproxy/contrib/tls_parser.py
+++ /dev/null
@@ -1,208 +0,0 @@
-# This file originally comes from https://github.com/pyca/tls/blob/master/tls/_constructs.py.
-# Modified by the mitmproxy team.
-
-# This file is dual licensed under the terms of the Apache License, Version
-# 2.0, and the BSD License. See the LICENSE file in the root of this repository
-# for complete details.
-
-
-from construct import (
- Array,
- Bytes,
- Struct,
- VarInt,
- Int8ub,
- Int16ub,
- Int24ub,
- Int32ub,
- PascalString,
- Embedded,
- Prefixed,
- Range,
- GreedyRange,
- Switch,
- Optional,
-)
-
-ProtocolVersion = "version" / Struct(
- "major" / Int8ub,
- "minor" / Int8ub,
-)
-
-TLSPlaintext = "TLSPlaintext" / Struct(
- "type" / Int8ub,
- ProtocolVersion,
- "length" / Int16ub, # TODO: Reject packets with length > 2 ** 14
- "fragment" / Bytes(lambda ctx: ctx.length),
-)
-
-TLSCompressed = "TLSCompressed" / Struct(
- "type" / Int8ub,
- ProtocolVersion,
- "length" / Int16ub, # TODO: Reject packets with length > 2 ** 14 + 1024
- "fragment" / Bytes(lambda ctx: ctx.length),
-)
-
-TLSCiphertext = "TLSCiphertext" / Struct(
- "type" / Int8ub,
- ProtocolVersion,
- "length" / Int16ub, # TODO: Reject packets with length > 2 ** 14 + 2048
- "fragment" / Bytes(lambda ctx: ctx.length),
-)
-
-Random = "random" / Struct(
- "gmt_unix_time" / Int32ub,
- "random_bytes" / Bytes(28),
-)
-
-SessionID = "session_id" / Struct(
- "length" / Int8ub,
- "session_id" / Bytes(lambda ctx: ctx.length),
-)
-
-CipherSuites = "cipher_suites" / Struct(
- "length" / Int16ub, # TODO: Reject packets of length 0
- Array(lambda ctx: ctx.length // 2, "cipher_suites" / Int16ub),
-)
-
-CompressionMethods = "compression_methods" / Struct(
- "length" / Int8ub, # TODO: Reject packets of length 0
- Array(lambda ctx: ctx.length, "compression_methods" / Int8ub),
-)
-
-ServerName = Struct(
- "type" / Int8ub,
- "name" / PascalString("length" / Int16ub),
-)
-
-SNIExtension = Prefixed(
- Int16ub,
- Struct(
- Int16ub,
- "server_names" / GreedyRange(
- "server_name" / Struct(
- "name_type" / Int8ub,
- "host_name" / PascalString("length" / Int16ub),
- )
- )
- )
-)
-
-ALPNExtension = Prefixed(
- Int16ub,
- Struct(
- Int16ub,
- "alpn_protocols" / GreedyRange(
- "name" / PascalString(Int8ub),
- ),
- )
-)
-
-UnknownExtension = Struct(
- "bytes" / PascalString("length" / Int16ub)
-)
-
-Extension = "Extension" / Struct(
- "type" / Int16ub,
- Embedded(
- Switch(
- lambda ctx: ctx.type,
- {
- 0x00: SNIExtension,
- 0x10: ALPNExtension,
- },
- default=UnknownExtension
- )
- )
-)
-
-extensions = "extensions" / Optional(
- Struct(
- Int16ub,
- "extensions" / GreedyRange(Extension)
- )
-)
-
-ClientHello = "ClientHello" / Struct(
- ProtocolVersion,
- Random,
- SessionID,
- CipherSuites,
- CompressionMethods,
- extensions,
-)
-
-ServerHello = "ServerHello" / Struct(
- ProtocolVersion,
- Random,
- SessionID,
- "cipher_suite" / Bytes(2),
- "compression_method" / Int8ub,
- extensions,
-)
-
-ClientCertificateType = "certificate_types" / Struct(
- "length" / Int8ub, # TODO: Reject packets of length 0
- Array(lambda ctx: ctx.length, "certificate_types" / Int8ub),
-)
-
-SignatureAndHashAlgorithm = "algorithms" / Struct(
- "hash" / Int8ub,
- "signature" / Int8ub,
-)
-
-SupportedSignatureAlgorithms = "supported_signature_algorithms" / Struct(
- "supported_signature_algorithms_length" / Int16ub,
- # TODO: Reject packets of length 0
- Array(
- lambda ctx: ctx.supported_signature_algorithms_length / 2,
- SignatureAndHashAlgorithm,
- ),
-)
-
-DistinguishedName = "certificate_authorities" / Struct(
- "length" / Int16ub,
- "certificate_authorities" / Bytes(lambda ctx: ctx.length),
-)
-
-CertificateRequest = "CertificateRequest" / Struct(
- ClientCertificateType,
- SupportedSignatureAlgorithms,
- DistinguishedName,
-)
-
-ServerDHParams = "ServerDHParams" / Struct(
- "dh_p_length" / Int16ub,
- "dh_p" / Bytes(lambda ctx: ctx.dh_p_length),
- "dh_g_length" / Int16ub,
- "dh_g" / Bytes(lambda ctx: ctx.dh_g_length),
- "dh_Ys_length" / Int16ub,
- "dh_Ys" / Bytes(lambda ctx: ctx.dh_Ys_length),
-)
-
-PreMasterSecret = "pre_master_secret" / Struct(
- ProtocolVersion,
- "random_bytes" / Bytes(46),
-)
-
-ASN1Cert = "ASN1Cert" / Struct(
- "length" / Int32ub, # TODO: Reject packets with length not in 1..2^24-1
- "asn1_cert" / Bytes(lambda ctx: ctx.length),
-)
-
-Certificate = "Certificate" / Struct(
- # TODO: Reject packets with length > 2 ** 24 - 1
- "certificates_length" / Int32ub,
- "certificates_bytes" / Bytes(lambda ctx: ctx.certificates_length),
-)
-
-Handshake = "Handshake" / Struct(
- "msg_type" / Int8ub,
- "length" / Int24ub,
- "body" / Bytes(lambda ctx: ctx.length),
-)
-
-Alert = "Alert" / Struct(
- "level" / Int8ub,
- "description" / Int8ub,
-)
diff --git a/mitmproxy/flowfilter.py b/mitmproxy/flowfilter.py
index 83c98bad..4edf0413 100644
--- a/mitmproxy/flowfilter.py
+++ b/mitmproxy/flowfilter.py
@@ -344,6 +344,8 @@ class FUrl(_Rex):
@only(http.HTTPFlow)
def __call__(self, f):
+ if not f.request:
+ return False
return self.re.search(f.request.pretty_url)
diff --git a/mitmproxy/net/socks.py b/mitmproxy/net/socks.py
index 570a4afb..fdfcfb80 100644
--- a/mitmproxy/net/socks.py
+++ b/mitmproxy/net/socks.py
@@ -82,12 +82,12 @@ class ClientGreeting:
client_greeting = cls(ver, [])
if fail_early:
client_greeting.assert_socks5()
- client_greeting.methods.fromstring(f.safe_read(nmethods))
+ client_greeting.methods.frombytes(f.safe_read(nmethods))
return client_greeting
def to_file(self, f):
f.write(struct.pack("!BB", self.ver, len(self.methods)))
- f.write(self.methods.tostring())
+ f.write(self.methods.tobytes())
class ServerGreeting:
diff --git a/mitmproxy/net/tcp.py b/mitmproxy/net/tcp.py
index 81568d24..12cf7337 100644
--- a/mitmproxy/net/tcp.py
+++ b/mitmproxy/net/tcp.py
@@ -502,7 +502,7 @@ class _Connection:
# Cipher List
if cipher_list:
try:
- context.set_cipher_list(cipher_list)
+ context.set_cipher_list(cipher_list.encode())
context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1'))
except SSL.Error as v:
raise exceptions.TlsException("SSL cipher specification error: %s" % str(v))
@@ -569,7 +569,9 @@ class TCPClient(_Connection):
# Make sure to close the real socket, not the SSL proxy.
# OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection,
# it tries to renegotiate...
- if isinstance(self.connection, SSL.Connection):
+ if not self.connection:
+ return
+ elif isinstance(self.connection, SSL.Connection):
close_socket(self.connection._socket)
else:
close_socket(self.connection)
@@ -674,6 +676,8 @@ class TCPClient(_Connection):
sock.setsockopt(socket.SOL_IP, socket.IP_TRANSPARENT, 1) # pragma: windows no cover pragma: osx no cover
except Exception as e:
# socket.IP_TRANSPARENT might not be available on every OS and Python version
+ if sock is not None:
+ sock.close()
raise exceptions.TcpException(
"Failed to spoof the source address: " + str(e)
)
@@ -864,6 +868,8 @@ class TCPServer:
self.socket.setsockopt(IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
self.socket.bind(self.address)
except socket.error:
+ if self.socket:
+ self.socket.close()
self.socket = None
if not self.socket:
diff --git a/mitmproxy/proxy/protocol/http2.py b/mitmproxy/proxy/protocol/http2.py
index 2191b54b..ace7ecde 100644
--- a/mitmproxy/proxy/protocol/http2.py
+++ b/mitmproxy/proxy/protocol/http2.py
@@ -206,14 +206,15 @@ class Http2Layer(base.Layer):
return True
def _handle_stream_reset(self, eid, event, is_server, other_conn):
- self.streams[eid].kill()
- if eid in self.streams and event.error_code == h2.errors.ErrorCodes.CANCEL:
- if is_server:
- other_stream_id = self.streams[eid].client_stream_id
- else:
- other_stream_id = self.streams[eid].server_stream_id
- if other_stream_id is not None:
- self.connections[other_conn].safe_reset_stream(other_stream_id, event.error_code)
+ if eid in self.streams:
+ self.streams[eid].kill()
+ if event.error_code == h2.errors.ErrorCodes.CANCEL:
+ if is_server:
+ other_stream_id = self.streams[eid].client_stream_id
+ else:
+ other_stream_id = self.streams[eid].server_stream_id
+ if other_stream_id is not None:
+ self.connections[other_conn].safe_reset_stream(other_stream_id, event.error_code)
return True
def _handle_remote_settings_changed(self, event, other_conn):
diff --git a/mitmproxy/proxy/protocol/http_replay.py b/mitmproxy/proxy/protocol/http_replay.py
index 25867871..1aa91847 100644
--- a/mitmproxy/proxy/protocol/http_replay.py
+++ b/mitmproxy/proxy/protocol/http_replay.py
@@ -83,7 +83,11 @@ class RequestReplayThread(basethread.BaseThread):
server.wfile.write(http1.assemble_request(r))
server.wfile.flush()
+
+ if self.f.server_conn:
+ self.f.server_conn.close()
self.f.server_conn = server
+
self.f.response = http.HTTPResponse.wrap(
http1.read_response(
server.rfile,
diff --git a/mitmproxy/proxy/protocol/tls.py b/mitmproxy/proxy/protocol/tls.py
index f55855f0..d42c7fdd 100644
--- a/mitmproxy/proxy/protocol/tls.py
+++ b/mitmproxy/proxy/protocol/tls.py
@@ -1,10 +1,11 @@
import struct
from typing import Optional # noqa
from typing import Union
+import io
-import construct
+from kaitaistruct import KaitaiStream
from mitmproxy import exceptions
-from mitmproxy.contrib import tls_parser
+from mitmproxy.contrib.kaitaistruct import tls_client_hello
from mitmproxy.proxy.protocol import base
from mitmproxy.net import check
@@ -263,7 +264,7 @@ def get_client_hello(client_conn):
class TlsClientHello:
def __init__(self, raw_client_hello):
- self._client_hello = tls_parser.ClientHello.parse(raw_client_hello)
+ self._client_hello = tls_client_hello.TlsClientHello(KaitaiStream(io.BytesIO(raw_client_hello)))
def raw(self):
return self._client_hello
@@ -278,12 +279,12 @@ class TlsClientHello:
for extension in self._client_hello.extensions.extensions:
is_valid_sni_extension = (
extension.type == 0x00 and
- len(extension.server_names) == 1 and
- extension.server_names[0].name_type == 0 and
- check.is_valid_host(extension.server_names[0].host_name)
+ len(extension.body.server_names) == 1 and
+ extension.body.server_names[0].name_type == 0 and
+ check.is_valid_host(extension.body.server_names[0].host_name)
)
if is_valid_sni_extension:
- return extension.server_names[0].host_name.decode("idna")
+ return extension.body.server_names[0].host_name.decode("idna")
return None
@property
@@ -291,7 +292,7 @@ class TlsClientHello:
if self._client_hello.extensions:
for extension in self._client_hello.extensions.extensions:
if extension.type == 0x10:
- return list(extension.alpn_protocols)
+ return list(extension.body.alpn_protocols)
return []
@classmethod
@@ -310,7 +311,7 @@ class TlsClientHello:
try:
return cls(raw_client_hello)
- except construct.ConstructError as e:
+ except EOFError as e:
raise exceptions.TlsProtocolException(
'Cannot parse Client Hello: %s, Raw Client Hello: %s' %
(repr(e), raw_client_hello.encode("hex"))
@@ -518,7 +519,8 @@ class TlsLayer(base.Layer):
# We only support http/1.1 and h2.
# If the server only supports spdy (next to http/1.1), it may select that
# and mitmproxy would enter TCP passthrough mode, which we want to avoid.
- alpn = [x for x in self._client_hello.alpn_protocols if not (x.startswith(b"h2-") or x.startswith(b"spdy"))]
+ alpn = [x.name for x in self._client_hello.alpn_protocols if
+ not (x.name.startswith(b"h2-") or x.name.startswith(b"spdy"))]
if alpn and b"h2" in alpn and not self.config.options.http2:
alpn.remove(b"h2")
@@ -537,8 +539,8 @@ class TlsLayer(base.Layer):
if not ciphers_server and self._client_tls:
ciphers_server = []
for id in self._client_hello.cipher_suites:
- if id in CIPHER_ID_NAME_MAP.keys():
- ciphers_server.append(CIPHER_ID_NAME_MAP[id])
+ if id.cipher_suite in CIPHER_ID_NAME_MAP.keys():
+ ciphers_server.append(CIPHER_ID_NAME_MAP[id.cipher_suite])
ciphers_server = ':'.join(ciphers_server)
self.server_conn.establish_ssl(
diff --git a/mitmproxy/proxy/server.py b/mitmproxy/proxy/server.py
index 50a2b76b..5171fbee 100644
--- a/mitmproxy/proxy/server.py
+++ b/mitmproxy/proxy/server.py
@@ -48,6 +48,8 @@ class ProxyServer(tcp.TCPServer):
if config.options.mode == "transparent":
platform.init_transparent_mode()
except Exception as e:
+ if self.socket:
+ self.socket.close()
raise exceptions.ServerException(
'Error starting proxy server: ' + repr(e)
) from e
diff --git a/mitmproxy/tools/main.py b/mitmproxy/tools/main.py
index d8fac077..84dab1fe 100644
--- a/mitmproxy/tools/main.py
+++ b/mitmproxy/tools/main.py
@@ -1,6 +1,8 @@
from __future__ import print_function # this is here for the version check to work on Python 2.
import sys
+# This must be at the very top, before importing anything else that might break!
+# Keep all other imports below with the 'noqa' magic comment.
if sys.version_info < (3, 5):
print("#" * 49, file=sys.stderr)
print("# mitmproxy only supports Python 3.5 and above! #", file=sys.stderr)
@@ -13,8 +15,7 @@ from mitmproxy.tools import cmdline # noqa
from mitmproxy import exceptions # noqa
from mitmproxy import options # noqa
from mitmproxy import optmanager # noqa
-from mitmproxy.proxy import config # noqa
-from mitmproxy.proxy import server # noqa
+from mitmproxy import proxy # noqa
from mitmproxy.utils import version_check # noqa
from mitmproxy.utils import debug # noqa
@@ -49,15 +50,7 @@ def process_options(parser, opts, args):
adict[n] = getattr(args, n)
opts.merge(adict)
- pconf = config.ProxyConfig(opts)
- if opts.server:
- try:
- return server.ProxyServer(pconf)
- except exceptions.ServerException as v:
- print(str(v), file=sys.stderr)
- sys.exit(1)
- else:
- return server.DummyServer(pconf)
+ return proxy.config.ProxyConfig(opts)
def run(MasterKlass, args, extra=None): # pragma: no cover
@@ -74,7 +67,16 @@ def run(MasterKlass, args, extra=None): # pragma: no cover
master = None
try:
unknown = optmanager.load_paths(opts, args.conf)
- server = process_options(parser, opts, args)
+ pconf = process_options(parser, opts, args)
+ if pconf.options.server:
+ try:
+ server = proxy.server.ProxyServer(pconf)
+ except exceptions.ServerException as v:
+ print(str(v), file=sys.stderr)
+ sys.exit(1)
+ else:
+ server = proxy.server.DummyServer(pconf)
+
master = MasterKlass(opts, server)
master.addons.trigger("configure", opts.keys())
master.addons.trigger("tick")
diff --git a/pathod/language/actions.py b/pathod/language/actions.py
index fc57a18b..3e48f40d 100644
--- a/pathod/language/actions.py
+++ b/pathod/language/actions.py
@@ -50,7 +50,7 @@ class _Action(base.Token):
class PauseAt(_Action):
- unique_name = None # type: ignore
+ unique_name = None
def __init__(self, offset, seconds):
_Action.__init__(self, offset)
diff --git a/pathod/language/base.py b/pathod/language/base.py
index c8892748..97871e7e 100644
--- a/pathod/language/base.py
+++ b/pathod/language/base.py
@@ -6,7 +6,8 @@ import pyparsing as pp
from mitmproxy.utils import strutils
from mitmproxy.utils import human
import typing # noqa
-from . import generators, exceptions
+from . import generators
+from . import exceptions
class Settings:
@@ -375,7 +376,7 @@ class OptionsOrValue(_Component):
class Integer(_Component):
- bounds = (None, None) # type: typing.Tuple[typing.Union[int, None], typing.Union[int , None]]
+ bounds = (None, None) # type: typing.Tuple[typing.Optional[int], typing.Optional[int]]
preamble = ""
def __init__(self, value):
@@ -537,43 +538,3 @@ class IntField(_Component):
def spec(self):
return "%s%s" % (self.preamble, self.origvalue)
-
-
-class NestedMessage(Token):
-
- """
- A nested message, as an escaped string with a preamble.
- """
- preamble = ""
- nest_type = None # type: ignore
-
- def __init__(self, value):
- Token.__init__(self)
- self.value = value
- try:
- self.parsed = self.nest_type(
- self.nest_type.expr().parseString(
- value.val.decode(),
- parseAll=True
- )
- )
- except pp.ParseException as v:
- raise exceptions.ParseException(v.msg, v.line, v.col)
-
- @classmethod
- def expr(cls):
- e = pp.Literal(cls.preamble).suppress()
- e = e + TokValueLiteral.expr()
- return e.setParseAction(lambda x: cls(*x))
-
- def values(self, settings):
- return [
- self.value.get_generator(settings),
- ]
-
- def spec(self):
- return "%s%s" % (self.preamble, self.value.spec())
-
- def freeze(self, settings):
- f = self.parsed.freeze(settings).spec()
- return self.__class__(TokValueLiteral(strutils.bytes_to_escaped_str(f.encode(), escape_single_quotes=True)))
diff --git a/pathod/language/generators.py b/pathod/language/generators.py
index d716804d..1961df74 100644
--- a/pathod/language/generators.py
+++ b/pathod/language/generators.py
@@ -1,7 +1,7 @@
+import os
import string
import random
import mmap
-
import sys
DATATYPES = dict(
@@ -74,20 +74,20 @@ class RandomGenerator:
class FileGenerator:
-
def __init__(self, path):
self.path = path
- self.fp = open(path, "rb")
- self.map = mmap.mmap(self.fp.fileno(), 0, access=mmap.ACCESS_READ)
def __len__(self):
- return len(self.map)
+ return os.path.getsize(self.path)
def __getitem__(self, x):
- if isinstance(x, slice):
- return self.map.__getitem__(x)
- # A slice of length 1 returns a byte object (not an integer)
- return self.map.__getitem__(slice(x, x + 1 or self.map.size()))
+ with open(self.path, mode="rb") as f:
+ if isinstance(x, slice):
+ with mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) as mapped:
+ return mapped.__getitem__(x)
+ else:
+ f.seek(x)
+ return f.read(1)
def __repr__(self):
return "<%s" % self.path
diff --git a/pathod/language/http.py b/pathod/language/http.py
index 5cd717a9..5a962145 100644
--- a/pathod/language/http.py
+++ b/pathod/language/http.py
@@ -54,7 +54,9 @@ class Method(base.OptionsOrValue):
class _HeaderMixin:
- unique_name = None # type: ignore
+ @property
+ def unique_name(self):
+ return None
def format_header(self, key, value):
return [key, b": ", value, b"\r\n"]
@@ -251,7 +253,7 @@ class Response(_HTTPMessage):
return ":".join([i.spec() for i in self.tokens])
-class NestedResponse(base.NestedMessage):
+class NestedResponse(message.NestedMessage):
preamble = "s"
nest_type = Response
diff --git a/pathod/language/http2.py b/pathod/language/http2.py
index 47d6e370..5b27d5bf 100644
--- a/pathod/language/http2.py
+++ b/pathod/language/http2.py
@@ -1,9 +1,9 @@
import pyparsing as pp
+
from mitmproxy.net import http
from mitmproxy.net.http import user_agents, Headers
from . import base, message
-
"""
Normal HTTP requests:
<method>:<path>:<header>:<body>
@@ -41,7 +41,9 @@ def get_header(val, headers):
class _HeaderMixin:
- unique_name = None # type: ignore
+ @property
+ def unique_name(self):
+ return None
def values(self, settings):
return (
@@ -146,7 +148,7 @@ class Times(base.Integer):
class Response(_HTTP2Message):
- unique_name = None # type: ignore
+ unique_name = None
comps = (
Header,
Body,
@@ -203,7 +205,7 @@ class Response(_HTTP2Message):
return ":".join([i.spec() for i in self.tokens])
-class NestedResponse(base.NestedMessage):
+class NestedResponse(message.NestedMessage):
preamble = "s"
nest_type = Response
diff --git a/pathod/language/message.py b/pathod/language/message.py
index 6b4c5021..5dda654b 100644
--- a/pathod/language/message.py
+++ b/pathod/language/message.py
@@ -1,8 +1,11 @@
import abc
-from . import actions, exceptions
-from mitmproxy.utils import strutils
import typing # noqa
+import pyparsing as pp
+
+from mitmproxy.utils import strutils
+from . import actions, exceptions, base
+
LOG_TRUNCATE = 1024
@@ -96,3 +99,46 @@ class Message:
def __repr__(self):
return self.spec()
+
+
+class NestedMessage(base.Token):
+ """
+ A nested message, as an escaped string with a preamble.
+ """
+ preamble = ""
+ nest_type = None # type: typing.Optional[typing.Type[Message]]
+
+ def __init__(self, value):
+ super().__init__()
+ self.value = value
+ try:
+ self.parsed = self.nest_type(
+ self.nest_type.expr().parseString(
+ value.val.decode(),
+ parseAll=True
+ )
+ )
+ except pp.ParseException as v:
+ raise exceptions.ParseException(v.msg, v.line, v.col)
+
+ @classmethod
+ def expr(cls):
+ e = pp.Literal(cls.preamble).suppress()
+ e = e + base.TokValueLiteral.expr()
+ return e.setParseAction(lambda x: cls(*x))
+
+ def values(self, settings):
+ return [
+ self.value.get_generator(settings),
+ ]
+
+ def spec(self):
+ return "%s%s" % (self.preamble, self.value.spec())
+
+ def freeze(self, settings):
+ f = self.parsed.freeze(settings).spec()
+ return self.__class__(
+ base.TokValueLiteral(
+ strutils.bytes_to_escaped_str(f.encode(), escape_single_quotes=True)
+ )
+ )
diff --git a/pathod/language/websockets.py b/pathod/language/websockets.py
index b4faf59b..cc00bcf1 100644
--- a/pathod/language/websockets.py
+++ b/pathod/language/websockets.py
@@ -1,10 +1,12 @@
import random
import string
+import typing # noqa
+
+import pyparsing as pp
+
import mitmproxy.net.websockets
from mitmproxy.utils import strutils
-import pyparsing as pp
from . import base, generators, actions, message
-import typing # noqa
NESTED_LEADER = b"pathod!"
@@ -74,7 +76,7 @@ class Times(base.Integer):
preamble = "x"
-COMPONENTS = (
+COMPONENTS = [
OpCode,
Length,
# Bit flags
@@ -89,14 +91,13 @@ COMPONENTS = (
KeyNone,
Key,
Times,
-
Body,
RawBody,
-)
+]
class WebsocketFrame(message.Message):
- components = COMPONENTS
+ components = COMPONENTS # type: typing.List[typing.Type[base._Component]]
logattrs = ["body"]
# Used for nested frames
unique_name = "body"
@@ -235,19 +236,10 @@ class WebsocketFrame(message.Message):
return ":".join([i.spec() for i in self.tokens])
-class NestedFrame(base.NestedMessage):
+class NestedFrame(message.NestedMessage):
preamble = "f"
nest_type = WebsocketFrame
-COMP = typing.Tuple[
- typing.Type[OpCode], typing.Type[Length], typing.Type[Fin], typing.Type[RSV1], typing.Type[RSV2], typing.Type[RSV3], typing.Type[Mask],
- typing.Type[actions.PauseAt], typing.Type[actions.DisconnectAt], typing.Type[actions.InjectAt], typing.Type[KeyNone], typing.Type[Key],
- typing.Type[Times], typing.Type[Body], typing.Type[RawBody]
-]
-
-
class WebsocketClientFrame(WebsocketFrame):
- components = typing.cast(COMP, COMPONENTS + (
- NestedFrame,
- ))
+ components = COMPONENTS + [NestedFrame]
diff --git a/pathod/pathod_cmdline.py b/pathod/pathod_cmdline.py
index ef1e983f..dee19f4f 100644
--- a/pathod/pathod_cmdline.py
+++ b/pathod/pathod_cmdline.py
@@ -216,7 +216,8 @@ def args_pathod(argv, stdout_=sys.stdout, stderr_=sys.stderr):
anchors = []
for patt, spec in args.anchors:
if os.path.isfile(spec):
- data = open(spec).read()
+ with open(spec) as f:
+ data = f.read()
spec = data
try:
arex = re.compile(patt)
diff --git a/release/setup.py b/release/setup.py
index 01d0672d..17b02ebc 100644
--- a/release/setup.py
+++ b/release/setup.py
@@ -6,9 +6,9 @@ setup(
py_modules=["rtool"],
install_requires=[
"click>=6.2, <7.0",
- "twine>=1.6.5, <1.9",
+ "twine>=1.6.5, <1.10",
"pysftp==0.2.8",
- "cryptography>=1.6, <1.7",
+ "cryptography>=1.6, <1.9",
],
entry_points={
"console_scripts": [
diff --git a/requirements.txt b/requirements.txt
index ab8e8a0b..28a0b495 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1 +1 @@
--e .[dev,examples,contentviews]
+-e .[dev,examples]
diff --git a/setup.cfg b/setup.cfg
index 1721975e..fc35021c 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -21,7 +21,6 @@ exclude_lines =
[tool:full_coverage]
exclude =
- mitmproxy/contentviews/wbxml.py
mitmproxy/proxy/protocol/
mitmproxy/proxy/config.py
mitmproxy/proxy/root_context.py
@@ -39,7 +38,6 @@ exclude =
mitmproxy/addons/onboardingapp/app.py
mitmproxy/addons/termlog.py
mitmproxy/contentviews/base.py
- mitmproxy/contentviews/wbxml.py
mitmproxy/controller.py
mitmproxy/ctx.py
mitmproxy/exceptions.py
diff --git a/setup.py b/setup.py
index a03d74fb..c2fb4718 100644
--- a/setup.py
+++ b/setup.py
@@ -61,9 +61,9 @@ setup(
# It is not considered best practice to use install_requires to pin dependencies to specific versions.
install_requires=[
"blinker>=1.4, <1.5",
- "click>=6.2, <7",
+ "brotlipy>=0.5.1, <0.7",
"certifi>=2015.11.20.1", # no semver here - this should always be on the last release!
- "construct>=2.8, <2.9",
+ "click>=6.2, <7",
"cryptography>=1.4, <1.9",
"cssutils>=1.0.1, <1.1",
"h2>=3.0, <4",
@@ -79,37 +79,29 @@ setup(
"pyperclip>=1.5.22, <1.6",
"requests>=2.9.1, <3",
"ruamel.yaml>=0.13.2, <0.15",
+ "sortedcontainers>=1.5.4, <1.6",
"tornado>=4.3, <4.6",
"urwid>=1.3.1, <1.4",
- "brotlipy>=0.5.1, <0.7",
- "sortedcontainers>=1.5.4, <1.6",
- # transitive from cryptography, we just blacklist here.
- # https://github.com/pypa/setuptools/issues/861
- "setuptools>=11.3, !=29.0.0",
],
extras_require={
':sys_platform == "win32"': [
"pydivert>=2.0.3, <2.1",
],
- ':sys_platform != "win32"': [
- ],
'dev': [
- "Flask>=0.10.1, <0.13",
"flake8>=3.2.1, <3.4",
- "mypy>=0.501, <0.502",
- "rstcheck>=2.2, <4.0",
- "tox>=2.3, <3",
- "pytest>=3, <3.1",
+ "Flask>=0.10.1, <0.13",
+ "mypy>=0.501, <0.512",
"pytest-cov>=2.2.1, <3",
+ "pytest-faulthandler>=1.3.0, <2",
"pytest-timeout>=1.0.0, <2",
"pytest-xdist>=1.14, <2",
- "pytest-faulthandler>=1.3.0, <2",
- "sphinx>=1.3.5, <1.7",
+ "pytest>=3.1, <4",
+ "rstcheck>=2.2, <4.0",
+ "sphinx_rtd_theme>=0.1.9, <0.3",
"sphinx-autobuild>=0.5.2, <0.7",
+ "sphinx>=1.3.5, <1.7",
"sphinxcontrib-documentedlist>=0.5.0, <0.7",
- "sphinx_rtd_theme>=0.1.9, <0.3",
- ],
- 'contentviews': [
+ "tox>=2.3, <3",
],
'examples': [
"beautifulsoup4>=4.4.1, <4.7",
diff --git a/test/conftest.py b/test/conftest.py
index b4e1da93..bb913548 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -1,8 +1,6 @@
import os
import pytest
import OpenSSL
-import functools
-from contextlib import contextmanager
import mitmproxy.net.tcp
@@ -32,21 +30,3 @@ skip_appveyor = pytest.mark.skipif(
def disable_alpn(monkeypatch):
monkeypatch.setattr(mitmproxy.net.tcp, 'HAS_ALPN', False)
monkeypatch.setattr(OpenSSL.SSL._lib, 'Cryptography_HAS_ALPN', False)
-
-
-################################################################################
-# TODO: remove this wrapper when pytest 3.1.0 is released
-original_pytest_raises = pytest.raises
-
-
-@contextmanager
-@functools.wraps(original_pytest_raises)
-def raises(exc, *args, **kwargs):
- with original_pytest_raises(exc, *args, **kwargs) as exc_info:
- yield
- if 'match' in kwargs:
- assert exc_info.match(kwargs['match'])
-
-
-pytest.raises = raises
-################################################################################
diff --git a/test/mitmproxy/addons/test_clientplayback.py b/test/mitmproxy/addons/test_clientplayback.py
index 7ffda317..6089b2d5 100644
--- a/test/mitmproxy/addons/test_clientplayback.py
+++ b/test/mitmproxy/addons/test_clientplayback.py
@@ -10,9 +10,10 @@ from mitmproxy.test import taddons
def tdump(path, flows):
- w = io.FlowWriter(open(path, "wb"))
- for i in flows:
- w.add(i)
+ with open(path, "wb") as f:
+ w = io.FlowWriter(f)
+ for i in flows:
+ w.add(i)
class MockThread():
diff --git a/test/mitmproxy/addons/test_save.py b/test/mitmproxy/addons/test_save.py
index 85c2a398..a4e425cd 100644
--- a/test/mitmproxy/addons/test_save.py
+++ b/test/mitmproxy/addons/test_save.py
@@ -26,8 +26,9 @@ def test_configure(tmpdir):
def rd(p):
- x = io.FlowReader(open(p, "rb"))
- return list(x.stream())
+ with open(p, "rb") as f:
+ x = io.FlowReader(f)
+ return list(x.stream())
def test_tcp(tmpdir):
diff --git a/test/mitmproxy/addons/test_serverplayback.py b/test/mitmproxy/addons/test_serverplayback.py
index 3ceab3fa..7605a5d9 100644
--- a/test/mitmproxy/addons/test_serverplayback.py
+++ b/test/mitmproxy/addons/test_serverplayback.py
@@ -11,9 +11,10 @@ from mitmproxy import io
def tdump(path, flows):
- w = io.FlowWriter(open(path, "wb"))
- for i in flows:
- w.add(i)
+ with open(path, "wb") as f:
+ w = io.FlowWriter(f)
+ for i in flows:
+ w.add(i)
def test_load_file(tmpdir):
diff --git a/test/mitmproxy/addons/test_view.py b/test/mitmproxy/addons/test_view.py
index 6da13650..d5a3a456 100644
--- a/test/mitmproxy/addons/test_view.py
+++ b/test/mitmproxy/addons/test_view.py
@@ -132,9 +132,10 @@ def test_filter():
def tdump(path, flows):
- w = io.FlowWriter(open(path, "wb"))
- for i in flows:
- w.add(i)
+ with open(path, "wb") as f:
+ w = io.FlowWriter(f)
+ for i in flows:
+ w.add(i)
def test_create():
diff --git a/test/mitmproxy/contentviews/test_protobuf.py b/test/mitmproxy/contentviews/test_protobuf.py
index 31e382ec..71e51576 100644
--- a/test/mitmproxy/contentviews/test_protobuf.py
+++ b/test/mitmproxy/contentviews/test_protobuf.py
@@ -17,7 +17,9 @@ def test_view_protobuf_request():
m.configure_mock(**attrs)
n.return_value = m
- content_type, output = v(open(p, "rb").read())
+ with open(p, "rb") as f:
+ data = f.read()
+ content_type, output = v(data)
assert content_type == "Protobuf"
assert output[0] == [('text', b'1: "3bbc333c-e61c-433b-819a-0b9a8cc103b8"')]
diff --git a/test/mitmproxy/contentviews/test_wbxml.py b/test/mitmproxy/contentviews/test_wbxml.py
index 777ab4dd..09c770e7 100644
--- a/test/mitmproxy/contentviews/test_wbxml.py
+++ b/test/mitmproxy/contentviews/test_wbxml.py
@@ -1 +1,21 @@
-# TODO: write tests
+from mitmproxy.contentviews import wbxml
+from mitmproxy.test import tutils
+from . import full_eval
+
+data = tutils.test_data.push("mitmproxy/contentviews/test_wbxml_data/")
+
+
+def test_wbxml():
+ v = full_eval(wbxml.ViewWBXML())
+
+ assert v(b'\x03\x01\x6A\x00') == ('WBXML', [[('text', '<?xml version="1.0" ?>')]])
+ assert v(b'foo') is None
+
+ path = data.path("data.wbxml") # File taken from https://github.com/davidpshaw/PyWBXMLDecoder/tree/master/wbxml_samples
+ with open(path, 'rb') as f:
+ input = f.read()
+ with open("-formatted.".join(path.rsplit(".", 1))) as f:
+ expected = f.read()
+
+ p = wbxml.ASCommandResponse.ASCommandResponse(input)
+ assert p.xmlString == expected
diff --git a/test/mitmproxy/contentviews/test_wbxml_data/data-formatted.wbxml b/test/mitmproxy/contentviews/test_wbxml_data/data-formatted.wbxml
new file mode 100644
index 00000000..fed293bd
--- /dev/null
+++ b/test/mitmproxy/contentviews/test_wbxml_data/data-formatted.wbxml
@@ -0,0 +1,10 @@
+<?xml version="1.0" ?>
+<Sync>
+ <Collections>
+ <Collection>
+ <SyncKey>1509029063</SyncKey>
+ <CollectionId>7</CollectionId>
+ <Status>1</Status>
+ </Collection>
+ </Collections>
+</Sync>
diff --git a/test/mitmproxy/contentviews/test_wbxml_data/data.wbxml b/test/mitmproxy/contentviews/test_wbxml_data/data.wbxml
new file mode 100644
index 00000000..7c7a2004
--- /dev/null
+++ b/test/mitmproxy/contentviews/test_wbxml_data/data.wbxml
Binary files differ
diff --git a/test/mitmproxy/contrib/test_tls_parser.py b/test/mitmproxy/contrib/test_tls_parser.py
deleted file mode 100644
index 66972b62..00000000
--- a/test/mitmproxy/contrib/test_tls_parser.py
+++ /dev/null
@@ -1,38 +0,0 @@
-from mitmproxy.contrib import tls_parser
-
-
-def test_parse_chrome():
- """
- Test if we properly parse a ClientHello sent by Chrome 54.
- """
- data = bytes.fromhex(
- "03033b70638d2523e1cba15f8364868295305e9c52aceabda4b5147210abc783e6e1000022c02bc02fc02cc030"
- "cca9cca8cc14cc13c009c013c00ac014009c009d002f0035000a0100006cff0100010000000010000e00000b65"
- "78616d706c652e636f6d0017000000230000000d00120010060106030501050304010403020102030005000501"
- "00000000001200000010000e000c02683208687474702f312e3175500000000b00020100000a00080006001d00"
- "170018"
- )
- c = tls_parser.ClientHello.parse(data)
- assert c.version.major == 3
- assert c.version.minor == 3
-
- alpn = [a for a in c.extensions.extensions if a.type == 16]
- assert len(alpn) == 1
- assert alpn[0].alpn_protocols == [b"h2", b"http/1.1"]
-
- sni = [a for a in c.extensions.extensions if a.type == 0]
- assert len(sni) == 1
- assert sni[0].server_names[0].name_type == 0
- assert sni[0].server_names[0].host_name == b"example.com"
-
-
-def test_parse_no_extensions():
- data = bytes.fromhex(
- "03015658a756ab2c2bff55f636814deac086b7ca56b65058c7893ffc6074f5245f70205658a75475103a152637"
- "78e1bb6d22e8bbd5b6b0a3a59760ad354e91ba20d353001a0035002f000a000500040009000300060008006000"
- "61006200640100"
- )
- c = tls_parser.ClientHello.parse(data)
- assert c.version.major == 3
- assert c.version.minor == 1
- assert c.extensions is None
diff --git a/test/mitmproxy/net/test_tcp.py b/test/mitmproxy/net/test_tcp.py
index 81d51888..adf8701a 100644
--- a/test/mitmproxy/net/test_tcp.py
+++ b/test/mitmproxy/net/test_tcp.py
@@ -34,7 +34,7 @@ class ClientCipherListHandler(tcp.BaseHandler):
sni = None
def handle(self):
- self.wfile.write("%s" % self.connection.get_cipher_list())
+ self.wfile.write(str(self.connection.get_cipher_list()).encode())
self.wfile.flush()
@@ -398,7 +398,8 @@ class TestServerCipherList(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_ssl(sni="foo.com")
- assert c.rfile.readline() == b"['AES256-GCM-SHA384']"
+ expected = b"['AES256-GCM-SHA384']"
+ assert c.rfile.read(len(expected) + 2) == expected
class TestServerCurrentCipher(tservers.ServerTestBase):
@@ -424,7 +425,7 @@ class TestServerCurrentCipher(tservers.ServerTestBase):
class TestServerCipherListError(tservers.ServerTestBase):
handler = ClientCipherListHandler
ssl = dict(
- cipher_list='bogus'
+ cipher_list=b'bogus'
)
def test_echo(self):
@@ -632,6 +633,7 @@ class TestTCPServer:
with s.handler_counter:
with pytest.raises(exceptions.Timeout):
s.wait_for_silence()
+ s.shutdown()
class TestFileLike:
diff --git a/test/mitmproxy/net/tservers.py b/test/mitmproxy/net/tservers.py
index ebe6d3eb..44701aa5 100644
--- a/test/mitmproxy/net/tservers.py
+++ b/test/mitmproxy/net/tservers.py
@@ -16,9 +16,6 @@ class _ServerThread(threading.Thread):
def run(self):
self.server.serve_forever()
- def shutdown(self):
- self.server.shutdown()
-
class _TServer(tcp.TCPServer):
@@ -54,9 +51,9 @@ class _TServer(tcp.TCPServer):
raw_key = self.ssl.get(
"key",
tutils.test_data.path("mitmproxy/net/data/server.key"))
- key = OpenSSL.crypto.load_privatekey(
- OpenSSL.crypto.FILETYPE_PEM,
- open(raw_key, "rb").read())
+ with open(raw_key) as f:
+ raw_key = f.read()
+ key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, raw_key)
if self.ssl.get("v3_only", False):
method = OpenSSL.SSL.SSLv3_METHOD
options = OpenSSL.SSL.OP_NO_SSLv2 | OpenSSL.SSL.OP_NO_TLSv1
@@ -64,7 +61,8 @@ class _TServer(tcp.TCPServer):
method = OpenSSL.SSL.SSLv23_METHOD
options = None
h.convert_to_ssl(
- cert, key,
+ cert,
+ key,
method=method,
options=options,
handle_sni=getattr(h, "handle_sni", None),
@@ -103,7 +101,7 @@ class ServerTestBase:
@classmethod
def teardown_class(cls):
- cls.server.shutdown()
+ cls.server.server.shutdown()
def teardown(self):
self.server.server.wait_for_silence()
diff --git a/test/mitmproxy/platform/test_pf.py b/test/mitmproxy/platform/test_pf.py
index f644bcc5..3292d345 100644
--- a/test/mitmproxy/platform/test_pf.py
+++ b/test/mitmproxy/platform/test_pf.py
@@ -9,10 +9,11 @@ class TestLookup:
def test_simple(self):
if sys.platform == "freebsd10":
p = tutils.test_data.path("mitmproxy/data/pf02")
- d = open(p, "rb").read()
else:
p = tutils.test_data.path("mitmproxy/data/pf01")
- d = open(p, "rb").read()
+ with open(p, "rb") as f:
+ d = f.read()
+
assert pf.lookup("192.168.1.111", 40000, d) == ("5.5.5.5", 80)
with pytest.raises(Exception, match="Could not resolve original destination"):
pf.lookup("192.168.1.112", 40000, d)
diff --git a/test/mitmproxy/proxy/protocol/test_http1.py b/test/mitmproxy/proxy/protocol/test_http1.py
index 07cd7dcc..b642afb3 100644
--- a/test/mitmproxy/proxy/protocol/test_http1.py
+++ b/test/mitmproxy/proxy/protocol/test_http1.py
@@ -65,6 +65,7 @@ class TestExpectHeader(tservers.HTTPProxyTest):
assert resp.status_code == 200
client.finish()
+ client.close()
class TestHeadContentLength(tservers.HTTPProxyTest):
diff --git a/test/mitmproxy/proxy/protocol/test_http2.py b/test/mitmproxy/proxy/protocol/test_http2.py
index b07257b3..261f8415 100644
--- a/test/mitmproxy/proxy/protocol/test_http2.py
+++ b/test/mitmproxy/proxy/protocol/test_http2.py
@@ -118,12 +118,16 @@ class _Http2TestBase:
self.master.reset([])
self.server.server.handle_server_event = self.handle_server_event
- def _setup_connection(self):
- client = mitmproxy.net.tcp.TCPClient(("127.0.0.1", self.proxy.port))
- client.connect()
+ def teardown(self):
+ if self.client:
+ self.client.close()
+
+ def setup_connection(self):
+ self.client = mitmproxy.net.tcp.TCPClient(("127.0.0.1", self.proxy.port))
+ self.client.connect()
# send CONNECT request
- client.wfile.write(http1.assemble_request(mitmproxy.net.http.Request(
+ self.client.wfile.write(http1.assemble_request(mitmproxy.net.http.Request(
'authority',
b'CONNECT',
b'',
@@ -134,13 +138,13 @@ class _Http2TestBase:
[(b'host', b'localhost:%d' % self.server.server.address[1])],
b'',
)))
- client.wfile.flush()
+ self.client.wfile.flush()
# read CONNECT response
- while client.rfile.readline() != b"\r\n":
+ while self.client.rfile.readline() != b"\r\n":
pass
- client.convert_to_ssl(alpn_protos=[b'h2'])
+ self.client.convert_to_ssl(alpn_protos=[b'h2'])
config = h2.config.H2Configuration(
client_side=True,
@@ -148,10 +152,10 @@ class _Http2TestBase:
validate_inbound_headers=False)
h2_conn = h2.connection.H2Connection(config)
h2_conn.initiate_connection()
- client.wfile.write(h2_conn.data_to_send())
- client.wfile.flush()
+ self.client.wfile.write(h2_conn.data_to_send())
+ self.client.wfile.flush()
- return client, h2_conn
+ return h2_conn
def _send_request(self,
wfile,
@@ -205,8 +209,8 @@ class TestSimple(_Http2Test):
if isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.RequestReceived):
- assert (b'client-foo', b'client-bar-1') in event.headers
- assert (b'client-foo', b'client-bar-2') in event.headers
+ assert (b'self.client-foo', b'self.client-bar-1') in event.headers
+ assert (b'self.client-foo', b'self.client-bar-2') in event.headers
elif isinstance(event, h2.events.StreamEnded):
import warnings
with warnings.catch_warnings():
@@ -233,32 +237,32 @@ class TestSimple(_Http2Test):
def test_simple(self):
response_body_buffer = b''
- client, h2_conn = self._setup_connection()
+ h2_conn = self.setup_connection()
self._send_request(
- client.wfile,
+ self.client.wfile,
h2_conn,
headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
- ('ClIeNt-FoO', 'client-bar-1'),
- ('ClIeNt-FoO', 'client-bar-2'),
+ ('self.client-FoO', 'self.client-bar-1'),
+ ('self.client-FoO', 'self.client-bar-2'),
],
body=b'request body')
done = False
while not done:
try:
- raw = b''.join(http2.read_raw_frame(client.rfile))
+ raw = b''.join(http2.read_raw_frame(self.client.rfile))
events = h2_conn.receive_data(raw)
except exceptions.HttpException:
print(traceback.format_exc())
assert False
- client.wfile.write(h2_conn.data_to_send())
- client.wfile.flush()
+ self.client.wfile.write(h2_conn.data_to_send())
+ self.client.wfile.flush()
for event in events:
if isinstance(event, h2.events.DataReceived):
@@ -267,8 +271,8 @@ class TestSimple(_Http2Test):
done = True
h2_conn.close_connection()
- client.wfile.write(h2_conn.data_to_send())
- client.wfile.flush()
+ self.client.wfile.write(h2_conn.data_to_send())
+ self.client.wfile.flush()
assert len(self.master.state.flows) == 1
assert self.master.state.flows[0].response.status_code == 200
@@ -317,10 +321,10 @@ class TestRequestWithPriority(_Http2Test):
def test_request_with_priority(self, http2_priority_enabled, priority, expected_priority):
self.config.options.http2_priority = http2_priority_enabled
- client, h2_conn = self._setup_connection()
+ h2_conn = self.setup_connection()
self._send_request(
- client.wfile,
+ self.client.wfile,
h2_conn,
headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
@@ -336,22 +340,22 @@ class TestRequestWithPriority(_Http2Test):
done = False
while not done:
try:
- raw = b''.join(http2.read_raw_frame(client.rfile))
+ raw = b''.join(http2.read_raw_frame(self.client.rfile))
events = h2_conn.receive_data(raw)
except exceptions.HttpException:
print(traceback.format_exc())
assert False
- client.wfile.write(h2_conn.data_to_send())
- client.wfile.flush()
+ self.client.wfile.write(h2_conn.data_to_send())
+ self.client.wfile.flush()
for event in events:
if isinstance(event, h2.events.StreamEnded):
done = True
h2_conn.close_connection()
- client.wfile.write(h2_conn.data_to_send())
- client.wfile.flush()
+ self.client.wfile.write(h2_conn.data_to_send())
+ self.client.wfile.flush()
assert len(self.master.state.flows) == 1
@@ -397,15 +401,15 @@ class TestPriority(_Http2Test):
self.config.options.http2_priority = http2_priority_enabled
self.__class__.priority_data = []
- client, h2_conn = self._setup_connection()
+ h2_conn = self.setup_connection()
if prioritize_before:
h2_conn.prioritize(1, exclusive=priority[0], depends_on=priority[1], weight=priority[2])
- client.wfile.write(h2_conn.data_to_send())
- client.wfile.flush()
+ self.client.wfile.write(h2_conn.data_to_send())
+ self.client.wfile.flush()
self._send_request(
- client.wfile,
+ self.client.wfile,
h2_conn,
headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
@@ -419,28 +423,28 @@ class TestPriority(_Http2Test):
if not prioritize_before:
h2_conn.prioritize(1, exclusive=priority[0], depends_on=priority[1], weight=priority[2])
h2_conn.end_stream(1)
- client.wfile.write(h2_conn.data_to_send())
- client.wfile.flush()
+ self.client.wfile.write(h2_conn.data_to_send())
+ self.client.wfile.flush()
done = False
while not done:
try:
- raw = b''.join(http2.read_raw_frame(client.rfile))
+ raw = b''.join(http2.read_raw_frame(self.client.rfile))
events = h2_conn.receive_data(raw)
except exceptions.HttpException:
print(traceback.format_exc())
assert False
- client.wfile.write(h2_conn.data_to_send())
- client.wfile.flush()
+ self.client.wfile.write(h2_conn.data_to_send())
+ self.client.wfile.flush()
for event in events:
if isinstance(event, h2.events.StreamEnded):
done = True
h2_conn.close_connection()
- client.wfile.write(h2_conn.data_to_send())
- client.wfile.flush()
+ self.client.wfile.write(h2_conn.data_to_send())
+ self.client.wfile.flush()
assert len(self.master.state.flows) == 1
assert self.priority_data == expected_priority
@@ -460,10 +464,10 @@ class TestStreamResetFromServer(_Http2Test):
return True
def test_request_with_priority(self):
- client, h2_conn = self._setup_connection()
+ h2_conn = self.setup_connection()
self._send_request(
- client.wfile,
+ self.client.wfile,
h2_conn,
headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
@@ -476,22 +480,22 @@ class TestStreamResetFromServer(_Http2Test):
done = False
while not done:
try:
- raw = b''.join(http2.read_raw_frame(client.rfile))
+ raw = b''.join(http2.read_raw_frame(self.client.rfile))
events = h2_conn.receive_data(raw)
except exceptions.HttpException:
print(traceback.format_exc())
assert False
- client.wfile.write(h2_conn.data_to_send())
- client.wfile.flush()
+ self.client.wfile.write(h2_conn.data_to_send())
+ self.client.wfile.flush()
for event in events:
if isinstance(event, h2.events.StreamReset):
done = True
h2_conn.close_connection()
- client.wfile.write(h2_conn.data_to_send())
- client.wfile.flush()
+ self.client.wfile.write(h2_conn.data_to_send())
+ self.client.wfile.flush()
assert len(self.master.state.flows) == 1
assert self.master.state.flows[0].response is None
@@ -510,10 +514,10 @@ class TestBodySizeLimit(_Http2Test):
self.config.options.body_size_limit = "20"
self.config.options._processed["body_size_limit"] = 20
- client, h2_conn = self._setup_connection()
+ h2_conn = self.setup_connection()
self._send_request(
- client.wfile,
+ self.client.wfile,
h2_conn,
headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
@@ -527,22 +531,22 @@ class TestBodySizeLimit(_Http2Test):
done = False
while not done:
try:
- raw = b''.join(http2.read_raw_frame(client.rfile))
+ raw = b''.join(http2.read_raw_frame(self.client.rfile))
events = h2_conn.receive_data(raw)
except exceptions.HttpException:
print(traceback.format_exc())
assert False
- client.wfile.write(h2_conn.data_to_send())
- client.wfile.flush()
+ self.client.wfile.write(h2_conn.data_to_send())
+ self.client.wfile.flush()
for event in events:
if isinstance(event, h2.events.StreamReset):
done = True
h2_conn.close_connection()
- client.wfile.write(h2_conn.data_to_send())
- client.wfile.flush()
+ self.client.wfile.write(h2_conn.data_to_send())
+ self.client.wfile.flush()
assert len(self.master.state.flows) == 0
@@ -609,9 +613,9 @@ class TestPushPromise(_Http2Test):
return True
def test_push_promise(self):
- client, h2_conn = self._setup_connection()
+ h2_conn = self.setup_connection()
- self._send_request(client.wfile, h2_conn, stream_id=1, headers=[
+ self._send_request(self.client.wfile, h2_conn, stream_id=1, headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
(':method', 'GET'),
(':scheme', 'https'),
@@ -625,15 +629,15 @@ class TestPushPromise(_Http2Test):
responses = 0
while not done:
try:
- raw = b''.join(http2.read_raw_frame(client.rfile))
+ raw = b''.join(http2.read_raw_frame(self.client.rfile))
events = h2_conn.receive_data(raw)
except exceptions.HttpException:
print(traceback.format_exc())
assert False
except:
break
- client.wfile.write(h2_conn.data_to_send())
- client.wfile.flush()
+ self.client.wfile.write(h2_conn.data_to_send())
+ self.client.wfile.flush()
for event in events:
if isinstance(event, h2.events.StreamEnded):
@@ -649,8 +653,8 @@ class TestPushPromise(_Http2Test):
done = True
h2_conn.close_connection()
- client.wfile.write(h2_conn.data_to_send())
- client.wfile.flush()
+ self.client.wfile.write(h2_conn.data_to_send())
+ self.client.wfile.flush()
assert ended_streams == 3
assert pushed_streams == 2
@@ -665,9 +669,9 @@ class TestPushPromise(_Http2Test):
assert len(pushed_flows) == 2
def test_push_promise_reset(self):
- client, h2_conn = self._setup_connection()
+ h2_conn = self.setup_connection()
- self._send_request(client.wfile, h2_conn, stream_id=1, headers=[
+ self._send_request(self.client.wfile, h2_conn, stream_id=1, headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
(':method', 'GET'),
(':scheme', 'https'),
@@ -681,14 +685,14 @@ class TestPushPromise(_Http2Test):
responses = 0
while not done:
try:
- raw = b''.join(http2.read_raw_frame(client.rfile))
+ raw = b''.join(http2.read_raw_frame(self.client.rfile))
events = h2_conn.receive_data(raw)
except exceptions.HttpException:
print(traceback.format_exc())
assert False
- client.wfile.write(h2_conn.data_to_send())
- client.wfile.flush()
+ self.client.wfile.write(h2_conn.data_to_send())
+ self.client.wfile.flush()
for event in events:
if isinstance(event, h2.events.StreamEnded) and event.stream_id == 1:
@@ -696,8 +700,8 @@ class TestPushPromise(_Http2Test):
elif isinstance(event, h2.events.PushedStreamReceived):
pushed_streams += 1
h2_conn.reset_stream(event.pushed_stream_id, error_code=0x8)
- client.wfile.write(h2_conn.data_to_send())
- client.wfile.flush()
+ self.client.wfile.write(h2_conn.data_to_send())
+ self.client.wfile.flush()
elif isinstance(event, h2.events.ResponseReceived):
responses += 1
if isinstance(event, h2.events.ConnectionTerminated):
@@ -707,8 +711,8 @@ class TestPushPromise(_Http2Test):
done = True
h2_conn.close_connection()
- client.wfile.write(h2_conn.data_to_send())
- client.wfile.flush()
+ self.client.wfile.write(h2_conn.data_to_send())
+ self.client.wfile.flush()
bodies = [flow.response.content for flow in self.master.state.flows if flow.response]
assert len(bodies) >= 1
@@ -728,9 +732,9 @@ class TestConnectionLost(_Http2Test):
return False
def test_connection_lost(self):
- client, h2_conn = self._setup_connection()
+ h2_conn = self.setup_connection()
- self._send_request(client.wfile, h2_conn, stream_id=1, headers=[
+ self._send_request(self.client.wfile, h2_conn, stream_id=1, headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
(':method', 'GET'),
(':scheme', 'https'),
@@ -741,7 +745,7 @@ class TestConnectionLost(_Http2Test):
done = False
while not done:
try:
- raw = b''.join(http2.read_raw_frame(client.rfile))
+ raw = b''.join(http2.read_raw_frame(self.client.rfile))
h2_conn.receive_data(raw)
except exceptions.HttpException:
print(traceback.format_exc())
@@ -749,8 +753,8 @@ class TestConnectionLost(_Http2Test):
except:
break
try:
- client.wfile.write(h2_conn.data_to_send())
- client.wfile.flush()
+ self.client.wfile.write(h2_conn.data_to_send())
+ self.client.wfile.flush()
except:
break
@@ -782,12 +786,12 @@ class TestMaxConcurrentStreams(_Http2Test):
return True
def test_max_concurrent_streams(self):
- client, h2_conn = self._setup_connection()
+ h2_conn = self.setup_connection()
new_streams = [1, 3, 5, 7, 9, 11]
for stream_id in new_streams:
# this will exceed MAX_CONCURRENT_STREAMS on the server connection
# and cause mitmproxy to throttle stream creation to the server
- self._send_request(client.wfile, h2_conn, stream_id=stream_id, headers=[
+ self._send_request(self.client.wfile, h2_conn, stream_id=stream_id, headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
(':method', 'GET'),
(':scheme', 'https'),
@@ -798,20 +802,20 @@ class TestMaxConcurrentStreams(_Http2Test):
ended_streams = 0
while ended_streams != len(new_streams):
try:
- header, body = http2.read_raw_frame(client.rfile)
+ header, body = http2.read_raw_frame(self.client.rfile)
events = h2_conn.receive_data(b''.join([header, body]))
except:
break
- client.wfile.write(h2_conn.data_to_send())
- client.wfile.flush()
+ self.client.wfile.write(h2_conn.data_to_send())
+ self.client.wfile.flush()
for event in events:
if isinstance(event, h2.events.StreamEnded):
ended_streams += 1
h2_conn.close_connection()
- client.wfile.write(h2_conn.data_to_send())
- client.wfile.flush()
+ self.client.wfile.write(h2_conn.data_to_send())
+ self.client.wfile.flush()
assert len(self.master.state.flows) == len(new_streams)
for flow in self.master.state.flows:
@@ -831,9 +835,9 @@ class TestConnectionTerminated(_Http2Test):
return True
def test_connection_terminated(self):
- client, h2_conn = self._setup_connection()
+ h2_conn = self.setup_connection()
- self._send_request(client.wfile, h2_conn, headers=[
+ self._send_request(self.client.wfile, h2_conn, headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
(':method', 'GET'),
(':scheme', 'https'),
@@ -844,7 +848,7 @@ class TestConnectionTerminated(_Http2Test):
connection_terminated_event = None
while not done:
try:
- raw = b''.join(http2.read_raw_frame(client.rfile))
+ raw = b''.join(http2.read_raw_frame(self.client.rfile))
events = h2_conn.receive_data(raw)
for event in events:
if isinstance(event, h2.events.ConnectionTerminated):
diff --git a/test/mitmproxy/proxy/protocol/test_tls.py b/test/mitmproxy/proxy/protocol/test_tls.py
index e17ee46f..980ba7bd 100644
--- a/test/mitmproxy/proxy/protocol/test_tls.py
+++ b/test/mitmproxy/proxy/protocol/test_tls.py
@@ -23,4 +23,5 @@ class TestClientHello:
)
c = TlsClientHello(data)
assert c.sni == 'example.com'
- assert c.alpn_protocols == [b'h2', b'http/1.1']
+ assert c.alpn_protocols[0].name == b'h2'
+ assert c.alpn_protocols[1].name == b'http/1.1'
diff --git a/test/mitmproxy/proxy/protocol/test_websocket.py b/test/mitmproxy/proxy/protocol/test_websocket.py
index 8dfc4f2b..f78e173f 100644
--- a/test/mitmproxy/proxy/protocol/test_websocket.py
+++ b/test/mitmproxy/proxy/protocol/test_websocket.py
@@ -79,9 +79,13 @@ class _WebSocketTestBase:
self.master.reset([])
self.server.server.handle_websockets = self.handle_websockets
- def _setup_connection(self):
- client = tcp.TCPClient(("127.0.0.1", self.proxy.port))
- client.connect()
+ def teardown(self):
+ if self.client:
+ self.client.close()
+
+ def setup_connection(self):
+ self.client = tcp.TCPClient(("127.0.0.1", self.proxy.port))
+ self.client.connect()
request = http.Request(
"authority",
@@ -92,14 +96,14 @@ class _WebSocketTestBase:
"",
"HTTP/1.1",
content=b'')
- client.wfile.write(http.http1.assemble_request(request))
- client.wfile.flush()
+ self.client.wfile.write(http.http1.assemble_request(request))
+ self.client.wfile.flush()
- response = http.http1.read_response(client.rfile, request)
+ response = http.http1.read_response(self.client.rfile, request)
if self.ssl:
- client.convert_to_ssl()
- assert client.ssl_established
+ self.client.convert_to_ssl()
+ assert self.client.ssl_established
request = http.Request(
"relative",
@@ -116,14 +120,12 @@ class _WebSocketTestBase:
sec_websocket_key="1234",
),
content=b'')
- client.wfile.write(http.http1.assemble_request(request))
- client.wfile.flush()
+ self.client.wfile.write(http.http1.assemble_request(request))
+ self.client.wfile.flush()
- response = http.http1.read_response(client.rfile, request)
+ response = http.http1.read_response(self.client.rfile, request)
assert websockets.check_handshake(response.headers)
- return client
-
class _WebSocketTest(_WebSocketTestBase, _WebSocketServerBase):
@@ -154,25 +156,25 @@ class TestSimple(_WebSocketTest):
wfile.flush()
def test_simple(self):
- client = self._setup_connection()
+ self.setup_connection()
- frame = websockets.Frame.from_file(client.rfile)
+ frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'server-foobar'
- client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar')))
- client.wfile.flush()
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
+ self.client.wfile.flush()
- frame = websockets.Frame.from_file(client.rfile)
- assert frame.payload == b'client-foobar'
+ frame = websockets.Frame.from_file(self.client.rfile)
+ assert frame.payload == b'self.client-foobar'
- client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef')))
- client.wfile.flush()
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.BINARY, payload=b'\xde\xad\xbe\xef')))
+ self.client.wfile.flush()
- frame = websockets.Frame.from_file(client.rfile)
+ frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'\xde\xad\xbe\xef'
- client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
- client.wfile.flush()
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
+ self.client.wfile.flush()
assert len(self.master.state.flows) == 2
assert isinstance(self.master.state.flows[0], HTTPFlow)
@@ -180,9 +182,9 @@ class TestSimple(_WebSocketTest):
assert len(self.master.state.flows[1].messages) == 5
assert self.master.state.flows[1].messages[0].content == b'server-foobar'
assert self.master.state.flows[1].messages[0].type == websockets.OPCODE.TEXT
- assert self.master.state.flows[1].messages[1].content == b'client-foobar'
+ assert self.master.state.flows[1].messages[1].content == b'self.client-foobar'
assert self.master.state.flows[1].messages[1].type == websockets.OPCODE.TEXT
- assert self.master.state.flows[1].messages[2].content == b'client-foobar'
+ assert self.master.state.flows[1].messages[2].content == b'self.client-foobar'
assert self.master.state.flows[1].messages[2].type == websockets.OPCODE.TEXT
assert self.master.state.flows[1].messages[3].content == b'\xde\xad\xbe\xef'
assert self.master.state.flows[1].messages[3].type == websockets.OPCODE.BINARY
@@ -203,19 +205,19 @@ class TestSimpleTLS(_WebSocketTest):
wfile.flush()
def test_simple_tls(self):
- client = self._setup_connection()
+ self.setup_connection()
- frame = websockets.Frame.from_file(client.rfile)
+ frame = websockets.Frame.from_file(self.client.rfile)
assert frame.payload == b'server-foobar'
- client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar')))
- client.wfile.flush()
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'self.client-foobar')))
+ self.client.wfile.flush()
- frame = websockets.Frame.from_file(client.rfile)
- assert frame.payload == b'client-foobar'
+ frame = websockets.Frame.from_file(self.client.rfile)
+ assert frame.payload == b'self.client-foobar'
- client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
- client.wfile.flush()
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
+ self.client.wfile.flush()
class TestPing(_WebSocketTest):
@@ -233,16 +235,16 @@ class TestPing(_WebSocketTest):
wfile.flush()
def test_ping(self):
- client = self._setup_connection()
+ self.setup_connection()
- frame = websockets.Frame.from_file(client.rfile)
+ frame = websockets.Frame.from_file(self.client.rfile)
assert frame.header.opcode == websockets.OPCODE.PING
assert frame.payload == b'foobar'
- client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
- client.wfile.flush()
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
+ self.client.wfile.flush()
- frame = websockets.Frame.from_file(client.rfile)
+ frame = websockets.Frame.from_file(self.client.rfile)
assert frame.header.opcode == websockets.OPCODE.TEXT
assert frame.payload == b'pong-received'
@@ -259,12 +261,12 @@ class TestPong(_WebSocketTest):
wfile.flush()
def test_pong(self):
- client = self._setup_connection()
+ self.setup_connection()
- client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
- client.wfile.flush()
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
+ self.client.wfile.flush()
- frame = websockets.Frame.from_file(client.rfile)
+ frame = websockets.Frame.from_file(self.client.rfile)
assert frame.header.opcode == websockets.OPCODE.PONG
assert frame.payload == b'foobar'
@@ -282,34 +284,34 @@ class TestClose(_WebSocketTest):
websockets.Frame.from_file(rfile)
def test_close(self):
- client = self._setup_connection()
+ self.setup_connection()
- client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
- client.wfile.flush()
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
+ self.client.wfile.flush()
- websockets.Frame.from_file(client.rfile)
+ websockets.Frame.from_file(self.client.rfile)
with pytest.raises(exceptions.TcpDisconnect):
- websockets.Frame.from_file(client.rfile)
+ websockets.Frame.from_file(self.client.rfile)
def test_close_payload_1(self):
- client = self._setup_connection()
+ self.setup_connection()
- client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42')))
- client.wfile.flush()
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42')))
+ self.client.wfile.flush()
- websockets.Frame.from_file(client.rfile)
+ websockets.Frame.from_file(self.client.rfile)
with pytest.raises(exceptions.TcpDisconnect):
- websockets.Frame.from_file(client.rfile)
+ websockets.Frame.from_file(self.client.rfile)
def test_close_payload_2(self):
- client = self._setup_connection()
+ self.setup_connection()
- client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar')))
- client.wfile.flush()
+ self.client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar')))
+ self.client.wfile.flush()
- websockets.Frame.from_file(client.rfile)
+ websockets.Frame.from_file(self.client.rfile)
with pytest.raises(exceptions.TcpDisconnect):
- websockets.Frame.from_file(client.rfile)
+ websockets.Frame.from_file(self.client.rfile)
class TestInvalidFrame(_WebSocketTest):
@@ -320,9 +322,9 @@ class TestInvalidFrame(_WebSocketTest):
wfile.flush()
def test_invalid_frame(self):
- client = self._setup_connection()
+ self.setup_connection()
# with pytest.raises(exceptions.TcpDisconnect):
- frame = websockets.Frame.from_file(client.rfile)
+ frame = websockets.Frame.from_file(self.client.rfile)
assert frame.header.opcode == 15
assert frame.payload == b'foobar'
diff --git a/test/mitmproxy/test_connections.py b/test/mitmproxy/test_connections.py
index e320885d..99367bb6 100644
--- a/test/mitmproxy/test_connections.py
+++ b/test/mitmproxy/test_connections.py
@@ -1,5 +1,4 @@
import socket
-import os
import threading
import ssl
import OpenSSL
@@ -140,16 +139,17 @@ class TestServerConnection:
assert d.last_log()
c.finish()
+ c.close()
d.shutdown()
def test_terminate_error(self):
d = test.Daemon()
c = connections.ServerConnection((d.IFACE, d.port))
c.connect()
+ c.close()
c.connection = mock.Mock()
c.connection.recv = mock.Mock(return_value=False)
c.connection.flush = mock.Mock(side_effect=exceptions.TcpDisconnect)
- c.finish()
d.shutdown()
def test_sni(self):
@@ -194,22 +194,25 @@ class TestClientConnectionTLS:
s = socket.create_connection(address)
s = ctx.wrap_socket(s, server_hostname=sni)
s.send(b'foobar')
- s.shutdown(socket.SHUT_RDWR)
+ s.close()
threading.Thread(target=client_run).start()
connection, client_address = sock.accept()
c = connections.ClientConnection(connection, client_address, None)
cert = tutils.test_data.path("mitmproxy/net/data/server.crt")
+ with open(tutils.test_data.path("mitmproxy/net/data/server.key")) as f:
+ raw_key = f.read()
key = OpenSSL.crypto.load_privatekey(
OpenSSL.crypto.FILETYPE_PEM,
- open(tutils.test_data.path("mitmproxy/net/data/server.key"), "rb").read())
+ raw_key)
c.convert_to_ssl(cert, key)
assert c.connected()
assert c.sni == sni
assert c.tls_established
assert c.rfile.read(6) == b'foobar'
c.finish()
+ sock.close()
class TestServerConnectionTLS(tservers.ServerTestBase):
@@ -222,7 +225,7 @@ class TestServerConnectionTLS(tservers.ServerTestBase):
@pytest.mark.parametrize("clientcert", [
None,
tutils.test_data.path("mitmproxy/data/clientcert"),
- os.path.join(tutils.test_data.path("mitmproxy/data/clientcert"), "client.pem"),
+ tutils.test_data.path("mitmproxy/data/clientcert/client.pem"),
])
def test_tls(self, clientcert):
c = connections.ServerConnection(("127.0.0.1", self.port))
diff --git a/test/mitmproxy/test_flowfilter.py b/test/mitmproxy/test_flowfilter.py
index 46fff477..fe9b2408 100644
--- a/test/mitmproxy/test_flowfilter.py
+++ b/test/mitmproxy/test_flowfilter.py
@@ -209,6 +209,9 @@ class TestMatchingHTTPFlow:
assert self.q("~u address:22/path", q)
assert not self.q("~u moo/path", q)
+ q.request = None
+ assert not self.q("~u address", q)
+
assert self.q("~u address", s)
assert self.q("~u address:22/path", s)
assert not self.q("~u moo/path", s)
diff --git a/test/mitmproxy/test_proxy.py b/test/mitmproxy/test_proxy.py
index e1d0da00..299abab3 100644
--- a/test/mitmproxy/test_proxy.py
+++ b/test/mitmproxy/test_proxy.py
@@ -32,8 +32,7 @@ class TestProcessProxyOptions:
opts = options.Options()
cmdline.common_options(parser, opts)
args = parser.parse_args(args=args)
- main.process_options(parser, opts, args)
- pconf = config.ProxyConfig(opts)
+ pconf = main.process_options(parser, opts, args)
return parser, pconf
def assert_noerr(self, *args):
diff --git a/test/mitmproxy/tservers.py b/test/mitmproxy/tservers.py
index b8005529..3a2050e1 100644
--- a/test/mitmproxy/tservers.py
+++ b/test/mitmproxy/tservers.py
@@ -133,7 +133,7 @@ class ProxyTestBase:
@classmethod
def teardown_class(cls):
- # perf: we want to run tests in parallell
+ # perf: we want to run tests in parallel
# should this ever cause an error, travis should catch it.
# shutil.rmtree(cls.cadir)
cls.proxy.shutdown()
diff --git a/test/pathod/language/test_base.py b/test/pathod/language/test_base.py
index ec460b07..910d298a 100644
--- a/test/pathod/language/test_base.py
+++ b/test/pathod/language/test_base.py
@@ -202,12 +202,14 @@ class TestMisc:
e.parseString("m@1")
s = base.Settings(staticdir=str(tmpdir))
- tmpdir.join("path").write_binary(b"a" * 20, ensure=True)
+ with open(str(tmpdir.join("path")), 'wb') as f:
+ f.write(b"a" * 20)
v = e.parseString("m<path")[0]
with pytest.raises(Exception, match="Invalid value length"):
v.values(s)
- tmpdir.join("path2").write_binary(b"a" * 4, ensure=True)
+ with open(str(tmpdir.join("path2")), 'wb') as f:
+ f.write(b"a" * 4)
v = e.parseString("m<path2")[0]
assert v.values(s)
diff --git a/test/pathod/language/test_generators.py b/test/pathod/language/test_generators.py
index dc15aaa1..5e64c726 100644
--- a/test/pathod/language/test_generators.py
+++ b/test/pathod/language/test_generators.py
@@ -14,18 +14,14 @@ def test_randomgenerator():
def test_filegenerator(tmpdir):
f = tmpdir.join("foo")
- f.write(b"x" * 10000)
+ f.write(b"abcdefghijklmnopqrstuvwxyz" * 1000)
g = generators.FileGenerator(str(f))
- assert len(g) == 10000
- assert g[0] == b"x"
- assert g[-1] == b"x"
- assert g[0:5] == b"xxxxx"
+ assert len(g) == 26000
+ assert g[0] == b"a"
+ assert g[2:7] == b"cdefg"
assert len(g[1:10]) == 9
- assert len(g[10000:10001]) == 0
+ assert len(g[26000:26001]) == 0
assert repr(g)
- # remove all references to FileGenerator instance to close the file
- # handle.
- del g
def test_transform_generator():
diff --git a/test/pathod/protocols/test_http2.py b/test/pathod/protocols/test_http2.py
index 1c074197..c16a6d40 100644
--- a/test/pathod/protocols/test_http2.py
+++ b/test/pathod/protocols/test_http2.py
@@ -202,7 +202,7 @@ class TestApplySettings(net_tservers.ServerTestBase):
def handle(self):
# check settings acknowledgement
assert self.rfile.read(9) == codecs.decode('000000040100000000', 'hex_codec')
- self.wfile.write("OK")
+ self.wfile.write(b"OK")
self.wfile.flush()
self.rfile.safe_read(9) # just to keep the connection alive a bit longer
diff --git a/test/pathod/test_test.py b/test/pathod/test_test.py
index 40f45f53..d51a2c7a 100644
--- a/test/pathod/test_test.py
+++ b/test/pathod/test_test.py
@@ -1,15 +1,9 @@
-import logging
+import os
import requests
import pytest
from pathod import test
-
-from mitmproxy.test import tutils
-
-import requests.packages.urllib3
-
-requests.packages.urllib3.disable_warnings()
-logging.disable(logging.CRITICAL)
+from pathod.pathod import SSLOptions, CA_CERT_NAME
class TestDaemonManual:
@@ -22,29 +16,17 @@ class TestDaemonManual:
with pytest.raises(requests.ConnectionError):
requests.get("http://localhost:%s/p/202:da" % d.port)
- def test_startstop_ssl(self):
- d = test.Daemon(ssl=True)
- rsp = requests.get(
- "https://localhost:%s/p/202:da" %
- d.port,
- verify=False)
- assert rsp.ok
- assert rsp.status_code == 202
- d.shutdown()
- with pytest.raises(requests.ConnectionError):
- requests.get("http://localhost:%s/p/202:da" % d.port)
-
- def test_startstop_ssl_explicit(self):
- ssloptions = dict(
- certfile=tutils.test_data.path("pathod/data/testkey.pem"),
- cacert=tutils.test_data.path("pathod/data/testkey.pem"),
- ssl_after_connect=False
+ @pytest.mark.parametrize('not_after_connect', [True, False])
+ def test_startstop_ssl(self, not_after_connect):
+ ssloptions = SSLOptions(
+ cn=b'localhost',
+ sans=[b'localhost', b'127.0.0.1'],
+ not_after_connect=not_after_connect,
)
- d = test.Daemon(ssl=ssloptions)
+ d = test.Daemon(ssl=True, ssloptions=ssloptions)
rsp = requests.get(
- "https://localhost:%s/p/202:da" %
- d.port,
- verify=False)
+ "https://localhost:%s/p/202:da" % d.port,
+ verify=os.path.expanduser(os.path.join(d.thread.server.ssloptions.confdir, CA_CERT_NAME)))
assert rsp.ok
assert rsp.status_code == 202
d.shutdown()
diff --git a/test/pathod/tservers.py b/test/pathod/tservers.py
index fab09288..a7c92964 100644
--- a/test/pathod/tservers.py
+++ b/test/pathod/tservers.py
@@ -1,3 +1,4 @@
+import os
import tempfile
import re
import shutil
@@ -13,6 +14,7 @@ from pathod import language
from pathod import pathoc
from pathod import pathod
from pathod import test
+from pathod.pathod import CA_CERT_NAME
def treader(bytes):
@@ -72,7 +74,7 @@ class DaemonTests:
self.d.port,
path
),
- verify=False,
+ verify=os.path.join(self.d.thread.server.ssloptions.confdir, CA_CERT_NAME),
params=params
)
return resp