aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--netlib/http/http1/protocol.py14
-rw-r--r--netlib/http/semantics.py43
-rw-r--r--netlib/odict.py3
-rw-r--r--netlib/tutils.py4
-rw-r--r--netlib/utils.py56
-rw-r--r--test/http/http1/test_protocol.py265
-rw-r--r--test/http/http2/test_protocol.py247
-rw-r--r--test/http/test_exceptions.py29
-rw-r--r--test/http/test_semantics.py389
-rw-r--r--test/test_utils.py31
-rw-r--r--test/websockets/test_websockets.py7
11 files changed, 778 insertions, 310 deletions
diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py
index 2e85a762..31e9cc85 100644
--- a/netlib/http/http1/protocol.py
+++ b/netlib/http/http1/protocol.py
@@ -360,20 +360,6 @@ class HTTP1Protocol(semantics.ProtocolMixin):
@classmethod
- def request_preamble(self, method, resource, http_major="1", http_minor="1"):
- return '%s %s HTTP/%s.%s' % (
- method, resource, http_major, http_minor
- )
-
-
- @classmethod
- def response_preamble(self, code, message=None, http_major="1", http_minor="1"):
- if message is None:
- message = status_codes.RESPONSES.get(code)
- return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message)
-
-
- @classmethod
def has_chunked_encoding(self, headers):
return "chunked" in [
i.lower() for i in utils.get_header_tokens(headers, "transfer-encoding")
diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py
index e7ae2b5f..974fe6e6 100644
--- a/netlib/http/semantics.py
+++ b/netlib/http/semantics.py
@@ -7,7 +7,7 @@ import urllib
import urlparse
from .. import utils, odict
-from . import cookies
+from . import cookies, exceptions
from netlib import utils, encoding
HDR_FORM_URLENCODED = "application/x-www-form-urlencoded"
@@ -18,10 +18,10 @@ CONTENT_MISSING = 0
class ProtocolMixin(object):
- def read_request(self):
+ def read_request(self, *args, **kwargs): # pragma: no cover
raise NotImplemented
- def read_response(self):
+ def read_response(self, *args, **kwargs): # pragma: no cover
raise NotImplemented
def assemble(self, message):
@@ -32,14 +32,23 @@ class ProtocolMixin(object):
else:
raise ValueError("HTTP message not supported.")
- def assemble_request(self, request):
+ def assemble_request(self, request): # pragma: no cover
raise NotImplemented
- def assemble_response(self, response):
+ def assemble_response(self, response): # pragma: no cover
raise NotImplemented
class Request(object):
+ # This list is adopted legacy code.
+ # We probably don't need to strip off keep-alive.
+ _headers_to_strip_off = [
+ 'Proxy-Connection',
+ 'Keep-Alive',
+ 'Connection',
+ 'Transfer-Encoding',
+ 'Upgrade',
+ ]
def __init__(
self,
@@ -71,7 +80,6 @@ class Request(object):
self.timestamp_start = timestamp_start
self.timestamp_end = timestamp_end
-
def __eq__(self, other):
try:
self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')]
@@ -114,7 +122,7 @@ class Request(object):
self.httpversion[1],
)
else:
- raise http.HttpError(400, "Invalid request form")
+ raise exceptions.HttpError(400, "Invalid request form")
def anticache(self):
"""
@@ -143,7 +151,7 @@ class Request(object):
if self.headers["accept-encoding"]:
self.headers["accept-encoding"] = [
', '.join(
- e for e in encoding.ENCODINGS if e in self.headers["accept-encoding"][0])]
+ e for e in encoding.ENCODINGS if e in self.headers.get_first("accept-encoding"))]
def update_host_header(self):
"""
@@ -317,12 +325,12 @@ class Request(object):
self.scheme, self.host, self.port, self.path = parts
@property
- def content(self):
+ def content(self): # pragma: no cover
# TODO: remove deprecated getter
return self.body
@content.setter
- def content(self, content):
+ def content(self, content): # pragma: no cover
# TODO: remove deprecated setter
self.body = content
@@ -343,6 +351,11 @@ class EmptyRequest(Request):
class Response(object):
+ _headers_to_strip_off = [
+ 'Proxy-Connection',
+ 'Alternate-Protocol',
+ 'Alt-Svc',
+ ]
def __init__(
self,
@@ -368,7 +381,6 @@ class Response(object):
self.timestamp_start = timestamp_start
self.timestamp_end = timestamp_end
-
def __eq__(self, other):
try:
self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')]
@@ -393,7 +405,6 @@ class Response(object):
size=size
)
-
def get_cookies(self):
"""
Get the contents of all Set-Cookie headers.
@@ -430,21 +441,21 @@ class Response(object):
self.headers["Set-Cookie"] = values
@property
- def content(self):
+ def content(self): # pragma: no cover
# TODO: remove deprecated getter
return self.body
@content.setter
- def content(self, content):
+ def content(self, content): # pragma: no cover
# TODO: remove deprecated setter
self.body = content
@property
- def code(self):
+ def code(self): # pragma: no cover
# TODO: remove deprecated getter
return self.status_code
@code.setter
- def code(self, code):
+ def code(self, code): # pragma: no cover
# TODO: remove deprecated setter
self.status_code = code
diff --git a/netlib/odict.py b/netlib/odict.py
index d02de08d..11d5d52a 100644
--- a/netlib/odict.py
+++ b/netlib/odict.py
@@ -91,8 +91,9 @@ class ODict(object):
self.lst = self._filter_lst(k, self.lst)
def __contains__(self, k):
+ k = self._kconv(k)
for i in self.lst:
- if self._kconv(i[0]) == self._kconv(k):
+ if self._kconv(i[0]) == k:
return True
return False
diff --git a/netlib/tutils.py b/netlib/tutils.py
index 5018b9e8..3c471d0d 100644
--- a/netlib/tutils.py
+++ b/netlib/tutils.py
@@ -119,7 +119,7 @@ def tresp(content="message"):
"OK",
headers,
content,
- time.time(),
- time.time(),
+ timestamp_start=time.time(),
+ timestamp_end=time.time(),
)
return resp
diff --git a/netlib/utils.py b/netlib/utils.py
index 35ea0ec7..2dfcafc6 100644
--- a/netlib/utils.py
+++ b/netlib/utils.py
@@ -4,6 +4,7 @@ import cgi
import urllib
import urlparse
import string
+import re
def isascii(s):
@@ -239,3 +240,58 @@ def urldecode(s):
Takes a urlencoded string and returns a list of (key, value) tuples.
"""
return cgi.parse_qsl(s, keep_blank_values=True)
+
+
+def parse_content_type(c):
+ """
+ A simple parser for content-type values. Returns a (type, subtype,
+ parameters) tuple, where type and subtype are strings, and parameters
+ is a dict. If the string could not be parsed, return None.
+
+ E.g. the following string:
+
+ text/html; charset=UTF-8
+
+ Returns:
+
+ ("text", "html", {"charset": "UTF-8"})
+ """
+ parts = c.split(";", 1)
+ ts = parts[0].split("/", 1)
+ if len(ts) != 2:
+ return None
+ d = {}
+ if len(parts) == 2:
+ for i in parts[1].split(";"):
+ clause = i.split("=", 1)
+ if len(clause) == 2:
+ d[clause[0].strip()] = clause[1].strip()
+ return ts[0].lower(), ts[1].lower(), d
+
+
+def multipartdecode(hdrs, content):
+ """
+ Takes a multipart boundary encoded string and returns list of (key, value) tuples.
+ """
+ v = hdrs.get_first("content-type")
+ if v:
+ v = parse_content_type(v)
+ if not v:
+ return []
+ boundary = v[2].get("boundary")
+ if not boundary:
+ return []
+
+ rx = re.compile(r'\bname="([^"]+)"')
+ r = []
+
+ for i in content.split("--" + boundary):
+ parts = i.splitlines()
+ if len(parts) > 1 and parts[0][0:2] != "--":
+ match = rx.search(parts[1])
+ if match:
+ key = match.group(1)
+ value = "".join(parts[3 + parts[2:].index(""):])
+ r.append((key, value))
+ return r
+ return []
diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py
index e3c3ff43..ff70b87d 100644
--- a/test/http/http1/test_protocol.py
+++ b/test/http/http1/test_protocol.py
@@ -3,16 +3,40 @@ import textwrap
import binascii
from netlib import http, odict, tcp, tutils
+from netlib.http import semantics
from netlib.http.http1 import HTTP1Protocol
from ... import tservers
-def mock_protocol(data='', chunked=False):
+class NoContentLengthHTTPHandler(tcp.BaseHandler):
+ def handle(self):
+ self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n")
+ self.wfile.flush()
+
+
+def mock_protocol(data=''):
rfile = cStringIO.StringIO(data)
wfile = cStringIO.StringIO()
return HTTP1Protocol(rfile=rfile, wfile=wfile)
+def match_http_string(data):
+ return textwrap.dedent(data).strip().replace('\n', '\r\n')
+
+
+def test_stripped_chunked_encoding_no_content():
+ """
+ https://github.com/mitmproxy/mitmproxy/issues/186
+ """
+
+ r = tutils.treq(content="")
+ r.headers["Transfer-Encoding"] = ["chunked"]
+ assert "Content-Length" in mock_protocol()._assemble_request_headers(r)
+
+ r = tutils.tresp(content="")
+ r.headers["Transfer-Encoding"] = ["chunked"]
+ assert "Content-Length" in mock_protocol()._assemble_response_headers(r)
+
def test_has_chunked_encoding():
h = odict.ODictCaseless()
@@ -75,7 +99,6 @@ def test_connection_close():
assert HTTP1Protocol.connection_close((1, 1), h)
-
def test_read_http_body_request():
h = odict.ODictCaseless()
data = "testing"
@@ -85,7 +108,7 @@ def test_read_http_body_request():
def test_read_http_body_response():
h = odict.ODictCaseless()
data = "testing"
- assert mock_protocol(data, chunked=True).read_http_body(h, None, "GET", 200, False) == "testing"
+ assert mock_protocol(data).read_http_body(h, None, "GET", 200, False) == "testing"
def test_read_http_body():
@@ -129,13 +152,13 @@ def test_read_http_body():
# test no content length: limit > actual content
h = odict.ODictCaseless()
data = "testing"
- assert len(mock_protocol(data, chunked=True).read_http_body(h, 100, "GET", 200, False)) == 7
+ assert len(mock_protocol(data).read_http_body(h, 100, "GET", 200, False)) == 7
# test no content length: limit < actual content
data = "testing"
tutils.raises(
http.HttpError,
- mock_protocol(data, chunked=True).read_http_body,
+ mock_protocol(data).read_http_body,
h, 4, "GET", 200, False
)
@@ -143,7 +166,7 @@ def test_read_http_body():
h = odict.ODictCaseless()
h["transfer-encoding"] = ["chunked"]
data = "5\r\naaaaa\r\n0\r\n\r\n"
- assert mock_protocol(data, chunked=True).read_http_body(h, 100, "GET", 200, False) == "aaaaa"
+ assert mock_protocol(data).read_http_body(h, 100, "GET", 200, False) == "aaaaa"
def test_expected_http_body_size():
@@ -167,6 +190,13 @@ def test_expected_http_body_size():
assert HTTP1Protocol.expected_http_body_size(h, True, "GET", None) == 0
+def test_get_request_line():
+ data = "\nfoo"
+ p = mock_protocol(data)
+ assert p._get_request_line() == "foo"
+ assert not p._get_request_line()
+
+
def test_parse_http_protocol():
assert HTTP1Protocol._parse_http_protocol("HTTP/1.1") == (1, 1)
assert HTTP1Protocol._parse_http_protocol("HTTP/0.0") == (0, 0)
@@ -269,96 +299,7 @@ class TestReadHeaders:
assert self._read(data) is None
-class NoContentLengthHTTPHandler(tcp.BaseHandler):
-
- def handle(self):
- self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n")
- self.wfile.flush()
-
-
-class TestReadResponseNoContentLength(tservers.ServerTestBase):
- handler = NoContentLengthHTTPHandler
-
- def test_no_content_length(self):
- c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- resp = HTTP1Protocol(c).read_response("GET", None)
- assert resp.body == "bar\r\n\r\n"
-
-
-def test_read_response():
- def tst(data, method, body_size_limit, include_body=True):
- data = textwrap.dedent(data)
- return mock_protocol(data).read_response(
- method, body_size_limit, include_body=include_body
- )
-
- tutils.raises("server disconnect", tst, "", "GET", None)
- tutils.raises("invalid server response", tst, "foo", "GET", None)
- data = """
- HTTP/1.1 200 OK
- """
- assert tst(data, "GET", None) == http.Response(
- (1, 1), 200, 'OK', odict.ODictCaseless(), ''
- )
- data = """
- HTTP/1.1 200
- """
- assert tst(data, "GET", None) == http.Response(
- (1, 1), 200, '', odict.ODictCaseless(), ''
- )
- data = """
- HTTP/x 200 OK
- """
- tutils.raises("invalid http version", tst, data, "GET", None)
- data = """
- HTTP/1.1 xx OK
- """
- tutils.raises("invalid server response", tst, data, "GET", None)
-
- data = """
- HTTP/1.1 100 CONTINUE
-
- HTTP/1.1 200 OK
- """
- assert tst(data, "GET", None) == http.Response(
- (1, 1), 100, 'CONTINUE', odict.ODictCaseless(), ''
- )
-
- data = """
- HTTP/1.1 200 OK
- Content-Length: 3
-
- foo
- """
- assert tst(data, "GET", None).body == 'foo'
- assert tst(data, "HEAD", None).body == ''
-
- data = """
- HTTP/1.1 200 OK
- \tContent-Length: 3
-
- foo
- """
- tutils.raises("invalid headers", tst, data, "GET", None)
-
- data = """
- HTTP/1.1 200 OK
- Content-Length: 3
-
- foo
- """
- assert tst(data, "GET", None, include_body=False).body is None
-
-
-def test_get_request_line():
- data = "\nfoo"
- p = mock_protocol(data)
- assert p._get_request_line() == "foo"
- assert not p._get_request_line()
-
-
-class TestReadRequest():
+class TestReadRequest(object):
def tst(self, data, **kwargs):
return mock_protocol(data).read_request(**kwargs)
@@ -385,6 +326,10 @@ class TestReadRequest():
"\r\n"
)
+ def test_empty(self):
+ v = self.tst("", allow_empty=True)
+ assert isinstance(v, semantics.EmptyRequest)
+
def test_asterisk_form_in(self):
v = self.tst("OPTIONS * HTTP/1.1")
assert v.form_in == "relative"
@@ -427,3 +372,131 @@ class TestReadRequest():
assert p.tcp_handler.wfile.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n"
assert v.body == "foo"
assert p.tcp_handler.rfile.read(3) == "bar"
+
+
+class TestReadResponse(object):
+ def tst(self, data, method, body_size_limit, include_body=True):
+ data = textwrap.dedent(data)
+ return mock_protocol(data).read_response(
+ method, body_size_limit, include_body=include_body
+ )
+
+ def test_errors(self):
+ tutils.raises("server disconnect", self.tst, "", "GET", None)
+ tutils.raises("invalid server response", self.tst, "foo", "GET", None)
+
+ def test_simple(self):
+ data = """
+ HTTP/1.1 200
+ """
+ assert self.tst(data, "GET", None) == http.Response(
+ (1, 1), 200, '', odict.ODictCaseless(), ''
+ )
+
+ def test_simple_message(self):
+ data = """
+ HTTP/1.1 200 OK
+ """
+ assert self.tst(data, "GET", None) == http.Response(
+ (1, 1), 200, 'OK', odict.ODictCaseless(), ''
+ )
+
+ def test_invalid_http_version(self):
+ data = """
+ HTTP/x 200 OK
+ """
+ tutils.raises("invalid http version", self.tst, data, "GET", None)
+
+ def test_invalid_status_code(self):
+ data = """
+ HTTP/1.1 xx OK
+ """
+ tutils.raises("invalid server response", self.tst, data, "GET", None)
+
+ def test_valid_with_continue(self):
+ data = """
+ HTTP/1.1 100 CONTINUE
+
+ HTTP/1.1 200 OK
+ """
+ assert self.tst(data, "GET", None) == http.Response(
+ (1, 1), 100, 'CONTINUE', odict.ODictCaseless(), ''
+ )
+
+ def test_simple_body(self):
+ data = """
+ HTTP/1.1 200 OK
+ Content-Length: 3
+
+ foo
+ """
+ assert self.tst(data, "GET", None).body == 'foo'
+ assert self.tst(data, "HEAD", None).body == ''
+
+ def test_invalid_headers(self):
+ data = """
+ HTTP/1.1 200 OK
+ \tContent-Length: 3
+
+ foo
+ """
+ tutils.raises("invalid headers", self.tst, data, "GET", None)
+
+ def test_without_body(self):
+ data = """
+ HTTP/1.1 200 OK
+ Content-Length: 3
+
+ foo
+ """
+ assert self.tst(data, "GET", None, include_body=False).body is None
+
+
+class TestReadResponseNoContentLength(tservers.ServerTestBase):
+ handler = NoContentLengthHTTPHandler
+
+ def test_no_content_length(self):
+ c = tcp.TCPClient(("127.0.0.1", self.port))
+ c.connect()
+ resp = HTTP1Protocol(c).read_response("GET", None)
+ assert resp.body == "bar\r\n\r\n"
+
+
+class TestAssembleRequest(object):
+ def test_simple(self):
+ req = tutils.treq()
+ b = HTTP1Protocol().assemble_request(req)
+ assert b == match_http_string("""
+ GET /path HTTP/1.1
+ header: qvalue
+ Host: address:22
+ Content-Length: 7
+
+ content""")
+
+ def test_body_missing(self):
+ req = tutils.treq(content=semantics.CONTENT_MISSING)
+ tutils.raises(http.HttpError, HTTP1Protocol().assemble_request, req)
+
+ def test_not_a_request(self):
+ tutils.raises(AssertionError, HTTP1Protocol().assemble_request, 'foo')
+
+
+class TestAssembleResponse(object):
+ def test_simple(self):
+ resp = tutils.tresp()
+ b = HTTP1Protocol().assemble_response(resp)
+ print(b)
+ assert b == match_http_string("""
+ HTTP/1.1 200 OK
+ header_response: svalue
+ Content-Length: 7
+
+ message""")
+
+ def test_body_missing(self):
+ resp = tutils.tresp(content=semantics.CONTENT_MISSING)
+ tutils.raises(http.HttpError, HTTP1Protocol().assemble_response, resp)
+
+ def test_not_a_request(self):
+ tutils.raises(AssertionError, HTTP1Protocol().assemble_response, 'foo')
diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py
index 8a27bbb1..3044179f 100644
--- a/test/http/http2/test_protocol.py
+++ b/test/http/http2/test_protocol.py
@@ -1,10 +1,25 @@
import OpenSSL
+import mock
from netlib import tcp, odict, http, tutils
from netlib.http import http2
+from netlib.http.http2 import HTTP2Protocol
from netlib.http.http2.frame import *
from ... import tservers
+class TestTCPHandlerWrapper:
+ def test_wrapped(self):
+ h = http2.TCPHandler(rfile='foo', wfile='bar')
+ p = HTTP2Protocol(h)
+ assert p.tcp_handler.rfile == 'foo'
+ assert p.tcp_handler.wfile == 'bar'
+
+ def test_direct(self):
+ p = HTTP2Protocol(rfile='foo', wfile='bar')
+ assert isinstance(p.tcp_handler, http2.TCPHandler)
+ assert p.tcp_handler.rfile == 'foo'
+ assert p.tcp_handler.wfile == 'bar'
+
class EchoHandler(tcp.BaseHandler):
sni = None
@@ -16,10 +31,40 @@ class EchoHandler(tcp.BaseHandler):
self.wfile.flush()
+class TestProtocol:
+ @mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface")
+ @mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface")
+ def test_perform_connection_preface(self, mock_client_method, mock_server_method):
+ protocol = HTTP2Protocol(is_server=False)
+ protocol.connection_preface_performed = True
+
+ protocol.perform_connection_preface()
+ assert not mock_client_method.called
+ assert not mock_server_method.called
+
+ protocol.perform_connection_preface(force=True)
+ assert mock_client_method.called
+ assert not mock_server_method.called
+
+ @mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface")
+ @mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface")
+ def test_perform_connection_preface_server(self, mock_client_method, mock_server_method):
+ protocol = HTTP2Protocol(is_server=True)
+ protocol.connection_preface_performed = True
+
+ protocol.perform_connection_preface()
+ assert not mock_client_method.called
+ assert not mock_server_method.called
+
+ protocol.perform_connection_preface(force=True)
+ assert not mock_client_method.called
+ assert mock_server_method.called
+
+
class TestCheckALPNMatch(tservers.ServerTestBase):
handler = EchoHandler
ssl = dict(
- alpn_select=http2.HTTP2Protocol.ALPN_PROTO_H2,
+ alpn_select=HTTP2Protocol.ALPN_PROTO_H2,
)
if OpenSSL._util.lib.Cryptography_HAS_ALPN:
@@ -27,8 +72,8 @@ class TestCheckALPNMatch(tservers.ServerTestBase):
def test_check_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
- c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2])
- protocol = http2.HTTP2Protocol(c)
+ c.convert_to_ssl(alpn_protos=[HTTP2Protocol.ALPN_PROTO_H2])
+ protocol = HTTP2Protocol(c)
assert protocol.check_alpn()
@@ -43,8 +88,8 @@ class TestCheckALPNMismatch(tservers.ServerTestBase):
def test_check_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
- c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2])
- protocol = http2.HTTP2Protocol(c)
+ c.convert_to_ssl(alpn_protos=[HTTP2Protocol.ALPN_PROTO_H2])
+ protocol = HTTP2Protocol(c)
tutils.raises(NotImplementedError, protocol.check_alpn)
@@ -76,8 +121,13 @@ class TestPerformServerConnectionPreface(tservers.ServerTestBase):
def test_perform_server_connection_preface(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
- protocol = http2.HTTP2Protocol(c)
+ protocol = HTTP2Protocol(c)
+
+ assert not protocol.connection_preface_performed
protocol.perform_server_connection_preface()
+ assert protocol.connection_preface_performed
+
+ tutils.raises(tcp.NetLibIncomplete, protocol.perform_server_connection_preface, force=True)
class TestPerformClientConnectionPreface(tservers.ServerTestBase):
@@ -107,13 +157,16 @@ class TestPerformClientConnectionPreface(tservers.ServerTestBase):
def test_perform_client_connection_preface(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
- protocol = http2.HTTP2Protocol(c)
+ protocol = HTTP2Protocol(c)
+
+ assert not protocol.connection_preface_performed
protocol.perform_client_connection_preface()
+ assert protocol.connection_preface_performed
class TestClientStreamIds():
c = tcp.TCPClient(("127.0.0.1", 0))
- protocol = http2.HTTP2Protocol(c)
+ protocol = HTTP2Protocol(c)
def test_client_stream_ids(self):
assert self.protocol.current_stream_id is None
@@ -127,7 +180,7 @@ class TestClientStreamIds():
class TestServerStreamIds():
c = tcp.TCPClient(("127.0.0.1", 0))
- protocol = http2.HTTP2Protocol(c, is_server=True)
+ protocol = HTTP2Protocol(c, is_server=True)
def test_server_stream_ids(self):
assert self.protocol.current_stream_id is None
@@ -154,7 +207,7 @@ class TestApplySettings(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl()
- protocol = http2.HTTP2Protocol(c)
+ protocol = HTTP2Protocol(c)
protocol._apply_settings({
SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 'foo',
@@ -182,13 +235,13 @@ class TestCreateHeaders():
(b':scheme', b'https'),
(b'foo', b'bar')]
- bytes = http2.HTTP2Protocol(self.c)._create_headers(
+ bytes = HTTP2Protocol(self.c)._create_headers(
headers, 1, end_stream=True)
assert b''.join(bytes) ==\
'000014010500000001824488355217caf3a69a3f87408294e7838c767f'\
.decode('hex')
- bytes = http2.HTTP2Protocol(self.c)._create_headers(
+ bytes = HTTP2Protocol(self.c)._create_headers(
headers, 1, end_stream=False)
assert b''.join(bytes) ==\
'000014010400000001824488355217caf3a69a3f87408294e7838c767f'\
@@ -199,7 +252,7 @@ class TestCreateHeaders():
class TestCreateBody():
c = tcp.TCPClient(("127.0.0.1", 0))
- protocol = http2.HTTP2Protocol(c)
+ protocol = HTTP2Protocol(c)
def test_create_body_empty(self):
bytes = self.protocol._create_body(b'', 1)
@@ -215,41 +268,30 @@ class TestCreateBody():
# TODO: add test for too large frames
-class TestAssembleRequest():
- c = tcp.TCPClient(("127.0.0.1", 0))
+class TestReadRequest(tservers.ServerTestBase):
+ class handler(tcp.BaseHandler):
- def test_assemble_request_simple(self):
- bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request(
- '',
- 'GET',
- 'https',
- '',
- '',
- '/',
- (2, 0),
- None,
- None,
- ))
- assert len(bytes) == 1
- assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex')
+ def handle(self):
+ self.wfile.write(
+ b'000003010400000001828487'.decode('hex'))
+ self.wfile.write(
+ b'000006000100000001666f6f626172'.decode('hex'))
+ self.wfile.flush()
- def test_assemble_request_with_body(self):
- bytes = http2.HTTP2Protocol(self.c).assemble_request(http.Request(
- '',
- 'GET',
- 'https',
- '',
- '',
- '/',
- (2, 0),
- odict.ODictCaseless([('foo', 'bar')]),
- 'foobar',
- ))
- assert len(bytes) == 2
- assert bytes[0] ==\
- '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex')
- assert bytes[1] ==\
- '000006000100000001666f6f626172'.decode('hex')
+ ssl = True
+
+ def test_read_request(self):
+ c = tcp.TCPClient(("127.0.0.1", self.port))
+ c.connect()
+ c.convert_to_ssl()
+ protocol = HTTP2Protocol(c, is_server=True)
+ protocol.connection_preface_performed = True
+
+ resp = protocol.read_request()
+
+ assert resp.stream_id
+ assert resp.headers.lst == [[u':method', u'GET'], [u':path', u'/'], [u':scheme', u'https']]
+ assert resp.body == b'foobar'
class TestReadResponse(tservers.ServerTestBase):
@@ -268,7 +310,7 @@ class TestReadResponse(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl()
- protocol = http2.HTTP2Protocol(c)
+ protocol = HTTP2Protocol(c)
protocol.connection_preface_performed = True
resp = protocol.read_response()
@@ -278,6 +320,23 @@ class TestReadResponse(tservers.ServerTestBase):
assert resp.msg == ""
assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']]
assert resp.body == b'foobar'
+ assert resp.timestamp_end
+
+ def test_read_response_no_body(self):
+ c = tcp.TCPClient(("127.0.0.1", self.port))
+ c.connect()
+ c.convert_to_ssl()
+ protocol = HTTP2Protocol(c)
+ protocol.connection_preface_performed = True
+
+ resp = protocol.read_response(include_body=False)
+
+ assert resp.httpversion == (2, 0)
+ assert resp.status_code == 200
+ assert resp.msg == ""
+ assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']]
+ assert resp.body == b'foobar' # TODO: this should be true: assert resp.body == http.CONTENT_MISSING
+ assert not resp.timestamp_end
class TestReadEmptyResponse(tservers.ServerTestBase):
@@ -294,7 +353,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
c.convert_to_ssl()
- protocol = http2.HTTP2Protocol(c)
+ protocol = HTTP2Protocol(c)
protocol.connection_preface_performed = True
resp = protocol.read_response()
@@ -307,37 +366,66 @@ class TestReadEmptyResponse(tservers.ServerTestBase):
assert resp.body == b''
-class TestReadRequest(tservers.ServerTestBase):
- class handler(tcp.BaseHandler):
-
- def handle(self):
- self.wfile.write(
- b'000003010400000001828487'.decode('hex'))
- self.wfile.write(
- b'000006000100000001666f6f626172'.decode('hex'))
- self.wfile.flush()
-
- ssl = True
+class TestAssembleRequest(object):
+ c = tcp.TCPClient(("127.0.0.1", 0))
- def test_read_request(self):
- c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- c.convert_to_ssl()
- protocol = http2.HTTP2Protocol(c, is_server=True)
- protocol.connection_preface_performed = True
+ def test_request_simple(self):
+ bytes = HTTP2Protocol(self.c).assemble_request(http.Request(
+ '',
+ 'GET',
+ 'https',
+ '',
+ '',
+ '/',
+ (2, 0),
+ None,
+ None,
+ ))
+ assert len(bytes) == 1
+ assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex')
- resp = protocol.read_request()
+ def test_request_with_stream_id(self):
+ req = http.Request(
+ '',
+ 'GET',
+ 'https',
+ '',
+ '',
+ '/',
+ (2, 0),
+ None,
+ None,
+ )
+ req.stream_id = 0x42
+ bytes = HTTP2Protocol(self.c).assemble_request(req)
+ assert len(bytes) == 1
+ print(bytes[0].encode('hex'))
+ assert bytes[0] == '00000d0105000000428284874188089d5c0b8170dc07'.decode('hex')
- assert resp.stream_id
- assert resp.headers.lst == [[u':method', u'GET'], [u':path', u'/'], [u':scheme', u'https']]
- assert resp.body == b'foobar'
+ def test_request_with_body(self):
+ bytes = HTTP2Protocol(self.c).assemble_request(http.Request(
+ '',
+ 'GET',
+ 'https',
+ '',
+ '',
+ '/',
+ (2, 0),
+ odict.ODictCaseless([('foo', 'bar')]),
+ 'foobar',
+ ))
+ assert len(bytes) == 2
+ assert bytes[0] ==\
+ '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex')
+ assert bytes[1] ==\
+ '000006000100000001666f6f626172'.decode('hex')
-class TestCreateResponse():
+class TestAssembleResponse(object):
c = tcp.TCPClient(("127.0.0.1", 0))
- def test_create_response_simple(self):
- bytes = http2.HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response(
+ def test_simple(self):
+ bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response(
(2, 0),
200,
))
@@ -345,8 +433,19 @@ class TestCreateResponse():
assert bytes[0] ==\
'00000101050000000288'.decode('hex')
- def test_create_response_with_body(self):
- bytes = http2.HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response(
+ def test_with_stream_id(self):
+ resp = http.Response(
+ (2, 0),
+ 200,
+ )
+ resp.stream_id = 0x42
+ bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(resp)
+ assert len(bytes) == 1
+ assert bytes[0] ==\
+ '00000101050000004288'.decode('hex')
+
+ def test_with_body(self):
+ bytes = HTTP2Protocol(self.c, is_server=True).assemble_response(http.Response(
(2, 0),
200,
'',
diff --git a/test/http/test_exceptions.py b/test/http/test_exceptions.py
index aa57f831..0131c7ef 100644
--- a/test/http/test_exceptions.py
+++ b/test/http/test_exceptions.py
@@ -1,6 +1,27 @@
from netlib.http.exceptions import *
+from netlib import odict
-def test_HttpAuthenticationError():
- x = HttpAuthenticationError({"foo": "bar"})
- assert str(x)
- assert "foo" in x.headers
+class TestHttpError:
+ def test_simple(self):
+ e = HttpError(404, "Not found")
+ assert str(e)
+
+class TestHttpAuthenticationError:
+ def test_init(self):
+ headers = odict.ODictCaseless([("foo", "bar")])
+ x = HttpAuthenticationError(headers)
+ assert str(x)
+ assert isinstance(x.headers, odict.ODictCaseless)
+ assert x.code == 407
+ assert x.headers == headers
+ print(x.headers.keys())
+ assert "foo" in x.headers.keys()
+
+ def test_header_conversion(self):
+ headers = {"foo": "bar"}
+ x = HttpAuthenticationError(headers)
+ assert isinstance(x.headers, odict.ODictCaseless)
+ assert x.headers.lst == headers.items()
+
+ def test_repr(self):
+ assert repr(HttpAuthenticationError()) == "Proxy Authentication Required"
diff --git a/test/http/test_semantics.py b/test/http/test_semantics.py
index d58a44d2..59364eae 100644
--- a/test/http/test_semantics.py
+++ b/test/http/test_semantics.py
@@ -1,18 +1,277 @@
import cStringIO
import textwrap
import binascii
+import mock
from mock import MagicMock
-from netlib import http, odict, tcp, tutils
-from netlib.http import http1
+from netlib import http, odict, tcp, tutils, utils
+from netlib.http import semantics
from netlib.http.semantics import CONTENT_MISSING
from .. import tservers
-def test_httperror():
- e = http.exceptions.HttpError(404, "Not found")
- assert str(e)
+class TestProtocolMixin(object):
+ @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_response")
+ @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_request")
+ def test_assemble_request(self, mock_request_method, mock_response_method):
+ p = semantics.ProtocolMixin()
+ p.assemble(tutils.treq())
+ assert mock_request_method.called
+ assert not mock_response_method.called
+
+ @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_response")
+ @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_request")
+ def test_assemble_response(self, mock_request_method, mock_response_method):
+ p = semantics.ProtocolMixin()
+ p.assemble(tutils.tresp())
+ assert not mock_request_method.called
+ assert mock_response_method.called
+
+ def test_assemble_foo(self):
+ p = semantics.ProtocolMixin()
+ tutils.raises(ValueError, p.assemble, 'foo')
+
+class TestRequest(object):
+ def test_repr(self):
+ r = tutils.treq()
+ assert repr(r)
+
+ def test_headers_odict(self):
+ tutils.raises(AssertionError, semantics.Request,
+ 'form_in',
+ 'method',
+ 'scheme',
+ 'host',
+ 'port',
+ 'path',
+ (1, 1),
+ 'foobar',
+ )
+
+ req = semantics.Request(
+ 'form_in',
+ 'method',
+ 'scheme',
+ 'host',
+ 'port',
+ 'path',
+ (1, 1),
+ )
+ assert isinstance(req.headers, odict.ODictCaseless)
+
+ def test_equal(self):
+ a = tutils.treq()
+ b = tutils.treq()
+ assert a == b
+
+ assert not a == 'foo'
+ assert not b == 'foo'
+ assert not 'foo' == a
+ assert not 'foo' == b
+
+ def test_legacy_first_line(self):
+ req = tutils.treq()
+
+ req.form_in = 'relative'
+ assert req.legacy_first_line() == "GET /path HTTP/1.1"
+
+ req.form_in = 'authority'
+ assert req.legacy_first_line() == "GET address:22 HTTP/1.1"
+
+ req.form_in = 'absolute'
+ assert req.legacy_first_line() == "GET http://address:22/path HTTP/1.1"
+
+ req.form_in = 'foobar'
+ tutils.raises(http.HttpError, req.legacy_first_line)
+
+ def test_anticache(self):
+ req = tutils.treq()
+ req.headers.add("If-Modified-Since", "foo")
+ req.headers.add("If-None-Match", "bar")
+ req.anticache()
+ assert "If-Modified-Since" not in req.headers
+ assert "If-None-Match" not in req.headers
+
+ def test_anticomp(self):
+ req = tutils.treq()
+ req.headers.add("Accept-Encoding", "foobar")
+ req.anticomp()
+ assert req.headers["Accept-Encoding"] == ["identity"]
+
+ def test_constrain_encoding(self):
+ req = tutils.treq()
+ req.headers.add("Accept-Encoding", "identity, gzip, foo")
+ req.constrain_encoding()
+ assert "foo" not in req.headers.get_first("Accept-Encoding")
+
+ def test_update_host(self):
+ req = tutils.treq()
+ req.headers.add("Host", "")
+ req.host = "foobar"
+ req.update_host_header()
+ assert req.headers.get_first("Host") == "foobar"
+
+ def test_get_form(self):
+ req = tutils.treq()
+ assert req.get_form() == odict.ODict()
+
+ @mock.patch("netlib.http.semantics.Request.get_form_multipart")
+ @mock.patch("netlib.http.semantics.Request.get_form_urlencoded")
+ def test_get_form_with_url_encoded(self, mock_method_urlencoded, mock_method_multipart):
+ req = tutils.treq()
+ assert req.get_form() == odict.ODict()
+
+ req = tutils.treq()
+ req.body = "foobar"
+ req.headers["Content-Type"] = [semantics.HDR_FORM_URLENCODED]
+ req.get_form()
+ assert req.get_form_urlencoded.called
+ assert not req.get_form_multipart.called
+
+ @mock.patch("netlib.http.semantics.Request.get_form_multipart")
+ @mock.patch("netlib.http.semantics.Request.get_form_urlencoded")
+ def test_get_form_with_multipart(self, mock_method_urlencoded, mock_method_multipart):
+ req = tutils.treq()
+ req.body = "foobar"
+ req.headers["Content-Type"] = [semantics.HDR_FORM_MULTIPART]
+ req.get_form()
+ assert not req.get_form_urlencoded.called
+ assert req.get_form_multipart.called
+
+ def test_get_form_urlencoded(self):
+ req = tutils.treq("foobar")
+ assert req.get_form_urlencoded() == odict.ODict()
+
+ req.headers["Content-Type"] = [semantics.HDR_FORM_URLENCODED]
+ assert req.get_form_urlencoded() == odict.ODict(utils.urldecode(req.body))
+
+ def test_get_form_multipart(self):
+ req = tutils.treq("foobar")
+ assert req.get_form_multipart() == odict.ODict()
+
+ req.headers["Content-Type"] = [semantics.HDR_FORM_MULTIPART]
+ assert req.get_form_multipart() == odict.ODict(
+ utils.multipartdecode(
+ req.headers,
+ req.body))
+
+ def test_set_form_urlencoded(self):
+ req = tutils.treq()
+ req.set_form_urlencoded(odict.ODict([('foo', 'bar'), ('rab', 'oof')]))
+ assert req.headers.get_first("Content-Type") == semantics.HDR_FORM_URLENCODED
+ assert req.body
+
+ def test_get_path_components(self):
+ req = tutils.treq()
+ assert req.get_path_components()
+ # TODO: add meaningful assertions
+
+ def test_set_path_components(self):
+ req = tutils.treq()
+ req.set_path_components(["foo", "bar"])
+ # TODO: add meaningful assertions
+
+ def test_get_query(self):
+ req = tutils.treq()
+ assert req.get_query().lst == []
+
+ req.url = "http://localhost:80/foo?bar=42"
+ assert req.get_query().lst == [("bar", "42")]
+
+ def test_set_query(self):
+ req = tutils.treq()
+ req.set_query(odict.ODict([]))
+
+ def test_pretty_host(self):
+ r = tutils.treq()
+ assert r.pretty_host(True) == "address"
+ assert r.pretty_host(False) == "address"
+ r.headers["host"] = ["other"]
+ assert r.pretty_host(True) == "other"
+ assert r.pretty_host(False) == "address"
+ r.host = None
+ assert r.pretty_host(True) == "other"
+ assert r.pretty_host(False) is None
+ del r.headers["host"]
+ assert r.pretty_host(True) is None
+ assert r.pretty_host(False) is None
+
+ # Invalid IDNA
+ r.headers["host"] = [".disqus.com"]
+ assert r.pretty_host(True) == ".disqus.com"
+
+ def test_pretty_url(self):
+ req = tutils.treq()
+ req.form_out = "authority"
+ assert req.pretty_url(True) == "address:22"
+ assert req.pretty_url(False) == "address:22"
+
+ req.form_out = "relative"
+ assert req.pretty_url(True) == "http://address:22/path"
+ assert req.pretty_url(False) == "http://address:22/path"
+
+ def test_get_cookies_none(self):
+ h = odict.ODictCaseless()
+ r = tutils.treq()
+ r.headers = h
+ assert len(r.get_cookies()) == 0
+
+ def test_get_cookies_single(self):
+ h = odict.ODictCaseless()
+ h["Cookie"] = ["cookiename=cookievalue"]
+ r = tutils.treq()
+ r.headers = h
+ result = r.get_cookies()
+ assert len(result) == 1
+ assert result['cookiename'] == ['cookievalue']
+
+ def test_get_cookies_double(self):
+ h = odict.ODictCaseless()
+ h["Cookie"] = [
+ "cookiename=cookievalue;othercookiename=othercookievalue"
+ ]
+ r = tutils.treq()
+ r.headers = h
+ result = r.get_cookies()
+ assert len(result) == 2
+ assert result['cookiename'] == ['cookievalue']
+ assert result['othercookiename'] == ['othercookievalue']
+
+ def test_get_cookies_withequalsign(self):
+ h = odict.ODictCaseless()
+ h["Cookie"] = [
+ "cookiename=coo=kievalue;othercookiename=othercookievalue"
+ ]
+ r = tutils.treq()
+ r.headers = h
+ result = r.get_cookies()
+ assert len(result) == 2
+ assert result['cookiename'] == ['coo=kievalue']
+ assert result['othercookiename'] == ['othercookievalue']
+
+ def test_set_cookies(self):
+ h = odict.ODictCaseless()
+ h["Cookie"] = ["cookiename=cookievalue"]
+ r = tutils.treq()
+ r.headers = h
+ result = r.get_cookies()
+ result["cookiename"] = ["foo"]
+ r.set_cookies(result)
+ assert r.get_cookies()["cookiename"] == ["foo"]
+
+ def test_set_url(self):
+ r = tutils.treq_absolute()
+ r.url = "https://otheraddress:42/ORLY"
+ assert r.scheme == "https"
+ assert r.host == "otheraddress"
+ assert r.port == 42
+ assert r.path == "/ORLY"
+
+ try:
+ r.url = "//localhost:80/foo@bar"
+ assert False
+ except:
+ assert True
-class TestRequest:
# def test_asterisk_form_in(self):
# f = tutils.tflow(req=None)
# protocol = mock_protocol("OPTIONS * HTTP/1.1")
@@ -92,105 +351,35 @@ class TestRequest:
# "Host: address\r\n"
# "Content-Length: 0\r\n\r\n")
- def test_set_url(self):
- r = tutils.treq_absolute()
- r.url = "https://otheraddress:42/ORLY"
- assert r.scheme == "https"
- assert r.host == "otheraddress"
- assert r.port == 42
- assert r.path == "/ORLY"
-
- def test_repr(self):
- r = tutils.treq()
- assert repr(r)
-
- def test_pretty_host(self):
- r = tutils.treq()
- assert r.pretty_host(True) == "address"
- assert r.pretty_host(False) == "address"
- r.headers["host"] = ["other"]
- assert r.pretty_host(True) == "other"
- assert r.pretty_host(False) == "address"
- r.host = None
- assert r.pretty_host(True) == "other"
- assert r.pretty_host(False) is None
- del r.headers["host"]
- assert r.pretty_host(True) is None
- assert r.pretty_host(False) is None
-
- # Invalid IDNA
- r.headers["host"] = [".disqus.com"]
- assert r.pretty_host(True) == ".disqus.com"
-
- def test_get_form_for_urlencoded(self):
- r = tutils.treq()
- r.headers.add("content-type", "application/x-www-form-urlencoded")
- r.get_form_urlencoded = MagicMock()
-
- r.get_form()
-
- assert r.get_form_urlencoded.called
-
- def test_get_form_for_multipart(self):
- r = tutils.treq()
- r.headers.add("content-type", "multipart/form-data")
- r.get_form_multipart = MagicMock()
-
- r.get_form()
+class TestEmptyRequest(object):
+ def test_init(self):
+ req = semantics.EmptyRequest()
+ assert req
- assert r.get_form_multipart.called
-
- def test_get_cookies_none(self):
- h = odict.ODictCaseless()
- r = tutils.treq()
- r.headers = h
- assert len(r.get_cookies()) == 0
-
- def test_get_cookies_single(self):
- h = odict.ODictCaseless()
- h["Cookie"] = ["cookiename=cookievalue"]
- r = tutils.treq()
- r.headers = h
- result = r.get_cookies()
- assert len(result) == 1
- assert result['cookiename'] == ['cookievalue']
-
- def test_get_cookies_double(self):
- h = odict.ODictCaseless()
- h["Cookie"] = [
- "cookiename=cookievalue;othercookiename=othercookievalue"
- ]
- r = tutils.treq()
- r.headers = h
- result = r.get_cookies()
- assert len(result) == 2
- assert result['cookiename'] == ['cookievalue']
- assert result['othercookiename'] == ['othercookievalue']
+class TestResponse(object):
+ def test_headers_odict(self):
+ tutils.raises(AssertionError, semantics.Response,
+ (1, 1),
+ 200,
+ headers='foobar',
+ )
- def test_get_cookies_withequalsign(self):
- h = odict.ODictCaseless()
- h["Cookie"] = [
- "cookiename=coo=kievalue;othercookiename=othercookievalue"
- ]
- r = tutils.treq()
- r.headers = h
- result = r.get_cookies()
- assert len(result) == 2
- assert result['cookiename'] == ['coo=kievalue']
- assert result['othercookiename'] == ['othercookievalue']
+ resp = semantics.Response(
+ (1, 1),
+ 200,
+ )
+ assert isinstance(resp.headers, odict.ODictCaseless)
- def test_set_cookies(self):
- h = odict.ODictCaseless()
- h["Cookie"] = ["cookiename=cookievalue"]
- r = tutils.treq()
- r.headers = h
- result = r.get_cookies()
- result["cookiename"] = ["foo"]
- r.set_cookies(result)
- assert r.get_cookies()["cookiename"] == ["foo"]
+ def test_equal(self):
+ a = tutils.tresp()
+ b = tutils.tresp()
+ assert a == b
+ assert not a == 'foo'
+ assert not b == 'foo'
+ assert not 'foo' == a
+ assert not 'foo' == b
-class TestResponse(object):
def test_repr(self):
r = tutils.tresp()
assert "unknown content type" in repr(r)
diff --git a/test/test_utils.py b/test/test_utils.py
index 5e681eb6..aafa1571 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -101,3 +101,34 @@ def test_get_header_tokens():
assert utils.get_header_tokens(h, "foo") == ["bar", "voing"]
h["foo"] = ["bar, voing", "oink"]
assert utils.get_header_tokens(h, "foo") == ["bar", "voing", "oink"]
+
+
+
+
+
+def test_multipartdecode():
+ boundary = 'somefancyboundary'
+ headers = odict.ODict(
+ [('content-type', ('multipart/form-data; boundary=%s' % boundary))])
+ content = "--{0}\n" \
+ "Content-Disposition: form-data; name=\"field1\"\n\n" \
+ "value1\n" \
+ "--{0}\n" \
+ "Content-Disposition: form-data; name=\"field2\"\n\n" \
+ "value2\n" \
+ "--{0}--".format(boundary)
+
+ form = utils.multipartdecode(headers, content)
+
+ assert len(form) == 2
+ assert form[0] == ('field1', 'value1')
+ assert form[1] == ('field2', 'value2')
+
+
+def test_parse_content_type():
+ p = utils.parse_content_type
+ assert p("text/html") == ("text", "html", {})
+ assert p("text") is None
+
+ v = p("text/html; charset=UTF-8")
+ assert v == ('text', 'html', {'charset': 'UTF-8'})
diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py
index 28dbb833..9fa98172 100644
--- a/test/websockets/test_websockets.py
+++ b/test/websockets/test_websockets.py
@@ -3,6 +3,7 @@ import os
from nose.tools import raises
from netlib import tcp, http, websockets, tutils
+from netlib.http import status_codes
from netlib.http.exceptions import *
from netlib.http.http1 import HTTP1Protocol
from .. import tservers
@@ -38,7 +39,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
req = http1_protocol.read_request()
key = self.protocol.check_client_handshake(req.headers)
- preamble = http1_protocol.response_preamble(101)
+ preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)
self.wfile.write(preamble + "\r\n")
headers = self.protocol.server_handshake_headers(key)
self.wfile.write(headers.format() + "\r\n")
@@ -62,7 +63,7 @@ class WebSocketsClient(tcp.TCPClient):
http1_protocol = HTTP1Protocol(self)
- preamble = http1_protocol.request_preamble("GET", "/")
+ preamble = 'GET / HTTP/1.1'
self.wfile.write(preamble + "\r\n")
headers = self.protocol.client_handshake_headers()
self.client_nonce = headers.get_first("sec-websocket-key")
@@ -162,7 +163,7 @@ class BadHandshakeHandler(WebSocketsEchoHandler):
client_hs = http1_protocol.read_request()
self.protocol.check_client_handshake(client_hs.headers)
- preamble = http1_protocol.response_preamble(101)
+ preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)
self.wfile.write(preamble + "\r\n")
headers = self.protocol.server_handshake_headers("malformed key")
self.wfile.write(headers.format() + "\r\n")