diff options
Diffstat (limited to 'test/netlib')
-rw-r--r-- | test/netlib/http/http1/test_assemble.py | 2 | ||||
-rw-r--r-- | test/netlib/http/http1/test_read.py | 26 | ||||
-rw-r--r-- | test/netlib/http/test_cookies.py | 21 | ||||
-rw-r--r-- | test/netlib/http/test_headers.py | 11 | ||||
-rw-r--r-- | test/netlib/http/test_message.py | 224 | ||||
-rw-r--r-- | test/netlib/http/test_request.py | 8 | ||||
-rw-r--r-- | test/netlib/test_encoding.py | 40 | ||||
-rw-r--r-- | test/netlib/test_strutils.py | 71 | ||||
-rw-r--r-- | test/netlib/test_tcp.py | 26 | ||||
-rw-r--r-- | test/netlib/tservers.py | 12 |
10 files changed, 328 insertions, 113 deletions
diff --git a/test/netlib/http/http1/test_assemble.py b/test/netlib/http/http1/test_assemble.py index 50d29384..841ea58a 100644 --- a/test/netlib/http/http1/test_assemble.py +++ b/test/netlib/http/http1/test_assemble.py @@ -24,7 +24,7 @@ def test_assemble_request(): def test_assemble_request_head(): - c = assemble_request_head(treq(content="foo")) + c = assemble_request_head(treq(content=b"foo")) assert b"GET" in c assert b"qvalue" in c assert b"content-length" in c diff --git a/test/netlib/http/http1/test_read.py b/test/netlib/http/http1/test_read.py index 5285ac1d..c8a40ecb 100644 --- a/test/netlib/http/http1/test_read.py +++ b/test/netlib/http/http1/test_read.py @@ -1,6 +1,9 @@ from __future__ import absolute_import, print_function, division + from io import BytesIO from mock import Mock +import pytest + from netlib.exceptions import HttpException, HttpSyntaxException, HttpReadDisconnect, TcpDisconnect from netlib.http import Headers from netlib.http.http1.read import ( @@ -23,11 +26,18 @@ def test_get_header_tokens(): assert get_header_tokens(headers, "foo") == ["bar", "voing", "oink"] -def test_read_request(): - rfile = BytesIO(b"GET / HTTP/1.1\r\n\r\nskip") +@pytest.mark.parametrize("input", [ + b"GET / HTTP/1.1\r\n\r\nskip", + b"GET / HTTP/1.1\r\n\r\nskip", + b"GET / HTTP/1.1\r\n\r\nskip", + b"GET / HTTP/1.1 \r\n\r\nskip", +]) +def test_read_request(input): + rfile = BytesIO(input) r = read_request(rfile) assert r.method == "GET" assert r.content == b"" + assert r.http_version == "HTTP/1.1" assert r.timestamp_end assert rfile.read() == b"skip" @@ -50,11 +60,19 @@ def test_read_request_head(): assert rfile.read() == b"skip" -def test_read_response(): +@pytest.mark.parametrize("input", [ + b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody", + b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody", + b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody", + b"HTTP/1.1 418 I'm a teapot \r\n\r\nbody", +]) +def test_read_response(input): req = treq() - rfile = BytesIO(b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody") + rfile = BytesIO(input) r = read_response(rfile, req) + assert r.http_version == "HTTP/1.1" assert r.status_code == 418 + assert r.reason == "I'm a teapot" assert r.content == b"body" assert r.timestamp_end diff --git a/test/netlib/http/test_cookies.py b/test/netlib/http/test_cookies.py index 83b85656..17e21b94 100644 --- a/test/netlib/http/test_cookies.py +++ b/test/netlib/http/test_cookies.py @@ -245,3 +245,24 @@ def test_refresh_cookie(): assert cookies.refresh_set_cookie_header(c, 0) c = "foo/bar=bla" assert cookies.refresh_set_cookie_header(c, 0) + + +def test_is_expired(): + CA = cookies.CookieAttrs + + # A cookie can be expired + # by setting the expire time in the past + assert cookies.is_expired(CA([("Expires", "Thu, 01-Jan-1970 00:00:00 GMT")])) + + # or by setting Max-Age to 0 + assert cookies.is_expired(CA([("Max-Age", "0")])) + + # or both + assert cookies.is_expired(CA([("Expires", "Thu, 01-Jan-1970 00:00:00 GMT"), ("Max-Age", "0")])) + + assert not cookies.is_expired(CA([("Expires", "Thu, 24-Aug-2063 00:00:00 GMT")])) + assert not cookies.is_expired(CA([("Max-Age", "1")])) + assert not cookies.is_expired(CA([("Expires", "Thu, 15-Jul-2068 00:00:00 GMT"), ("Max-Age", "1")])) + + assert not cookies.is_expired(CA([("Max-Age", "nan")])) + assert not cookies.is_expired(CA([("Expires", "false")])) diff --git a/test/netlib/http/test_headers.py b/test/netlib/http/test_headers.py index 51819b86..51537310 100644 --- a/test/netlib/http/test_headers.py +++ b/test/netlib/http/test_headers.py @@ -1,4 +1,6 @@ -from netlib.http import Headers, parse_content_type +import collections + +from netlib.http.headers import Headers, parse_content_type, assemble_content_type from netlib.tutils import raises @@ -81,3 +83,10 @@ def test_parse_content_type(): v = p("text/html; charset=UTF-8") assert v == ('text', 'html', {'charset': 'UTF-8'}) + + +def test_assemble_content_type(): + p = assemble_content_type + assert p("text", "html", {}) == "text/html" + assert p("text", "html", {"charset": "utf8"}) == "text/html; charset=utf8" + assert p("text", "html", collections.OrderedDict([("charset", "utf8"), ("foo", "bar")])) == "text/html; charset=utf8; foo=bar" diff --git a/test/netlib/http/test_message.py b/test/netlib/http/test_message.py index f5bf7f0c..deebd6f2 100644 --- a/test/netlib/http/test_message.py +++ b/test/netlib/http/test_message.py @@ -1,14 +1,17 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, print_function, division -from netlib.http import decoded +import mock +import six + from netlib.tutils import tresp +from netlib import http, tutils def _test_passthrough_attr(message, attr): assert getattr(message, attr) == getattr(message.data, attr) - setattr(message, attr, "foo") - assert getattr(message.data, attr) == "foo" + setattr(message, attr, b"foo") + assert getattr(message.data, attr) == b"foo" def _test_decoded_attr(message, attr): @@ -68,6 +71,15 @@ class TestMessage(object): assert resp != 0 + def test_hash(self): + resp = tresp() + assert hash(resp) + + def test_serializable(self): + resp = tresp() + resp2 = http.Response.from_state(resp.get_state()) + assert resp == resp2 + def test_content_length_update(self): resp = tresp() resp.content = b"foo" @@ -76,9 +88,9 @@ class TestMessage(object): resp.content = b"" assert resp.data.content == b"" assert resp.headers["content-length"] == "0" - - def test_content_basic(self): - _test_passthrough_attr(tresp(), "content") + resp.raw_content = b"bar" + assert resp.data.content == b"bar" + assert resp.headers["content-length"] == "0" def test_headers(self): _test_passthrough_attr(tresp(), "headers") @@ -89,65 +101,201 @@ class TestMessage(object): def test_timestamp_end(self): _test_passthrough_attr(tresp(), "timestamp_end") - def teste_http_version(self): + def test_http_version(self): _test_decoded_attr(tresp(), "http_version") -class TestDecodedDecorator(object): - +class TestMessageContentEncoding(object): def test_simple(self): r = tresp() - assert r.content == b"message" + assert r.raw_content == b"message" assert "content-encoding" not in r.headers - assert r.encode("gzip") + r.encode("gzip") assert r.headers["content-encoding"] - assert r.content != b"message" - with decoded(r): - assert "content-encoding" not in r.headers - assert r.content == b"message" - assert r.headers["content-encoding"] - assert r.content != b"message" + assert r.raw_content != b"message" + assert r.content == b"message" + assert r.raw_content != b"message" + + r.raw_content = b"foo" + with mock.patch("netlib.encoding.decode") as e: + assert r.content + assert e.call_count == 1 + e.reset_mock() + assert r.content + assert e.call_count == 0 def test_modify(self): r = tresp() assert "content-encoding" not in r.headers - assert r.encode("gzip") + r.encode("gzip") - with decoded(r): + r.content = b"foo" + assert r.raw_content != b"foo" + r.decode() + assert r.raw_content == b"foo" + + r.encode("identity") + with mock.patch("netlib.encoding.encode") as e: r.content = b"foo" + assert e.call_count == 0 + r.content = b"bar" + assert e.call_count == 1 - assert r.content != b"foo" - r.decode() - assert r.content == b"foo" + with tutils.raises(TypeError): + r.content = u"foo" def test_unknown_ce(self): r = tresp() r.headers["content-encoding"] = "zopfli" - r.content = b"foo" - with decoded(r): - assert r.headers["content-encoding"] - assert r.content == b"foo" + r.raw_content = b"foo" + with tutils.raises(ValueError): + assert r.content assert r.headers["content-encoding"] - assert r.content == b"foo" + assert r.get_content(strict=False) == b"foo" def test_cannot_decode(self): r = tresp() - assert r.encode("gzip") - r.content = b"foo" - with decoded(r): - assert r.headers["content-encoding"] - assert r.content == b"foo" + r.encode("gzip") + r.raw_content = b"foo" + with tutils.raises(ValueError): + assert r.content assert r.headers["content-encoding"] - assert r.content != b"foo" - r.decode() + assert r.get_content(strict=False) == b"foo" + + with tutils.raises(ValueError): + r.decode() + assert r.raw_content == b"foo" + assert "content-encoding" in r.headers + + r.decode(strict=False) assert r.content == b"foo" + assert "content-encoding" not in r.headers + + def test_none(self): + r = tresp(content=None) + assert r.content is None + r.content = b"foo" + assert r.content is not None + r.content = None + assert r.content is None def test_cannot_encode(self): r = tresp() - assert r.encode("gzip") - with decoded(r): - r.content = None + r.encode("gzip") + r.content = None + assert r.headers["content-encoding"] + assert r.raw_content is None + r.headers["content-encoding"] = "zopfli" + r.content = b"foo" assert "content-encoding" not in r.headers - assert r.content is None + assert r.raw_content == b"foo" + + with tutils.raises(ValueError): + r.encode("zopfli") + assert r.raw_content == b"foo" + assert "content-encoding" not in r.headers + + +class TestMessageText(object): + def test_simple(self): + r = tresp(content=b'\xfc') + assert r.raw_content == b"\xfc" + assert r.content == b"\xfc" + assert r.text == u"ü" + + r.encode("gzip") + assert r.text == u"ü" + r.decode() + assert r.text == u"ü" + + r.headers["content-type"] = "text/html; charset=latin1" + r.content = b"\xc3\xbc" + assert r.text == u"ü" + r.headers["content-type"] = "text/html; charset=utf8" + assert r.text == u"ü" + + r.encode("identity") + r.raw_content = b"foo" + with mock.patch("netlib.encoding.decode") as e: + assert r.text + assert e.call_count == 2 + e.reset_mock() + assert r.text + assert e.call_count == 0 + + def test_guess_json(self): + r = tresp(content=b'"\xc3\xbc"') + r.headers["content-type"] = "application/json" + assert r.text == u'"ü"' + + def test_none(self): + r = tresp(content=None) + assert r.text is None + r.text = u"foo" + assert r.text is not None + r.text = None + assert r.text is None + + def test_modify(self): + r = tresp() + + r.text = u"ü" + assert r.raw_content == b"\xfc" + + r.headers["content-type"] = "text/html; charset=utf8" + r.text = u"ü" + assert r.raw_content == b"\xc3\xbc" + assert r.headers["content-length"] == "2" + + r.encode("identity") + with mock.patch("netlib.encoding.encode") as e: + e.return_value = b"" + r.text = u"ü" + assert e.call_count == 0 + r.text = u"ä" + assert e.call_count == 2 + + def test_unknown_ce(self): + r = tresp() + r.headers["content-type"] = "text/html; charset=wtf" + r.raw_content = b"foo" + with tutils.raises(ValueError): + assert r.text == u"foo" + assert r.get_text(strict=False) == u"foo" + + def test_cannot_decode(self): + r = tresp() + r.headers["content-type"] = "text/html; charset=utf8" + r.raw_content = b"\xFF" + with tutils.raises(ValueError): + assert r.text + + assert r.get_text(strict=False) == u'\ufffd' if six.PY2 else '\udcff' + + def test_cannot_encode(self): + r = tresp() + r.content = None + assert "content-type" not in r.headers + assert r.raw_content is None + + r.headers["content-type"] = "text/html; charset=latin1; foo=bar" + r.text = u"☃" + assert r.headers["content-type"] == "text/html; charset=utf-8; foo=bar" + assert r.raw_content == b'\xe2\x98\x83' + + r.headers["content-type"] = "gibberish" + r.text = u"☃" + assert r.headers["content-type"] == "text/plain; charset=utf-8" + assert r.raw_content == b'\xe2\x98\x83' + + del r.headers["content-type"] + r.text = u"☃" + assert r.headers["content-type"] == "text/plain; charset=utf-8" + assert r.raw_content == b'\xe2\x98\x83' + + r.headers["content-type"] = "text/html; charset=latin1" + r.text = u'\udcff' + assert r.headers["content-type"] == "text/html; charset=utf-8" + assert r.raw_content == b'\xed\xb3\xbf' if six.PY2 else b"\xFF" diff --git a/test/netlib/http/test_request.py b/test/netlib/http/test_request.py index c03db339..f3cd8b71 100644 --- a/test/netlib/http/test_request.py +++ b/test/netlib/http/test_request.py @@ -248,20 +248,20 @@ class TestRequestUtils(object): assert "gzip" in request.headers["Accept-Encoding"] def test_get_urlencoded_form(self): - request = treq(content="foobar=baz") + request = treq(content=b"foobar=baz") assert not request.urlencoded_form request.headers["Content-Type"] = "application/x-www-form-urlencoded" - assert list(request.urlencoded_form.items()) == [("foobar", "baz")] + assert list(request.urlencoded_form.items()) == [(b"foobar", b"baz")] def test_set_urlencoded_form(self): request = treq() - request.urlencoded_form = [('foo', 'bar'), ('rab', 'oof')] + request.urlencoded_form = [(b'foo', b'bar'), (b'rab', b'oof')] assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" assert request.content def test_get_multipart_form(self): - request = treq(content="foobar") + request = treq(content=b"foobar") assert not request.multipart_form request.headers["Content-Type"] = "multipart/form-data" diff --git a/test/netlib/test_encoding.py b/test/netlib/test_encoding.py index 0ff1aad1..de10fc48 100644 --- a/test/netlib/test_encoding.py +++ b/test/netlib/test_encoding.py @@ -1,37 +1,39 @@ -from netlib import encoding +from netlib import encoding, tutils def test_identity(): - assert b"string" == encoding.decode("identity", b"string") - assert b"string" == encoding.encode("identity", b"string") - assert not encoding.encode("nonexistent", b"string") - assert not encoding.decode("nonexistent encoding", b"string") + assert b"string" == encoding.decode(b"string", "identity") + assert b"string" == encoding.encode(b"string", "identity") + with tutils.raises(ValueError): + encoding.encode(b"string", "nonexistent encoding") def test_gzip(): assert b"string" == encoding.decode( - "gzip", encoding.encode( - "gzip", - b"string" - ) + b"string", + "gzip" + ), + "gzip" ) - assert encoding.decode("gzip", b"bogus") is None + with tutils.raises(ValueError): + encoding.decode(b"bogus", "gzip") def test_deflate(): assert b"string" == encoding.decode( - "deflate", encoding.encode( - "deflate", - b"string" - ) + b"string", + "deflate" + ), + "deflate" ) assert b"string" == encoding.decode( - "deflate", encoding.encode( - "deflate", - b"string" - )[2:-4] + b"string", + "deflate" + )[2:-4], + "deflate" ) - assert encoding.decode("deflate", b"bogus") is None + with tutils.raises(ValueError): + encoding.decode(b"bogus", "deflate") diff --git a/test/netlib/test_strutils.py b/test/netlib/test_strutils.py index 84a0dded..7c3eacc6 100644 --- a/test/netlib/test_strutils.py +++ b/test/netlib/test_strutils.py @@ -1,9 +1,15 @@ -# coding=utf-8 import six from netlib import strutils, tutils +def test_always_bytes(): + assert strutils.always_bytes(bytes(bytearray(range(256)))) == bytes(bytearray(range(256))) + assert strutils.always_bytes("foo") == b"foo" + with tutils.raises(ValueError): + strutils.always_bytes(u"\u2605", "ascii") + + def test_native(): with tutils.raises(TypeError): strutils.native(42) @@ -15,22 +21,26 @@ def test_native(): assert strutils.native(b"foo") == u"foo" -def test_clean_bin(): - assert strutils.clean_bin(b"one") == b"one" - assert strutils.clean_bin(b"\00ne") == b".ne" - assert strutils.clean_bin(b"\nne") == b"\nne" - assert strutils.clean_bin(b"\nne", False) == b".ne" - assert strutils.clean_bin(u"\u2605".encode("utf8")) == b"..." - - assert strutils.clean_bin(u"one") == u"one" - assert strutils.clean_bin(u"\00ne") == u".ne" - assert strutils.clean_bin(u"\nne") == u"\nne" - assert strutils.clean_bin(u"\nne", False) == u".ne" - assert strutils.clean_bin(u"\u2605") == u"\u2605" - - -def test_safe_subn(): - assert strutils.safe_subn("foo", u"bar", "\xc2foo") +def test_escape_control_characters(): + assert strutils.escape_control_characters(u"one") == u"one" + assert strutils.escape_control_characters(u"\00ne") == u".ne" + assert strutils.escape_control_characters(u"\nne") == u"\nne" + assert strutils.escape_control_characters(u"\nne", False) == u".ne" + assert strutils.escape_control_characters(u"\u2605") == u"\u2605" + assert ( + strutils.escape_control_characters(bytes(bytearray(range(128))).decode()) == + u'.........\t\n..\r.................. !"#$%&\'()*+,-./0123456789:;<' + u'=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~.' + ) + assert ( + strutils.escape_control_characters(bytes(bytearray(range(128))).decode(), False) == + u'................................ !"#$%&\'()*+,-./0123456789:;<' + u'=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~.' + ) + + if not six.PY2: + with tutils.raises(ValueError): + strutils.escape_control_characters(b"foo") def test_bytes_to_escaped_str(): @@ -41,6 +51,14 @@ def test_bytes_to_escaped_str(): assert strutils.bytes_to_escaped_str(b"'") == r"\'" assert strutils.bytes_to_escaped_str(b'"') == r'"' + assert strutils.bytes_to_escaped_str(b"\r\n\t") == "\\r\\n\\t" + assert strutils.bytes_to_escaped_str(b"\r\n\t", True) == "\r\n\t" + + assert strutils.bytes_to_escaped_str(b"\n", True) == "\n" + assert strutils.bytes_to_escaped_str(b"\\n", True) == "\\ \\ n".replace(" ", "") + assert strutils.bytes_to_escaped_str(b"\\\n", True) == "\\ \\ \n".replace(" ", "") + assert strutils.bytes_to_escaped_str(b"\\\\n", True) == "\\ \\ \\ \\ n".replace(" ", "") + with tutils.raises(ValueError): strutils.bytes_to_escaped_str(u"such unicode") @@ -49,10 +67,9 @@ def test_escaped_str_to_bytes(): assert strutils.escaped_str_to_bytes("foo") == b"foo" assert strutils.escaped_str_to_bytes("\x08") == b"\b" assert strutils.escaped_str_to_bytes("&!?=\\\\)") == br"&!?=\)" - assert strutils.escaped_str_to_bytes("ü") == b'\xc3\xbc' assert strutils.escaped_str_to_bytes(u"\\x08") == b"\b" assert strutils.escaped_str_to_bytes(u"&!?=\\\\)") == br"&!?=\)" - assert strutils.escaped_str_to_bytes(u"ü") == b'\xc3\xbc' + assert strutils.escaped_str_to_bytes(u"\u00fc") == b'\xc3\xbc' if six.PY2: with tutils.raises(ValueError): @@ -62,17 +79,15 @@ def test_escaped_str_to_bytes(): strutils.escaped_str_to_bytes(b"very byte") -def test_isBin(): - assert not strutils.isBin("testing\n\r") - assert strutils.isBin("testing\x01") - assert strutils.isBin("testing\x0e") - assert strutils.isBin("testing\x7f") +def test_is_mostly_bin(): + assert not strutils.is_mostly_bin(b"foo\xFF") + assert strutils.is_mostly_bin(b"foo" + b"\xFF" * 10) -def test_isXml(): - assert not strutils.isXML("foo") - assert strutils.isXML("<foo") - assert strutils.isXML(" \n<foo") +def test_is_xml(): + assert not strutils.is_xml(b"foo") + assert strutils.is_xml(b"<foo") + assert strutils.is_xml(b" \n<foo") def test_clean_hanging_newline(): diff --git a/test/netlib/test_tcp.py b/test/netlib/test_tcp.py index 590bcc01..273427d5 100644 --- a/test/netlib/test_tcp.py +++ b/test/netlib/test_tcp.py @@ -169,7 +169,7 @@ class TestServerSSL(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni=b"foo.com", options=SSL.OP_ALL) + c.convert_to_ssl(sni="foo.com", options=SSL.OP_ALL) testval = b"echo!\n" c.wfile.write(testval) c.wfile.flush() @@ -179,7 +179,7 @@ class TestServerSSL(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): assert not c.get_current_cipher() - c.convert_to_ssl(sni=b"foo.com") + c.convert_to_ssl(sni="foo.com") ret = c.get_current_cipher() assert ret assert "AES" in ret[0] @@ -195,7 +195,7 @@ class TestSSLv3Only(tservers.ServerTestBase): def test_failure(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - tutils.raises(TlsException, c.convert_to_ssl, sni=b"foo.com") + tutils.raises(TlsException, c.convert_to_ssl, sni="foo.com") class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): @@ -238,7 +238,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): with c.connect(): with tutils.raises(InvalidCertificateException): c.convert_to_ssl( - sni=b"example.mitmproxy.org", + sni="example.mitmproxy.org", verify_options=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt") ) @@ -272,7 +272,7 @@ class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase): with c.connect(): with tutils.raises(InvalidCertificateException): c.convert_to_ssl( - sni=b"mitmproxy.org", + sni="mitmproxy.org", verify_options=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt") ) @@ -291,7 +291,7 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): c.convert_to_ssl( - sni=b"example.mitmproxy.org", + sni="example.mitmproxy.org", verify_options=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt") ) @@ -307,7 +307,7 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): c.convert_to_ssl( - sni=b"example.mitmproxy.org", + sni="example.mitmproxy.org", verify_options=SSL.VERIFY_PEER, ca_path=tutils.test_data.path("data/verificationcerts/") ) @@ -371,8 +371,8 @@ class TestSNI(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni=b"foo.com") - assert c.sni == b"foo.com" + c.convert_to_ssl(sni="foo.com") + assert c.sni == "foo.com" assert c.rfile.readline() == b"foo.com" @@ -385,7 +385,7 @@ class TestServerCipherList(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni=b"foo.com") + c.convert_to_ssl(sni="foo.com") assert c.rfile.readline() == b"['RC4-SHA']" @@ -405,7 +405,7 @@ class TestServerCurrentCipher(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni=b"foo.com") + c.convert_to_ssl(sni="foo.com") assert b"RC4-SHA" in c.rfile.readline() @@ -418,7 +418,7 @@ class TestServerCipherListError(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - tutils.raises("handshake error", c.convert_to_ssl, sni=b"foo.com") + tutils.raises("handshake error", c.convert_to_ssl, sni="foo.com") class TestClientCipherListError(tservers.ServerTestBase): @@ -433,7 +433,7 @@ class TestClientCipherListError(tservers.ServerTestBase): tutils.raises( "cipher specification", c.convert_to_ssl, - sni=b"foo.com", + sni="foo.com", cipher_list="bogus" ) diff --git a/test/netlib/tservers.py b/test/netlib/tservers.py index 803aaa72..666f97ac 100644 --- a/test/netlib/tservers.py +++ b/test/netlib/tservers.py @@ -24,7 +24,7 @@ class _ServerThread(threading.Thread): class _TServer(tcp.TCPServer): - def __init__(self, ssl, q, handler_klass, addr): + def __init__(self, ssl, q, handler_klass, addr, **kwargs): """ ssl: A dictionary of SSL parameters: @@ -42,6 +42,8 @@ class _TServer(tcp.TCPServer): self.q = q self.handler_klass = handler_klass + if self.handler_klass is not None: + self.handler_klass.kwargs = kwargs self.last_handler = None def handle_client_connection(self, request, client_address): @@ -89,16 +91,16 @@ class ServerTestBase(object): addr = ("localhost", 0) @classmethod - def setup_class(cls): + def setup_class(cls, **kwargs): cls.q = queue.Queue() - s = cls.makeserver() + s = cls.makeserver(**kwargs) cls.port = s.address.port cls.server = _ServerThread(s) cls.server.start() @classmethod - def makeserver(cls): - return _TServer(cls.ssl, cls.q, cls.handler, cls.addr) + def makeserver(cls, **kwargs): + return _TServer(cls.ssl, cls.q, cls.handler, cls.addr, **kwargs) @classmethod def teardown_class(cls): |