aboutsummaryrefslogtreecommitdiffstats
path: root/test
diff options
context:
space:
mode:
authorMaximilian Hils <git@maximilianhils.com>2015-09-16 00:04:23 +0200
committerMaximilian Hils <git@maximilianhils.com>2015-09-16 00:04:23 +0200
commita077d8877d210562f703c23e9625e8467c81222d (patch)
tree47608f9f99d149634f6c5dcb755bdf534a096d45 /test
parent11e7f476bd4bbcd6d072fa3659f628ae3a19705d (diff)
downloadmitmproxy-a077d8877d210562f703c23e9625e8467c81222d.tar.gz
mitmproxy-a077d8877d210562f703c23e9625e8467c81222d.tar.bz2
mitmproxy-a077d8877d210562f703c23e9625e8467c81222d.zip
finish netlib.http.http1 refactor
Diffstat (limited to 'test')
-rw-r--r--test/http/http1/test_assemble.py91
-rw-r--r--test/http/http1/test_protocol.py466
-rw-r--r--test/http/http1/test_read.py313
-rw-r--r--test/http/http2/test_frames.py2
-rw-r--r--test/http/http2/test_protocol.py16
-rw-r--r--test/http/test_exceptions.py6
-rw-r--r--test/http/test_models.py (renamed from test/http/test_semantics.py)163
-rw-r--r--test/websockets/test_websockets.py16
8 files changed, 485 insertions, 588 deletions
diff --git a/test/http/http1/test_assemble.py b/test/http/http1/test_assemble.py
new file mode 100644
index 00000000..8a0a54f1
--- /dev/null
+++ b/test/http/http1/test_assemble.py
@@ -0,0 +1,91 @@
+from __future__ import absolute_import, print_function, division
+from netlib.exceptions import HttpException
+from netlib.http import CONTENT_MISSING, Headers
+from netlib.http.http1.assemble import (
+ assemble_request, assemble_request_head, assemble_response,
+ assemble_response_head, _assemble_request_line, _assemble_request_headers,
+ _assemble_response_headers
+)
+from netlib.tutils import treq, raises, tresp
+
+
+def test_assemble_request():
+ c = assemble_request(treq()) == (
+ b"GET /path HTTP/1.1\r\n"
+ b"header: qvalue\r\n"
+ b"Host: address:22\r\n"
+ b"Content-Length: 7\r\n"
+ b"\r\n"
+ b"content"
+ )
+
+ with raises(HttpException):
+ assemble_request(treq(body=CONTENT_MISSING))
+
+
+def test_assemble_request_head():
+ c = assemble_request_head(treq())
+ assert b"GET" in c
+ assert b"qvalue" in c
+ assert b"content" not in c
+
+
+def test_assemble_response():
+ c = assemble_response(tresp()) == (
+ b"HTTP/1.1 200 OK\r\n"
+ b"header-response: svalue\r\n"
+ b"Content-Length: 7\r\n"
+ b"\r\n"
+ b"message"
+ )
+
+ with raises(HttpException):
+ assemble_response(tresp(body=CONTENT_MISSING))
+
+
+def test_assemble_response_head():
+ c = assemble_response_head(tresp())
+ assert b"200" in c
+ assert b"svalue" in c
+ assert b"message" not in c
+
+
+def test_assemble_request_line():
+ assert _assemble_request_line(treq()) == b"GET /path HTTP/1.1"
+
+ authority_request = treq(method=b"CONNECT", form_in="authority")
+ assert _assemble_request_line(authority_request) == b"CONNECT address:22 HTTP/1.1"
+
+ absolute_request = treq(form_in="absolute")
+ assert _assemble_request_line(absolute_request) == b"GET http://address:22/path HTTP/1.1"
+
+ with raises(RuntimeError):
+ _assemble_request_line(treq(), "invalid_form")
+
+
+def test_assemble_request_headers():
+ # https://github.com/mitmproxy/mitmproxy/issues/186
+ r = treq(body=b"")
+ r.headers[b"Transfer-Encoding"] = b"chunked"
+ c = _assemble_request_headers(r)
+ assert b"Content-Length" in c
+ assert b"Transfer-Encoding" not in c
+
+ assert b"Host" in _assemble_request_headers(treq(headers=Headers()))
+
+ assert b"Proxy-Connection" not in _assemble_request_headers(
+ treq(headers=Headers(Proxy_Connection="42"))
+ )
+
+
+def test_assemble_response_headers():
+ # https://github.com/mitmproxy/mitmproxy/issues/186
+ r = tresp(body=b"")
+ r.headers["Transfer-Encoding"] = b"chunked"
+ c = _assemble_response_headers(r)
+ assert b"Content-Length" in c
+ assert b"Transfer-Encoding" not in c
+
+ assert b"Proxy-Connection" not in _assemble_response_headers(
+ tresp(headers=Headers(Proxy_Connection=b"42"))
+ )
diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py
index bdcba5cb..e69de29b 100644
--- a/test/http/http1/test_protocol.py
+++ b/test/http/http1/test_protocol.py
@@ -1,466 +0,0 @@
-from io import BytesIO
-import textwrap
-from http.http1.protocol import _parse_authority_form
-from netlib.exceptions import HttpSyntaxException, HttpReadDisconnect, HttpException
-
-from netlib import http, tcp, tutils
-from netlib.http import semantics, Headers
-from netlib.http.http1 import HTTP1Protocol, read_message_body, read_request, \
- read_message_body_chunked, expected_http_body_size
-from ... import tservers
-
-
-class NoContentLengthHTTPHandler(tcp.BaseHandler):
- def handle(self):
- self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n")
- self.wfile.flush()
-
-
-def mock_protocol(data=''):
- rfile = BytesIO(data)
- wfile = BytesIO()
- return HTTP1Protocol(rfile=rfile, wfile=wfile)
-
-
-def match_http_string(data):
- return textwrap.dedent(data).strip().replace('\n', '\r\n')
-
-
-def test_stripped_chunked_encoding_no_content():
- """
- https://github.com/mitmproxy/mitmproxy/issues/186
- """
-
- r = tutils.treq(content="")
- r.headers["Transfer-Encoding"] = "chunked"
- assert "Content-Length" in mock_protocol()._assemble_request_headers(r)
-
- r = tutils.tresp(content="")
- r.headers["Transfer-Encoding"] = "chunked"
- assert "Content-Length" in mock_protocol()._assemble_response_headers(r)
-
-
-def test_read_chunked():
- req = tutils.treq(None)
- req.headers["Transfer-Encoding"] = "chunked"
-
- data = b"1\r\na\r\n0\r\n"
- with tutils.raises(HttpSyntaxException):
- read_message_body(BytesIO(data), req)
-
- data = b"1\r\na\r\n0\r\n\r\n"
- assert read_message_body(BytesIO(data), req) == b"a"
-
- data = b"\r\n\r\n1\r\na\r\n1\r\nb\r\n0\r\n\r\n"
- assert read_message_body(BytesIO(data), req) == b"ab"
-
- data = b"\r\n"
- with tutils.raises("closed prematurely"):
- read_message_body(BytesIO(data), req)
-
- data = b"1\r\nfoo"
- with tutils.raises("malformed chunked body"):
- read_message_body(BytesIO(data), req)
-
- data = b"foo\r\nfoo"
- with tutils.raises(HttpSyntaxException):
- read_message_body(BytesIO(data), req)
-
- data = b"5\r\naaaaa\r\n0\r\n\r\n"
- with tutils.raises("too large"):
- read_message_body(BytesIO(data), req, limit=2)
-
-
-def test_connection_close():
- headers = Headers()
- assert HTTP1Protocol.connection_close((1, 0), headers)
- assert not HTTP1Protocol.connection_close((1, 1), headers)
-
- headers["connection"] = "keep-alive"
- assert not HTTP1Protocol.connection_close((1, 1), headers)
-
- headers["connection"] = "close"
- assert HTTP1Protocol.connection_close((1, 1), headers)
-
-
-def test_read_http_body_request():
- headers = Headers()
- data = "testing"
- assert mock_protocol(data).read_http_body(headers, None, "GET", None, True) == ""
-
-
-def test_read_http_body_response():
- headers = Headers()
- data = "testing"
- assert mock_protocol(data).read_http_body(headers, None, "GET", 200, False) == "testing"
-
-
-def test_read_http_body():
- # test default case
- headers = Headers()
- headers["content-length"] = "7"
- data = "testing"
- assert mock_protocol(data).read_http_body(headers, None, "GET", 200, False) == "testing"
-
- # test content length: invalid header
- headers["content-length"] = "foo"
- data = "testing"
- tutils.raises(
- http.HttpError,
- mock_protocol(data).read_http_body,
- headers, None, "GET", 200, False
- )
-
- # test content length: invalid header #2
- headers["content-length"] = "-1"
- data = "testing"
- tutils.raises(
- http.HttpError,
- mock_protocol(data).read_http_body,
- headers, None, "GET", 200, False
- )
-
- # test content length: content length > actual content
- headers["content-length"] = "5"
- data = "testing"
- tutils.raises(
- http.HttpError,
- mock_protocol(data).read_http_body,
- headers, 4, "GET", 200, False
- )
-
- # test content length: content length < actual content
- data = "testing"
- assert len(mock_protocol(data).read_http_body(headers, None, "GET", 200, False)) == 5
-
- # test no content length: limit > actual content
- headers = Headers()
- data = "testing"
- assert len(mock_protocol(data).read_http_body(headers, 100, "GET", 200, False)) == 7
-
- # test no content length: limit < actual content
- data = "testing"
- tutils.raises(
- http.HttpError,
- mock_protocol(data).read_http_body,
- headers, 4, "GET", 200, False
- )
-
- # test chunked
- headers = Headers()
- headers["transfer-encoding"] = "chunked"
- data = "5\r\naaaaa\r\n0\r\n\r\n"
- assert mock_protocol(data).read_http_body(headers, 100, "GET", 200, False) == "aaaaa"
-
-
-def test_expected_http_body_size():
- # gibber in the content-length field
- headers = Headers(content_length="foo")
- with tutils.raises(HttpSyntaxException):
- expected_http_body_size(headers, False, "GET", 200) is None
- # negative number in the content-length field
- headers = Headers(content_length="-7")
- with tutils.raises(HttpSyntaxException):
- expected_http_body_size(headers, False, "GET", 200) is None
- # explicit length
- headers = Headers(content_length="5")
- assert expected_http_body_size(headers, False, "GET", 200) == 5
- # no length
- headers = Headers()
- assert expected_http_body_size(headers, False, "GET", 200) == -1
- # no length request
- headers = Headers()
- assert expected_http_body_size(headers, True, "GET", None) == 0
- # expect header
- headers = Headers(content_length="5", expect="100-continue")
- assert expected_http_body_size(headers, True, "GET", None) == 0
-
-
-def test_parse_init_connect():
- assert _parse_authority_form(b"CONNECT host.com:443 HTTP/1.0")
- tutils.raises(ValueError,_parse_authority_form, b"\0host.com:443")
- tutils.raises(ValueError,_parse_authority_form, b"host.com:444444")
- tutils.raises(ValueError,_parse_authority_form, b"CONNECT host.com443 HTTP/1.0")
- tutils.raises(ValueError,_parse_authority_form, b"CONNECT host.com:foo HTTP/1.0")
-
-
-def test_parse_init_proxy():
- u = b"GET http://foo.com:8888/test HTTP/1.1"
- m, s, h, po, pa, httpversion = HTTP1Protocol._parse_absolute_form(u)
- assert m == "GET"
- assert s == "http"
- assert h == "foo.com"
- assert po == 8888
- assert pa == "/test"
- assert httpversion == (1, 1)
-
- u = "G\xfeET http://foo.com:8888/test HTTP/1.1"
- assert not HTTP1Protocol._parse_absolute_form(u)
-
- with tutils.raises(ValueError):
- assert not HTTP1Protocol._parse_absolute_form("invalid")
- with tutils.raises(ValueError):
- assert not HTTP1Protocol._parse_absolute_form("GET invalid HTTP/1.1")
- with tutils.raises(ValueError):
- assert not HTTP1Protocol._parse_absolute_form("GET http://foo.com:8888/test foo/1.1")
-
-
-def test_parse_init_http():
- u = "GET /test HTTP/1.1"
- m, u, httpversion = HTTP1Protocol._parse_init_http(u)
- assert m == "GET"
- assert u == "/test"
- assert httpversion == (1, 1)
-
- u = "G\xfeET /test HTTP/1.1"
- assert not HTTP1Protocol._parse_init_http(u)
-
- assert not HTTP1Protocol._parse_init_http("invalid")
- assert not HTTP1Protocol._parse_init_http("GET invalid HTTP/1.1")
- assert not HTTP1Protocol._parse_init_http("GET /test foo/1.1")
- assert not HTTP1Protocol._parse_init_http("GET /test\xc0 HTTP/1.1")
-
-
-class TestReadHeaders:
-
- def _read(self, data, verbatim=False):
- if not verbatim:
- data = textwrap.dedent(data)
- data = data.strip()
- return mock_protocol(data).read_headers()
-
- def test_read_simple(self):
- data = """
- Header: one
- Header2: two
- \r\n
- """
- headers = self._read(data)
- assert headers.fields == [["Header", "one"], ["Header2", "two"]]
-
- def test_read_multi(self):
- data = """
- Header: one
- Header: two
- \r\n
- """
- headers = self._read(data)
- assert headers.fields == [["Header", "one"], ["Header", "two"]]
-
- def test_read_continued(self):
- data = """
- Header: one
- \ttwo
- Header2: three
- \r\n
- """
- headers = self._read(data)
- assert headers.fields == [["Header", "one\r\n two"], ["Header2", "three"]]
-
- def test_read_continued_err(self):
- data = "\tfoo: bar\r\n"
- assert self._read(data, True) is None
-
- def test_read_err(self):
- data = """
- foo
- """
- assert self._read(data) is None
-
-
-class TestReadRequest(object):
-
- def tst(self, data, **kwargs):
- return mock_protocol(data).read_request(**kwargs)
-
- def test_invalid(self):
- tutils.raises(
- "bad http request",
- self.tst,
- "xxx"
- )
- tutils.raises(
- "bad http request line",
- self.tst,
- "get /\xff HTTP/1.1"
- )
- tutils.raises(
- "invalid headers",
- self.tst,
- "get / HTTP/1.1\r\nfoo"
- )
- tutils.raises(
- HttpReadDisconnect,
- self.tst,
- "\r\n"
- )
-
- def test_asterisk_form_in(self):
- v = self.tst("OPTIONS * HTTP/1.1")
- assert v.form_in == "relative"
- assert v.method == "OPTIONS"
-
- def test_absolute_form_in(self):
- tutils.raises(
- "Bad HTTP request line",
- self.tst,
- "GET oops-no-protocol.com HTTP/1.1"
- )
- v = self.tst("GET http://address:22/ HTTP/1.1")
- assert v.form_in == "absolute"
- assert v.port == 22
- assert v.host == "address"
- assert v.scheme == "http"
-
- def test_connect(self):
- tutils.raises(
- "Bad HTTP request line",
- self.tst,
- "CONNECT oops-no-port.com HTTP/1.1"
- )
- v = self.tst("CONNECT foo.com:443 HTTP/1.1")
- assert v.form_in == "authority"
- assert v.method == "CONNECT"
- assert v.port == 443
- assert v.host == "foo.com"
-
- def test_expect(self):
- data = (
- b"GET / HTTP/1.1\r\n"
- b"Content-Length: 3\r\n"
- b"Expect: 100-continue\r\n"
- b"\r\n"
- b"foobar"
- )
-
- rfile = BytesIO(data)
- r = read_request(rfile)
- assert r.body == b""
- assert rfile.read(-1) == b"foobar"
-
-
-class TestReadResponse(object):
- def tst(self, data, method, body_size_limit, include_body=True):
- data = textwrap.dedent(data)
- return mock_protocol(data).read_response(
- method, body_size_limit, include_body=include_body
- )
-
- def test_errors(self):
- tutils.raises("server disconnect", self.tst, "", "GET", None)
- tutils.raises("invalid server response", self.tst, "foo", "GET", None)
-
- def test_simple(self):
- data = """
- HTTP/1.1 200
- """
- assert self.tst(data, "GET", None) == http.Response(
- (1, 1), 200, '', Headers(), ''
- )
-
- def test_simple_message(self):
- data = """
- HTTP/1.1 200 OK
- """
- assert self.tst(data, "GET", None) == http.Response(
- (1, 1), 200, 'OK', Headers(), ''
- )
-
- def test_invalid_http_version(self):
- data = """
- HTTP/x 200 OK
- """
- tutils.raises("invalid http version", self.tst, data, "GET", None)
-
- def test_invalid_status_code(self):
- data = """
- HTTP/1.1 xx OK
- """
- tutils.raises("invalid server response", self.tst, data, "GET", None)
-
- def test_valid_with_continue(self):
- data = """
- HTTP/1.1 100 CONTINUE
-
- HTTP/1.1 200 OK
- """
- assert self.tst(data, "GET", None) == http.Response(
- (1, 1), 100, 'CONTINUE', Headers(), ''
- )
-
- def test_simple_body(self):
- data = """
- HTTP/1.1 200 OK
- Content-Length: 3
-
- foo
- """
- assert self.tst(data, "GET", None).body == 'foo'
- assert self.tst(data, "HEAD", None).body == ''
-
- def test_invalid_headers(self):
- data = """
- HTTP/1.1 200 OK
- \tContent-Length: 3
-
- foo
- """
- tutils.raises("invalid headers", self.tst, data, "GET", None)
-
- def test_without_body(self):
- data = """
- HTTP/1.1 200 OK
- Content-Length: 3
-
- foo
- """
- assert self.tst(data, "GET", None, include_body=False).body is None
-
-
-class TestReadResponseNoContentLength(tservers.ServerTestBase):
- handler = NoContentLengthHTTPHandler
-
- def test_no_content_length(self):
- c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- resp = HTTP1Protocol(c).read_response("GET", None)
- assert resp.body == "bar\r\n\r\n"
-
-
-class TestAssembleRequest(object):
- def test_simple(self):
- req = tutils.treq()
- b = HTTP1Protocol().assemble_request(req)
- assert b == match_http_string("""
- GET /path HTTP/1.1
- header: qvalue
- Host: address:22
- Content-Length: 7
-
- content""")
-
- def test_body_missing(self):
- req = tutils.treq(content=semantics.CONTENT_MISSING)
- tutils.raises(http.HttpError, HTTP1Protocol().assemble_request, req)
-
- def test_not_a_request(self):
- tutils.raises(AssertionError, HTTP1Protocol().assemble_request, 'foo')
-
-
-class TestAssembleResponse(object):
- def test_simple(self):
- resp = tutils.tresp()
- b = HTTP1Protocol().assemble_response(resp)
- assert b == match_http_string("""
- HTTP/1.1 200 OK
- header_response: svalue
- Content-Length: 7
-
- message""")
-
- def test_body_missing(self):
- resp = tutils.tresp(content=semantics.CONTENT_MISSING)
- tutils.raises(http.HttpError, HTTP1Protocol().assemble_response, resp)
-
- def test_not_a_request(self):
- tutils.raises(AssertionError, HTTP1Protocol().assemble_response, 'foo')
diff --git a/test/http/http1/test_read.py b/test/http/http1/test_read.py
new file mode 100644
index 00000000..5e6680af
--- /dev/null
+++ b/test/http/http1/test_read.py
@@ -0,0 +1,313 @@
+from __future__ import absolute_import, print_function, division
+from io import BytesIO
+import textwrap
+
+from mock import Mock
+
+from netlib.exceptions import HttpException, HttpSyntaxException, HttpReadDisconnect
+from netlib.http import Headers
+from netlib.http.http1.read import (
+ read_request, read_response, read_request_head,
+ read_response_head, read_body, connection_close, expected_http_body_size, _get_first_line,
+ _read_request_line, _parse_authority_form, _read_response_line, _check_http_version,
+ _read_headers, _read_chunked
+)
+from netlib.tutils import treq, tresp, raises
+
+
+def test_read_request():
+ rfile = BytesIO(b"GET / HTTP/1.1\r\n\r\nskip")
+ r = read_request(rfile)
+ assert r.method == b"GET"
+ assert r.body == b""
+ assert r.timestamp_end
+ assert rfile.read() == b"skip"
+
+
+def test_read_request_head():
+ rfile = BytesIO(
+ b"GET / HTTP/1.1\r\n"
+ b"Content-Length: 4\r\n"
+ b"\r\n"
+ b"skip"
+ )
+ rfile.reset_timestamps = Mock()
+ rfile.first_byte_timestamp = 42
+ r = read_request_head(rfile)
+ assert r.method == b"GET"
+ assert r.headers["Content-Length"] == b"4"
+ assert r.body is None
+ assert rfile.reset_timestamps.called
+ assert r.timestamp_start == 42
+ assert rfile.read() == b"skip"
+
+
+def test_read_response():
+ req = treq()
+ rfile = BytesIO(b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody")
+ r = read_response(rfile, req)
+ assert r.status_code == 418
+ assert r.body == b"body"
+ assert r.timestamp_end
+
+
+def test_read_response_head():
+ rfile = BytesIO(
+ b"HTTP/1.1 418 I'm a teapot\r\n"
+ b"Content-Length: 4\r\n"
+ b"\r\n"
+ b"skip"
+ )
+ rfile.reset_timestamps = Mock()
+ rfile.first_byte_timestamp = 42
+ r = read_response_head(rfile)
+ assert r.status_code == 418
+ assert r.headers["Content-Length"] == b"4"
+ assert r.body is None
+ assert rfile.reset_timestamps.called
+ assert r.timestamp_start == 42
+ assert rfile.read() == b"skip"
+
+
+class TestReadBody(object):
+ def test_chunked(self):
+ rfile = BytesIO(b"3\r\nfoo\r\n0\r\n\r\nbar")
+ body = b"".join(read_body(rfile, None))
+ assert body == b"foo"
+ assert rfile.read() == b"bar"
+
+
+ def test_known_size(self):
+ rfile = BytesIO(b"foobar")
+ body = b"".join(read_body(rfile, 3))
+ assert body == b"foo"
+ assert rfile.read() == b"bar"
+
+
+ def test_known_size_limit(self):
+ rfile = BytesIO(b"foobar")
+ with raises(HttpException):
+ b"".join(read_body(rfile, 3, 2))
+
+ def test_known_size_too_short(self):
+ rfile = BytesIO(b"foo")
+ with raises(HttpException):
+ b"".join(read_body(rfile, 6))
+
+ def test_unknown_size(self):
+ rfile = BytesIO(b"foobar")
+ body = b"".join(read_body(rfile, -1))
+ assert body == b"foobar"
+
+
+ def test_unknown_size_limit(self):
+ rfile = BytesIO(b"foobar")
+ with raises(HttpException):
+ b"".join(read_body(rfile, -1, 3))
+
+
+def test_connection_close():
+ headers = Headers()
+ assert connection_close((1, 0), headers)
+ assert not connection_close((1, 1), headers)
+
+ headers["connection"] = "keep-alive"
+ assert not connection_close((1, 1), headers)
+
+ headers["connection"] = "close"
+ assert connection_close((1, 1), headers)
+
+
+def test_expected_http_body_size():
+ # Expect: 100-continue
+ assert expected_http_body_size(
+ treq(headers=Headers(expect=b"100-continue", content_length=b"42"))
+ ) == 0
+
+ # http://tools.ietf.org/html/rfc7230#section-3.3
+ assert expected_http_body_size(
+ treq(method=b"HEAD"),
+ tresp(headers=Headers(content_length=b"42"))
+ ) == 0
+ assert expected_http_body_size(
+ treq(method=b"CONNECT"),
+ tresp()
+ ) == 0
+ for code in (100, 204, 304):
+ assert expected_http_body_size(
+ treq(),
+ tresp(status_code=code)
+ ) == 0
+
+ # chunked
+ assert expected_http_body_size(
+ treq(headers=Headers(transfer_encoding=b"chunked")),
+ ) is None
+
+ # explicit length
+ for l in (b"foo", b"-7"):
+ with raises(HttpSyntaxException):
+ expected_http_body_size(
+ treq(headers=Headers(content_length=l))
+ )
+ assert expected_http_body_size(
+ treq(headers=Headers(content_length=b"42"))
+ ) == 42
+
+ # no length
+ assert expected_http_body_size(
+ treq()
+ ) == 0
+ assert expected_http_body_size(
+ treq(), tresp()
+ ) == -1
+
+
+def test_get_first_line():
+ rfile = BytesIO(b"foo\r\nbar")
+ assert _get_first_line(rfile) == b"foo"
+
+ rfile = BytesIO(b"\r\nfoo\r\nbar")
+ assert _get_first_line(rfile) == b"foo"
+
+ with raises(HttpReadDisconnect):
+ rfile = BytesIO(b"")
+ _get_first_line(rfile)
+
+ with raises(HttpSyntaxException):
+ rfile = BytesIO(b"GET /\xff HTTP/1.1")
+ _get_first_line(rfile)
+
+
+def test_read_request_line():
+ def t(b):
+ return _read_request_line(BytesIO(b))
+
+ assert (t(b"GET / HTTP/1.1") ==
+ ("relative", b"GET", None, None, None, b"/", b"HTTP/1.1"))
+ assert (t(b"OPTIONS * HTTP/1.1") ==
+ ("relative", b"OPTIONS", None, None, None, b"*", b"HTTP/1.1"))
+ assert (t(b"CONNECT foo:42 HTTP/1.1") ==
+ ("authority", b"CONNECT", None, b"foo", 42, None, b"HTTP/1.1"))
+ assert (t(b"GET http://foo:42/bar HTTP/1.1") ==
+ ("absolute", b"GET", b"http", b"foo", 42, b"/bar", b"HTTP/1.1"))
+
+ with raises(HttpSyntaxException):
+ t(b"GET / WTF/1.1")
+ with raises(HttpSyntaxException):
+ t(b"this is not http")
+
+
+def test_parse_authority_form():
+ assert _parse_authority_form(b"foo:42") == (b"foo", 42)
+ with raises(HttpSyntaxException):
+ _parse_authority_form(b"foo")
+ with raises(HttpSyntaxException):
+ _parse_authority_form(b"foo:bar")
+ with raises(HttpSyntaxException):
+ _parse_authority_form(b"foo:99999999")
+ with raises(HttpSyntaxException):
+ _parse_authority_form(b"f\x00oo:80")
+
+
+def test_read_response_line():
+ def t(b):
+ return _read_response_line(BytesIO(b))
+
+ assert t(b"HTTP/1.1 200 OK") == (b"HTTP/1.1", 200, b"OK")
+ assert t(b"HTTP/1.1 200") == (b"HTTP/1.1", 200, b"")
+ with raises(HttpSyntaxException):
+ assert t(b"HTTP/1.1")
+
+ with raises(HttpSyntaxException):
+ t(b"HTTP/1.1 OK OK")
+ with raises(HttpSyntaxException):
+ t(b"WTF/1.1 200 OK")
+
+
+def test_check_http_version():
+ _check_http_version(b"HTTP/0.9")
+ _check_http_version(b"HTTP/1.0")
+ _check_http_version(b"HTTP/1.1")
+ _check_http_version(b"HTTP/2.0")
+ with raises(HttpSyntaxException):
+ _check_http_version(b"WTF/1.0")
+ with raises(HttpSyntaxException):
+ _check_http_version(b"HTTP/1.10")
+ with raises(HttpSyntaxException):
+ _check_http_version(b"HTTP/1.b")
+
+
+class TestReadHeaders(object):
+ @staticmethod
+ def _read(data):
+ return _read_headers(BytesIO(data))
+
+ def test_read_simple(self):
+ data = (
+ b"Header: one\r\n"
+ b"Header2: two\r\n"
+ b"\r\n"
+ )
+ headers = self._read(data)
+ assert headers.fields == [[b"Header", b"one"], [b"Header2", b"two"]]
+
+ def test_read_multi(self):
+ data = (
+ b"Header: one\r\n"
+ b"Header: two\r\n"
+ b"\r\n"
+ )
+ headers = self._read(data)
+ assert headers.fields == [[b"Header", b"one"], [b"Header", b"two"]]
+
+ def test_read_continued(self):
+ data = (
+ b"Header: one\r\n"
+ b"\ttwo\r\n"
+ b"Header2: three\r\n"
+ b"\r\n"
+ )
+ headers = self._read(data)
+ assert headers.fields == [[b"Header", b"one\r\n two"], [b"Header2", b"three"]]
+
+ def test_read_continued_err(self):
+ data = b"\tfoo: bar\r\n"
+ with raises(HttpSyntaxException):
+ self._read(data)
+
+ def test_read_err(self):
+ data = b"foo"
+ with raises(HttpSyntaxException):
+ self._read(data)
+
+
+def test_read_chunked():
+ req = treq(body=None)
+ req.headers["Transfer-Encoding"] = "chunked"
+
+ data = b"1\r\na\r\n0\r\n"
+ with raises(HttpSyntaxException):
+ b"".join(_read_chunked(BytesIO(data)))
+
+ data = b"1\r\na\r\n0\r\n\r\n"
+ assert b"".join(_read_chunked(BytesIO(data))) == b"a"
+
+ data = b"\r\n\r\n1\r\na\r\n1\r\nb\r\n0\r\n\r\n"
+ assert b"".join(_read_chunked(BytesIO(data))) == b"ab"
+
+ data = b"\r\n"
+ with raises("closed prematurely"):
+ b"".join(_read_chunked(BytesIO(data)))
+
+ data = b"1\r\nfoo"
+ with raises("malformed chunked body"):
+ b"".join(_read_chunked(BytesIO(data)))
+
+ data = b"foo\r\nfoo"
+ with raises(HttpSyntaxException):
+ b"".join(_read_chunked(BytesIO(data)))
+
+ data = b"5\r\naaaaa\r\n0\r\n\r\n"
+ with raises("too large"):
+ b"".join(_read_chunked(BytesIO(data), limit=2))
diff --git a/test/http/http2/test_frames.py b/test/http/http2/test_frames.py
index efdb55e2..4c89b023 100644
--- a/test/http/http2/test_frames.py
+++ b/test/http/http2/test_frames.py
@@ -39,7 +39,7 @@ def test_too_large_frames():
flags=Frame.FLAG_END_STREAM,
stream_id=0x1234567,
payload='foobar' * 3000)
- tutils.raises(FrameSizeError, f.to_bytes)
+ tutils.raises(HttpSyntaxException, f.to_bytes)
def test_data_frame_to_bytes():
diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py
index 2b7d7958..789b6e63 100644
--- a/test/http/http2/test_protocol.py
+++ b/test/http/http2/test_protocol.py
@@ -2,21 +2,21 @@ import OpenSSL
import mock
from netlib import tcp, http, tutils
-from netlib.http import http2, Headers
-from netlib.http.http2 import HTTP2Protocol
+from netlib.http import Headers
+from netlib.http.http2.connections import HTTP2Protocol, TCPHandler
from netlib.http.http2.frame import *
from ... import tservers
class TestTCPHandlerWrapper:
def test_wrapped(self):
- h = http2.TCPHandler(rfile='foo', wfile='bar')
+ h = TCPHandler(rfile='foo', wfile='bar')
p = HTTP2Protocol(h)
assert p.tcp_handler.rfile == 'foo'
assert p.tcp_handler.wfile == 'bar'
def test_direct(self):
p = HTTP2Protocol(rfile='foo', wfile='bar')
- assert isinstance(p.tcp_handler, http2.TCPHandler)
+ assert isinstance(p.tcp_handler, TCPHandler)
assert p.tcp_handler.rfile == 'foo'
assert p.tcp_handler.wfile == 'bar'
@@ -32,8 +32,8 @@ class EchoHandler(tcp.BaseHandler):
class TestProtocol:
- @mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface")
- @mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface")
+ @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_server_connection_preface")
+ @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_client_connection_preface")
def test_perform_connection_preface(self, mock_client_method, mock_server_method):
protocol = HTTP2Protocol(is_server=False)
protocol.connection_preface_performed = True
@@ -46,8 +46,8 @@ class TestProtocol:
assert mock_client_method.called
assert not mock_server_method.called
- @mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface")
- @mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface")
+ @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_server_connection_preface")
+ @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_client_connection_preface")
def test_perform_connection_preface_server(self, mock_client_method, mock_server_method):
protocol = HTTP2Protocol(is_server=True)
protocol.connection_preface_performed = True
diff --git a/test/http/test_exceptions.py b/test/http/test_exceptions.py
deleted file mode 100644
index 49588d0a..00000000
--- a/test/http/test_exceptions.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from netlib.http.exceptions import *
-
-class TestHttpError:
- def test_simple(self):
- e = HttpError(404, "Not found")
- assert str(e)
diff --git a/test/http/test_semantics.py b/test/http/test_models.py
index 44d3c85e..0f4dcc3b 100644
--- a/test/http/test_semantics.py
+++ b/test/http/test_models.py
@@ -1,32 +1,11 @@
import mock
-from netlib import http
-from netlib import odict
from netlib import tutils
from netlib import utils
-from netlib.http import semantics
-from netlib.http.semantics import CONTENT_MISSING
-
-class TestProtocolMixin(object):
- @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_response")
- @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_request")
- def test_assemble_request(self, mock_request_method, mock_response_method):
- p = semantics.ProtocolMixin()
- p.assemble(tutils.treq())
- assert mock_request_method.called
- assert not mock_response_method.called
-
- @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_response")
- @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_request")
- def test_assemble_response(self, mock_request_method, mock_response_method):
- p = semantics.ProtocolMixin()
- p.assemble(tutils.tresp())
- assert not mock_request_method.called
- assert mock_response_method.called
-
- def test_assemble_foo(self):
- p = semantics.ProtocolMixin()
- tutils.raises(ValueError, p.assemble, 'foo')
+from netlib.odict import ODict, ODictCaseless
+from netlib.http import Request, Response, Headers, CONTENT_MISSING, HDR_FORM_URLENCODED, \
+ HDR_FORM_MULTIPART
+
class TestRequest(object):
def test_repr(self):
@@ -34,7 +13,7 @@ class TestRequest(object):
assert repr(r)
def test_headers(self):
- tutils.raises(AssertionError, semantics.Request,
+ tutils.raises(AssertionError, Request,
'form_in',
'method',
'scheme',
@@ -45,7 +24,7 @@ class TestRequest(object):
'foobar',
)
- req = semantics.Request(
+ req = Request(
'form_in',
'method',
'scheme',
@@ -54,7 +33,7 @@ class TestRequest(object):
'path',
(1, 1),
)
- assert isinstance(req.headers, http.Headers)
+ assert isinstance(req.headers, Headers)
def test_equal(self):
a = tutils.treq()
@@ -66,13 +45,6 @@ class TestRequest(object):
assert not 'foo' == a
assert not 'foo' == b
- def test_legacy_first_line(self):
- req = tutils.treq()
-
- assert req.legacy_first_line('relative') == "GET /path HTTP/1.1"
- assert req.legacy_first_line('authority') == "GET address:22 HTTP/1.1"
- assert req.legacy_first_line('absolute') == "GET http://address:22/path HTTP/1.1"
- tutils.raises(http.HttpError, req.legacy_first_line, 'foobar')
def test_anticache(self):
req = tutils.treq()
@@ -103,44 +75,44 @@ class TestRequest(object):
def test_get_form(self):
req = tutils.treq()
- assert req.get_form() == odict.ODict()
+ assert req.get_form() == ODict()
- @mock.patch("netlib.http.semantics.Request.get_form_multipart")
- @mock.patch("netlib.http.semantics.Request.get_form_urlencoded")
+ @mock.patch("netlib.http.Request.get_form_multipart")
+ @mock.patch("netlib.http.Request.get_form_urlencoded")
def test_get_form_with_url_encoded(self, mock_method_urlencoded, mock_method_multipart):
req = tutils.treq()
- assert req.get_form() == odict.ODict()
+ assert req.get_form() == ODict()
req = tutils.treq()
req.body = "foobar"
- req.headers["Content-Type"] = semantics.HDR_FORM_URLENCODED
+ req.headers["Content-Type"] = HDR_FORM_URLENCODED
req.get_form()
assert req.get_form_urlencoded.called
assert not req.get_form_multipart.called
- @mock.patch("netlib.http.semantics.Request.get_form_multipart")
- @mock.patch("netlib.http.semantics.Request.get_form_urlencoded")
+ @mock.patch("netlib.http.Request.get_form_multipart")
+ @mock.patch("netlib.http.Request.get_form_urlencoded")
def test_get_form_with_multipart(self, mock_method_urlencoded, mock_method_multipart):
req = tutils.treq()
req.body = "foobar"
- req.headers["Content-Type"] = semantics.HDR_FORM_MULTIPART
+ req.headers["Content-Type"] = HDR_FORM_MULTIPART
req.get_form()
assert not req.get_form_urlencoded.called
assert req.get_form_multipart.called
def test_get_form_urlencoded(self):
- req = tutils.treq("foobar")
- assert req.get_form_urlencoded() == odict.ODict()
+ req = tutils.treq(body="foobar")
+ assert req.get_form_urlencoded() == ODict()
- req.headers["Content-Type"] = semantics.HDR_FORM_URLENCODED
- assert req.get_form_urlencoded() == odict.ODict(utils.urldecode(req.body))
+ req.headers["Content-Type"] = HDR_FORM_URLENCODED
+ assert req.get_form_urlencoded() == ODict(utils.urldecode(req.body))
def test_get_form_multipart(self):
- req = tutils.treq("foobar")
- assert req.get_form_multipart() == odict.ODict()
+ req = tutils.treq(body="foobar")
+ assert req.get_form_multipart() == ODict()
- req.headers["Content-Type"] = semantics.HDR_FORM_MULTIPART
- assert req.get_form_multipart() == odict.ODict(
+ req.headers["Content-Type"] = HDR_FORM_MULTIPART
+ assert req.get_form_multipart() == ODict(
utils.multipartdecode(
req.headers,
req.body
@@ -149,8 +121,8 @@ class TestRequest(object):
def test_set_form_urlencoded(self):
req = tutils.treq()
- req.set_form_urlencoded(odict.ODict([('foo', 'bar'), ('rab', 'oof')]))
- assert req.headers["Content-Type"] == semantics.HDR_FORM_URLENCODED
+ req.set_form_urlencoded(ODict([('foo', 'bar'), ('rab', 'oof')]))
+ assert req.headers["Content-Type"] == HDR_FORM_URLENCODED
assert req.body
def test_get_path_components(self):
@@ -172,7 +144,7 @@ class TestRequest(object):
def test_set_query(self):
req = tutils.treq()
- req.set_query(odict.ODict([]))
+ req.set_query(ODict([]))
def test_pretty_host(self):
r = tutils.treq()
@@ -203,21 +175,21 @@ class TestRequest(object):
assert req.pretty_url(False) == "http://address:22/path"
def test_get_cookies_none(self):
- headers = http.Headers()
+ headers = Headers()
r = tutils.treq()
r.headers = headers
assert len(r.get_cookies()) == 0
def test_get_cookies_single(self):
r = tutils.treq()
- r.headers = http.Headers(cookie="cookiename=cookievalue")
+ r.headers = Headers(cookie="cookiename=cookievalue")
result = r.get_cookies()
assert len(result) == 1
assert result['cookiename'] == ['cookievalue']
def test_get_cookies_double(self):
r = tutils.treq()
- r.headers = http.Headers(cookie="cookiename=cookievalue;othercookiename=othercookievalue")
+ r.headers = Headers(cookie="cookiename=cookievalue;othercookiename=othercookievalue")
result = r.get_cookies()
assert len(result) == 2
assert result['cookiename'] == ['cookievalue']
@@ -225,7 +197,7 @@ class TestRequest(object):
def test_get_cookies_withequalsign(self):
r = tutils.treq()
- r.headers = http.Headers(cookie="cookiename=coo=kievalue;othercookiename=othercookievalue")
+ r.headers = Headers(cookie="cookiename=coo=kievalue;othercookiename=othercookievalue")
result = r.get_cookies()
assert len(result) == 2
assert result['cookiename'] == ['coo=kievalue']
@@ -233,14 +205,14 @@ class TestRequest(object):
def test_set_cookies(self):
r = tutils.treq()
- r.headers = http.Headers(cookie="cookiename=cookievalue")
+ r.headers = Headers(cookie="cookiename=cookievalue")
result = r.get_cookies()
result["cookiename"] = ["foo"]
r.set_cookies(result)
assert r.get_cookies()["cookiename"] == ["foo"]
def test_set_url(self):
- r = tutils.treq_absolute()
+ r = tutils.treq(form_in="absolute")
r.url = "https://otheraddress:42/ORLY"
assert r.scheme == "https"
assert r.host == "otheraddress"
@@ -332,24 +304,19 @@ class TestRequest(object):
# "Host: address\r\n"
# "Content-Length: 0\r\n\r\n")
-class TestEmptyRequest(object):
- def test_init(self):
- req = semantics.EmptyRequest()
- assert req
-
class TestResponse(object):
def test_headers(self):
- tutils.raises(AssertionError, semantics.Response,
+ tutils.raises(AssertionError, Response,
(1, 1),
200,
headers='foobar',
)
- resp = semantics.Response(
+ resp = Response(
(1, 1),
200,
)
- assert isinstance(resp.headers, http.Headers)
+ assert isinstance(resp.headers, Headers)
def test_equal(self):
a = tutils.tresp()
@@ -366,24 +333,24 @@ class TestResponse(object):
assert "unknown content type" in repr(r)
r.headers["content-type"] = "foo"
assert "foo" in repr(r)
- assert repr(tutils.tresp(content=CONTENT_MISSING))
+ assert repr(tutils.tresp(body=CONTENT_MISSING))
def test_get_cookies_none(self):
resp = tutils.tresp()
- resp.headers = http.Headers()
+ resp.headers = Headers()
assert not resp.get_cookies()
def test_get_cookies_simple(self):
resp = tutils.tresp()
- resp.headers = http.Headers(set_cookie="cookiename=cookievalue")
+ resp.headers = Headers(set_cookie="cookiename=cookievalue")
result = resp.get_cookies()
assert len(result) == 1
assert "cookiename" in result
- assert result["cookiename"][0] == ["cookievalue", odict.ODict()]
+ assert result["cookiename"][0] == ["cookievalue", ODict()]
def test_get_cookies_with_parameters(self):
resp = tutils.tresp()
- resp.headers = http.Headers(set_cookie="cookiename=cookievalue;domain=example.com;expires=Wed Oct 21 16:29:41 2015;path=/; HttpOnly")
+ resp.headers = Headers(set_cookie="cookiename=cookievalue;domain=example.com;expires=Wed Oct 21 16:29:41 2015;path=/; HttpOnly")
result = resp.get_cookies()
assert len(result) == 1
assert "cookiename" in result
@@ -397,7 +364,7 @@ class TestResponse(object):
def test_get_cookies_no_value(self):
resp = tutils.tresp()
- resp.headers = http.Headers(set_cookie="cookiename=; Expires=Thu, 01-Jan-1970 00:00:01 GMT; path=/")
+ resp.headers = Headers(set_cookie="cookiename=; Expires=Thu, 01-Jan-1970 00:00:01 GMT; path=/")
result = resp.get_cookies()
assert len(result) == 1
assert "cookiename" in result
@@ -406,31 +373,31 @@ class TestResponse(object):
def test_get_cookies_twocookies(self):
resp = tutils.tresp()
- resp.headers = http.Headers([
+ resp.headers = Headers([
["Set-Cookie", "cookiename=cookievalue"],
["Set-Cookie", "othercookie=othervalue"]
])
result = resp.get_cookies()
assert len(result) == 2
assert "cookiename" in result
- assert result["cookiename"][0] == ["cookievalue", odict.ODict()]
+ assert result["cookiename"][0] == ["cookievalue", ODict()]
assert "othercookie" in result
- assert result["othercookie"][0] == ["othervalue", odict.ODict()]
+ assert result["othercookie"][0] == ["othervalue", ODict()]
def test_set_cookies(self):
resp = tutils.tresp()
v = resp.get_cookies()
- v.add("foo", ["bar", odict.ODictCaseless()])
+ v.add("foo", ["bar", ODictCaseless()])
resp.set_cookies(v)
v = resp.get_cookies()
assert len(v) == 1
- assert v["foo"] == [["bar", odict.ODictCaseless()]]
+ assert v["foo"] == [["bar", ODictCaseless()]]
class TestHeaders(object):
def _2host(self):
- return semantics.Headers(
+ return Headers(
[
["Host", "example.com"],
["host", "example.org"]
@@ -438,25 +405,25 @@ class TestHeaders(object):
)
def test_init(self):
- headers = semantics.Headers()
+ headers = Headers()
assert len(headers) == 0
- headers = semantics.Headers([["Host", "example.com"]])
+ headers = Headers([["Host", "example.com"]])
assert len(headers) == 1
assert headers["Host"] == "example.com"
- headers = semantics.Headers(Host="example.com")
+ headers = Headers(Host="example.com")
assert len(headers) == 1
assert headers["Host"] == "example.com"
- headers = semantics.Headers(
+ headers = Headers(
[["Host", "invalid"]],
Host="example.com"
)
assert len(headers) == 1
assert headers["Host"] == "example.com"
- headers = semantics.Headers(
+ headers = Headers(
[["Host", "invalid"], ["Accept", "text/plain"]],
Host="example.com"
)
@@ -465,7 +432,7 @@ class TestHeaders(object):
assert headers["Accept"] == "text/plain"
def test_getitem(self):
- headers = semantics.Headers(Host="example.com")
+ headers = Headers(Host="example.com")
assert headers["Host"] == "example.com"
assert headers["host"] == "example.com"
tutils.raises(KeyError, headers.__getitem__, "Accept")
@@ -474,17 +441,17 @@ class TestHeaders(object):
assert headers["Host"] == "example.com, example.org"
def test_str(self):
- headers = semantics.Headers(Host="example.com")
+ headers = Headers(Host="example.com")
assert bytes(headers) == "Host: example.com\r\n"
- headers = semantics.Headers([
+ headers = Headers([
["Host", "example.com"],
["Accept", "text/plain"]
])
assert str(headers) == "Host: example.com\r\nAccept: text/plain\r\n"
def test_setitem(self):
- headers = semantics.Headers()
+ headers = Headers()
headers["Host"] = "example.com"
assert "Host" in headers
assert "host" in headers
@@ -507,7 +474,7 @@ class TestHeaders(object):
assert "Host" in headers
def test_delitem(self):
- headers = semantics.Headers(Host="example.com")
+ headers = Headers(Host="example.com")
assert len(headers) == 1
del headers["host"]
assert len(headers) == 0
@@ -523,7 +490,7 @@ class TestHeaders(object):
assert len(headers) == 0
def test_keys(self):
- headers = semantics.Headers(Host="example.com")
+ headers = Headers(Host="example.com")
assert len(headers.keys()) == 1
assert headers.keys()[0] == "Host"
@@ -532,13 +499,13 @@ class TestHeaders(object):
assert headers.keys()[0] == "Host"
def test_eq_ne(self):
- headers1 = semantics.Headers(Host="example.com")
- headers2 = semantics.Headers(host="example.com")
+ headers1 = Headers(Host="example.com")
+ headers2 = Headers(host="example.com")
assert not (headers1 == headers2)
assert headers1 != headers2
- headers1 = semantics.Headers(Host="example.com")
- headers2 = semantics.Headers(Host="example.com")
+ headers1 = Headers(Host="example.com")
+ headers2 = Headers(Host="example.com")
assert headers1 == headers2
assert not (headers1 != headers2)
@@ -550,7 +517,7 @@ class TestHeaders(object):
assert headers.get_all("accept") == []
def test_set_all(self):
- headers = semantics.Headers(Host="example.com")
+ headers = Headers(Host="example.com")
headers.set_all("Accept", ["text/plain"])
assert len(headers) == 2
assert "accept" in headers
@@ -565,9 +532,9 @@ class TestHeaders(object):
def test_state(self):
headers = self._2host()
assert len(headers.get_state()) == 2
- assert headers == semantics.Headers.from_state(headers.get_state())
+ assert headers == Headers.from_state(headers.get_state())
- headers2 = semantics.Headers()
+ headers2 = Headers()
assert headers != headers2
headers2.load_state(headers.get_state())
assert headers == headers2
diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py
index 57cfd166..3fdeb683 100644
--- a/test/websockets/test_websockets.py
+++ b/test/websockets/test_websockets.py
@@ -1,11 +1,13 @@
import os
from nose.tools import raises
+from netlib.http.http1 import read_response, read_request
from netlib import tcp, tutils, websockets, http
from netlib.http import status_codes
-from netlib.http.exceptions import *
-from netlib.http.http1 import HTTP1Protocol
+from netlib.tutils import treq
+
+from netlib.exceptions import *
from .. import tservers
@@ -34,9 +36,8 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
frame.to_file(self.wfile)
def handshake(self):
- http1_protocol = HTTP1Protocol(self)
- req = http1_protocol.read_request()
+ req = read_request(self.rfile)
key = self.protocol.check_client_handshake(req.headers)
preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)
@@ -61,8 +62,6 @@ class WebSocketsClient(tcp.TCPClient):
def connect(self):
super(WebSocketsClient, self).connect()
- http1_protocol = HTTP1Protocol(self)
-
preamble = 'GET / HTTP/1.1'
self.wfile.write(preamble + "\r\n")
headers = self.protocol.client_handshake_headers()
@@ -70,7 +69,7 @@ class WebSocketsClient(tcp.TCPClient):
self.wfile.write(str(headers) + "\r\n")
self.wfile.flush()
- resp = http1_protocol.read_response("GET", None)
+ resp = read_response(self.rfile, treq(method="GET"))
server_nonce = self.protocol.check_server_handshake(resp.headers)
if not server_nonce == self.protocol.create_server_nonce(
@@ -158,9 +157,8 @@ class TestWebSockets(tservers.ServerTestBase):
class BadHandshakeHandler(WebSocketsEchoHandler):
def handshake(self):
- http1_protocol = HTTP1Protocol(self)
- client_hs = http1_protocol.read_request()
+ client_hs = read_request(self.rfile)
self.protocol.check_client_handshake(client_hs.headers)
preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)