aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--docs/scripting/inlinescripts.rst2
-rw-r--r--examples/custom_contentviews.py2
-rw-r--r--examples/filt.py8
-rw-r--r--examples/flowwriter.py8
-rw-r--r--examples/har_extractor.py7
-rw-r--r--examples/iframe_injector.py7
-rw-r--r--examples/modify_response_body.py8
-rw-r--r--examples/proxapp.py2
-rw-r--r--examples/sslstrip.py2
-rw-r--r--examples/stub.py2
-rw-r--r--examples/tcp_message.py4
-rw-r--r--examples/tls_passthrough.py7
-rw-r--r--mitmproxy/script/script.py26
-rw-r--r--netlib/exceptions.py4
-rw-r--r--netlib/tcp.py11
-rw-r--r--pathod/test.py14
-rw-r--r--test/mitmproxy/data/scripts/a.py6
-rw-r--r--test/mitmproxy/data/scripts/concurrent_decorator_err.py2
-rw-r--r--test/mitmproxy/data/scripts/starterr.py4
-rw-r--r--test/netlib/http/http2/test_connections.py200
-rw-r--r--test/netlib/test_tcp.py429
-rw-r--r--test/netlib/tservers.py3
22 files changed, 400 insertions, 358 deletions
diff --git a/docs/scripting/inlinescripts.rst b/docs/scripting/inlinescripts.rst
index d282dfa6..2065923d 100644
--- a/docs/scripting/inlinescripts.rst
+++ b/docs/scripting/inlinescripts.rst
@@ -44,7 +44,7 @@ to store any form of state you require.
Script Lifecycle Events
^^^^^^^^^^^^^^^^^^^^^^^
-.. py:function:: start(context, argv)
+.. py:function:: start(context)
Called once on startup, before any other events.
diff --git a/examples/custom_contentviews.py b/examples/custom_contentviews.py
index 034f356c..05ebeb69 100644
--- a/examples/custom_contentviews.py
+++ b/examples/custom_contentviews.py
@@ -62,7 +62,7 @@ class ViewPigLatin(contentviews.View):
pig_view = ViewPigLatin()
-def start(context, argv):
+def start(context):
context.add_contentview(pig_view)
diff --git a/examples/filt.py b/examples/filt.py
index f99b675c..1a423845 100644
--- a/examples/filt.py
+++ b/examples/filt.py
@@ -1,13 +1,13 @@
# This scripts demonstrates how to use mitmproxy's filter pattern in inline scripts.
# Usage: mitmdump -s "filt.py FILTER"
-
+import sys
from mitmproxy import filt
-def start(context, argv):
- if len(argv) != 2:
+def start(context):
+ if len(sys.argv) != 2:
raise ValueError("Usage: -s 'filt.py FILTER'")
- context.filter = filt.parse(argv[1])
+ context.filter = filt.parse(sys.argv[1])
def response(context, flow):
diff --git a/examples/flowwriter.py b/examples/flowwriter.py
index 8fb8cc60..cb5ccb0d 100644
--- a/examples/flowwriter.py
+++ b/examples/flowwriter.py
@@ -4,14 +4,14 @@ import sys
from mitmproxy.flow import FlowWriter
-def start(context, argv):
- if len(argv) != 2:
+def start(context):
+ if len(sys.argv) != 2:
raise ValueError('Usage: -s "flowriter.py filename"')
- if argv[1] == "-":
+ if sys.argv[1] == "-":
f = sys.stdout
else:
- f = open(argv[1], "wb")
+ f = open(sys.argv[1], "wb")
context.flow_writer = FlowWriter(f)
diff --git a/examples/har_extractor.py b/examples/har_extractor.py
index 6806989d..c21f1a8f 100644
--- a/examples/har_extractor.py
+++ b/examples/har_extractor.py
@@ -3,6 +3,7 @@
https://github.com/JustusW/harparser to generate a HAR log object.
"""
import six
+import sys
from harparser import HAR
from datetime import datetime
@@ -52,15 +53,15 @@ class _HARLog(HAR.log):
return self.__page_list__
-def start(context, argv):
+def start(context):
"""
On start we create a HARLog instance. You will have to adapt this to
suit your actual needs of HAR generation. As it will probably be
necessary to cluster logs by IPs or reset them from time to time.
"""
context.dump_file = None
- if len(argv) > 1:
- context.dump_file = argv[1]
+ if len(sys.argv) > 1:
+ context.dump_file = sys.argv[1]
else:
raise ValueError(
'Usage: -s "har_extractor.py filename" '
diff --git a/examples/iframe_injector.py b/examples/iframe_injector.py
index ad844f19..9495da93 100644
--- a/examples/iframe_injector.py
+++ b/examples/iframe_injector.py
@@ -1,13 +1,14 @@
# Usage: mitmdump -s "iframe_injector.py url"
# (this script works best with --anticache)
+import sys
from bs4 import BeautifulSoup
from mitmproxy.models import decoded
-def start(context, argv):
- if len(argv) != 2:
+def start(context):
+ if len(sys.argv) != 2:
raise ValueError('Usage: -s "iframe_injector.py url"')
- context.iframe_url = argv[1]
+ context.iframe_url = sys.argv[1]
def response(context, flow):
diff --git a/examples/modify_response_body.py b/examples/modify_response_body.py
index d68bcf63..3034892e 100644
--- a/examples/modify_response_body.py
+++ b/examples/modify_response_body.py
@@ -1,14 +1,16 @@
# Usage: mitmdump -s "modify_response_body.py mitmproxy bananas"
# (this script works best with --anticache)
+import sys
+
from mitmproxy.models import decoded
-def start(context, argv):
- if len(argv) != 3:
+def start(context):
+ if len(sys.argv) != 3:
raise ValueError('Usage: -s "modify_response_body.py old new"')
# You may want to use Python's argparse for more sophisticated argument
# parsing.
- context.old, context.new = argv[1], argv[2]
+ context.old, context.new = sys.argv[1], sys.argv[2]
def response(context, flow):
diff --git a/examples/proxapp.py b/examples/proxapp.py
index 4d8e7b58..613d3f8b 100644
--- a/examples/proxapp.py
+++ b/examples/proxapp.py
@@ -15,7 +15,7 @@ def hello_world():
# Register the app using the magic domain "proxapp" on port 80. Requests to
# this domain and port combination will now be routed to the WSGI app instance.
-def start(context, argv):
+def start(context):
context.app_registry.add(app, "proxapp", 80)
# SSL works too, but the magic domain needs to be resolvable from the mitmproxy machine due to mitmproxy's design.
diff --git a/examples/sslstrip.py b/examples/sslstrip.py
index 1bc89946..8dde8e3e 100644
--- a/examples/sslstrip.py
+++ b/examples/sslstrip.py
@@ -3,7 +3,7 @@ import re
from six.moves import urllib
-def start(context, argv):
+def start(context):
# set of SSL/TLS capable hosts
context.secure_hosts = set()
diff --git a/examples/stub.py b/examples/stub.py
index 516b71a5..a0f73538 100644
--- a/examples/stub.py
+++ b/examples/stub.py
@@ -3,7 +3,7 @@
"""
-def start(context, argv):
+def start(context):
"""
Called once on script startup, before any other events.
"""
diff --git a/examples/tcp_message.py b/examples/tcp_message.py
index 2c210618..78500c19 100644
--- a/examples/tcp_message.py
+++ b/examples/tcp_message.py
@@ -1,4 +1,4 @@
-'''
+"""
tcp_message Inline Script Hook API Demonstration
------------------------------------------------
@@ -7,7 +7,7 @@ tcp_message Inline Script Hook API Demonstration
example cmdline invocation:
mitmdump -T --host --tcp ".*" -q -s examples/tcp_message.py
-'''
+"""
from netlib import strutils
diff --git a/examples/tls_passthrough.py b/examples/tls_passthrough.py
index 0c6d450d..50aab65b 100644
--- a/examples/tls_passthrough.py
+++ b/examples/tls_passthrough.py
@@ -24,6 +24,7 @@ from __future__ import (absolute_import, print_function, division)
import collections
import random
+import sys
from enum import Enum
from mitmproxy.exceptions import TlsProtocolException
@@ -110,9 +111,9 @@ class TlsFeedback(TlsLayer):
# inline script hooks below.
-def start(context, argv):
- if len(argv) == 2:
- context.tls_strategy = ProbabilisticStrategy(float(argv[1]))
+def start(context):
+ if len(sys.argv) == 2:
+ context.tls_strategy = ProbabilisticStrategy(float(sys.argv[1]))
else:
context.tls_strategy = ConservativeStrategy()
diff --git a/mitmproxy/script/script.py b/mitmproxy/script/script.py
index 70f74817..9ff79f52 100644
--- a/mitmproxy/script/script.py
+++ b/mitmproxy/script/script.py
@@ -6,15 +6,28 @@ by the mitmproxy-specific ScriptContext.
# Do not import __future__ here, this would apply transitively to the inline scripts.
from __future__ import absolute_import, print_function, division
+import inspect
import os
import shlex
import sys
+import contextlib
+import warnings
import six
from mitmproxy import exceptions
+@contextlib.contextmanager
+def setargs(args):
+ oldargs = sys.argv
+ sys.argv = args
+ try:
+ yield
+ finally:
+ sys.argv = oldargs
+
+
class Script(object):
"""
@@ -89,7 +102,15 @@ class Script(object):
finally:
sys.path.pop()
sys.path.pop()
- return self.run("start", self.args)
+
+ start_fn = self.ns.get("start")
+ if start_fn and len(inspect.getargspec(start_fn).args) == 2:
+ warnings.warn(
+ "The 'args' argument of the start() script hook is deprecated. "
+ "Please use sys.argv instead."
+ )
+ return self.run("start", self.args)
+ return self.run("start")
def unload(self):
try:
@@ -113,7 +134,8 @@ class Script(object):
f = self.ns.get(name)
if f:
try:
- return f(self.ctx, *args, **kwargs)
+ with setargs(self.args):
+ return f(self.ctx, *args, **kwargs)
except Exception:
six.reraise(
exceptions.ScriptException,
diff --git a/netlib/exceptions.py b/netlib/exceptions.py
index 05f1054b..dec79c22 100644
--- a/netlib/exceptions.py
+++ b/netlib/exceptions.py
@@ -54,3 +54,7 @@ class TlsException(NetlibException):
class InvalidCertificateException(TlsException):
pass
+
+
+class Timeout(TcpException):
+ pass
diff --git a/netlib/tcp.py b/netlib/tcp.py
index a8a68139..69dafc1f 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -967,3 +967,14 @@ class TCPServer(object):
"""
Called after server shutdown.
"""
+
+ def wait_for_silence(self, timeout=5):
+ start = time.time()
+ while 1:
+ if time.time() - start >= timeout:
+ raise exceptions.Timeout(
+ "%s service threads still alive" %
+ self.handler_counter.count
+ )
+ if self.handler_counter.count == 0:
+ return
diff --git a/pathod/test.py b/pathod/test.py
index 3ba541b1..4992945d 100644
--- a/pathod/test.py
+++ b/pathod/test.py
@@ -7,10 +7,6 @@ from . import pathod
from netlib import basethread
-class TimeoutError(Exception):
- pass
-
-
class Daemon:
IFACE = "127.0.0.1"
@@ -45,15 +41,7 @@ class Daemon:
return self.logfp.getvalue()
def wait_for_silence(self, timeout=5):
- start = time.time()
- while 1:
- if time.time() - start >= timeout:
- raise TimeoutError(
- "%s service threads still alive" %
- self.thread.server.handler_counter.count
- )
- if self.thread.server.handler_counter.count == 0:
- return
+ self.thread.server.wait_for_silence(timeout=timeout)
def expect_log(self, n, timeout=5):
l = []
diff --git a/test/mitmproxy/data/scripts/a.py b/test/mitmproxy/data/scripts/a.py
index d4272ac8..33dbaa64 100644
--- a/test/mitmproxy/data/scripts/a.py
+++ b/test/mitmproxy/data/scripts/a.py
@@ -1,11 +1,13 @@
+import sys
+
from a_helper import parser
var = 0
-def start(ctx, argv):
+def start(ctx):
global var
- var = parser.parse_args(argv[1:]).var
+ var = parser.parse_args(sys.argv[1:]).var
def here(ctx):
diff --git a/test/mitmproxy/data/scripts/concurrent_decorator_err.py b/test/mitmproxy/data/scripts/concurrent_decorator_err.py
index 071b8889..349e5dd6 100644
--- a/test/mitmproxy/data/scripts/concurrent_decorator_err.py
+++ b/test/mitmproxy/data/scripts/concurrent_decorator_err.py
@@ -2,5 +2,5 @@ from mitmproxy.script import concurrent
@concurrent
-def start(context, argv):
+def start(context):
pass
diff --git a/test/mitmproxy/data/scripts/starterr.py b/test/mitmproxy/data/scripts/starterr.py
index b217bdfe..82d773bd 100644
--- a/test/mitmproxy/data/scripts/starterr.py
+++ b/test/mitmproxy/data/scripts/starterr.py
@@ -1,3 +1,3 @@
-def start(ctx, argv):
- raise ValueError
+def start(ctx):
+ raise ValueError()
diff --git a/test/netlib/http/http2/test_connections.py b/test/netlib/http/http2/test_connections.py
index 27cc30ba..2a43627a 100644
--- a/test/netlib/http/http2/test_connections.py
+++ b/test/netlib/http/http2/test_connections.py
@@ -75,10 +75,10 @@ class TestCheckALPNMatch(tservers.ServerTestBase):
def test_check_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl(alpn_protos=[b'h2'])
- protocol = HTTP2Protocol(c)
- assert protocol.check_alpn()
+ with c.connect():
+ c.convert_to_ssl(alpn_protos=[b'h2'])
+ protocol = HTTP2Protocol(c)
+ assert protocol.check_alpn()
class TestCheckALPNMismatch(tservers.ServerTestBase):
@@ -91,11 +91,11 @@ class TestCheckALPNMismatch(tservers.ServerTestBase):
def test_check_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl(alpn_protos=[b'h2'])
- protocol = HTTP2Protocol(c)
- with raises(NotImplementedError):
- protocol.check_alpn()
+ with c.connect():
+ c.convert_to_ssl(alpn_protos=[b'h2'])
+ protocol = HTTP2Protocol(c)
+ with raises(NotImplementedError):
+ protocol.check_alpn()
class TestPerformServerConnectionPreface(tservers.ServerTestBase):
@@ -124,15 +124,15 @@ class TestPerformServerConnectionPreface(tservers.ServerTestBase):
def test_perform_server_connection_preface(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- protocol = HTTP2Protocol(c)
+ with c.connect():
+ protocol = HTTP2Protocol(c)
- assert not protocol.connection_preface_performed
- protocol.perform_server_connection_preface()
- assert protocol.connection_preface_performed
+ assert not protocol.connection_preface_performed
+ protocol.perform_server_connection_preface()
+ assert protocol.connection_preface_performed
- with raises(TcpDisconnect):
- protocol.perform_server_connection_preface(force=True)
+ with raises(TcpDisconnect):
+ protocol.perform_server_connection_preface(force=True)
class TestPerformClientConnectionPreface(tservers.ServerTestBase):
@@ -160,12 +160,12 @@ class TestPerformClientConnectionPreface(tservers.ServerTestBase):
def test_perform_client_connection_preface(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- protocol = HTTP2Protocol(c)
+ with c.connect():
+ protocol = HTTP2Protocol(c)
- assert not protocol.connection_preface_performed
- protocol.perform_client_connection_preface()
- assert protocol.connection_preface_performed
+ assert not protocol.connection_preface_performed
+ protocol.perform_client_connection_preface()
+ assert protocol.connection_preface_performed
class TestClientStreamIds(object):
@@ -209,24 +209,24 @@ class TestApplySettings(tservers.ServerTestBase):
def test_apply_settings(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- protocol = HTTP2Protocol(c)
+ with c.connect():
+ c.convert_to_ssl()
+ protocol = HTTP2Protocol(c)
- protocol._apply_settings({
- hyperframe.frame.SettingsFrame.ENABLE_PUSH: 'foo',
- hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 'bar',
- hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE: 'deadbeef',
- })
+ protocol._apply_settings({
+ hyperframe.frame.SettingsFrame.ENABLE_PUSH: 'foo',
+ hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS: 'bar',
+ hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE: 'deadbeef',
+ })
- assert c.rfile.safe_read(2) == b"OK"
+ assert c.rfile.safe_read(2) == b"OK"
- assert protocol.http2_settings[
- hyperframe.frame.SettingsFrame.ENABLE_PUSH] == 'foo'
- assert protocol.http2_settings[
- hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS] == 'bar'
- assert protocol.http2_settings[
- hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE] == 'deadbeef'
+ assert protocol.http2_settings[
+ hyperframe.frame.SettingsFrame.ENABLE_PUSH] == 'foo'
+ assert protocol.http2_settings[
+ hyperframe.frame.SettingsFrame.MAX_CONCURRENT_STREAMS] == 'bar'
+ assert protocol.http2_settings[
+ hyperframe.frame.SettingsFrame.INITIAL_WINDOW_SIZE] == 'deadbeef'
class TestCreateHeaders(object):
@@ -304,19 +304,19 @@ class TestReadRequest(tservers.ServerTestBase):
def test_read_request(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- protocol = HTTP2Protocol(c, is_server=True)
- protocol.connection_preface_performed = True
+ with c.connect():
+ c.convert_to_ssl()
+ protocol = HTTP2Protocol(c, is_server=True)
+ protocol.connection_preface_performed = True
- req = protocol.read_request(NotImplemented)
+ req = protocol.read_request(NotImplemented)
- assert req.stream_id
- assert req.headers.fields == ()
- assert req.method == "GET"
- assert req.path == "/"
- assert req.scheme == "https"
- assert req.content == b'foobar'
+ assert req.stream_id
+ assert req.headers.fields == ()
+ assert req.method == "GET"
+ assert req.path == "/"
+ assert req.scheme == "https"
+ assert req.content == b'foobar'
class TestReadRequestRelative(tservers.ServerTestBase):
@@ -330,16 +330,16 @@ class TestReadRequestRelative(tservers.ServerTestBase):
def test_asterisk_form(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- protocol = HTTP2Protocol(c, is_server=True)
- protocol.connection_preface_performed = True
+ with c.connect():
+ c.convert_to_ssl()
+ protocol = HTTP2Protocol(c, is_server=True)
+ protocol.connection_preface_performed = True
- req = protocol.read_request(NotImplemented)
+ req = protocol.read_request(NotImplemented)
- assert req.first_line_format == "relative"
- assert req.method == "OPTIONS"
- assert req.path == "*"
+ assert req.first_line_format == "relative"
+ assert req.method == "OPTIONS"
+ assert req.path == "*"
class TestReadRequestAbsolute(tservers.ServerTestBase):
@@ -353,17 +353,17 @@ class TestReadRequestAbsolute(tservers.ServerTestBase):
def test_absolute_form(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- protocol = HTTP2Protocol(c, is_server=True)
- protocol.connection_preface_performed = True
+ with c.connect():
+ c.convert_to_ssl()
+ protocol = HTTP2Protocol(c, is_server=True)
+ protocol.connection_preface_performed = True
- req = protocol.read_request(NotImplemented)
+ req = protocol.read_request(NotImplemented)
- assert req.first_line_format == "absolute"
- assert req.scheme == "http"
- assert req.host == "address"
- assert req.port == 22
+ assert req.first_line_format == "absolute"
+ assert req.scheme == "http"
+ assert req.host == "address"
+ assert req.port == 22
class TestReadRequestConnect(tservers.ServerTestBase):
@@ -379,22 +379,22 @@ class TestReadRequestConnect(tservers.ServerTestBase):
def test_connect(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- protocol = HTTP2Protocol(c, is_server=True)
- protocol.connection_preface_performed = True
+ with c.connect():
+ c.convert_to_ssl()
+ protocol = HTTP2Protocol(c, is_server=True)
+ protocol.connection_preface_performed = True
- req = protocol.read_request(NotImplemented)
- assert req.first_line_format == "authority"
- assert req.method == "CONNECT"
- assert req.host == "address"
- assert req.port == 22
+ req = protocol.read_request(NotImplemented)
+ assert req.first_line_format == "authority"
+ assert req.method == "CONNECT"
+ assert req.host == "address"
+ assert req.port == 22
- req = protocol.read_request(NotImplemented)
- assert req.first_line_format == "authority"
- assert req.method == "CONNECT"
- assert req.host == "example.com"
- assert req.port == 443
+ req = protocol.read_request(NotImplemented)
+ assert req.first_line_format == "authority"
+ assert req.method == "CONNECT"
+ assert req.host == "example.com"
+ assert req.port == 443
class TestReadResponse(tservers.ServerTestBase):
@@ -411,19 +411,19 @@ class TestReadResponse(tservers.ServerTestBase):
def test_read_response(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- protocol = HTTP2Protocol(c)
- protocol.connection_preface_performed = True
+ with c.connect():
+ c.convert_to_ssl()
+ protocol = HTTP2Protocol(c)
+ protocol.connection_preface_performed = True
- resp = protocol.read_response(NotImplemented, stream_id=42)
+ resp = protocol.read_response(NotImplemented, stream_id=42)
- assert resp.http_version == "HTTP/2.0"
- assert resp.status_code == 200
- assert resp.reason == ''
- assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar'))
- assert resp.content == b'foobar'
- assert resp.timestamp_end
+ assert resp.http_version == "HTTP/2.0"
+ assert resp.status_code == 200
+ assert resp.reason == ''
+ assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar'))
+ assert resp.content == b'foobar'
+ assert resp.timestamp_end
class TestReadEmptyResponse(tservers.ServerTestBase):
@@ -437,19 +437,19 @@ class TestReadEmptyResponse(tservers.ServerTestBase):
def test_read_empty_response(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- protocol = HTTP2Protocol(c)
- protocol.connection_preface_performed = True
+ with c.connect():
+ c.convert_to_ssl()
+ protocol = HTTP2Protocol(c)
+ protocol.connection_preface_performed = True
- resp = protocol.read_response(NotImplemented, stream_id=42)
+ resp = protocol.read_response(NotImplemented, stream_id=42)
- assert resp.stream_id == 42
- assert resp.http_version == "HTTP/2.0"
- assert resp.status_code == 200
- assert resp.reason == ''
- assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar'))
- assert resp.content == b''
+ assert resp.stream_id == 42
+ assert resp.http_version == "HTTP/2.0"
+ assert resp.status_code == 200
+ assert resp.reason == ''
+ assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar'))
+ assert resp.content == b''
class TestAssembleRequest(object):
diff --git a/test/netlib/test_tcp.py b/test/netlib/test_tcp.py
index 083360b4..590bcc01 100644
--- a/test/netlib/test_tcp.py
+++ b/test/netlib/test_tcp.py
@@ -39,8 +39,21 @@ class ClientCipherListHandler(tcp.BaseHandler):
class HangHandler(tcp.BaseHandler):
def handle(self):
+ # Hang as long as the client connection is alive
while True:
- time.sleep(1)
+ try:
+ self.connection.setblocking(0)
+ ret = self.connection.recv(1)
+ # Client connection is dead...
+ if ret == "" or ret == b"":
+ return
+ except socket.error:
+ pass
+ except SSL.WantReadError:
+ pass
+ except Exception:
+ return
+ time.sleep(0.1)
class ALPNHandler(tcp.BaseHandler):
@@ -61,18 +74,18 @@ class TestServer(tservers.ServerTestBase):
def test_echo(self):
testval = b"echo!\n"
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.wfile.write(testval)
- c.wfile.flush()
- assert c.rfile.readline() == testval
+ with c.connect():
+ c.wfile.write(testval)
+ c.wfile.flush()
+ assert c.rfile.readline() == testval
def test_thread_start_error(self):
with mock.patch.object(threading.Thread, "start", side_effect=threading.ThreadError("nonewthread")) as m:
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- assert not c.rfile.read(1)
- assert m.called
- assert "nonewthread" in self.q.get_nowait()
+ with c.connect():
+ assert not c.rfile.read(1)
+ assert m.called
+ assert "nonewthread" in self.q.get_nowait()
self.test_echo()
@@ -92,9 +105,9 @@ class TestServerBind(tservers.ServerTestBase):
c = tcp.TCPClient(
("127.0.0.1", self.port), source_address=(
"127.0.0.1", random_port))
- c.connect()
- assert c.rfile.readline() == str(("127.0.0.1", random_port)).encode()
- return
+ with c.connect():
+ assert c.rfile.readline() == str(("127.0.0.1", random_port)).encode()
+ return
except TcpException: # port probably already in use
pass
@@ -106,10 +119,10 @@ class TestServerIPv6(tservers.ServerTestBase):
def test_echo(self):
testval = b"echo!\n"
c = tcp.TCPClient(tcp.Address(("::1", self.port), use_ipv6=True))
- c.connect()
- c.wfile.write(testval)
- c.wfile.flush()
- assert c.rfile.readline() == testval
+ with c.connect():
+ c.wfile.write(testval)
+ c.wfile.flush()
+ assert c.rfile.readline() == testval
class TestEcho(tservers.ServerTestBase):
@@ -118,10 +131,10 @@ class TestEcho(tservers.ServerTestBase):
def test_echo(self):
testval = b"echo!\n"
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.wfile.write(testval)
- c.wfile.flush()
- assert c.rfile.readline() == testval
+ with c.connect():
+ c.wfile.write(testval)
+ c.wfile.flush()
+ assert c.rfile.readline() == testval
class HardDisconnectHandler(tcp.BaseHandler):
@@ -140,10 +153,10 @@ class TestFinishFail(tservers.ServerTestBase):
def test_disconnect_in_finish(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.wfile.write(b"foo\n")
- c.wfile.flush = mock.Mock(side_effect=TcpDisconnect)
- c.finish()
+ with c.connect():
+ c.wfile.write(b"foo\n")
+ c.wfile.flush = mock.Mock(side_effect=TcpDisconnect)
+ c.finish()
class TestServerSSL(tservers.ServerTestBase):
@@ -155,21 +168,21 @@ class TestServerSSL(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl(sni=b"foo.com", options=SSL.OP_ALL)
- testval = b"echo!\n"
- c.wfile.write(testval)
- c.wfile.flush()
- assert c.rfile.readline() == testval
+ with c.connect():
+ c.convert_to_ssl(sni=b"foo.com", options=SSL.OP_ALL)
+ testval = b"echo!\n"
+ c.wfile.write(testval)
+ c.wfile.flush()
+ assert c.rfile.readline() == testval
def test_get_current_cipher(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- assert not c.get_current_cipher()
- c.convert_to_ssl(sni=b"foo.com")
- ret = c.get_current_cipher()
- assert ret
- assert "AES" in ret[0]
+ with c.connect():
+ assert not c.get_current_cipher()
+ c.convert_to_ssl(sni=b"foo.com")
+ ret = c.get_current_cipher()
+ assert ret
+ assert "AES" in ret[0]
class TestSSLv3Only(tservers.ServerTestBase):
@@ -181,8 +194,8 @@ class TestSSLv3Only(tservers.ServerTestBase):
def test_failure(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- tutils.raises(TlsException, c.convert_to_ssl, sni=b"foo.com")
+ with c.connect():
+ tutils.raises(TlsException, c.convert_to_ssl, sni=b"foo.com")
class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):
@@ -195,49 +208,46 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):
def test_mode_default_should_pass(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
-
- c.convert_to_ssl()
+ with c.connect():
+ c.convert_to_ssl()
- # Verification errors should be saved even if connection isn't aborted
- # aborted
- assert c.ssl_verification_error is not None
+ # Verification errors should be saved even if connection isn't aborted
+ # aborted
+ assert c.ssl_verification_error is not None
- testval = b"echo!\n"
- c.wfile.write(testval)
- c.wfile.flush()
- assert c.rfile.readline() == testval
+ testval = b"echo!\n"
+ c.wfile.write(testval)
+ c.wfile.flush()
+ assert c.rfile.readline() == testval
def test_mode_none_should_pass(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
-
- c.convert_to_ssl(verify_options=SSL.VERIFY_NONE)
+ with c.connect():
+ c.convert_to_ssl(verify_options=SSL.VERIFY_NONE)
- # Verification errors should be saved even if connection isn't aborted
- assert c.ssl_verification_error is not None
+ # Verification errors should be saved even if connection isn't aborted
+ assert c.ssl_verification_error is not None
- testval = b"echo!\n"
- c.wfile.write(testval)
- c.wfile.flush()
- assert c.rfile.readline() == testval
+ testval = b"echo!\n"
+ c.wfile.write(testval)
+ c.wfile.flush()
+ assert c.rfile.readline() == testval
def test_mode_strict_should_fail(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
-
- with tutils.raises(InvalidCertificateException):
- c.convert_to_ssl(
- sni=b"example.mitmproxy.org",
- verify_options=SSL.VERIFY_PEER,
- ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
- )
+ with c.connect():
+ with tutils.raises(InvalidCertificateException):
+ c.convert_to_ssl(
+ sni=b"example.mitmproxy.org",
+ verify_options=SSL.VERIFY_PEER,
+ ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
+ )
- assert c.ssl_verification_error is not None
+ assert c.ssl_verification_error is not None
- # Unknown issuing certificate authority for first certificate
- assert c.ssl_verification_error['errno'] == 18
- assert c.ssl_verification_error['depth'] == 0
+ # Unknown issuing certificate authority for first certificate
+ assert c.ssl_verification_error['errno'] == 18
+ assert c.ssl_verification_error['depth'] == 0
class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase):
@@ -250,26 +260,23 @@ class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase):
def test_should_fail_without_sni(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
-
- with tutils.raises(TlsException):
- c.convert_to_ssl(
- verify_options=SSL.VERIFY_PEER,
- ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
- )
+ with c.connect():
+ with tutils.raises(TlsException):
+ c.convert_to_ssl(
+ verify_options=SSL.VERIFY_PEER,
+ ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
+ )
def test_should_fail(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
-
- with tutils.raises(InvalidCertificateException):
- c.convert_to_ssl(
- sni=b"mitmproxy.org",
- verify_options=SSL.VERIFY_PEER,
- ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
- )
-
- assert c.ssl_verification_error is not None
+ with c.connect():
+ with tutils.raises(InvalidCertificateException):
+ c.convert_to_ssl(
+ sni=b"mitmproxy.org",
+ verify_options=SSL.VERIFY_PEER,
+ ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
+ )
+ assert c.ssl_verification_error is not None
class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase):
@@ -282,37 +289,35 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase):
def test_mode_strict_w_pemfile_should_pass(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
-
- c.convert_to_ssl(
- sni=b"example.mitmproxy.org",
- verify_options=SSL.VERIFY_PEER,
- ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
- )
+ with c.connect():
+ c.convert_to_ssl(
+ sni=b"example.mitmproxy.org",
+ verify_options=SSL.VERIFY_PEER,
+ ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
+ )
- assert c.ssl_verification_error is None
+ assert c.ssl_verification_error is None
- testval = b"echo!\n"
- c.wfile.write(testval)
- c.wfile.flush()
- assert c.rfile.readline() == testval
+ testval = b"echo!\n"
+ c.wfile.write(testval)
+ c.wfile.flush()
+ assert c.rfile.readline() == testval
def test_mode_strict_w_cadir_should_pass(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
-
- c.convert_to_ssl(
- sni=b"example.mitmproxy.org",
- verify_options=SSL.VERIFY_PEER,
- ca_path=tutils.test_data.path("data/verificationcerts/")
- )
+ with c.connect():
+ c.convert_to_ssl(
+ sni=b"example.mitmproxy.org",
+ verify_options=SSL.VERIFY_PEER,
+ ca_path=tutils.test_data.path("data/verificationcerts/")
+ )
- assert c.ssl_verification_error is None
+ assert c.ssl_verification_error is None
- testval = b"echo!\n"
- c.wfile.write(testval)
- c.wfile.flush()
- assert c.rfile.readline() == testval
+ testval = b"echo!\n"
+ c.wfile.write(testval)
+ c.wfile.flush()
+ assert c.rfile.readline() == testval
class TestSSLClientCert(tservers.ServerTestBase):
@@ -334,19 +339,19 @@ class TestSSLClientCert(tservers.ServerTestBase):
def test_clientcert(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl(
- cert=tutils.test_data.path("data/clientcert/client.pem"))
- assert c.rfile.readline().strip() == b"1"
+ with c.connect():
+ c.convert_to_ssl(
+ cert=tutils.test_data.path("data/clientcert/client.pem"))
+ assert c.rfile.readline().strip() == b"1"
def test_clientcert_err(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- tutils.raises(
- TlsException,
- c.convert_to_ssl,
- cert=tutils.test_data.path("data/clientcert/make")
- )
+ with c.connect():
+ tutils.raises(
+ TlsException,
+ c.convert_to_ssl,
+ cert=tutils.test_data.path("data/clientcert/make")
+ )
class TestSNI(tservers.ServerTestBase):
@@ -365,10 +370,10 @@ class TestSNI(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl(sni=b"foo.com")
- assert c.sni == b"foo.com"
- assert c.rfile.readline() == b"foo.com"
+ with c.connect():
+ c.convert_to_ssl(sni=b"foo.com")
+ assert c.sni == b"foo.com"
+ assert c.rfile.readline() == b"foo.com"
class TestServerCipherList(tservers.ServerTestBase):
@@ -379,9 +384,9 @@ class TestServerCipherList(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl(sni=b"foo.com")
- assert c.rfile.readline() == b"['RC4-SHA']"
+ with c.connect():
+ c.convert_to_ssl(sni=b"foo.com")
+ assert c.rfile.readline() == b"['RC4-SHA']"
class TestServerCurrentCipher(tservers.ServerTestBase):
@@ -399,9 +404,9 @@ class TestServerCurrentCipher(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl(sni=b"foo.com")
- assert b"RC4-SHA" in c.rfile.readline()
+ with c.connect():
+ c.convert_to_ssl(sni=b"foo.com")
+ assert b"RC4-SHA" in c.rfile.readline()
class TestServerCipherListError(tservers.ServerTestBase):
@@ -412,8 +417,8 @@ class TestServerCipherListError(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- tutils.raises("handshake error", c.convert_to_ssl, sni=b"foo.com")
+ with c.connect():
+ tutils.raises("handshake error", c.convert_to_ssl, sni=b"foo.com")
class TestClientCipherListError(tservers.ServerTestBase):
@@ -424,12 +429,13 @@ class TestClientCipherListError(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- tutils.raises(
- "cipher specification",
- c.convert_to_ssl,
- sni=b"foo.com",
- cipher_list="bogus")
+ with c.connect():
+ tutils.raises(
+ "cipher specification",
+ c.convert_to_ssl,
+ sni=b"foo.com",
+ cipher_list="bogus"
+ )
class TestSSLDisconnect(tservers.ServerTestBase):
@@ -443,13 +449,13 @@ class TestSSLDisconnect(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- # Excercise SSL.ZeroReturnError
- c.rfile.read(10)
- c.close()
- tutils.raises(TcpDisconnect, c.wfile.write, b"foo")
- tutils.raises(queue.Empty, self.q.get_nowait)
+ with c.connect():
+ c.convert_to_ssl()
+ # Excercise SSL.ZeroReturnError
+ c.rfile.read(10)
+ c.close()
+ tutils.raises(TcpDisconnect, c.wfile.write, b"foo")
+ tutils.raises(queue.Empty, self.q.get_nowait)
class TestSSLHardDisconnect(tservers.ServerTestBase):
@@ -458,23 +464,23 @@ class TestSSLHardDisconnect(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- # Exercise SSL.SysCallError
- c.rfile.read(10)
- c.close()
- tutils.raises(TcpDisconnect, c.wfile.write, b"foo")
+ with c.connect():
+ c.convert_to_ssl()
+ # Exercise SSL.SysCallError
+ c.rfile.read(10)
+ c.close()
+ tutils.raises(TcpDisconnect, c.wfile.write, b"foo")
class TestDisconnect(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.rfile.read(10)
- c.wfile.write(b"foo")
- c.close()
- c.close()
+ with c.connect():
+ c.rfile.read(10)
+ c.wfile.write(b"foo")
+ c.close()
+ c.close()
class TestServerTimeOut(tservers.ServerTestBase):
@@ -491,9 +497,9 @@ class TestServerTimeOut(tservers.ServerTestBase):
def test_timeout(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- time.sleep(0.3)
- assert self.last_handler.timeout
+ with c.connect():
+ time.sleep(0.3)
+ assert self.last_handler.timeout
class TestTimeOut(tservers.ServerTestBase):
@@ -501,10 +507,10 @@ class TestTimeOut(tservers.ServerTestBase):
def test_timeout(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.settimeout(0.1)
- assert c.gettimeout() == 0.1
- tutils.raises(TcpTimeout, c.rfile.read, 10)
+ with c.connect():
+ c.settimeout(0.1)
+ assert c.gettimeout() == 0.1
+ tutils.raises(TcpTimeout, c.rfile.read, 10)
class TestALPNClient(tservers.ServerTestBase):
@@ -516,25 +522,25 @@ class TestALPNClient(tservers.ServerTestBase):
if tcp.HAS_ALPN:
def test_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl(alpn_protos=[b"foo", b"bar", b"fasel"])
- assert c.get_alpn_proto_negotiated() == b"bar"
- assert c.rfile.readline().strip() == b"bar"
+ with c.connect():
+ c.convert_to_ssl(alpn_protos=[b"foo", b"bar", b"fasel"])
+ assert c.get_alpn_proto_negotiated() == b"bar"
+ assert c.rfile.readline().strip() == b"bar"
def test_no_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- assert c.get_alpn_proto_negotiated() == b""
- assert c.rfile.readline().strip() == b"NONE"
+ with c.connect():
+ c.convert_to_ssl()
+ assert c.get_alpn_proto_negotiated() == b""
+ assert c.rfile.readline().strip() == b"NONE"
else:
def test_none_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl(alpn_protos=[b"foo", b"bar", b"fasel"])
- assert c.get_alpn_proto_negotiated() == b""
- assert c.rfile.readline() == b"NONE"
+ with c.connect():
+ c.convert_to_ssl(alpn_protos=[b"foo", b"bar", b"fasel"])
+ assert c.get_alpn_proto_negotiated() == b""
+ assert c.rfile.readline() == b"NONE"
class TestNoSSLNoALPNClient(tservers.ServerTestBase):
@@ -542,9 +548,9 @@ class TestNoSSLNoALPNClient(tservers.ServerTestBase):
def test_no_ssl_no_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- assert c.get_alpn_proto_negotiated() == b""
- assert c.rfile.readline().strip() == b"NONE"
+ with c.connect():
+ assert c.get_alpn_proto_negotiated() == b""
+ assert c.rfile.readline().strip() == b"NONE"
class TestSSLTimeOut(tservers.ServerTestBase):
@@ -553,10 +559,10 @@ class TestSSLTimeOut(tservers.ServerTestBase):
def test_timeout_client(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- c.settimeout(0.1)
- tutils.raises(TcpTimeout, c.rfile.read, 10)
+ with c.connect():
+ c.convert_to_ssl()
+ c.settimeout(0.1)
+ tutils.raises(TcpTimeout, c.rfile.read, 10)
class TestDHParams(tservers.ServerTestBase):
@@ -570,10 +576,10 @@ class TestDHParams(tservers.ServerTestBase):
def test_dhparams(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- ret = c.get_current_cipher()
- assert ret[0] == "DHE-RSA-AES256-SHA"
+ with c.connect():
+ c.convert_to_ssl()
+ ret = c.get_current_cipher()
+ assert ret[0] == "DHE-RSA-AES256-SHA"
def test_create_dhparams(self):
with tutils.tmpdir() as d:
@@ -718,33 +724,34 @@ class TestPeek(tservers.ServerTestBase):
handler = EchoHandler
def _connect(self, c):
- c.connect()
+ return c.connect()
def test_peek(self):
testval = b"peek!\n"
c = tcp.TCPClient(("127.0.0.1", self.port))
- self._connect(c)
- c.wfile.write(testval)
- c.wfile.flush()
+ with self._connect(c):
+ c.wfile.write(testval)
+ c.wfile.flush()
- assert c.rfile.peek(4) == b"peek"
- assert c.rfile.peek(6) == b"peek!\n"
- assert c.rfile.readline() == testval
+ assert c.rfile.peek(4) == b"peek"
+ assert c.rfile.peek(6) == b"peek!\n"
+ assert c.rfile.readline() == testval
- c.close()
- with tutils.raises(NetlibException):
- if c.rfile.peek(1) == b"":
- # Workaround for Python 2 on Unix:
- # Peeking a closed connection does not raise an exception here.
- raise NetlibException()
+ c.close()
+ with tutils.raises(NetlibException):
+ if c.rfile.peek(1) == b"":
+ # Workaround for Python 2 on Unix:
+ # Peeking a closed connection does not raise an exception here.
+ raise NetlibException()
class TestPeekSSL(TestPeek):
ssl = True
def _connect(self, c):
- c.connect()
- c.convert_to_ssl()
+ with c.connect() as conn:
+ c.convert_to_ssl()
+ return conn.pop()
class TestAddress:
@@ -774,16 +781,16 @@ class TestSSLKeyLogger(tservers.ServerTestBase):
tcp.log_ssl_key = tcp.SSLKeyLogger(logfile)
c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- c.wfile.write(testval)
- c.wfile.flush()
- assert c.rfile.readline() == testval
- c.finish()
-
- tcp.log_ssl_key.close()
- with open(logfile, "rb") as f:
- assert f.read().count(b"CLIENT_RANDOM") == 2
+ with c.connect():
+ c.convert_to_ssl()
+ c.wfile.write(testval)
+ c.wfile.flush()
+ assert c.rfile.readline() == testval
+ c.finish()
+
+ tcp.log_ssl_key.close()
+ with open(logfile, "rb") as f:
+ assert f.read().count(b"CLIENT_RANDOM") == 2
tcp.log_ssl_key = _logfun
diff --git a/test/netlib/tservers.py b/test/netlib/tservers.py
index 569745e6..803aaa72 100644
--- a/test/netlib/tservers.py
+++ b/test/netlib/tservers.py
@@ -104,6 +104,9 @@ class ServerTestBase(object):
def teardown_class(cls):
cls.server.shutdown()
+ def teardown(self):
+ self.server.server.wait_for_silence()
+
@property
def last_handler(self):
return self.server.server.last_handler