aboutsummaryrefslogtreecommitdiffstats
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/mitmproxy/protocol/__init__.py0
-rw-r--r--test/mitmproxy/protocol/test_http1.py (renamed from test/mitmproxy/test_protocol_http1.py)4
-rw-r--r--test/mitmproxy/protocol/test_http2.py (renamed from test/mitmproxy/test_protocol_http2.py)10
-rw-r--r--test/mitmproxy/protocol/test_websockets.py299
-rw-r--r--test/mitmproxy/test_examples.py90
-rw-r--r--test/mitmproxy/test_proxy.py26
-rw-r--r--test/mitmproxy/test_server.py21
-rw-r--r--test/netlib/http/test_cookies.py68
-rw-r--r--test/netlib/http/test_headers.py5
-rw-r--r--test/netlib/http/test_message.py10
-rw-r--r--test/netlib/http/test_request.py10
-rw-r--r--test/netlib/http/test_response.py20
-rw-r--r--test/netlib/test_strutils.py1
-rw-r--r--test/netlib/tservers.py3
-rw-r--r--test/netlib/websockets/test_frame.py164
-rw-r--r--test/netlib/websockets/test_masker.py23
-rw-r--r--test/netlib/websockets/test_utils.py105
-rw-r--r--test/netlib/websockets/test_websockets.py269
18 files changed, 812 insertions, 316 deletions
diff --git a/test/mitmproxy/protocol/__init__.py b/test/mitmproxy/protocol/__init__.py
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/test/mitmproxy/protocol/__init__.py
diff --git a/test/mitmproxy/test_protocol_http1.py b/test/mitmproxy/protocol/test_http1.py
index cf7bd598..7d04c56b 100644
--- a/test/mitmproxy/test_protocol_http1.py
+++ b/test/mitmproxy/protocol/test_http1.py
@@ -1,7 +1,9 @@
+from __future__ import (absolute_import, print_function, division)
+
from netlib.http import http1
from netlib.tcp import TCPClient
from netlib.tutils import treq
-from . import tutils, tservers
+from .. import tutils, tservers
class TestHTTPFlow(object):
diff --git a/test/mitmproxy/test_protocol_http2.py b/test/mitmproxy/protocol/test_http2.py
index f0fa9a40..1eabebf1 100644
--- a/test/mitmproxy/test_protocol_http2.py
+++ b/test/mitmproxy/protocol/test_http2.py
@@ -13,11 +13,11 @@ from mitmproxy import options
from mitmproxy.proxy.config import ProxyConfig
import netlib
-from ..netlib import tservers as netlib_tservers
+from ...netlib import tservers as netlib_tservers
from netlib.exceptions import HttpException
from netlib.http.http2 import framereader
-from . import tservers
+from .. import tservers
import logging
logging.getLogger("hyper.packages.hpack.hpack").setLevel(logging.WARNING)
@@ -849,15 +849,15 @@ class TestMaxConcurrentStreams(_Http2Test):
def test_max_concurrent_streams(self):
client, h2_conn = self._setup_connection()
new_streams = [1, 3, 5, 7, 9, 11]
- for id in new_streams:
+ for stream_id in new_streams:
# this will exceed MAX_CONCURRENT_STREAMS on the server connection
# and cause mitmproxy to throttle stream creation to the server
- self._send_request(client.wfile, h2_conn, stream_id=id, headers=[
+ self._send_request(client.wfile, h2_conn, stream_id=stream_id, headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
- ('X-Stream-ID', str(id)),
+ ('X-Stream-ID', str(stream_id)),
])
ended_streams = 0
diff --git a/test/mitmproxy/protocol/test_websockets.py b/test/mitmproxy/protocol/test_websockets.py
new file mode 100644
index 00000000..e7e2684f
--- /dev/null
+++ b/test/mitmproxy/protocol/test_websockets.py
@@ -0,0 +1,299 @@
+from __future__ import absolute_import, print_function, division
+
+import pytest
+import os
+import tempfile
+import traceback
+
+from mitmproxy import options
+from mitmproxy.proxy.config import ProxyConfig
+
+import netlib
+from netlib import http
+from ...netlib import tservers as netlib_tservers
+from .. import tservers
+
+from netlib import websockets
+
+
+class _WebSocketsServerBase(netlib_tservers.ServerTestBase):
+
+ class handler(netlib.tcp.BaseHandler):
+
+ def handle(self):
+ try:
+ request = http.http1.read_request(self.rfile)
+ assert websockets.check_handshake(request.headers)
+
+ response = http.Response(
+ "HTTP/1.1",
+ 101,
+ reason=http.status_codes.RESPONSES.get(101),
+ headers=http.Headers(
+ connection='upgrade',
+ upgrade='websocket',
+ sec_websocket_accept=b'',
+ ),
+ content=b'',
+ )
+ self.wfile.write(http.http1.assemble_response(response))
+ self.wfile.flush()
+
+ self.server.handle_websockets(self.rfile, self.wfile)
+ except:
+ traceback.print_exc()
+
+
+class _WebSocketsTestBase(object):
+
+ @classmethod
+ def setup_class(cls):
+ opts = cls.get_options()
+ cls.config = ProxyConfig(opts)
+
+ tmaster = tservers.TestMaster(opts, cls.config)
+ tmaster.start_app(options.APP_HOST, options.APP_PORT)
+ cls.proxy = tservers.ProxyThread(tmaster)
+ cls.proxy.start()
+
+ @classmethod
+ def teardown_class(cls):
+ cls.proxy.shutdown()
+
+ @classmethod
+ def get_options(cls):
+ opts = options.Options(
+ listen_port=0,
+ no_upstream_cert=False,
+ ssl_insecure=True
+ )
+ opts.cadir = os.path.join(tempfile.gettempdir(), "mitmproxy")
+ return opts
+
+ @property
+ def master(self):
+ return self.proxy.tmaster
+
+ def setup(self):
+ self.master.clear_log()
+ self.master.state.clear()
+ self.server.server.handle_websockets = self.handle_websockets
+
+ def _setup_connection(self):
+ client = netlib.tcp.TCPClient(("127.0.0.1", self.proxy.port))
+ client.connect()
+
+ request = http.Request(
+ "authority",
+ "CONNECT",
+ "",
+ "localhost",
+ self.server.server.address.port,
+ "",
+ "HTTP/1.1",
+ content=b'')
+ client.wfile.write(http.http1.assemble_request(request))
+ client.wfile.flush()
+
+ response = http.http1.read_response(client.rfile, request)
+
+ if self.ssl:
+ client.convert_to_ssl()
+ assert client.ssl_established
+
+ request = http.Request(
+ "relative",
+ "GET",
+ "http",
+ "localhost",
+ self.server.server.address.port,
+ "/ws",
+ "HTTP/1.1",
+ headers=http.Headers(
+ connection="upgrade",
+ upgrade="websocket",
+ sec_websocket_version="13",
+ sec_websocket_key="1234",
+ ),
+ content=b'')
+ client.wfile.write(http.http1.assemble_request(request))
+ client.wfile.flush()
+
+ response = http.http1.read_response(client.rfile, request)
+ assert websockets.check_handshake(response.headers)
+
+ return client
+
+
+class _WebSocketsTest(_WebSocketsTestBase, _WebSocketsServerBase):
+
+ @classmethod
+ def setup_class(cls):
+ _WebSocketsTestBase.setup_class()
+ _WebSocketsServerBase.setup_class(ssl=cls.ssl)
+
+ @classmethod
+ def teardown_class(cls):
+ _WebSocketsTestBase.teardown_class()
+ _WebSocketsServerBase.teardown_class()
+
+
+class TestSimple(_WebSocketsTest):
+
+ @classmethod
+ def handle_websockets(cls, rfile, wfile):
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar')))
+ wfile.flush()
+
+ frame = websockets.Frame.from_file(rfile)
+ wfile.write(bytes(frame))
+ wfile.flush()
+
+ def test_simple(self):
+ client = self._setup_connection()
+
+ frame = websockets.Frame.from_file(client.rfile)
+ assert frame.payload == b'server-foobar'
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar')))
+ client.wfile.flush()
+
+ frame = websockets.Frame.from_file(client.rfile)
+ assert frame.payload == b'client-foobar'
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
+ client.wfile.flush()
+
+
+class TestSimpleTLS(_WebSocketsTest):
+ ssl = True
+
+ @classmethod
+ def handle_websockets(cls, rfile, wfile):
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'server-foobar')))
+ wfile.flush()
+
+ frame = websockets.Frame.from_file(rfile)
+ wfile.write(bytes(frame))
+ wfile.flush()
+
+ def test_simple_tls(self):
+ client = self._setup_connection()
+
+ frame = websockets.Frame.from_file(client.rfile)
+ assert frame.payload == b'server-foobar'
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'client-foobar')))
+ client.wfile.flush()
+
+ frame = websockets.Frame.from_file(client.rfile)
+ assert frame.payload == b'client-foobar'
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
+ client.wfile.flush()
+
+
+class TestPing(_WebSocketsTest):
+
+ @classmethod
+ def handle_websockets(cls, rfile, wfile):
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
+ wfile.flush()
+
+ frame = websockets.Frame.from_file(rfile)
+ assert frame.header.opcode == websockets.OPCODE.PONG
+ assert frame.payload == b'foobar'
+
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.TEXT, payload=b'pong-received')))
+ wfile.flush()
+
+ def test_ping(self):
+ client = self._setup_connection()
+
+ frame = websockets.Frame.from_file(client.rfile)
+ assert frame.header.opcode == websockets.OPCODE.PING
+ assert frame.payload == b'foobar'
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
+ client.wfile.flush()
+
+ frame = websockets.Frame.from_file(client.rfile)
+ assert frame.header.opcode == websockets.OPCODE.TEXT
+ assert frame.payload == b'pong-received'
+
+
+class TestPong(_WebSocketsTest):
+
+ @classmethod
+ def handle_websockets(cls, rfile, wfile):
+ frame = websockets.Frame.from_file(rfile)
+ assert frame.header.opcode == websockets.OPCODE.PING
+ assert frame.payload == b'foobar'
+
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PONG, payload=frame.payload)))
+ wfile.flush()
+
+ def test_pong(self):
+ client = self._setup_connection()
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.PING, payload=b'foobar')))
+ client.wfile.flush()
+
+ frame = websockets.Frame.from_file(client.rfile)
+ assert frame.header.opcode == websockets.OPCODE.PONG
+ assert frame.payload == b'foobar'
+
+
+class TestClose(_WebSocketsTest):
+
+ @classmethod
+ def handle_websockets(cls, rfile, wfile):
+ frame = websockets.Frame.from_file(rfile)
+ wfile.write(bytes(frame))
+ wfile.flush()
+
+ with pytest.raises(netlib.exceptions.TcpDisconnect):
+ websockets.Frame.from_file(rfile)
+
+ def test_close(self):
+ client = self._setup_connection()
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE)))
+ client.wfile.flush()
+
+ with pytest.raises(netlib.exceptions.TcpDisconnect):
+ websockets.Frame.from_file(client.rfile)
+
+ def test_close_payload_1(self):
+ client = self._setup_connection()
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42')))
+ client.wfile.flush()
+
+ with pytest.raises(netlib.exceptions.TcpDisconnect):
+ websockets.Frame.from_file(client.rfile)
+
+ def test_close_payload_2(self):
+ client = self._setup_connection()
+
+ client.wfile.write(bytes(websockets.Frame(fin=1, opcode=websockets.OPCODE.CLOSE, payload=b'\00\42foobar')))
+ client.wfile.flush()
+
+ with pytest.raises(netlib.exceptions.TcpDisconnect):
+ websockets.Frame.from_file(client.rfile)
+
+
+class TestInvalidFrame(_WebSocketsTest):
+
+ @classmethod
+ def handle_websockets(cls, rfile, wfile):
+ wfile.write(bytes(websockets.Frame(fin=1, opcode=15, payload=b'foobar')))
+ wfile.flush()
+
+ def test_invalid_frame(self):
+ client = self._setup_connection()
+
+ # with pytest.raises(netlib.exceptions.TcpDisconnect):
+ frame = websockets.Frame.from_file(client.rfile)
+ assert frame.header.opcode == 15
+ assert frame.payload == b'foobar'
diff --git a/test/mitmproxy/test_examples.py b/test/mitmproxy/test_examples.py
index 6c24ace5..83a37a36 100644
--- a/test/mitmproxy/test_examples.py
+++ b/test/mitmproxy/test_examples.py
@@ -1,16 +1,20 @@
import json
+import os
import six
-import sys
-import os.path
-from mitmproxy.flow import master
-from mitmproxy.flow import state
+
from mitmproxy import options
from mitmproxy import contentviews
from mitmproxy.builtins import script
+from mitmproxy.flow import master
+from mitmproxy.flow import state
+
import netlib.utils
+
from netlib import tutils as netutils
from netlib.http import Headers
+from netlib.http import cookies
+
from . import tutils, mastertest
example_dir = netlib.utils.Data(__name__).push("../../examples")
@@ -98,30 +102,66 @@ class TestScripts(mastertest.MasterTest):
m.request(f)
assert f.request.host == "mitmproxy.org"
- def test_har_extractor(self):
- if sys.version_info >= (3, 0):
- with tutils.raises("does not work on Python 3"):
- tscript("har_extractor.py")
- return
+class TestHARDump():
+
+ def flow(self, resp_content=b'message'):
+ times = dict(
+ timestamp_start=746203272,
+ timestamp_end=746203272,
+ )
+
+ # Create a dummy flow for testing
+ return tutils.tflow(
+ req=netutils.treq(method=b'GET', **times),
+ resp=netutils.tresp(content=resp_content, **times)
+ )
+
+ def test_no_file_arg(self):
with tutils.raises(ScriptError):
- tscript("har_extractor.py")
+ tscript("har_dump.py")
+
+ def test_simple(self):
+ with tutils.tmpdir() as tdir:
+ path = os.path.join(tdir, "somefile")
+
+ m, sc = tscript("har_dump.py", six.moves.shlex_quote(path))
+ m.addons.invoke(m, "response", self.flow())
+ m.addons.remove(sc)
+ with open(path, "r") as inp:
+ har = json.load(inp)
+
+ assert len(har["log"]["entries"]) == 1
+
+ def test_base64(self):
with tutils.tmpdir() as tdir:
- times = dict(
- timestamp_start=746203272,
- timestamp_end=746203272,
- )
-
- path = os.path.join(tdir, "file")
- m, sc = tscript("har_extractor.py", six.moves.shlex_quote(path))
- f = tutils.tflow(
- req=netutils.treq(**times),
- resp=netutils.tresp(**times)
- )
- m.response(f)
+ path = os.path.join(tdir, "somefile")
+
+ m, sc = tscript("har_dump.py", six.moves.shlex_quote(path))
+ m.addons.invoke(m, "response", self.flow(resp_content=b"foo" + b"\xFF" * 10))
m.addons.remove(sc)
- with open(path, "rb") as f:
- test_data = json.load(f)
- assert len(test_data["log"]["pages"]) == 1
+ with open(path, "r") as inp:
+ har = json.load(inp)
+
+ assert har["log"]["entries"][0]["response"]["content"]["encoding"] == "base64"
+
+ def test_format_cookies(self):
+ m, sc = tscript("har_dump.py", "-")
+ format_cookies = sc.ns.ns["format_cookies"]
+
+ CA = cookies.CookieAttrs
+
+ f = format_cookies([("n", "v", CA([("k", "v")]))])[0]
+ assert f['name'] == "n"
+ assert f['value'] == "v"
+ assert not f['httpOnly']
+ assert not f['secure']
+
+ f = format_cookies([("n", "v", CA([("httponly", None), ("secure", None)]))])[0]
+ assert f['httpOnly']
+ assert f['secure']
+
+ f = format_cookies([("n", "v", CA([("expires", "Mon, 24-Aug-2037 00:00:00 GMT")]))])[0]
+ assert f['expires']
diff --git a/test/mitmproxy/test_proxy.py b/test/mitmproxy/test_proxy.py
index 84838018..f7c64e50 100644
--- a/test/mitmproxy/test_proxy.py
+++ b/test/mitmproxy/test_proxy.py
@@ -85,22 +85,22 @@ class TestProcessProxyOptions:
@mock.patch("mitmproxy.platform.resolver")
def test_modes(self, _):
- # self.assert_noerr("-R", "http://localhost")
- # self.assert_err("expected one argument", "-R")
- # self.assert_err("Invalid server specification", "-R", "reverse")
- #
- # self.assert_noerr("-T")
- #
- # self.assert_noerr("-U", "http://localhost")
- # self.assert_err("expected one argument", "-U")
- # self.assert_err("Invalid server specification", "-U", "upstream")
- #
- # self.assert_noerr("--upstream-auth", "test:test")
- # self.assert_err("expected one argument", "--upstream-auth")
+ self.assert_noerr("-R", "http://localhost")
+ self.assert_err("expected one argument", "-R")
+ self.assert_err("Invalid server specification", "-R", "reverse")
+
+ self.assert_noerr("-T")
+
+ self.assert_noerr("-U", "http://localhost")
+ self.assert_err("expected one argument", "-U")
+ self.assert_err("Invalid server specification", "-U", "upstream")
+
+ self.assert_noerr("--upstream-auth", "test:test")
+ self.assert_err("expected one argument", "--upstream-auth")
self.assert_err(
"Invalid upstream auth specification", "--upstream-auth", "test"
)
- # self.assert_err("mutually exclusive", "-R", "http://localhost", "-T")
+ self.assert_err("mutually exclusive", "-R", "http://localhost", "-T")
def test_socks_auth(self):
self.assert_err(
diff --git a/test/mitmproxy/test_server.py b/test/mitmproxy/test_server.py
index 78e9b5c7..e0a8da47 100644
--- a/test/mitmproxy/test_server.py
+++ b/test/mitmproxy/test_server.py
@@ -313,6 +313,22 @@ class TestHTTPAuth(tservers.HTTPProxyTest):
assert ret.status_code == 202
+class TestHTTPReverseAuth(tservers.ReverseProxyTest):
+ def test_auth(self):
+ self.master.options.auth_singleuser = "test:test"
+ assert self.pathod("202").status_code == 401
+ p = self.pathoc()
+ ret = p.request("""
+ get
+ '/p/202'
+ h'%s'='%s'
+ """ % (
+ http.authentication.BasicWebsiteAuth.AUTH_HEADER,
+ authentication.assemble_http_basic_auth("basic", "test", "test")
+ ))
+ assert ret.status_code == 202
+
+
class TestHTTPS(tservers.HTTPProxyTest, CommonMixin, TcpMixin):
ssl = True
ssloptions = pathod.SSLOptions(request_client_cert=True)
@@ -456,6 +472,11 @@ class TestReverse(tservers.ReverseProxyTest, CommonMixin, TcpMixin):
reverse = True
+class TestReverseSSL(tservers.ReverseProxyTest, CommonMixin, TcpMixin):
+ reverse = True
+ ssl = True
+
+
class TestSocks5(tservers.SocksModeTest):
def test_simple(self):
diff --git a/test/netlib/http/test_cookies.py b/test/netlib/http/test_cookies.py
index 17e21b94..efd8ba80 100644
--- a/test/netlib/http/test_cookies.py
+++ b/test/netlib/http/test_cookies.py
@@ -1,6 +1,10 @@
+import time
+
from netlib.http import cookies
from netlib.tutils import raises
+import mock
+
def test_read_token():
tokens = [
@@ -247,6 +251,22 @@ def test_refresh_cookie():
assert cookies.refresh_set_cookie_header(c, 0)
+@mock.patch('time.time')
+def test_get_expiration_ts(*args):
+ # Freeze time
+ now_ts = 17
+ time.time.return_value = now_ts
+
+ CA = cookies.CookieAttrs
+ F = cookies.get_expiration_ts
+
+ assert F(CA([("Expires", "Thu, 01-Jan-1970 00:00:00 GMT")])) == 0
+ assert F(CA([("Expires", "Mon, 24-Aug-2037 00:00:00 GMT")])) == 2134684800
+
+ assert F(CA([("Max-Age", "0")])) == now_ts
+ assert F(CA([("Max-Age", "31")])) == now_ts + 31
+
+
def test_is_expired():
CA = cookies.CookieAttrs
@@ -260,9 +280,53 @@ def test_is_expired():
# or both
assert cookies.is_expired(CA([("Expires", "Thu, 01-Jan-1970 00:00:00 GMT"), ("Max-Age", "0")]))
- assert not cookies.is_expired(CA([("Expires", "Thu, 24-Aug-2063 00:00:00 GMT")]))
+ assert not cookies.is_expired(CA([("Expires", "Mon, 24-Aug-2037 00:00:00 GMT")]))
assert not cookies.is_expired(CA([("Max-Age", "1")]))
- assert not cookies.is_expired(CA([("Expires", "Thu, 15-Jul-2068 00:00:00 GMT"), ("Max-Age", "1")]))
+ assert not cookies.is_expired(CA([("Expires", "Wed, 15-Jul-2037 00:00:00 GMT"), ("Max-Age", "1")]))
assert not cookies.is_expired(CA([("Max-Age", "nan")]))
assert not cookies.is_expired(CA([("Expires", "false")]))
+
+
+def test_group_cookies():
+ CA = cookies.CookieAttrs
+ groups = [
+ [
+ "one=uno; foo=bar; foo=baz",
+ [
+ ('one', 'uno', CA([])),
+ ('foo', 'bar', CA([])),
+ ('foo', 'baz', CA([]))
+ ]
+ ],
+ [
+ "one=uno; Path=/; foo=bar; Max-Age=0; foo=baz; expires=24-08-1993",
+ [
+ ('one', 'uno', CA([('Path', '/')])),
+ ('foo', 'bar', CA([('Max-Age', '0')])),
+ ('foo', 'baz', CA([('expires', '24-08-1993')]))
+ ]
+ ],
+ [
+ "one=uno;",
+ [
+ ('one', 'uno', CA([]))
+ ]
+ ],
+ [
+ "one=uno; Path=/; Max-Age=0; Expires=24-08-1993",
+ [
+ ('one', 'uno', CA([('Path', '/'), ('Max-Age', '0'), ('Expires', '24-08-1993')]))
+ ]
+ ],
+ [
+ "path=val; Path=/",
+ [
+ ('path', 'val', CA([('Path', '/')]))
+ ]
+ ]
+ ]
+
+ for c, expected in groups:
+ observed = cookies.group_cookies(cookies.parse_cookie_header(c))
+ assert observed == expected
diff --git a/test/netlib/http/test_headers.py b/test/netlib/http/test_headers.py
index 51537310..ad2bc548 100644
--- a/test/netlib/http/test_headers.py
+++ b/test/netlib/http/test_headers.py
@@ -75,6 +75,11 @@ class TestHeaders(object):
assert replacements == 0
assert headers["Host"] == "example.com"
+ def test_replace_with_count(self):
+ headers = Headers(Host="foobarfoo.com", Accept="foo/bar")
+ replacements = headers.replace("foo", "bar", count=1)
+ assert replacements == 1
+
def test_parse_content_type():
p = parse_content_type
diff --git a/test/netlib/http/test_message.py b/test/netlib/http/test_message.py
index 12e4706c..74272309 100644
--- a/test/netlib/http/test_message.py
+++ b/test/netlib/http/test_message.py
@@ -99,6 +99,16 @@ class TestMessage(object):
def test_http_version(self):
_test_decoded_attr(tresp(), "http_version")
+ def test_replace(self):
+ r = tresp()
+ r.content = b"foofootoo"
+ r.replace(b"foo", "gg")
+ assert r.content == b"ggggtoo"
+
+ r.content = b"foofootoo"
+ r.replace(b"foo", "gg", count=1)
+ assert r.content == b"ggfootoo"
+
class TestMessageContentEncoding(object):
def test_simple(self):
diff --git a/test/netlib/http/test_request.py b/test/netlib/http/test_request.py
index f3cd8b71..9baabaa6 100644
--- a/test/netlib/http/test_request.py
+++ b/test/netlib/http/test_request.py
@@ -26,6 +26,16 @@ class TestRequestCore(object):
request.host = None
assert repr(request) == "Request(GET /path)"
+ def replace(self):
+ r = treq()
+ r.path = b"foobarfoo"
+ r.replace(b"foo", "bar")
+ assert r.path == b"barbarbar"
+
+ r.path = b"foobarfoo"
+ r.replace(b"foo", "bar", count=1)
+ assert r.path == b"barbarfoo"
+
def test_first_line_format(self):
_test_passthrough_attr(treq(), "first_line_format")
diff --git a/test/netlib/http/test_response.py b/test/netlib/http/test_response.py
index b3c2f736..c7b1b646 100644
--- a/test/netlib/http/test_response.py
+++ b/test/netlib/http/test_response.py
@@ -5,6 +5,7 @@ import email
import time
from netlib.http import Headers
+from netlib.http import Response
from netlib.http.cookies import CookieAttrs
from netlib.tutils import raises, tresp
from .test_message import _test_passthrough_attr, _test_decoded_attr
@@ -28,6 +29,25 @@ class TestResponseCore(object):
response.content = None
assert repr(response) == "Response(200 OK, no content)"
+ def test_make(self):
+ r = Response.make()
+ assert r.status_code == 200
+ assert r.content == b""
+
+ Response.make(content=b"foo")
+ Response.make(content="foo")
+ with raises(TypeError):
+ Response.make(content=42)
+
+ r = Response.make(headers=[(b"foo", b"bar")])
+ assert r.headers["foo"] == "bar"
+
+ r = Response.make(headers=({"foo": "baz"}))
+ assert r.headers["foo"] == "baz"
+
+ with raises(TypeError):
+ Response.make(headers=42)
+
def test_status_code(self):
_test_passthrough_attr(tresp(), "status_code")
diff --git a/test/netlib/test_strutils.py b/test/netlib/test_strutils.py
index 52299e59..5be254a3 100644
--- a/test/netlib/test_strutils.py
+++ b/test/netlib/test_strutils.py
@@ -85,6 +85,7 @@ def test_escaped_str_to_bytes():
def test_is_mostly_bin():
assert not strutils.is_mostly_bin(b"foo\xFF")
assert strutils.is_mostly_bin(b"foo" + b"\xFF" * 10)
+ assert not strutils.is_mostly_bin("")
def test_is_xml():
diff --git a/test/netlib/tservers.py b/test/netlib/tservers.py
index 666f97ac..a80dcb28 100644
--- a/test/netlib/tservers.py
+++ b/test/netlib/tservers.py
@@ -100,7 +100,8 @@ class ServerTestBase(object):
@classmethod
def makeserver(cls, **kwargs):
- return _TServer(cls.ssl, cls.q, cls.handler, cls.addr, **kwargs)
+ ssl = kwargs.pop('ssl', cls.ssl)
+ return _TServer(ssl, cls.q, cls.handler, cls.addr, **kwargs)
@classmethod
def teardown_class(cls):
diff --git a/test/netlib/websockets/test_frame.py b/test/netlib/websockets/test_frame.py
new file mode 100644
index 00000000..cce39454
--- /dev/null
+++ b/test/netlib/websockets/test_frame.py
@@ -0,0 +1,164 @@
+import os
+import codecs
+import pytest
+
+from netlib import websockets
+from netlib import tutils
+
+
+class TestFrameHeader(object):
+
+ @pytest.mark.parametrize("input,expected", [
+ (0, '0100'),
+ (125, '017D'),
+ (126, '017E007E'),
+ (127, '017E007F'),
+ (142, '017E008E'),
+ (65534, '017EFFFE'),
+ (65535, '017EFFFF'),
+ (65536, '017F0000000000010000'),
+ (8589934591, '017F00000001FFFFFFFF'),
+ (2 ** 64 - 1, '017FFFFFFFFFFFFFFFFF'),
+ ])
+ def test_serialization_length(self, input, expected):
+ h = websockets.FrameHeader(
+ opcode=websockets.OPCODE.TEXT,
+ payload_length=input,
+ )
+ assert bytes(h) == codecs.decode(expected, 'hex')
+
+ def test_serialization_too_large(self):
+ h = websockets.FrameHeader(
+ payload_length=2 ** 64 + 1,
+ )
+ with pytest.raises(ValueError):
+ bytes(h)
+
+ @pytest.mark.parametrize("input,expected", [
+ ('0100', 0),
+ ('017D', 125),
+ ('017E007E', 126),
+ ('017E007F', 127),
+ ('017E008E', 142),
+ ('017EFFFE', 65534),
+ ('017EFFFF', 65535),
+ ('017F0000000000010000', 65536),
+ ('017F00000001FFFFFFFF', 8589934591),
+ ('017FFFFFFFFFFFFFFFFF', 2 ** 64 - 1),
+ ])
+ def test_deserialization_length(self, input, expected):
+ h = websockets.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex')))
+ assert h.payload_length == expected
+
+ @pytest.mark.parametrize("input,expected", [
+ ('0100', (False, None)),
+ ('018000000000', (True, '00000000')),
+ ('018012345678', (True, '12345678')),
+ ])
+ def test_deserialization_masking(self, input, expected):
+ h = websockets.FrameHeader.from_file(tutils.treader(codecs.decode(input, 'hex')))
+ assert h.mask == expected[0]
+ if h.mask:
+ assert h.masking_key == codecs.decode(expected[1], 'hex')
+
+ def test_equality(self):
+ h = websockets.FrameHeader(mask=True, masking_key=b'1234')
+ h2 = websockets.FrameHeader(mask=True, masking_key=b'1234')
+ assert h == h2
+
+ h = websockets.FrameHeader(fin=True)
+ h2 = websockets.FrameHeader(fin=False)
+ assert h != h2
+
+ assert h != 'foobar'
+
+ def test_roundtrip(self):
+ def round(*args, **kwargs):
+ h = websockets.FrameHeader(*args, **kwargs)
+ h2 = websockets.FrameHeader.from_file(tutils.treader(bytes(h)))
+ assert h == h2
+
+ round()
+ round(fin=True)
+ round(rsv1=True)
+ round(rsv2=True)
+ round(rsv3=True)
+ round(payload_length=1)
+ round(payload_length=100)
+ round(payload_length=1000)
+ round(payload_length=10000)
+ round(opcode=websockets.OPCODE.PING)
+ round(masking_key=b"test")
+
+ def test_human_readable(self):
+ f = websockets.FrameHeader(
+ masking_key=b"test",
+ fin=True,
+ payload_length=10
+ )
+ assert repr(f)
+
+ f = websockets.FrameHeader()
+ assert repr(f)
+
+ def test_funky(self):
+ f = websockets.FrameHeader(masking_key=b"test", mask=False)
+ raw = bytes(f)
+ f2 = websockets.FrameHeader.from_file(tutils.treader(raw))
+ assert not f2.mask
+
+ def test_violations(self):
+ tutils.raises("opcode", websockets.FrameHeader, opcode=17)
+ tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x")
+
+ def test_automask(self):
+ f = websockets.FrameHeader(mask=True)
+ assert f.masking_key
+
+ f = websockets.FrameHeader(masking_key=b"foob")
+ assert f.mask
+
+ f = websockets.FrameHeader(masking_key=b"foob", mask=0)
+ assert not f.mask
+ assert f.masking_key
+
+
+class TestFrame(object):
+ def test_equality(self):
+ f = websockets.Frame(payload=b'1234')
+ f2 = websockets.Frame(payload=b'1234')
+ assert f == f2
+
+ assert f != b'1234'
+
+ def test_roundtrip(self):
+ def round(*args, **kwargs):
+ f = websockets.Frame(*args, **kwargs)
+ raw = bytes(f)
+ f2 = websockets.Frame.from_file(tutils.treader(raw))
+ assert f == f2
+ round(b"test")
+ round(b"test", fin=1)
+ round(b"test", rsv1=1)
+ round(b"test", opcode=websockets.OPCODE.PING)
+ round(b"test", masking_key=b"test")
+
+ def test_human_readable(self):
+ f = websockets.Frame()
+ assert repr(f)
+
+ f = websockets.Frame(b"foobar")
+ assert "foobar" in repr(f)
+
+ @pytest.mark.parametrize("masked", [True, False])
+ @pytest.mark.parametrize("length", [100, 50000, 150000])
+ def test_serialization_bijection(self, masked, length):
+ frame = websockets.Frame(
+ os.urandom(length),
+ fin=True,
+ opcode=websockets.OPCODE.TEXT,
+ mask=int(masked),
+ masking_key=(os.urandom(4) if masked else None)
+ )
+ serialized = bytes(frame)
+ assert frame == websockets.Frame.from_bytes(serialized)
diff --git a/test/netlib/websockets/test_masker.py b/test/netlib/websockets/test_masker.py
new file mode 100644
index 00000000..528fce71
--- /dev/null
+++ b/test/netlib/websockets/test_masker.py
@@ -0,0 +1,23 @@
+import codecs
+import pytest
+
+from netlib import websockets
+
+
+class TestMasker(object):
+
+ @pytest.mark.parametrize("input,expected", [
+ ([b"a"], '00'),
+ ([b"four"], '070d1616'),
+ ([b"fourf"], '070d161607'),
+ ([b"fourfive"], '070d1616070b1501'),
+ ([b"a", b"aasdfasdfa", b"asdf"], '000302170504021705040205120605'),
+ ([b"a" * 50, b"aasdfasdfa", b"asdf"], '00030205000302050003020500030205000302050003020500030205000302050003020500030205000302050003020500030205120605051206050500110702'), # noqa
+ ])
+ def test_masker(self, input, expected):
+ m = websockets.Masker(b"abcd")
+ data = b"".join([m(t) for t in input])
+ assert data == codecs.decode(expected, 'hex')
+
+ data = websockets.Masker(b"abcd")(data)
+ assert data == b"".join(input)
diff --git a/test/netlib/websockets/test_utils.py b/test/netlib/websockets/test_utils.py
new file mode 100644
index 00000000..34765e04
--- /dev/null
+++ b/test/netlib/websockets/test_utils.py
@@ -0,0 +1,105 @@
+import pytest
+
+from netlib import http
+from netlib import websockets
+
+
+class TestUtils(object):
+
+ def test_client_handshake_headers(self):
+ h = websockets.client_handshake_headers(version='42')
+ assert h['sec-websocket-version'] == '42'
+
+ h = websockets.client_handshake_headers(key='some-key')
+ assert h['sec-websocket-key'] == 'some-key'
+
+ h = websockets.client_handshake_headers(protocol='foobar')
+ assert h['sec-websocket-protocol'] == 'foobar'
+
+ h = websockets.client_handshake_headers(extensions='foo; bar')
+ assert h['sec-websocket-extensions'] == 'foo; bar'
+
+ def test_server_handshake_headers(self):
+ h = websockets.server_handshake_headers('some-key')
+ assert h['sec-websocket-accept'] == '8iILEZtcVdtFD7MDlPKip9ec9nw='
+ assert 'sec-websocket-protocol' not in h
+ assert 'sec-websocket-extensions' not in h
+
+ h = websockets.server_handshake_headers('some-key', 'foobar', 'foo; bar')
+ assert h['sec-websocket-accept'] == '8iILEZtcVdtFD7MDlPKip9ec9nw='
+ assert h['sec-websocket-protocol'] == 'foobar'
+ assert h['sec-websocket-extensions'] == 'foo; bar'
+
+ @pytest.mark.parametrize("input,expected", [
+ ([(b'connection', b'upgrade'), (b'upgrade', b'websocket'), (b'sec-websocket-key', b'foobar')], True),
+ ([(b'connection', b'upgrade'), (b'upgrade', b'websocket'), (b'sec-websocket-accept', b'foobar')], True),
+ ([(b'Connection', b'UpgRaDe'), (b'Upgrade', b'WebSocKeT'), (b'Sec-WebSockeT-KeY', b'foobar')], True),
+ ([(b'Connection', b'UpgRaDe'), (b'Upgrade', b'WebSocKeT'), (b'Sec-WebSockeT-AccePt', b'foobar')], True),
+ ([(b'connection', b'foo'), (b'upgrade', b'bar'), (b'sec-websocket-key', b'foobar')], False),
+ ([(b'connection', b'upgrade'), (b'upgrade', b'websocket')], False),
+ ([(b'connection', b'upgrade'), (b'sec-websocket-key', b'foobar')], False),
+ ([(b'upgrade', b'websocket'), (b'sec-websocket-key', b'foobar')], False),
+ ([], False),
+ ])
+ def test_check_handshake(self, input, expected):
+ h = http.Headers(input)
+ assert websockets.check_handshake(h) == expected
+
+ @pytest.mark.parametrize("input,expected", [
+ ([(b'sec-websocket-version', b'13')], True),
+ ([(b'Sec-WebSockeT-VerSion', b'13')], True),
+ ([(b'sec-websocket-version', b'9')], False),
+ ([(b'sec-websocket-version', b'42')], False),
+ ([(b'sec-websocket-version', b'')], False),
+ ([], False),
+ ])
+ def test_check_client_version(self, input, expected):
+ h = http.Headers(input)
+ assert websockets.check_client_version(h) == expected
+
+ @pytest.mark.parametrize("input,expected", [
+ ('foobar', b'AzhRPA4TNwR6I/riJheN0TfR7+I='),
+ (b'foobar', b'AzhRPA4TNwR6I/riJheN0TfR7+I='),
+ ])
+ def test_create_server_nonce(self, input, expected):
+ assert websockets.create_server_nonce(input) == expected
+
+ @pytest.mark.parametrize("input,expected", [
+ ([(b'sec-websocket-extensions', b'foo; bar')], 'foo; bar'),
+ ([(b'Sec-WebSockeT-ExteNsionS', b'foo; bar')], 'foo; bar'),
+ ([(b'sec-websocket-extensions', b'')], ''),
+ ([], None),
+ ])
+ def test_get_extensions(self, input, expected):
+ h = http.Headers(input)
+ assert websockets.get_extensions(h) == expected
+
+ @pytest.mark.parametrize("input,expected", [
+ ([(b'sec-websocket-protocol', b'foobar')], 'foobar'),
+ ([(b'Sec-WebSockeT-ProTocoL', b'foobar')], 'foobar'),
+ ([(b'sec-websocket-protocol', b'')], ''),
+ ([], None),
+ ])
+ def test_get_protocol(self, input, expected):
+ h = http.Headers(input)
+ assert websockets.get_protocol(h) == expected
+
+ @pytest.mark.parametrize("input,expected", [
+ ([(b'sec-websocket-key', b'foobar')], 'foobar'),
+ ([(b'Sec-WebSockeT-KeY', b'foobar')], 'foobar'),
+ ([(b'sec-websocket-key', b'')], ''),
+ ([], None),
+ ])
+ def test_get_client_key(self, input, expected):
+ h = http.Headers(input)
+ assert websockets.get_client_key(h) == expected
+
+ @pytest.mark.parametrize("input,expected", [
+ ([(b'sec-websocket-accept', b'foobar')], 'foobar'),
+ ([(b'Sec-WebSockeT-AccepT', b'foobar')], 'foobar'),
+ ([(b'sec-websocket-accept', b'')], ''),
+ ([], None),
+ ])
+ def test_get_server_accept(self, input, expected):
+ h = http.Headers(input)
+ assert websockets.get_server_accept(h) == expected
diff --git a/test/netlib/websockets/test_websockets.py b/test/netlib/websockets/test_websockets.py
deleted file mode 100644
index 50fa26e6..00000000
--- a/test/netlib/websockets/test_websockets.py
+++ /dev/null
@@ -1,269 +0,0 @@
-import os
-
-from netlib.http.http1 import read_response, read_request
-
-from netlib import tcp
-from netlib import tutils
-from netlib import websockets
-from netlib.http import status_codes
-from netlib.tutils import treq
-from netlib import exceptions
-
-from .. import tservers
-
-
-class WebSocketsEchoHandler(tcp.BaseHandler):
-
- def __init__(self, connection, address, server):
- super(WebSocketsEchoHandler, self).__init__(
- connection, address, server
- )
- self.protocol = websockets.WebsocketsProtocol()
- self.handshake_done = False
-
- def handle(self):
- while True:
- if not self.handshake_done:
- self.handshake()
- else:
- self.read_next_message()
-
- def read_next_message(self):
- frame = websockets.Frame.from_file(self.rfile)
- self.on_message(frame.payload)
-
- def send_message(self, message):
- frame = websockets.Frame.default(message, from_client=False)
- frame.to_file(self.wfile)
-
- def handshake(self):
-
- req = read_request(self.rfile)
- key = self.protocol.check_client_handshake(req.headers)
-
- preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)
- self.wfile.write(preamble.encode() + b"\r\n")
- headers = self.protocol.server_handshake_headers(key)
- self.wfile.write(str(headers) + "\r\n")
- self.wfile.flush()
- self.handshake_done = True
-
- def on_message(self, message):
- if message is not None:
- self.send_message(message)
-
-
-class WebSocketsClient(tcp.TCPClient):
-
- def __init__(self, address, source_address=None):
- super(WebSocketsClient, self).__init__(address, source_address)
- self.protocol = websockets.WebsocketsProtocol()
- self.client_nonce = None
-
- def connect(self):
- super(WebSocketsClient, self).connect()
-
- preamble = b'GET / HTTP/1.1'
- self.wfile.write(preamble + b"\r\n")
- headers = self.protocol.client_handshake_headers()
- self.client_nonce = headers["sec-websocket-key"].encode("ascii")
- self.wfile.write(bytes(headers) + b"\r\n")
- self.wfile.flush()
-
- resp = read_response(self.rfile, treq(method=b"GET"))
- server_nonce = self.protocol.check_server_handshake(resp.headers)
-
- if not server_nonce == self.protocol.create_server_nonce(self.client_nonce):
- self.close()
-
- def read_next_message(self):
- return websockets.Frame.from_file(self.rfile).payload
-
- def send_message(self, message):
- frame = websockets.Frame.default(message, from_client=True)
- frame.to_file(self.wfile)
-
-
-class TestWebSockets(tservers.ServerTestBase):
- handler = WebSocketsEchoHandler
-
- def __init__(self):
- self.protocol = websockets.WebsocketsProtocol()
-
- def random_bytes(self, n=100):
- return os.urandom(n)
-
- def echo(self, msg):
- client = WebSocketsClient(("127.0.0.1", self.port))
- client.connect()
- client.send_message(msg)
- response = client.read_next_message()
- assert response == msg
-
- def test_simple_echo(self):
- self.echo(b"hello I'm the client")
-
- def test_frame_sizes(self):
- # length can fit in the the 7 bit payload length
- small_msg = self.random_bytes(100)
- # 50kb, sligthly larger than can fit in a 7 bit int
- medium_msg = self.random_bytes(50000)
- # 150kb, slightly larger than can fit in a 16 bit int
- large_msg = self.random_bytes(150000)
-
- self.echo(small_msg)
- self.echo(medium_msg)
- self.echo(large_msg)
-
- def test_default_builder(self):
- """
- default builder should always generate valid frames
- """
- msg = self.random_bytes()
- assert websockets.Frame.default(msg, from_client=True)
- assert websockets.Frame.default(msg, from_client=False)
-
- def test_serialization_bijection(self):
- """
- Ensure that various frame types can be serialized/deserialized back
- and forth between to_bytes() and from_bytes()
- """
- for is_client in [True, False]:
- for num_bytes in [100, 50000, 150000]:
- frame = websockets.Frame.default(
- self.random_bytes(num_bytes), is_client
- )
- frame2 = websockets.Frame.from_bytes(
- frame.to_bytes()
- )
- assert frame == frame2
-
- bytes = b'\x81\x03cba'
- assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes
-
- def test_check_server_handshake(self):
- headers = self.protocol.server_handshake_headers("key")
- assert self.protocol.check_server_handshake(headers)
- headers["Upgrade"] = "not_websocket"
- assert not self.protocol.check_server_handshake(headers)
-
- def test_check_client_handshake(self):
- headers = self.protocol.client_handshake_headers("key")
- assert self.protocol.check_client_handshake(headers) == "key"
- headers["Upgrade"] = "not_websocket"
- assert not self.protocol.check_client_handshake(headers)
-
-
-class BadHandshakeHandler(WebSocketsEchoHandler):
-
- def handshake(self):
-
- client_hs = read_request(self.rfile)
- self.protocol.check_client_handshake(client_hs.headers)
-
- preamble = 'HTTP/1.1 101 %s\r\n' % status_codes.RESPONSES.get(101)
- self.wfile.write(preamble.encode())
- headers = self.protocol.server_handshake_headers(b"malformed key")
- self.wfile.write(bytes(headers) + b"\r\n")
- self.wfile.flush()
- self.handshake_done = True
-
-
-class TestBadHandshake(tservers.ServerTestBase):
-
- """
- Ensure that the client disconnects if the server handshake is malformed
- """
- handler = BadHandshakeHandler
-
- def test(self):
- with tutils.raises(exceptions.TcpDisconnect):
- client = WebSocketsClient(("127.0.0.1", self.port))
- client.connect()
- client.send_message(b"hello")
-
-
-class TestFrameHeader:
-
- def test_roundtrip(self):
- def round(*args, **kwargs):
- f = websockets.FrameHeader(*args, **kwargs)
- f2 = websockets.FrameHeader.from_file(tutils.treader(bytes(f)))
- assert f == f2
- round()
- round(fin=1)
- round(rsv1=1)
- round(rsv2=1)
- round(rsv3=1)
- round(payload_length=1)
- round(payload_length=100)
- round(payload_length=1000)
- round(payload_length=10000)
- round(opcode=websockets.OPCODE.PING)
- round(masking_key=b"test")
-
- def test_human_readable(self):
- f = websockets.FrameHeader(
- masking_key=b"test",
- fin=True,
- payload_length=10
- )
- assert repr(f)
- f = websockets.FrameHeader()
- assert repr(f)
-
- def test_funky(self):
- f = websockets.FrameHeader(masking_key=b"test", mask=False)
- raw = bytes(f)
- f2 = websockets.FrameHeader.from_file(tutils.treader(raw))
- assert not f2.mask
-
- def test_violations(self):
- tutils.raises("opcode", websockets.FrameHeader, opcode=17)
- tutils.raises("masking key", websockets.FrameHeader, masking_key=b"x")
-
- def test_automask(self):
- f = websockets.FrameHeader(mask=True)
- assert f.masking_key
-
- f = websockets.FrameHeader(masking_key=b"foob")
- assert f.mask
-
- f = websockets.FrameHeader(masking_key=b"foob", mask=0)
- assert not f.mask
- assert f.masking_key
-
-
-class TestFrame:
-
- def test_roundtrip(self):
- def round(*args, **kwargs):
- f = websockets.Frame(*args, **kwargs)
- raw = bytes(f)
- f2 = websockets.Frame.from_file(tutils.treader(raw))
- assert f == f2
- round(b"test")
- round(b"test", fin=1)
- round(b"test", rsv1=1)
- round(b"test", opcode=websockets.OPCODE.PING)
- round(b"test", masking_key=b"test")
-
- def test_human_readable(self):
- f = websockets.Frame()
- assert repr(f)
-
-
-def test_masker():
- tests = [
- [b"a"],
- [b"four"],
- [b"fourf"],
- [b"fourfive"],
- [b"a", b"aasdfasdfa", b"asdf"],
- [b"a" * 50, b"aasdfasdfa", b"asdf"],
- ]
- for i in tests:
- m = websockets.Masker(b"abcd")
- data = b"".join([m(t) for t in i])
- data2 = websockets.Masker(b"abcd")(data)
- assert data2 == b"".join(i)