aboutsummaryrefslogtreecommitdiffstats
path: root/test/netlib
diff options
context:
space:
mode:
Diffstat (limited to 'test/netlib')
-rw-r--r--test/netlib/http/http1/test_assemble.py2
-rw-r--r--test/netlib/http/http1/test_read.py26
-rw-r--r--test/netlib/http/test_cookies.py21
-rw-r--r--test/netlib/http/test_headers.py11
-rw-r--r--test/netlib/http/test_message.py224
-rw-r--r--test/netlib/http/test_request.py8
-rw-r--r--test/netlib/test_encoding.py40
-rw-r--r--test/netlib/test_strutils.py71
-rw-r--r--test/netlib/test_tcp.py26
-rw-r--r--test/netlib/tservers.py12
10 files changed, 328 insertions, 113 deletions
diff --git a/test/netlib/http/http1/test_assemble.py b/test/netlib/http/http1/test_assemble.py
index 50d29384..841ea58a 100644
--- a/test/netlib/http/http1/test_assemble.py
+++ b/test/netlib/http/http1/test_assemble.py
@@ -24,7 +24,7 @@ def test_assemble_request():
def test_assemble_request_head():
- c = assemble_request_head(treq(content="foo"))
+ c = assemble_request_head(treq(content=b"foo"))
assert b"GET" in c
assert b"qvalue" in c
assert b"content-length" in c
diff --git a/test/netlib/http/http1/test_read.py b/test/netlib/http/http1/test_read.py
index 5285ac1d..c8a40ecb 100644
--- a/test/netlib/http/http1/test_read.py
+++ b/test/netlib/http/http1/test_read.py
@@ -1,6 +1,9 @@
from __future__ import absolute_import, print_function, division
+
from io import BytesIO
from mock import Mock
+import pytest
+
from netlib.exceptions import HttpException, HttpSyntaxException, HttpReadDisconnect, TcpDisconnect
from netlib.http import Headers
from netlib.http.http1.read import (
@@ -23,11 +26,18 @@ def test_get_header_tokens():
assert get_header_tokens(headers, "foo") == ["bar", "voing", "oink"]
-def test_read_request():
- rfile = BytesIO(b"GET / HTTP/1.1\r\n\r\nskip")
+@pytest.mark.parametrize("input", [
+ b"GET / HTTP/1.1\r\n\r\nskip",
+ b"GET / HTTP/1.1\r\n\r\nskip",
+ b"GET / HTTP/1.1\r\n\r\nskip",
+ b"GET / HTTP/1.1 \r\n\r\nskip",
+])
+def test_read_request(input):
+ rfile = BytesIO(input)
r = read_request(rfile)
assert r.method == "GET"
assert r.content == b""
+ assert r.http_version == "HTTP/1.1"
assert r.timestamp_end
assert rfile.read() == b"skip"
@@ -50,11 +60,19 @@ def test_read_request_head():
assert rfile.read() == b"skip"
-def test_read_response():
+@pytest.mark.parametrize("input", [
+ b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody",
+ b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody",
+ b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody",
+ b"HTTP/1.1 418 I'm a teapot \r\n\r\nbody",
+])
+def test_read_response(input):
req = treq()
- rfile = BytesIO(b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody")
+ rfile = BytesIO(input)
r = read_response(rfile, req)
+ assert r.http_version == "HTTP/1.1"
assert r.status_code == 418
+ assert r.reason == "I'm a teapot"
assert r.content == b"body"
assert r.timestamp_end
diff --git a/test/netlib/http/test_cookies.py b/test/netlib/http/test_cookies.py
index 83b85656..17e21b94 100644
--- a/test/netlib/http/test_cookies.py
+++ b/test/netlib/http/test_cookies.py
@@ -245,3 +245,24 @@ def test_refresh_cookie():
assert cookies.refresh_set_cookie_header(c, 0)
c = "foo/bar=bla"
assert cookies.refresh_set_cookie_header(c, 0)
+
+
+def test_is_expired():
+ CA = cookies.CookieAttrs
+
+ # A cookie can be expired
+ # by setting the expire time in the past
+ assert cookies.is_expired(CA([("Expires", "Thu, 01-Jan-1970 00:00:00 GMT")]))
+
+ # or by setting Max-Age to 0
+ assert cookies.is_expired(CA([("Max-Age", "0")]))
+
+ # or both
+ assert cookies.is_expired(CA([("Expires", "Thu, 01-Jan-1970 00:00:00 GMT"), ("Max-Age", "0")]))
+
+ assert not cookies.is_expired(CA([("Expires", "Thu, 24-Aug-2063 00:00:00 GMT")]))
+ assert not cookies.is_expired(CA([("Max-Age", "1")]))
+ assert not cookies.is_expired(CA([("Expires", "Thu, 15-Jul-2068 00:00:00 GMT"), ("Max-Age", "1")]))
+
+ assert not cookies.is_expired(CA([("Max-Age", "nan")]))
+ assert not cookies.is_expired(CA([("Expires", "false")]))
diff --git a/test/netlib/http/test_headers.py b/test/netlib/http/test_headers.py
index 51819b86..51537310 100644
--- a/test/netlib/http/test_headers.py
+++ b/test/netlib/http/test_headers.py
@@ -1,4 +1,6 @@
-from netlib.http import Headers, parse_content_type
+import collections
+
+from netlib.http.headers import Headers, parse_content_type, assemble_content_type
from netlib.tutils import raises
@@ -81,3 +83,10 @@ def test_parse_content_type():
v = p("text/html; charset=UTF-8")
assert v == ('text', 'html', {'charset': 'UTF-8'})
+
+
+def test_assemble_content_type():
+ p = assemble_content_type
+ assert p("text", "html", {}) == "text/html"
+ assert p("text", "html", {"charset": "utf8"}) == "text/html; charset=utf8"
+ assert p("text", "html", collections.OrderedDict([("charset", "utf8"), ("foo", "bar")])) == "text/html; charset=utf8; foo=bar"
diff --git a/test/netlib/http/test_message.py b/test/netlib/http/test_message.py
index f5bf7f0c..deebd6f2 100644
--- a/test/netlib/http/test_message.py
+++ b/test/netlib/http/test_message.py
@@ -1,14 +1,17 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function, division
-from netlib.http import decoded
+import mock
+import six
+
from netlib.tutils import tresp
+from netlib import http, tutils
def _test_passthrough_attr(message, attr):
assert getattr(message, attr) == getattr(message.data, attr)
- setattr(message, attr, "foo")
- assert getattr(message.data, attr) == "foo"
+ setattr(message, attr, b"foo")
+ assert getattr(message.data, attr) == b"foo"
def _test_decoded_attr(message, attr):
@@ -68,6 +71,15 @@ class TestMessage(object):
assert resp != 0
+ def test_hash(self):
+ resp = tresp()
+ assert hash(resp)
+
+ def test_serializable(self):
+ resp = tresp()
+ resp2 = http.Response.from_state(resp.get_state())
+ assert resp == resp2
+
def test_content_length_update(self):
resp = tresp()
resp.content = b"foo"
@@ -76,9 +88,9 @@ class TestMessage(object):
resp.content = b""
assert resp.data.content == b""
assert resp.headers["content-length"] == "0"
-
- def test_content_basic(self):
- _test_passthrough_attr(tresp(), "content")
+ resp.raw_content = b"bar"
+ assert resp.data.content == b"bar"
+ assert resp.headers["content-length"] == "0"
def test_headers(self):
_test_passthrough_attr(tresp(), "headers")
@@ -89,65 +101,201 @@ class TestMessage(object):
def test_timestamp_end(self):
_test_passthrough_attr(tresp(), "timestamp_end")
- def teste_http_version(self):
+ def test_http_version(self):
_test_decoded_attr(tresp(), "http_version")
-class TestDecodedDecorator(object):
-
+class TestMessageContentEncoding(object):
def test_simple(self):
r = tresp()
- assert r.content == b"message"
+ assert r.raw_content == b"message"
assert "content-encoding" not in r.headers
- assert r.encode("gzip")
+ r.encode("gzip")
assert r.headers["content-encoding"]
- assert r.content != b"message"
- with decoded(r):
- assert "content-encoding" not in r.headers
- assert r.content == b"message"
- assert r.headers["content-encoding"]
- assert r.content != b"message"
+ assert r.raw_content != b"message"
+ assert r.content == b"message"
+ assert r.raw_content != b"message"
+
+ r.raw_content = b"foo"
+ with mock.patch("netlib.encoding.decode") as e:
+ assert r.content
+ assert e.call_count == 1
+ e.reset_mock()
+ assert r.content
+ assert e.call_count == 0
def test_modify(self):
r = tresp()
assert "content-encoding" not in r.headers
- assert r.encode("gzip")
+ r.encode("gzip")
- with decoded(r):
+ r.content = b"foo"
+ assert r.raw_content != b"foo"
+ r.decode()
+ assert r.raw_content == b"foo"
+
+ r.encode("identity")
+ with mock.patch("netlib.encoding.encode") as e:
r.content = b"foo"
+ assert e.call_count == 0
+ r.content = b"bar"
+ assert e.call_count == 1
- assert r.content != b"foo"
- r.decode()
- assert r.content == b"foo"
+ with tutils.raises(TypeError):
+ r.content = u"foo"
def test_unknown_ce(self):
r = tresp()
r.headers["content-encoding"] = "zopfli"
- r.content = b"foo"
- with decoded(r):
- assert r.headers["content-encoding"]
- assert r.content == b"foo"
+ r.raw_content = b"foo"
+ with tutils.raises(ValueError):
+ assert r.content
assert r.headers["content-encoding"]
- assert r.content == b"foo"
+ assert r.get_content(strict=False) == b"foo"
def test_cannot_decode(self):
r = tresp()
- assert r.encode("gzip")
- r.content = b"foo"
- with decoded(r):
- assert r.headers["content-encoding"]
- assert r.content == b"foo"
+ r.encode("gzip")
+ r.raw_content = b"foo"
+ with tutils.raises(ValueError):
+ assert r.content
assert r.headers["content-encoding"]
- assert r.content != b"foo"
- r.decode()
+ assert r.get_content(strict=False) == b"foo"
+
+ with tutils.raises(ValueError):
+ r.decode()
+ assert r.raw_content == b"foo"
+ assert "content-encoding" in r.headers
+
+ r.decode(strict=False)
assert r.content == b"foo"
+ assert "content-encoding" not in r.headers
+
+ def test_none(self):
+ r = tresp(content=None)
+ assert r.content is None
+ r.content = b"foo"
+ assert r.content is not None
+ r.content = None
+ assert r.content is None
def test_cannot_encode(self):
r = tresp()
- assert r.encode("gzip")
- with decoded(r):
- r.content = None
+ r.encode("gzip")
+ r.content = None
+ assert r.headers["content-encoding"]
+ assert r.raw_content is None
+ r.headers["content-encoding"] = "zopfli"
+ r.content = b"foo"
assert "content-encoding" not in r.headers
- assert r.content is None
+ assert r.raw_content == b"foo"
+
+ with tutils.raises(ValueError):
+ r.encode("zopfli")
+ assert r.raw_content == b"foo"
+ assert "content-encoding" not in r.headers
+
+
+class TestMessageText(object):
+ def test_simple(self):
+ r = tresp(content=b'\xfc')
+ assert r.raw_content == b"\xfc"
+ assert r.content == b"\xfc"
+ assert r.text == u"ü"
+
+ r.encode("gzip")
+ assert r.text == u"ü"
+ r.decode()
+ assert r.text == u"ü"
+
+ r.headers["content-type"] = "text/html; charset=latin1"
+ r.content = b"\xc3\xbc"
+ assert r.text == u"ü"
+ r.headers["content-type"] = "text/html; charset=utf8"
+ assert r.text == u"ü"
+
+ r.encode("identity")
+ r.raw_content = b"foo"
+ with mock.patch("netlib.encoding.decode") as e:
+ assert r.text
+ assert e.call_count == 2
+ e.reset_mock()
+ assert r.text
+ assert e.call_count == 0
+
+ def test_guess_json(self):
+ r = tresp(content=b'"\xc3\xbc"')
+ r.headers["content-type"] = "application/json"
+ assert r.text == u'"ü"'
+
+ def test_none(self):
+ r = tresp(content=None)
+ assert r.text is None
+ r.text = u"foo"
+ assert r.text is not None
+ r.text = None
+ assert r.text is None
+
+ def test_modify(self):
+ r = tresp()
+
+ r.text = u"ü"
+ assert r.raw_content == b"\xfc"
+
+ r.headers["content-type"] = "text/html; charset=utf8"
+ r.text = u"ü"
+ assert r.raw_content == b"\xc3\xbc"
+ assert r.headers["content-length"] == "2"
+
+ r.encode("identity")
+ with mock.patch("netlib.encoding.encode") as e:
+ e.return_value = b""
+ r.text = u"ü"
+ assert e.call_count == 0
+ r.text = u"ä"
+ assert e.call_count == 2
+
+ def test_unknown_ce(self):
+ r = tresp()
+ r.headers["content-type"] = "text/html; charset=wtf"
+ r.raw_content = b"foo"
+ with tutils.raises(ValueError):
+ assert r.text == u"foo"
+ assert r.get_text(strict=False) == u"foo"
+
+ def test_cannot_decode(self):
+ r = tresp()
+ r.headers["content-type"] = "text/html; charset=utf8"
+ r.raw_content = b"\xFF"
+ with tutils.raises(ValueError):
+ assert r.text
+
+ assert r.get_text(strict=False) == u'\ufffd' if six.PY2 else '\udcff'
+
+ def test_cannot_encode(self):
+ r = tresp()
+ r.content = None
+ assert "content-type" not in r.headers
+ assert r.raw_content is None
+
+ r.headers["content-type"] = "text/html; charset=latin1; foo=bar"
+ r.text = u"☃"
+ assert r.headers["content-type"] == "text/html; charset=utf-8; foo=bar"
+ assert r.raw_content == b'\xe2\x98\x83'
+
+ r.headers["content-type"] = "gibberish"
+ r.text = u"☃"
+ assert r.headers["content-type"] == "text/plain; charset=utf-8"
+ assert r.raw_content == b'\xe2\x98\x83'
+
+ del r.headers["content-type"]
+ r.text = u"☃"
+ assert r.headers["content-type"] == "text/plain; charset=utf-8"
+ assert r.raw_content == b'\xe2\x98\x83'
+
+ r.headers["content-type"] = "text/html; charset=latin1"
+ r.text = u'\udcff'
+ assert r.headers["content-type"] == "text/html; charset=utf-8"
+ assert r.raw_content == b'\xed\xb3\xbf' if six.PY2 else b"\xFF"
diff --git a/test/netlib/http/test_request.py b/test/netlib/http/test_request.py
index c03db339..f3cd8b71 100644
--- a/test/netlib/http/test_request.py
+++ b/test/netlib/http/test_request.py
@@ -248,20 +248,20 @@ class TestRequestUtils(object):
assert "gzip" in request.headers["Accept-Encoding"]
def test_get_urlencoded_form(self):
- request = treq(content="foobar=baz")
+ request = treq(content=b"foobar=baz")
assert not request.urlencoded_form
request.headers["Content-Type"] = "application/x-www-form-urlencoded"
- assert list(request.urlencoded_form.items()) == [("foobar", "baz")]
+ assert list(request.urlencoded_form.items()) == [(b"foobar", b"baz")]
def test_set_urlencoded_form(self):
request = treq()
- request.urlencoded_form = [('foo', 'bar'), ('rab', 'oof')]
+ request.urlencoded_form = [(b'foo', b'bar'), (b'rab', b'oof')]
assert request.headers["Content-Type"] == "application/x-www-form-urlencoded"
assert request.content
def test_get_multipart_form(self):
- request = treq(content="foobar")
+ request = treq(content=b"foobar")
assert not request.multipart_form
request.headers["Content-Type"] = "multipart/form-data"
diff --git a/test/netlib/test_encoding.py b/test/netlib/test_encoding.py
index 0ff1aad1..de10fc48 100644
--- a/test/netlib/test_encoding.py
+++ b/test/netlib/test_encoding.py
@@ -1,37 +1,39 @@
-from netlib import encoding
+from netlib import encoding, tutils
def test_identity():
- assert b"string" == encoding.decode("identity", b"string")
- assert b"string" == encoding.encode("identity", b"string")
- assert not encoding.encode("nonexistent", b"string")
- assert not encoding.decode("nonexistent encoding", b"string")
+ assert b"string" == encoding.decode(b"string", "identity")
+ assert b"string" == encoding.encode(b"string", "identity")
+ with tutils.raises(ValueError):
+ encoding.encode(b"string", "nonexistent encoding")
def test_gzip():
assert b"string" == encoding.decode(
- "gzip",
encoding.encode(
- "gzip",
- b"string"
- )
+ b"string",
+ "gzip"
+ ),
+ "gzip"
)
- assert encoding.decode("gzip", b"bogus") is None
+ with tutils.raises(ValueError):
+ encoding.decode(b"bogus", "gzip")
def test_deflate():
assert b"string" == encoding.decode(
- "deflate",
encoding.encode(
- "deflate",
- b"string"
- )
+ b"string",
+ "deflate"
+ ),
+ "deflate"
)
assert b"string" == encoding.decode(
- "deflate",
encoding.encode(
- "deflate",
- b"string"
- )[2:-4]
+ b"string",
+ "deflate"
+ )[2:-4],
+ "deflate"
)
- assert encoding.decode("deflate", b"bogus") is None
+ with tutils.raises(ValueError):
+ encoding.decode(b"bogus", "deflate")
diff --git a/test/netlib/test_strutils.py b/test/netlib/test_strutils.py
index 84a0dded..7c3eacc6 100644
--- a/test/netlib/test_strutils.py
+++ b/test/netlib/test_strutils.py
@@ -1,9 +1,15 @@
-# coding=utf-8
import six
from netlib import strutils, tutils
+def test_always_bytes():
+ assert strutils.always_bytes(bytes(bytearray(range(256)))) == bytes(bytearray(range(256)))
+ assert strutils.always_bytes("foo") == b"foo"
+ with tutils.raises(ValueError):
+ strutils.always_bytes(u"\u2605", "ascii")
+
+
def test_native():
with tutils.raises(TypeError):
strutils.native(42)
@@ -15,22 +21,26 @@ def test_native():
assert strutils.native(b"foo") == u"foo"
-def test_clean_bin():
- assert strutils.clean_bin(b"one") == b"one"
- assert strutils.clean_bin(b"\00ne") == b".ne"
- assert strutils.clean_bin(b"\nne") == b"\nne"
- assert strutils.clean_bin(b"\nne", False) == b".ne"
- assert strutils.clean_bin(u"\u2605".encode("utf8")) == b"..."
-
- assert strutils.clean_bin(u"one") == u"one"
- assert strutils.clean_bin(u"\00ne") == u".ne"
- assert strutils.clean_bin(u"\nne") == u"\nne"
- assert strutils.clean_bin(u"\nne", False) == u".ne"
- assert strutils.clean_bin(u"\u2605") == u"\u2605"
-
-
-def test_safe_subn():
- assert strutils.safe_subn("foo", u"bar", "\xc2foo")
+def test_escape_control_characters():
+ assert strutils.escape_control_characters(u"one") == u"one"
+ assert strutils.escape_control_characters(u"\00ne") == u".ne"
+ assert strutils.escape_control_characters(u"\nne") == u"\nne"
+ assert strutils.escape_control_characters(u"\nne", False) == u".ne"
+ assert strutils.escape_control_characters(u"\u2605") == u"\u2605"
+ assert (
+ strutils.escape_control_characters(bytes(bytearray(range(128))).decode()) ==
+ u'.........\t\n..\r.................. !"#$%&\'()*+,-./0123456789:;<'
+ u'=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~.'
+ )
+ assert (
+ strutils.escape_control_characters(bytes(bytearray(range(128))).decode(), False) ==
+ u'................................ !"#$%&\'()*+,-./0123456789:;<'
+ u'=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~.'
+ )
+
+ if not six.PY2:
+ with tutils.raises(ValueError):
+ strutils.escape_control_characters(b"foo")
def test_bytes_to_escaped_str():
@@ -41,6 +51,14 @@ def test_bytes_to_escaped_str():
assert strutils.bytes_to_escaped_str(b"'") == r"\'"
assert strutils.bytes_to_escaped_str(b'"') == r'"'
+ assert strutils.bytes_to_escaped_str(b"\r\n\t") == "\\r\\n\\t"
+ assert strutils.bytes_to_escaped_str(b"\r\n\t", True) == "\r\n\t"
+
+ assert strutils.bytes_to_escaped_str(b"\n", True) == "\n"
+ assert strutils.bytes_to_escaped_str(b"\\n", True) == "\\ \\ n".replace(" ", "")
+ assert strutils.bytes_to_escaped_str(b"\\\n", True) == "\\ \\ \n".replace(" ", "")
+ assert strutils.bytes_to_escaped_str(b"\\\\n", True) == "\\ \\ \\ \\ n".replace(" ", "")
+
with tutils.raises(ValueError):
strutils.bytes_to_escaped_str(u"such unicode")
@@ -49,10 +67,9 @@ def test_escaped_str_to_bytes():
assert strutils.escaped_str_to_bytes("foo") == b"foo"
assert strutils.escaped_str_to_bytes("\x08") == b"\b"
assert strutils.escaped_str_to_bytes("&!?=\\\\)") == br"&!?=\)"
- assert strutils.escaped_str_to_bytes("ü") == b'\xc3\xbc'
assert strutils.escaped_str_to_bytes(u"\\x08") == b"\b"
assert strutils.escaped_str_to_bytes(u"&!?=\\\\)") == br"&!?=\)"
- assert strutils.escaped_str_to_bytes(u"ü") == b'\xc3\xbc'
+ assert strutils.escaped_str_to_bytes(u"\u00fc") == b'\xc3\xbc'
if six.PY2:
with tutils.raises(ValueError):
@@ -62,17 +79,15 @@ def test_escaped_str_to_bytes():
strutils.escaped_str_to_bytes(b"very byte")
-def test_isBin():
- assert not strutils.isBin("testing\n\r")
- assert strutils.isBin("testing\x01")
- assert strutils.isBin("testing\x0e")
- assert strutils.isBin("testing\x7f")
+def test_is_mostly_bin():
+ assert not strutils.is_mostly_bin(b"foo\xFF")
+ assert strutils.is_mostly_bin(b"foo" + b"\xFF" * 10)
-def test_isXml():
- assert not strutils.isXML("foo")
- assert strutils.isXML("<foo")
- assert strutils.isXML(" \n<foo")
+def test_is_xml():
+ assert not strutils.is_xml(b"foo")
+ assert strutils.is_xml(b"<foo")
+ assert strutils.is_xml(b" \n<foo")
def test_clean_hanging_newline():
diff --git a/test/netlib/test_tcp.py b/test/netlib/test_tcp.py
index 590bcc01..273427d5 100644
--- a/test/netlib/test_tcp.py
+++ b/test/netlib/test_tcp.py
@@ -169,7 +169,7 @@ class TestServerSSL(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
- c.convert_to_ssl(sni=b"foo.com", options=SSL.OP_ALL)
+ c.convert_to_ssl(sni="foo.com", options=SSL.OP_ALL)
testval = b"echo!\n"
c.wfile.write(testval)
c.wfile.flush()
@@ -179,7 +179,7 @@ class TestServerSSL(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
assert not c.get_current_cipher()
- c.convert_to_ssl(sni=b"foo.com")
+ c.convert_to_ssl(sni="foo.com")
ret = c.get_current_cipher()
assert ret
assert "AES" in ret[0]
@@ -195,7 +195,7 @@ class TestSSLv3Only(tservers.ServerTestBase):
def test_failure(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
- tutils.raises(TlsException, c.convert_to_ssl, sni=b"foo.com")
+ tutils.raises(TlsException, c.convert_to_ssl, sni="foo.com")
class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):
@@ -238,7 +238,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase):
with c.connect():
with tutils.raises(InvalidCertificateException):
c.convert_to_ssl(
- sni=b"example.mitmproxy.org",
+ sni="example.mitmproxy.org",
verify_options=SSL.VERIFY_PEER,
ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
)
@@ -272,7 +272,7 @@ class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase):
with c.connect():
with tutils.raises(InvalidCertificateException):
c.convert_to_ssl(
- sni=b"mitmproxy.org",
+ sni="mitmproxy.org",
verify_options=SSL.VERIFY_PEER,
ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
)
@@ -291,7 +291,7 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_ssl(
- sni=b"example.mitmproxy.org",
+ sni="example.mitmproxy.org",
verify_options=SSL.VERIFY_PEER,
ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt")
)
@@ -307,7 +307,7 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
c.convert_to_ssl(
- sni=b"example.mitmproxy.org",
+ sni="example.mitmproxy.org",
verify_options=SSL.VERIFY_PEER,
ca_path=tutils.test_data.path("data/verificationcerts/")
)
@@ -371,8 +371,8 @@ class TestSNI(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
- c.convert_to_ssl(sni=b"foo.com")
- assert c.sni == b"foo.com"
+ c.convert_to_ssl(sni="foo.com")
+ assert c.sni == "foo.com"
assert c.rfile.readline() == b"foo.com"
@@ -385,7 +385,7 @@ class TestServerCipherList(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
- c.convert_to_ssl(sni=b"foo.com")
+ c.convert_to_ssl(sni="foo.com")
assert c.rfile.readline() == b"['RC4-SHA']"
@@ -405,7 +405,7 @@ class TestServerCurrentCipher(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
- c.convert_to_ssl(sni=b"foo.com")
+ c.convert_to_ssl(sni="foo.com")
assert b"RC4-SHA" in c.rfile.readline()
@@ -418,7 +418,7 @@ class TestServerCipherListError(tservers.ServerTestBase):
def test_echo(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
with c.connect():
- tutils.raises("handshake error", c.convert_to_ssl, sni=b"foo.com")
+ tutils.raises("handshake error", c.convert_to_ssl, sni="foo.com")
class TestClientCipherListError(tservers.ServerTestBase):
@@ -433,7 +433,7 @@ class TestClientCipherListError(tservers.ServerTestBase):
tutils.raises(
"cipher specification",
c.convert_to_ssl,
- sni=b"foo.com",
+ sni="foo.com",
cipher_list="bogus"
)
diff --git a/test/netlib/tservers.py b/test/netlib/tservers.py
index 803aaa72..666f97ac 100644
--- a/test/netlib/tservers.py
+++ b/test/netlib/tservers.py
@@ -24,7 +24,7 @@ class _ServerThread(threading.Thread):
class _TServer(tcp.TCPServer):
- def __init__(self, ssl, q, handler_klass, addr):
+ def __init__(self, ssl, q, handler_klass, addr, **kwargs):
"""
ssl: A dictionary of SSL parameters:
@@ -42,6 +42,8 @@ class _TServer(tcp.TCPServer):
self.q = q
self.handler_klass = handler_klass
+ if self.handler_klass is not None:
+ self.handler_klass.kwargs = kwargs
self.last_handler = None
def handle_client_connection(self, request, client_address):
@@ -89,16 +91,16 @@ class ServerTestBase(object):
addr = ("localhost", 0)
@classmethod
- def setup_class(cls):
+ def setup_class(cls, **kwargs):
cls.q = queue.Queue()
- s = cls.makeserver()
+ s = cls.makeserver(**kwargs)
cls.port = s.address.port
cls.server = _ServerThread(s)
cls.server.start()
@classmethod
- def makeserver(cls):
- return _TServer(cls.ssl, cls.q, cls.handler, cls.addr)
+ def makeserver(cls, **kwargs):
+ return _TServer(cls.ssl, cls.q, cls.handler, cls.addr, **kwargs)
@classmethod
def teardown_class(cls):