diff options
author | Maximilian Hils <git@maximilianhils.com> | 2015-09-16 20:19:52 +0200 |
---|---|---|
committer | Maximilian Hils <git@maximilianhils.com> | 2015-09-16 20:19:52 +0200 |
commit | e1659f3fcf83b5993b776a4ef3d2de70fbe27aa2 (patch) | |
tree | c0eba50b522d1d0183b057e9cae7bf7cc38c4fc3 | |
parent | 2f9c566e480c377566a0ae044d698a75b45cd54c (diff) | |
parent | 265f31e8782ee9da511ce4b63aa2da00221cbf66 (diff) | |
download | mitmproxy-e1659f3fcf83b5993b776a4ef3d2de70fbe27aa2.tar.gz mitmproxy-e1659f3fcf83b5993b776a4ef3d2de70fbe27aa2.tar.bz2 mitmproxy-e1659f3fcf83b5993b776a4ef3d2de70fbe27aa2.zip |
Merge pull request #92 from mitmproxy/python3
Python3 & HTTP1 Refactor
32 files changed, 1441 insertions, 1604 deletions
diff --git a/.travis.yml b/.travis.yml index fd2fba3d..fa997542 100644 --- a/.travis.yml +++ b/.travis.yml @@ -15,6 +15,8 @@ matrix: - debian-sid packages: - libssl-dev + - python: 3.5 + script: "nosetests --with-cov --cov-report term-missing test/http/http1" - python: pypy - python: pypy env: OPENSSL=1.0.2 @@ -67,4 +69,4 @@ cache: - /home/travis/virtualenv/python2.7.9/lib/python2.7/site-packages - /home/travis/virtualenv/python2.7.9/bin - /home/travis/virtualenv/pypy-2.5.0/site-packages - - /home/travis/virtualenv/pypy-2.5.0/bin
\ No newline at end of file + - /home/travis/virtualenv/pypy-2.5.0/bin diff --git a/netlib/encoding.py b/netlib/encoding.py index f107eb5f..06830f2c 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -2,13 +2,13 @@ Utility functions for decoding response bodies. """ from __future__ import absolute_import -import cStringIO +from io import BytesIO import gzip import zlib __ALL__ = ["ENCODINGS"] -ENCODINGS = set(["identity", "gzip", "deflate"]) +ENCODINGS = {"identity", "gzip", "deflate"} def decode(e, content): @@ -42,7 +42,7 @@ def identity(content): def decode_gzip(content): - gfile = gzip.GzipFile(fileobj=cStringIO.StringIO(content)) + gfile = gzip.GzipFile(fileobj=BytesIO(content)) try: return gfile.read() except (IOError, EOFError): @@ -50,7 +50,7 @@ def decode_gzip(content): def encode_gzip(content): - s = cStringIO.StringIO() + s = BytesIO() gf = gzip.GzipFile(fileobj=s, mode='wb') gf.write(content) gf.close() diff --git a/netlib/exceptions.py b/netlib/exceptions.py new file mode 100644 index 00000000..e13af473 --- /dev/null +++ b/netlib/exceptions.py @@ -0,0 +1,32 @@ +""" +We try to be very hygienic regarding the exceptions we throw: +Every Exception netlib raises shall be a subclass of NetlibException. + + +See also: http://lucumr.pocoo.org/2014/10/16/on-error-handling/ +""" +from __future__ import absolute_import, print_function, division + + +class NetlibException(Exception): + """ + Base class for all exceptions thrown by netlib. + """ + def __init__(self, message=None): + super(NetlibException, self).__init__(message) + + +class ReadDisconnect(object): + """Immediate EOF""" + + +class HttpException(NetlibException): + pass + + +class HttpReadDisconnect(HttpException, ReadDisconnect): + pass + + +class HttpSyntaxException(HttpException): + pass diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index 9b4b0e6b..d72884b3 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -1,2 +1,12 @@ -from exceptions import * -from semantics import * +from __future__ import absolute_import, print_function, division +from .models import Request, Response, Headers +from .models import ALPN_PROTO_HTTP1, ALPN_PROTO_H2 +from .models import HDR_FORM_MULTIPART, HDR_FORM_URLENCODED, CONTENT_MISSING +from . import http1, http2 + +__all__ = [ + "Request", "Response", "Headers", + "ALPN_PROTO_HTTP1", "ALPN_PROTO_H2", + "HDR_FORM_MULTIPART", "HDR_FORM_URLENCODED", "CONTENT_MISSING", + "http1", "http2", +] diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py index fe1f0d14..2055f843 100644 --- a/netlib/http/authentication.py +++ b/netlib/http/authentication.py @@ -19,8 +19,8 @@ def parse_http_basic_auth(s): def assemble_http_basic_auth(scheme, username, password): - v = binascii.b2a_base64(username + ":" + password) - return scheme + " " + v + v = binascii.b2a_base64(username + b":" + password) + return scheme + b" " + v class NullProxyAuth(object): diff --git a/netlib/http/exceptions.py b/netlib/http/exceptions.py deleted file mode 100644 index 8a2bbebc..00000000 --- a/netlib/http/exceptions.py +++ /dev/null @@ -1,9 +0,0 @@ -class HttpError(Exception): - - def __init__(self, code, message): - super(HttpError, self).__init__(message) - self.code = code - - -class HttpErrorConnClosed(HttpError): - pass diff --git a/netlib/http/http1/__init__.py b/netlib/http/http1/__init__.py index 6b5043af..2d33ff8a 100644 --- a/netlib/http/http1/__init__.py +++ b/netlib/http/http1/__init__.py @@ -1 +1,23 @@ -from protocol import * +from __future__ import absolute_import, print_function, division +from .read import ( + read_request, read_request_head, + read_response, read_response_head, + read_body, + connection_close, + expected_http_body_size, +) +from .assemble import ( + assemble_request, assemble_request_head, + assemble_response, assemble_response_head, +) + + +__all__ = [ + "read_request", "read_request_head", + "read_response", "read_response_head", + "read_body", + "connection_close", + "expected_http_body_size", + "assemble_request", "assemble_request_head", + "assemble_response", "assemble_response_head", +] diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py new file mode 100644 index 00000000..ace25d79 --- /dev/null +++ b/netlib/http/http1/assemble.py @@ -0,0 +1,103 @@ +from __future__ import absolute_import, print_function, division + +from ... import utils +from ...exceptions import HttpException +from .. import CONTENT_MISSING + + +def assemble_request(request): + if request.body == CONTENT_MISSING: + raise HttpException("Cannot assemble flow with CONTENT_MISSING") + head = assemble_request_head(request) + return head + request.body + + +def assemble_request_head(request): + first_line = _assemble_request_line(request) + headers = _assemble_request_headers(request) + return b"%s\r\n%s\r\n" % (first_line, headers) + + +def assemble_response(response): + if response.body == CONTENT_MISSING: + raise HttpException("Cannot assemble flow with CONTENT_MISSING") + head = assemble_response_head(response) + return head + response.body + + +def assemble_response_head(response, preserve_transfer_encoding=False): + first_line = _assemble_response_line(response) + headers = _assemble_response_headers(response, preserve_transfer_encoding) + return b"%s\r\n%s\r\n" % (first_line, headers) + + +def _assemble_request_line(request, form=None): + if form is None: + form = request.form_out + if form == "relative": + return b"%s %s %s" % ( + request.method, + request.path, + request.httpversion + ) + elif form == "authority": + return b"%s %s:%d %s" % ( + request.method, + request.host, + request.port, + request.httpversion + ) + elif form == "absolute": + return b"%s %s://%s:%d%s %s" % ( + request.method, + request.scheme, + request.host, + request.port, + request.path, + request.httpversion + ) + else: # pragma: nocover + raise RuntimeError("Invalid request form") + + +def _assemble_request_headers(request): + headers = request.headers.copy() + for k in request._headers_to_strip_off: + headers.pop(k, None) + if b"host" not in headers and request.scheme and request.host and request.port: + headers[b"Host"] = utils.hostport( + request.scheme, + request.host, + request.port + ) + + # If content is defined (i.e. not None or CONTENT_MISSING), we always + # add a content-length header. + if request.body or request.body == b"": + headers[b"Content-Length"] = str(len(request.body)).encode("ascii") + + return bytes(headers) + + +def _assemble_response_line(response): + return b"%s %d %s" % ( + response.httpversion, + response.status_code, + response.msg, + ) + + +def _assemble_response_headers(response, preserve_transfer_encoding=False): + # TODO: Remove preserve_transfer_encoding + headers = response.headers.copy() + for k in response._headers_to_strip_off: + headers.pop(k, None) + if not preserve_transfer_encoding: + headers.pop(b"Transfer-Encoding", None) + + # If body is defined (i.e. not None or CONTENT_MISSING), we always + # add a content-length header. + if response.body or response.body == b"": + headers[b"Content-Length"] = str(len(response.body)).encode("ascii") + + return bytes(headers) diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py deleted file mode 100644 index cf1dffa3..00000000 --- a/netlib/http/http1/protocol.py +++ /dev/null @@ -1,586 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import string -import sys -import time - -from ... import utils, tcp, http -from .. import semantics, Headers -from ..exceptions import * - - -class TCPHandler(object): - - def __init__(self, rfile, wfile=None): - self.rfile = rfile - self.wfile = wfile - - -class HTTP1Protocol(semantics.ProtocolMixin): - - ALPN_PROTO_HTTP1 = 'http/1.1' - - def __init__(self, tcp_handler=None, rfile=None, wfile=None): - self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) - - def read_request( - self, - include_body=True, - body_size_limit=None, - allow_empty=False, - ): - """ - Parse an HTTP request from a file stream - - Args: - include_body (bool): Read response body as well - body_size_limit (bool): Maximum body size - wfile (file): If specified, HTTP Expect headers are handled - automatically, by writing a HTTP 100 CONTINUE response to the stream. - - Returns: - Request: The HTTP request - - Raises: - HttpError: If the input is invalid. - """ - timestamp_start = time.time() - if hasattr(self.tcp_handler.rfile, "reset_timestamps"): - self.tcp_handler.rfile.reset_timestamps() - - httpversion, host, port, scheme, method, path, headers, body = ( - None, None, None, None, None, None, None, None) - - request_line = self._get_request_line() - if not request_line: - if allow_empty: - return http.EmptyRequest() - else: - raise tcp.NetLibDisconnect() - - request_line_parts = self._parse_init(request_line) - if not request_line_parts: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - method, path, httpversion = request_line_parts - - if path == '*' or path.startswith("/"): - form_in = "relative" - if not utils.isascii(path): - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - elif method == 'CONNECT': - form_in = "authority" - r = self._parse_init_connect(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - host, port, httpversion = r - path = None - else: - form_in = "absolute" - r = self._parse_init_proxy(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - _, scheme, host, port, path, _ = r - - headers = self.read_headers() - if headers is None: - raise HttpError(400, "Invalid headers") - - expect_header = headers.get("expect", "").lower() - if expect_header == "100-continue" and httpversion == (1, 1): - self.tcp_handler.wfile.write( - 'HTTP/1.1 100 Continue\r\n' - '\r\n' - ) - self.tcp_handler.wfile.flush() - del headers['expect'] - - if include_body: - body = self.read_http_body( - headers, - body_size_limit, - method, - None, - True - ) - - if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = self.tcp_handler.rfile.first_byte_timestamp - - timestamp_end = time.time() - - return http.Request( - form_in, - method, - scheme, - host, - port, - path, - httpversion, - headers, - body, - timestamp_start, - timestamp_end, - ) - - def read_response( - self, - request_method, - body_size_limit=None, - include_body=True, - ): - """ - Returns an http.Response - - By default, both response header and body are read. - If include_body=False is specified, body may be one of the - following: - - None, if the response is technically allowed to have a response body - - "", if the response must not have a response body (e.g. it's a - response to a HEAD request) - """ - timestamp_start = time.time() - if hasattr(self.tcp_handler.rfile, "reset_timestamps"): - self.tcp_handler.rfile.reset_timestamps() - - line = self.tcp_handler.rfile.readline() - # Possible leftover from previous message - if line == "\r\n" or line == "\n": - line = self.tcp_handler.rfile.readline() - if not line: - raise HttpErrorConnClosed(502, "Server disconnect.") - parts = self.parse_response_line(line) - if not parts: - raise HttpError(502, "Invalid server response: %s" % repr(line)) - proto, code, msg = parts - httpversion = self._parse_http_protocol(proto) - if httpversion is None: - raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) - headers = self.read_headers() - if headers is None: - raise HttpError(502, "Invalid headers.") - - if include_body: - body = self.read_http_body( - headers, - body_size_limit, - request_method, - code, - False - ) - else: - # if include_body==False then a None body means the body should be - # read separately - body = None - - if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = self.tcp_handler.rfile.first_byte_timestamp - - if include_body: - timestamp_end = time.time() - else: - timestamp_end = None - - return http.Response( - httpversion, - code, - msg, - headers, - body, - timestamp_start=timestamp_start, - timestamp_end=timestamp_end, - ) - - def assemble_request(self, request): - assert isinstance(request, semantics.Request) - - if request.body == semantics.CONTENT_MISSING: - raise http.HttpError( - 502, - "Cannot assemble flow with CONTENT_MISSING" - ) - first_line = self._assemble_request_first_line(request) - headers = self._assemble_request_headers(request) - return "%s\r\n%s\r\n%s" % (first_line, headers, request.body) - - def assemble_response(self, response): - assert isinstance(response, semantics.Response) - - if response.body == semantics.CONTENT_MISSING: - raise http.HttpError( - 502, - "Cannot assemble flow with CONTENT_MISSING" - ) - first_line = self._assemble_response_first_line(response) - headers = self._assemble_response_headers(response) - return "%s\r\n%s\r\n%s" % (first_line, headers, response.body) - - def read_headers(self): - """ - Read a set of headers. - Stop once a blank line is reached. - - Return a Header object, or None if headers are invalid. - """ - ret = [] - while True: - line = self.tcp_handler.rfile.readline() - if not line or line == '\r\n' or line == '\n': - break - if line[0] in ' \t': - if not ret: - return None - # continued header - ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() - else: - i = line.find(':') - # We're being liberal in what we accept, here. - if i > 0: - name = line[:i] - value = line[i + 1:].strip() - ret.append([name, value]) - else: - return None - return Headers(ret) - - - def read_http_body(self, *args, **kwargs): - return "".join(self.read_http_body_chunked(*args, **kwargs)) - - - def read_http_body_chunked( - self, - headers, - limit, - request_method, - response_code, - is_request, - max_chunk_size=None - ): - """ - Read an HTTP message body: - headers: A Header object - limit: Size limit. - is_request: True if the body to read belongs to a request, False - otherwise - """ - if max_chunk_size is None: - max_chunk_size = limit or sys.maxsize - - expected_size = self.expected_http_body_size( - headers, is_request, request_method, response_code - ) - - if expected_size is None: - if self.has_chunked_encoding(headers): - # Python 3: yield from - for x in self._read_chunked(limit, is_request): - yield x - else: # pragma: nocover - raise HttpError( - 400 if is_request else 502, - "Content-Length unknown but no chunked encoding" - ) - elif expected_size >= 0: - if limit is not None and expected_size > limit: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s, content-length was %s" % ( - limit, expected_size - ) - ) - bytes_left = expected_size - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - content = self.tcp_handler.rfile.read(chunk_size) - yield content - bytes_left -= chunk_size - else: - bytes_left = limit or -1 - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - content = self.tcp_handler.rfile.read(chunk_size) - if not content: - return - yield content - bytes_left -= chunk_size - not_done = self.tcp_handler.rfile.read(1) - if not_done: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s," % limit - ) - - @classmethod - def expected_http_body_size( - self, - headers, - is_request, - request_method, - response_code, - ): - """ - Returns the expected body length: - - a positive integer, if the size is known in advance - - None, if the size in unknown in advance (chunked encoding or invalid - data) - - -1, if all data should be read until end of stream. - - May raise HttpError. - """ - # Determine response size according to - # http://tools.ietf.org/html/rfc7230#section-3.3 - if request_method: - request_method = request_method.upper() - - if (not is_request and ( - request_method == "HEAD" or - (request_method == "CONNECT" and response_code == 200) or - response_code in [204, 304] or - 100 <= response_code <= 199)): - return 0 - if self.has_chunked_encoding(headers): - return None - if "content-length" in headers: - try: - size = int(headers["content-length"]) - if size < 0: - raise ValueError() - return size - except ValueError: - return None - if is_request: - return 0 - return -1 - - - @classmethod - def has_chunked_encoding(self, headers): - return "chunked" in headers.get("transfer-encoding", "").lower() - - - def _get_request_line(self): - """ - Get a line, possibly preceded by a blank. - """ - line = self.tcp_handler.rfile.readline() - if line == "\r\n" or line == "\n": - # Possible leftover from previous message - line = self.tcp_handler.rfile.readline() - return line - - def _read_chunked(self, limit, is_request): - """ - Read a chunked HTTP body. - - May raise HttpError. - """ - # FIXME: Should check if chunked is the final encoding in the headers - # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 - # 3.3 2. - total = 0 - code = 400 if is_request else 502 - while True: - line = self.tcp_handler.rfile.readline(128) - if line == "": - raise HttpErrorConnClosed(code, "Connection closed prematurely") - if line != '\r\n' and line != '\n': - try: - length = int(line, 16) - except ValueError: - raise HttpError( - code, - "Invalid chunked encoding length: %s" % line - ) - total += length - if limit is not None and total > limit: - msg = "HTTP Body too large. Limit is %s," \ - " chunked content longer than %s" % (limit, total) - raise HttpError(code, msg) - chunk = self.tcp_handler.rfile.read(length) - suffix = self.tcp_handler.rfile.readline(5) - if suffix != '\r\n': - raise HttpError(code, "Malformed chunked body") - if length == 0: - return - yield chunk - - @classmethod - def _parse_http_protocol(self, line): - """ - Parse an HTTP protocol declaration. - Returns a (major, minor) tuple, or None. - """ - if not line.startswith("HTTP/"): - return None - _, version = line.split('/', 1) - if "." not in version: - return None - major, minor = version.split('.', 1) - try: - major = int(major) - minor = int(minor) - except ValueError: - return None - return major, minor - - @classmethod - def _parse_init(self, line): - try: - method, url, protocol = string.split(line) - except ValueError: - return None - httpversion = self._parse_http_protocol(protocol) - if not httpversion: - return None - if not utils.isascii(method): - return None - return method, url, httpversion - - @classmethod - def _parse_init_connect(self, line): - """ - Returns (host, port, httpversion) if line is a valid CONNECT line. - http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 - """ - v = self._parse_init(line) - if not v: - return None - method, url, httpversion = v - - if method.upper() != 'CONNECT': - return None - try: - host, port = url.split(":") - except ValueError: - return None - try: - port = int(port) - except ValueError: - return None - if not utils.is_valid_port(port): - return None - if not utils.is_valid_host(host): - return None - return host, port, httpversion - - @classmethod - def _parse_init_proxy(self, line): - v = self._parse_init(line) - if not v: - return None - method, url, httpversion = v - - parts = utils.parse_url(url) - if not parts: - return None - scheme, host, port, path = parts - return method, scheme, host, port, path, httpversion - - @classmethod - def _parse_init_http(self, line): - """ - Returns (method, url, httpversion) - """ - v = self._parse_init(line) - if not v: - return None - method, url, httpversion = v - if not utils.isascii(url): - return None - if not (url.startswith("/") or url == "*"): - return None - return method, url, httpversion - - @classmethod - def connection_close(self, httpversion, headers): - """ - Checks the message to see if the client connection should be closed - according to RFC 2616 Section 8.1 Note that a connection should be - closed as well if the response has been read until end of the stream. - """ - # At first, check if we have an explicit Connection header. - if "connection" in headers: - toks = utils.get_header_tokens(headers, "connection") - if "close" in toks: - return True - elif "keep-alive" in toks: - return False - - # If we don't have a Connection header, HTTP 1.1 connections are assumed to - # be persistent - return httpversion != (1, 1) - - @classmethod - def parse_response_line(self, line): - parts = line.strip().split(" ", 2) - if len(parts) == 2: # handle missing message gracefully - parts.append("") - if len(parts) != 3: - return None - proto, code, msg = parts - try: - code = int(code) - except ValueError: - return None - return (proto, code, msg) - - @classmethod - def _assemble_request_first_line(self, request): - return request.legacy_first_line() - - def _assemble_request_headers(self, request): - headers = request.headers.copy() - for k in request._headers_to_strip_off: - headers.pop(k, None) - if 'host' not in headers and request.scheme and request.host and request.port: - headers["Host"] = utils.hostport( - request.scheme, - request.host, - request.port - ) - - # If content is defined (i.e. not None or CONTENT_MISSING), we always - # add a content-length header. - if request.body or request.body == "": - headers["Content-Length"] = str(len(request.body)) - - return str(headers) - - def _assemble_response_first_line(self, response): - return 'HTTP/%s.%s %s %s' % ( - response.httpversion[0], - response.httpversion[1], - response.status_code, - response.msg, - ) - - def _assemble_response_headers( - self, - response, - preserve_transfer_encoding=False, - ): - headers = response.headers.copy() - for k in response._headers_to_strip_off: - headers.pop(k, None) - if not preserve_transfer_encoding: - headers.pop('Transfer-Encoding', None) - - # If body is defined (i.e. not None or CONTENT_MISSING), we always - # add a content-length header. - if response.body or response.body == "": - headers["Content-Length"] = str(len(response.body)) - - return str(headers) diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py new file mode 100644 index 00000000..62025d15 --- /dev/null +++ b/netlib/http/http1/read.py @@ -0,0 +1,360 @@ +from __future__ import absolute_import, print_function, division +import time +import sys +import re + +from ... import utils +from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException +from .. import Request, Response, Headers +from netlib.tcp import NetLibDisconnect + + +def read_request(rfile, body_size_limit=None): + request = read_request_head(rfile) + expected_body_size = expected_http_body_size(request) + request.body = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit)) + request.timestamp_end = time.time() + return request + + +def read_request_head(rfile): + """ + Parse an HTTP request head (request line + headers) from an input stream + + Args: + rfile: The input stream + + Returns: + The HTTP request object (without body) + + Raises: + HttpReadDisconnect: No bytes can be read from rfile. + HttpSyntaxException: The input is malformed HTTP. + HttpException: Any other error occured. + """ + timestamp_start = time.time() + if hasattr(rfile, "reset_timestamps"): + rfile.reset_timestamps() + + form, method, scheme, host, port, path, http_version = _read_request_line(rfile) + headers = _read_headers(rfile) + + if hasattr(rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = rfile.first_byte_timestamp + + return Request( + form, method, scheme, host, port, path, http_version, headers, None, timestamp_start + ) + + +def read_response(rfile, request, body_size_limit=None): + response = read_response_head(rfile) + expected_body_size = expected_http_body_size(request, response) + response.body = b"".join(read_body(rfile, expected_body_size, body_size_limit)) + response.timestamp_end = time.time() + return response + + +def read_response_head(rfile): + """ + Parse an HTTP response head (response line + headers) from an input stream + + Args: + rfile: The input stream + + Returns: + The HTTP request object (without body) + + Raises: + HttpReadDisconnect: No bytes can be read from rfile. + HttpSyntaxException: The input is malformed HTTP. + HttpException: Any other error occured. + """ + + timestamp_start = time.time() + if hasattr(rfile, "reset_timestamps"): + rfile.reset_timestamps() + + http_version, status_code, message = _read_response_line(rfile) + headers = _read_headers(rfile) + + if hasattr(rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = rfile.first_byte_timestamp + + return Response(http_version, status_code, message, headers, None, timestamp_start) + + +def read_body(rfile, expected_size, limit=None, max_chunk_size=4096): + """ + Read an HTTP message body + + Args: + rfile: The input stream + expected_size: The expected body size (see :py:meth:`expected_body_size`) + limit: Maximum body size + max_chunk_size: Maximium chunk size that gets yielded + + Returns: + A generator that yields byte chunks of the content. + + Raises: + HttpException, if an error occurs + + Caveats: + max_chunk_size is not considered if the transfer encoding is chunked. + """ + if not limit or limit < 0: + limit = sys.maxsize + if not max_chunk_size: + max_chunk_size = limit + + if expected_size is None: + for x in _read_chunked(rfile, limit): + yield x + elif expected_size >= 0: + if limit is not None and expected_size > limit: + raise HttpException( + "HTTP Body too large. " + "Limit is {}, content length was advertised as {}".format(limit, expected_size) + ) + bytes_left = expected_size + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + content = rfile.read(chunk_size) + if len(content) < chunk_size: + raise HttpException("Unexpected EOF") + yield content + bytes_left -= chunk_size + else: + bytes_left = limit + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + content = rfile.read(chunk_size) + if not content: + return + yield content + bytes_left -= chunk_size + not_done = rfile.read(1) + if not_done: + raise HttpException("HTTP body too large. Limit is {}.".format(limit)) + + +def connection_close(http_version, headers): + """ + Checks the message to see if the client connection should be closed + according to RFC 2616 Section 8.1. + """ + # At first, check if we have an explicit Connection header. + if b"connection" in headers: + tokens = utils.get_header_tokens(headers, "connection") + if b"close" in tokens: + return True + elif b"keep-alive" in tokens: + return False + + # If we don't have a Connection header, HTTP 1.1 connections are assumed to + # be persistent + return http_version != b"HTTP/1.1" + + +def expected_http_body_size(request, response=None): + """ + Returns: + The expected body length: + - a positive integer, if the size is known in advance + - None, if the size in unknown in advance (chunked encoding) + - -1, if all data should be read until end of stream. + + Raises: + HttpSyntaxException, if the content length header is invalid + """ + # Determine response size according to + # http://tools.ietf.org/html/rfc7230#section-3.3 + if not response: + headers = request.headers + response_code = None + is_request = True + else: + headers = response.headers + response_code = response.status_code + is_request = False + + if is_request: + if headers.get(b"expect", b"").lower() == b"100-continue": + return 0 + else: + if request.method.upper() == b"HEAD": + return 0 + if 100 <= response_code <= 199: + return 0 + if response_code == 200 and request.method.upper() == b"CONNECT": + return 0 + if response_code in (204, 304): + return 0 + + if b"chunked" in headers.get(b"transfer-encoding", b"").lower(): + return None + if b"content-length" in headers: + try: + size = int(headers[b"content-length"]) + if size < 0: + raise ValueError() + return size + except ValueError: + raise HttpSyntaxException("Unparseable Content Length") + if is_request: + return 0 + return -1 + + +def _get_first_line(rfile): + try: + line = rfile.readline() + if line == b"\r\n" or line == b"\n": + # Possible leftover from previous message + line = rfile.readline() + except NetLibDisconnect: + raise HttpReadDisconnect() + if not line: + raise HttpReadDisconnect() + line = line.strip() + try: + line.decode("ascii") + except ValueError: + raise HttpSyntaxException("Non-ascii characters in first line: {}".format(line)) + return line.strip() + + +def _read_request_line(rfile): + line = _get_first_line(rfile) + + try: + method, path, http_version = line.split(b" ") + + if path == b"*" or path.startswith(b"/"): + form = "relative" + scheme, host, port = None, None, None + elif method == b"CONNECT": + form = "authority" + host, port = _parse_authority_form(path) + scheme, path = None, None + else: + form = "absolute" + scheme, host, port, path = utils.parse_url(path) + + _check_http_version(http_version) + except ValueError: + raise HttpSyntaxException("Bad HTTP request line: {}".format(line)) + + return form, method, scheme, host, port, path, http_version + + +def _parse_authority_form(hostport): + """ + Returns (host, port) if hostport is a valid authority-form host specification. + http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 + + Raises: + ValueError, if the input is malformed + """ + try: + host, port = hostport.split(b":") + port = int(port) + if not utils.is_valid_host(host) or not utils.is_valid_port(port): + raise ValueError() + except ValueError: + raise HttpSyntaxException("Invalid host specification: {}".format(hostport)) + + return host, port + + +def _read_response_line(rfile): + line = _get_first_line(rfile) + + try: + + parts = line.split(b" ", 2) + if len(parts) == 2: # handle missing message gracefully + parts.append(b"") + + http_version, status_code, message = parts + status_code = int(status_code) + _check_http_version(http_version) + + except ValueError: + raise HttpSyntaxException("Bad HTTP response line: {}".format(line)) + + return http_version, status_code, message + + +def _check_http_version(http_version): + if not re.match(br"^HTTP/\d\.\d$", http_version): + raise HttpSyntaxException("Unknown HTTP version: {}".format(http_version)) + + +def _read_headers(rfile): + """ + Read a set of headers. + Stop once a blank line is reached. + + Returns: + A headers object + + Raises: + HttpSyntaxException + """ + ret = [] + while True: + line = rfile.readline() + if not line or line == b"\r\n" or line == b"\n": + break + if line[0] in b" \t": + if not ret: + raise HttpSyntaxException("Invalid headers") + # continued header + ret[-1][1] = ret[-1][1] + b'\r\n ' + line.strip() + else: + try: + name, value = line.split(b":", 1) + value = value.strip() + if not name or not value: + raise ValueError() + ret.append([name, value]) + except ValueError: + raise HttpSyntaxException("Invalid headers") + return Headers(ret) + + +def _read_chunked(rfile, limit=sys.maxsize): + """ + Read a HTTP body with chunked transfer encoding. + + Args: + rfile: the input file + limit: A positive integer + """ + total = 0 + while True: + line = rfile.readline(128) + if line == b"": + raise HttpException("Connection closed prematurely") + if line != b"\r\n" and line != b"\n": + try: + length = int(line, 16) + except ValueError: + raise HttpSyntaxException("Invalid chunked encoding length: {}".format(line)) + total += length + if total > limit: + raise HttpException( + "HTTP Body too large. Limit is {}, " + "chunked content longer than {}".format(limit, total) + ) + chunk = rfile.read(length) + suffix = rfile.readline(5) + if suffix != b"\r\n": + raise HttpSyntaxException("Malformed chunked body") + if length == 0: + return + yield chunk diff --git a/netlib/http/http2/__init__.py b/netlib/http/http2/__init__.py index 5acf7696..7043d36f 100644 --- a/netlib/http/http2/__init__.py +++ b/netlib/http/http2/__init__.py @@ -1,2 +1,6 @@ -from frame import * -from protocol import * +from __future__ import absolute_import, print_function, division +from .connections import HTTP2Protocol + +__all__ = [ + "HTTP2Protocol" +] diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/connections.py index b6d376d3..5220d5d2 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/connections.py @@ -3,8 +3,8 @@ import itertools import time from hpack.hpack import Encoder, Decoder -from netlib import http, utils -from netlib.http import semantics +from ... import utils +from .. import Headers, Response, Request, ALPN_PROTO_H2 from . import frame @@ -15,7 +15,7 @@ class TCPHandler(object): self.wfile = wfile -class HTTP2Protocol(semantics.ProtocolMixin): +class HTTP2Protocol(object): ERROR_CODES = utils.BiDi( NO_ERROR=0x0, @@ -36,8 +36,6 @@ class HTTP2Protocol(semantics.ProtocolMixin): CLIENT_CONNECTION_PREFACE = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" - ALPN_PROTO_H2 = 'h2' - def __init__( self, tcp_handler=None, @@ -62,6 +60,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): def read_request( self, + __rfile, include_body=True, body_size_limit=None, allow_empty=False, @@ -111,7 +110,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): port = 80 if scheme == 'http' else 443 port = int(port) - request = http.Request( + request = Request( form_in, method, scheme, @@ -131,6 +130,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): def read_response( self, + __rfile, request_method='', body_size_limit=None, include_body=True, @@ -159,7 +159,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): else: timestamp_end = None - response = http.Response( + response = Response( (2, 0), int(headers.get(':status', 502)), "", @@ -172,8 +172,16 @@ class HTTP2Protocol(semantics.ProtocolMixin): return response + def assemble(self, message): + if isinstance(message, Request): + return self.assemble_request(message) + elif isinstance(message, Response): + return self.assemble_response(message) + else: + raise ValueError("HTTP message not supported.") + def assemble_request(self, request): - assert isinstance(request, semantics.Request) + assert isinstance(request, Request) authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host if self.tcp_handler.address.port != 443: @@ -200,7 +208,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): self._create_body(request.body, stream_id))) def assemble_response(self, response): - assert isinstance(response, semantics.Response) + assert isinstance(response, Response) headers = response.headers.copy() @@ -275,7 +283,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): def check_alpn(self): alp = self.tcp_handler.get_alpn_proto_negotiated() - if alp != self.ALPN_PROTO_H2: + if alp != ALPN_PROTO_H2: raise NotImplementedError( "HTTP2Protocol can not handle unknown ALP: %s" % alp) return True @@ -405,7 +413,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): else: self._handle_unexpected_frame(frm) - headers = http.Headers( + headers = Headers( [[str(k), str(v)] for k, v in self.decoder.decode(header_block_fragment)] ) diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frame.py index b36b3adf..cb2cde99 100644 --- a/netlib/http/http2/frame.py +++ b/netlib/http/http2/frame.py @@ -1,12 +1,31 @@ -import sys +from __future__ import absolute_import, print_function, division import struct from hpack.hpack import Encoder, Decoder -from .. import utils +from ...utils import BiDi +from ...exceptions import HttpSyntaxException -class FrameSizeError(Exception): - pass +ERROR_CODES = BiDi( + NO_ERROR=0x0, + PROTOCOL_ERROR=0x1, + INTERNAL_ERROR=0x2, + FLOW_CONTROL_ERROR=0x3, + SETTINGS_TIMEOUT=0x4, + STREAM_CLOSED=0x5, + FRAME_SIZE_ERROR=0x6, + REFUSED_STREAM=0x7, + CANCEL=0x8, + COMPRESSION_ERROR=0x9, + CONNECT_ERROR=0xa, + ENHANCE_YOUR_CALM=0xb, + INADEQUATE_SECURITY=0xc, + HTTP_1_1_REQUIRED=0xd +) + +CLIENT_CONNECTION_PREFACE = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + +ALPN_PROTO_H2 = b'h2' class Frame(object): @@ -30,7 +49,9 @@ class Frame(object): length=0, flags=FLAG_NO_FLAGS, stream_id=0x0): - valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) + valid_flags = 0 + for flag in self.VALID_FLAGS: + valid_flags |= flag if flags | valid_flags != valid_flags: raise ValueError('invalid flags detected.') @@ -61,7 +82,7 @@ class Frame(object): SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] if length > max_frame_size: - raise FrameSizeError( + raise HttpSyntaxException( "Frame size exceeded: %d, but only %d allowed." % ( length, max_frame_size)) @@ -80,7 +101,7 @@ class Frame(object): stream_id = fields[4] if raw_header[:4] == b'HTTP': # pragma no cover - print >> sys.stderr, "WARNING: This looks like an HTTP/1 connection!" + raise HttpSyntaxException("Expected HTTP2 Frame, got HTTP/1 connection") cls._check_frame_size(length, state) @@ -339,7 +360,7 @@ class SettingsFrame(Frame): TYPE = 0x4 VALID_FLAGS = [Frame.FLAG_ACK] - SETTINGS = utils.BiDi( + SETTINGS = BiDi( SETTINGS_HEADER_TABLE_SIZE=0x1, SETTINGS_ENABLE_PUSH=0x2, SETTINGS_MAX_CONCURRENT_STREAMS=0x3, @@ -366,7 +387,7 @@ class SettingsFrame(Frame): def from_bytes(cls, state, length, flags, stream_id, payload): f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - for i in xrange(0, len(payload), 6): + for i in range(0, len(payload), 6): identifier, value = struct.unpack("!HL", payload[i:i + 6]) f.settings[identifier] = value diff --git a/netlib/http/semantics.py b/netlib/http/models.py index 5bb098a7..2d09535c 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/models.py @@ -1,20 +1,28 @@ -from __future__ import (absolute_import, print_function, division) -import UserDict +from __future__ import absolute_import, print_function, division import copy -import urllib -import urlparse -from .. import odict -from . import cookies, exceptions -from netlib import utils, encoding +from ..odict import ODict +from .. import utils, encoding +from ..utils import always_bytes, always_byte_args +from . import cookies -HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" -HDR_FORM_MULTIPART = "multipart/form-data" +import six +from six.moves import urllib +try: + from collections import MutableMapping +except ImportError: + from collections.abc import MutableMapping + +# TODO: Move somewhere else? +ALPN_PROTO_HTTP1 = b'http/1.1' +ALPN_PROTO_H2 = b'h2' +HDR_FORM_URLENCODED = b"application/x-www-form-urlencoded" +HDR_FORM_MULTIPART = b"multipart/form-data" CONTENT_MISSING = 0 -class Headers(object, UserDict.DictMixin): +class Headers(MutableMapping, object): """ Header class which allows both convenient access to individual headers as well as direct access to the underlying raw data. Provides a full dictionary interface. @@ -62,10 +70,12 @@ class Headers(object, UserDict.DictMixin): For use with the "Set-Cookie" header, see :py:meth:`get_all`. """ + @always_byte_args("ascii") def __init__(self, fields=None, **headers): """ Args: - fields: (optional) list of ``(name, value)`` header tuples, e.g. ``[("Host","example.com")]`` + fields: (optional) list of ``(name, value)`` header tuples, + e.g. ``[("Host","example.com")]``. All names and values must be bytes. **headers: Additional headers to set. Will overwrite existing values from `fields`. For convenience, underscores in header names will be transformed to dashes - this behaviour does not extend to other methods. @@ -76,21 +86,25 @@ class Headers(object, UserDict.DictMixin): # content_type -> content-type headers = { - name.replace("_", "-"): value - for name, value in headers.iteritems() + name.encode("ascii").replace(b"_", b"-"): value + for name, value in six.iteritems(headers) } self.update(headers) - def __str__(self): - return "\r\n".join(": ".join(field) for field in self.fields) + "\r\n" + def __bytes__(self): + return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" + + if six.PY2: + __str__ = __bytes__ + @always_byte_args("ascii") def __getitem__(self, name): values = self.get_all(name) if not values: raise KeyError(name) - else: - return ", ".join(values) + return b", ".join(values) + @always_byte_args("ascii") def __setitem__(self, name, value): idx = self._index(name) @@ -101,6 +115,7 @@ class Headers(object, UserDict.DictMixin): else: self.fields.append([name, value]) + @always_byte_args("ascii") def __delitem__(self, name): if name not in self: raise KeyError(name) @@ -110,6 +125,19 @@ class Headers(object, UserDict.DictMixin): if name != field[0].lower() ] + def __iter__(self): + seen = set() + for name, _ in self.fields: + name_lower = name.lower() + if name_lower not in seen: + seen.add(name_lower) + yield name + + def __len__(self): + return len(set(name.lower() for name, _ in self.fields)) + + #__hash__ = object.__hash__ + def _index(self, name): name = name.lower() for i, field in enumerate(self.fields): @@ -117,16 +145,6 @@ class Headers(object, UserDict.DictMixin): return i return None - def keys(self): - seen = set() - names = [] - for name, _ in self.fields: - name_lower = name.lower() - if name_lower not in seen: - seen.add(name_lower) - names.append(name) - return names - def __eq__(self, other): if isinstance(other, Headers): return self.fields == other.fields @@ -135,6 +153,7 @@ class Headers(object, UserDict.DictMixin): def __ne__(self, other): return not self.__eq__(other) + @always_byte_args("ascii") def get_all(self, name): """ Like :py:meth:`get`, but does not fold multiple headers into a single one. @@ -142,8 +161,8 @@ class Headers(object, UserDict.DictMixin): See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 """ - name = name.lower() - values = [value for n, value in self.fields if n.lower() == name] + name_lower = name.lower() + values = [value for n, value in self.fields if n.lower() == name_lower] return values def set_all(self, name, values): @@ -151,6 +170,8 @@ class Headers(object, UserDict.DictMixin): Explicitly set multiple headers for the given key. See: :py:meth:`get_all` """ + name = always_bytes(name, "ascii") + values = (always_bytes(value, "ascii") for value in values) if name in self: del self[name] self.fields.extend( @@ -172,28 +193,6 @@ class Headers(object, UserDict.DictMixin): return cls([list(field) for field in state]) -class ProtocolMixin(object): - def read_request(self, *args, **kwargs): # pragma: no cover - raise NotImplementedError - - def read_response(self, *args, **kwargs): # pragma: no cover - raise NotImplementedError - - def assemble(self, message): - if isinstance(message, Request): - return self.assemble_request(message) - elif isinstance(message, Response): - return self.assemble_response(message) - else: - raise ValueError("HTTP message not supported.") - - def assemble_request(self, *args, **kwargs): # pragma: no cover - raise NotImplementedError - - def assemble_response(self, *args, **kwargs): # pragma: no cover - raise NotImplementedError - - class Request(object): # This list is adopted legacy code. # We probably don't need to strip off keep-alive. @@ -248,42 +247,14 @@ class Request(object): return False def __repr__(self): - # return "Request(%s - %s, %s)" % (self.method, self.host, self.path) - - return "<HTTPRequest: {0}>".format( - self.legacy_first_line()[:-9] - ) - - def legacy_first_line(self, form=None): - if form is None: - form = self.form_out - if form == "relative": - return '%s %s HTTP/%s.%s' % ( - self.method, - self.path, - self.httpversion[0], - self.httpversion[1], - ) - elif form == "authority": - return '%s %s:%s HTTP/%s.%s' % ( - self.method, - self.host, - self.port, - self.httpversion[0], - self.httpversion[1], - ) - elif form == "absolute": - return '%s %s://%s:%s%s HTTP/%s.%s' % ( - self.method, - self.scheme, - self.host, - self.port, - self.path, - self.httpversion[0], - self.httpversion[1], - ) + if self.host and self.port: + hostport = "{}:{}".format(self.host, self.port) else: - raise exceptions.HttpError(400, "Invalid request form") + hostport = "" + path = self.path or "" + return "HTTPRequest({} {}{})".format( + self.method, hostport, path + ) def anticache(self): """ @@ -336,7 +307,7 @@ class Request(object): return self.get_form_urlencoded() elif HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): return self.get_form_multipart() - return odict.ODict([]) + return ODict([]) def get_form_urlencoded(self): """ @@ -345,16 +316,16 @@ class Request(object): indicates non-form data. """ if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): - return odict.ODict(utils.urldecode(self.body)) - return odict.ODict([]) + return ODict(utils.urldecode(self.body)) + return ODict([]) def get_form_multipart(self): if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): - return odict.ODict( + return ODict( utils.multipartdecode( self.headers, self.body)) - return odict.ODict([]) + return ODict([]) def set_form_urlencoded(self, odict): """ @@ -373,8 +344,8 @@ class Request(object): Components are unquoted. """ - _, _, path, _, _, _ = urlparse.urlparse(self.url) - return [urllib.unquote(i) for i in path.split("/") if i] + _, _, path, _, _, _ = urllib.parse.urlparse(self.url) + return [urllib.parse.unquote(i) for i in path.split(b"/") if i] def set_path_components(self, lst): """ @@ -382,10 +353,10 @@ class Request(object): Components are quoted. """ - lst = [urllib.quote(i, safe="") for i in lst] - path = "/" + "/".join(lst) - scheme, netloc, _, params, query, fragment = urlparse.urlparse(self.url) - self.url = urlparse.urlunparse( + lst = [urllib.parse.quote(i, safe="") for i in lst] + path = b"/" + b"/".join(lst) + scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url) + self.url = urllib.parse.urlunparse( [scheme, netloc, path, params, query, fragment] ) @@ -393,18 +364,18 @@ class Request(object): """ Gets the request query string. Returns an ODict object. """ - _, _, _, _, query, _ = urlparse.urlparse(self.url) + _, _, _, _, query, _ = urllib.parse.urlparse(self.url) if query: - return odict.ODict(utils.urldecode(query)) - return odict.ODict([]) + return ODict(utils.urldecode(query)) + return ODict([]) def set_query(self, odict): """ Takes an ODict object, and sets the request query string. """ - scheme, netloc, path, params, _, fragment = urlparse.urlparse(self.url) + scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) query = utils.urlencode(odict.lst) - self.url = urlparse.urlunparse( + self.url = urllib.parse.urlunparse( [scheme, netloc, path, params, query, fragment] ) @@ -421,18 +392,13 @@ class Request(object): but not the resolved name. This is disabled by default, as an attacker may spoof the host header to confuse an analyst. """ - host = None - if hostheader: - host = self.headers.get("Host") - if not host: - host = self.host - if host: + if hostheader and b"Host" in self.headers: try: - return host.encode("idna") + return self.headers[b"Host"].decode("idna") except ValueError: - return host - else: - return None + pass + if self.host: + return self.host.decode("idna") def pretty_url(self, hostheader): if self.form_out == "authority": # upstream proxy mode @@ -446,7 +412,7 @@ class Request(object): """ Returns a possibly empty netlib.odict.ODict object. """ - ret = odict.ODict() + ret = ODict() for i in self.headers.get_all("cookie"): ret.extend(cookies.parse_cookie_header(i)) return ret @@ -477,8 +443,10 @@ class Request(object): Parses a URL specification, and updates the Request's information accordingly. - Returns False if the URL was invalid, True if the request succeeded. + Raises: + ValueError if the URL was invalid """ + # TODO: Should handle incoming unicode here. parts = utils.parse_url(url) if not parts: raise ValueError("Invalid URL: %s" % url) @@ -495,32 +463,6 @@ class Request(object): self.body = content -class EmptyRequest(Request): - def __init__( - self, - form_in="", - method="", - scheme="", - host="", - port="", - path="", - httpversion=(0, 0), - headers=None, - body="" - ): - super(EmptyRequest, self).__init__( - form_in=form_in, - method=method, - scheme=scheme, - host=host, - port=port, - path=path, - httpversion=httpversion, - headers=headers, - body=body, - ) - - class Response(object): _headers_to_strip_off = [ 'Proxy-Connection', @@ -535,7 +477,6 @@ class Response(object): msg=None, headers=None, body=None, - sslinfo=None, timestamp_start=None, timestamp_end=None, ): @@ -548,7 +489,6 @@ class Response(object): self.msg = msg self.headers = headers self.body = body - self.sslinfo = sslinfo self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end @@ -591,7 +531,7 @@ class Response(object): if v: name, value, attrs = v ret.append([name, [value, attrs]]) - return odict.ODict(ret) + return ODict(ret) def set_cookies(self, odict): """ diff --git a/netlib/tcp.py b/netlib/tcp.py index 4a7f6153..1eb417b4 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -834,14 +834,14 @@ class TCPServer(object): # If a thread has persisted after interpreter exit, the module might be # none. if traceback: - exc = traceback.format_exc() - print('-' * 40, file=fp) + exc = six.text_type(traceback.format_exc()) + print(u'-' * 40, file=fp) print( - "Error in processing of request from %s:%s" % ( + u"Error in processing of request from %s:%s" % ( client_address.host, client_address.port ), file=fp) print(exc, file=fp) - print('-' * 40, file=fp) + print(u'-' * 40, file=fp) def handle_client_connection(self, conn, client_address): # pragma: no cover """ diff --git a/netlib/tutils.py b/netlib/tutils.py index 951ef3d9..05791c49 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -1,18 +1,22 @@ -import cStringIO +from io import BytesIO import tempfile import os import time import shutil from contextlib import contextmanager +import six +import sys -from netlib import tcp, utils, http +from . import utils +from .http import Request, Response, Headers def treader(bytes): """ Construct a tcp.Read object from bytes. """ - fp = cStringIO.StringIO(bytes) + from . import tcp # TODO: move to top once cryptography is on Python 3.5 + fp = BytesIO(bytes) return tcp.Reader(fp) @@ -28,7 +32,24 @@ def tmpdir(*args, **kwargs): shutil.rmtree(temp_workdir) -def raises(exc, obj, *args, **kwargs): +def _check_exception(expected, actual, exc_tb): + if isinstance(expected, six.string_types): + if expected.lower() not in str(actual).lower(): + six.reraise(AssertionError, AssertionError( + "Expected %s, but caught %s" % ( + repr(expected), repr(actual) + ) + ), exc_tb) + else: + if not isinstance(actual, expected): + six.reraise(AssertionError, AssertionError( + "Expected %s, but caught %s %s" % ( + expected.__name__, actual.__class__.__name__, repr(actual) + ) + ), exc_tb) + + +def raises(expected_exception, obj=None, *args, **kwargs): """ Assert that a callable raises a specified exception. @@ -43,81 +64,68 @@ def raises(exc, obj, *args, **kwargs): :kwargs Arguments to be passed to the callable. """ - try: - ret = obj(*args, **kwargs) - except Exception as v: - if isinstance(exc, basestring): - if exc.lower() in str(v).lower(): - return - else: - raise AssertionError( - "Expected %s, but caught %s" % ( - repr(str(exc)), v - ) - ) + if obj is None: + return RaisesContext(expected_exception) + else: + try: + ret = obj(*args, **kwargs) + except Exception as actual: + _check_exception(expected_exception, actual, sys.exc_info()[2]) else: - if isinstance(v, exc): - return - else: - raise AssertionError( - "Expected %s, but caught %s %s" % ( - exc.__name__, v.__class__.__name__, str(v) - ) - ) - raise AssertionError("No exception raised. Return value: {}".format(ret)) + raise AssertionError("No exception raised. Return value: {}".format(ret)) -test_data = utils.Data(__name__) +class RaisesContext(object): + def __init__(self, expected_exception): + self.expected_exception = expected_exception -def treq(content="content", scheme="http", host="address", port=22): - """ - @return: libmproxy.protocol.http.HTTPRequest - """ - headers = http.Headers() - headers["header"] = "qvalue" - req = http.Request( - "relative", - "GET", - scheme, - host, - port, - "/path", - (1, 1), - headers, - content, - None, - None, - ) - return req + def __enter__(self): + return + def __exit__(self, exc_type, exc_val, exc_tb): + if not exc_type: + raise AssertionError("No exception raised.") + else: + _check_exception(self.expected_exception, exc_val, exc_tb) + return True -def treq_absolute(content="content"): - """ - @return: libmproxy.protocol.http.HTTPRequest - """ - r = treq(content) - r.form_in = r.form_out = "absolute" - r.host = "address" - r.port = 22 - r.scheme = "http" - return r +test_data = utils.Data(__name__) -def tresp(content="message"): + +def treq(**kwargs): """ - @return: libmproxy.protocol.http.HTTPResponse + Returns: + netlib.http.Request """ + default = dict( + form_in="relative", + method=b"GET", + scheme=b"http", + host=b"address", + port=22, + path=b"/path", + httpversion=b"HTTP/1.1", + headers=Headers(header=b"qvalue"), + body=b"content" + ) + default.update(kwargs) + return Request(**default) - headers = http.Headers() - headers["header_response"] = "svalue" - resp = http.semantics.Response( - (1, 1), - 200, - "OK", - headers, - content, +def tresp(**kwargs): + """ + Returns: + netlib.http.Response + """ + default = dict( + httpversion=b"HTTP/1.1", + status_code=200, + msg=b"OK", + headers=Headers(header_response=b"svalue"), + body=b"message", timestamp_start=time.time(), - timestamp_end=time.time(), + timestamp_end=time.time() ) - return resp + default.update(kwargs) + return Response(**default) diff --git a/netlib/utils.py b/netlib/utils.py index d6774419..a86b8019 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -1,17 +1,17 @@ -from __future__ import (absolute_import, print_function, division) +from __future__ import absolute_import, print_function, division import os.path -import cgi -import urllib -import urlparse -import string import re -import six +import string import unicodedata +import six + +from six.moves import urllib + -def isascii(s): +def isascii(bytes): try: - s.decode("ascii") + bytes.decode("ascii") except ValueError: return False return True @@ -40,12 +40,12 @@ def clean_bin(s, keep_spacing=True): ) else: if keep_spacing: - keep = b"\n\r\t" + keep = (9, 10, 13) # \t, \n, \r, else: - keep = b"" + keep = () return b"".join( - ch if (31 < ord(ch) < 127 or ch in keep) else b"." - for ch in s + six.int2byte(ch) if (31 < ch < 127 or ch in keep) else b"." + for ch in six.iterbytes(s) ) @@ -149,10 +149,7 @@ class Data(object): return fullpath -def is_valid_port(port): - if not 0 <= port <= 65535: - return False - return True +_label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE) def is_valid_host(host): @@ -160,53 +157,79 @@ def is_valid_host(host): host.decode("idna") except ValueError: return False - if "\0" in host: - return None - return True + if len(host) > 255: + return False + if host[-1] == ".": + host = host[:-1] + return all(_label_valid.match(x) for x in host.split(b".")) + + +def is_valid_port(port): + return 0 <= port <= 65535 + + +# PY2 workaround +def decode_parse_result(result, enc): + if hasattr(result, "decode"): + return result.decode(enc) + else: + return urllib.parse.ParseResult(*[x.decode(enc) for x in result]) + + +# PY2 workaround +def encode_parse_result(result, enc): + if hasattr(result, "encode"): + return result.encode(enc) + else: + return urllib.parse.ParseResult(*[x.encode(enc) for x in result]) def parse_url(url): """ - Returns a (scheme, host, port, path) tuple, or None on error. + URL-parsing function that checks that + - port is an integer 0-65535 + - host is a valid IDNA-encoded hostname with no null-bytes + - path is valid ASCII - Checks that: - port is an integer 0-65535 - host is a valid IDNA-encoded hostname with no null-bytes - path is valid ASCII + Args: + A URL (as bytes or as unicode) + + Returns: + A (scheme, host, port, path) tuple + + Raises: + ValueError, if the URL is not properly formatted. """ - try: - scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) - except ValueError: - return None - if not scheme: - return None - if '@' in netloc: - # FIXME: Consider what to do with the discarded credentials here Most - # probably we should extend the signature to return these as a separate - # value. - _, netloc = string.rsplit(netloc, '@', maxsplit=1) - if ':' in netloc: - host, port = string.rsplit(netloc, ':', maxsplit=1) - try: - port = int(port) - except ValueError: - return None + parsed = urllib.parse.urlparse(url) + + if not parsed.hostname: + raise ValueError("No hostname given") + + if isinstance(url, six.binary_type): + host = parsed.hostname + + # this should not raise a ValueError + decode_parse_result(parsed, "ascii") else: - host = netloc - if scheme.endswith("https"): - port = 443 - else: - port = 80 - path = urlparse.urlunparse(('', '', path, params, query, fragment)) - if not path.startswith("/"): - path = "/" + path + host = parsed.hostname.encode("idna") + parsed = encode_parse_result(parsed, "ascii") + + port = parsed.port + if not port: + port = 443 if parsed.scheme == b"https" else 80 + + full_path = urllib.parse.urlunparse( + (b"", b"", parsed.path, parsed.params, parsed.query, parsed.fragment) + ) + if not full_path.startswith(b"/"): + full_path = b"/" + full_path + if not is_valid_host(host): - return None - if not isascii(path): - return None + raise ValueError("Invalid Host") if not is_valid_port(port): - return None - return scheme, host, port, path + raise ValueError("Invalid Port") + + return parsed.scheme, host, port, full_path def get_header_tokens(headers, key): @@ -217,7 +240,7 @@ def get_header_tokens(headers, key): """ if key not in headers: return [] - tokens = headers[key].split(",") + tokens = headers[key].split(b",") return [token.strip() for token in tokens] @@ -228,7 +251,7 @@ def hostport(scheme, host, port): if (port, scheme) in [(80, "http"), (443, "https")]: return host else: - return "%s:%s" % (host, port) + return b"%s:%d" % (host, port) def unparse_url(scheme, host, port, path=""): @@ -243,14 +266,14 @@ def urlencode(s): Takes a list of (key, value) tuples and returns a urlencoded string. """ s = [tuple(i) for i in s] - return urllib.urlencode(s, False) + return urllib.parse.urlencode(s, False) def urldecode(s): """ Takes a urlencoded string and returns a list of (key, value) tuples. """ - return cgi.parse_qsl(s, keep_blank_values=True) + return urllib.parse.parse_qsl(s, keep_blank_values=True) def parse_content_type(c): @@ -267,14 +290,14 @@ def parse_content_type(c): ("text", "html", {"charset": "UTF-8"}) """ - parts = c.split(";", 1) - ts = parts[0].split("/", 1) + parts = c.split(b";", 1) + ts = parts[0].split(b"/", 1) if len(ts) != 2: return None d = {} if len(parts) == 2: - for i in parts[1].split(";"): - clause = i.split("=", 1) + for i in parts[1].split(b";"): + clause = i.split(b"=", 1) if len(clause) == 2: d[clause[0].strip()] = clause[1].strip() return ts[0].lower(), ts[1].lower(), d @@ -289,7 +312,7 @@ def multipartdecode(headers, content): v = parse_content_type(v) if not v: return [] - boundary = v[2].get("boundary") + boundary = v[2].get(b"boundary") if not boundary: return [] @@ -306,3 +329,20 @@ def multipartdecode(headers, content): r.append((key, value)) return r return [] + + +def always_bytes(unicode_or_bytes, encoding): + if isinstance(unicode_or_bytes, six.text_type): + return unicode_or_bytes.encode(encoding) + return unicode_or_bytes + + +def always_byte_args(encoding): + """Decorator that transparently encodes all arguments passed as unicode""" + def decorator(fun): + def _fun(*args, **kwargs): + args = [always_bytes(arg, encoding) for arg in args] + kwargs = {k: always_bytes(v, encoding) for k, v in six.iteritems(kwargs)} + return fun(*args, **kwargs) + return _fun + return decorator diff --git a/netlib/version_check.py b/netlib/version_check.py index 1d7e025c..9cf27eea 100644 --- a/netlib/version_check.py +++ b/netlib/version_check.py @@ -7,6 +7,7 @@ from __future__ import division, absolute_import, print_function import sys import inspect import os.path +import six import OpenSSL from . import version @@ -19,8 +20,8 @@ def check_mitmproxy_version(mitmproxy_version, fp=sys.stderr): # consider major and minor version. if version.IVERSION[:2] != mitmproxy_version[:2]: print( - "You are using mitmproxy %s with netlib %s. " - "Most likely, that won't work - please upgrade!" % ( + u"You are using mitmproxy %s with netlib %s. " + u"Most likely, that won't work - please upgrade!" % ( mitmproxy_version, version.VERSION ), file=fp @@ -29,13 +30,13 @@ def check_mitmproxy_version(mitmproxy_version, fp=sys.stderr): def check_pyopenssl_version(min_version=PYOPENSSL_MIN_VERSION, fp=sys.stderr): - min_version_str = ".".join(str(x) for x in min_version) + min_version_str = u".".join(six.text_type(x) for x in min_version) try: v = tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) except ValueError: print( - "Cannot parse pyOpenSSL version: {}" - "mitmproxy requires pyOpenSSL {} or greater.".format( + u"Cannot parse pyOpenSSL version: {}" + u"mitmproxy requires pyOpenSSL {} or greater.".format( OpenSSL.__version__, min_version_str ), file=fp @@ -43,15 +44,15 @@ def check_pyopenssl_version(min_version=PYOPENSSL_MIN_VERSION, fp=sys.stderr): return if v < min_version: print( - "You are using an outdated version of pyOpenSSL: " - "mitmproxy requires pyOpenSSL {} or greater.".format(min_version_str), + u"You are using an outdated version of pyOpenSSL: " + u"mitmproxy requires pyOpenSSL {} or greater.".format(min_version_str), file=fp ) # Some users apparently have multiple versions of pyOpenSSL installed. # Report which one we got. pyopenssl_path = os.path.dirname(inspect.getfile(OpenSSL)) print( - "Your pyOpenSSL {} installation is located at {}".format( + u"Your pyOpenSSL {} installation is located at {}".format( OpenSSL.__version__, pyopenssl_path ), file=fp diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py index 5acf7696..1c143919 100644 --- a/netlib/websockets/__init__.py +++ b/netlib/websockets/__init__.py @@ -1,2 +1,2 @@ -from frame import * -from protocol import * +from .frame import * +from .protocol import * diff --git a/test/http/http1/test_assemble.py b/test/http/http1/test_assemble.py new file mode 100644 index 00000000..8a0a54f1 --- /dev/null +++ b/test/http/http1/test_assemble.py @@ -0,0 +1,91 @@ +from __future__ import absolute_import, print_function, division +from netlib.exceptions import HttpException +from netlib.http import CONTENT_MISSING, Headers +from netlib.http.http1.assemble import ( + assemble_request, assemble_request_head, assemble_response, + assemble_response_head, _assemble_request_line, _assemble_request_headers, + _assemble_response_headers +) +from netlib.tutils import treq, raises, tresp + + +def test_assemble_request(): + c = assemble_request(treq()) == ( + b"GET /path HTTP/1.1\r\n" + b"header: qvalue\r\n" + b"Host: address:22\r\n" + b"Content-Length: 7\r\n" + b"\r\n" + b"content" + ) + + with raises(HttpException): + assemble_request(treq(body=CONTENT_MISSING)) + + +def test_assemble_request_head(): + c = assemble_request_head(treq()) + assert b"GET" in c + assert b"qvalue" in c + assert b"content" not in c + + +def test_assemble_response(): + c = assemble_response(tresp()) == ( + b"HTTP/1.1 200 OK\r\n" + b"header-response: svalue\r\n" + b"Content-Length: 7\r\n" + b"\r\n" + b"message" + ) + + with raises(HttpException): + assemble_response(tresp(body=CONTENT_MISSING)) + + +def test_assemble_response_head(): + c = assemble_response_head(tresp()) + assert b"200" in c + assert b"svalue" in c + assert b"message" not in c + + +def test_assemble_request_line(): + assert _assemble_request_line(treq()) == b"GET /path HTTP/1.1" + + authority_request = treq(method=b"CONNECT", form_in="authority") + assert _assemble_request_line(authority_request) == b"CONNECT address:22 HTTP/1.1" + + absolute_request = treq(form_in="absolute") + assert _assemble_request_line(absolute_request) == b"GET http://address:22/path HTTP/1.1" + + with raises(RuntimeError): + _assemble_request_line(treq(), "invalid_form") + + +def test_assemble_request_headers(): + # https://github.com/mitmproxy/mitmproxy/issues/186 + r = treq(body=b"") + r.headers[b"Transfer-Encoding"] = b"chunked" + c = _assemble_request_headers(r) + assert b"Content-Length" in c + assert b"Transfer-Encoding" not in c + + assert b"Host" in _assemble_request_headers(treq(headers=Headers())) + + assert b"Proxy-Connection" not in _assemble_request_headers( + treq(headers=Headers(Proxy_Connection="42")) + ) + + +def test_assemble_response_headers(): + # https://github.com/mitmproxy/mitmproxy/issues/186 + r = tresp(body=b"") + r.headers["Transfer-Encoding"] = b"chunked" + c = _assemble_response_headers(r) + assert b"Content-Length" in c + assert b"Transfer-Encoding" not in c + + assert b"Proxy-Connection" not in _assemble_response_headers( + tresp(headers=Headers(Proxy_Connection=b"42")) + ) diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index f7c615bd..e69de29b 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -1,497 +0,0 @@ -import cStringIO -import textwrap - -from netlib import http, odict, tcp, tutils -from netlib.http import semantics, Headers -from netlib.http.http1 import HTTP1Protocol -from ... import tservers - - -class NoContentLengthHTTPHandler(tcp.BaseHandler): - def handle(self): - self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n") - self.wfile.flush() - - -def mock_protocol(data=''): - rfile = cStringIO.StringIO(data) - wfile = cStringIO.StringIO() - return HTTP1Protocol(rfile=rfile, wfile=wfile) - - -def match_http_string(data): - return textwrap.dedent(data).strip().replace('\n', '\r\n') - - -def test_stripped_chunked_encoding_no_content(): - """ - https://github.com/mitmproxy/mitmproxy/issues/186 - """ - - r = tutils.treq(content="") - r.headers["Transfer-Encoding"] = "chunked" - assert "Content-Length" in mock_protocol()._assemble_request_headers(r) - - r = tutils.tresp(content="") - r.headers["Transfer-Encoding"] = "chunked" - assert "Content-Length" in mock_protocol()._assemble_response_headers(r) - - -def test_has_chunked_encoding(): - headers = http.Headers() - assert not HTTP1Protocol.has_chunked_encoding(headers) - headers["transfer-encoding"] = "chunked" - assert HTTP1Protocol.has_chunked_encoding(headers) - - -def test_read_chunked(): - headers = http.Headers() - headers["transfer-encoding"] = "chunked" - - data = "1\r\na\r\n0\r\n" - tutils.raises( - "malformed chunked body", - mock_protocol(data).read_http_body, - headers, None, "GET", None, True - ) - - data = "1\r\na\r\n0\r\n\r\n" - assert mock_protocol(data).read_http_body(headers, None, "GET", None, True) == "a" - - data = "\r\n\r\n1\r\na\r\n0\r\n\r\n" - assert mock_protocol(data).read_http_body(headers, None, "GET", None, True) == "a" - - data = "\r\n" - tutils.raises( - "closed prematurely", - mock_protocol(data).read_http_body, - headers, None, "GET", None, True - ) - - data = "1\r\nfoo" - tutils.raises( - "malformed chunked body", - mock_protocol(data).read_http_body, - headers, None, "GET", None, True - ) - - data = "foo\r\nfoo" - tutils.raises( - http.HttpError, - mock_protocol(data).read_http_body, - headers, None, "GET", None, True - ) - - data = "5\r\naaaaa\r\n0\r\n\r\n" - tutils.raises("too large", mock_protocol(data).read_http_body, headers, 2, "GET", None, True) - - -def test_connection_close(): - headers = Headers() - assert HTTP1Protocol.connection_close((1, 0), headers) - assert not HTTP1Protocol.connection_close((1, 1), headers) - - headers["connection"] = "keep-alive" - assert not HTTP1Protocol.connection_close((1, 1), headers) - - headers["connection"] = "close" - assert HTTP1Protocol.connection_close((1, 1), headers) - - -def test_read_http_body_request(): - headers = Headers() - data = "testing" - assert mock_protocol(data).read_http_body(headers, None, "GET", None, True) == "" - - -def test_read_http_body_response(): - headers = Headers() - data = "testing" - assert mock_protocol(data).read_http_body(headers, None, "GET", 200, False) == "testing" - - -def test_read_http_body(): - # test default case - headers = Headers() - headers["content-length"] = "7" - data = "testing" - assert mock_protocol(data).read_http_body(headers, None, "GET", 200, False) == "testing" - - # test content length: invalid header - headers["content-length"] = "foo" - data = "testing" - tutils.raises( - http.HttpError, - mock_protocol(data).read_http_body, - headers, None, "GET", 200, False - ) - - # test content length: invalid header #2 - headers["content-length"] = "-1" - data = "testing" - tutils.raises( - http.HttpError, - mock_protocol(data).read_http_body, - headers, None, "GET", 200, False - ) - - # test content length: content length > actual content - headers["content-length"] = "5" - data = "testing" - tutils.raises( - http.HttpError, - mock_protocol(data).read_http_body, - headers, 4, "GET", 200, False - ) - - # test content length: content length < actual content - data = "testing" - assert len(mock_protocol(data).read_http_body(headers, None, "GET", 200, False)) == 5 - - # test no content length: limit > actual content - headers = Headers() - data = "testing" - assert len(mock_protocol(data).read_http_body(headers, 100, "GET", 200, False)) == 7 - - # test no content length: limit < actual content - data = "testing" - tutils.raises( - http.HttpError, - mock_protocol(data).read_http_body, - headers, 4, "GET", 200, False - ) - - # test chunked - headers = Headers() - headers["transfer-encoding"] = "chunked" - data = "5\r\naaaaa\r\n0\r\n\r\n" - assert mock_protocol(data).read_http_body(headers, 100, "GET", 200, False) == "aaaaa" - - -def test_expected_http_body_size(): - # gibber in the content-length field - headers = Headers(content_length="foo") - assert HTTP1Protocol.expected_http_body_size(headers, False, "GET", 200) is None - # negative number in the content-length field - headers = Headers(content_length="-7") - assert HTTP1Protocol.expected_http_body_size(headers, False, "GET", 200) is None - # explicit length - headers = Headers(content_length="5") - assert HTTP1Protocol.expected_http_body_size(headers, False, "GET", 200) == 5 - # no length - headers = Headers() - assert HTTP1Protocol.expected_http_body_size(headers, False, "GET", 200) == -1 - # no length request - headers = Headers() - assert HTTP1Protocol.expected_http_body_size(headers, True, "GET", None) == 0 - - -def test_get_request_line(): - data = "\nfoo" - p = mock_protocol(data) - assert p._get_request_line() == "foo" - assert not p._get_request_line() - - -def test_parse_http_protocol(): - assert HTTP1Protocol._parse_http_protocol("HTTP/1.1") == (1, 1) - assert HTTP1Protocol._parse_http_protocol("HTTP/0.0") == (0, 0) - assert not HTTP1Protocol._parse_http_protocol("HTTP/a.1") - assert not HTTP1Protocol._parse_http_protocol("HTTP/1.a") - assert not HTTP1Protocol._parse_http_protocol("foo/0.0") - assert not HTTP1Protocol._parse_http_protocol("HTTP/x") - - -def test_parse_init_connect(): - assert HTTP1Protocol._parse_init_connect("CONNECT host.com:443 HTTP/1.0") - assert not HTTP1Protocol._parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0") - assert not HTTP1Protocol._parse_init_connect("CONNECT \0host.com:443 HTTP/1.0") - assert not HTTP1Protocol._parse_init_connect("CONNECT host.com:444444 HTTP/1.0") - assert not HTTP1Protocol._parse_init_connect("bogus") - assert not HTTP1Protocol._parse_init_connect("GET host.com:443 HTTP/1.0") - assert not HTTP1Protocol._parse_init_connect("CONNECT host.com443 HTTP/1.0") - assert not HTTP1Protocol._parse_init_connect("CONNECT host.com:443 foo/1.0") - assert not HTTP1Protocol._parse_init_connect("CONNECT host.com:foo HTTP/1.0") - - -def test_parse_init_proxy(): - u = "GET http://foo.com:8888/test HTTP/1.1" - m, s, h, po, pa, httpversion = HTTP1Protocol._parse_init_proxy(u) - assert m == "GET" - assert s == "http" - assert h == "foo.com" - assert po == 8888 - assert pa == "/test" - assert httpversion == (1, 1) - - u = "G\xfeET http://foo.com:8888/test HTTP/1.1" - assert not HTTP1Protocol._parse_init_proxy(u) - - assert not HTTP1Protocol._parse_init_proxy("invalid") - assert not HTTP1Protocol._parse_init_proxy("GET invalid HTTP/1.1") - assert not HTTP1Protocol._parse_init_proxy("GET http://foo.com:8888/test foo/1.1") - - -def test_parse_init_http(): - u = "GET /test HTTP/1.1" - m, u, httpversion = HTTP1Protocol._parse_init_http(u) - assert m == "GET" - assert u == "/test" - assert httpversion == (1, 1) - - u = "G\xfeET /test HTTP/1.1" - assert not HTTP1Protocol._parse_init_http(u) - - assert not HTTP1Protocol._parse_init_http("invalid") - assert not HTTP1Protocol._parse_init_http("GET invalid HTTP/1.1") - assert not HTTP1Protocol._parse_init_http("GET /test foo/1.1") - assert not HTTP1Protocol._parse_init_http("GET /test\xc0 HTTP/1.1") - - -class TestReadHeaders: - - def _read(self, data, verbatim=False): - if not verbatim: - data = textwrap.dedent(data) - data = data.strip() - return mock_protocol(data).read_headers() - - def test_read_simple(self): - data = """ - Header: one - Header2: two - \r\n - """ - headers = self._read(data) - assert headers.fields == [["Header", "one"], ["Header2", "two"]] - - def test_read_multi(self): - data = """ - Header: one - Header: two - \r\n - """ - headers = self._read(data) - assert headers.fields == [["Header", "one"], ["Header", "two"]] - - def test_read_continued(self): - data = """ - Header: one - \ttwo - Header2: three - \r\n - """ - headers = self._read(data) - assert headers.fields == [["Header", "one\r\n two"], ["Header2", "three"]] - - def test_read_continued_err(self): - data = "\tfoo: bar\r\n" - assert self._read(data, True) is None - - def test_read_err(self): - data = """ - foo - """ - assert self._read(data) is None - - -class TestReadRequest(object): - - def tst(self, data, **kwargs): - return mock_protocol(data).read_request(**kwargs) - - def test_invalid(self): - tutils.raises( - "bad http request", - self.tst, - "xxx" - ) - tutils.raises( - "bad http request line", - self.tst, - "get /\xff HTTP/1.1" - ) - tutils.raises( - "invalid headers", - self.tst, - "get / HTTP/1.1\r\nfoo" - ) - tutils.raises( - tcp.NetLibDisconnect, - self.tst, - "\r\n" - ) - - def test_empty(self): - v = self.tst("", allow_empty=True) - assert isinstance(v, semantics.EmptyRequest) - - def test_asterisk_form_in(self): - v = self.tst("OPTIONS * HTTP/1.1") - assert v.form_in == "relative" - assert v.method == "OPTIONS" - - def test_absolute_form_in(self): - tutils.raises( - "Bad HTTP request line", - self.tst, - "GET oops-no-protocol.com HTTP/1.1" - ) - v = self.tst("GET http://address:22/ HTTP/1.1") - assert v.form_in == "absolute" - assert v.port == 22 - assert v.host == "address" - assert v.scheme == "http" - - def test_connect(self): - tutils.raises( - "Bad HTTP request line", - self.tst, - "CONNECT oops-no-port.com HTTP/1.1" - ) - v = self.tst("CONNECT foo.com:443 HTTP/1.1") - assert v.form_in == "authority" - assert v.method == "CONNECT" - assert v.port == 443 - assert v.host == "foo.com" - - def test_expect(self): - data = "".join( - "GET / HTTP/1.1\r\n" - "Content-Length: 3\r\n" - "Expect: 100-continue\r\n\r\n" - "foobar" - ) - - p = mock_protocol(data) - v = p.read_request() - assert p.tcp_handler.wfile.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" - assert v.body == "foo" - assert p.tcp_handler.rfile.read(3) == "bar" - - -class TestReadResponse(object): - def tst(self, data, method, body_size_limit, include_body=True): - data = textwrap.dedent(data) - return mock_protocol(data).read_response( - method, body_size_limit, include_body=include_body - ) - - def test_errors(self): - tutils.raises("server disconnect", self.tst, "", "GET", None) - tutils.raises("invalid server response", self.tst, "foo", "GET", None) - - def test_simple(self): - data = """ - HTTP/1.1 200 - """ - assert self.tst(data, "GET", None) == http.Response( - (1, 1), 200, '', Headers(), '' - ) - - def test_simple_message(self): - data = """ - HTTP/1.1 200 OK - """ - assert self.tst(data, "GET", None) == http.Response( - (1, 1), 200, 'OK', Headers(), '' - ) - - def test_invalid_http_version(self): - data = """ - HTTP/x 200 OK - """ - tutils.raises("invalid http version", self.tst, data, "GET", None) - - def test_invalid_status_code(self): - data = """ - HTTP/1.1 xx OK - """ - tutils.raises("invalid server response", self.tst, data, "GET", None) - - def test_valid_with_continue(self): - data = """ - HTTP/1.1 100 CONTINUE - - HTTP/1.1 200 OK - """ - assert self.tst(data, "GET", None) == http.Response( - (1, 1), 100, 'CONTINUE', Headers(), '' - ) - - def test_simple_body(self): - data = """ - HTTP/1.1 200 OK - Content-Length: 3 - - foo - """ - assert self.tst(data, "GET", None).body == 'foo' - assert self.tst(data, "HEAD", None).body == '' - - def test_invalid_headers(self): - data = """ - HTTP/1.1 200 OK - \tContent-Length: 3 - - foo - """ - tutils.raises("invalid headers", self.tst, data, "GET", None) - - def test_without_body(self): - data = """ - HTTP/1.1 200 OK - Content-Length: 3 - - foo - """ - assert self.tst(data, "GET", None, include_body=False).body is None - - -class TestReadResponseNoContentLength(tservers.ServerTestBase): - handler = NoContentLengthHTTPHandler - - def test_no_content_length(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - resp = HTTP1Protocol(c).read_response("GET", None) - assert resp.body == "bar\r\n\r\n" - - -class TestAssembleRequest(object): - def test_simple(self): - req = tutils.treq() - b = HTTP1Protocol().assemble_request(req) - assert b == match_http_string(""" - GET /path HTTP/1.1 - header: qvalue - Host: address:22 - Content-Length: 7 - - content""") - - def test_body_missing(self): - req = tutils.treq(content=semantics.CONTENT_MISSING) - tutils.raises(http.HttpError, HTTP1Protocol().assemble_request, req) - - def test_not_a_request(self): - tutils.raises(AssertionError, HTTP1Protocol().assemble_request, 'foo') - - -class TestAssembleResponse(object): - def test_simple(self): - resp = tutils.tresp() - b = HTTP1Protocol().assemble_response(resp) - assert b == match_http_string(""" - HTTP/1.1 200 OK - header_response: svalue - Content-Length: 7 - - message""") - - def test_body_missing(self): - resp = tutils.tresp(content=semantics.CONTENT_MISSING) - tutils.raises(http.HttpError, HTTP1Protocol().assemble_response, resp) - - def test_not_a_request(self): - tutils.raises(AssertionError, HTTP1Protocol().assemble_response, 'foo') diff --git a/test/http/http1/test_read.py b/test/http/http1/test_read.py new file mode 100644 index 00000000..55def2a5 --- /dev/null +++ b/test/http/http1/test_read.py @@ -0,0 +1,317 @@ +from __future__ import absolute_import, print_function, division +from io import BytesIO +import textwrap + +from mock import Mock + +from netlib.exceptions import HttpException, HttpSyntaxException, HttpReadDisconnect +from netlib.http import Headers +from netlib.http.http1.read import ( + read_request, read_response, read_request_head, + read_response_head, read_body, connection_close, expected_http_body_size, _get_first_line, + _read_request_line, _parse_authority_form, _read_response_line, _check_http_version, + _read_headers, _read_chunked +) +from netlib.tutils import treq, tresp, raises + + +def test_read_request(): + rfile = BytesIO(b"GET / HTTP/1.1\r\n\r\nskip") + r = read_request(rfile) + assert r.method == b"GET" + assert r.body == b"" + assert r.timestamp_end + assert rfile.read() == b"skip" + + +def test_read_request_head(): + rfile = BytesIO( + b"GET / HTTP/1.1\r\n" + b"Content-Length: 4\r\n" + b"\r\n" + b"skip" + ) + rfile.reset_timestamps = Mock() + rfile.first_byte_timestamp = 42 + r = read_request_head(rfile) + assert r.method == b"GET" + assert r.headers["Content-Length"] == b"4" + assert r.body is None + assert rfile.reset_timestamps.called + assert r.timestamp_start == 42 + assert rfile.read() == b"skip" + + +def test_read_response(): + req = treq() + rfile = BytesIO(b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody") + r = read_response(rfile, req) + assert r.status_code == 418 + assert r.body == b"body" + assert r.timestamp_end + + +def test_read_response_head(): + rfile = BytesIO( + b"HTTP/1.1 418 I'm a teapot\r\n" + b"Content-Length: 4\r\n" + b"\r\n" + b"skip" + ) + rfile.reset_timestamps = Mock() + rfile.first_byte_timestamp = 42 + r = read_response_head(rfile) + assert r.status_code == 418 + assert r.headers["Content-Length"] == b"4" + assert r.body is None + assert rfile.reset_timestamps.called + assert r.timestamp_start == 42 + assert rfile.read() == b"skip" + + +class TestReadBody(object): + def test_chunked(self): + rfile = BytesIO(b"3\r\nfoo\r\n0\r\n\r\nbar") + body = b"".join(read_body(rfile, None)) + assert body == b"foo" + assert rfile.read() == b"bar" + + + def test_known_size(self): + rfile = BytesIO(b"foobar") + body = b"".join(read_body(rfile, 3)) + assert body == b"foo" + assert rfile.read() == b"bar" + + + def test_known_size_limit(self): + rfile = BytesIO(b"foobar") + with raises(HttpException): + b"".join(read_body(rfile, 3, 2)) + + def test_known_size_too_short(self): + rfile = BytesIO(b"foo") + with raises(HttpException): + b"".join(read_body(rfile, 6)) + + def test_unknown_size(self): + rfile = BytesIO(b"foobar") + body = b"".join(read_body(rfile, -1)) + assert body == b"foobar" + + + def test_unknown_size_limit(self): + rfile = BytesIO(b"foobar") + with raises(HttpException): + b"".join(read_body(rfile, -1, 3)) + + +def test_connection_close(): + headers = Headers() + assert connection_close(b"HTTP/1.0", headers) + assert not connection_close(b"HTTP/1.1", headers) + + headers["connection"] = "keep-alive" + assert not connection_close(b"HTTP/1.1", headers) + + headers["connection"] = "close" + assert connection_close(b"HTTP/1.1", headers) + + +def test_expected_http_body_size(): + # Expect: 100-continue + assert expected_http_body_size( + treq(headers=Headers(expect=b"100-continue", content_length=b"42")) + ) == 0 + + # http://tools.ietf.org/html/rfc7230#section-3.3 + assert expected_http_body_size( + treq(method=b"HEAD"), + tresp(headers=Headers(content_length=b"42")) + ) == 0 + assert expected_http_body_size( + treq(method=b"CONNECT"), + tresp() + ) == 0 + for code in (100, 204, 304): + assert expected_http_body_size( + treq(), + tresp(status_code=code) + ) == 0 + + # chunked + assert expected_http_body_size( + treq(headers=Headers(transfer_encoding=b"chunked")), + ) is None + + # explicit length + for l in (b"foo", b"-7"): + with raises(HttpSyntaxException): + expected_http_body_size( + treq(headers=Headers(content_length=l)) + ) + assert expected_http_body_size( + treq(headers=Headers(content_length=b"42")) + ) == 42 + + # no length + assert expected_http_body_size( + treq() + ) == 0 + assert expected_http_body_size( + treq(), tresp() + ) == -1 + + +def test_get_first_line(): + rfile = BytesIO(b"foo\r\nbar") + assert _get_first_line(rfile) == b"foo" + + rfile = BytesIO(b"\r\nfoo\r\nbar") + assert _get_first_line(rfile) == b"foo" + + with raises(HttpReadDisconnect): + rfile = BytesIO(b"") + _get_first_line(rfile) + + with raises(HttpSyntaxException): + rfile = BytesIO(b"GET /\xff HTTP/1.1") + _get_first_line(rfile) + + +def test_read_request_line(): + def t(b): + return _read_request_line(BytesIO(b)) + + assert (t(b"GET / HTTP/1.1") == + ("relative", b"GET", None, None, None, b"/", b"HTTP/1.1")) + assert (t(b"OPTIONS * HTTP/1.1") == + ("relative", b"OPTIONS", None, None, None, b"*", b"HTTP/1.1")) + assert (t(b"CONNECT foo:42 HTTP/1.1") == + ("authority", b"CONNECT", None, b"foo", 42, None, b"HTTP/1.1")) + assert (t(b"GET http://foo:42/bar HTTP/1.1") == + ("absolute", b"GET", b"http", b"foo", 42, b"/bar", b"HTTP/1.1")) + + with raises(HttpSyntaxException): + t(b"GET / WTF/1.1") + with raises(HttpSyntaxException): + t(b"this is not http") + + +def test_parse_authority_form(): + assert _parse_authority_form(b"foo:42") == (b"foo", 42) + with raises(HttpSyntaxException): + _parse_authority_form(b"foo") + with raises(HttpSyntaxException): + _parse_authority_form(b"foo:bar") + with raises(HttpSyntaxException): + _parse_authority_form(b"foo:99999999") + with raises(HttpSyntaxException): + _parse_authority_form(b"f\x00oo:80") + + +def test_read_response_line(): + def t(b): + return _read_response_line(BytesIO(b)) + + assert t(b"HTTP/1.1 200 OK") == (b"HTTP/1.1", 200, b"OK") + assert t(b"HTTP/1.1 200") == (b"HTTP/1.1", 200, b"") + with raises(HttpSyntaxException): + assert t(b"HTTP/1.1") + + with raises(HttpSyntaxException): + t(b"HTTP/1.1 OK OK") + with raises(HttpSyntaxException): + t(b"WTF/1.1 200 OK") + + +def test_check_http_version(): + _check_http_version(b"HTTP/0.9") + _check_http_version(b"HTTP/1.0") + _check_http_version(b"HTTP/1.1") + _check_http_version(b"HTTP/2.0") + with raises(HttpSyntaxException): + _check_http_version(b"WTF/1.0") + with raises(HttpSyntaxException): + _check_http_version(b"HTTP/1.10") + with raises(HttpSyntaxException): + _check_http_version(b"HTTP/1.b") + + +class TestReadHeaders(object): + @staticmethod + def _read(data): + return _read_headers(BytesIO(data)) + + def test_read_simple(self): + data = ( + b"Header: one\r\n" + b"Header2: two\r\n" + b"\r\n" + ) + headers = self._read(data) + assert headers.fields == [[b"Header", b"one"], [b"Header2", b"two"]] + + def test_read_multi(self): + data = ( + b"Header: one\r\n" + b"Header: two\r\n" + b"\r\n" + ) + headers = self._read(data) + assert headers.fields == [[b"Header", b"one"], [b"Header", b"two"]] + + def test_read_continued(self): + data = ( + b"Header: one\r\n" + b"\ttwo\r\n" + b"Header2: three\r\n" + b"\r\n" + ) + headers = self._read(data) + assert headers.fields == [[b"Header", b"one\r\n two"], [b"Header2", b"three"]] + + def test_read_continued_err(self): + data = b"\tfoo: bar\r\n" + with raises(HttpSyntaxException): + self._read(data) + + def test_read_err(self): + data = b"foo" + with raises(HttpSyntaxException): + self._read(data) + + def test_read_empty_name(self): + data = b":foo" + with raises(HttpSyntaxException): + self._read(data) + +def test_read_chunked(): + req = treq(body=None) + req.headers["Transfer-Encoding"] = "chunked" + + data = b"1\r\na\r\n0\r\n" + with raises(HttpSyntaxException): + b"".join(_read_chunked(BytesIO(data))) + + data = b"1\r\na\r\n0\r\n\r\n" + assert b"".join(_read_chunked(BytesIO(data))) == b"a" + + data = b"\r\n\r\n1\r\na\r\n1\r\nb\r\n0\r\n\r\n" + assert b"".join(_read_chunked(BytesIO(data))) == b"ab" + + data = b"\r\n" + with raises("closed prematurely"): + b"".join(_read_chunked(BytesIO(data))) + + data = b"1\r\nfoo" + with raises("malformed chunked body"): + b"".join(_read_chunked(BytesIO(data))) + + data = b"foo\r\nfoo" + with raises(HttpSyntaxException): + b"".join(_read_chunked(BytesIO(data))) + + data = b"5\r\naaaaa\r\n0\r\n\r\n" + with raises("too large"): + b"".join(_read_chunked(BytesIO(data), limit=2)) diff --git a/test/http/http2/test_frames.py b/test/http/http2/test_frames.py index 5d5cb0ba..4c89b023 100644 --- a/test/http/http2/test_frames.py +++ b/test/http/http2/test_frames.py @@ -1,4 +1,4 @@ -import cStringIO +from io import BytesIO from nose.tools import assert_equal from netlib import tcp, tutils @@ -7,7 +7,7 @@ from netlib.http.http2.frame import * def hex_to_file(data): data = data.decode('hex') - return tcp.Reader(cStringIO.StringIO(data)) + return tcp.Reader(BytesIO(data)) def test_invalid_flags(): @@ -39,7 +39,7 @@ def test_too_large_frames(): flags=Frame.FLAG_END_STREAM, stream_id=0x1234567, payload='foobar' * 3000) - tutils.raises(FrameSizeError, f.to_bytes) + tutils.raises(HttpSyntaxException, f.to_bytes) def test_data_frame_to_bytes(): diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index 2b7d7958..a369eb49 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -2,21 +2,21 @@ import OpenSSL import mock from netlib import tcp, http, tutils -from netlib.http import http2, Headers -from netlib.http.http2 import HTTP2Protocol +from netlib.http import Headers +from netlib.http.http2.connections import HTTP2Protocol, TCPHandler from netlib.http.http2.frame import * from ... import tservers class TestTCPHandlerWrapper: def test_wrapped(self): - h = http2.TCPHandler(rfile='foo', wfile='bar') + h = TCPHandler(rfile='foo', wfile='bar') p = HTTP2Protocol(h) assert p.tcp_handler.rfile == 'foo' assert p.tcp_handler.wfile == 'bar' def test_direct(self): p = HTTP2Protocol(rfile='foo', wfile='bar') - assert isinstance(p.tcp_handler, http2.TCPHandler) + assert isinstance(p.tcp_handler, TCPHandler) assert p.tcp_handler.rfile == 'foo' assert p.tcp_handler.wfile == 'bar' @@ -32,8 +32,8 @@ class EchoHandler(tcp.BaseHandler): class TestProtocol: - @mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface") - @mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface") + @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_server_connection_preface") + @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_client_connection_preface") def test_perform_connection_preface(self, mock_client_method, mock_server_method): protocol = HTTP2Protocol(is_server=False) protocol.connection_preface_performed = True @@ -46,8 +46,8 @@ class TestProtocol: assert mock_client_method.called assert not mock_server_method.called - @mock.patch("netlib.http.http2.HTTP2Protocol.perform_server_connection_preface") - @mock.patch("netlib.http.http2.HTTP2Protocol.perform_client_connection_preface") + @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_server_connection_preface") + @mock.patch("netlib.http.http2.connections.HTTP2Protocol.perform_client_connection_preface") def test_perform_connection_preface_server(self, mock_client_method, mock_server_method): protocol = HTTP2Protocol(is_server=True) protocol.connection_preface_performed = True @@ -64,7 +64,7 @@ class TestProtocol: class TestCheckALPNMatch(tservers.ServerTestBase): handler = EchoHandler ssl = dict( - alpn_select=HTTP2Protocol.ALPN_PROTO_H2, + alpn_select=ALPN_PROTO_H2, ) if OpenSSL._util.lib.Cryptography_HAS_ALPN: @@ -72,7 +72,7 @@ class TestCheckALPNMatch(tservers.ServerTestBase): def test_check_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(alpn_protos=[HTTP2Protocol.ALPN_PROTO_H2]) + c.convert_to_ssl(alpn_protos=[ALPN_PROTO_H2]) protocol = HTTP2Protocol(c) assert protocol.check_alpn() @@ -88,7 +88,7 @@ class TestCheckALPNMismatch(tservers.ServerTestBase): def test_check_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(alpn_protos=[HTTP2Protocol.ALPN_PROTO_H2]) + c.convert_to_ssl(alpn_protos=[ALPN_PROTO_H2]) protocol = HTTP2Protocol(c) tutils.raises(NotImplementedError, protocol.check_alpn) @@ -306,7 +306,7 @@ class TestReadRequest(tservers.ServerTestBase): protocol = HTTP2Protocol(c, is_server=True) protocol.connection_preface_performed = True - req = protocol.read_request() + req = protocol.read_request(NotImplemented) assert req.stream_id assert req.headers.fields == [[':method', 'GET'], [':path', '/'], [':scheme', 'https']] @@ -329,7 +329,7 @@ class TestReadRequestRelative(tservers.ServerTestBase): protocol = HTTP2Protocol(c, is_server=True) protocol.connection_preface_performed = True - req = protocol.read_request() + req = protocol.read_request(NotImplemented) assert req.form_in == "relative" assert req.method == "OPTIONS" @@ -352,7 +352,7 @@ class TestReadRequestAbsolute(tservers.ServerTestBase): protocol = HTTP2Protocol(c, is_server=True) protocol.connection_preface_performed = True - req = protocol.read_request() + req = protocol.read_request(NotImplemented) assert req.form_in == "absolute" assert req.scheme == "http" @@ -378,13 +378,13 @@ class TestReadRequestConnect(tservers.ServerTestBase): protocol = HTTP2Protocol(c, is_server=True) protocol.connection_preface_performed = True - req = protocol.read_request() + req = protocol.read_request(NotImplemented) assert req.form_in == "authority" assert req.method == "CONNECT" assert req.host == "address" assert req.port == 22 - req = protocol.read_request() + req = protocol.read_request(NotImplemented) assert req.form_in == "authority" assert req.method == "CONNECT" assert req.host == "example.com" @@ -410,7 +410,7 @@ class TestReadResponse(tservers.ServerTestBase): protocol = HTTP2Protocol(c) protocol.connection_preface_performed = True - resp = protocol.read_response(stream_id=42) + resp = protocol.read_response(NotImplemented, stream_id=42) assert resp.httpversion == (2, 0) assert resp.status_code == 200 @@ -436,7 +436,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase): protocol = HTTP2Protocol(c) protocol.connection_preface_performed = True - resp = protocol.read_response(stream_id=42) + resp = protocol.read_response(NotImplemented, stream_id=42) assert resp.stream_id == 42 assert resp.httpversion == (2, 0) diff --git a/test/http/test_authentication.py b/test/http/test_authentication.py index 17c91fe5..ee192dd7 100644 --- a/test/http/test_authentication.py +++ b/test/http/test_authentication.py @@ -5,7 +5,7 @@ from netlib.http import authentication, Headers def test_parse_http_basic_auth(): - vals = ("basic", "foo", "bar") + vals = (b"basic", b"foo", b"bar") assert authentication.parse_http_basic_auth( authentication.assemble_http_basic_auth(*vals) ) == vals diff --git a/test/http/test_exceptions.py b/test/http/test_exceptions.py deleted file mode 100644 index 49588d0a..00000000 --- a/test/http/test_exceptions.py +++ /dev/null @@ -1,6 +0,0 @@ -from netlib.http.exceptions import * - -class TestHttpError: - def test_simple(self): - e = HttpError(404, "Not found") - assert str(e) diff --git a/test/http/test_semantics.py b/test/http/test_models.py index 6dcbbe07..8fce2e9d 100644 --- a/test/http/test_semantics.py +++ b/test/http/test_models.py @@ -1,32 +1,11 @@ import mock -from netlib import http -from netlib import odict from netlib import tutils from netlib import utils -from netlib.http import semantics -from netlib.http.semantics import CONTENT_MISSING - -class TestProtocolMixin(object): - @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_response") - @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_request") - def test_assemble_request(self, mock_request_method, mock_response_method): - p = semantics.ProtocolMixin() - p.assemble(tutils.treq()) - assert mock_request_method.called - assert not mock_response_method.called - - @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_response") - @mock.patch("netlib.http.semantics.ProtocolMixin.assemble_request") - def test_assemble_response(self, mock_request_method, mock_response_method): - p = semantics.ProtocolMixin() - p.assemble(tutils.tresp()) - assert not mock_request_method.called - assert mock_response_method.called - - def test_assemble_foo(self): - p = semantics.ProtocolMixin() - tutils.raises(ValueError, p.assemble, 'foo') +from netlib.odict import ODict, ODictCaseless +from netlib.http import Request, Response, Headers, CONTENT_MISSING, HDR_FORM_URLENCODED, \ + HDR_FORM_MULTIPART + class TestRequest(object): def test_repr(self): @@ -34,27 +13,27 @@ class TestRequest(object): assert repr(r) def test_headers(self): - tutils.raises(AssertionError, semantics.Request, + tutils.raises(AssertionError, Request, 'form_in', 'method', 'scheme', 'host', 'port', 'path', - (1, 1), + b"HTTP/1.1", 'foobar', ) - req = semantics.Request( + req = Request( 'form_in', 'method', 'scheme', 'host', 'port', 'path', - (1, 1), + b"HTTP/1.1", ) - assert isinstance(req.headers, http.Headers) + assert isinstance(req.headers, Headers) def test_equal(self): a = tutils.treq() @@ -66,13 +45,6 @@ class TestRequest(object): assert not 'foo' == a assert not 'foo' == b - def test_legacy_first_line(self): - req = tutils.treq() - - assert req.legacy_first_line('relative') == "GET /path HTTP/1.1" - assert req.legacy_first_line('authority') == "GET address:22 HTTP/1.1" - assert req.legacy_first_line('absolute') == "GET http://address:22/path HTTP/1.1" - tutils.raises(http.HttpError, req.legacy_first_line, 'foobar') def test_anticache(self): req = tutils.treq() @@ -103,44 +75,44 @@ class TestRequest(object): def test_get_form(self): req = tutils.treq() - assert req.get_form() == odict.ODict() + assert req.get_form() == ODict() - @mock.patch("netlib.http.semantics.Request.get_form_multipart") - @mock.patch("netlib.http.semantics.Request.get_form_urlencoded") + @mock.patch("netlib.http.Request.get_form_multipart") + @mock.patch("netlib.http.Request.get_form_urlencoded") def test_get_form_with_url_encoded(self, mock_method_urlencoded, mock_method_multipart): req = tutils.treq() - assert req.get_form() == odict.ODict() + assert req.get_form() == ODict() req = tutils.treq() req.body = "foobar" - req.headers["Content-Type"] = semantics.HDR_FORM_URLENCODED + req.headers["Content-Type"] = HDR_FORM_URLENCODED req.get_form() assert req.get_form_urlencoded.called assert not req.get_form_multipart.called - @mock.patch("netlib.http.semantics.Request.get_form_multipart") - @mock.patch("netlib.http.semantics.Request.get_form_urlencoded") + @mock.patch("netlib.http.Request.get_form_multipart") + @mock.patch("netlib.http.Request.get_form_urlencoded") def test_get_form_with_multipart(self, mock_method_urlencoded, mock_method_multipart): req = tutils.treq() req.body = "foobar" - req.headers["Content-Type"] = semantics.HDR_FORM_MULTIPART + req.headers["Content-Type"] = HDR_FORM_MULTIPART req.get_form() assert not req.get_form_urlencoded.called assert req.get_form_multipart.called def test_get_form_urlencoded(self): - req = tutils.treq("foobar") - assert req.get_form_urlencoded() == odict.ODict() + req = tutils.treq(body="foobar") + assert req.get_form_urlencoded() == ODict() - req.headers["Content-Type"] = semantics.HDR_FORM_URLENCODED - assert req.get_form_urlencoded() == odict.ODict(utils.urldecode(req.body)) + req.headers["Content-Type"] = HDR_FORM_URLENCODED + assert req.get_form_urlencoded() == ODict(utils.urldecode(req.body)) def test_get_form_multipart(self): - req = tutils.treq("foobar") - assert req.get_form_multipart() == odict.ODict() + req = tutils.treq(body="foobar") + assert req.get_form_multipart() == ODict() - req.headers["Content-Type"] = semantics.HDR_FORM_MULTIPART - assert req.get_form_multipart() == odict.ODict( + req.headers["Content-Type"] = HDR_FORM_MULTIPART + assert req.get_form_multipart() == ODict( utils.multipartdecode( req.headers, req.body @@ -149,8 +121,8 @@ class TestRequest(object): def test_set_form_urlencoded(self): req = tutils.treq() - req.set_form_urlencoded(odict.ODict([('foo', 'bar'), ('rab', 'oof')])) - assert req.headers["Content-Type"] == semantics.HDR_FORM_URLENCODED + req.set_form_urlencoded(ODict([('foo', 'bar'), ('rab', 'oof')])) + assert req.headers["Content-Type"] == HDR_FORM_URLENCODED assert req.body def test_get_path_components(self): @@ -172,7 +144,7 @@ class TestRequest(object): def test_set_query(self): req = tutils.treq() - req.set_query(odict.ODict([])) + req.set_query(ODict([])) def test_pretty_host(self): r = tutils.treq() @@ -203,21 +175,21 @@ class TestRequest(object): assert req.pretty_url(False) == "http://address:22/path" def test_get_cookies_none(self): - headers = http.Headers() + headers = Headers() r = tutils.treq() r.headers = headers assert len(r.get_cookies()) == 0 def test_get_cookies_single(self): r = tutils.treq() - r.headers = http.Headers(cookie="cookiename=cookievalue") + r.headers = Headers(cookie="cookiename=cookievalue") result = r.get_cookies() assert len(result) == 1 assert result['cookiename'] == ['cookievalue'] def test_get_cookies_double(self): r = tutils.treq() - r.headers = http.Headers(cookie="cookiename=cookievalue;othercookiename=othercookievalue") + r.headers = Headers(cookie="cookiename=cookievalue;othercookiename=othercookievalue") result = r.get_cookies() assert len(result) == 2 assert result['cookiename'] == ['cookievalue'] @@ -225,7 +197,7 @@ class TestRequest(object): def test_get_cookies_withequalsign(self): r = tutils.treq() - r.headers = http.Headers(cookie="cookiename=coo=kievalue;othercookiename=othercookievalue") + r.headers = Headers(cookie="cookiename=coo=kievalue;othercookiename=othercookievalue") result = r.get_cookies() assert len(result) == 2 assert result['cookiename'] == ['coo=kievalue'] @@ -233,14 +205,14 @@ class TestRequest(object): def test_set_cookies(self): r = tutils.treq() - r.headers = http.Headers(cookie="cookiename=cookievalue") + r.headers = Headers(cookie="cookiename=cookievalue") result = r.get_cookies() result["cookiename"] = ["foo"] r.set_cookies(result) assert r.get_cookies()["cookiename"] == ["foo"] def test_set_url(self): - r = tutils.treq_absolute() + r = tutils.treq(form_in="absolute") r.url = "https://otheraddress:42/ORLY" assert r.scheme == "https" assert r.host == "otheraddress" @@ -332,24 +304,19 @@ class TestRequest(object): # "Host: address\r\n" # "Content-Length: 0\r\n\r\n") -class TestEmptyRequest(object): - def test_init(self): - req = semantics.EmptyRequest() - assert req - class TestResponse(object): def test_headers(self): - tutils.raises(AssertionError, semantics.Response, - (1, 1), + tutils.raises(AssertionError, Response, + b"HTTP/1.1", 200, headers='foobar', ) - resp = semantics.Response( - (1, 1), + resp = Response( + b"HTTP/1.1", 200, ) - assert isinstance(resp.headers, http.Headers) + assert isinstance(resp.headers, Headers) def test_equal(self): a = tutils.tresp() @@ -366,24 +333,24 @@ class TestResponse(object): assert "unknown content type" in repr(r) r.headers["content-type"] = "foo" assert "foo" in repr(r) - assert repr(tutils.tresp(content=CONTENT_MISSING)) + assert repr(tutils.tresp(body=CONTENT_MISSING)) def test_get_cookies_none(self): resp = tutils.tresp() - resp.headers = http.Headers() + resp.headers = Headers() assert not resp.get_cookies() def test_get_cookies_simple(self): resp = tutils.tresp() - resp.headers = http.Headers(set_cookie="cookiename=cookievalue") + resp.headers = Headers(set_cookie="cookiename=cookievalue") result = resp.get_cookies() assert len(result) == 1 assert "cookiename" in result - assert result["cookiename"][0] == ["cookievalue", odict.ODict()] + assert result["cookiename"][0] == ["cookievalue", ODict()] def test_get_cookies_with_parameters(self): resp = tutils.tresp() - resp.headers = http.Headers(set_cookie="cookiename=cookievalue;domain=example.com;expires=Wed Oct 21 16:29:41 2015;path=/; HttpOnly") + resp.headers = Headers(set_cookie="cookiename=cookievalue;domain=example.com;expires=Wed Oct 21 16:29:41 2015;path=/; HttpOnly") result = resp.get_cookies() assert len(result) == 1 assert "cookiename" in result @@ -397,7 +364,7 @@ class TestResponse(object): def test_get_cookies_no_value(self): resp = tutils.tresp() - resp.headers = http.Headers(set_cookie="cookiename=; Expires=Thu, 01-Jan-1970 00:00:01 GMT; path=/") + resp.headers = Headers(set_cookie="cookiename=; Expires=Thu, 01-Jan-1970 00:00:01 GMT; path=/") result = resp.get_cookies() assert len(result) == 1 assert "cookiename" in result @@ -406,31 +373,31 @@ class TestResponse(object): def test_get_cookies_twocookies(self): resp = tutils.tresp() - resp.headers = http.Headers([ + resp.headers = Headers([ ["Set-Cookie", "cookiename=cookievalue"], ["Set-Cookie", "othercookie=othervalue"] ]) result = resp.get_cookies() assert len(result) == 2 assert "cookiename" in result - assert result["cookiename"][0] == ["cookievalue", odict.ODict()] + assert result["cookiename"][0] == ["cookievalue", ODict()] assert "othercookie" in result - assert result["othercookie"][0] == ["othervalue", odict.ODict()] + assert result["othercookie"][0] == ["othervalue", ODict()] def test_set_cookies(self): resp = tutils.tresp() v = resp.get_cookies() - v.add("foo", ["bar", odict.ODictCaseless()]) + v.add("foo", ["bar", ODictCaseless()]) resp.set_cookies(v) v = resp.get_cookies() assert len(v) == 1 - assert v["foo"] == [["bar", odict.ODictCaseless()]] + assert v["foo"] == [["bar", ODictCaseless()]] class TestHeaders(object): def _2host(self): - return semantics.Headers( + return Headers( [ ["Host", "example.com"], ["host", "example.org"] @@ -438,25 +405,25 @@ class TestHeaders(object): ) def test_init(self): - headers = semantics.Headers() + headers = Headers() assert len(headers) == 0 - headers = semantics.Headers([["Host", "example.com"]]) + headers = Headers([["Host", "example.com"]]) assert len(headers) == 1 assert headers["Host"] == "example.com" - headers = semantics.Headers(Host="example.com") + headers = Headers(Host="example.com") assert len(headers) == 1 assert headers["Host"] == "example.com" - headers = semantics.Headers( + headers = Headers( [["Host", "invalid"]], Host="example.com" ) assert len(headers) == 1 assert headers["Host"] == "example.com" - headers = semantics.Headers( + headers = Headers( [["Host", "invalid"], ["Accept", "text/plain"]], Host="example.com" ) @@ -465,7 +432,7 @@ class TestHeaders(object): assert headers["Accept"] == "text/plain" def test_getitem(self): - headers = semantics.Headers(Host="example.com") + headers = Headers(Host="example.com") assert headers["Host"] == "example.com" assert headers["host"] == "example.com" tutils.raises(KeyError, headers.__getitem__, "Accept") @@ -474,17 +441,17 @@ class TestHeaders(object): assert headers["Host"] == "example.com, example.org" def test_str(self): - headers = semantics.Headers(Host="example.com") - assert str(headers) == "Host: example.com\r\n" + headers = Headers(Host="example.com") + assert bytes(headers) == "Host: example.com\r\n" - headers = semantics.Headers([ + headers = Headers([ ["Host", "example.com"], ["Accept", "text/plain"] ]) assert str(headers) == "Host: example.com\r\nAccept: text/plain\r\n" def test_setitem(self): - headers = semantics.Headers() + headers = Headers() headers["Host"] = "example.com" assert "Host" in headers assert "host" in headers @@ -507,7 +474,7 @@ class TestHeaders(object): assert "Host" in headers def test_delitem(self): - headers = semantics.Headers(Host="example.com") + headers = Headers(Host="example.com") assert len(headers) == 1 del headers["host"] assert len(headers) == 0 @@ -523,7 +490,7 @@ class TestHeaders(object): assert len(headers) == 0 def test_keys(self): - headers = semantics.Headers(Host="example.com") + headers = Headers(Host="example.com") assert len(headers.keys()) == 1 assert headers.keys()[0] == "Host" @@ -532,13 +499,13 @@ class TestHeaders(object): assert headers.keys()[0] == "Host" def test_eq_ne(self): - headers1 = semantics.Headers(Host="example.com") - headers2 = semantics.Headers(host="example.com") + headers1 = Headers(Host="example.com") + headers2 = Headers(host="example.com") assert not (headers1 == headers2) assert headers1 != headers2 - headers1 = semantics.Headers(Host="example.com") - headers2 = semantics.Headers(Host="example.com") + headers1 = Headers(Host="example.com") + headers2 = Headers(Host="example.com") assert headers1 == headers2 assert not (headers1 != headers2) @@ -550,7 +517,7 @@ class TestHeaders(object): assert headers.get_all("accept") == [] def test_set_all(self): - headers = semantics.Headers(Host="example.com") + headers = Headers(Host="example.com") headers.set_all("Accept", ["text/plain"]) assert len(headers) == 2 assert "accept" in headers @@ -565,9 +532,9 @@ class TestHeaders(object): def test_state(self): headers = self._2host() assert len(headers.get_state()) == 2 - assert headers == semantics.Headers.from_state(headers.get_state()) + assert headers == Headers.from_state(headers.get_state()) - headers2 = semantics.Headers() + headers2 = Headers() assert headers != headers2 headers2.load_state(headers.get_state()) assert headers == headers2 diff --git a/test/test_encoding.py b/test/test_encoding.py index 612aea89..9da3a38d 100644 --- a/test/test_encoding.py +++ b/test/test_encoding.py @@ -9,25 +9,29 @@ def test_identity(): def test_gzip(): - assert "string" == encoding.decode( + assert b"string" == encoding.decode( "gzip", encoding.encode( "gzip", - "string")) - assert None == encoding.decode("gzip", "bogus") + b"string" + ) + ) + assert encoding.decode("gzip", b"bogus") is None def test_deflate(): - assert "string" == encoding.decode( + assert b"string" == encoding.decode( "deflate", encoding.encode( "deflate", - "string")) - assert "string" == encoding.decode( + b"string" + ) + ) + assert b"string" == encoding.decode( "deflate", encoding.encode( "deflate", - "string")[ - 2:- - 4]) - assert None == encoding.decode("deflate", "bogus") + b"string" + )[2:-4] + ) + assert encoding.decode("deflate", b"bogus") is None diff --git a/test/test_utils.py b/test/test_utils.py index 9dba5d35..8b2ddae4 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -36,46 +36,51 @@ def test_pretty_size(): def test_parse_url(): - assert not utils.parse_url("") + with tutils.raises(ValueError): + utils.parse_url("") - u = "http://foo.com:8888/test" - s, h, po, pa = utils.parse_url(u) - assert s == "http" - assert h == "foo.com" + s, h, po, pa = utils.parse_url(b"http://foo.com:8888/test") + assert s == b"http" + assert h == b"foo.com" assert po == 8888 - assert pa == "/test" + assert pa == b"/test" s, h, po, pa = utils.parse_url("http://foo/bar") - assert s == "http" - assert h == "foo" + assert s == b"http" + assert h == b"foo" assert po == 80 - assert pa == "/bar" + assert pa == b"/bar" - s, h, po, pa = utils.parse_url("http://user:pass@foo/bar") - assert s == "http" - assert h == "foo" + s, h, po, pa = utils.parse_url(b"http://user:pass@foo/bar") + assert s == b"http" + assert h == b"foo" assert po == 80 - assert pa == "/bar" + assert pa == b"/bar" - s, h, po, pa = utils.parse_url("http://foo") - assert pa == "/" + s, h, po, pa = utils.parse_url(b"http://foo") + assert pa == b"/" - s, h, po, pa = utils.parse_url("https://foo") + s, h, po, pa = utils.parse_url(b"https://foo") assert po == 443 - assert not utils.parse_url("https://foo:bar") - assert not utils.parse_url("https://foo:") + with tutils.raises(ValueError): + utils.parse_url(b"https://foo:bar") # Invalid IDNA - assert not utils.parse_url("http://\xfafoo") + with tutils.raises(ValueError): + utils.parse_url("http://\xfafoo") # Invalid PATH - assert not utils.parse_url("http:/\xc6/localhost:56121") + with tutils.raises(ValueError): + utils.parse_url("http:/\xc6/localhost:56121") # Null byte in host - assert not utils.parse_url("http://foo\0") + with tutils.raises(ValueError): + utils.parse_url("http://foo\0") # Port out of range - assert not utils.parse_url("http://foo:999999") + _, _, port, _ = utils.parse_url("http://foo:999999") + assert port == 80 # Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt - assert not utils.parse_url('http://lo[calhost') + with tutils.raises(ValueError): + utils.parse_url('http://lo[calhost') def test_unparse_url(): @@ -106,23 +111,25 @@ def test_get_header_tokens(): def test_multipartdecode(): - boundary = 'somefancyboundary' + boundary = b'somefancyboundary' headers = Headers( - content_type='multipart/form-data; boundary=%s' % boundary + content_type=b'multipart/form-data; boundary=' + boundary + ) + content = ( + "--{0}\n" + "Content-Disposition: form-data; name=\"field1\"\n\n" + "value1\n" + "--{0}\n" + "Content-Disposition: form-data; name=\"field2\"\n\n" + "value2\n" + "--{0}--".format(boundary).encode("ascii") ) - content = "--{0}\n" \ - "Content-Disposition: form-data; name=\"field1\"\n\n" \ - "value1\n" \ - "--{0}\n" \ - "Content-Disposition: form-data; name=\"field2\"\n\n" \ - "value2\n" \ - "--{0}--".format(boundary) form = utils.multipartdecode(headers, content) assert len(form) == 2 - assert form[0] == ('field1', 'value1') - assert form[1] == ('field2', 'value2') + assert form[0] == (b"field1", b"value1") + assert form[1] == (b"field2", b"value2") def test_parse_content_type(): diff --git a/test/test_version_check.py b/test/test_version_check.py index 9a127814..ec2396fe 100644 --- a/test/test_version_check.py +++ b/test/test_version_check.py @@ -1,11 +1,11 @@ -import cStringIO +from io import StringIO import mock from netlib import version_check, version @mock.patch("sys.exit") def test_check_mitmproxy_version(sexit): - fp = cStringIO.StringIO() + fp = StringIO() version_check.check_mitmproxy_version(version.IVERSION, fp=fp) assert not fp.getvalue() assert not sexit.called @@ -18,7 +18,7 @@ def test_check_mitmproxy_version(sexit): @mock.patch("sys.exit") def test_check_pyopenssl_version(sexit): - fp = cStringIO.StringIO() + fp = StringIO() version_check.check_pyopenssl_version(fp=fp) assert not fp.getvalue() assert not sexit.called @@ -32,7 +32,7 @@ def test_check_pyopenssl_version(sexit): @mock.patch("OpenSSL.__version__") def test_unparseable_pyopenssl_version(version, sexit): version.split.return_value = ["foo", "bar"] - fp = cStringIO.StringIO() + fp = StringIO() version_check.check_pyopenssl_version(fp=fp) assert "Cannot parse" in fp.getvalue() assert not sexit.called diff --git a/test/tservers.py b/test/tservers.py index 682a9144..1f4ce725 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -1,7 +1,7 @@ from __future__ import (absolute_import, print_function, division) import threading -import Queue -import cStringIO +from six.moves import queue +from io import StringIO import OpenSSL from netlib import tcp from netlib import tutils @@ -27,7 +27,7 @@ class ServerTestBase(object): @classmethod def setupAll(cls): - cls.q = Queue.Queue() + cls.q = queue.Queue() s = cls.makeserver() cls.port = s.address.port cls.server = ServerThread(s) @@ -102,6 +102,6 @@ class TServer(tcp.TCPServer): h.finish() def handle_error(self, connection, client_address, fp=None): - s = cStringIO.StringIO() + s = StringIO() tcp.TCPServer.handle_error(self, connection, client_address, s) self.q.put(s.getvalue()) diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index 57cfd166..3fdeb683 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -1,11 +1,13 @@ import os from nose.tools import raises +from netlib.http.http1 import read_response, read_request from netlib import tcp, tutils, websockets, http from netlib.http import status_codes -from netlib.http.exceptions import * -from netlib.http.http1 import HTTP1Protocol +from netlib.tutils import treq + +from netlib.exceptions import * from .. import tservers @@ -34,9 +36,8 @@ class WebSocketsEchoHandler(tcp.BaseHandler): frame.to_file(self.wfile) def handshake(self): - http1_protocol = HTTP1Protocol(self) - req = http1_protocol.read_request() + req = read_request(self.rfile) key = self.protocol.check_client_handshake(req.headers) preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101) @@ -61,8 +62,6 @@ class WebSocketsClient(tcp.TCPClient): def connect(self): super(WebSocketsClient, self).connect() - http1_protocol = HTTP1Protocol(self) - preamble = 'GET / HTTP/1.1' self.wfile.write(preamble + "\r\n") headers = self.protocol.client_handshake_headers() @@ -70,7 +69,7 @@ class WebSocketsClient(tcp.TCPClient): self.wfile.write(str(headers) + "\r\n") self.wfile.flush() - resp = http1_protocol.read_response("GET", None) + resp = read_response(self.rfile, treq(method="GET")) server_nonce = self.protocol.check_server_handshake(resp.headers) if not server_nonce == self.protocol.create_server_nonce( @@ -158,9 +157,8 @@ class TestWebSockets(tservers.ServerTestBase): class BadHandshakeHandler(WebSocketsEchoHandler): def handshake(self): - http1_protocol = HTTP1Protocol(self) - client_hs = http1_protocol.read_request() + client_hs = read_request(self.rfile) self.protocol.check_client_handshake(client_hs.headers) preamble = 'HTTP/1.1 101 %s' % status_codes.RESPONSES.get(101) |