diff options
author | Maximilian Hils <git@maximilianhils.com> | 2015-09-15 19:12:15 +0200 |
---|---|---|
committer | Maximilian Hils <git@maximilianhils.com> | 2015-09-15 19:12:15 +0200 |
commit | 11e7f476bd4bbcd6d072fa3659f628ae3a19705d (patch) | |
tree | 7fe3f67bcf41af6c573e312ef4e6adfa18f9f870 /netlib | |
parent | 2f9c566e480c377566a0ae044d698a75b45cd54c (diff) | |
download | mitmproxy-11e7f476bd4bbcd6d072fa3659f628ae3a19705d.tar.gz mitmproxy-11e7f476bd4bbcd6d072fa3659f628ae3a19705d.tar.bz2 mitmproxy-11e7f476bd4bbcd6d072fa3659f628ae3a19705d.zip |
wip
Diffstat (limited to 'netlib')
-rw-r--r-- | netlib/encoding.py | 8 | ||||
-rw-r--r-- | netlib/exceptions.py | 31 | ||||
-rw-r--r-- | netlib/http/__init__.py | 9 | ||||
-rw-r--r-- | netlib/http/authentication.py | 4 | ||||
-rw-r--r-- | netlib/http/exceptions.py | 9 | ||||
-rw-r--r-- | netlib/http/http1/__init__.py | 23 | ||||
-rw-r--r-- | netlib/http/http1/assemble.py | 105 | ||||
-rw-r--r-- | netlib/http/http1/protocol.py | 586 | ||||
-rw-r--r-- | netlib/http/http1/read.py | 346 | ||||
-rw-r--r-- | netlib/http/http2/__init__.py | 2 | ||||
-rw-r--r-- | netlib/http/http2/connections.py (renamed from netlib/http/http2/protocol.py) | 0 | ||||
-rw-r--r-- | netlib/http/http2/frames.py (renamed from netlib/http/http2/frame.py) | 0 | ||||
-rw-r--r-- | netlib/http/models.py (renamed from netlib/http/semantics.py) | 221 | ||||
-rw-r--r-- | netlib/tcp.py | 8 | ||||
-rw-r--r-- | netlib/tutils.py | 70 | ||||
-rw-r--r-- | netlib/utils.py | 162 | ||||
-rw-r--r-- | netlib/version_check.py | 17 | ||||
-rw-r--r-- | netlib/websockets/__init__.py | 4 |
18 files changed, 759 insertions, 846 deletions
diff --git a/netlib/encoding.py b/netlib/encoding.py index f107eb5f..06830f2c 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -2,13 +2,13 @@ Utility functions for decoding response bodies. """ from __future__ import absolute_import -import cStringIO +from io import BytesIO import gzip import zlib __ALL__ = ["ENCODINGS"] -ENCODINGS = set(["identity", "gzip", "deflate"]) +ENCODINGS = {"identity", "gzip", "deflate"} def decode(e, content): @@ -42,7 +42,7 @@ def identity(content): def decode_gzip(content): - gfile = gzip.GzipFile(fileobj=cStringIO.StringIO(content)) + gfile = gzip.GzipFile(fileobj=BytesIO(content)) try: return gfile.read() except (IOError, EOFError): @@ -50,7 +50,7 @@ def decode_gzip(content): def encode_gzip(content): - s = cStringIO.StringIO() + s = BytesIO() gf = gzip.GzipFile(fileobj=s, mode='wb') gf.write(content) gf.close() diff --git a/netlib/exceptions.py b/netlib/exceptions.py new file mode 100644 index 00000000..637be3df --- /dev/null +++ b/netlib/exceptions.py @@ -0,0 +1,31 @@ +""" +We try to be very hygienic regarding the exceptions we throw: +Every Exception netlib raises shall be a subclass of NetlibException. + + +See also: http://lucumr.pocoo.org/2014/10/16/on-error-handling/ +""" +from __future__ import absolute_import, print_function, division + + +class NetlibException(Exception): + """ + Base class for all exceptions thrown by netlib. + """ + def __init__(self, message=None): + super(NetlibException, self).__init__(message) + + +class ReadDisconnect(object): + """Immediate EOF""" + + +class HttpException(NetlibException): + pass + + +class HttpReadDisconnect(HttpException, ReadDisconnect): + pass + +class HttpSyntaxException(HttpException): + pass diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index 9b4b0e6b..0b1a0bc5 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -1,2 +1,7 @@ -from exceptions import * -from semantics import * +from .models import Request, Response, Headers, CONTENT_MISSING +from . import http1, http2 + +__all__ = [ + "Request", "Response", "Headers", "CONTENT_MISSING" + "http1", "http2" +] diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py index fe1f0d14..2055f843 100644 --- a/netlib/http/authentication.py +++ b/netlib/http/authentication.py @@ -19,8 +19,8 @@ def parse_http_basic_auth(s): def assemble_http_basic_auth(scheme, username, password): - v = binascii.b2a_base64(username + ":" + password) - return scheme + " " + v + v = binascii.b2a_base64(username + b":" + password) + return scheme + b" " + v class NullProxyAuth(object): diff --git a/netlib/http/exceptions.py b/netlib/http/exceptions.py deleted file mode 100644 index 8a2bbebc..00000000 --- a/netlib/http/exceptions.py +++ /dev/null @@ -1,9 +0,0 @@ -class HttpError(Exception): - - def __init__(self, code, message): - super(HttpError, self).__init__(message) - self.code = code - - -class HttpErrorConnClosed(HttpError): - pass diff --git a/netlib/http/http1/__init__.py b/netlib/http/http1/__init__.py index 6b5043af..4d223f97 100644 --- a/netlib/http/http1/__init__.py +++ b/netlib/http/http1/__init__.py @@ -1 +1,22 @@ -from protocol import * +from .read import ( + read_request, read_request_head, + read_response, read_response_head, + read_message_body, read_message_body_chunked, + connection_close, + expected_http_body_size, +) +from .assemble import ( + assemble_request, assemble_request_head, + assemble_response, assemble_response_head, +) + + +__all__ = [ + "read_request", "read_request_head", + "read_response", "read_response_head", + "read_message_body", "read_message_body_chunked", + "connection_close", + "expected_http_body_size", + "assemble_request", "assemble_request_head", + "assemble_response", "assemble_response_head", +] diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py new file mode 100644 index 00000000..a3269eed --- /dev/null +++ b/netlib/http/http1/assemble.py @@ -0,0 +1,105 @@ +from __future__ import absolute_import, print_function, division + +from ... import utils +from ...exceptions import HttpException +from .. import CONTENT_MISSING + + +def assemble_request(request): + if request.body == CONTENT_MISSING: + raise HttpException("Cannot assemble flow with CONTENT_MISSING") + head = assemble_request_head(request) + return head + request.body + + +def assemble_request_head(request): + first_line = _assemble_request_line(request) + headers = _assemble_request_headers(request) + return b"%s\r\n%s\r\n" % (first_line, headers) + + +def assemble_response(response): + if response.body == CONTENT_MISSING: + raise HttpException("Cannot assemble flow with CONTENT_MISSING") + head = assemble_response_head(response) + return head + response.body + + +def assemble_response_head(response): + first_line = _assemble_response_line(response) + headers = _assemble_response_headers(response) + return b"%s\r\n%s\r\n" % (first_line, headers) + + + + +def _assemble_request_line(request, form=None): + if form is None: + form = request.form_out + if form == "relative": + return b"%s %s %s" % ( + request.method, + request.path, + request.httpversion + ) + elif form == "authority": + return b"%s %s:%d %s" % ( + request.method, + request.host, + request.port, + request.httpversion + ) + elif form == "absolute": + return b"%s %s://%s:%s%s %s" % ( + request.method, + request.scheme, + request.host, + request.port, + request.path, + request.httpversion + ) + else: # pragma: nocover + raise RuntimeError("Invalid request form") + + +def _assemble_request_headers(request): + headers = request.headers.copy() + for k in request._headers_to_strip_off: + headers.pop(k, None) + if b"host" not in headers and request.scheme and request.host and request.port: + headers[b"Host"] = utils.hostport( + request.scheme, + request.host, + request.port + ) + + # If content is defined (i.e. not None or CONTENT_MISSING), we always + # add a content-length header. + if request.body or request.body == b"": + headers[b"Content-Length"] = str(len(request.body)).encode("ascii") + + return str(headers) + + +def _assemble_response_line(response): + return b"%s %s %s" % ( + response.httpversion, + response.status_code, + response.msg, + ) + + +def _assemble_response_headers(response, preserve_transfer_encoding=False): + # TODO: Remove preserve_transfer_encoding + headers = response.headers.copy() + for k in response._headers_to_strip_off: + headers.pop(k, None) + if not preserve_transfer_encoding: + headers.pop(b"Transfer-Encoding", None) + + # If body is defined (i.e. not None or CONTENT_MISSING), we always + # add a content-length header. + if response.body or response.body == b"": + headers[b"Content-Length"] = str(len(response.body)).encode("ascii") + + return bytes(headers) diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py deleted file mode 100644 index cf1dffa3..00000000 --- a/netlib/http/http1/protocol.py +++ /dev/null @@ -1,586 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import string -import sys -import time - -from ... import utils, tcp, http -from .. import semantics, Headers -from ..exceptions import * - - -class TCPHandler(object): - - def __init__(self, rfile, wfile=None): - self.rfile = rfile - self.wfile = wfile - - -class HTTP1Protocol(semantics.ProtocolMixin): - - ALPN_PROTO_HTTP1 = 'http/1.1' - - def __init__(self, tcp_handler=None, rfile=None, wfile=None): - self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) - - def read_request( - self, - include_body=True, - body_size_limit=None, - allow_empty=False, - ): - """ - Parse an HTTP request from a file stream - - Args: - include_body (bool): Read response body as well - body_size_limit (bool): Maximum body size - wfile (file): If specified, HTTP Expect headers are handled - automatically, by writing a HTTP 100 CONTINUE response to the stream. - - Returns: - Request: The HTTP request - - Raises: - HttpError: If the input is invalid. - """ - timestamp_start = time.time() - if hasattr(self.tcp_handler.rfile, "reset_timestamps"): - self.tcp_handler.rfile.reset_timestamps() - - httpversion, host, port, scheme, method, path, headers, body = ( - None, None, None, None, None, None, None, None) - - request_line = self._get_request_line() - if not request_line: - if allow_empty: - return http.EmptyRequest() - else: - raise tcp.NetLibDisconnect() - - request_line_parts = self._parse_init(request_line) - if not request_line_parts: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - method, path, httpversion = request_line_parts - - if path == '*' or path.startswith("/"): - form_in = "relative" - if not utils.isascii(path): - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - elif method == 'CONNECT': - form_in = "authority" - r = self._parse_init_connect(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - host, port, httpversion = r - path = None - else: - form_in = "absolute" - r = self._parse_init_proxy(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - _, scheme, host, port, path, _ = r - - headers = self.read_headers() - if headers is None: - raise HttpError(400, "Invalid headers") - - expect_header = headers.get("expect", "").lower() - if expect_header == "100-continue" and httpversion == (1, 1): - self.tcp_handler.wfile.write( - 'HTTP/1.1 100 Continue\r\n' - '\r\n' - ) - self.tcp_handler.wfile.flush() - del headers['expect'] - - if include_body: - body = self.read_http_body( - headers, - body_size_limit, - method, - None, - True - ) - - if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = self.tcp_handler.rfile.first_byte_timestamp - - timestamp_end = time.time() - - return http.Request( - form_in, - method, - scheme, - host, - port, - path, - httpversion, - headers, - body, - timestamp_start, - timestamp_end, - ) - - def read_response( - self, - request_method, - body_size_limit=None, - include_body=True, - ): - """ - Returns an http.Response - - By default, both response header and body are read. - If include_body=False is specified, body may be one of the - following: - - None, if the response is technically allowed to have a response body - - "", if the response must not have a response body (e.g. it's a - response to a HEAD request) - """ - timestamp_start = time.time() - if hasattr(self.tcp_handler.rfile, "reset_timestamps"): - self.tcp_handler.rfile.reset_timestamps() - - line = self.tcp_handler.rfile.readline() - # Possible leftover from previous message - if line == "\r\n" or line == "\n": - line = self.tcp_handler.rfile.readline() - if not line: - raise HttpErrorConnClosed(502, "Server disconnect.") - parts = self.parse_response_line(line) - if not parts: - raise HttpError(502, "Invalid server response: %s" % repr(line)) - proto, code, msg = parts - httpversion = self._parse_http_protocol(proto) - if httpversion is None: - raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) - headers = self.read_headers() - if headers is None: - raise HttpError(502, "Invalid headers.") - - if include_body: - body = self.read_http_body( - headers, - body_size_limit, - request_method, - code, - False - ) - else: - # if include_body==False then a None body means the body should be - # read separately - body = None - - if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = self.tcp_handler.rfile.first_byte_timestamp - - if include_body: - timestamp_end = time.time() - else: - timestamp_end = None - - return http.Response( - httpversion, - code, - msg, - headers, - body, - timestamp_start=timestamp_start, - timestamp_end=timestamp_end, - ) - - def assemble_request(self, request): - assert isinstance(request, semantics.Request) - - if request.body == semantics.CONTENT_MISSING: - raise http.HttpError( - 502, - "Cannot assemble flow with CONTENT_MISSING" - ) - first_line = self._assemble_request_first_line(request) - headers = self._assemble_request_headers(request) - return "%s\r\n%s\r\n%s" % (first_line, headers, request.body) - - def assemble_response(self, response): - assert isinstance(response, semantics.Response) - - if response.body == semantics.CONTENT_MISSING: - raise http.HttpError( - 502, - "Cannot assemble flow with CONTENT_MISSING" - ) - first_line = self._assemble_response_first_line(response) - headers = self._assemble_response_headers(response) - return "%s\r\n%s\r\n%s" % (first_line, headers, response.body) - - def read_headers(self): - """ - Read a set of headers. - Stop once a blank line is reached. - - Return a Header object, or None if headers are invalid. - """ - ret = [] - while True: - line = self.tcp_handler.rfile.readline() - if not line or line == '\r\n' or line == '\n': - break - if line[0] in ' \t': - if not ret: - return None - # continued header - ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() - else: - i = line.find(':') - # We're being liberal in what we accept, here. - if i > 0: - name = line[:i] - value = line[i + 1:].strip() - ret.append([name, value]) - else: - return None - return Headers(ret) - - - def read_http_body(self, *args, **kwargs): - return "".join(self.read_http_body_chunked(*args, **kwargs)) - - - def read_http_body_chunked( - self, - headers, - limit, - request_method, - response_code, - is_request, - max_chunk_size=None - ): - """ - Read an HTTP message body: - headers: A Header object - limit: Size limit. - is_request: True if the body to read belongs to a request, False - otherwise - """ - if max_chunk_size is None: - max_chunk_size = limit or sys.maxsize - - expected_size = self.expected_http_body_size( - headers, is_request, request_method, response_code - ) - - if expected_size is None: - if self.has_chunked_encoding(headers): - # Python 3: yield from - for x in self._read_chunked(limit, is_request): - yield x - else: # pragma: nocover - raise HttpError( - 400 if is_request else 502, - "Content-Length unknown but no chunked encoding" - ) - elif expected_size >= 0: - if limit is not None and expected_size > limit: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s, content-length was %s" % ( - limit, expected_size - ) - ) - bytes_left = expected_size - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - content = self.tcp_handler.rfile.read(chunk_size) - yield content - bytes_left -= chunk_size - else: - bytes_left = limit or -1 - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - content = self.tcp_handler.rfile.read(chunk_size) - if not content: - return - yield content - bytes_left -= chunk_size - not_done = self.tcp_handler.rfile.read(1) - if not_done: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s," % limit - ) - - @classmethod - def expected_http_body_size( - self, - headers, - is_request, - request_method, - response_code, - ): - """ - Returns the expected body length: - - a positive integer, if the size is known in advance - - None, if the size in unknown in advance (chunked encoding or invalid - data) - - -1, if all data should be read until end of stream. - - May raise HttpError. - """ - # Determine response size according to - # http://tools.ietf.org/html/rfc7230#section-3.3 - if request_method: - request_method = request_method.upper() - - if (not is_request and ( - request_method == "HEAD" or - (request_method == "CONNECT" and response_code == 200) or - response_code in [204, 304] or - 100 <= response_code <= 199)): - return 0 - if self.has_chunked_encoding(headers): - return None - if "content-length" in headers: - try: - size = int(headers["content-length"]) - if size < 0: - raise ValueError() - return size - except ValueError: - return None - if is_request: - return 0 - return -1 - - - @classmethod - def has_chunked_encoding(self, headers): - return "chunked" in headers.get("transfer-encoding", "").lower() - - - def _get_request_line(self): - """ - Get a line, possibly preceded by a blank. - """ - line = self.tcp_handler.rfile.readline() - if line == "\r\n" or line == "\n": - # Possible leftover from previous message - line = self.tcp_handler.rfile.readline() - return line - - def _read_chunked(self, limit, is_request): - """ - Read a chunked HTTP body. - - May raise HttpError. - """ - # FIXME: Should check if chunked is the final encoding in the headers - # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 - # 3.3 2. - total = 0 - code = 400 if is_request else 502 - while True: - line = self.tcp_handler.rfile.readline(128) - if line == "": - raise HttpErrorConnClosed(code, "Connection closed prematurely") - if line != '\r\n' and line != '\n': - try: - length = int(line, 16) - except ValueError: - raise HttpError( - code, - "Invalid chunked encoding length: %s" % line - ) - total += length - if limit is not None and total > limit: - msg = "HTTP Body too large. Limit is %s," \ - " chunked content longer than %s" % (limit, total) - raise HttpError(code, msg) - chunk = self.tcp_handler.rfile.read(length) - suffix = self.tcp_handler.rfile.readline(5) - if suffix != '\r\n': - raise HttpError(code, "Malformed chunked body") - if length == 0: - return - yield chunk - - @classmethod - def _parse_http_protocol(self, line): - """ - Parse an HTTP protocol declaration. - Returns a (major, minor) tuple, or None. - """ - if not line.startswith("HTTP/"): - return None - _, version = line.split('/', 1) - if "." not in version: - return None - major, minor = version.split('.', 1) - try: - major = int(major) - minor = int(minor) - except ValueError: - return None - return major, minor - - @classmethod - def _parse_init(self, line): - try: - method, url, protocol = string.split(line) - except ValueError: - return None - httpversion = self._parse_http_protocol(protocol) - if not httpversion: - return None - if not utils.isascii(method): - return None - return method, url, httpversion - - @classmethod - def _parse_init_connect(self, line): - """ - Returns (host, port, httpversion) if line is a valid CONNECT line. - http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 - """ - v = self._parse_init(line) - if not v: - return None - method, url, httpversion = v - - if method.upper() != 'CONNECT': - return None - try: - host, port = url.split(":") - except ValueError: - return None - try: - port = int(port) - except ValueError: - return None - if not utils.is_valid_port(port): - return None - if not utils.is_valid_host(host): - return None - return host, port, httpversion - - @classmethod - def _parse_init_proxy(self, line): - v = self._parse_init(line) - if not v: - return None - method, url, httpversion = v - - parts = utils.parse_url(url) - if not parts: - return None - scheme, host, port, path = parts - return method, scheme, host, port, path, httpversion - - @classmethod - def _parse_init_http(self, line): - """ - Returns (method, url, httpversion) - """ - v = self._parse_init(line) - if not v: - return None - method, url, httpversion = v - if not utils.isascii(url): - return None - if not (url.startswith("/") or url == "*"): - return None - return method, url, httpversion - - @classmethod - def connection_close(self, httpversion, headers): - """ - Checks the message to see if the client connection should be closed - according to RFC 2616 Section 8.1 Note that a connection should be - closed as well if the response has been read until end of the stream. - """ - # At first, check if we have an explicit Connection header. - if "connection" in headers: - toks = utils.get_header_tokens(headers, "connection") - if "close" in toks: - return True - elif "keep-alive" in toks: - return False - - # If we don't have a Connection header, HTTP 1.1 connections are assumed to - # be persistent - return httpversion != (1, 1) - - @classmethod - def parse_response_line(self, line): - parts = line.strip().split(" ", 2) - if len(parts) == 2: # handle missing message gracefully - parts.append("") - if len(parts) != 3: - return None - proto, code, msg = parts - try: - code = int(code) - except ValueError: - return None - return (proto, code, msg) - - @classmethod - def _assemble_request_first_line(self, request): - return request.legacy_first_line() - - def _assemble_request_headers(self, request): - headers = request.headers.copy() - for k in request._headers_to_strip_off: - headers.pop(k, None) - if 'host' not in headers and request.scheme and request.host and request.port: - headers["Host"] = utils.hostport( - request.scheme, - request.host, - request.port - ) - - # If content is defined (i.e. not None or CONTENT_MISSING), we always - # add a content-length header. - if request.body or request.body == "": - headers["Content-Length"] = str(len(request.body)) - - return str(headers) - - def _assemble_response_first_line(self, response): - return 'HTTP/%s.%s %s %s' % ( - response.httpversion[0], - response.httpversion[1], - response.status_code, - response.msg, - ) - - def _assemble_response_headers( - self, - response, - preserve_transfer_encoding=False, - ): - headers = response.headers.copy() - for k in response._headers_to_strip_off: - headers.pop(k, None) - if not preserve_transfer_encoding: - headers.pop('Transfer-Encoding', None) - - # If body is defined (i.e. not None or CONTENT_MISSING), we always - # add a content-length header. - if response.body or response.body == "": - headers["Content-Length"] = str(len(response.body)) - - return str(headers) diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py new file mode 100644 index 00000000..573bc739 --- /dev/null +++ b/netlib/http/http1/read.py @@ -0,0 +1,346 @@ +from __future__ import absolute_import, print_function, division +import time +import sys +import re + +from ... import utils +from ...exceptions import HttpReadDisconnect, HttpSyntaxException, HttpException +from .. import Request, Response, Headers + +ALPN_PROTO_HTTP1 = 'http/1.1' + + +def read_request(rfile, body_size_limit=None): + request = read_request_head(rfile) + request.body = read_message_body(rfile, request, limit=body_size_limit) + request.timestamp_end = time.time() + return request + + +def read_request_head(rfile): + """ + Parse an HTTP request head (request line + headers) from an input stream + + Args: + rfile: The input stream + body_size_limit (bool): Maximum body size + + Returns: + The HTTP request object + + Raises: + HttpReadDisconnect: If no bytes can be read from rfile. + HttpSyntaxException: If the input is invalid. + HttpException: A different error occured. + """ + timestamp_start = time.time() + if hasattr(rfile, "reset_timestamps"): + rfile.reset_timestamps() + + form, method, scheme, host, port, path, http_version = _read_request_line(rfile) + headers = _read_headers(rfile) + + if hasattr(rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = rfile.first_byte_timestamp + + return Request( + form, method, scheme, host, port, path, http_version, headers, None, timestamp_start + ) + + +def read_response(rfile, request, body_size_limit=None): + response = read_response_head(rfile) + response.body = read_message_body(rfile, request, response, body_size_limit) + response.timestamp_end = time.time() + return response + + +def read_response_head(rfile): + timestamp_start = time.time() + if hasattr(rfile, "reset_timestamps"): + rfile.reset_timestamps() + + http_version, status_code, message = _read_response_line(rfile) + headers = _read_headers(rfile) + + if hasattr(rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = rfile.first_byte_timestamp + + return Response( + http_version, + status_code, + message, + headers, + None, + timestamp_start + ) + + +def read_message_body(*args, **kwargs): + chunks = read_message_body_chunked(*args, **kwargs) + return b"".join(chunks) + + +def read_message_body_chunked(rfile, request, response=None, limit=None, max_chunk_size=None): + """ + Read an HTTP message body: + + Args: + If a request body should be read, only request should be passed. + If a response body should be read, both request and response should be passed. + + Raises: + HttpException + """ + if not response: + headers = request.headers + response_code = None + is_request = True + else: + headers = response.headers + response_code = response.status_code + is_request = False + + if not limit or limit < 0: + limit = sys.maxsize + if not max_chunk_size: + max_chunk_size = limit + + expected_size = expected_http_body_size( + headers, is_request, request.method, response_code + ) + + if expected_size is None: + for x in _read_chunked(rfile, limit): + yield x + elif expected_size >= 0: + if limit is not None and expected_size > limit: + raise HttpException( + "HTTP Body too large. " + "Limit is {}, content length was advertised as {}".format(limit, expected_size) + ) + bytes_left = expected_size + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + content = rfile.read(chunk_size) + yield content + bytes_left -= chunk_size + else: + bytes_left = limit + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + content = rfile.read(chunk_size) + if not content: + return + yield content + bytes_left -= chunk_size + not_done = rfile.read(1) + if not_done: + raise HttpException("HTTP body too large. Limit is {}.".format(limit)) + + +def connection_close(http_version, headers): + """ + Checks the message to see if the client connection should be closed + according to RFC 2616 Section 8.1. + """ + # At first, check if we have an explicit Connection header. + if b"connection" in headers: + toks = utils.get_header_tokens(headers, "connection") + if b"close" in toks: + return True + elif b"keep-alive" in toks: + return False + + # If we don't have a Connection header, HTTP 1.1 connections are assumed to + # be persistent + return http_version != (1, 1) + + +def expected_http_body_size( + headers, + is_request, + request_method, + response_code, +): + """ + Returns the expected body length: + - a positive integer, if the size is known in advance + - None, if the size in unknown in advance (chunked encoding) + - -1, if all data should be read until end of stream. + + Raises: + HttpSyntaxException, if the content length header is invalid + """ + # Determine response size according to + # http://tools.ietf.org/html/rfc7230#section-3.3 + if request_method: + request_method = request_method.upper() + + is_empty_response = (not is_request and ( + request_method == b"HEAD" or + 100 <= response_code <= 199 or + (response_code == 200 and request_method == b"CONNECT") or + response_code in (204, 304) + )) + + if is_empty_response: + return 0 + if is_request and headers.get(b"expect", b"").lower() == b"100-continue": + return 0 + if b"chunked" in headers.get(b"transfer-encoding", b"").lower(): + return None + if b"content-length" in headers: + try: + size = int(headers[b"content-length"]) + if size < 0: + raise ValueError() + return size + except ValueError: + raise HttpSyntaxException("Unparseable Content Length") + if is_request: + return 0 + return -1 + + +def _get_first_line(rfile): + line = rfile.readline() + if line == b"\r\n" or line == b"\n": + # Possible leftover from previous message + line = rfile.readline() + if not line: + raise HttpReadDisconnect() + return line + + +def _read_request_line(rfile): + line = _get_first_line(rfile) + + try: + method, path, http_version = line.strip().split(b" ") + + if path == b"*" or path.startswith(b"/"): + form = "relative" + path.decode("ascii") # should not raise a ValueError + scheme, host, port = None, None, None + elif method == b"CONNECT": + form = "authority" + host, port = _parse_authority_form(path) + scheme, path = None, None + else: + form = "absolute" + scheme, host, port, path = utils.parse_url(path) + + except ValueError: + raise HttpSyntaxException("Bad HTTP request line: {}".format(line)) + + return form, method, scheme, host, port, path, http_version + + +def _parse_authority_form(hostport): + """ + Returns (host, port) if hostport is a valid authority-form host specification. + http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 + + Raises: + ValueError, if the input is malformed + """ + try: + host, port = hostport.split(b":") + port = int(port) + if not utils.is_valid_host(host) or not utils.is_valid_port(port): + raise ValueError() + except ValueError: + raise ValueError("Invalid host specification: {}".format(hostport)) + + return host, port + + +def _read_response_line(rfile): + line = _get_first_line(rfile) + + try: + + parts = line.strip().split(b" ") + if len(parts) == 2: # handle missing message gracefully + parts.append(b"") + + http_version, status_code, message = parts + status_code = int(status_code) + _check_http_version(http_version) + + except ValueError: + raise HttpSyntaxException("Bad HTTP response line: {}".format(line)) + + return http_version, status_code, message + + +def _check_http_version(http_version): + if not re.match(rb"^HTTP/\d\.\d$", http_version): + raise HttpSyntaxException("Unknown HTTP version: {}".format(http_version)) + + +def _read_headers(rfile): + """ + Read a set of headers. + Stop once a blank line is reached. + + Returns: + A headers object + + Raises: + HttpSyntaxException + """ + ret = [] + while True: + line = rfile.readline() + if not line or line == b"\r\n" or line == b"\n": + break + if line[0] in b" \t": + if not ret: + raise HttpSyntaxException("Invalid headers") + # continued header + ret[-1][1] = ret[-1][1] + b'\r\n ' + line.strip() + else: + try: + name, value = line.split(b":", 1) + value = value.strip() + ret.append([name, value]) + except ValueError: + raise HttpSyntaxException("Invalid headers") + return Headers(ret) + + +def _read_chunked(rfile, limit): + """ + Read a HTTP body with chunked transfer encoding. + + Args: + rfile: the input file + limit: A positive integer + """ + total = 0 + while True: + line = rfile.readline(128) + if line == b"": + raise HttpException("Connection closed prematurely") + if line != b"\r\n" and line != b"\n": + try: + length = int(line, 16) + except ValueError: + raise HttpSyntaxException("Invalid chunked encoding length: {}".format(line)) + total += length + if total > limit: + raise HttpException( + "HTTP Body too large. Limit is {}, " + "chunked content longer than {}".format(limit, total) + ) + chunk = rfile.read(length) + suffix = rfile.readline(5) + if suffix != b"\r\n": + raise HttpSyntaxException("Malformed chunked body") + if length == 0: + return + yield chunk diff --git a/netlib/http/http2/__init__.py b/netlib/http/http2/__init__.py index 5acf7696..e69de29b 100644 --- a/netlib/http/http2/__init__.py +++ b/netlib/http/http2/__init__.py @@ -1,2 +0,0 @@ -from frame import * -from protocol import * diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/connections.py index b6d376d3..b6d376d3 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/connections.py diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frames.py index b36b3adf..b36b3adf 100644 --- a/netlib/http/http2/frame.py +++ b/netlib/http/http2/frames.py diff --git a/netlib/http/semantics.py b/netlib/http/models.py index 5bb098a7..bd5863b1 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/models.py @@ -1,20 +1,25 @@ -from __future__ import (absolute_import, print_function, division) -import UserDict +from __future__ import absolute_import, print_function, division import copy -import urllib -import urlparse -from .. import odict -from . import cookies, exceptions -from netlib import utils, encoding +from ..odict import ODict +from .. import utils, encoding +from ..utils import always_bytes, always_byte_args +from . import cookies -HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" -HDR_FORM_MULTIPART = "multipart/form-data" +import six +from six.moves import urllib +try: + from collections import MutableMapping +except ImportError: + from collections.abc import MutableMapping + +HDR_FORM_URLENCODED = b"application/x-www-form-urlencoded" +HDR_FORM_MULTIPART = b"multipart/form-data" CONTENT_MISSING = 0 -class Headers(object, UserDict.DictMixin): +class Headers(MutableMapping, object): """ Header class which allows both convenient access to individual headers as well as direct access to the underlying raw data. Provides a full dictionary interface. @@ -62,10 +67,12 @@ class Headers(object, UserDict.DictMixin): For use with the "Set-Cookie" header, see :py:meth:`get_all`. """ + @always_byte_args("ascii") def __init__(self, fields=None, **headers): """ Args: - fields: (optional) list of ``(name, value)`` header tuples, e.g. ``[("Host","example.com")]`` + fields: (optional) list of ``(name, value)`` header tuples, + e.g. ``[("Host","example.com")]``. All names and values must be bytes. **headers: Additional headers to set. Will overwrite existing values from `fields`. For convenience, underscores in header names will be transformed to dashes - this behaviour does not extend to other methods. @@ -76,21 +83,25 @@ class Headers(object, UserDict.DictMixin): # content_type -> content-type headers = { - name.replace("_", "-"): value - for name, value in headers.iteritems() + name.encode("ascii").replace(b"_", b"-"): value + for name, value in six.iteritems(headers) } self.update(headers) - def __str__(self): - return "\r\n".join(": ".join(field) for field in self.fields) + "\r\n" + def __bytes__(self): + return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" + + if six.PY2: + __str__ = __bytes__ + @always_byte_args("ascii") def __getitem__(self, name): values = self.get_all(name) if not values: raise KeyError(name) - else: - return ", ".join(values) + return b", ".join(values) + @always_byte_args("ascii") def __setitem__(self, name, value): idx = self._index(name) @@ -101,6 +112,7 @@ class Headers(object, UserDict.DictMixin): else: self.fields.append([name, value]) + @always_byte_args("ascii") def __delitem__(self, name): if name not in self: raise KeyError(name) @@ -110,6 +122,19 @@ class Headers(object, UserDict.DictMixin): if name != field[0].lower() ] + def __iter__(self): + seen = set() + for name, _ in self.fields: + name_lower = name.lower() + if name_lower not in seen: + seen.add(name_lower) + yield name + + def __len__(self): + return len(set(name.lower() for name, _ in self.fields)) + + #__hash__ = object.__hash__ + def _index(self, name): name = name.lower() for i, field in enumerate(self.fields): @@ -117,16 +142,6 @@ class Headers(object, UserDict.DictMixin): return i return None - def keys(self): - seen = set() - names = [] - for name, _ in self.fields: - name_lower = name.lower() - if name_lower not in seen: - seen.add(name_lower) - names.append(name) - return names - def __eq__(self, other): if isinstance(other, Headers): return self.fields == other.fields @@ -135,6 +150,7 @@ class Headers(object, UserDict.DictMixin): def __ne__(self, other): return not self.__eq__(other) + @always_byte_args("ascii") def get_all(self, name): """ Like :py:meth:`get`, but does not fold multiple headers into a single one. @@ -142,8 +158,8 @@ class Headers(object, UserDict.DictMixin): See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 """ - name = name.lower() - values = [value for n, value in self.fields if n.lower() == name] + name_lower = name.lower() + values = [value for n, value in self.fields if n.lower() == name_lower] return values def set_all(self, name, values): @@ -151,6 +167,8 @@ class Headers(object, UserDict.DictMixin): Explicitly set multiple headers for the given key. See: :py:meth:`get_all` """ + name = always_bytes(name, "ascii") + values = (always_bytes(value, "ascii") for value in values) if name in self: del self[name] self.fields.extend( @@ -172,28 +190,6 @@ class Headers(object, UserDict.DictMixin): return cls([list(field) for field in state]) -class ProtocolMixin(object): - def read_request(self, *args, **kwargs): # pragma: no cover - raise NotImplementedError - - def read_response(self, *args, **kwargs): # pragma: no cover - raise NotImplementedError - - def assemble(self, message): - if isinstance(message, Request): - return self.assemble_request(message) - elif isinstance(message, Response): - return self.assemble_response(message) - else: - raise ValueError("HTTP message not supported.") - - def assemble_request(self, *args, **kwargs): # pragma: no cover - raise NotImplementedError - - def assemble_response(self, *args, **kwargs): # pragma: no cover - raise NotImplementedError - - class Request(object): # This list is adopted legacy code. # We probably don't need to strip off keep-alive. @@ -248,42 +244,14 @@ class Request(object): return False def __repr__(self): - # return "Request(%s - %s, %s)" % (self.method, self.host, self.path) - - return "<HTTPRequest: {0}>".format( - self.legacy_first_line()[:-9] - ) - - def legacy_first_line(self, form=None): - if form is None: - form = self.form_out - if form == "relative": - return '%s %s HTTP/%s.%s' % ( - self.method, - self.path, - self.httpversion[0], - self.httpversion[1], - ) - elif form == "authority": - return '%s %s:%s HTTP/%s.%s' % ( - self.method, - self.host, - self.port, - self.httpversion[0], - self.httpversion[1], - ) - elif form == "absolute": - return '%s %s://%s:%s%s HTTP/%s.%s' % ( - self.method, - self.scheme, - self.host, - self.port, - self.path, - self.httpversion[0], - self.httpversion[1], - ) + if self.host and self.port: + hostport = "{}:{}".format(self.host, self.port) else: - raise exceptions.HttpError(400, "Invalid request form") + hostport = "" + path = self.path or "" + return "HTTPRequest({} {}{})".format( + self.method, hostport, path + ) def anticache(self): """ @@ -336,7 +304,7 @@ class Request(object): return self.get_form_urlencoded() elif HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): return self.get_form_multipart() - return odict.ODict([]) + return ODict([]) def get_form_urlencoded(self): """ @@ -345,16 +313,16 @@ class Request(object): indicates non-form data. """ if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): - return odict.ODict(utils.urldecode(self.body)) - return odict.ODict([]) + return ODict(utils.urldecode(self.body)) + return ODict([]) def get_form_multipart(self): if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): - return odict.ODict( + return ODict( utils.multipartdecode( self.headers, self.body)) - return odict.ODict([]) + return ODict([]) def set_form_urlencoded(self, odict): """ @@ -373,8 +341,8 @@ class Request(object): Components are unquoted. """ - _, _, path, _, _, _ = urlparse.urlparse(self.url) - return [urllib.unquote(i) for i in path.split("/") if i] + _, _, path, _, _, _ = urllib.parse.urlparse(self.url) + return [urllib.parse.unquote(i) for i in path.split(b"/") if i] def set_path_components(self, lst): """ @@ -382,10 +350,10 @@ class Request(object): Components are quoted. """ - lst = [urllib.quote(i, safe="") for i in lst] - path = "/" + "/".join(lst) - scheme, netloc, _, params, query, fragment = urlparse.urlparse(self.url) - self.url = urlparse.urlunparse( + lst = [urllib.parse.quote(i, safe="") for i in lst] + path = b"/" + b"/".join(lst) + scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url) + self.url = urllib.parse.urlunparse( [scheme, netloc, path, params, query, fragment] ) @@ -393,18 +361,18 @@ class Request(object): """ Gets the request query string. Returns an ODict object. """ - _, _, _, _, query, _ = urlparse.urlparse(self.url) + _, _, _, _, query, _ = urllib.parse.urlparse(self.url) if query: - return odict.ODict(utils.urldecode(query)) - return odict.ODict([]) + return ODict(utils.urldecode(query)) + return ODict([]) def set_query(self, odict): """ Takes an ODict object, and sets the request query string. """ - scheme, netloc, path, params, _, fragment = urlparse.urlparse(self.url) + scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) query = utils.urlencode(odict.lst) - self.url = urlparse.urlunparse( + self.url = urllib.parse.urlunparse( [scheme, netloc, path, params, query, fragment] ) @@ -421,18 +389,13 @@ class Request(object): but not the resolved name. This is disabled by default, as an attacker may spoof the host header to confuse an analyst. """ - host = None - if hostheader: - host = self.headers.get("Host") - if not host: - host = self.host - if host: + if hostheader and b"Host" in self.headers: try: - return host.encode("idna") + return self.headers[b"Host"].decode("idna") except ValueError: - return host - else: - return None + pass + if self.host: + return self.host.decode("idna") def pretty_url(self, hostheader): if self.form_out == "authority": # upstream proxy mode @@ -446,7 +409,7 @@ class Request(object): """ Returns a possibly empty netlib.odict.ODict object. """ - ret = odict.ODict() + ret = ODict() for i in self.headers.get_all("cookie"): ret.extend(cookies.parse_cookie_header(i)) return ret @@ -477,8 +440,10 @@ class Request(object): Parses a URL specification, and updates the Request's information accordingly. - Returns False if the URL was invalid, True if the request succeeded. + Raises: + ValueError if the URL was invalid """ + # TODO: Should handle incoming unicode here. parts = utils.parse_url(url) if not parts: raise ValueError("Invalid URL: %s" % url) @@ -495,32 +460,6 @@ class Request(object): self.body = content -class EmptyRequest(Request): - def __init__( - self, - form_in="", - method="", - scheme="", - host="", - port="", - path="", - httpversion=(0, 0), - headers=None, - body="" - ): - super(EmptyRequest, self).__init__( - form_in=form_in, - method=method, - scheme=scheme, - host=host, - port=port, - path=path, - httpversion=httpversion, - headers=headers, - body=body, - ) - - class Response(object): _headers_to_strip_off = [ 'Proxy-Connection', @@ -591,7 +530,7 @@ class Response(object): if v: name, value, attrs = v ret.append([name, [value, attrs]]) - return odict.ODict(ret) + return ODict(ret) def set_cookies(self, odict): """ diff --git a/netlib/tcp.py b/netlib/tcp.py index 4a7f6153..1eb417b4 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -834,14 +834,14 @@ class TCPServer(object): # If a thread has persisted after interpreter exit, the module might be # none. if traceback: - exc = traceback.format_exc() - print('-' * 40, file=fp) + exc = six.text_type(traceback.format_exc()) + print(u'-' * 40, file=fp) print( - "Error in processing of request from %s:%s" % ( + u"Error in processing of request from %s:%s" % ( client_address.host, client_address.port ), file=fp) print(exc, file=fp) - print('-' * 40, file=fp) + print(u'-' * 40, file=fp) def handle_client_connection(self, conn, client_address): # pragma: no cover """ diff --git a/netlib/tutils.py b/netlib/tutils.py index 951ef3d9..65c4a313 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -1,9 +1,11 @@ -import cStringIO +from io import BytesIO import tempfile import os import time import shutil from contextlib import contextmanager +import six +import sys from netlib import tcp, utils, http @@ -12,7 +14,7 @@ def treader(bytes): """ Construct a tcp.Read object from bytes. """ - fp = cStringIO.StringIO(bytes) + fp = BytesIO(bytes) return tcp.Reader(fp) @@ -28,7 +30,24 @@ def tmpdir(*args, **kwargs): shutil.rmtree(temp_workdir) -def raises(exc, obj, *args, **kwargs): +def _check_exception(expected, actual, exc_tb): + if isinstance(expected, six.string_types): + if expected.lower() not in str(actual).lower(): + six.reraise(AssertionError, AssertionError( + "Expected %s, but caught %s" % ( + repr(str(expected)), actual + ) + ), exc_tb) + else: + if not isinstance(actual, expected): + six.reraise(AssertionError, AssertionError( + "Expected %s, but caught %s %s" % ( + expected.__name__, actual.__class__.__name__, str(actual) + ) + ), exc_tb) + + +def raises(expected_exception, obj=None, *args, **kwargs): """ Assert that a callable raises a specified exception. @@ -43,28 +62,31 @@ def raises(exc, obj, *args, **kwargs): :kwargs Arguments to be passed to the callable. """ - try: - ret = obj(*args, **kwargs) - except Exception as v: - if isinstance(exc, basestring): - if exc.lower() in str(v).lower(): - return - else: - raise AssertionError( - "Expected %s, but caught %s" % ( - repr(str(exc)), v - ) - ) + if obj is None: + return RaisesContext(expected_exception) + else: + try: + ret = obj(*args, **kwargs) + except Exception as actual: + _check_exception(expected_exception, actual, sys.exc_info()[2]) else: - if isinstance(v, exc): - return - else: - raise AssertionError( - "Expected %s, but caught %s %s" % ( - exc.__name__, v.__class__.__name__, str(v) - ) - ) - raise AssertionError("No exception raised. Return value: {}".format(ret)) + raise AssertionError("No exception raised. Return value: {}".format(ret)) + + +class RaisesContext(object): + def __init__(self, expected_exception): + self.expected_exception = expected_exception + + def __enter__(self): + return + + def __exit__(self, exc_type, exc_val, exc_tb): + if not exc_type: + raise AssertionError("No exception raised.") + else: + _check_exception(self.expected_exception, exc_val, exc_tb) + return True + test_data = utils.Data(__name__) diff --git a/netlib/utils.py b/netlib/utils.py index d6774419..fb579cac 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -1,17 +1,17 @@ -from __future__ import (absolute_import, print_function, division) +from __future__ import absolute_import, print_function, division import os.path -import cgi -import urllib -import urlparse -import string import re -import six +import string import unicodedata +import six + +from six.moves import urllib + -def isascii(s): +def isascii(bytes): try: - s.decode("ascii") + bytes.decode("ascii") except ValueError: return False return True @@ -44,8 +44,8 @@ def clean_bin(s, keep_spacing=True): else: keep = b"" return b"".join( - ch if (31 < ord(ch) < 127 or ch in keep) else b"." - for ch in s + six.int2byte(ch) if (31 < ch < 127 or ch in keep) else b"." + for ch in six.iterbytes(s) ) @@ -149,10 +149,7 @@ class Data(object): return fullpath -def is_valid_port(port): - if not 0 <= port <= 65535: - return False - return True +_label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE) def is_valid_host(host): @@ -160,53 +157,79 @@ def is_valid_host(host): host.decode("idna") except ValueError: return False - if "\0" in host: - return None - return True + if len(host) > 255: + return False + if host[-1] == ".": + host = host[:-1] + return all(_label_valid.match(x) for x in host.split(b".")) + + +def is_valid_port(port): + return 0 <= port <= 65535 + + +# PY2 workaround +def decode_parse_result(result, enc): + if hasattr(result, "decode"): + return result.decode(enc) + else: + return urllib.parse.ParseResult(*[x.decode(enc) for x in result]) + + +# PY2 workaround +def encode_parse_result(result, enc): + if hasattr(result, "encode"): + return result.encode(enc) + else: + return urllib.parse.ParseResult(*[x.encode(enc) for x in result]) def parse_url(url): """ - Returns a (scheme, host, port, path) tuple, or None on error. + URL-parsing function that checks that + - port is an integer 0-65535 + - host is a valid IDNA-encoded hostname with no null-bytes + - path is valid ASCII - Checks that: - port is an integer 0-65535 - host is a valid IDNA-encoded hostname with no null-bytes - path is valid ASCII + Args: + A URL (as bytes or as unicode) + + Returns: + A (scheme, host, port, path) tuple + + Raises: + ValueError, if the URL is not properly formatted. """ - try: - scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) - except ValueError: - return None - if not scheme: - return None - if '@' in netloc: - # FIXME: Consider what to do with the discarded credentials here Most - # probably we should extend the signature to return these as a separate - # value. - _, netloc = string.rsplit(netloc, '@', maxsplit=1) - if ':' in netloc: - host, port = string.rsplit(netloc, ':', maxsplit=1) - try: - port = int(port) - except ValueError: - return None + parsed = urllib.parse.urlparse(url) + + if not parsed.hostname: + raise ValueError("No hostname given") + + if isinstance(url, six.binary_type): + host = parsed.hostname + + # this should not raise a ValueError + decode_parse_result(parsed, "ascii") else: - host = netloc - if scheme.endswith("https"): - port = 443 - else: - port = 80 - path = urlparse.urlunparse(('', '', path, params, query, fragment)) - if not path.startswith("/"): - path = "/" + path + host = parsed.hostname.encode("idna") + parsed = encode_parse_result(parsed, "ascii") + + port = parsed.port + if not port: + port = 443 if parsed.scheme == b"https" else 80 + + full_path = urllib.parse.urlunparse( + (b"", b"", parsed.path, parsed.params, parsed.query, parsed.fragment) + ) + if not full_path.startswith(b"/"): + full_path = b"/" + full_path + if not is_valid_host(host): - return None - if not isascii(path): - return None + raise ValueError("Invalid Host") if not is_valid_port(port): - return None - return scheme, host, port, path + raise ValueError("Invalid Port") + + return parsed.scheme, host, port, full_path def get_header_tokens(headers, key): @@ -217,7 +240,7 @@ def get_header_tokens(headers, key): """ if key not in headers: return [] - tokens = headers[key].split(",") + tokens = headers[key].split(b",") return [token.strip() for token in tokens] @@ -228,7 +251,7 @@ def hostport(scheme, host, port): if (port, scheme) in [(80, "http"), (443, "https")]: return host else: - return "%s:%s" % (host, port) + return b"%s:%s" % (host, port) def unparse_url(scheme, host, port, path=""): @@ -243,14 +266,14 @@ def urlencode(s): Takes a list of (key, value) tuples and returns a urlencoded string. """ s = [tuple(i) for i in s] - return urllib.urlencode(s, False) + return urllib.parse.urlencode(s, False) def urldecode(s): """ Takes a urlencoded string and returns a list of (key, value) tuples. """ - return cgi.parse_qsl(s, keep_blank_values=True) + return urllib.parse.parse_qsl(s, keep_blank_values=True) def parse_content_type(c): @@ -267,14 +290,14 @@ def parse_content_type(c): ("text", "html", {"charset": "UTF-8"}) """ - parts = c.split(";", 1) - ts = parts[0].split("/", 1) + parts = c.split(b";", 1) + ts = parts[0].split(b"/", 1) if len(ts) != 2: return None d = {} if len(parts) == 2: - for i in parts[1].split(";"): - clause = i.split("=", 1) + for i in parts[1].split(b";"): + clause = i.split(b"=", 1) if len(clause) == 2: d[clause[0].strip()] = clause[1].strip() return ts[0].lower(), ts[1].lower(), d @@ -289,7 +312,7 @@ def multipartdecode(headers, content): v = parse_content_type(v) if not v: return [] - boundary = v[2].get("boundary") + boundary = v[2].get(b"boundary") if not boundary: return [] @@ -306,3 +329,20 @@ def multipartdecode(headers, content): r.append((key, value)) return r return [] + + +def always_bytes(unicode_or_bytes, encoding): + if isinstance(unicode_or_bytes, six.text_type): + return unicode_or_bytes.encode(encoding) + return unicode_or_bytes + + +def always_byte_args(encoding): + """Decorator that transparently encodes all arguments passed as unicode""" + def decorator(fun): + def _fun(*args, **kwargs): + args = [always_bytes(arg, encoding) for arg in args] + kwargs = {k: always_bytes(v, encoding) for k, v in six.iteritems(kwargs)} + return fun(*args, **kwargs) + return _fun + return decorator diff --git a/netlib/version_check.py b/netlib/version_check.py index 1d7e025c..9cf27eea 100644 --- a/netlib/version_check.py +++ b/netlib/version_check.py @@ -7,6 +7,7 @@ from __future__ import division, absolute_import, print_function import sys import inspect import os.path +import six import OpenSSL from . import version @@ -19,8 +20,8 @@ def check_mitmproxy_version(mitmproxy_version, fp=sys.stderr): # consider major and minor version. if version.IVERSION[:2] != mitmproxy_version[:2]: print( - "You are using mitmproxy %s with netlib %s. " - "Most likely, that won't work - please upgrade!" % ( + u"You are using mitmproxy %s with netlib %s. " + u"Most likely, that won't work - please upgrade!" % ( mitmproxy_version, version.VERSION ), file=fp @@ -29,13 +30,13 @@ def check_mitmproxy_version(mitmproxy_version, fp=sys.stderr): def check_pyopenssl_version(min_version=PYOPENSSL_MIN_VERSION, fp=sys.stderr): - min_version_str = ".".join(str(x) for x in min_version) + min_version_str = u".".join(six.text_type(x) for x in min_version) try: v = tuple(int(x) for x in OpenSSL.__version__.split(".")[:2]) except ValueError: print( - "Cannot parse pyOpenSSL version: {}" - "mitmproxy requires pyOpenSSL {} or greater.".format( + u"Cannot parse pyOpenSSL version: {}" + u"mitmproxy requires pyOpenSSL {} or greater.".format( OpenSSL.__version__, min_version_str ), file=fp @@ -43,15 +44,15 @@ def check_pyopenssl_version(min_version=PYOPENSSL_MIN_VERSION, fp=sys.stderr): return if v < min_version: print( - "You are using an outdated version of pyOpenSSL: " - "mitmproxy requires pyOpenSSL {} or greater.".format(min_version_str), + u"You are using an outdated version of pyOpenSSL: " + u"mitmproxy requires pyOpenSSL {} or greater.".format(min_version_str), file=fp ) # Some users apparently have multiple versions of pyOpenSSL installed. # Report which one we got. pyopenssl_path = os.path.dirname(inspect.getfile(OpenSSL)) print( - "Your pyOpenSSL {} installation is located at {}".format( + u"Your pyOpenSSL {} installation is located at {}".format( OpenSSL.__version__, pyopenssl_path ), file=fp diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py index 5acf7696..1c143919 100644 --- a/netlib/websockets/__init__.py +++ b/netlib/websockets/__init__.py @@ -1,2 +1,2 @@ -from frame import * -from protocol import * +from .frame import * +from .protocol import * |