diff options
Diffstat (limited to 'test')
-rw-r--r-- | test/mitmproxy/protocol/__init__.py | 0 | ||||
-rw-r--r-- | test/mitmproxy/protocol/test_http1.py (renamed from test/mitmproxy/test_protocol_http1.py) | 4 | ||||
-rw-r--r-- | test/mitmproxy/protocol/test_http2.py (renamed from test/mitmproxy/test_protocol_http2.py) | 10 | ||||
-rw-r--r-- | test/mitmproxy/protocol/test_websockets.py | 299 | ||||
-rw-r--r-- | test/mitmproxy/test_examples.py | 90 | ||||
-rw-r--r-- | test/mitmproxy/test_proxy.py | 26 | ||||
-rw-r--r-- | test/mitmproxy/test_server.py | 21 | ||||
-rw-r--r-- | test/netlib/http/test_cookies.py | 68 | ||||
-rw-r--r-- | test/netlib/http/test_headers.py | 5 | ||||
-rw-r--r-- | test/netlib/http/test_message.py | 10 | ||||
-rw-r--r-- | test/netlib/http/test_request.py | 10 | ||||
-rw-r--r-- | test/netlib/http/test_response.py | 20 | ||||
-rw-r--r-- | test/netlib/test_strutils.py | 1 | ||||
-rw-r--r-- | test/netlib/tservers.py | 3 | ||||
-rw-r--r-- | test/netlib/websockets/test_frame.py | 164 | ||||
-rw-r--r-- | test/netlib/websockets/test_masker.py | 23 | ||||
-rw-r--r-- | test/netlib/websockets/test_utils.py | 105 | ||||
-rw-r--r-- | test/netlib/websockets/test_websockets.py | 269 |
18 files changed, 812 insertions, 316 deletions
diff --git a/test/mitmproxy/protocol/__init__.py b/test/mitmproxy/protocol/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/mitmproxy/protocol/__init__.py diff --git a/test/mitmproxy/test_protocol_http1.py b/test/mitmproxy/protocol/test_http1.py index cf7bd598..7d04c56b 100644 --- a/test/mitmproxy/test_protocol_http1.py +++ b/test/mitmproxy/protocol/test_http1.py @@ -1,7 +1,9 @@ +from __future__ import (absolute_import, print_function, division) + from netlib.http import http1 from netlib.tcp import TCPClient from netlib.tutils import treq -from . import tutils, tservers +from .. import tutils, tservers class TestHTTPFlow(object): diff --git a/test/mitmproxy/test_protocol_http2.py b/test/mitmproxy/protocol/test_http2.py index f0fa9a40..1eabebf1 100644 --- a/test/mitmproxy/test_protocol_http2.py +++ b/test/mitmproxy/protocol/test_http2.py @@ -13,11 +13,11 @@ from mitmproxy import options from mitmproxy.proxy.config import ProxyConfig import netlib -from ..netlib import tservers as netlib_tservers +from ...netlib import tservers as netlib_tservers from netlib.exceptions import HttpException from netlib.http.http2 import framereader -from . import tservers +from .. import tservers import logging logging.getLogger("hyper.packages.hpack.hpack").setLevel(logging.WARNING) @@ -849,15 +849,15 @@ class TestMaxConcurrentStreams(_Http2Test): def test_max_concurrent_streams(self): client, h2_conn = self._setup_connection() new_streams = [1, 3, 5, 7, 9, 11] - for id in new_streams: + for stream_id in new_streams: # this will exceed MAX_CONCURRENT_STREAMS on the server connection # and cause mitmproxy to throttle stream creation to the server - self._send_request(client.wfile, h2_conn, stream_id=id, headers=[ + self._send_request(client.wfile, h2_conn, stream_id=stream_id, headers=[ (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), (':method', 'GET'), (':scheme', 'https'), (':path', '/'), - ('X-Stream-ID', str(id)), + ('X-Stream-ID', str(stream_id)), ]) ended_streams = 0 diff --git a/test/mitmproxy/protocol/test_websockets.py b/test/mitmproxy/protocol/test_websockets.py new file mode 100644 index 00000000..e7e2684f --- /dev/null +++ b/test/mitmproxy/protocol/test_websockets.py @@ -0,0 +1,299 @@ +from __future__ import absolute_import, print_function, division + +import pytest +import os +import tempfile +import traceback + +from mitmproxy import options +from mitmproxy.proxy.config import ProxyConfig + +import netlib +from netlib import http +from ...netlib import tservers as netlib_tservers +from .. import tservers + +from netlib import websockets + + +class _WebSocketsServerBase(netlib_tservers.ServerTestBase): + + class handler(netlib.tcp.BaseHandler): + + def handle(self): + try: + request = http.http1.read_request(self.rfile) + assert websockets.check_handshake(request.headers) + + response = http.Response( + "HTTP/1.1", + 101, + reason=http.status_codes.RESPONSES.get(101), + headers=http.Headers( + connection='upgrade', + upgrade='websocket', + sec_websocket_accept=b'', + ), + content=b'', + ) + self.wfile.write(http.http1.assemble_response(response)) + self.wfile.flush() + + self.server.handle_websockets(self.rfile, self.wfile) + except: + traceback.print_exc() + + +class _WebSocketsTestBase(object): + + @classmethod + def setup_class(cls): + opts = cls.get_options() + cls.config = ProxyConfig(opts) + + tmaster = tservers.TestMaster(opts, cls.config) + tmaster.start_app(options.APP_HOST, options.APP_PORT) + cls.proxy = tservers.ProxyThread(tmaster) + cls.proxy.start() + + @classmethod + def teardown_class(cls): + cls.proxy.shutdown() + + @classmethod + def get_options(cls): + opts = options.Options( + listen_port=0, + no_upstream_cert=False, + ssl_insecure=True + ) + opts.cadir = os.path.join(tempfile.gettempdir(), "mitmproxy") + return opts + + @property + def master(self): + return self.proxy.tmaster + + def setup(self): + self.master.clear_log() + self.master.state.clear() + self.server.server.handle_websockets = self.handle_websockets + + def _setup_connection(self): + client = netlib.tcp.TCPClient(("127.0.0.1", self.proxy.port)) + client.connect() + + request = http.Request( + "authority", + "CONNECT", + "", + "localhost", + self.server.server.address.port, + "", + "HTTP/1.1", + content=b'') + client.wfile.write(http.http1.assemble_request(request)) + client.wfile.flush() + + response = http.http1.read_response(client.rfile, request) + + if self.ssl: + client.convert_to_ssl() + assert client.ssl_established + + request = http.Request( + "relative", + "GET", + "http", + "localhost", + self.server.server.address.port, + "/ws", + "HTTP/1.1", + headers=http.Headers( + connection="upgrade", + upgrade="websocket", + sec_websocket_version="13", + sec_websocket_key="1234", + ), + content=b'') + client.wfile.write(http.http1.assemble_request(request)) + client.wfile.flush() + + response = http.http1.read_response(client.rfile, request) + assert websockets.check_handshake(response.headers) + + return client + + +class _WebSocketsTest(_WebSocketsTestBase, _WebSocketsServerBase): + + @classmethod + def setup_class(cls): + _WebSocketsTestBase.setup_class() + _WebSocketsServerBase.setup_class(ssl=cls.ssl) + + @classmethod + def teardown_class(cls): + _WebSocketsTestBase.teardown_class() + _WebSocketsServerBase.teardown_class() + + +class TestSimple(_WebSocketsTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar'))) + wfile.flush() + + frame = websockets.Frame.from_file(rfile) + wfile.write(bytes(frame)) + wfile.flush() + + def test_simple(self): + client = self._setup_connection() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.payload == b'server-foobar' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar'))) + client.wfile.flush() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.payload == b'client-foobar' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + client.wfile.flush() + + +class TestSimpleTLS(_WebSocketsTest): + ssl = True + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar'))) + wfile.flush() + + frame = websockets.Frame.from_file(rfile) + wfile.write(bytes(frame)) + wfile.flush() + + def test_simple_tls(self): + client = self._setup_connection() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.payload == b'server-foobar' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar'))) + client.wfile.flush() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.payload == b'client-foobar' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + client.wfile.flush() + + +class TestPing(_WebSocketsTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) + wfile.flush() + + frame = websockets.Frame.from_file(rfile) + assert frame.header.opcode == websockets.OPCODE.PONG + assert frame.payload == b'foobar' + + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'pong-received'))) + wfile.flush() + + def test_ping(self): + client = self._setup_connection() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.header.opcode == websockets.OPCODE.PING + assert frame.payload == b'foobar' + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) + client.wfile.flush() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.header.opcode == websockets.OPCODE.TEXT + assert frame.payload == b'pong-received' + + +class TestPong(_WebSocketsTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + frame = websockets.Frame.from_file(rfile) + assert frame.header.opcode == websockets.OPCODE.PING + assert frame.payload == b'foobar' + + wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload))) + wfile.flush() + + def test_pong(self): + client = self._setup_connection() + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar'))) + client.wfile.flush() + + frame = websockets.Frame.from_file(client.rfile) + assert frame.header.opcode == websockets.OPCODE.PONG + assert frame.payload == b'foobar' + + +class TestClose(_WebSocketsTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + frame = websockets.Frame.from_file(rfile) + wfile.write(bytes(frame)) + wfile.flush() + + with pytest.raises(netlib.exceptions.TcpDisconnect): + websockets.Frame.from_file(rfile) + + def test_close(self): + client = self._setup_connection() + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE))) + client.wfile.flush() + + with pytest.raises(netlib.exceptions.TcpDisconnect): + websockets.Frame.from_file(client.rfile) + + def test_close_payload_1(self): + client = self._setup_connection() + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42'))) + client.wfile.flush() + + with pytest.raises(netlib.exceptions.TcpDisconnect): + websockets.Frame.from_file(client.rfile) + + def test_close_payload_2(self): + client = self._setup_connection() + + client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar'))) + client.wfile.flush() + + with pytest.raises(netlib.exceptions.TcpDisconnect): + websockets.Frame.from_file(client.rfile) + + +class TestInvalidFrame(_WebSocketsTest): + + @classmethod + def handle_websockets(cls, rfile, wfile): + wfile.write(bytes(websockets.Frame(fin=1, opcode=15, payload=b'foobar'))) + wfile.flush() + + def test_invalid_frame(self): + client = self._setup_connection() + + # with pytest.raises(netlib.exceptions.TcpDisconnect): + frame = websockets.Frame.from_file(client.rfile) + assert frame.header.opcode == 15 + assert frame.payload == b'foobar' diff --git a/test/mitmproxy/test_examples.py b/test/mitmproxy/test_examples.py index 6c24ace5..83a37a36 100644 --- a/test/mitmproxy/test_examples.py +++ b/test/mitmproxy/test_examples.py @@ -1,16 +1,20 @@ import json +import os import six -import sys -import os.path -from mitmproxy.flow import master -from mitmproxy.flow import state + from mitmproxy import options from mitmproxy import contentviews from mitmproxy.builtins import script +from mitmproxy.flow import master +from mitmproxy.flow import state + import netlib.utils + from netlib import tutils as netutils from netlib.http import Headers +from netlib.http import cookies + from . import tutils, mastertest example_dir = netlib.utils.Data(__name__).push("../../examples") @@ -98,30 +102,66 @@ class TestScripts(mastertest.MasterTest): m.request(f) assert f.request.host == "mitmproxy.org" - def test_har_extractor(self): - if sys.version_info >= (3, 0): - with tutils.raises("does not work on Python 3"): - tscript("har_extractor.py") - return +class TestHARDump(): + + def flow(self, resp_content=b'message'): + times = dict( + timestamp_start=746203272, + timestamp_end=746203272, + ) + + # Create a dummy flow for testing + return tutils.tflow( + req=netutils.treq(method=b'GET', **times), + resp=netutils.tresp(content=resp_content, **times) + ) + + def test_no_file_arg(self): with tutils.raises(ScriptError): - tscript("har_extractor.py") + tscript("har_dump.py") + + def test_simple(self): + with tutils.tmpdir() as tdir: + path = os.path.join(tdir, "somefile") + + m, sc = tscript("har_dump.py", six.moves.shlex_quote(path)) + m.addons.invoke(m, "response", self.flow()) + m.addons.remove(sc) + with open(path, "r") as inp: + har = json.load(inp) + + assert len(har["log"]["entries"]) == 1 + + def test_base64(self): with tutils.tmpdir() as tdir: - times = dict( - timestamp_start=746203272, - timestamp_end=746203272, - ) - - path = os.path.join(tdir, "file") - m, sc = tscript("har_extractor.py", six.moves.shlex_quote(path)) - f = tutils.tflow( - req=netutils.treq(**times), - resp=netutils.tresp(**times) - ) - m.response(f) + path = os.path.join(tdir, "somefile") + + m, sc = tscript("har_dump.py", six.moves.shlex_quote(path)) + m.addons.invoke(m, "response", self.flow(resp_content=b"foo" + b"\xFF" * 10)) m.addons.remove(sc) - with open(path, "rb") as f: - test_data = json.load(f) - assert len(test_data["log"]["pages"]) == 1 + with open(path, "r") as inp: + har = json.load(inp) + + assert har["log"]["entries"][0]["response"]["content"]["encoding"] == "base64" + + def test_format_cookies(self): + m, sc = tscript("har_dump.py", "-") + format_cookies = sc.ns.ns["format_cookies"] + + CA = cookies.CookieAttrs + + f = format_cookies([("n", "v", CA([("k", "v")]))])[0] + assert f['name'] == "n" + assert f['value'] == "v" + assert not f['httpOnly'] + assert not f['secure'] + + f = format_cookies([("n", "v", CA([("httponly", None), ("secure", None)]))])[0] + assert f['httpOnly'] + assert f['secure'] + + f = format_cookies([("n", "v", CA([("expires", "Mon, 24-Aug-2037 00:00:00 GMT")]))])[0] + assert f['expires'] diff --git a/test/mitmproxy/test_proxy.py b/test/mitmproxy/test_proxy.py index 84838018..f7c64e50 100644 --- a/test/mitmproxy/test_proxy.py +++ b/test/mitmproxy/test_proxy.py @@ -85,22 +85,22 @@ class TestProcessProxyOptions: @mock.patch("mitmproxy.platform.resolver") def test_modes(self, _): - # self.assert_noerr("-R", "http://localhost") - # self.assert_err("expected one argument", "-R") - # self.assert_err("Invalid server specification", "-R", "reverse") - # - # self.assert_noerr("-T") - # - # self.assert_noerr("-U", "http://localhost") - # self.assert_err("expected one argument", "-U") - # self.assert_err("Invalid server specification", "-U", "upstream") - # - # self.assert_noerr("--upstream-auth", "test:test") - # self.assert_err("expected one argument", "--upstream-auth") + self.assert_noerr("-R", "http://localhost") + self.assert_err("expected one argument", "-R") + self.assert_err("Invalid server specification", "-R", "reverse") + + self.assert_noerr("-T") + + self.assert_noerr("-U", "http://localhost") + self.assert_err("expected one argument", "-U") + self.assert_err("Invalid server specification", "-U", "upstream") + + self.assert_noerr("--upstream-auth", "test:test") + self.assert_err("expected one argument", "--upstream-auth") self.assert_err( "Invalid upstream auth specification", "--upstream-auth", "test" ) - # self.assert_err("mutually exclusive", "-R", "http://localhost", "-T") + self.assert_err("mutually exclusive", "-R", "http://localhost", "-T") def test_socks_auth(self): self.assert_err( diff --git a/test/mitmproxy/test_server.py b/test/mitmproxy/test_server.py index 78e9b5c7..e0a8da47 100644 --- a/test/mitmproxy/test_server.py +++ b/test/mitmproxy/test_server.py @@ -313,6 +313,22 @@ class TestHTTPAuth(tservers.HTTPProxyTest): assert ret.status_code == 202 +class TestHTTPReverseAuth(tservers.ReverseProxyTest): + def test_auth(self): + self.master.options.auth_singleuser = "test:test" + assert self.pathod("202").status_code == 401 + p = self.pathoc() + ret = p.request(""" + get + '/p/202' + h'%s'='%s' + """ % ( + http.authentication.BasicWebsiteAuth.AUTH_HEADER, + authentication.assemble_http_basic_auth("basic", "test", "test") + )) + assert ret.status_code == 202 + + class TestHTTPS(tservers.HTTPProxyTest, CommonMixin, TcpMixin): ssl = True ssloptions = pathod.SSLOptions(request_client_cert=True) @@ -456,6 +472,11 @@ class TestReverse(tservers.ReverseProxyTest, CommonMixin, TcpMixin): reverse = True +class TestReverseSSL(tservers.ReverseProxyTest, CommonMixin, TcpMixin): + reverse = True + ssl = True + + class TestSocks5(tservers.SocksModeTest): def test_simple(self): diff --git a/test/netlib/http/test_cookies.py b/test/netlib/http/test_cookies.py index 17e21b94..efd8ba80 100644 --- a/test/netlib/http/test_cookies.py +++ b/test/netlib/http/test_cookies.py @@ -1,6 +1,10 @@ +import time + from netlib.http import cookies from netlib.tutils import raises +import mock + def test_read_token(): tokens = [ @@ -247,6 +251,22 @@ def test_refresh_cookie(): assert cookies.refresh_set_cookie_header(c, 0) +@mock.patch('time.time') +def test_get_expiration_ts(*args): + # Freeze time + now_ts = 17 + time.time.return_value = now_ts + + CA = cookies.CookieAttrs + F = cookies.get_expiration_ts + + assert F(CA([("Expires", "Thu, 01-Jan-1970 00:00:00 GMT")])) == 0 + assert F(CA([("Expires", "Mon, 24-Aug-2037 00:00:00 GMT")])) == 2134684800 + + assert F(CA([("Max-Age", "0")])) == now_ts + assert F(CA([("Max-Age", "31")])) == now_ts + 31 + + def test_is_expired(): CA = cookies.CookieAttrs @@ -260,9 +280,53 @@ def test_is_expired(): # or both assert cookies.is_expired(CA([("Expires", "Thu, 01-Jan-1970 00:00:00 GMT"), ("Max-Age", "0")])) - assert not cookies.is_expired(CA([("Expires", "Thu, 24-Aug-2063 00:00:00 GMT")])) + assert not cookies.is_expired(CA([("Expires", "Mon, 24-Aug-2037 00:00:00 GMT")])) assert not cookies.is_expired(CA([("Max-Age", "1")])) - assert not cookies.is_expired(CA([("Expires", "Thu, 15-Jul-2068 00:00:00 GMT"), ("Max-Age", "1")])) + assert not cookies.is_expired(CA([("Expires", "Wed, 15-Jul-2037 00:00:00 GMT"), ("Max-Age", "1")])) assert not cookies.is_expired(CA([("Max-Age", "nan")])) assert not cookies.is_expired(CA([("Expires", "false")])) + + +def test_group_cookies(): + CA = cookies.CookieAttrs + groups = [ + [ + "one=uno; foo=bar; foo=baz", + [ + ('one', 'uno', CA([])), + ('foo', 'bar', CA([])), + ('foo', 'baz', CA([])) + ] + ], + [ + "one=uno; Path=/; foo=bar; Max-Age=0; foo=baz; expires=24-08-1993", + [ + ('one', 'uno', CA([('Path', '/')])), + ('foo', 'bar', CA([('Max-Age', '0')])), + ('foo', 'baz', CA([('expires', '24-08-1993')])) + ] + ], + [ + "one=uno;", + [ + ('one', 'uno', CA([])) + ] + ], + [ + "one=uno; Path=/; Max-Age=0; Expires=24-08-1993", + [ + ('one', 'uno', CA([('Path', '/'), ('Max-Age', '0'), ('Expires', '24-08-1993')])) + ] + ], + [ + "path=val; Path=/", + [ + ('path', 'val', CA([('Path', '/')])) + ] + ] + ] + + for c, expected in groups: + observed = cookies.group_cookies(cookies.parse_cookie_header(c)) + assert observed == expected diff --git a/test/netlib/http/test_headers.py b/test/netlib/http/test_headers.py index 51537310..ad2bc548 100644 --- a/test/netlib/http/test_headers.py +++ b/test/netlib/http/test_headers.py @@ -75,6 +75,11 @@ class TestHeaders(object): assert replacements == 0 assert headers["Host"] == "example.com" + def test_replace_with_count(self): + headers = Headers(Host="foobarfoo.com", Accept="foo/bar") + replacements = headers.replace("foo", "bar", count=1) + assert replacements == 1 + def test_parse_content_type(): p = parse_content_type diff --git a/test/netlib/http/test_message.py b/test/netlib/http/test_message.py index 12e4706c..74272309 100644 --- a/test/netlib/http/test_message.py +++ b/test/netlib/http/test_message.py @@ -99,6 +99,16 @@ class TestMessage(object): def test_http_version(self): _test_decoded_attr(tresp(), "http_version") + def test_replace(self): + r = tresp() + r.content = b"foofootoo" + r.replace(b"foo", "gg") + assert r.content == b"ggggtoo" + + r.content = b"foofootoo" + r.replace(b"foo", "gg", count=1) + assert r.content == b"ggfootoo" + class TestMessageContentEncoding(object): def test_simple(self): diff --git a/test/netlib/http/test_request.py b/test/netlib/http/test_request.py index f3cd8b71..9baabaa6 100644 --- a/test/netlib/http/test_request.py +++ b/test/netlib/http/test_request.py @@ -26,6 +26,16 @@ class TestRequestCore(object): request.host = None assert repr(request) == "Request(GET /path)" + def replace(self): + r = treq() + r.path = b"foobarfoo" + r.replace(b"foo", "bar") + assert r.path == b"barbarbar" + + r.path = b"foobarfoo" + r.replace(b"foo", "bar", count=1) + assert r.path == b"barbarfoo" + def test_first_line_format(self): _test_passthrough_attr(treq(), "first_line_format") diff --git a/test/netlib/http/test_response.py b/test/netlib/http/test_response.py index b3c2f736..c7b1b646 100644 --- a/test/netlib/http/test_response.py +++ b/test/netlib/http/test_response.py @@ -5,6 +5,7 @@ import email import time from netlib.http import Headers +from netlib.http import Response from netlib.http.cookies import CookieAttrs from netlib.tutils import raises, tresp from .test_message import _test_passthrough_attr, _test_decoded_attr @@ -28,6 +29,25 @@ class TestResponseCore(object): response.content = None assert repr(response) == "Response(200 OK, no content)" + def test_make(self): + r = Response.make() + assert r.status_code == 200 + assert r.content == b"" + + Response.make(content=b"foo") + Response.make(content="foo") + with raises(TypeError): + Response.make(content=42) + + r = Response.make(headers=[(b"foo", b"bar")]) + assert r.headers["foo"] == "bar" + + r = Response.make(headers=({"foo": "baz"})) + assert r.headers["foo"] == "baz" + + with raises(TypeError): + Response.make(headers=42) + def test_status_code(self): _test_passthrough_attr(tresp(), "status_code") diff --git a/test/netlib/test_strutils.py b/test/netlib/test_strutils.py index 52299e59..5be254a3 100644 --- a/test/netlib/test_strutils.py +++ b/test/netlib/test_strutils.py @@ -85,6 +85,7 @@ def test_escaped_str_to_bytes(): def test_is_mostly_bin(): assert not strutils.is_mostly_bin(b"foo\xFF") assert strutils.is_mostly_bin(b"foo" + b"\xFF" * 10) + assert not strutils.is_mostly_bin("") def test_is_xml(): diff --git a/test/netlib/tservers.py b/test/netlib/tservers.py index 666f97ac..a80dcb28 100644 --- a/test/netlib/tservers.py +++ b/test/netlib/tservers.py @@ -100,7 +100,8 @@ class ServerTestBase(object): @classmethod def makeserver(cls, **kwargs): - return _TServer(cls.ssl, cls.q, cls.handler, cls.addr, **kwargs) + ssl = kwargs.pop('ssl', cls.ssl) + return _TServer(ssl, cls.q, cls.handler, cls.addr, **kwargs) @classmethod def teardown_class(cls): diff --git a/test/netlib/websockets/test_frame.py b/test/netlib/websockets/test_frame.py new file mode 100644 index 00000000..cce39454 --- /dev/null +++ b/test/netlib/websockets/test_frame.py @@ -0,0 +1,164 @@ +import os +import codecs +import pytest + +from netlib import websockets +from netlib import tutils + + +class TestFrameHeader(object): + + @pytest.mark.parametrize("input,expected", [ + (0, '0100'), + (125, '017D'), + (126, '017E007E'), + (127, '017E007F'), + (142, '017E008E'), + (65534, '017EFFFE'), + (65535, '017EFFFF'), + (65536, '017F0000000000010000'), + (8589934591, '017F00000001FFFFFFFF'), + (2 ** 64 - 1, '017FFFFFFFFFFFFFFFFF'), + ]) + def test_serialization_length(self, input, expected): + h = websockets.FrameHeader( + opcode=websockets.OPCODE.TEXT, + payload_length=input, + ) + assert bytes(h) == codecs.decode(expected, 'hex') + + def test_serialization_too_large(self): + h = websockets.FrameHeader( + payload_length=2 ** 64 + 1, + ) + with pytest.raises(ValueError): + bytes(h) + + @pytest.mark.parametrize("input,expected", [ + ('0100', 0), + ('017D', 125), + ('017E007E', 126), + ('017E007F', 127), + ('017E008E', 142), + ('017EFFFE', 65534), + ('017EFFFF', 65535), + ('017F0000000000010000', 65536), + ('017F00000001FFFFFFFF', 8589934591), + ('017FFFFFFFFFFFFFFFFF', 2 ** 64 - 1), + ]) + def test_deserialization_length(self, input, expected): + h = websockets.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex'))) + assert h.payload_length == expected + + @pytest.mark.parametrize("input,expected", [ + ('0100', (False, None)), + ('018000000000', (True, '00000000')), + ('018012345678', (True, '12345678')), + ]) + def test_deserialization_masking(self, input, expected): + h = websockets.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex'))) + assert h.mask == expected[0] + if h.mask: + assert h.masking_key == codecs.decode(expected[1], 'hex') + + def test_equality(self): + h = websockets.FrameHeader(mask=True, masking_key=b'1234') + h2 = websockets.FrameHeader(mask=True, masking_key=b'1234') + assert h == h2 + + h = websockets.FrameHeader(fin=True) + h2 = websockets.FrameHeader(fin=False) + assert h != h2 + + assert h != 'foobar' + + def test_roundtrip(self): + def round(*args, **kwargs): + h = websockets.FrameHeader(*args, **kwargs) + h2 = websockets.FrameHeader.from_file(tutils.treader(bytes(h))) + assert h == h2 + + round() + round(fin=True) + round(rsv1=True) + round(rsv2=True) + round(rsv3=True) + round(payload_length=1) + round(payload_length=100) + round(payload_length=1000) + round(payload_length=10000) + round(opcode=websockets.OPCODE.PING) + round(masking_key=b"test") + + def test_human_readable(self): + f = websockets.FrameHeader( + masking_key=b"test", + fin=True, + payload_length=10 + ) + assert repr(f) + + f = websockets.FrameHeader() + assert repr(f) + + def test_funky(self): + f = websockets.FrameHeader(masking_key=b"test", mask=False) + raw = bytes(f) + f2 = websockets.FrameHeader.from_file(tutils.treader(raw)) + assert not f2.mask + + def test_violations(self): + tutils.raises("opcode", websockets.FrameHeader, opcode=17) + tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x") + + def test_automask(self): + f = websockets.FrameHeader(mask=True) + assert f.masking_key + + f = websockets.FrameHeader(masking_key=b"foob") + assert f.mask + + f = websockets.FrameHeader(masking_key=b"foob", mask=0) + assert not f.mask + assert f.masking_key + + +class TestFrame(object): + def test_equality(self): + f = websockets.Frame(payload=b'1234') + f2 = websockets.Frame(payload=b'1234') + assert f == f2 + + assert f != b'1234' + + def test_roundtrip(self): + def round(*args, **kwargs): + f = websockets.Frame(*args, **kwargs) + raw = bytes(f) + f2 = websockets.Frame.from_file(tutils.treader(raw)) + assert f == f2 + round(b"test") + round(b"test", fin=1) + round(b"test", rsv1=1) + round(b"test", opcode=websockets.OPCODE.PING) + round(b"test", masking_key=b"test") + + def test_human_readable(self): + f = websockets.Frame() + assert repr(f) + + f = websockets.Frame(b"foobar") + assert "foobar" in repr(f) + + @pytest.mark.parametrize("masked", [True, False]) + @pytest.mark.parametrize("length", [100, 50000, 150000]) + def test_serialization_bijection(self, masked, length): + frame = websockets.Frame( + os.urandom(length), + fin=True, + opcode=websockets.OPCODE.TEXT, + mask=int(masked), + masking_key=(os.urandom(4) if masked else None) + ) + serialized = bytes(frame) + assert frame == websockets.Frame.from_bytes(serialized) diff --git a/test/netlib/websockets/test_masker.py b/test/netlib/websockets/test_masker.py new file mode 100644 index 00000000..528fce71 --- /dev/null +++ b/test/netlib/websockets/test_masker.py @@ -0,0 +1,23 @@ +import codecs +import pytest + +from netlib import websockets + + +class TestMasker(object): + + @pytest.mark.parametrize("input,expected", [ + ([b"a"], '00'), + ([b"four"], '070d1616'), + ([b"fourf"], '070d161607'), + ([b"fourfive"], '070d1616070b1501'), + ([b"a", b"aasdfasdfa", b"asdf"], '000302170504021705040205120605'), + ([b"a" * 50, b"aasdfasdfa", b"asdf"], '00030205000302050003020500030205000302050003020500030205000302050003020500030205000302050003020500030205120605051206050500110702'), # noqa + ]) + def test_masker(self, input, expected): + m = websockets.Masker(b"abcd") + data = b"".join([m(t) for t in input]) + assert data == codecs.decode(expected, 'hex') + + data = websockets.Masker(b"abcd")(data) + assert data == b"".join(input) diff --git a/test/netlib/websockets/test_utils.py b/test/netlib/websockets/test_utils.py new file mode 100644 index 00000000..34765e04 --- /dev/null +++ b/test/netlib/websockets/test_utils.py @@ -0,0 +1,105 @@ +import pytest + +from netlib import http +from netlib import websockets + + +class TestUtils(object): + + def test_client_handshake_headers(self): + h = websockets.client_handshake_headers(version='42') + assert h['sec-websocket-version'] == '42' + + h = websockets.client_handshake_headers(key='some-key') + assert h['sec-websocket-key'] == 'some-key' + + h = websockets.client_handshake_headers(protocol='foobar') + assert h['sec-websocket-protocol'] == 'foobar' + + h = websockets.client_handshake_headers(extensions='foo; bar') + assert h['sec-websocket-extensions'] == 'foo; bar' + + def test_server_handshake_headers(self): + h = websockets.server_handshake_headers('some-key') + assert h['sec-websocket-accept'] == '8iILEZtcVdtFD7MDlPKip9ec9nw=' + assert 'sec-websocket-protocol' not in h + assert 'sec-websocket-extensions' not in h + + h = websockets.server_handshake_headers('some-key', 'foobar', 'foo; bar') + assert h['sec-websocket-accept'] == '8iILEZtcVdtFD7MDlPKip9ec9nw=' + assert h['sec-websocket-protocol'] == 'foobar' + assert h['sec-websocket-extensions'] == 'foo; bar' + + @pytest.mark.parametrize("input,expected", [ + ([(b'connection', b'upgrade'), (b'upgrade', b'websocket'), (b'sec-websocket-key', b'foobar')], True), + ([(b'connection', b'upgrade'), (b'upgrade', b'websocket'), (b'sec-websocket-accept', b'foobar')], True), + ([(b'Connection', b'UpgRaDe'), (b'Upgrade', b'WebSocKeT'), (b'Sec-WebSockeT-KeY', b'foobar')], True), + ([(b'Connection', b'UpgRaDe'), (b'Upgrade', b'WebSocKeT'), (b'Sec-WebSockeT-AccePt', b'foobar')], True), + ([(b'connection', b'foo'), (b'upgrade', b'bar'), (b'sec-websocket-key', b'foobar')], False), + ([(b'connection', b'upgrade'), (b'upgrade', b'websocket')], False), + ([(b'connection', b'upgrade'), (b'sec-websocket-key', b'foobar')], False), + ([(b'upgrade', b'websocket'), (b'sec-websocket-key', b'foobar')], False), + ([], False), + ]) + def test_check_handshake(self, input, expected): + h = http.Headers(input) + assert websockets.check_handshake(h) == expected + + @pytest.mark.parametrize("input,expected", [ + ([(b'sec-websocket-version', b'13')], True), + ([(b'Sec-WebSockeT-VerSion', b'13')], True), + ([(b'sec-websocket-version', b'9')], False), + ([(b'sec-websocket-version', b'42')], False), + ([(b'sec-websocket-version', b'')], False), + ([], False), + ]) + def test_check_client_version(self, input, expected): + h = http.Headers(input) + assert websockets.check_client_version(h) == expected + + @pytest.mark.parametrize("input,expected", [ + ('foobar', b'AzhRPA4TNwR6I/riJheN0TfR7+I='), + (b'foobar', b'AzhRPA4TNwR6I/riJheN0TfR7+I='), + ]) + def test_create_server_nonce(self, input, expected): + assert websockets.create_server_nonce(input) == expected + + @pytest.mark.parametrize("input,expected", [ + ([(b'sec-websocket-extensions', b'foo; bar')], 'foo; bar'), + ([(b'Sec-WebSockeT-ExteNsionS', b'foo; bar')], 'foo; bar'), + ([(b'sec-websocket-extensions', b'')], ''), + ([], None), + ]) + def test_get_extensions(self, input, expected): + h = http.Headers(input) + assert websockets.get_extensions(h) == expected + + @pytest.mark.parametrize("input,expected", [ + ([(b'sec-websocket-protocol', b'foobar')], 'foobar'), + ([(b'Sec-WebSockeT-ProTocoL', b'foobar')], 'foobar'), + ([(b'sec-websocket-protocol', b'')], ''), + ([], None), + ]) + def test_get_protocol(self, input, expected): + h = http.Headers(input) + assert websockets.get_protocol(h) == expected + + @pytest.mark.parametrize("input,expected", [ + ([(b'sec-websocket-key', b'foobar')], 'foobar'), + ([(b'Sec-WebSockeT-KeY', b'foobar')], 'foobar'), + ([(b'sec-websocket-key', b'')], ''), + ([], None), + ]) + def test_get_client_key(self, input, expected): + h = http.Headers(input) + assert websockets.get_client_key(h) == expected + + @pytest.mark.parametrize("input,expected", [ + ([(b'sec-websocket-accept', b'foobar')], 'foobar'), + ([(b'Sec-WebSockeT-AccepT', b'foobar')], 'foobar'), + ([(b'sec-websocket-accept', b'')], ''), + ([], None), + ]) + def test_get_server_accept(self, input, expected): + h = http.Headers(input) + assert websockets.get_server_accept(h) == expected diff --git a/test/netlib/websockets/test_websockets.py b/test/netlib/websockets/test_websockets.py deleted file mode 100644 index 50fa26e6..00000000 --- a/test/netlib/websockets/test_websockets.py +++ /dev/null @@ -1,269 +0,0 @@ -import os - -from netlib.http.http1 import read_response, read_request - -from netlib import tcp -from netlib import tutils -from netlib import websockets -from netlib.http import status_codes -from netlib.tutils import treq -from netlib import exceptions - -from .. import tservers - - -class WebSocketsEchoHandler(tcp.BaseHandler): - - def __init__(self, connection, address, server): - super(WebSocketsEchoHandler, self).__init__( - connection, address, server - ) - self.protocol = websockets.WebsocketsProtocol() - self.handshake_done = False - - def handle(self): - while True: - if not self.handshake_done: - self.handshake() - else: - self.read_next_message() - - def read_next_message(self): - frame = websockets.Frame.from_file(self.rfile) - self.on_message(frame.payload) - - def send_message(self, message): - frame = websockets.Frame.default(message, from_client=False) - frame.to_file(self.wfile) - - def handshake(self): - - req = read_request(self.rfile) - key = self.protocol.check_client_handshake(req.headers) - - preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101) - self.wfile.write(preamble.encode() + b"\r\n") - headers = self.protocol.server_handshake_headers(key) - self.wfile.write(str(headers) + "\r\n") - self.wfile.flush() - self.handshake_done = True - - def on_message(self, message): - if message is not None: - self.send_message(message) - - -class WebSocketsClient(tcp.TCPClient): - - def __init__(self, address, source_address=None): - super(WebSocketsClient, self).__init__(address, source_address) - self.protocol = websockets.WebsocketsProtocol() - self.client_nonce = None - - def connect(self): - super(WebSocketsClient, self).connect() - - preamble = b'GET / HTTP/1.1' - self.wfile.write(preamble + b"\r\n") - headers = self.protocol.client_handshake_headers() - self.client_nonce = headers["sec-websocket-key"].encode("ascii") - self.wfile.write(bytes(headers) + b"\r\n") - self.wfile.flush() - - resp = read_response(self.rfile, treq(method=b"GET")) - server_nonce = self.protocol.check_server_handshake(resp.headers) - - if not server_nonce == self.protocol.create_server_nonce(self.client_nonce): - self.close() - - def read_next_message(self): - return websockets.Frame.from_file(self.rfile).payload - - def send_message(self, message): - frame = websockets.Frame.default(message, from_client=True) - frame.to_file(self.wfile) - - -class TestWebSockets(tservers.ServerTestBase): - handler = WebSocketsEchoHandler - - def __init__(self): - self.protocol = websockets.WebsocketsProtocol() - - def random_bytes(self, n=100): - return os.urandom(n) - - def echo(self, msg): - client = WebSocketsClient(("127.0.0.1", self.port)) - client.connect() - client.send_message(msg) - response = client.read_next_message() - assert response == msg - - def test_simple_echo(self): - self.echo(b"hello I'm the client") - - def test_frame_sizes(self): - # length can fit in the the 7 bit payload length - small_msg = self.random_bytes(100) - # 50kb, sligthly larger than can fit in a 7 bit int - medium_msg = self.random_bytes(50000) - # 150kb, slightly larger than can fit in a 16 bit int - large_msg = self.random_bytes(150000) - - self.echo(small_msg) - self.echo(medium_msg) - self.echo(large_msg) - - def test_default_builder(self): - """ - default builder should always generate valid frames - """ - msg = self.random_bytes() - assert websockets.Frame.default(msg, from_client=True) - assert websockets.Frame.default(msg, from_client=False) - - def test_serialization_bijection(self): - """ - Ensure that various frame types can be serialized/deserialized back - and forth between to_bytes() and from_bytes() - """ - for is_client in [True, False]: - for num_bytes in [100, 50000, 150000]: - frame = websockets.Frame.default( - self.random_bytes(num_bytes), is_client - ) - frame2 = websockets.Frame.from_bytes( - frame.to_bytes() - ) - assert frame == frame2 - - bytes = b'\x81\x03cba' - assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes - - def test_check_server_handshake(self): - headers = self.protocol.server_handshake_headers("key") - assert self.protocol.check_server_handshake(headers) - headers["Upgrade"] = "not_websocket" - assert not self.protocol.check_server_handshake(headers) - - def test_check_client_handshake(self): - headers = self.protocol.client_handshake_headers("key") - assert self.protocol.check_client_handshake(headers) == "key" - headers["Upgrade"] = "not_websocket" - assert not self.protocol.check_client_handshake(headers) - - -class BadHandshakeHandler(WebSocketsEchoHandler): - - def handshake(self): - - client_hs = read_request(self.rfile) - self.protocol.check_client_handshake(client_hs.headers) - - preamble = 'HTTP/1.1 101 %s\r\n' % status_codes.RESPONSES.get(101) - self.wfile.write(preamble.encode()) - headers = self.protocol.server_handshake_headers(b"malformed key") - self.wfile.write(bytes(headers) + b"\r\n") - self.wfile.flush() - self.handshake_done = True - - -class TestBadHandshake(tservers.ServerTestBase): - - """ - Ensure that the client disconnects if the server handshake is malformed - """ - handler = BadHandshakeHandler - - def test(self): - with tutils.raises(exceptions.TcpDisconnect): - client = WebSocketsClient(("127.0.0.1", self.port)) - client.connect() - client.send_message(b"hello") - - -class TestFrameHeader: - - def test_roundtrip(self): - def round(*args, **kwargs): - f = websockets.FrameHeader(*args, **kwargs) - f2 = websockets.FrameHeader.from_file(tutils.treader(bytes(f))) - assert f == f2 - round() - round(fin=1) - round(rsv1=1) - round(rsv2=1) - round(rsv3=1) - round(payload_length=1) - round(payload_length=100) - round(payload_length=1000) - round(payload_length=10000) - round(opcode=websockets.OPCODE.PING) - round(masking_key=b"test") - - def test_human_readable(self): - f = websockets.FrameHeader( - masking_key=b"test", - fin=True, - payload_length=10 - ) - assert repr(f) - f = websockets.FrameHeader() - assert repr(f) - - def test_funky(self): - f = websockets.FrameHeader(masking_key=b"test", mask=False) - raw = bytes(f) - f2 = websockets.FrameHeader.from_file(tutils.treader(raw)) - assert not f2.mask - - def test_violations(self): - tutils.raises("opcode", websockets.FrameHeader, opcode=17) - tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x") - - def test_automask(self): - f = websockets.FrameHeader(mask=True) - assert f.masking_key - - f = websockets.FrameHeader(masking_key=b"foob") - assert f.mask - - f = websockets.FrameHeader(masking_key=b"foob", mask=0) - assert not f.mask - assert f.masking_key - - -class TestFrame: - - def test_roundtrip(self): - def round(*args, **kwargs): - f = websockets.Frame(*args, **kwargs) - raw = bytes(f) - f2 = websockets.Frame.from_file(tutils.treader(raw)) - assert f == f2 - round(b"test") - round(b"test", fin=1) - round(b"test", rsv1=1) - round(b"test", opcode=websockets.OPCODE.PING) - round(b"test", masking_key=b"test") - - def test_human_readable(self): - f = websockets.Frame() - assert repr(f) - - -def test_masker(): - tests = [ - [b"a"], - [b"four"], - [b"fourf"], - [b"fourfive"], - [b"a", b"aasdfasdfa", b"asdf"], - [b"a" * 50, b"aasdfasdfa", b"asdf"], - ] - for i in tests: - m = websockets.Masker(b"abcd") - data = b"".join([m(t) for t in i]) - data2 = websockets.Masker(b"abcd")(data) - assert data2 == b"".join(i) |