aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMaximilian Hils <git@maximilianhils.com>2015-09-16 20:19:52 +0200
committerMaximilian Hils <git@maximilianhils.com>2015-09-16 20:19:52 +0200
commite1659f3fcf83b5993b776a4ef3d2de70fbe27aa2 (patch)
treec0eba50b522d1d0183b057e9cae7bf7cc38c4fc3
parent2f9c566e480c377566a0ae044d698a75b45cd54c (diff)
parent265f31e8782ee9da511ce4b63aa2da00221cbf66 (diff)
downloadmitmproxy-e1659f3fcf83b5993b776a4ef3d2de70fbe27aa2.tar.gz
mitmproxy-e1659f3fcf83b5993b776a4ef3d2de70fbe27aa2.tar.bz2
mitmproxy-e1659f3fcf83b5993b776a4ef3d2de70fbe27aa2.zip
Merge pull request #92 from mitmproxy/python3
Python3 & HTTP1 Refactor
-rw-r--r--.travis.yml4
-rw-r--r--netlib/encoding.py8
-rw-r--r--netlib/exceptions.py32
-rw-r--r--netlib/http/__init__.py14
-rw-r--r--netlib/http/authentication.py4
-rw-r--r--netlib/http/exceptions.py9
-rw-r--r--netlib/http/http1/__init__.py24
-rw-r--r--netlib/http/http1/assemble.py103
-rw-r--r--netlib/http/http1/protocol.py586
-rw-r--r--netlib/http/http1/read.py360
-rw-r--r--netlib/http/http2/__init__.py8
-rw-r--r--netlib/http/http2/connections.py (renamed from netlib/http/http2/protocol.py)30
-rw-r--r--netlib/http/http2/frame.py39
-rw-r--r--netlib/http/models.py (renamed from netlib/http/semantics.py)226
-rw-r--r--netlib/tcp.py8
-rw-r--r--netlib/tutils.py144
-rw-r--r--netlib/utils.py166
-rw-r--r--netlib/version_check.py17
-rw-r--r--netlib/websockets/__init__.py4
-rw-r--r--test/http/http1/test_assemble.py91
-rw-r--r--test/http/http1/test_protocol.py497
-rw-r--r--test/http/http1/test_read.py317
-rw-r--r--test/http/http2/test_frames.py6
-rw-r--r--test/http/http2/test_protocol.py36
-rw-r--r--test/http/test_authentication.py2
-rw-r--r--test/http/test_exceptions.py6
-rw-r--r--test/http/test_models.py (renamed from test/http/test_semantics.py)173
-rw-r--r--test/test_encoding.py24
-rw-r--r--test/test_utils.py75
-rw-r--r--test/test_version_check.py8
-rw-r--r--test/tservers.py8
-rw-r--r--test/websockets/test_websockets.py16
32 files changed, 1441 insertions, 1604 deletions
diff --git a/.travis.yml b/.travis.yml
index fd2fba3d..fa997542 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -15,6 +15,8 @@ matrix:
- debian-sid
packages:
- libssl-dev
+ - python: 3.5
+ script: "nosetests --with-cov --cov-report term-missing test/http/http1"
- python: pypy
- python: pypy
env: OPENSSL=1.0.2
@@ -67,4 +69,4 @@ cache:
- /home/travis/virtualenv/python2.7.9/lib/python2.7/site-packages
- /home/travis/virtualenv/python2.7.9/bin
- /home/travis/virtualenv/pypy-2.5.0/site-packages
- - /home/travis/virtualenv/pypy-2.5.0/bin \ No newline at end of file
+ - /home/travis/virtualenv/pypy-2.5.0/bin
diff --git a/netlib/encoding.py b/netlib/encoding.py
index f107eb5f..06830f2c 100644
--- a/netlib/encoding.py
+++ b/netlib/encoding.py
@@ -2,13 +2,13 @@
Utility functions for decoding response bodies.
"""
from __future__ import absolute_import
-import cStringIO
+from io import BytesIO
import gzip
import zlib
__ALL__ = ["ENCODINGS"]
-ENCODINGS = set(["identity", "gzip", "deflate"])
+ENCODINGS = {"identity", "gzip", "deflate"}
def decode(e, content):
@@ -42,7 +42,7 @@ def identity(content):
def decode_gzip(content):
- gfile = gzip.GzipFile(fileobj=cStringIO.StringIO(content))
+ gfile = gzip.GzipFile(fileobj=BytesIO(content))
try:
return gfile.read()
except (IOError, EOFError):
@@ -50,7 +50,7 @@ def decode_gzip(content):
def encode_gzip(content):
- s = cStringIO.StringIO()
+ s = BytesIO()
gf = gzip.GzipFile(fileobj=s, mode='wb')
gf.write(content)
gf.close()
diff --git a/netlib/exceptions.py b/netlib/exceptions.py
new file mode 100644
index 00000000..e13af473
--- /dev/null
+++ b/netlib/exceptions.py
@@ -0,0 +1,32 @@
+"""
+We try to be very hygienic regarding the exceptions we throw:
+Every Exception netlib raises shall be a subclass of NetlibException.
+
+
+See also: http://lucumr.pocoo.org/2014/10/16/on-error-handling/
+"""
+from __future__ import absolute_import, print_function, division
+
+
+class NetlibException(Exception):
+ """
+ Base class for all exceptions thrown by netlib.
+ """
+ def __init__(self, message=None):
+ super(NetlibException, self).__init__(message)
+
+
+class ReadDisconnect(object):
+ """Immediate EOF"""
+
+
+class HttpException(NetlibException):
+ pass
+
+
+class HttpReadDisconnect(HttpException, ReadDisconnect):
+ pass
+
+
+class HttpSyntaxException(HttpException):
+ pass
diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py
index 9b4b0e6b..d72884b3 100644
--- a/netlib/http/__init__.py
+++ b/netlib/http/__init__.py
@@ -1,2 +1,12 @@
-from exceptions import *
-from semantics import *
+from __future__ import absolute_import, print_function, division
+from .models import Request, Response, Headers
+from .models import ALPN_PROTO_HTTP1, ALPN_PROTO_H2
+from .models import HDR_FORM_MULTIPART, HDR_FORM_URLENCODED, CONTENT_MISSING
+from . import http1, http2
+
+__all__ = [
+ "Request", "Response", "Headers",
+ "ALPN_PROTO_HTTP1", "ALPN_PROTO_H2",
+ "HDR_FORM_MULTIPART", "HDR_FORM_URLENCODED", "CONTENT_MISSING",
+ "http1", "http2",
+]
diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py
index fe1f0d14..2055f843 100644
--- a/netlib/http/authentication.py
+++ b/netlib/http/authentication.py
@@ -19,8 +19,8 @@ def parse_http_basic_auth(s):
def assemble_http_basic_auth(scheme, username, password):
- v = binascii.b2a_base64(username + ":" + password)
- return scheme + " " + v
+ v = binascii.b2a_base64(username + b":" + password)
+ return scheme + b" " + v
class NullProxyAuth(object):
diff --git a/netlib/http/exceptions.py b/netlib/http/exceptions.py
deleted file mode 100644
index 8a2bbebc..00000000
--- a/netlib/http/exceptions.py
+++ /dev/null
@@ -1,9 +0,0 @@
-class HttpError(Exception):
-
- def __init__(self, code, message):
- super(HttpError, self).__init__(message)
- self.code = code
-
-
-class HttpErrorConnClosed(HttpError):
- pass
diff --git a/netlib/http/http1/__init__.py b/netlib/http/http1/__init__.py
index 6b5043af..2d33ff8a 100644
--- a/netlib/http/http1/__init__.py
+++ b/netlib/http/http1/__init__.py
@@ -1 +1,23 @@
-from protocol import *
+from __future__ import absolute_import, print_function, division
+from .read import (
+ read_request, read_request_head,
+ read_response, read_response_head,
+ read_body,
+ connection_close,
+ expected_http_body_size,
+)
+from .assemble import (
+ assemble_request, assemble_request_head,
+ assemble_response, assemble_response_head,
+)
+
+
+__all__ = [
+ "read_request", "read_request_head",
+ "read_response", "read_response_head",
+ "read_body",
+ "connection_close",
+ "expected_http_body_size",
+ "assemble_request", "assemble_request_head",
+ "assemble_response", "assemble_response_head",
+]
diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py
new file mode 100644
index 00000000..ace25d79
--- /dev/null
+++ b/netlib/http/http1/assemble.py
@@ -0,0 +1,103 @@
+from __future__ import absolute_import, print_function, division
+
+from ... import utils
+from ...exceptions import HttpException
+from .. import CONTENT_MISSING
+
+
+def assemble_request(request):
+ if request.body == CONTENT_MISSING:
+ raise HttpException("Cannot assemble flow with CONTENT_MISSING")
+ head = assemble_request_head(request)
+ return head + request.body
+
+
+def assemble_request_head(request):
+ first_line = _assemble_request_line(request)
+ headers = _assemble_request_headers(request)
+ return b"%s\r\n%s\r\n" % (first_line, headers)
+
+
+def assemble_response(response):
+ if response.body == CONTENT_MISSING:
+ raise HttpException("Cannot assemble flow with CONTENT_MISSING")
+ head = assemble_response_head(response)
+ return head + response.body
+
+
+def assemble_response_head(response, preserve_transfer_encoding=False):
+ first_line = _assemble_response_line(response)
+ headers = _assemble_response_headers(response, preserve_transfer_encoding)
+ return b"%s\r\n%s\r\n" % (first_line, headers)
+
+
+def _assemble_request_line(request, form=None):
+ if form is None:
+ form = request.form_out
+ if form == "relative":
+ return b"%s %s %s" % (
+ request.method,
+ request.path,
+ request.httpversion
+ )
+ elif form == "authority":
+ return b"%s %s:%d %s" % (
+ request.method,
+ request.host,
+ request.port,
+ request.httpversion
+ )
+ elif form == "absolute":
+ return b"%s %s://%s:%d%s %s" % (
+ request.method,
+ request.scheme,
+ request.host,
+ request.port,
+ request.path,
+ request.httpversion
+ )
+ else: # pragma: nocover
+ raise RuntimeError("Invalid request form")
+
+
+def _assemble_request_headers(request):
+ headers = request.headers.copy()
+ for k in request._headers_to_strip_off:
+ headers.pop(k, None)
+ if b"host" not in headers and request.scheme and request.host and request.port:
+ headers[b"Host"] = utils.hostport(
+ request.scheme,
+ request.host,
+ request.port
+ )
+
+ # If content is defined (i.e. not None or CONTENT_MISSING), we always
+ # add a content-length header.
+ if request.body or request.body == b"":
+ headers[b"Content-Length"] = str(len(request.body)).encode("ascii")
+
+ return bytes(headers)
+
+
+def _assemble_response_line(response):
+ return b"%s %d %s" % (
+ response.httpversion,
+ response.status_code,
+ response.msg,
+ )
+
+
+def _assemble_response_headers(response, preserve_transfer_encoding=False):
+ # TODO: Remove preserve_transfer_encoding
+ headers = response.headers.copy()
+ for k in response._headers_to_strip_off:
+ headers.pop(k, None)
+ if not preserve_transfer_encoding:
+ headers.pop(b"Transfer-Encoding", None)
+
+ # If body is defined (i.e. not None or CONTENT_MISSING), we always
+ # add a content-length header.
+ if response.body or response.body == b"":
+ headers[b"Content-Length"] = str(len(response.body)).encode("ascii")
+
+ return bytes(headers)
diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py
deleted file mode 100644
index cf1dffa3..00000000
--- a/netlib/http/http1/protocol.py
+++ /dev/null
@@ -1,586 +0,0 @@
-from __future__ import (absolute_import, print_function, division)
-import string
-import sys
-import time
-
-from ... import utils, tcp, http
-from .. import semantics, Headers
-from ..exceptions import *
-
-
-class TCPHandler(object):
-
- def __init__(self, rfile, wfile=None):
- self.rfile = rfile
- self.wfile = wfile
-
-
-class HTTP1Protocol(semantics.ProtocolMixin):
-
- ALPN_PROTO_HTTP1 = 'http/1.1'
-
- def __init__(self, tcp_handler=None, rfile=None, wfile=None):
- self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile)
-
- def read_request(
- self,
- include_body=True,
- body_size_limit=None,
- allow_empty=False,
- ):
- """
- Parse an HTTP request from a file stream
-
- Args:
- include_body (bool): Read response body as well
- body_size_limit (bool): Maximum body size
- wfile (file): If specified, HTTP Expect headers are handled
- automatically, by writing a HTTP 100 CONTINUE response to the stream.
-
- Returns:
- Request: The HTTP request
-
- Raises:
- HttpError: If the input is invalid.
- """
- timestamp_start = time.time()
- if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
- self.tcp_handler.rfile.reset_timestamps()
-
- httpversion, host, port, scheme, method, path, headers, body = (
- None, None, None, None, None, None, None, None)
-
- request_line = self._get_request_line()
- if not request_line:
- if allow_empty:
- return http.EmptyRequest()
- else:
- raise tcp.NetLibDisconnect()
-
- request_line_parts = self._parse_init(request_line)
- if not request_line_parts:
- raise HttpError(
- 400,
- "Bad HTTP request line: %s" % repr(request_line)
- )
- method, path, httpversion = request_line_parts
-
- if path == '*' or path.startswith("/"):
- form_in = "relative"
- if not utils.isascii(path):
- raise HttpError(
- 400,
- "Bad HTTP request line: %s" % repr(request_line)
- )
- elif method == 'CONNECT':
- form_in = "authority"
- r = self._parse_init_connect(request_line)
- if not r:
- raise HttpError(
- 400,
- "Bad HTTP request line: %s" % repr(request_line)
- )
- host, port, httpversion = r
- path = None
- else:
- form_in = "absolute"
- r = self._parse_init_proxy(request_line)
- if not r:
- raise HttpError(
- 400,
- "Bad HTTP request line: %s" % repr(request_line)
- )
- _, scheme, host, port, path, _ = r
-
- headers = self.read_headers()
- if headers is None:
- raise HttpError(400, "Invalid headers")
-
- expect_header = headers.get("expect", "").lower()
- if expect_header == "100-continue" and httpversion == (1, 1):
- self.tcp_handler.wfile.write(
- 'HTTP/1.1 100 Continue\r\n'
- '\r\n'
- )
- self.tcp_handler.wfile.flush()
- del headers['expect']
-
- if include_body:
- body = self.read_http_body(
- headers,
- body_size_limit,
- method,
- None,
- True
- )
-
- if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
- # more accurate timestamp_start
- timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
-
- timestamp_end = time.time()
-
- return http.Request(
- form_in,
- method,
- scheme,
- host,
- port,
- path,
- httpversion,
- headers,
- body,
- timestamp_start,
- timestamp_end,
- )
-
- def read_response(
- self,
- request_method,
- body_size_limit=None,
- include_body=True,
- ):
- """
- Returns an http.Response
-
- By default, both response header and body are read.
- If include_body=False is specified, body may be one of the
- following:
- - None, if the response is technically allowed to have a response body
- - "", if the response must not have a response body (e.g. it's a
- response to a HEAD request)
- """
- timestamp_start = time.time()
- if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
- self.tcp_handler.rfile.reset_timestamps()
-
- line = self.tcp_handler.rfile.readline()
- # Possible leftover from previous message
- if line == "\r\n" or line == "\n":
- line = self.tcp_handler.rfile.readline()
- if not line:
- raise HttpErrorConnClosed(502, "Server disconnect.")
- parts = self.parse_response_line(line)
- if not parts:
- raise HttpError(502, "Invalid server response: %s" % repr(line))
- proto, code, msg = parts
- httpversion = self._parse_http_protocol(proto)
- if httpversion is None:
- raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto))
- headers = self.read_headers()
- if headers is None:
- raise HttpError(502, "Invalid headers.")
-
- if include_body:
- body = self.read_http_body(
- headers,
- body_size_limit,
- request_method,
- code,
- False
- )
- else:
- # if include_body==False then a None body means the body should be
- # read separately
- body = None
-
- if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
- # more accurate timestamp_start
- timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
-
- if include_body:
- timestamp_end = time.time()
- else:
- timestamp_end = None
-
- return http.Response(
- httpversion,
- code,
- msg,
- headers,
- body,
- timestamp_start=timestamp_start,
- timestamp_end=timestamp_end,
- )
-
- def assemble_request(self, request):
- assert isinstance(request, semantics.Request)
-
- if request.body == semantics.CONTENT_MISSING:
- raise http.HttpError(
- 502,
- "Cannot assemble flow with CONTENT_MISSING"
- )
- first_line = self._assemble_request_first_line(request)
- headers = self._assemble_request_headers(request)
- return "%s\r\n%s\r\n%s" % (first_line, headers, request.body)
-
- def assemble_response(self, response):
- assert isinstance(response, semantics.Response)
-
- if response.body == semantics.CONTENT_MISSING:
- raise http.HttpError(
- 502,
- "Cannot assemble flow with CONTENT_MISSING"
- )
- first_line = self._assemble_response_first_line(response)
- headers = self._assemble_response_headers(response)
- return "%s\r\n%s\r\n%s" % (first_line, headers, response.body)
-
- def read_headers(self):
- """
- Read a set of headers.
- Stop once a blank line is reached.
-
- Return a Header object, or None if headers are invalid.
- """
- ret = []
- while True:
- line = self.tcp_handler.rfile.readline()
- if not line or line == '\r\n' or line == '\n':
- break
- if line[0] in ' \t':
- if not ret:
- return None
- # continued header
- ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip()
- else:
- i = line.find(':')
- # We're being liberal in what we accept, here.
- if i > 0:
- name = line[:i]
- value = line[i + 1:].strip()
- ret.append([name, value])
- else:
- return None
- return Headers(ret)
-
-
- def read_http_body(self, *args, **kwargs):
- return "".join(self.read_http_body_chunked(*args, **kwargs))
-
-
- def read_http_body_chunked(
- self,
- headers,
- limit,
- request_method,
- response_code,
- is_request,
- max_chunk_size=None
- ):
- """
- Read an HTTP message body:
- headers: A Header object
- limit: Size limit.
- is_request: True if the body to read belongs to a request, False
- otherwise
- """
- if max_chunk_size is None:
- max_chunk_size = limit or sys.maxsize
-
- expected_size = self.expected_http_body_size(
- headers, is_request, request_method, response_code
- )
-
- if expected_size is None:
- if self.has_chunked_encoding(headers):
- # Python 3: yield from
- for x in self._read_chunked(limit, is_request):
- yield x
- else: # pragma: nocover
- raise HttpError(
- 400 if is_request else 502,
- "Content-Length unknown but no chunked encoding"
- )
- elif expected_size >= 0:
- if limit is not None and expected_size > limit:
- raise HttpError(
- 400 if is_request else 509,
- "HTTP Body too large. Limit is %s, content-length was %s" % (
- limit, expected_size
- )
- )
- bytes_left = expected_size
- while bytes_left:
- chunk_size = min(bytes_left, max_chunk_size)
- content = self.tcp_handler.rfile.read(chunk_size)
- yield content
- bytes_left -= chunk_size
- else:
- bytes_left = limit or -1
- while bytes_left:
- chunk_size = min(bytes_left, max_chunk_size)
- content = self.tcp_handler.rfile.read(chunk_size)
- if not content:
- return
- yield content
- bytes_left -= chunk_size
- not_done = self.tcp_handler.rfile.read(1)
- if not_done:
- raise HttpError(
- 400 if is_request else 509,
- "HTTP Body too large. Limit is %s," % limit
- )
-
- @classmethod
- def expected_http_body_size(
- self,
- headers,
- is_request,
- request_method,
- response_code,
- ):
- """
- Returns the expected body length:
- - a positive integer, if the size is known in advance
- - None, if the size in unknown in advance (chunked encoding or invalid
- data)
- - -1, if all data should be read until end of stream.
-
- May raise HttpError.
- """
- # Determine response size according to
- # http://tools.ietf.org/html/rfc7230#section-3.3
- if request_method:
- request_method = request_method.upper()
-
- if (not is_request and (
- request_method == "HEAD" or
- (request_method == "CONNECT" and response_code == 200) or
- response_code in [204, 304] or
- 100 <= response_code <= 199)):
- return 0
- if self.has_chunked_encoding(headers):
- return None
- if "content-length" in headers:
- try:
- size = int(headers["content-length"])
- if size < 0:
- raise ValueError()
- return size
- except ValueError:
- return None
- if is_request:
- return 0
- return -1
-
-
- @classmethod
- def has_chunked_encoding(self, headers):
- return "chunked" in headers.get("transfer-encoding", "").lower()
-
-
- def _get_request_line(self):
- """
- Get a line, possibly preceded by a blank.
- """
- line = self.tcp_handler.rfile.readline()
- if line == "\r\n" or line == "\n":
- # Possible leftover from previous message
- line = self.tcp_handler.rfile.readline()
- return line
-
- def _read_chunked(self, limit, is_request):
- """
- Read a chunked HTTP body.
-
- May raise HttpError.
- """
- # FIXME: Should check if chunked is the final encoding in the headers
- # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3
- # 3.3 2.
- total = 0
- code = 400 if is_request else 502
- while True:
- line = self.tcp_handler.rfile.readline(128)
- if line == "":
- raise HttpErrorConnClosed(code, "Connection closed prematurely")
- if line != '\r\n' and line != '\n':
- try:
- length = int(line, 16)
- except ValueError:
- raise HttpError(
- code,
- "Invalid chunked encoding length: %s" % line
- )
- total += length
- if limit is not None and total > limit:
- msg = "HTTP Body too large. Limit is %s," \
- " chunked content longer than %s" % (limit, total)
- raise HttpError(code, msg)
- chunk = self.tcp_handler.rfile.read(length)
- suffix = self.tcp_handler.rfile.readline(5)
- if suffix != '\r\n':
- raise HttpError(code, "Malformed chunked body")
- if length == 0:
- return
- yield chunk
-
- @classmethod
- def _parse_http_protocol(self, line):
- """
- Parse an HTTP protocol declaration.
- Returns a (major, minor) tuple, or None.
- """
- if not line.startswith("HTTP/"):
- return None
- _, version = line.split('/', 1)
- if "." not in version:
- return None
- major, minor = version.split('.', 1)
- try:
- major = int(major)
- minor = int(minor)
- except ValueError:
- return None
- return major, minor
-
- @classmethod
- def _parse_init(self, line):
- try:
- method, url, protocol = string.split(line)
- except ValueError:
- return None
- httpversion = self._parse_http_protocol(protocol)
- if not httpversion:
- return None
- if not utils.isascii(method):
- return None
- return method, url, httpversion
-
- @classmethod
- def _parse_init_connect(self, line):
- """
- Returns (host, port, httpversion) if line is a valid CONNECT line.
- http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1
- """
- v = self._parse_init(line)
- if not v:
- return None
- method, url, httpversion = v
-
- if method.upper() != 'CONNECT':
- return None
- try:
- host, port = url.split(":")
- except ValueError:
- return None
- try:
- port = int(port)
- except ValueError:
- return None
- if not utils.is_valid_port(port):
- return None
- if not utils.is_valid_host(host):
- return None
- return host, port, httpversion
-
- @classmethod
- def _parse_init_proxy(self, line):
- v = self._parse_init(line)
- if not v:
- return None
- method, url, httpversion = v
-
- parts = utils.parse_url(url)
- if not parts:
- return None
- scheme, host, port, path = parts
- return method, scheme, host, port, path, httpversion
-
- @classmethod
- def _parse_init_http(self, line):
- """
- Returns (method, url, httpversion)
- """
- v = self._parse_init(line)
- if not v:
- return None
- method, url, httpversion = v
- if not utils.isascii(url):
- return None
- if not (url.startswith("/") or url == "*"):
- return None
- return method, url, httpversion
-
- @classmethod
- def connection_close(self, httpversion, headers):
- """
- Checks the message to see if the client connection should be closed
- according to RFC 2616 Section 8.1 Note that a connection should be
- closed as well if the response has been read until end of the stream.
- """
- # At first, check if we have an explicit Connection header.
- if "connection" in headers:
- toks = utils.get_header_tokens(headers, "connection")
- if "close" in toks:
- return True
- elif "keep-alive" in toks:
- return False
-
- # If we don't have a Connection header, HTTP 1.1 connections are assumed to
- # be persistent
- return httpversion != (1, 1)
-
- @classmethod
- def parse_response_line(self, line):
- parts = line.strip().split(" ", 2)
- if len(parts) == 2: # handle missing message gracefully
- parts.append("")
- if len(parts) != 3:
- return None
- proto, code, msg = parts
- try:
- code = int(code)
- except ValueError:
- return None
- return (proto, code, msg)
-
- @classmethod
- def _assemble_request_first_line(self, request):
- return request.legacy_first_line()
-
- def _assemble_request_headers(self, request):
- headers = request.headers.copy()
- for k in request._headers_to_strip_off:
- headers.pop(k, None)
- if 'host' not in headers and request.scheme and request.host and request.port:
- headers["Host"] = utils.hostport(
- request.scheme,
- request.host,
- request.port
- )
-
- # If content is defined (i.e. not None or CONTENT_MISSING), we always
- # add a content-length header.
- if request.body or request.body == "":
- headers["Content-Length"] = str(len(request.body))
-
- return str(headers)
-
- def _assemble_response_first_line(self, response):
- return 'HTTP/%s.%s %s %s' % (
- response.httpversion[0],
- response.httpversion[1],
- response.status_code,
- response.msg,
- )
-
- def _assemble_response_headers(
- self,
- response,
- preserve_transfer_encoding=False,
- ):
- headers = response.headers.copy()
- for k in response._headers_to_strip_off:
- headers.pop(k, None)
- if not preserve_transfer_encoding:
- headers.pop('Transfer-Encoding', None)
-
- # If body is defined (i.e. not None or CONTENT_MISSING), we always
- # add a content-length header.
- if response.body or response.body == "":
- headers["Content-Length"] = str(len(response.body))
-
- return str(headers)
diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py
new file mode 100644
index 00000000..62025d15
--- /dev/null
+++ b/netlib/http/http1/read.py
@@ -0,0 +1,360 @@
+from __future__ import absolute_import, print_function, division
+import time
+import sys
+import re
+
+from ... import utils
+from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException
+from .. import Request, Response, Headers
+from netlib.tcp import NetLibDisconnect
+
+
+def read_request(rfile, body_size_limit=None):
+ request = read_request_head(rfile)
+ expected_body_size = expected_http_body_size(request)
+ request.body = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit))
+ request.timestamp_end = time.time()
+ return request
+
+
+def read_request_head(rfile):
+ """
+ Parse an HTTP request head (request line + headers) from an input stream
+
+ Args:
+ rfile: The input stream
+
+ Returns:
+ The HTTP request object (without body)
+
+ Raises:
+ HttpReadDisconnect: No bytes can be read from rfile.
+ HttpSyntaxException: The input is malformed HTTP.
+ HttpException: Any other error occured.
+ """
+ timestamp_start = time.time()
+ if hasattr(rfile, "reset_timestamps"):
+ rfile.reset_timestamps()
+
+ form, method, scheme, host, port, path, http_version = _read_request_line(rfile)
+ headers = _read_headers(rfile)
+
+ if hasattr(rfile, "first_byte_timestamp"):
+ # more accurate timestamp_start
+ timestamp_start = rfile.first_byte_timestamp
+
+ return Request(
+ form, method, scheme, host, port, path, http_version, headers, None, timestamp_start
+ )
+
+
+def read_response(rfile, request, body_size_limit=None):
+ response = read_response_head(rfile)
+ expected_body_size = expected_http_body_size(request, response)
+ response.body = b"".join(read_body(rfile, expected_body_size, body_size_limit))
+ response.timestamp_end = time.time()
+ return response
+
+
+def read_response_head(rfile):
+ """
+ Parse an HTTP response head (response line + headers) from an input stream
+
+ Args:
+ rfile: The input stream
+
+ Returns:
+ The HTTP request object (without body)
+
+ Raises:
+ HttpReadDisconnect: No bytes can be read from rfile.
+ HttpSyntaxException: The input is malformed HTTP.
+ HttpException: Any other error occured.
+ """
+
+ timestamp_start = time.time()
+ if hasattr(rfile, "reset_timestamps"):
+ rfile.reset_timestamps()
+
+ http_version, status_code, message = _read_response_line(rfile)
+ headers = _read_headers(rfile)
+
+ if hasattr(rfile, "first_byte_timestamp"):
+ # more accurate timestamp_start
+ timestamp_start = rfile.first_byte_timestamp
+
+ return Response(http_version, status_code, message, headers, None, timestamp_start)
+
+
+def read_body(rfile, expected_size, limit=None, max_chunk_size=4096):
+ """
+ Read an HTTP message body
+
+ Args:
+ rfile: The input stream
+ expected_size: The expected body size (see :py:meth:`expected_body_size`)
+ limit: Maximum body size
+ max_chunk_size: Maximium chunk size that gets yielded
+
+ Returns:
+ A generator that yields byte chunks of the content.
+
+ Raises:
+ HttpException, if an error occurs
+
+ Caveats:
+ max_chunk_size is not considered if the transfer encoding is chunked.
+ """
+ if not limit or limit < 0:
+ limit = sys.maxsize
+ if not max_chunk_size:
+ max_chunk_size = limit
+
+ if expected_size is None:
+ for x in _read_chunked(rfile, limit):
+ yield x
+ elif expected_size >= 0:
+ if limit is not None and expected_size > limit:
+ raise HttpException(
+ "HTTP Body too large. "
+ "Limit is {}, content length was advertised as {}".format(limit, expected_size)
+ )
+ bytes_left = expected_size
+ while bytes_left:
+ chunk_size = min(bytes_left, max_chunk_size)
+ content = rfile.read(chunk_size)
+ if len(content) < chunk_size:
+ raise HttpException("Unexpected EOF")
+ yield content
+ bytes_left -= chunk_size
+ else:
+ bytes_left = limit
+ while bytes_left:
+ chunk_size = min(bytes_left, max_chunk_size)
+ content = rfile.read(chunk_size)
+ if not content:
+ return
+ yield content
+ bytes_left -= chunk_size
+ not_done = rfile.read(1)
+ if not_done:
+ raise HttpException("HTTP body too large. Limit is {}.".format(limit))
+
+
+def connection_close(http_version, headers):
+ """
+ Checks the message to see if the client connection should be closed
+ according to RFC 2616 Section 8.1.
+ """
+ # At first, check if we have an explicit Connection header.
+ if b"connection" in headers:
+ tokens = utils.get_header_tokens(headers, "connection")
+ if b"close" in tokens:
+ return True
+ elif b"keep-alive" in tokens:
+ return False
+
+ # If we don't have a Connection header, HTTP 1.1 connections are assumed to
+ # be persistent
+ return http_version != b"HTTP/1.1"
+
+
+def expected_http_body_size(request, response=None):
+ """
+ Returns:
+ The expected body length:
+ - a positive integer, if the size is known in advance
+ - None, if the size in unknown in advance (chunked encoding)
+ - -1, if all data should be read until end of stream.
+
+ Raises:
+ HttpSyntaxException, if the content length header is invalid
+ """
+ # Determine response size according to
+ # http://tools.ietf.org/html/rfc7230#section-3.3
+ if not response:
+ headers = request.headers
+ response_code = None
+ is_request = True
+ else:
+ headers = response.headers
+ response_code = response.status_code
+ is_request = False
+
+ if is_request:
+ if headers.get(b"expect", b"").lower() == b"100-continue":
+ return 0
+ else:
+ if request.method.upper() == b"HEAD":
+ return 0
+ if 100 <= response_code <= 199:
+ return 0
+ if response_code == 200 and request.method.upper() == b"CONNECT":
+ return 0
+ if response_code in (204, 304):
+ return 0
+
+ if b"chunked" in headers.get(b"transfer-encoding", b"").lower():
+ return None
+ if b"content-length" in headers:
+ try:
+ size = int(headers[b"content-length"])
+ if size < 0:
+ raise ValueError()
+ return size
+ except ValueError:
+ raise HttpSyntaxException("Unparseable Content Length")
+ if is_request:
+ return 0
+ return -1
+
+
+def _get_first_line(rfile):
+ try:
+ line = rfile.readline()
+ if line == b"\r\n" or line == b"\n":
+ # Possible leftover from previous message
+ line = rfile.readline()
+ except NetLibDisconnect:
+ raise HttpReadDisconnect()
+ if not line:
+ raise HttpReadDisconnect()
+ line = line.strip()
+ try:
+ line.decode("ascii")
+ except ValueError:
+ raise HttpSyntaxException("Non-ascii characters in first line: {}".format(line))
+ return line.strip()
+
+
+def _read_request_line(rfile):
+ line = _get_first_line(rfile)
+
+ try:
+ method, path, http_version = line.split(b" ")
+
+ if path == b"*" or path.startswith(b"/"):
+ form = "relative"
+ scheme, host, port = None, None, None
+ elif method == b"CONNECT":
+ form = "authority"
+ host, port = _parse_authority_form(path)
+ scheme, path = None, None
+ else:
+ form = "absolute"
+ scheme, host, port, path = utils.parse_url(path)
+
+ _check_http_version(http_version)
+ except ValueError:
+ raise HttpSyntaxException("Bad HTTP request line: {}".format(line))
+
+ return form, method, scheme, host, port, path, http_version
+
+
+def _parse_authority_form(hostport):
+ """
+ Returns (host, port) if hostport is a valid authority-form host specification.
+ http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1
+
+ Raises:
+ ValueError, if the input is malformed
+ """
+ try:
+ host, port = hostport.split(b":")
+ port = int(port)
+ if not utils.is_valid_host(host) or not utils.is_valid_port(port):
+ raise ValueError()
+ except ValueError:
+ raise HttpSyntaxException("Invalid host specification: {}".format(hostport))
+
+ return host, port
+
+
+def _read_response_line(rfile):
+ line = _get_first_line(rfile)
+
+ try:
+
+ parts = line.split(b" ", 2)
+ if len(parts) == 2: # handle missing message gracefully
+ parts.append(b"")
+
+ http_version, status_code, message = parts
+ status_code = int(status_code)
+ _check_http_version(http_version)
+
+ except ValueError:
+ raise HttpSyntaxException("Bad HTTP response line: {}".format(line))
+
+ return http_version, status_code, message
+
+
+def _check_http_version(http_version):
+ if not re.match(br"^HTTP/\d\.\d$", http_version):
+ raise HttpSyntaxException("Unknown HTTP version: {}".format(http_version))
+
+
+def _read_headers(rfile):
+ """
+ Read a set of headers.
+ Stop once a blank line is reached.
+
+ Returns:
+ A headers object
+
+ Raises:
+ HttpSyntaxException
+ """
+ ret = []
+ while True:
+ line = rfile.readline()
+ if not line or line == b"\r\n" or line == b"\n":
+ break
+ if line[0] in b" \t":
+ if not ret:
+ raise HttpSyntaxException("Invalid headers")
+ # continued header
+ ret[-1][1] = ret[-1][1] + b'\r\n ' + line.strip()
+ else:
+ try:
+ name, value = line.split(b":", 1)
+ value = value.strip()
+ if not name or not value:
+ raise ValueError()
+ ret.append([name, value])
+ except ValueError:
+ raise HttpSyntaxException("Invalid headers")
+ return Headers(ret)
+
+
+def _read_chunked(rfile, limit=sys.maxsize):
+ """
+ Read a HTTP body with chunked transfer encoding.
+
+ Args:
+ rfile: the input file
+ limit: A positive integer
+ """
+ total = 0
+ while True:
+ line = rfile.readline(128)
+ if line == b"":
+ raise HttpException("Connection closed prematurely")
+ if line != b"\r\n" and line != b"\n":
+ try:
+ length = int(line, 16)
+ except ValueError:
+ raise HttpSyntaxException("Invalid chunked encoding length: {}".format(line))
+ total += length
+ if total > limit:
+ raise HttpException(
+ "HTTP Body too large. Limit is {}, "
+ "chunked content longer than {}".format(limit, total)
+ )
+ chunk = rfile.read(length)
+ suffix = rfile.readline(5)
+ if suffix != b"\r\n":
+ raise HttpSyntaxException("Malformed chunked body")
+ if length == 0:
+ return
+ yield chunk
diff --git a/netlib/http/http2/__init__.py b/netlib/http/http2/__init__.py
index 5acf7696..7043d36f 100644
--- a/netlib/http/http2/__init__.py
+++ b/netlib/http/http2/__init__.py
@@ -1,2 +1,6 @@
-from frame import *
-from protocol import *
+from __future__ import absolute_import, print_function, division
+from .connections import HTTP2Protocol
+
+__all__ = [
+ "HTTP2Protocol"
+]
diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/connections.py
index b6d376d3..5220d5d2 100644
--- a/netlib/http/http2/protocol.py
+++ b/netlib/http/http2/connections.py
@@ -3,8 +3,8 @@ import itertools
import time
from hpack.hpack import Encoder, Decoder
-from netlib import http, utils
-from netlib.http import semantics
+from ... import utils
+from .. import Headers, Response, Request, ALPN_PROTO_H2
from . import frame
@@ -15,7 +15,7 @@ class TCPHandler(object):
self.wfile = wfile
-class HTTP2Protocol(semantics.ProtocolMixin):
+class HTTP2Protocol(object):
ERROR_CODES = utils.BiDi(
NO_ERROR=0x0,
@@ -36,8 +36,6 @@ class HTTP2Protocol(semantics.ProtocolMixin):
CLIENT_CONNECTION_PREFACE = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"
- ALPN_PROTO_H2 = 'h2'
-
def __init__(
self,
tcp_handler=None,
@@ -62,6 +60,7 @@ class HTTP2Protocol(semantics.ProtocolMixin):
def read_request(
self,
+ __rfile,
include_body=True,
body_size_limit=None,
allow_empty=False,
@@ -111,7 +110,7 @@ class HTTP2Protocol(semantics.ProtocolMixin):
port = 80 if scheme == 'http' else 443
port = int(port)
- request = http.Request(
+ request = Request(
form_in,
method,
scheme,
@@ -131,6 +130,7 @@ class HTTP2Protocol(semantics.ProtocolMixin):
def read_response(
self,
+ __rfile,
request_method='',
body_size_limit=None,
include_body=True,
@@ -159,7 +159,7 @@ class HTTP2Protocol(semantics.ProtocolMixin):
else:
timestamp_end = None
- response = http.Response(
+ response = Response(
(2, 0),
int(headers.get(':status', 502)),
"",
@@ -172,8 +172,16 @@ class HTTP2Protocol(semantics.ProtocolMixin):
return response
+ def assemble(self, message):
+ if isinstance(message, Request):
+ return self.assemble_request(message)
+ elif isinstance(message, Response):
+ return self.assemble_response(message)
+ else:
+ raise ValueError("HTTP message not supported.")
+
def assemble_request(self, request):
- assert isinstance(request, semantics.Request)
+ assert isinstance(request, Request)
authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host
if self.tcp_handler.address.port != 443:
@@ -200,7 +208,7 @@ class HTTP2Protocol(semantics.ProtocolMixin):
self._create_body(request.body, stream_id)))
def assemble_response(self, response):
- assert isinstance(response, semantics.Response)
+ assert isinstance(response, Response)
headers = response.headers.copy()
@@ -275,7 +283,7 @@ class HTTP2Protocol(semantics.ProtocolMixin):
def check_alpn(self):
alp = self.tcp_handler.get_alpn_proto_negotiated()
- if alp != self.ALPN_PROTO_H2:
+ if alp != ALPN_PROTO_H2:
raise NotImplementedError(
"HTTP2Protocol can not handle unknown ALP: %s" % alp)
return True
@@ -405,7 +413,7 @@ class HTTP2Protocol(semantics.ProtocolMixin):
else:
self._handle_unexpected_frame(frm)
- headers = http.Headers(
+ headers = Headers(
[[str(k), str(v)] for k, v in self.decoder.decode(header_block_fragment)]
)
diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frame.py
index b36b3adf..cb2cde99 100644
--- a/netlib/http/http2/frame.py
+++ b/netlib/http/http2/frame.py
@@ -1,12 +1,31 @@
-import sys
+from __future__ import absolute_import, print_function, division
import struct
from hpack.hpack import Encoder, Decoder
-from .. import utils
+from ...utils import BiDi
+from ...exceptions import HttpSyntaxException
-class FrameSizeError(Exception):
- pass
+ERROR_CODES = BiDi(
+ NO_ERROR=0x0,
+ PROTOCOL_ERROR=0x1,
+ INTERNAL_ERROR=0x2,
+ FLOW_CONTROL_ERROR=0x3,
+ SETTINGS_TIMEOUT=0x4,
+ STREAM_CLOSED=0x5,
+ FRAME_SIZE_ERROR=0x6,
+ REFUSED_STREAM=0x7,
+ CANCEL=0x8,
+ COMPRESSION_ERROR=0x9,
+ CONNECT_ERROR=0xa,
+ ENHANCE_YOUR_CALM=0xb,
+ INADEQUATE_SECURITY=0xc,
+ HTTP_1_1_REQUIRED=0xd
+)
+
+CLIENT_CONNECTION_PREFACE = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"
+
+ALPN_PROTO_H2 = b'h2'
class Frame(object):
@@ -30,7 +49,9 @@ class Frame(object):
length=0,
flags=FLAG_NO_FLAGS,
stream_id=0x0):
- valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0)
+ valid_flags = 0
+ for flag in self.VALID_FLAGS:
+ valid_flags |= flag
if flags | valid_flags != valid_flags:
raise ValueError('invalid flags detected.')
@@ -61,7 +82,7 @@ class Frame(object):
SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE]
if length > max_frame_size:
- raise FrameSizeError(
+ raise HttpSyntaxException(
"Frame size exceeded: %d, but only %d allowed." % (
length, max_frame_size))
@@ -80,7 +101,7 @@ class Frame(object):
stream_id = fields[4]
if raw_header[:4] == b'HTTP': # pragma no cover
- print >> sys.stderr, "WARNING: This looks like an HTTP/1 connection!"
+ raise HttpSyntaxException("Expected HTTP2 Frame, got HTTP/1 connection")
cls._check_frame_size(length, state)
@@ -339,7 +360,7 @@ class SettingsFrame(Frame):
TYPE = 0x4
VALID_FLAGS = [Frame.FLAG_ACK]
- SETTINGS = utils.BiDi(
+ SETTINGS = BiDi(
SETTINGS_HEADER_TABLE_SIZE=0x1,
SETTINGS_ENABLE_PUSH=0x2,
SETTINGS_MAX_CONCURRENT_STREAMS=0x3,
@@ -366,7 +387,7 @@ class SettingsFrame(Frame):
def from_bytes(cls, state, length, flags, stream_id, payload):
f = cls(state=state, length=length, flags=flags, stream_id=stream_id)
- for i in xrange(0, len(payload), 6):
+ for i in range(0, len(payload), 6):
identifier, value = struct.unpack("!HL", payload[i:i + 6])
f.settings[identifier] = value
diff --git a/netlib/http/semantics.py b/netlib/http/models.py
index 5bb098a7..2d09535c 100644
--- a/netlib/http/semantics.py
+++ b/netlib/http/models.py
@@ -1,20 +1,28 @@
-from __future__ import (absolute_import, print_function, division)
-import UserDict
+from __future__ import absolute_import, print_function, division
import copy
-import urllib
-import urlparse
-from .. import odict
-from . import cookies, exceptions
-from netlib import utils, encoding
+from ..odict import ODict
+from .. import utils, encoding
+from ..utils import always_bytes, always_byte_args
+from . import cookies
-HDR_FORM_URLENCODED = "application/x-www-form-urlencoded"
-HDR_FORM_MULTIPART = "multipart/form-data"
+import six
+from six.moves import urllib
+try:
+ from collections import MutableMapping
+except ImportError:
+ from collections.abc import MutableMapping
+
+# TODO: Move somewhere else?
+ALPN_PROTO_HTTP1 = b'http/1.1'
+ALPN_PROTO_H2 = b'h2'
+HDR_FORM_URLENCODED = b"application/x-www-form-urlencoded"
+HDR_FORM_MULTIPART = b"multipart/form-data"
CONTENT_MISSING = 0
-class Headers(object, UserDict.DictMixin):
+class Headers(MutableMapping, object):
"""
Header class which allows both convenient access to individual headers as well as
direct access to the underlying raw data. Provides a full dictionary interface.
@@ -62,10 +70,12 @@ class Headers(object, UserDict.DictMixin):
For use with the "Set-Cookie" header, see :py:meth:`get_all`.
"""
+ @always_byte_args("ascii")
def __init__(self, fields=None, **headers):
"""
Args:
- fields: (optional) list of ``(name, value)`` header tuples, e.g. ``[("Host","example.com")]``
+ fields: (optional) list of ``(name, value)`` header tuples,
+ e.g. ``[("Host","example.com")]``. All names and values must be bytes.
**headers: Additional headers to set. Will overwrite existing values from `fields`.
For convenience, underscores in header names will be transformed to dashes -
this behaviour does not extend to other methods.
@@ -76,21 +86,25 @@ class Headers(object, UserDict.DictMixin):
# content_type -> content-type
headers = {
- name.replace("_", "-"): value
- for name, value in headers.iteritems()
+ name.encode("ascii").replace(b"_", b"-"): value
+ for name, value in six.iteritems(headers)
}
self.update(headers)
- def __str__(self):
- return "\r\n".join(": ".join(field) for field in self.fields) + "\r\n"
+ def __bytes__(self):
+ return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n"
+
+ if six.PY2:
+ __str__ = __bytes__
+ @always_byte_args("ascii")
def __getitem__(self, name):
values = self.get_all(name)
if not values:
raise KeyError(name)
- else:
- return ", ".join(values)
+ return b", ".join(values)
+ @always_byte_args("ascii")
def __setitem__(self, name, value):
idx = self._index(name)
@@ -101,6 +115,7 @@ class Headers(object, UserDict.DictMixin):
else:
self.fields.append([name, value])
+ @always_byte_args("ascii")
def __delitem__(self, name):
if name not in self:
raise KeyError(name)
@@ -110,6 +125,19 @@ class Headers(object, UserDict.DictMixin):
if name != field[0].lower()
]
+ def __iter__(self):
+ seen = set()
+ for name, _ in self.fields:
+ name_lower = name.lower()
+ if name_lower not in seen:
+ seen.add(name_lower)
+ yield name
+
+ def __len__(self):
+ return len(set(name.lower() for name, _ in self.fields))
+
+ #__hash__ = object.__hash__
+
def _index(self, name):
name = name.lower()
for i, field in enumerate(self.fields):
@@ -117,16 +145,6 @@ class Headers(object, UserDict.DictMixin):
return i
return None
- def keys(self):
- seen = set()
- names = []
- for name, _ in self.fields:
- name_lower = name.lower()
- if name_lower not in seen:
- seen.add(name_lower)
- names.append(name)
- return names
-
def __eq__(self, other):
if isinstance(other, Headers):
return self.fields == other.fields
@@ -135,6 +153,7 @@ class Headers(object, UserDict.DictMixin):
def __ne__(self, other):
return not self.__eq__(other)
+ @always_byte_args("ascii")
def get_all(self, name):
"""
Like :py:meth:`get`, but does not fold multiple headers into a single one.
@@ -142,8 +161,8 @@ class Headers(object, UserDict.DictMixin):
See also: https://tools.ietf.org/html/rfc7230#section-3.2.2
"""
- name = name.lower()
- values = [value for n, value in self.fields if n.lower() == name]
+ name_lower = name.lower()
+ values = [value for n, value in self.fields if n.lower() == name_lower]
return values
def set_all(self, name, values):
@@ -151,6 +170,8 @@ class Headers(object, UserDict.DictMixin):
Explicitly set multiple headers for the given key.
See: :py:meth:`get_all`
"""
+ name = always_bytes(name, "ascii")
+ values = (always_bytes(value, "ascii") for value in values)
if name in self:
del self[name]
self.fields.extend(
@@ -172,28 +193,6 @@ class Headers(object, UserDict.DictMixin):
return cls([list(field) for field in state])
-class ProtocolMixin(object):
- def read_request(self, *args, **kwargs): # pragma: no cover
- raise NotImplementedError
-
- def read_response(self, *args, **kwargs): # pragma: no cover
- raise NotImplementedError
-
- def assemble(self, message):
- if isinstance(message, Request):
- return self.assemble_request(message)
- elif isinstance(message, Response):
- return self.assemble_response(message)
- else:
- raise ValueError("HTTP message not supported.")
-
- def assemble_request(self, *args, **kwargs): # pragma: no cover
- raise NotImplementedError
-
- def assemble_response(self, *args, **kwargs): # pragma: no cover
- raise NotImplementedError
-
-
class Request(object):
# This list is adopted legacy code.
# We probably don't need to strip off keep-alive.
@@ -248,42 +247,14 @@ class Request(object):
return False
def __repr__(self):
- # return "Request(%s - %s, %s)" % (self.method, self.host, self.path)
-
- return "<HTTPRequest: {0}>".format(
- self.legacy_first_line()[:-9]
- )
-
- def legacy_first_line(self, form=None):
- if form is None:
- form = self.form_out
- if form == "relative":
- return '%s %s HTTP/%s.%s' % (
- self.method,
- self.path,
- self.httpversion[0],
- self.httpversion[1],
- )
- elif form == "authority":
- return '%s %s:%s HTTP/%s.%s' % (
- self.method,
- self.host,
- self.port,
- self.httpversion[0],
- self.httpversion[1],
- )
- elif form == "absolute":
- return '%s %s://%s:%s%s HTTP/%s.%s' % (
- self.method,
- self.scheme,
- self.host,
- self.port,
- self.path,
- self.httpversion[0],
- self.httpversion[1],
- )
+ if self.host and self.port:
+ hostport = "{}:{}".format(self.host, self.port)
else:
- raise exceptions.HttpError(400, "Invalid request form")
+ hostport = ""
+ path = self.path or ""
+ return "HTTPRequest({} {}{})".format(
+ self.method, hostport, path
+ )
def anticache(self):
"""
@@ -336,7 +307,7 @@ class Request(object):
return self.get_form_urlencoded()
elif HDR_FORM_MULTIPART in self.headers.get("content-type","").lower():
return self.get_form_multipart()
- return odict.ODict([])
+ return ODict([])
def get_form_urlencoded(self):
"""
@@ -345,16 +316,16 @@ class Request(object):
indicates non-form data.
"""
if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type","").lower():
- return odict.ODict(utils.urldecode(self.body))
- return odict.ODict([])
+ return ODict(utils.urldecode(self.body))
+ return ODict([])
def get_form_multipart(self):
if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type","").lower():
- return odict.ODict(
+ return ODict(
utils.multipartdecode(
self.headers,
self.body))
- return odict.ODict([])
+ return ODict([])
def set_form_urlencoded(self, odict):
"""
@@ -373,8 +344,8 @@ class Request(object):
Components are unquoted.
"""
- _, _, path, _, _, _ = urlparse.urlparse(self.url)
- return [urllib.unquote(i) for i in path.split("/") if i]
+ _, _, path, _, _, _ = urllib.parse.urlparse(self.url)
+ return [urllib.parse.unquote(i) for i in path.split(b"/") if i]
def set_path_components(self, lst):
"""
@@ -382,10 +353,10 @@ class Request(object):
Components are quoted.
"""
- lst = [urllib.quote(i, safe="") for i in lst]
- path = "/" + "/".join(lst)
- scheme, netloc, _, params, query, fragment = urlparse.urlparse(self.url)
- self.url = urlparse.urlunparse(
+ lst = [urllib.parse.quote(i, safe="") for i in lst]
+ path = b"/" + b"/".join(lst)
+ scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url)
+ self.url = urllib.parse.urlunparse(
[scheme, netloc, path, params, query, fragment]
)
@@ -393,18 +364,18 @@ class Request(object):
"""
Gets the request query string. Returns an ODict object.
"""
- _, _, _, _, query, _ = urlparse.urlparse(self.url)
+ _, _, _, _, query, _ = urllib.parse.urlparse(self.url)
if query:
- return odict.ODict(utils.urldecode(query))
- return odict.ODict([])
+ return ODict(utils.urldecode(query))
+ return ODict([])
def set_query(self, odict):
"""
Takes an ODict object, and sets the request query string.
"""
- scheme, netloc, path, params, _, fragment = urlparse.urlparse(self.url)
+ scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url)
query = utils.urlencode(odict.lst)
- self.url = urlparse.urlunparse(
+ self.url = urllib.parse.urlunparse(
[scheme, netloc, path, params, query, fragment]
)
@@ -421,18 +392,13 @@ class Request(object):
but not the resolved name. This is disabled by default, as an
attacker may spoof the host header to confuse an analyst.
"""
- host = None
- if hostheader:
- host = self.headers.get("Host")
- if not host:
- host = self.host
- if host:
+ if hostheader and b"Host" in self.headers:
try:
- return host.encode("idna")
+ return self.headers[b"Host"].decode("idna")
except ValueError:
- return host
- else:
- return None
+ pass
+ if self.host:
+ return self.host.decode("idna")
def pretty_url(self, hostheader):
if self.form_out == "authority": # upstream proxy mode
@@ -446,7 +412,7 @@ class Request(object):
"""
Returns a possibly empty netlib.odict.ODict object.
"""
- ret = odict.ODict()
+ ret = ODict()
for i in self.headers.get_all("cookie"):
ret.extend(cookies.parse_cookie_header(i))
return ret
@@ -477,8 +443,10 @@ class Request(object):
Parses a URL specification, and updates the Request's information
accordingly.
- Returns False if the URL was invalid, True if the request succeeded.
+ Raises:
+ ValueError if the URL was invalid
"""
+ # TODO: Should handle incoming unicode here.
parts = utils.parse_url(url)
if not parts:
raise ValueError("Invalid URL: %s" % url)
@@ -495,32 +463,6 @@ class Request(object):
self.body = content
-class EmptyRequest(Request):
- def __init__(
- self,
- form_in="",
- method="",
- scheme="",
- host="",
- port="",
- path="",
- httpversion=(0, 0),
- headers=None,
- body=""
- ):
- super(EmptyRequest, self).__init__(
- form_in=form_in,
- method=method,
- scheme=scheme,
- host=host,
- port=port,
- path=path,
- httpversion=httpversion,
- headers=headers,
- body=body,
- )
-
-
class Response(object):
_headers_to_strip_off = [
'Proxy-Connection',
@@ -535,7 +477,6 @@ class Response(object):
msg=None,
headers=None,
body=None,
- sslinfo=None,
timestamp_start=None,
timestamp_end=None,
):
@@ -548,7 +489,6 @@ class Response(object):
self.msg = msg
self.headers = headers
self.body = body
- self.sslinfo = sslinfo
self.timestamp_start = timestamp_start
self.timestamp_end = timestamp_end
@@ -591,7 +531,7 @@ class Response(object):
if v:
name, value, attrs = v
ret.append([name, [value, attrs]])
- return odict.ODict(ret)
+ return ODict(ret)
def set_cookies(self, odict):
"""
diff --git a/netlib/tcp.py b/netlib/tcp.py
index 4a7f6153..1eb417b4 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -834,14 +834,14 @@ class TCPServer(object):
# If a thread has persisted after interpreter exit, the module might be
# none.
if traceback:
- exc = traceback.format_exc()
- print('-' * 40, file=fp)
+ exc = six.text_type(traceback.format_exc())
+ print(u'-' * 40, file=fp)
print(
- "Error in processing of request from %s:%s" % (
+ u"Error in processing of request from %s:%s" % (
client_address.host, client_address.port
), file=fp)
print(exc, file=fp)
- print('-' * 40, file=fp)
+ print(u'-' * 40, file=fp)
def handle_client_connection(self, conn, client_address): # pragma: no cover
"""
diff --git a/netlib/tutils.py b/netlib/tutils.py
index 951ef3d9..05791c49 100644
--- a/netlib/tutils.py
+++ b/netlib/tutils.py
@@ -1,18 +1,22 @@
-import cStringIO
+from io import BytesIO
import tempfile
import os
import time
import shutil
from contextlib import contextmanager
+import six
+import sys
-from netlib import tcp, utils, http
+from . import utils
+from .http import Request, Response, Headers
def treader(bytes):
"""
Construct a tcp.Read object from bytes.
"""
- fp = cStringIO.StringIO(bytes)
+ from . import tcp # TODO: move to top once cryptography is on Python 3.5
+ fp = BytesIO(bytes)
return tcp.Reader(fp)
@@ -28,7 +32,24 @@ def tmpdir(*args, **kwargs):
shutil.rmtree(temp_workdir)
-def raises(exc, obj, *args, **kwargs):
+def _check_exception(expected, actual, exc_tb):
+ if isinstance(expected, six.string_types):
+ if expected.lower() not in str(actual).lower():
+ six.reraise(AssertionError, AssertionError(
+ "Expected %s, but caught %s" % (
+ repr(expected), repr(actual)
+ )
+ ), exc_tb)
+ else:
+ if not isinstance(actual, expected):
+ six.reraise(AssertionError, AssertionError(
+ "Expected %s, but caught %s %s" % (
+ expected.__name__, actual.__class__.__name__, repr(actual)
+ )
+ ), exc_tb)
+
+
+def raises(expected_exception, obj=None, *args, **kwargs):
"""
Assert that a callable raises a specified exception.
@@ -43,81 +64,68 @@ def raises(exc, obj, *args, **kwargs):
:kwargs Arguments to be passed to the callable.
"""
- try:
- ret = obj(*args, **kwargs)
- except Exception as v:
- if isinstance(exc, basestring):
- if exc.lower() in str(v).lower():
- return
- else:
- raise AssertionError(
- "Expected %s, but caught %s" % (
- repr(str(exc)), v
- )
- )
+ if obj is None:
+ return RaisesContext(expected_exception)
+ else:
+ try:
+ ret = obj(*args, **kwargs)
+ except Exception as actual:
+ _check_exception(expected_exception, actual, sys.exc_info()[2])
else:
- if isinstance(v, exc):
- return
- else:
- raise AssertionError(
- "Expected %s, but caught %s %s" % (
- exc.__name__, v.__class__.__name__, str(v)
- )
- )
- raise AssertionError("No exception raised. Return value: {}".format(ret))
+ raise AssertionError("No exception raised. Return value: {}".format(ret))
-test_data = utils.Data(__name__)
+class RaisesContext(object):
+ def __init__(self, expected_exception):
+ self.expected_exception = expected_exception
-def treq(content="content", scheme="http", host="address", port=22):
- """
- @return: libmproxy.protocol.http.HTTPRequest
- """
- headers = http.Headers()
- headers["header"] = "qvalue"
- req = http.Request(
- "relative",
- "GET",
- scheme,
- host,
- port,
- "/path",
- (1, 1),
- headers,
- content,
- None,
- None,
- )
- return req
+ def __enter__(self):
+ return
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if not exc_type:
+ raise AssertionError("No exception raised.")
+ else:
+ _check_exception(self.expected_exception, exc_val, exc_tb)
+ return True
-def treq_absolute(content="content"):
- """
- @return: libmproxy.protocol.http.HTTPRequest
- """
- r = treq(content)
- r.form_in = r.form_out = "absolute"
- r.host = "address"
- r.port = 22
- r.scheme = "http"
- return r
+test_data = utils.Data(__name__)
-def tresp(content="message"):
+
+def treq(**kwargs):
"""
- @return: libmproxy.protocol.http.HTTPResponse
+ Returns:
+ netlib.http.Request
"""
+ default = dict(
+ form_in="relative",
+ method=b"GET",
+ scheme=b"http",
+ host=b"address",
+ port=22,
+ path=b"/path",
+ httpversion=b"HTTP/1.1",
+ headers=Headers(header=b"qvalue"),
+ body=b"content"
+ )
+ default.update(kwargs)
+ return Request(**default)
- headers = http.Headers()
- headers["header_response"] = "svalue"
- resp = http.semantics.Response(
- (1, 1),
- 200,
- "OK",
- headers,
- content,
+def tresp(**kwargs):
+ """
+ Returns:
+ netlib.http.Response
+ """
+ default = dict(
+ httpversion=b"HTTP/1.1",
+ status_code=200,
+ msg=b"OK",
+ headers=Headers(header_response=b"svalue"),
+ body=b"message",
timestamp_start=time.time(),
- timestamp_end=time.time(),
+ timestamp_end=time.time()
)
- return resp
+ default.update(kwargs)
+ return Response(**default)
diff --git a/netlib/utils.py b/netlib/utils.py
index d6774419..a86b8019 100644
--- a/netlib/utils.py
+++ b/netlib/utils.py
@@ -1,17 +1,17 @@
-from __future__ import (absolute_import, print_function, division)
+from __future__ import absolute_import, print_function, division
import os.path
-import cgi
-import urllib
-import urlparse
-import string
import re
-import six
+import string
import unicodedata
+import six
+
+from six.moves import urllib
+
-def isascii(s):
+def isascii(bytes):
try:
- s.decode("ascii")
+ bytes.decode("ascii")
except ValueError:
return False
return True
@@ -40,12 +40,12 @@ def clean_bin(s, keep_spacing=True):
)
else:
if keep_spacing:
- keep = b"\n\r\t"
+ keep = (9, 10, 13) # \t, \n, \r,
else:
- keep = b""
+ keep = ()
return b"".join(
- ch if (31 < ord(ch) < 127 or ch in keep) else b"."
- for ch in s
+ six.int2byte(ch) if (31 < ch < 127 or ch in keep) else b"."
+ for ch in six.iterbytes(s)
)
@@ -149,10 +149,7 @@ class Data(object):
return fullpath
-def is_valid_port(port):
- if not 0 <= port <= 65535:
- return False
- return True
+_label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE)
def is_valid_host(host):
@@ -160,53 +157,79 @@ def is_valid_host(host):
host.decode("idna")
except ValueError:
return False
- if "\0" in host:
- return None
- return True
+ if len(host) > 255:
+ return False
+ if host[-1] == ".":
+ host = host[:-1]
+ return all(_label_valid.match(x) for x in host.split(b"."))
+
+
+def is_valid_port(port):
+ return 0 <= port <= 65535
+
+
+# PY2 workaround
+def decode_parse_result(result, enc):
+ if hasattr(result, "decode"):
+ return result.decode(enc)
+ else:
+ return urllib.parse.ParseResult(*[x.decode(enc) for x in result])
+
+
+# PY2 workaround
+def encode_parse_result(result, enc):
+ if hasattr(result, "encode"):
+ return result.encode(enc)
+ else:
+ return urllib.parse.ParseResult(*[x.encode(enc) for x in result])
def parse_url(url):
"""
- Returns a (scheme, host, port, path) tuple, or None on error.
+ URL-parsing function that checks that
+ - port is an integer 0-65535
+ - host is a valid IDNA-encoded hostname with no null-bytes
+ - path is valid ASCII
- Checks that:
- port is an integer 0-65535
- host is a valid IDNA-encoded hostname with no null-bytes
- path is valid ASCII
+ Args:
+ A URL (as bytes or as unicode)
+
+ Returns:
+ A (scheme, host, port, path) tuple
+
+ Raises:
+ ValueError, if the URL is not properly formatted.
"""
- try:
- scheme, netloc, path, params, query, fragment = urlparse.urlparse(url)
- except ValueError:
- return None
- if not scheme:
- return None
- if '@' in netloc:
- # FIXME: Consider what to do with the discarded credentials here Most
- # probably we should extend the signature to return these as a separate
- # value.
- _, netloc = string.rsplit(netloc, '@', maxsplit=1)
- if ':' in netloc:
- host, port = string.rsplit(netloc, ':', maxsplit=1)
- try:
- port = int(port)
- except ValueError:
- return None
+ parsed = urllib.parse.urlparse(url)
+
+ if not parsed.hostname:
+ raise ValueError("No hostname given")
+
+ if isinstance(url, six.binary_type):
+ host = parsed.hostname
+
+ # this should not raise a ValueError
+ decode_parse_result(parsed, "ascii")
else:
- host = netloc
- if scheme.endswith("https"):
- port = 443
- else:
- port = 80
- path = urlparse.urlunparse(('', '', path, params, query, fragment))
- if not path.startswith("/"):
- path = "/" + path
+ host = parsed.hostname.encode("idna")
+ parsed = encode_parse_result(parsed, "ascii")
+
+ port = parsed.port
+ if not port:
+ port = 443 if parsed.scheme == b"https" else 80
+
+ full_path = urllib.parse.urlunparse(
+ (b"", b"", parsed.path, parsed.params, parsed.query, parsed.fragment)
+ )
+ if not full_path.startswith(b"/"):
+ full_path = b"/" + full_path
+
if not is_valid_host(host):
- return None
- if not isascii(path):
- return None
+ raise ValueError("Invalid Host")
if not is_valid_port(port):
- return None
- return scheme, host, port, path
+ raise ValueError("Invalid Port")
+
+ return parsed.scheme, host, port, full_path
def get_header_tokens(headers, key):
@@ -217,7 +240,7 @@ def get_header_tokens(headers, key):
"""
if key not in headers:
return []
- tokens = headers[key].split(",")
+ tokens = headers[key].split(b",")
return [token.strip() for token in tokens]
@@ -228,7 +251,7 @@ def hostport(scheme, host, port):
if (port, scheme) in [(80, "http"), (443, "https")]:
return host
else:
- return "%s:%s" % (host, port)
+ return b"%s:%d" % (host, port)
def unparse_url(scheme, host, port, path=""):
@@ -243,14 +266,14 @@ def urlencode(s):
Takes a list of (key, value) tuples and returns a urlencoded string.
"""
s = [tuple(i) for i in s]
- return urllib.urlencode(s, False)
+ return urllib.parse.urlencode(s, False)
def urldecode(s):
"""
Takes a urlencoded string and returns a list of (key, value) tuples.
"""
- return cgi.parse_qsl(s, keep_blank_values=True)
+ return urllib.parse.parse_qsl(s, keep_blank_values=True)
def parse_content_type(c):
@@ -267,14 +290,14 @@ def parse_content_type(c):
("text", "html", {"charset": "UTF-8"})
"""
- parts = c.split(";", 1)
- ts = parts[0].split("/", 1)
+ parts = c.split(b";", 1)
+ ts = parts[0].split(b"/", 1)
if len(ts) != 2:
return None
d = {}
if len(parts) == 2:
- for i in parts[1].split(";"):
- clause = i.split("=", 1)
+ for i in parts[1].split(b";"):
+ clause = i.split(b"=", 1)
if len(clause) == 2:
d[clause[0].strip()] = clause[1].strip()
return ts[0].lower(), ts[1].lower(), d
@@ -289,7 +312,7 @@ def multipartdecode(headers, content):
v = parse_content_type(v)
if not v:
return []
- boundary = v[2].get("boundary")
+ boundary = v[2].get(b"boundary")
if not boundary:
return []
@@ -306,3 +329,20 @@ def multipartdecode(headers, content):
r.append((key, value))
return r
return []
+
+
+def always_bytes(unicode_or_bytes, encoding):
+ if isinstance(unicode_or_bytes, six.text_type):
+ return unicode_or_bytes.encode(encoding)
+ return unicode_or_bytes
+
+
+def always_byte_args(encoding):
+ """Decorator that transparently encodes all arguments passed as unicode"""
+ def decorator(fun):
+ def _fun(*args, **kwargs):
+ args = [always_bytes(arg, encoding) for arg in args]
+ kwargs = {k: always_bytes(v, encoding) for k, v in six.iteritems(kwargs)}
+ return fun(*args, **kwargs)
+ return _fun
+ return decorator
diff --git a/netlib/version_check.py b/netlib/version_check.py
index 1d7e025c..9cf27eea 100644
--- a/netlib/version_check.py
+++ b/netlib/version_check.py
@@ -7,6 +7,7 @@ from __future__ import division, absolute_import, print_function
import sys
import inspect
import os.path
+import six
import OpenSSL
from . import version
@@ -19,8 +20,8 @@ def check_mitmproxy_version(mitmproxy_version, fp=sys.stderr):
# consider major and minor version.
if version.IVERSION[:2] != mitmproxy_version[:2]:
print(
- "You are using mitmproxy %s with netlib %s. "
- "Most likely, that won't work - please upgrade!" % (
+ u"You are using mitmproxy %s with netlib %s. "
+ u"Most likely, that won't work - please upgrade!" % (
mitmproxy_version, version.VERSION
),
file=fp
@@ -29,13 +30,13 @@ def check_mitmproxy_version(mitmproxy_version, fp=sys.stderr):
def check_pyopenssl_version(min_version=PYOPENSSL_MIN_VERSION, fp=sys.stderr):
- min_version_str = ".".join(str(x) for x in min_version)
+ min_version_str = u".".join(six.text_type(x) for x in min_version)
try:
v = tuple(int(x) for x in OpenSSL.__version__.split(".")[:2])
except ValueError:
print(
- "Cannot parse pyOpenSSL version: {}"
- "mitmproxy requires pyOpenSSL {} or greater.".format(
+ u"Cannot parse pyOpenSSL version: {}"
+ u"mitmproxy requires pyOpenSSL {} or greater.".format(
OpenSSL.__version__, min_version_str
),
file=fp
@@ -43,15 +44,15 @@ def check_pyopenssl_version(min_version=PYOPENSSL_MIN_VERSION, fp=sys.stderr):
return
if v < min_version:
print(
- "You are using an outdated version of pyOpenSSL: "
- "mitmproxy requires pyOpenSSL {} or greater.".format(min_version_str),
+ u"You are using an outdated version of pyOpenSSL: "
+ u"mitmproxy requires pyOpenSSL {} or greater.".format(min_version_str),
file=fp
)
# Some users apparently have multiple versions of pyOpenSSL installed.
# Report which one we got.
pyopenssl_path = os.path.dirname(inspect.getfile(OpenSSL))
print(
- "Your pyOpenSSL {} installation is located at {}".format(
+ u"Your pyOpenSSL {} installation is located at {}".format(
OpenSSL.__version__, pyopenssl_path
),
file=fp
diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py
index 5acf7696..1c143919 100644
--- a/netlib/websockets/__init__.py
+++ b/netlib/websockets/__init__.py
@@ -1,2 +1,2 @@
-from frame import *
-from protocol import *
+from .frame import *
+from .protocol import *
diff --git a/test/http/http1/test_assemble.py b/test/http/http1/test_assemble.py
new file mode 100644
index 00000000..8a0a54f1
--- /dev/null
+++ b/test/http/http1/test_assemble.py
@@ -0,0 +1,91 @@
+from __future__ import absolute_import, print_function, division
+from netlib.exceptions import HttpException
+from netlib.http import CONTENT_MISSING, Headers
+from netlib.http.http1.assemble import (
+ assemble_request, assemble_request_head, assemble_response,
+ assemble_response_head, _assemble_request_line, _assemble_request_headers,
+ _assemble_response_headers
+)
+from netlib.tutils import treq, raises, tresp
+
+
+def test_assemble_request():
+ c = assemble_request(treq()) == (
+ b"GET /path HTTP/1.1\r\n"
+ b"header: qvalue\r\n"
+ b"Host: address:22\r\n"
+ b"Content-Length: 7\r\n"
+ b"\r\n"
+ b"content"
+ )
+
+ with raises(HttpException):
+ assemble_request(treq(body=CONTENT_MISSING))
+
+
+def test_assemble_request_head():
+ c = assemble_request_head(treq())
+ assert b"GET" in c
+ assert b"qvalue" in c
+ assert b"content" not in c
+
+
+def test_assemble_response():
+ c = assemble_response(tresp()) == (
+ b"HTTP/1.1 200 OK\r\n"
+ b"header-response: svalue\r\n"
+ b"Content-Length: 7\r\n"
+ b"\r\n"
+ b"message"
+ )
+
+ with raises(HttpException):
+ assemble_response(tresp(body=CONTENT_MISSING))
+
+
+def test_assemble_response_head():
+ c = assemble_response_head(tresp())
+ assert b"200" in c
+ assert b"svalue" in c
+ assert b"message" not in c
+
+
+def test_assemble_request_line():
+ assert _assemble_request_line(treq()) == b"GET /path HTTP/1.1"
+
+ authority_request = treq(method=b"CONNECT", form_in="authority")
+ assert _assemble_request_line(authority_request) == b"CONNECT address:22 HTTP/1.1"
+
+ absolute_request = treq(form_in="absolute")
+ assert _assemble_request_line(absolute_request) == b"GET http://address:22/path HTTP/1.1"
+
+ with raises(RuntimeError):
+ _assemble_request_line(treq(), "invalid_form")
+
+
+def test_assemble_request_headers():
+ # https://github.com/mitmproxy/mitmproxy/issues/186
+ r = treq(body=b"")
+ r.headers[b"Transfer-Encoding"] = b"chunked"
+ c = _assemble_request_headers(r)
+ assert b"Content-Length" in c
+ assert b"Transfer-Encoding" not in c
+
+ assert b"Host" in _assemble_request_headers(treq(headers=Headers()))
+
+ assert b"Proxy-Connection" not in _assemble_request_headers(
+ treq(headers=Headers(Proxy_Connection="42"))
+ )
+
+
+def test_assemble_response_headers():
+ # https://github.com/mitmproxy/mitmproxy/issues/186
+ r = tresp(body=b"")
+ r.headers["Transfer-Encoding"] = b"chunked"
+ c = _assemble_response_headers(r)
+ assert b"Content-Length" in c
+ assert b"Transfer-Encoding" not in c
+
+ assert b"Proxy-Connection" not in _assemble_response_headers(
+ tresp(headers=Headers(Proxy_Connection=b"42"))
+ )
diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py
index f7c615bd..e69de29b 100644
--- a/test/http/http1/test_protocol.py
+++ b/test/http/http1/test_protocol.py
@@ -1,497 +0,0 @@
-import cStringIO
-import textwrap
-
-from netlib import http, odict, tcp, tutils
-from netlib.http import semantics, Headers
-from netlib.http.http1 import HTTP1Protocol
-from ... import tservers
-
-
-class NoContentLengthHTTPHandler(tcp.BaseHandler):
- def handle(self):
- self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n")
- self.wfile.flush()
-
-
-def mock_protocol(data=''):
- rfile = cStringIO.StringIO(data)
- wfile = cStringIO.StringIO()
- return HTTP1Protocol(rfile=rfile, wfile=wfile)
-
-
-def match_http_string(data):
- return textwrap.dedent(data).strip().replace('\n', '\r\n')
-
-
-def test_stripped_chunked_encoding_no_content():
- """
- https://github.com/mitmproxy/mitmproxy/issues/186
- """
-
- r = tutils.treq(content="")
- r.headers["Transfer-Encoding"] = "chunked"
- assert "Content-Length" in mock_protocol()._assemble_request_headers(r)
-
- r = tutils.tresp(content="")
- r.headers["Transfer-Encoding"] = "chunked"
- assert "Content-Length" in mock_protocol()._assemble_response_headers(r)
-
-
-def test_has_chunked_encoding():
- headers = http.Headers()
- assert not HTTP1Protocol.has_chunked_encoding(headers)
- headers["transfer-encoding"] = "chunked"
- assert HTTP1Protocol.has_chunked_encoding(headers)
-
-
-def test_read_chunked():
- headers = http.Headers()
- headers["transfer-encoding"] = "chunked"
-
- data = "1\r\na\r\n0\r\n"
- tutils.raises(
- "malformed chunked body",
- mock_protocol(data).read_http_body,
- headers, None, "GET", None, True
- )
-
- data = "1\r\na\r\n0\r\n\r\n"
- assert mock_protocol(data).read_http_body(headers, None, "GET", None, True) == "a"
-
- data = "\r\n\r\n1\r\na\r\n0\r\n\r\n"
- assert mock_protocol(data).read_http_body(headers, None, "GET", None, True) == "a"
-
- data = "\r\n"
- tutils.raises(
- "closed prematurely",
- mock_protocol(data).read_http_body,
- headers, None, "GET", None, True
- )
-
- data = "1\r\nfoo"
- tutils.raises(
- "malformed chunked body",
- mock_protocol(data).read_http_body,
- headers, None, "GET", None, True
- )
-
- data = "foo\r\nfoo"
- tutils.raises(
- http.HttpError,
- mock_protocol(data).read_http_body,
- headers, None, "GET", None, True
- )
-
- data = "5\r\naaaaa\r\n0\r\n\r\n"
- tutils.raises("too large", mock_protocol(data).read_http_body, headers, 2, "GET", None, True)
-
-
-def test_connection_close():
- headers = Headers()
- assert HTTP1Protocol.connection_close((1, 0), headers)
- assert not HTTP1Protocol.connection_close((1, 1), headers)
-
- headers["connection"] = "keep-alive"
- assert not HTTP1Protocol.connection_close((1, 1), headers)
-
- headers["connection"] = "close"
- assert HTTP1Protocol.connection_close((1, 1), headers)
-
-
-def test_read_http_body_request():
- headers = Headers()
- data = "testing"
- assert mock_protocol(data).read_http_body(headers, None, "GET", None, True) == ""
-
-
-def test_read_http_body_response():
- headers = Headers()
- data = "testing"
- assert mock_protocol(data).read_http_body(headers, None, "GET", 200, False) == "testing"
-
-
-def test_read_http_body():
- # test default case
- headers = Headers()
- headers["content-length"] = "7"
- data = "testing"
- assert mock_protocol(data).read_http_body(headers, None, "GET", 200, False) == "testing"
-
- # test content length: invalid header
- headers["content-length"] = "foo"
- data = "testing"
- tutils.raises(
- http.HttpError,
- mock_protocol(data).read_http_body,
- headers, None, "GET", 200, False
- )
-
- # test content length: invalid header #2
- headers["content-length"] = "-1"
- data = "testing"
- tutils.raises(
- http.HttpError,
- mock_protocol(data).read_http_body,
- headers, None, "GET", 200, False
- )
-
- # test content length: content length > actual content
- headers["content-length"] = "5"
- data = "testing"
- tutils.raises(
- http.HttpError,
- mock_protocol(data).read_http_body,
- headers, 4, "GET", 200, False
- )
-
- # test content length: content length < actual content
- data = "testing"
- assert len(mock_protocol(data).read_http_body(headers, None, "GET", 200, False)) == 5
-
- # test no content length: limit > actual content
- headers = Headers()
- data = "testing"
- assert len(mock_protocol(data).read_http_body(headers, 100, "GET", 200, False)) == 7
-
- # test no content length: limit < actual content
- data = "testing"
- tutils.raises(
- http.HttpError,
- mock_protocol(data).read_http_body,
- headers, 4, "GET", 200, False
- )
-
- # test chunked
- headers = Headers()
- headers["transfer-encoding"] = "chunked"
- data = "5\r\naaaaa\r\n0\r\n\r\n"
- assert mock_protocol(data).read_http_body(headers, 100, "GET", 200, False) == "aaaaa"
-
-
-def test_expected_http_body_size():
- # gibber in the content-length field
- headers = Headers(content_length="foo")
- assert HTTP1Protocol.expected_http_body_size(headers, False, "GET", 200) is None
- # negative number in the content-length field
- headers = Headers(content_length="-7")
- assert HTTP1Protocol.expected_http_body_size(headers, False, "GET", 200) is None
- # explicit length
- headers = Headers(content_length="5")
- assert HTTP1Protocol.expected_http_body_size(headers, False, "GET", 200) == 5
- # no length
- headers = Headers()
- assert HTTP1Protocol.expected_http_body_size(headers, False, "GET", 200) == -1
- # no length request
- headers = Headers()
- assert HTTP1Protocol.expected_http_body_size(headers, True, "GET", None) == 0
-
-
-def test_get_request_line():
- data = "\nfoo"
- p = mock_protocol(data)
- assert p._get_request_line() == "foo"
- assert not p._get_request_line()
-
-
-def test_parse_http_protocol():
- assert HTTP1Protocol._parse_http_protocol("HTTP/1.1") == (1, 1)
- assert HTTP1Protocol._parse_http_protocol("HTTP/0.0") == (0, 0)
- assert not HTTP1Protocol._parse_http_protocol("HTTP/a.1")
- assert not HTTP1Protocol._parse_http_protocol("HTTP/1.a")
- assert not HTTP1Protocol._parse_http_protocol("foo/0.0")
- assert not HTTP1Protocol._parse_http_protocol("HTTP/x")
-
-
-def test_parse_init_connect():
- assert HTTP1Protocol._parse_init_connect("CONNECT host.com:443 HTTP/1.0")
- assert not HTTP1Protocol._parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0")
- assert not HTTP1Protocol._parse_init_connect("CONNECT \0host.com:443 HTTP/1.0")
- assert not HTTP1Protocol._parse_init_connect("CONNECT host.com:444444 HTTP/1.0")
- assert not HTTP1Protocol._parse_init_connect("bogus")
- assert not HTTP1Protocol._parse_init_connect("GET host.com:443 HTTP/1.0")
- assert not HTTP1Protocol._parse_init_connect("CONNECT host.com443 HTTP/1.0")
- assert not HTTP1Protocol._parse_init_connect("CONNECT host.com:443 foo/1.0")
- assert not HTTP1Protocol._parse_init_connect("CONNECT host.com:foo HTTP/1.0")
-
-
-def test_parse_init_proxy():
- u = "GET http://foo.com:8888/test HTTP/1.1"
- m, s, h, po, pa, httpversion = HTTP1Protocol._parse_init_proxy(u)
- assert m == "GET"
- assert s == "http"
- assert h == "foo.com"
- assert po == 8888
- assert pa == "/test"
- assert httpversion == (1, 1)
-
- u = "G\xfeET http://foo.com:8888/test HTTP/1.1"
- assert not HTTP1Protocol._parse_init_proxy(u)
-
- assert not HTTP1Protocol._parse_init_proxy("invalid")
- assert not HTTP1Protocol._parse_init_proxy("GET invalid HTTP/1.1")
- assert not HTTP1Protocol._parse_init_proxy("GET http://foo.com:8888/test foo/1.1")
-
-
-def test_parse_init_http():
- u = "GET /test HTTP/1.1"
- m, u, httpversion = HTTP1Protocol._parse_init_http(u)
- assert m == "GET"
- assert u == "/test"
- assert httpversion == (1, 1)
-
- u = "G\xfeET /test HTTP/1.1"
- assert not HTTP1Protocol._parse_init_http(u)
-
- assert not HTTP1Protocol._parse_init_http("invalid")
- assert not HTTP1Protocol._parse_init_http("GET invalid HTTP/1.1")
- assert not HTTP1Protocol._parse_init_http("GET /test foo/1.1")
- assert not HTTP1Protocol._parse_init_http("GET /test\xc0 HTTP/1.1")
-
-
-class TestReadHeaders:
-
- def _read(self, data, verbatim=False):
- if not verbatim:
- data = textwrap.dedent(data)
- data = data.strip()
- return mock_protocol(data).read_headers()
-
- def test_read_simple(self):
- data = """
- Header: one
- Header2: two
- \r\n
- """
- headers = self._read(data)
- assert headers.fields == [["Header", "one"], ["Header2", "two"]]
-
- def test_read_multi(self):
- data = """
- Header: one
- Header: two
- \r\n
- """
- headers = self._read(data)
- assert headers.fields == [["Header", "one"], ["Header", "two"]]
-
- def test_read_continued(self):
- data = """
- Header: one
- \ttwo
- Header2: three
- \r\n
- """
- headers = self._read(data)
- assert headers.fields == [["Header", "one\r\n two"], ["Header2", "three"]]
-
- def test_read_continued_err(self):
- data = "\tfoo: bar\r\n"
- assert self._read(data, True) is None
-
- def test_read_err(self):
- data = """
- foo
- """
- assert self._read(data) is None
-
-
-class TestReadRequest(object):
-
- def tst(self, data, **kwargs):
- return mock_protocol(data).read_request(**kwargs)
-
- def test_invalid(self):
- tutils.raises(
- "bad http request",
- self.tst,
- "xxx"
- )
- tutils.raises(
- "bad http request line",
- self.tst,
- "get /\xff HTTP/1.1"
- )
- tutils.raises(
- "invalid headers",
- self.tst,
- "get / HTTP/1.1\r\nfoo"
- )
- tutils.raises(
- tcp.NetLibDisconnect,
- self.tst,
- "\r\n"
- )
-
- def test_empty(self):
- v = self.tst("", allow_empty=True)
- assert isinstance(v, semantics.EmptyRequest)
-
- def test_asterisk_form_in(self):
- v = self.tst("OPTIONS * HTTP/1.1")
- assert v.form_in == "relative"
- assert v.method == "OPTIONS"
-
- def test_absolute_form_in(self):
- tutils.raises(
- "Bad HTTP request line",
- self.tst,
- "GET oops-no-protocol.com HTTP/1.1"
- )
- v = self.tst("GET http://address:22/ HTTP/1.1")
- assert v.form_in == "absolute"
- assert v.port == 22
- assert v.host == "address"
- assert v.scheme == "http"
-
- def test_connect(self):
- tutils.raises(
- "Bad HTTP request line",
- self.tst,
- "CONNECT oops-no-port.com HTTP/1.1"
- )
- v = self.tst("CONNECT foo.com:443 HTTP/1.1")
- assert v.form_in == "authority"
- assert v.method == "CONNECT"
- assert v.port == 443
- assert v.host == "foo.com"
-
- def test_expect(self):
- data = "".join(
- "GET / HTTP/1.1\r\n"
- "Content-Length: 3\r\n"
- "Expect: 100-continue\r\n\r\n"
- "foobar"
- )
-
- p = mock_protocol(data)
- v = p.read_request()
- assert p.tcp_handler.wfile.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n"
- assert v.body == "foo"
- assert p.tcp_handler.rfile.read(3) == "bar"
-
-
-class TestReadResponse(object):
- def tst(self, data, method, body_size_limit, include_body=True):
- data = textwrap.dedent(data)
- return mock_protocol(data).read_response(
- method, body_size_limit, include_body=include_body
- )
-
- def test_errors(self):
- tutils.raises("server disconnect", self.tst, "", "GET", None)
- tutils.raises("invalid server response", self.tst, "foo", "GET", None)
-
- def test_simple(self):
- data = """
- HTTP/1.1 200
- """
- assert self.tst(data, "GET", None) == http.Response(
- (1, 1), 200, '', Headers(), ''
- )
-
- def test_simple_message(self):
- data = """
- HTTP/1.1 200 OK
- """
- assert self.tst(data, "GET", None) == http.Response(
- (1, 1), 200, 'OK', Headers(), ''
- )
-
- def test_invalid_http_version(self):
- data = """
- HTTP/x 200 OK
- """
- tutils.raises("invalid http version", self.tst, data, "GET", None)
-
- def test_invalid_status_code(self):
- data = """
- HTTP/1.1 xx OK
- """
- tutils.raises("invalid server response", self.tst, data, "GET", None)
-
- def test_valid_with_continue(self):
- data = """
- HTTP/1.1 100 CONTINUE
-
- HTTP/1.1 200 OK
- """
- assert self.tst(data, "GET", None) == http.Response(
- (1, 1), 100, 'CONTINUE', Headers(), ''
- )
-
- def test_simple_body(self):
- data = """
- HTTP/1.1 200 OK
- Content-Length: 3
-
- foo
- """
- assert self.tst(data, "GET", None).body == 'foo'
- assert self.tst(data, "HEAD", None).body == ''
-
- def test_invalid_headers(self):
- data = """
- HTTP/1.1 200 OK
- \tContent-Length: 3
-
- foo
- """
- tutils.raises("invalid headers", self.tst, data, "GET", None)
-
- def test_without_body(self):
- data = """
- HTTP/1.1 200 OK
- Content-Length: 3
-
- foo
- """
- assert self.tst(data, "GET", None, include_body=False).body is None
-
-
-class TestReadResponseNoContentLength(tservers.ServerTestBase):
- handler = NoContentLengthHTTPHandler
-
- def test_no_content_length(self):
- c = tcp.TCPClient(("127.0.0.1", self.port))
- c.connect()
- resp = HTTP1Protocol(c).read_response("GET", None)
- assert resp.body == "bar\r\n\r\n"
-
-
-class TestAssembleRequest(object):
- def test_simple(self):
- req = tutils.treq()
- b = HTTP1Protocol().assemble_request(req)
- assert b == match_http_string("""
- GET /path HTTP/1.1
- header: qvalue
- Host: address:22
- Content-Length: 7
-
- content""")
-
- def test_body_missing(self):
- req = tutils.treq(content=semantics.CONTENT_MISSING)
- tutils.raises(http.HttpError, HTTP1Protocol().assemble_request, req)
-
- def test_not_a_request(self):
- tutils.raises(AssertionError, HTTP1Protocol().assemble_request, 'foo')
-
-
-class TestAssembleResponse(object):
- def test_simple(self):
- resp = tutils.tresp()
- b = HTTP1Protocol().assemble_response(resp)
- assert b == match_http_string("""
- HTTP/1.1 200 OK
- header_response: svalue
- Content-Length: 7
-
- message""")
-
- def test_body_missing(self):
- resp = tutils.tresp(content=semantics.CONTENT_MISSING)
- tutils.raises(http.HttpError, HTTP1Protocol().assemble_response, resp)
-
- def test_not_a_request(self):
- tutils.raises(AssertionError, HTTP1Protocol().assemble_response, 'foo')
diff --git a/test/http/http1/test_read.py b/test/http/http1/test_read.py
new file mode 100644
index 00000000..55def2a5
--- /dev/null
+++ b/test/http/http1/test_read.py
@@ -0,0 +1,317 @@
+from __future__ import absolute_import, print_function, division
+from io import BytesIO
+import textwrap
+
+from mock import Mock
+
+from netlib.exceptions import HttpException, HttpSyntaxException, HttpReadDisconnect
+from netlib.http import Headers
+from netlib.http.http1.read import (
+ read_request, read_response, read_request_head,
+ read_response_head, read_body, connection_close, expected_http_body_size, _get_first_line,
+ _read_request_line, _parse_authority_form, _read_response_line, _check_http_version,
+ _read_headers, _read_chunked
+)
+from netlib.tutils import treq, tresp, raises
+
+
+def test_read_request():
+ rfile = BytesIO(b"GET / HTTP/1.1\r\n\r\nskip")
+ r = read_request(rfile)
+ assert r.method == b"GET"
+ assert r.body == b""
+ assert r.timestamp_end
+ assert rfile.read() == b"skip"
+
+
+def test_read_request_head():
+ rfile = BytesIO(
+ b"GET / HTTP/1.1\r\n"
+ b"Content-Length: 4\r\n"
+ b"\r\n"
+ b"skip"
+ )
+ rfile.reset_timestamps = Mock()
+ rfile.first_byte_timestamp = 42
+ r = read_request_head(rfile)
+ assert r.method == b"GET"
+ assert r.headers["Content-Length"] == b"4"
+ assert r.body is None
+ assert rfile.reset_timestamps.called
+ assert r.timestamp_start == 42
+ assert rfile.read() == b"skip"
+
+
+def test_read_response():
+ req = treq()
+ rfile = BytesIO(b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody")
+ r = read_response(rfile, req)
+ assert r.status_code == 418
+ assert r.body == b"body"
+ assert r.timestamp_end
+
+
+def test_read_response_head():
+ rfile = BytesIO(
+ b"HTTP/1.1 418 I'm a teapot\r\n"
+ b"Content-Length: 4\r\n"
+ b"\r\n"
+ b"skip"
+ )
+ rfile.reset_timestamps = Mock()
+ rfile.first_byte_timestamp = 42
+ r = read_response_head(rfile)
+ assert r.status_code == 418
+ assert r.headers["Content-Length"] == b"4"
+ assert r.body is None
+ assert rfile.reset_timestamps.called
+ assert r.timestamp_start == 42
+ assert rfile.read() == b"skip"
+
+
+class TestReadBody(object):
+ def test_chunked(self):
+ rfile = BytesIO(b"3\r\nfoo\r\n0\r\n\r\nbar")
+ body = b"".join(read_body(rfile, None))
+ assert body == b"foo"
+ assert rfile.read() == b"bar"
+
+
+ def test_known_size(self):
+ rfile = BytesIO(b"foobar")
+ body = b"".join(read_body(rfile, 3))
+ assert body == b"foo"
+ assert rfile.read() == b"bar"
+
+
+ def test_known_size_limit(self):
+ rfile = BytesIO(b"foobar")
+ with raises(HttpException):
+ b"".join(read_body(rfile, 3, 2))
+
+ def test_known_size_too_short(self):
+ rfile = BytesIO(b"foo")
+ with raises(HttpException):
+ b"".join(read_body(rfile, 6))
+
+ def test_unknown_size(self):
+ rfile = BytesIO(b"foobar")
+ body = b"".join(read_body(rfile, -1))
+ assert body == b"foobar"
+
+
+ def test_unknown_size_limit(self):
+ rfile = BytesIO(b"foobar")
+ with raises(HttpException):
+ b"".join(read_body(rfile, -1, 3))
+
+
+def test_connection_close():
+ headers = Headers()
+ assert connection_close(b"HTTP/1.0", headers)
+ assert not connection_close(b"HTTP/1.1", headers)
+
+ headers["connection"] = "keep-alive"
+ assert not connection_close(b"HTTP/1.1", headers)
+
+ headers["connection"] = "close"
+ assert connection_close(b"HTTP/1.1", headers)
+
+
+def test_expected_http_body_size():
+ # Expect: 100-continue
+ assert expected_http_body_size(
+ treq(headers=Headers(expect=b"100-continue", content_length=b"42"))
+ ) == 0
+
+ # http://tools.ietf.org/html/rfc7230#section-3.3
+ assert expected_http_body_size(
+ treq(method=b"HEAD"),
+ tresp(headers=Headers(content_length=b"42"))
+ ) == 0
+ assert expected_http_body_size(
+ treq(method=b"CONNECT"),
+ tresp()
+ ) == 0
+ for code in (100, 204, 304):
+ assert expected_http_body_size(
+ treq(),
+ tresp(status_code=code)
+ ) == 0
+
+ # chunked
+ assert expected_http_body_size(
+ treq(headers=Headers(transfer_encoding=b"chunked")),
+ ) is None
+
+ # explicit length
+ for l in (b"foo", b"-7"):
+ with raises(HttpSyntaxException):
+ expected_http_body_size(
+ treq(headers=Headers(content_length=l))
+ )
+ assert expected_http_body_size(
+ treq(headers=Headers(content_length=b"42"))
+ ) == 42
+
+ # no length
+ assert expected_http_body_size(
+ treq()
+ ) == 0
+ assert expected_http_body_size(
+ treq(), tresp()
+ ) == -1
+
+
+def test_get_first_line():
+ rfile = BytesIO(b"foo\r\nbar")
+ assert _get_first_line(rfile) == b"foo"
+
+ rfile = BytesIO(b"\r\nfoo\r\nbar")
+ assert _get_first_line(rfile) == b"foo"
+
+ with raises(HttpReadDisconnect):
+ rfile = BytesIO(b"")
+ _get_first_line(rfile)
+
+ with raises(HttpSyntaxException):
+ rfile = BytesIO(b"GET /\xff HTTP/1.1")
+ _get_first_line(rfile)
+
+
+def test_read_request_line():
+ def t(b):
+ return _read_request_line(BytesIO(b))
+
+ assert (t(b"GET / HTTP/1.1") ==
+ ("relative", b"GET", None, None, None, b"/", b"HTTP/1.1"))
+ assert (t(b"OPTIONS * HTTP/1.1") ==
+ ("relative", b"OPTIONS", None, None, None, b"*", b"HTTP/1.1"))
+ assert (t(b"CONNECT foo:42 HTTP/1.1") ==
+ ("authority", b"CONNECT", None, b"foo", 42, None, b"HTTP/1.1"))
+ assert (t(b"GET http://foo:42/bar HTTP/1.1") ==
+ ("absolute", b"GET", b"http", b"foo", 42, b"/bar", b"HTTP/1.1"))
+
+ with raises(HttpSyntaxException):
+ t(b"GET / WTF/1.1")
+ with raises(HttpSyntaxException):
+ t(b"this is not http")
+
+
+def test_parse_authority_form():
+ assert _parse_authority_form(b"foo:42") == (b"foo", 42)
+ with raises(HttpSyntaxException):
+ _parse_authority_form(b"foo")
+ with raises(HttpSyntaxException):
+ _parse_authority_form(b"foo:bar")
+ with raises(HttpSyntaxException):
+ _parse_authority_form(b"foo:99999999")
+ with raises(HttpSyntaxException):
+ _parse_authority_form(b"f\x00oo:80")
+
+
+def test_read_response_line():
+ def t(b):
+ return _read_response_line(BytesIO(b))
+
+ assert t(b"HTTP/1.1 200 OK") == (b"HTTP/1.1", 200, b"OK")
+ assert t(b"HTTP/1.1 200") == (b"HTTP/1.1", 200, b"")
+ with raises(HttpSyntaxException):
+ assert t(b"HTTP/1.1")
+
+ with raises(HttpSyntaxException):
+ t(b"HTTP/1.1 OK OK")
+ with raises(HttpSyntaxException):
+ t(b"WTF/1.1 200 OK")
+
+
+def test_check_http_version():
+ _check_http_version(b"HTTP/0.9")
+ _check_http_version(b"HTTP/1.0")
+ _check_http_version(b"HTTP/1.1")
+ _check_http_version(b"HTTP/2.0")
+ with raises(HttpSyntaxException):
+ _check_http_version(b"WTF/1.0")
+ with raises(HttpSyntaxException):
+ _check_http_version(b"HTTP/1.10")
+ with raises(HttpSyntaxException):
+ _check_http_version(b"HTTP/1.b")
+
+
+class TestReadHeaders(object):
+ @staticmethod
+ def _read(data):
+ return _read_headers(BytesIO(data))
+
+ def test_read_simple(self):
+ data = (
+ b"Header: one\r\n"
+ b"Header2: two\r\n"
+ b"\r\n"
+ )
+ headers = self._read(data)
+ assert headers.fields == [[b"Header", b"one"], [b"Header2", b"two"]]
+
+ def test_read_multi(self):
+ data = (
+ b"Header: one\r\n"
+ b"Header: two\r\n"
+ b"\r\n"
+ )
+ headers = self._read(data)
+ assert headers.fields == [[b"Header", b"one"], [b"Header", b"two"]]
+
+ def test_read_continued(self):
+ data = (
+ b"Header: one\r\n"
+ b"\ttwo\r\n"
+ b"Header2: three\r\n"
+ b"\r\n"
+ )
+ headers = self._read(data)
+ assert headers.fields == [[b"Header", b"one\r\n two"], [b"Header2", b"three"]]
+
+ def test_read_continued_err(self):
+ data = b"\tfoo: bar\r\n"
+ with raises(HttpSyntaxException):
+ self._read(data)
+
+ def test_read_err(self):
+ data = b"foo"
+ with raises(HttpSyntaxException):
+ self._read(data)
+
+ def test_read_empty_name(self):
+ data = b":foo"
+ with raises(HttpSyntaxException):
+ self._read(data)
+
+def test_read_chunked():
+ req = treq(body=None)
+ req.headers["Transfer-Encoding"] = "chunked"
+
+ data = b"1\r\na\r\n0\r\n"
+ with raises(HttpSyntaxException):
+ b"".join(_read_chunked(BytesIO(data)))
+
+ data = b"1\r\na\r\n0\r\n\r\n"
+ assert b"".join(_read_chunked(BytesIO(data))) == b"a"
+
+ data = b"\r\n\r\n1\r\na\r\n1\r\nb\r\n0\r\n\r\n"
+ assert b"".join(_read_chunked(BytesIO(data))) == b"ab"
+
+ data = b"\r\n"
+ with raises("closed prematurely"):
+ b"".join(_read_chunked(BytesIO(data)))
+
+ data = b"1\r\nfoo"
+ with raises("malformed chunked body"):
+ b"".join(_read_chunked(BytesIO(data)))
+
+ data = b"foo\r\nfoo"
+ with raises(HttpSyntaxException):
+ b"".join(_read_chunked(BytesIO(data)))
+
+ data = b"5\r\naaaaa\r\n0\r\n\r\n"
+ with raises("too large"):
+ b"".join(_read_chunked(BytesIO(data), limit=2))
diff --git a/test/http/http2/test_frames.py b/test/http/http2/test_frames.py
index 5d5cb0ba..4c89b023 100644
--- a/test/http/http2/test_frames.py
+++ b/test/http/http2/test_frames.py
@@ -1,4 +1,4 @@
-import cStringIO
+from io import BytesIO
from nose.tools import assert_equal
from netlib import tcp, tutils
@@ -7,7 +7,7 @@ from netlib.http.http2.frame import *
def hex_to_file(data):
data = data.decode('hex')
- return tcp.Reader(cStringIO.StringIO(data))
+ return tcp.Reader(BytesIO(data))
def test_invalid_flags():
@@ -39,7 +39,7 @@ def test_too_large_frames():
flags=Frame.FLAG_END_STREAM,
stream_id=0x1234567,
payload='foobar' * 3000)
- tutils.raises(FrameSizeError, f.to_bytes)
+ tutils.raises(HttpSyntaxException, f.to_bytes)
def test_data_frame_to_bytes():
diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py
index 2b7d7958..a369eb49 100644
--- a/test/http/http2/test_protocol.py
+++ b/test/http/http2/test_protocol.py
@@ -2,21 +2,21 @@ import OpenSSL
import mock
from netlib import tcp, http, tutils
-from netlib.http import http2, Headers
-from netlib.http.http2 import HTTP2Protocol
+from netlib.http import Headers
+from netlib.http.http2.connections import HTTP2Protocol, TCPHandler
from netlib.http.http2.frame import *
from ... import tservers
class TestTCPHandlerWrapper:
def test_wrapped(self):
- h = http2.TCPHandler(rfile='foo', wfile='bar')
+ h = TCPHandler(rfile='foo', wfile='bar')
p = HTTP2Protocol(h)
assert p.tcp_handler.rfile == 'foo'
assert p.tcp_handler.wfile == 'bar'
def test_direct(self):
p = HTTP2Protocol(rfile='foo', wfile='bar')
- assert isinstance(p.tcp_handler, http2.TCPHandler)
+ assert isinstance(p.tcp_handler, TCPHandler)
assert p.tcp_handler.rfile == 'foo'
assert p.tcp_handler.wfile == 'bar'
@@ -32,8 +32,8 @@ class EchoHandler(tcp.BaseHandler):
class TestProtocol:
- @mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface")
- @mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface")
+ @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_server_connection_preface")
+ @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_client_connection_preface")
def test_perform_connection_preface(self, mock_client_method, mock_server_method):
protocol = HTTP2Protocol(is_server=False)
protocol.connection_preface_performed = True
@@ -46,8 +46,8 @@ class TestProtocol:
assert mock_client_method.called
assert not mock_server_method.called
- @mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface")
- @mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface")
+ @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_server_connection_preface")
+ @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_client_connection_preface")
def test_perform_connection_preface_server(self, mock_client_method, mock_server_method):
protocol = HTTP2Protocol(is_server=True)
protocol.connection_preface_performed = True
@@ -64,7 +64,7 @@ class TestProtocol:
class TestCheckALPNMatch(tservers.ServerTestBase):
handler = EchoHandler
ssl = dict(
- alpn_select=HTTP2Protocol.ALPN_PROTO_H2,
+ alpn_select=ALPN_PROTO_H2,
)
if OpenSSL._util.lib.Cryptography_HAS_ALPN:
@@ -72,7 +72,7 @@ class TestCheckALPNMatch(tservers.ServerTestBase):
def test_check_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
- c.convert_to_ssl(alpn_protos=[HTTP2Protocol.ALPN_PROTO_H2])
+ c.convert_to_ssl(alpn_protos=[ALPN_PROTO_H2])
protocol = HTTP2Protocol(c)
assert protocol.check_alpn()
@@ -88,7 +88,7 @@ class TestCheckALPNMismatch(tservers.ServerTestBase):
def test_check_alpn(self):
c = tcp.TCPClient(("127.0.0.1", self.port))
c.connect()
- c.convert_to_ssl(alpn_protos=[HTTP2Protocol.ALPN_PROTO_H2])
+ c.convert_to_ssl(alpn_protos=[ALPN_PROTO_H2])
protocol = HTTP2Protocol(c)
tutils.raises(NotImplementedError, protocol.check_alpn)
@@ -306,7 +306,7 @@ class TestReadRequest(tservers.ServerTestBase):
protocol = HTTP2Protocol(c, is_server=True)
protocol.connection_preface_performed = True
- req = protocol.read_request()
+ req = protocol.read_request(NotImplemented)
assert req.stream_id
assert req.headers.fields == [[':method', 'GET'], [':path', '/'], [':scheme', 'https']]
@@ -329,7 +329,7 @@ class TestReadRequestRelative(tservers.ServerTestBase):
protocol = HTTP2Protocol(c, is_server=True)
protocol.connection_preface_performed = True
- req = protocol.read_request()
+ req = protocol.read_request(NotImplemented)
assert req.form_in == "relative"
assert req.method == "OPTIONS"
@@ -352,7 +352,7 @@ class TestReadRequestAbsolute(tservers.ServerTestBase):
protocol = HTTP2Protocol(c, is_server=True)
protocol.connection_preface_performed = True
- req = protocol.read_request()
+ req = protocol.read_request(NotImplemented)
assert req.form_in == "absolute"
assert req.scheme == "http"
@@ -378,13 +378,13 @@ class TestReadRequestConnect(tservers.ServerTestBase):
protocol = HTTP2Protocol(c, is_server=True)
protocol.connection_preface_performed = True
- req = protocol.read_request()
+ req = protocol.read_request(NotImplemented)
assert req.form_in == "authority"
assert req.method == "CONNECT"
assert req.host == "address"
assert req.port == 22
- req = protocol.read_request()
+ req = protocol.read_request(NotImplemented)
assert req.form_in == "authority"
assert req.method == "CONNECT"
assert req.host == "example.com"
@@ -410,7 +410,7 @@ class TestReadResponse(tservers.ServerTestBase):
protocol = HTTP2Protocol(c)
protocol.connection_preface_performed = True
- resp = protocol.read_response(stream_id=42)
+ resp = protocol.read_response(NotImplemented, stream_id=42)
assert resp.httpversion == (2, 0)
assert resp.status_code == 200
@@ -436,7 +436,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase):
protocol = HTTP2Protocol(c)
protocol.connection_preface_performed = True
- resp = protocol.read_response(stream_id=42)
+ resp = protocol.read_response(NotImplemented, stream_id=42)
assert resp.stream_id == 42
assert resp.httpversion == (2, 0)
diff --git a/test/http/test_authentication.py b/test/http/test_authentication.py
index 17c91fe5..ee192dd7 100644
--- a/test/http/test_authentication.py
+++ b/test/http/test_authentication.py
@@ -5,7 +5,7 @@ from netlib.http import authentication, Headers
def test_parse_http_basic_auth():
- vals = ("basic", "foo", "bar")
+ vals = (b"basic", b"foo", b"bar")
assert authentication.parse_http_basic_auth(
authentication.assemble_http_basic_auth(*vals)
) == vals
diff --git a/test/http/test_exceptions.py b/test/http/test_exceptions.py
deleted file mode 100644
index 49588d0a..00000000
--- a/test/http/test_exceptions.py
+++ /dev/null
@@ -1,6 +0,0 @@
-from netlib.http.exceptions import *
-
-class TestHttpError:
- def test_simple(self):
- e = HttpError(404, "Not found")
- assert str(e)
diff --git a/test/http/test_semantics.py b/test/http/test_models.py
index 6dcbbe07..8fce2e9d 100644
--- a/test/http/test_semantics.py
+++ b/test/http/test_models.py
@@ -1,32 +1,11 @@
import mock
-from netlib import http
-from netlib import odict
from netlib import tutils
from netlib import utils
-from netlib.http import semantics
-from netlib.http.semantics import CONTENT_MISSING
-
-class TestProtocolMixin(object):
- @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_response")
- @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_request")
- def test_assemble_request(self, mock_request_method, mock_response_method):
- p = semantics.ProtocolMixin()
- p.assemble(tutils.treq())
- assert mock_request_method.called
- assert not mock_response_method.called
-
- @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_response")
- @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_request")
- def test_assemble_response(self, mock_request_method, mock_response_method):
- p = semantics.ProtocolMixin()
- p.assemble(tutils.tresp())
- assert not mock_request_method.called
- assert mock_response_method.called
-
- def test_assemble_foo(self):
- p = semantics.ProtocolMixin()
- tutils.raises(ValueError, p.assemble, 'foo')
+from netlib.odict import ODict, ODictCaseless
+from netlib.http import Request, Response, Headers, CONTENT_MISSING, HDR_FORM_URLENCODED, \
+ HDR_FORM_MULTIPART
+
class TestRequest(object):
def test_repr(self):
@@ -34,27 +13,27 @@ class TestRequest(object):
assert repr(r)
def test_headers(self):
- tutils.raises(AssertionError, semantics.Request,
+ tutils.raises(AssertionError, Request,
'form_in',
'method',
'scheme',
'host',
'port',
'path',
- (1, 1),
+ b"HTTP/1.1",
'foobar',
)
- req = semantics.Request(
+ req = Request(
'form_in',
'method',
'scheme',
'host',
'port',
'path',
- (1, 1),
+ b"HTTP/1.1",
)
- assert isinstance(req.headers, http.Headers)
+ assert isinstance(req.headers, Headers)
def test_equal(self):
a = tutils.treq()
@@ -66,13 +45,6 @@ class TestRequest(object):
assert not 'foo' == a
assert not 'foo' == b
- def test_legacy_first_line(self):
- req = tutils.treq()
-
- assert req.legacy_first_line('relative') == "GET /path HTTP/1.1"
- assert req.legacy_first_line('authority') == "GET address:22 HTTP/1.1"
- assert req.legacy_first_line('absolute') == "GET http://address:22/path HTTP/1.1"
- tutils.raises(http.HttpError, req.legacy_first_line, 'foobar')
def test_anticache(self):
req = tutils.treq()
@@ -103,44 +75,44 @@ class TestRequest(object):
def test_get_form(self):
req = tutils.treq()
- assert req.get_form() == odict.ODict()
+ assert req.get_form() == ODict()
- @mock.patch("netlib.http.semantics.Request.get_form_multipart")
- @mock.patch("netlib.http.semantics.Request.get_form_urlencoded")
+ @mock.patch("netlib.http.Request.get_form_multipart")
+ @mock.patch("netlib.http.Request.get_form_urlencoded")
def test_get_form_with_url_encoded(self, mock_method_urlencoded, mock_method_multipart):
req = tutils.treq()
- assert req.get_form() == odict.ODict()
+ assert req.get_form() == ODict()
req = tutils.treq()
req.body = "foobar"
- req.headers["Content-Type"] = semantics.HDR_FORM_URLENCODED
+ req.headers["Content-Type"] = HDR_FORM_URLENCODED
req.get_form()
assert req.get_form_urlencoded.called
assert not req.get_form_multipart.called
- @mock.patch("netlib.http.semantics.Request.get_form_multipart")
- @mock.patch("netlib.http.semantics.Request.get_form_urlencoded")
+ @mock.patch("netlib.http.Request.get_form_multipart")
+ @mock.patch("netlib.http.Request.get_form_urlencoded")
def test_get_form_with_multipart(self, mock_method_urlencoded, mock_method_multipart):
req = tutils.treq()
req.body = "foobar"
- req.headers["Content-Type"] = semantics.HDR_FORM_MULTIPART
+ req.headers["Content-Type"] = HDR_FORM_MULTIPART
req.get_form()
assert not req.get_form_urlencoded.called
assert req.get_form_multipart.called
def test_get_form_urlencoded(self):
- req = tutils.treq("foobar")
- assert req.get_form_urlencoded() == odict.ODict()
+ req = tutils.treq(body="foobar")
+ assert req.get_form_urlencoded() == ODict()
- req.headers["Content-Type"] = semantics.HDR_FORM_URLENCODED
- assert req.get_form_urlencoded() == odict.ODict(utils.urldecode(req.body))
+ req.headers["Content-Type"] = HDR_FORM_URLENCODED
+ assert req.get_form_urlencoded() == ODict(utils.urldecode(req.body))
def test_get_form_multipart(self):
- req = tutils.treq("foobar")
- assert req.get_form_multipart() == odict.ODict()
+ req = tutils.treq(body="foobar")
+ assert req.get_form_multipart() == ODict()
- req.headers["Content-Type"] = semantics.HDR_FORM_MULTIPART
- assert req.get_form_multipart() == odict.ODict(
+ req.headers["Content-Type"] = HDR_FORM_MULTIPART
+ assert req.get_form_multipart() == ODict(
utils.multipartdecode(
req.headers,
req.body
@@ -149,8 +121,8 @@ class TestRequest(object):
def test_set_form_urlencoded(self):
req = tutils.treq()
- req.set_form_urlencoded(odict.ODict([('foo', 'bar'), ('rab', 'oof')]))
- assert req.headers["Content-Type"] == semantics.HDR_FORM_URLENCODED
+ req.set_form_urlencoded(ODict([('foo', 'bar'), ('rab', 'oof')]))
+ assert req.headers["Content-Type"] == HDR_FORM_URLENCODED
assert req.body
def test_get_path_components(self):
@@ -172,7 +144,7 @@ class TestRequest(object):
def test_set_query(self):
req = tutils.treq()
- req.set_query(odict.ODict([]))
+ req.set_query(ODict([]))
def test_pretty_host(self):
r = tutils.treq()
@@ -203,21 +175,21 @@ class TestRequest(object):
assert req.pretty_url(False) == "http://address:22/path"
def test_get_cookies_none(self):
- headers = http.Headers()
+ headers = Headers()
r = tutils.treq()
r.headers = headers
assert len(r.get_cookies()) == 0
def test_get_cookies_single(self):
r = tutils.treq()
- r.headers = http.Headers(cookie="cookiename=cookievalue")
+ r.headers = Headers(cookie="cookiename=cookievalue")
result = r.get_cookies()
assert len(result) == 1
assert result['cookiename'] == ['cookievalue']
def test_get_cookies_double(self):
r = tutils.treq()
- r.headers = http.Headers(cookie="cookiename=cookievalue;othercookiename=othercookievalue")
+ r.headers = Headers(cookie="cookiename=cookievalue;othercookiename=othercookievalue")
result = r.get_cookies()
assert len(result) == 2
assert result['cookiename'] == ['cookievalue']
@@ -225,7 +197,7 @@ class TestRequest(object):
def test_get_cookies_withequalsign(self):
r = tutils.treq()
- r.headers = http.Headers(cookie="cookiename=coo=kievalue;othercookiename=othercookievalue")
+ r.headers = Headers(cookie="cookiename=coo=kievalue;othercookiename=othercookievalue")
result = r.get_cookies()
assert len(result) == 2
assert result['cookiename'] == ['coo=kievalue']
@@ -233,14 +205,14 @@ class TestRequest(object):
def test_set_cookies(self):
r = tutils.treq()
- r.headers = http.Headers(cookie="cookiename=cookievalue")
+ r.headers = Headers(cookie="cookiename=cookievalue")
result = r.get_cookies()
result["cookiename"] = ["foo"]
r.set_cookies(result)
assert r.get_cookies()["cookiename"] == ["foo"]
def test_set_url(self):
- r = tutils.treq_absolute()
+ r = tutils.treq(form_in="absolute")
r.url = "https://otheraddress:42/ORLY"
assert r.scheme == "https"
assert r.host == "otheraddress"
@@ -332,24 +304,19 @@ class TestRequest(object):
# "Host: address\r\n"
# "Content-Length: 0\r\n\r\n")
-class TestEmptyRequest(object):
- def test_init(self):
- req = semantics.EmptyRequest()
- assert req
-
class TestResponse(object):
def test_headers(self):
- tutils.raises(AssertionError, semantics.Response,
- (1, 1),
+ tutils.raises(AssertionError, Response,
+ b"HTTP/1.1",
200,
headers='foobar',
)
- resp = semantics.Response(
- (1, 1),
+ resp = Response(
+ b"HTTP/1.1",
200,
)
- assert isinstance(resp.headers, http.Headers)
+ assert isinstance(resp.headers, Headers)
def test_equal(self):
a = tutils.tresp()
@@ -366,24 +333,24 @@ class TestResponse(object):
assert "unknown content type" in repr(r)
r.headers["content-type"] = "foo"
assert "foo" in repr(r)
- assert repr(tutils.tresp(content=CONTENT_MISSING))
+ assert repr(tutils.tresp(body=CONTENT_MISSING))
def test_get_cookies_none(self):
resp = tutils.tresp()
- resp.headers = http.Headers()
+ resp.headers = Headers()
assert not resp.get_cookies()
def test_get_cookies_simple(self):
resp = tutils.tresp()
- resp.headers = http.Headers(set_cookie="cookiename=cookievalue")
+ resp.headers = Headers(set_cookie="cookiename=cookievalue")
result = resp.get_cookies()
assert len(result) == 1
assert "cookiename" in result
- assert result["cookiename"][0] == ["cookievalue", odict.ODict()]
+ assert result["cookiename"][0] == ["cookievalue", ODict()]
def test_get_cookies_with_parameters(self):
resp = tutils.tresp()
- resp.headers = http.Headers(set_cookie="cookiename=cookievalue;domain=example.com;expires=Wed Oct 21 16:29:41 2015;path=/; HttpOnly")
+ resp.headers = Headers(set_cookie="cookiename=cookievalue;domain=example.com;expires=Wed Oct 21 16:29:41 2015;path=/; HttpOnly")
result = resp.get_cookies()
assert len(result) == 1
assert "cookiename" in result
@@ -397,7 +364,7 @@ class TestResponse(object):
def test_get_cookies_no_value(self):
resp = tutils.tresp()
- resp.headers = http.Headers(set_cookie="cookiename=; Expires=Thu, 01-Jan-1970 00:00:01 GMT; path=/")
+ resp.headers = Headers(set_cookie="cookiename=; Expires=Thu, 01-Jan-1970 00:00:01 GMT; path=/")
result = resp.get_cookies()
assert len(result) == 1
assert "cookiename" in result
@@ -406,31 +373,31 @@ class TestResponse(object):
def test_get_cookies_twocookies(self):
resp = tutils.tresp()
- resp.headers = http.Headers([
+ resp.headers = Headers([
["Set-Cookie", "cookiename=cookievalue"],
["Set-Cookie", "othercookie=othervalue"]
])
result = resp.get_cookies()
assert len(result) == 2
assert "cookiename" in result
- assert result["cookiename"][0] == ["cookievalue", odict.ODict()]
+ assert result["cookiename"][0] == ["cookievalue", ODict()]
assert "othercookie" in result
- assert result["othercookie"][0] == ["othervalue", odict.ODict()]
+ assert result["othercookie"][0] == ["othervalue", ODict()]
def test_set_cookies(self):
resp = tutils.tresp()
v = resp.get_cookies()
- v.add("foo", ["bar", odict.ODictCaseless()])
+ v.add("foo", ["bar", ODictCaseless()])
resp.set_cookies(v)
v = resp.get_cookies()
assert len(v) == 1
- assert v["foo"] == [["bar", odict.ODictCaseless()]]
+ assert v["foo"] == [["bar", ODictCaseless()]]
class TestHeaders(object):
def _2host(self):
- return semantics.Headers(
+ return Headers(
[
["Host", "example.com"],
["host", "example.org"]
@@ -438,25 +405,25 @@ class TestHeaders(object):
)
def test_init(self):
- headers = semantics.Headers()
+ headers = Headers()
assert len(headers) == 0
- headers = semantics.Headers([["Host", "example.com"]])
+ headers = Headers([["Host", "example.com"]])
assert len(headers) == 1
assert headers["Host"] == "example.com"
- headers = semantics.Headers(Host="example.com")
+ headers = Headers(Host="example.com")
assert len(headers) == 1
assert headers["Host"] == "example.com"
- headers = semantics.Headers(
+ headers = Headers(
[["Host", "invalid"]],
Host="example.com"
)
assert len(headers) == 1
assert headers["Host"] == "example.com"
- headers = semantics.Headers(
+ headers = Headers(
[["Host", "invalid"], ["Accept", "text/plain"]],
Host="example.com"
)
@@ -465,7 +432,7 @@ class TestHeaders(object):
assert headers["Accept"] == "text/plain"
def test_getitem(self):
- headers = semantics.Headers(Host="example.com")
+ headers = Headers(Host="example.com")
assert headers["Host"] == "example.com"
assert headers["host"] == "example.com"
tutils.raises(KeyError, headers.__getitem__, "Accept")
@@ -474,17 +441,17 @@ class TestHeaders(object):
assert headers["Host"] == "example.com, example.org"
def test_str(self):
- headers = semantics.Headers(Host="example.com")
- assert str(headers) == "Host: example.com\r\n"
+ headers = Headers(Host="example.com")
+ assert bytes(headers) == "Host: example.com\r\n"
- headers = semantics.Headers([
+ headers = Headers([
["Host", "example.com"],
["Accept", "text/plain"]
])
assert str(headers) == "Host: example.com\r\nAccept: text/plain\r\n"
def test_setitem(self):
- headers = semantics.Headers()
+ headers = Headers()
headers["Host"] = "example.com"
assert "Host" in headers
assert "host" in headers
@@ -507,7 +474,7 @@ class TestHeaders(object):
assert "Host" in headers
def test_delitem(self):
- headers = semantics.Headers(Host="example.com")
+ headers = Headers(Host="example.com")
assert len(headers) == 1
del headers["host"]
assert len(headers) == 0
@@ -523,7 +490,7 @@ class TestHeaders(object):
assert len(headers) == 0
def test_keys(self):
- headers = semantics.Headers(Host="example.com")
+ headers = Headers(Host="example.com")
assert len(headers.keys()) == 1
assert headers.keys()[0] == "Host"
@@ -532,13 +499,13 @@ class TestHeaders(object):
assert headers.keys()[0] == "Host"
def test_eq_ne(self):
- headers1 = semantics.Headers(Host="example.com")
- headers2 = semantics.Headers(host="example.com")
+ headers1 = Headers(Host="example.com")
+ headers2 = Headers(host="example.com")
assert not (headers1 == headers2)
assert headers1 != headers2
- headers1 = semantics.Headers(Host="example.com")
- headers2 = semantics.Headers(Host="example.com")
+ headers1 = Headers(Host="example.com")
+ headers2 = Headers(Host="example.com")
assert headers1 == headers2
assert not (headers1 != headers2)
@@ -550,7 +517,7 @@ class TestHeaders(object):
assert headers.get_all("accept") == []
def test_set_all(self):
- headers = semantics.Headers(Host="example.com")
+ headers = Headers(Host="example.com")
headers.set_all("Accept", ["text/plain"])
assert len(headers) == 2
assert "accept" in headers
@@ -565,9 +532,9 @@ class TestHeaders(object):
def test_state(self):
headers = self._2host()
assert len(headers.get_state()) == 2
- assert headers == semantics.Headers.from_state(headers.get_state())
+ assert headers == Headers.from_state(headers.get_state())
- headers2 = semantics.Headers()
+ headers2 = Headers()
assert headers != headers2
headers2.load_state(headers.get_state())
assert headers == headers2
diff --git a/test/test_encoding.py b/test/test_encoding.py
index 612aea89..9da3a38d 100644
--- a/test/test_encoding.py
+++ b/test/test_encoding.py
@@ -9,25 +9,29 @@ def test_identity():
def test_gzip():
- assert "string" == encoding.decode(
+ assert b"string" == encoding.decode(
"gzip",
encoding.encode(
"gzip",
- "string"))
- assert None == encoding.decode("gzip", "bogus")
+ b"string"
+ )
+ )
+ assert encoding.decode("gzip", b"bogus") is None
def test_deflate():
- assert "string" == encoding.decode(
+ assert b"string" == encoding.decode(
"deflate",
encoding.encode(
"deflate",
- "string"))
- assert "string" == encoding.decode(
+ b"string"
+ )
+ )
+ assert b"string" == encoding.decode(
"deflate",
encoding.encode(
"deflate",
- "string")[
- 2:-
- 4])
- assert None == encoding.decode("deflate", "bogus")
+ b"string"
+ )[2:-4]
+ )
+ assert encoding.decode("deflate", b"bogus") is None
diff --git a/test/test_utils.py b/test/test_utils.py
index 9dba5d35..8b2ddae4 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -36,46 +36,51 @@ def test_pretty_size():
def test_parse_url():
- assert not utils.parse_url("")
+ with tutils.raises(ValueError):
+ utils.parse_url("")
- u = "http://foo.com:8888/test"
- s, h, po, pa = utils.parse_url(u)
- assert s == "http"
- assert h == "foo.com"
+ s, h, po, pa = utils.parse_url(b"http://foo.com:8888/test")
+ assert s == b"http"
+ assert h == b"foo.com"
assert po == 8888
- assert pa == "/test"
+ assert pa == b"/test"
s, h, po, pa = utils.parse_url("http://foo/bar")
- assert s == "http"
- assert h == "foo"
+ assert s == b"http"
+ assert h == b"foo"
assert po == 80
- assert pa == "/bar"
+ assert pa == b"/bar"
- s, h, po, pa = utils.parse_url("http://user:pass@foo/bar")
- assert s == "http"
- assert h == "foo"
+ s, h, po, pa = utils.parse_url(b"http://user:pass@foo/bar")
+ assert s == b"http"
+ assert h == b"foo"
assert po == 80
- assert pa == "/bar"
+ assert pa == b"/bar"
- s, h, po, pa = utils.parse_url("http://foo")
- assert pa == "/"
+ s, h, po, pa = utils.parse_url(b"http://foo")
+ assert pa == b"/"
- s, h, po, pa = utils.parse_url("https://foo")
+ s, h, po, pa = utils.parse_url(b"https://foo")
assert po == 443
- assert not utils.parse_url("https://foo:bar")
- assert not utils.parse_url("https://foo:")
+ with tutils.raises(ValueError):
+ utils.parse_url(b"https://foo:bar")
# Invalid IDNA
- assert not utils.parse_url("http://\xfafoo")
+ with tutils.raises(ValueError):
+ utils.parse_url("http://\xfafoo")
# Invalid PATH
- assert not utils.parse_url("http:/\xc6/localhost:56121")
+ with tutils.raises(ValueError):
+ utils.parse_url("http:/\xc6/localhost:56121")
# Null byte in host
- assert not utils.parse_url("http://foo\0")
+ with tutils.raises(ValueError):
+ utils.parse_url("http://foo\0")
# Port out of range
- assert not utils.parse_url("http://foo:999999")
+ _, _, port, _ = utils.parse_url("http://foo:999999")
+ assert port == 80
# Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt
- assert not utils.parse_url('http://lo[calhost')
+ with tutils.raises(ValueError):
+ utils.parse_url('http://lo[calhost')
def test_unparse_url():
@@ -106,23 +111,25 @@ def test_get_header_tokens():
def test_multipartdecode():
- boundary = 'somefancyboundary'
+ boundary = b'somefancyboundary'
headers = Headers(
- content_type='multipart/form-data; boundary=%s' % boundary
+ content_type=b'multipart/form-data; boundary=' + boundary
+ )
+ content = (
+ "--{0}\n"
+ "Content-Disposition: form-data; name=\"field1\"\n\n"
+ "value1\n"
+ "--{0}\n"
+ "Content-Disposition: form-data; name=\"field2\"\n\n"
+ "value2\n"
+ "--{0}--".format(boundary).encode("ascii")
)
- content = "--{0}\n" \
- "Content-Disposition: form-data; name=\"field1\"\n\n" \
- "value1\n" \
- "--{0}\n" \
- "Content-Disposition: form-data; name=\"field2\"\n\n" \
- "value2\n" \
- "--{0}--".format(boundary)
form = utils.multipartdecode(headers, content)
assert len(form) == 2
- assert form[0] == ('field1', 'value1')
- assert form[1] == ('field2', 'value2')
+ assert form[0] == (b"field1", b"value1")
+ assert form[1] == (b"field2", b"value2")
def test_parse_content_type():
diff --git a/test/test_version_check.py b/test/test_version_check.py
index 9a127814..ec2396fe 100644
--- a/test/test_version_check.py
+++ b/test/test_version_check.py
@@ -1,11 +1,11 @@
-import cStringIO
+from io import StringIO
import mock
from netlib import version_check, version
@mock.patch("sys.exit")
def test_check_mitmproxy_version(sexit):
- fp = cStringIO.StringIO()
+ fp = StringIO()
version_check.check_mitmproxy_version(version.IVERSION, fp=fp)
assert not fp.getvalue()
assert not sexit.called
@@ -18,7 +18,7 @@ def test_check_mitmproxy_version(sexit):
@mock.patch("sys.exit")
def test_check_pyopenssl_version(sexit):
- fp = cStringIO.StringIO()
+ fp = StringIO()
version_check.check_pyopenssl_version(fp=fp)
assert not fp.getvalue()
assert not sexit.called
@@ -32,7 +32,7 @@ def test_check_pyopenssl_version(sexit):
@mock.patch("OpenSSL.__version__")
def test_unparseable_pyopenssl_version(version, sexit):
version.split.return_value = ["foo", "bar"]
- fp = cStringIO.StringIO()
+ fp = StringIO()
version_check.check_pyopenssl_version(fp=fp)
assert "Cannot parse" in fp.getvalue()
assert not sexit.called
diff --git a/test/tservers.py b/test/tservers.py
index 682a9144..1f4ce725 100644
--- a/test/tservers.py
+++ b/test/tservers.py
@@ -1,7 +1,7 @@
from __future__ import (absolute_import, print_function, division)
import threading
-import Queue
-import cStringIO
+from six.moves import queue
+from io import StringIO
import OpenSSL
from netlib import tcp
from netlib import tutils
@@ -27,7 +27,7 @@ class ServerTestBase(object):
@classmethod
def setupAll(cls):
- cls.q = Queue.Queue()
+ cls.q = queue.Queue()
s = cls.makeserver()
cls.port = s.address.port
cls.server = ServerThread(s)
@@ -102,6 +102,6 @@ class TServer(tcp.TCPServer):
h.finish()
def handle_error(self, connection, client_address, fp=None):
- s = cStringIO.StringIO()
+ s = StringIO()
tcp.TCPServer.handle_error(self, connection, client_address, s)
self.q.put(s.getvalue())
diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py
index 57cfd166..3fdeb683 100644
--- a/test/websockets/test_websockets.py
+++ b/test/websockets/test_websockets.py
@@ -1,11 +1,13 @@
import os
from nose.tools import raises
+from netlib.http.http1 import read_response, read_request
from netlib import tcp, tutils, websockets, http
from netlib.http import status_codes
-from netlib.http.exceptions import *
-from netlib.http.http1 import HTTP1Protocol
+from netlib.tutils import treq
+
+from netlib.exceptions import *
from .. import tservers
@@ -34,9 +36,8 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
frame.to_file(self.wfile)
def handshake(self):
- http1_protocol = HTTP1Protocol(self)
- req = http1_protocol.read_request()
+ req = read_request(self.rfile)
key = self.protocol.check_client_handshake(req.headers)
preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)
@@ -61,8 +62,6 @@ class WebSocketsClient(tcp.TCPClient):
def connect(self):
super(WebSocketsClient, self).connect()
- http1_protocol = HTTP1Protocol(self)
-
preamble = 'GET / HTTP/1.1'
self.wfile.write(preamble + "\r\n")
headers = self.protocol.client_handshake_headers()
@@ -70,7 +69,7 @@ class WebSocketsClient(tcp.TCPClient):
self.wfile.write(str(headers) + "\r\n")
self.wfile.flush()
- resp = http1_protocol.read_response("GET", None)
+ resp = read_response(self.rfile, treq(method="GET"))
server_nonce = self.protocol.check_server_handshake(resp.headers)
if not server_nonce == self.protocol.create_server_nonce(
@@ -158,9 +157,8 @@ class TestWebSockets(tservers.ServerTestBase):
class BadHandshakeHandler(WebSocketsEchoHandler):
def handshake(self):
- http1_protocol = HTTP1Protocol(self)
- client_hs = http1_protocol.read_request()
+ client_hs = read_request(self.rfile)
self.protocol.check_client_handshake(client_hs.headers)
preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101)