From f50deb7b763d093a22a4d331e16465a2fb0329cf Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Tue, 14 Jul 2015 23:02:14 +0200 Subject: move bits around --- netlib/http.py | 583 ------------------------------ netlib/http/__init__.py | 2 + netlib/http/authentication.py | 149 ++++++++ netlib/http/cookies.py | 193 ++++++++++ netlib/http/exceptions.py | 9 + netlib/http/http1/__init__.py | 1 + netlib/http/http1/protocol.py | 518 +++++++++++++++++++++++++++ netlib/http/http2/__init__.py | 2 + netlib/http/http2/frame.py | 636 +++++++++++++++++++++++++++++++++ netlib/http/http2/protocol.py | 240 +++++++++++++ netlib/http/semantics.py | 94 +++++ netlib/http/status_codes.py | 104 ++++++ netlib/http/user_agents.py | 52 +++ netlib/http2/__init__.py | 2 - netlib/http2/frame.py | 636 --------------------------------- netlib/http2/protocol.py | 240 ------------- netlib/http_auth.py | 148 -------- netlib/http_cookies.py | 193 ---------- netlib/http_semantics.py | 23 -- netlib/http_status.py | 104 ------ netlib/http_uastrings.py | 52 --- netlib/websockets/frame.py | 2 +- netlib/websockets/protocol.py | 2 +- test/http/__init__.py | 0 test/http/http1/__init__.py | 0 test/http/http1/test_protocol.py | 445 +++++++++++++++++++++++ test/http/http2/__init__.py | 0 test/http/http2/test_frames.py | 704 +++++++++++++++++++++++++++++++++++++ test/http/http2/test_protocol.py | 325 +++++++++++++++++ test/http/test_authentication.py | 110 ++++++ test/http/test_cookies.py | 219 ++++++++++++ test/http/test_semantics.py | 54 +++ test/http/test_user_agents.py | 6 + test/http2/__init__.py | 0 test/http2/test_frames.py | 704 ------------------------------------- test/http2/test_protocol.py | 326 ----------------- test/test_http.py | 491 -------------------------- test/test_http_auth.py | 109 ------ test/test_http_cookies.py | 219 ------------ test/test_http_uastrings.py | 6 - test/test_websockets.py | 261 -------------- test/websockets/__init__.py | 0 test/websockets/test_websockets.py | 262 ++++++++++++++ 43 files changed, 4127 insertions(+), 4099 deletions(-) delete mode 100644 netlib/http.py create mode 100644 netlib/http/__init__.py create mode 100644 netlib/http/authentication.py create mode 100644 netlib/http/cookies.py create mode 100644 netlib/http/exceptions.py create mode 100644 netlib/http/http1/__init__.py create mode 100644 netlib/http/http1/protocol.py create mode 100644 netlib/http/http2/__init__.py create mode 100644 netlib/http/http2/frame.py create mode 100644 netlib/http/http2/protocol.py create mode 100644 netlib/http/semantics.py create mode 100644 netlib/http/status_codes.py create mode 100644 netlib/http/user_agents.py delete mode 100644 netlib/http2/__init__.py delete mode 100644 netlib/http2/frame.py delete mode 100644 netlib/http2/protocol.py delete mode 100644 netlib/http_auth.py delete mode 100644 netlib/http_cookies.py delete mode 100644 netlib/http_semantics.py delete mode 100644 netlib/http_status.py delete mode 100644 netlib/http_uastrings.py create mode 100644 test/http/__init__.py create mode 100644 test/http/http1/__init__.py create mode 100644 test/http/http1/test_protocol.py create mode 100644 test/http/http2/__init__.py create mode 100644 test/http/http2/test_frames.py create mode 100644 test/http/http2/test_protocol.py create mode 100644 test/http/test_authentication.py create mode 100644 test/http/test_cookies.py create mode 100644 test/http/test_semantics.py create mode 100644 test/http/test_user_agents.py delete mode 100644 test/http2/__init__.py delete mode 100644 test/http2/test_frames.py delete mode 100644 test/http2/test_protocol.py delete mode 100644 test/test_http.py delete mode 100644 test/test_http_auth.py delete mode 100644 test/test_http_cookies.py delete mode 100644 test/test_http_uastrings.py delete mode 100644 test/test_websockets.py create mode 100644 test/websockets/__init__.py create mode 100644 test/websockets/test_websockets.py diff --git a/netlib/http.py b/netlib/http.py deleted file mode 100644 index 073e9a3f..00000000 --- a/netlib/http.py +++ /dev/null @@ -1,583 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import collections -import string -import urlparse -import binascii -import sys -from . import odict, utils, tcp, http_semantics, http_status - - -class HttpError(Exception): - - def __init__(self, code, message): - super(HttpError, self).__init__(message) - self.code = code - - -class HttpErrorConnClosed(HttpError): - pass - - -def _is_valid_port(port): - if not 0 <= port <= 65535: - return False - return True - - -def _is_valid_host(host): - try: - host.decode("idna") - except ValueError: - return False - if "\0" in host: - return None - return True - - -def get_request_line(fp): - """ - Get a line, possibly preceded by a blank. - """ - line = fp.readline() - if line == "\r\n" or line == "\n": - # Possible leftover from previous message - line = fp.readline() - return line - - -def parse_url(url): - """ - Returns a (scheme, host, port, path) tuple, or None on error. - - Checks that: - port is an integer 0-65535 - host is a valid IDNA-encoded hostname with no null-bytes - path is valid ASCII - """ - try: - scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) - except ValueError: - return None - if not scheme: - return None - if '@' in netloc: - # FIXME: Consider what to do with the discarded credentials here Most - # probably we should extend the signature to return these as a separate - # value. - _, netloc = string.rsplit(netloc, '@', maxsplit=1) - if ':' in netloc: - host, port = string.rsplit(netloc, ':', maxsplit=1) - try: - port = int(port) - except ValueError: - return None - else: - host = netloc - if scheme == "https": - port = 443 - else: - port = 80 - path = urlparse.urlunparse(('', '', path, params, query, fragment)) - if not path.startswith("/"): - path = "/" + path - if not _is_valid_host(host): - return None - if not utils.isascii(path): - return None - if not _is_valid_port(port): - return None - return scheme, host, port, path - - -def read_headers(fp): - """ - Read a set of headers from a file pointer. Stop once a blank line is - reached. Return a ODictCaseless object, or None if headers are invalid. - """ - ret = [] - name = '' - while True: - line = fp.readline() - if not line or line == '\r\n' or line == '\n': - break - if line[0] in ' \t': - if not ret: - return None - # continued header - ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() - else: - i = line.find(':') - # We're being liberal in what we accept, here. - if i > 0: - name = line[:i] - value = line[i + 1:].strip() - ret.append([name, value]) - else: - return None - return odict.ODictCaseless(ret) - - -def read_chunked(fp, limit, is_request): - """ - Read a chunked HTTP body. - - May raise HttpError. - """ - # FIXME: Should check if chunked is the final encoding in the headers - # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 - # 3.3 2. - total = 0 - code = 400 if is_request else 502 - while True: - line = fp.readline(128) - if line == "": - raise HttpErrorConnClosed(code, "Connection closed prematurely") - if line != '\r\n' and line != '\n': - try: - length = int(line, 16) - except ValueError: - raise HttpError( - code, - "Invalid chunked encoding length: %s" % line - ) - total += length - if limit is not None and total > limit: - msg = "HTTP Body too large. Limit is %s," \ - " chunked content longer than %s" % (limit, total) - raise HttpError(code, msg) - chunk = fp.read(length) - suffix = fp.readline(5) - if suffix != '\r\n': - raise HttpError(code, "Malformed chunked body") - yield line, chunk, '\r\n' - if length == 0: - return - - -def get_header_tokens(headers, key): - """ - Retrieve all tokens for a header key. A number of different headers - follow a pattern where each header line can containe comma-separated - tokens, and headers can be set multiple times. - """ - toks = [] - for i in headers[key]: - for j in i.split(","): - toks.append(j.strip()) - return toks - - -def has_chunked_encoding(headers): - return "chunked" in [ - i.lower() for i in get_header_tokens(headers, "transfer-encoding") - ] - - -def parse_http_protocol(s): - """ - Parse an HTTP protocol declaration. Returns a (major, minor) tuple, or - None. - """ - if not s.startswith("HTTP/"): - return None - _, version = s.split('/', 1) - if "." not in version: - return None - major, minor = version.split('.', 1) - try: - major = int(major) - minor = int(minor) - except ValueError: - return None - return major, minor - - -def parse_http_basic_auth(s): - words = s.split() - if len(words) != 2: - return None - scheme = words[0] - try: - user = binascii.a2b_base64(words[1]) - except binascii.Error: - return None - parts = user.split(':') - if len(parts) != 2: - return None - return scheme, parts[0], parts[1] - - -def assemble_http_basic_auth(scheme, username, password): - v = binascii.b2a_base64(username + ":" + password) - return scheme + " " + v - - -def parse_init(line): - try: - method, url, protocol = string.split(line) - except ValueError: - return None - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None - if not utils.isascii(method): - return None - return method, url, httpversion - - -def parse_init_connect(line): - """ - Returns (host, port, httpversion) if line is a valid CONNECT line. - http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 - """ - v = parse_init(line) - if not v: - return None - method, url, httpversion = v - - if method.upper() != 'CONNECT': - return None - try: - host, port = url.split(":") - except ValueError: - return None - try: - port = int(port) - except ValueError: - return None - if not _is_valid_port(port): - return None - if not _is_valid_host(host): - return None - return host, port, httpversion - - -def parse_init_proxy(line): - v = parse_init(line) - if not v: - return None - method, url, httpversion = v - - parts = parse_url(url) - if not parts: - return None - scheme, host, port, path = parts - return method, scheme, host, port, path, httpversion - - -def parse_init_http(line): - """ - Returns (method, url, httpversion) - """ - v = parse_init(line) - if not v: - return None - method, url, httpversion = v - if not utils.isascii(url): - return None - if not (url.startswith("/") or url == "*"): - return None - return method, url, httpversion - - -def connection_close(httpversion, headers): - """ - Checks the message to see if the client connection should be closed - according to RFC 2616 Section 8.1 Note that a connection should be - closed as well if the response has been read until end of the stream. - """ - # At first, check if we have an explicit Connection header. - if "connection" in headers: - toks = get_header_tokens(headers, "connection") - if "close" in toks: - return True - elif "keep-alive" in toks: - return False - # If we don't have a Connection header, HTTP 1.1 connections are assumed to - # be persistent - if httpversion == (1, 1): - return False - return True - - -def parse_response_line(line): - parts = line.strip().split(" ", 2) - if len(parts) == 2: # handle missing message gracefully - parts.append("") - if len(parts) != 3: - return None - proto, code, msg = parts - try: - code = int(code) - except ValueError: - return None - return (proto, code, msg) - - -def read_http_body(*args, **kwargs): - return "".join( - content for _, content, _ in read_http_body_chunked(*args, **kwargs) - ) - - -def read_http_body_chunked( - rfile, - headers, - limit, - request_method, - response_code, - is_request, - max_chunk_size=None -): - """ - Read an HTTP message body: - - rfile: A file descriptor to read from - headers: An ODictCaseless object - limit: Size limit. - is_request: True if the body to read belongs to a request, False - otherwise - """ - if max_chunk_size is None: - max_chunk_size = limit or sys.maxsize - - expected_size = expected_http_body_size( - headers, is_request, request_method, response_code - ) - - if expected_size is None: - if has_chunked_encoding(headers): - # Python 3: yield from - for x in read_chunked(rfile, limit, is_request): - yield x - else: # pragma: nocover - raise HttpError( - 400 if is_request else 502, - "Content-Length unknown but no chunked encoding" - ) - elif expected_size >= 0: - if limit is not None and expected_size > limit: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s, content-length was %s" % ( - limit, expected_size - ) - ) - bytes_left = expected_size - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - yield "", rfile.read(chunk_size), "" - bytes_left -= chunk_size - else: - bytes_left = limit or -1 - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - content = rfile.read(chunk_size) - if not content: - return - yield "", content, "" - bytes_left -= chunk_size - not_done = rfile.read(1) - if not_done: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s," % limit - ) - - -def expected_http_body_size(headers, is_request, request_method, response_code): - """ - Returns the expected body length: - - a positive integer, if the size is known in advance - - None, if the size in unknown in advance (chunked encoding or invalid - data) - - -1, if all data should be read until end of stream. - - May raise HttpError. - """ - # Determine response size according to - # http://tools.ietf.org/html/rfc7230#section-3.3 - if request_method: - request_method = request_method.upper() - - if (not is_request and ( - request_method == "HEAD" or - (request_method == "CONNECT" and response_code == 200) or - response_code in [204, 304] or - 100 <= response_code <= 199)): - return 0 - if has_chunked_encoding(headers): - return None - if "content-length" in headers: - try: - size = int(headers["content-length"][0]) - if size < 0: - raise ValueError() - return size - except ValueError: - return None - if is_request: - return 0 - return -1 - - -Request = collections.namedtuple( - "Request", - [ - "form_in", - "method", - "scheme", - "host", - "port", - "path", - "httpversion", - "headers", - "content" - ] -) - - -def read_request(rfile, include_body=True, body_size_limit=None, wfile=None): - """ - Parse an HTTP request from a file stream - - Args: - rfile (file): Input file to read from - include_body (bool): Read response body as well - body_size_limit (bool): Maximum body size - wfile (file): If specified, HTTP Expect headers are handled - automatically, by writing a HTTP 100 CONTINUE response to the stream. - - Returns: - Request: The HTTP request - - Raises: - HttpError: If the input is invalid. - """ - httpversion, host, port, scheme, method, path, headers, content = ( - None, None, None, None, None, None, None, None) - - request_line = get_request_line(rfile) - if not request_line: - raise tcp.NetLibDisconnect() - - request_line_parts = parse_init(request_line) - if not request_line_parts: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - method, path, httpversion = request_line_parts - - if path == '*' or path.startswith("/"): - form_in = "relative" - if not utils.isascii(path): - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - elif method.upper() == 'CONNECT': - form_in = "authority" - r = parse_init_connect(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - host, port, _ = r - path = None - else: - form_in = "absolute" - r = parse_init_proxy(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - _, scheme, host, port, path, _ = r - - headers = read_headers(rfile) - if headers is None: - raise HttpError(400, "Invalid headers") - - expect_header = headers.get_first("expect", "").lower() - if expect_header == "100-continue" and httpversion >= (1, 1): - wfile.write( - 'HTTP/1.1 100 Continue\r\n' - '\r\n' - ) - wfile.flush() - del headers['expect'] - - if include_body: - content = read_http_body( - rfile, headers, body_size_limit, method, None, True - ) - - return Request( - form_in, - method, - scheme, - host, - port, - path, - httpversion, - headers, - content - ) - - -def read_response(rfile, request_method, body_size_limit, include_body=True): - """ - Return an (httpversion, code, msg, headers, content) tuple. - - By default, both response header and body are read. - If include_body=False is specified, content may be one of the - following: - - None, if the response is technically allowed to have a response body - - "", if the response must not have a response body (e.g. it's a - response to a HEAD request) - """ - line = rfile.readline() - # Possible leftover from previous message - if line == "\r\n" or line == "\n": - line = rfile.readline() - if not line: - raise HttpErrorConnClosed(502, "Server disconnect.") - parts = parse_response_line(line) - if not parts: - raise HttpError(502, "Invalid server response: %s" % repr(line)) - proto, code, msg = parts - httpversion = parse_http_protocol(proto) - if httpversion is None: - raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) - headers = read_headers(rfile) - if headers is None: - raise HttpError(502, "Invalid headers.") - - if include_body: - content = read_http_body( - rfile, - headers, - body_size_limit, - request_method, - code, - False - ) - else: - # if include_body==False then a None content means the body should be - # read separately - content = None - return http_semantics.Response(httpversion, code, msg, headers, content) - - -def request_preamble(method, resource, http_major="1", http_minor="1"): - return '%s %s HTTP/%s.%s' % ( - method, resource, http_major, http_minor - ) - - -def response_preamble(code, message=None, http_major="1", http_minor="1"): - if message is None: - message = http_status.RESPONSES.get(code) - return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py new file mode 100644 index 00000000..9b4b0e6b --- /dev/null +++ b/netlib/http/__init__.py @@ -0,0 +1,2 @@ +from exceptions import * +from semantics import * diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py new file mode 100644 index 00000000..26e3c2c4 --- /dev/null +++ b/netlib/http/authentication.py @@ -0,0 +1,149 @@ +from __future__ import (absolute_import, print_function, division) +from argparse import Action, ArgumentTypeError + +from .. import http + + +class NullProxyAuth(object): + + """ + No proxy auth at all (returns empty challange headers) + """ + + def __init__(self, password_manager): + self.password_manager = password_manager + + def clean(self, headers_): + """ + Clean up authentication headers, so they're not passed upstream. + """ + pass + + def authenticate(self, headers_): + """ + Tests that the user is allowed to use the proxy + """ + return True + + def auth_challenge_headers(self): + """ + Returns a dictionary containing the headers require to challenge the user + """ + return {} + + +class BasicProxyAuth(NullProxyAuth): + CHALLENGE_HEADER = 'Proxy-Authenticate' + AUTH_HEADER = 'Proxy-Authorization' + + def __init__(self, password_manager, realm): + NullProxyAuth.__init__(self, password_manager) + self.realm = realm + + def clean(self, headers): + del headers[self.AUTH_HEADER] + + def authenticate(self, headers): + auth_value = headers.get(self.AUTH_HEADER, []) + if not auth_value: + return False + parts = http.http1.parse_http_basic_auth(auth_value[0]) + if not parts: + return False + scheme, username, password = parts + if scheme.lower() != 'basic': + return False + if not self.password_manager.test(username, password): + return False + self.username = username + return True + + def auth_challenge_headers(self): + return {self.CHALLENGE_HEADER: 'Basic realm="%s"' % self.realm} + + +class PassMan(object): + + def test(self, username_, password_token_): + return False + + +class PassManNonAnon(PassMan): + + """ + Ensure the user specifies a username, accept any password. + """ + + def test(self, username, password_token_): + if username: + return True + return False + + +class PassManHtpasswd(PassMan): + + """ + Read usernames and passwords from an htpasswd file + """ + + def __init__(self, path): + """ + Raises ValueError if htpasswd file is invalid. + """ + import passlib.apache + self.htpasswd = passlib.apache.HtpasswdFile(path) + + def test(self, username, password_token): + return bool(self.htpasswd.check_password(username, password_token)) + + +class PassManSingleUser(PassMan): + + def __init__(self, username, password): + self.username, self.password = username, password + + def test(self, username, password_token): + return self.username == username and self.password == password_token + + +class AuthAction(Action): + + """ + Helper class to allow seamless integration int argparse. Example usage: + parser.add_argument( + "--nonanonymous", + action=NonanonymousAuthAction, nargs=0, + help="Allow access to any user long as a credentials are specified." + ) + """ + + def __call__(self, parser, namespace, values, option_string=None): + passman = self.getPasswordManager(values) + authenticator = BasicProxyAuth(passman, "mitmproxy") + setattr(namespace, self.dest, authenticator) + + def getPasswordManager(self, s): # pragma: nocover + raise NotImplementedError() + + +class SingleuserAuthAction(AuthAction): + + def getPasswordManager(self, s): + if len(s.split(':')) != 2: + raise ArgumentTypeError( + "Invalid single-user specification. Please use the format username:password" + ) + username, password = s.split(':') + return PassManSingleUser(username, password) + + +class NonanonymousAuthAction(AuthAction): + + def getPasswordManager(self, s): + return PassManNonAnon() + + +class HtpasswdAuthAction(AuthAction): + + def getPasswordManager(self, s): + return PassManHtpasswd(s) diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py new file mode 100644 index 00000000..b77e3503 --- /dev/null +++ b/netlib/http/cookies.py @@ -0,0 +1,193 @@ +import re + +from .. import odict + +""" +A flexible module for cookie parsing and manipulation. + +This module differs from usual standards-compliant cookie modules in a number +of ways. We try to be as permissive as possible, and to retain even mal-formed +information. Duplicate cookies are preserved in parsing, and can be set in +formatting. We do attempt to escape and quote values where needed, but will not +reject data that violate the specs. + +Parsing accepts the formats in RFC6265 and partially RFC2109 and RFC2965. We do +not parse the comma-separated variant of Set-Cookie that allows multiple +cookies to be set in a single header. Technically this should be feasible, but +it turns out that violations of RFC6265 that makes the parsing problem +indeterminate are much more common than genuine occurences of the multi-cookie +variants. Serialization follows RFC6265. + + http://tools.ietf.org/html/rfc6265 + http://tools.ietf.org/html/rfc2109 + http://tools.ietf.org/html/rfc2965 +""" + +# TODO +# - Disallow LHS-only Cookie values + + +def _read_until(s, start, term): + """ + Read until one of the characters in term is reached. + """ + if start == len(s): + return "", start + 1 + for i in range(start, len(s)): + if s[i] in term: + return s[start:i], i + return s[start:i + 1], i + 1 + + +def _read_token(s, start): + """ + Read a token - the LHS of a token/value pair in a cookie. + """ + return _read_until(s, start, ";=") + + +def _read_quoted_string(s, start): + """ + start: offset to the first quote of the string to be read + + A sort of loose super-set of the various quoted string specifications. + + RFC6265 disallows backslashes or double quotes within quoted strings. + Prior RFCs use backslashes to escape. This leaves us free to apply + backslash escaping by default and be compatible with everything. + """ + escaping = False + ret = [] + # Skip the first quote + for i in range(start + 1, len(s)): + if escaping: + ret.append(s[i]) + escaping = False + elif s[i] == '"': + break + elif s[i] == "\\": + escaping = True + else: + ret.append(s[i]) + return "".join(ret), i + 1 + + +def _read_value(s, start, delims): + """ + Reads a value - the RHS of a token/value pair in a cookie. + + special: If the value is special, commas are premitted. Else comma + terminates. This helps us support old and new style values. + """ + if start >= len(s): + return "", start + elif s[start] == '"': + return _read_quoted_string(s, start) + else: + return _read_until(s, start, delims) + + +def _read_pairs(s, off=0): + """ + Read pairs of lhs=rhs values. + + off: start offset + specials: a lower-cased list of keys that may contain commas + """ + vals = [] + while True: + lhs, off = _read_token(s, off) + lhs = lhs.lstrip() + if lhs: + rhs = None + if off < len(s): + if s[off] == "=": + rhs, off = _read_value(s, off + 1, ";") + vals.append([lhs, rhs]) + off += 1 + if not off < len(s): + break + return vals, off + + +def _has_special(s): + for i in s: + if i in '",;\\': + return True + o = ord(i) + if o < 0x21 or o > 0x7e: + return True + return False + + +ESCAPE = re.compile(r"([\"\\])") + + +def _format_pairs(lst, specials=(), sep="; "): + """ + specials: A lower-cased list of keys that will not be quoted. + """ + vals = [] + for k, v in lst: + if v is None: + vals.append(k) + else: + if k.lower() not in specials and _has_special(v): + v = ESCAPE.sub(r"\\\1", v) + v = '"%s"' % v + vals.append("%s=%s" % (k, v)) + return sep.join(vals) + + +def _format_set_cookie_pairs(lst): + return _format_pairs( + lst, + specials=("expires", "path") + ) + + +def _parse_set_cookie_pairs(s): + """ + For Set-Cookie, we support multiple cookies as described in RFC2109. + This function therefore returns a list of lists. + """ + pairs, off_ = _read_pairs(s) + return pairs + + +def parse_set_cookie_header(line): + """ + Parse a Set-Cookie header value + + Returns a (name, value, attrs) tuple, or None, where attrs is an + ODictCaseless set of attributes. No attempt is made to parse attribute + values - they are treated purely as strings. + """ + pairs = _parse_set_cookie_pairs(line) + if pairs: + return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) + + +def format_set_cookie_header(name, value, attrs): + """ + Formats a Set-Cookie header value. + """ + pairs = [[name, value]] + pairs.extend(attrs.lst) + return _format_set_cookie_pairs(pairs) + + +def parse_cookie_header(line): + """ + Parse a Cookie header value. + Returns a (possibly empty) ODict object. + """ + pairs, off_ = _read_pairs(line) + return odict.ODict(pairs) + + +def format_cookie_header(od): + """ + Formats a Cookie header value. + """ + return _format_pairs(od.lst) diff --git a/netlib/http/exceptions.py b/netlib/http/exceptions.py new file mode 100644 index 00000000..8a2bbebc --- /dev/null +++ b/netlib/http/exceptions.py @@ -0,0 +1,9 @@ +class HttpError(Exception): + + def __init__(self, code, message): + super(HttpError, self).__init__(message) + self.code = code + + +class HttpErrorConnClosed(HttpError): + pass diff --git a/netlib/http/http1/__init__.py b/netlib/http/http1/__init__.py new file mode 100644 index 00000000..6b5043af --- /dev/null +++ b/netlib/http/http1/__init__.py @@ -0,0 +1 @@ +from protocol import * diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py new file mode 100644 index 00000000..0f7a0bd3 --- /dev/null +++ b/netlib/http/http1/protocol.py @@ -0,0 +1,518 @@ +from __future__ import (absolute_import, print_function, division) +import binascii +import collections +import string +import sys +import urlparse + +from netlib import odict, utils, tcp, http +from .. import status_codes +from ..exceptions import * + + +def get_request_line(fp): + """ + Get a line, possibly preceded by a blank. + """ + line = fp.readline() + if line == "\r\n" or line == "\n": + # Possible leftover from previous message + line = fp.readline() + return line + +def read_headers(fp): + """ + Read a set of headers from a file pointer. Stop once a blank line is + reached. Return a ODictCaseless object, or None if headers are invalid. + """ + ret = [] + name = '' + while True: + line = fp.readline() + if not line or line == '\r\n' or line == '\n': + break + if line[0] in ' \t': + if not ret: + return None + # continued header + ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() + else: + i = line.find(':') + # We're being liberal in what we accept, here. + if i > 0: + name = line[:i] + value = line[i + 1:].strip() + ret.append([name, value]) + else: + return None + return odict.ODictCaseless(ret) + + +def read_chunked(fp, limit, is_request): + """ + Read a chunked HTTP body. + + May raise HttpError. + """ + # FIXME: Should check if chunked is the final encoding in the headers + # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 + # 3.3 2. + total = 0 + code = 400 if is_request else 502 + while True: + line = fp.readline(128) + if line == "": + raise HttpErrorConnClosed(code, "Connection closed prematurely") + if line != '\r\n' and line != '\n': + try: + length = int(line, 16) + except ValueError: + raise HttpError( + code, + "Invalid chunked encoding length: %s" % line + ) + total += length + if limit is not None and total > limit: + msg = "HTTP Body too large. Limit is %s," \ + " chunked content longer than %s" % (limit, total) + raise HttpError(code, msg) + chunk = fp.read(length) + suffix = fp.readline(5) + if suffix != '\r\n': + raise HttpError(code, "Malformed chunked body") + yield line, chunk, '\r\n' + if length == 0: + return + + +def get_header_tokens(headers, key): + """ + Retrieve all tokens for a header key. A number of different headers + follow a pattern where each header line can containe comma-separated + tokens, and headers can be set multiple times. + """ + toks = [] + for i in headers[key]: + for j in i.split(","): + toks.append(j.strip()) + return toks + + +def has_chunked_encoding(headers): + return "chunked" in [ + i.lower() for i in get_header_tokens(headers, "transfer-encoding") + ] + + +def parse_http_protocol(s): + """ + Parse an HTTP protocol declaration. Returns a (major, minor) tuple, or + None. + """ + if not s.startswith("HTTP/"): + return None + _, version = s.split('/', 1) + if "." not in version: + return None + major, minor = version.split('.', 1) + try: + major = int(major) + minor = int(minor) + except ValueError: + return None + return major, minor + + +def parse_http_basic_auth(s): + # TODO: check if this is HTTP/1 only - otherwise move it to netlib.http.semantics + words = s.split() + if len(words) != 2: + return None + scheme = words[0] + try: + user = binascii.a2b_base64(words[1]) + except binascii.Error: + return None + parts = user.split(':') + if len(parts) != 2: + return None + return scheme, parts[0], parts[1] + + +def assemble_http_basic_auth(scheme, username, password): + # TODO: check if this is HTTP/1 only - otherwise move it to netlib.http.semantics + v = binascii.b2a_base64(username + ":" + password) + return scheme + " " + v + + +def parse_init(line): + try: + method, url, protocol = string.split(line) + except ValueError: + return None + httpversion = parse_http_protocol(protocol) + if not httpversion: + return None + if not utils.isascii(method): + return None + return method, url, httpversion + + +def parse_init_connect(line): + """ + Returns (host, port, httpversion) if line is a valid CONNECT line. + http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 + """ + v = parse_init(line) + if not v: + return None + method, url, httpversion = v + + if method.upper() != 'CONNECT': + return None + try: + host, port = url.split(":") + except ValueError: + return None + try: + port = int(port) + except ValueError: + return None + if not http.is_valid_port(port): + return None + if not http.is_valid_host(host): + return None + return host, port, httpversion + + +def parse_init_proxy(line): + v = parse_init(line) + if not v: + return None + method, url, httpversion = v + + parts = http.parse_url(url) + if not parts: + return None + scheme, host, port, path = parts + return method, scheme, host, port, path, httpversion + + +def parse_init_http(line): + """ + Returns (method, url, httpversion) + """ + v = parse_init(line) + if not v: + return None + method, url, httpversion = v + if not utils.isascii(url): + return None + if not (url.startswith("/") or url == "*"): + return None + return method, url, httpversion + + +def connection_close(httpversion, headers): + """ + Checks the message to see if the client connection should be closed + according to RFC 2616 Section 8.1 Note that a connection should be + closed as well if the response has been read until end of the stream. + """ + # At first, check if we have an explicit Connection header. + if "connection" in headers: + toks = get_header_tokens(headers, "connection") + if "close" in toks: + return True + elif "keep-alive" in toks: + return False + # If we don't have a Connection header, HTTP 1.1 connections are assumed to + # be persistent + if httpversion == (1, 1): + return False + return True + + +def parse_response_line(line): + parts = line.strip().split(" ", 2) + if len(parts) == 2: # handle missing message gracefully + parts.append("") + if len(parts) != 3: + return None + proto, code, msg = parts + try: + code = int(code) + except ValueError: + return None + return (proto, code, msg) + + +def read_http_body(*args, **kwargs): + return "".join( + content for _, content, _ in read_http_body_chunked(*args, **kwargs) + ) + + +def read_http_body_chunked( + rfile, + headers, + limit, + request_method, + response_code, + is_request, + max_chunk_size=None +): + """ + Read an HTTP message body: + + rfile: A file descriptor to read from + headers: An ODictCaseless object + limit: Size limit. + is_request: True if the body to read belongs to a request, False + otherwise + """ + if max_chunk_size is None: + max_chunk_size = limit or sys.maxsize + + expected_size = expected_http_body_size( + headers, is_request, request_method, response_code + ) + + if expected_size is None: + if has_chunked_encoding(headers): + # Python 3: yield from + for x in read_chunked(rfile, limit, is_request): + yield x + else: # pragma: nocover + raise HttpError( + 400 if is_request else 502, + "Content-Length unknown but no chunked encoding" + ) + elif expected_size >= 0: + if limit is not None and expected_size > limit: + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s, content-length was %s" % ( + limit, expected_size + ) + ) + bytes_left = expected_size + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + yield "", rfile.read(chunk_size), "" + bytes_left -= chunk_size + else: + bytes_left = limit or -1 + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + content = rfile.read(chunk_size) + if not content: + return + yield "", content, "" + bytes_left -= chunk_size + not_done = rfile.read(1) + if not_done: + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s," % limit + ) + + +def expected_http_body_size(headers, is_request, request_method, response_code): + """ + Returns the expected body length: + - a positive integer, if the size is known in advance + - None, if the size in unknown in advance (chunked encoding or invalid + data) + - -1, if all data should be read until end of stream. + + May raise HttpError. + """ + # Determine response size according to + # http://tools.ietf.org/html/rfc7230#section-3.3 + if request_method: + request_method = request_method.upper() + + if (not is_request and ( + request_method == "HEAD" or + (request_method == "CONNECT" and response_code == 200) or + response_code in [204, 304] or + 100 <= response_code <= 199)): + return 0 + if has_chunked_encoding(headers): + return None + if "content-length" in headers: + try: + size = int(headers["content-length"][0]) + if size < 0: + raise ValueError() + return size + except ValueError: + return None + if is_request: + return 0 + return -1 + + +# TODO: make this a regular class - just like Response +Request = collections.namedtuple( + "Request", + [ + "form_in", + "method", + "scheme", + "host", + "port", + "path", + "httpversion", + "headers", + "content" + ] +) + + +def read_request(rfile, include_body=True, body_size_limit=None, wfile=None): + """ + Parse an HTTP request from a file stream + + Args: + rfile (file): Input file to read from + include_body (bool): Read response body as well + body_size_limit (bool): Maximum body size + wfile (file): If specified, HTTP Expect headers are handled + automatically, by writing a HTTP 100 CONTINUE response to the stream. + + Returns: + Request: The HTTP request + + Raises: + HttpError: If the input is invalid. + """ + httpversion, host, port, scheme, method, path, headers, content = ( + None, None, None, None, None, None, None, None) + + request_line = get_request_line(rfile) + if not request_line: + raise tcp.NetLibDisconnect() + + request_line_parts = parse_init(request_line) + if not request_line_parts: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + method, path, httpversion = request_line_parts + + if path == '*' or path.startswith("/"): + form_in = "relative" + if not utils.isascii(path): + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + elif method.upper() == 'CONNECT': + form_in = "authority" + r = parse_init_connect(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + host, port, _ = r + path = None + else: + form_in = "absolute" + r = parse_init_proxy(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + _, scheme, host, port, path, _ = r + + headers = read_headers(rfile) + if headers is None: + raise HttpError(400, "Invalid headers") + + expect_header = headers.get_first("expect", "").lower() + if expect_header == "100-continue" and httpversion >= (1, 1): + wfile.write( + 'HTTP/1.1 100 Continue\r\n' + '\r\n' + ) + wfile.flush() + del headers['expect'] + + if include_body: + content = read_http_body( + rfile, headers, body_size_limit, method, None, True + ) + + return Request( + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers, + content + ) + + +def read_response(rfile, request_method, body_size_limit, include_body=True): + """ + Returns an http.Response + + By default, both response header and body are read. + If include_body=False is specified, content may be one of the + following: + - None, if the response is technically allowed to have a response body + - "", if the response must not have a response body (e.g. it's a + response to a HEAD request) + """ + + line = rfile.readline() + # Possible leftover from previous message + if line == "\r\n" or line == "\n": + line = rfile.readline() + if not line: + raise HttpErrorConnClosed(502, "Server disconnect.") + parts = parse_response_line(line) + if not parts: + raise HttpError(502, "Invalid server response: %s" % repr(line)) + proto, code, msg = parts + httpversion = parse_http_protocol(proto) + if httpversion is None: + raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) + headers = read_headers(rfile) + if headers is None: + raise HttpError(502, "Invalid headers.") + + if include_body: + content = read_http_body( + rfile, + headers, + body_size_limit, + request_method, + code, + False + ) + else: + # if include_body==False then a None content means the body should be + # read separately + content = None + return http.Response(httpversion, code, msg, headers, content) + + +def request_preamble(method, resource, http_major="1", http_minor="1"): + return '%s %s HTTP/%s.%s' % ( + method, resource, http_major, http_minor + ) + + +def response_preamble(code, message=None, http_major="1", http_minor="1"): + if message is None: + message = status_codes.RESPONSES.get(code) + return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) diff --git a/netlib/http/http2/__init__.py b/netlib/http/http2/__init__.py new file mode 100644 index 00000000..5acf7696 --- /dev/null +++ b/netlib/http/http2/__init__.py @@ -0,0 +1,2 @@ +from frame import * +from protocol import * diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frame.py new file mode 100644 index 00000000..f7e60471 --- /dev/null +++ b/netlib/http/http2/frame.py @@ -0,0 +1,636 @@ +import sys +import struct +from hpack.hpack import Encoder, Decoder + +from .. import utils + + +class FrameSizeError(Exception): + pass + + +class Frame(object): + + """ + Baseclass Frame + contains header + payload is defined in subclasses + """ + + FLAG_NO_FLAGS = 0x0 + FLAG_ACK = 0x1 + FLAG_END_STREAM = 0x1 + FLAG_END_HEADERS = 0x4 + FLAG_PADDED = 0x8 + FLAG_PRIORITY = 0x20 + + def __init__( + self, + state=None, + length=0, + flags=FLAG_NO_FLAGS, + stream_id=0x0): + valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) + if flags | valid_flags != valid_flags: + raise ValueError('invalid flags detected.') + + if state is None: + class State(object): + pass + + state = State() + state.http2_settings = HTTP2_DEFAULT_SETTINGS.copy() + state.encoder = Encoder() + state.decoder = Decoder() + + self.state = state + + self.length = length + self.type = self.TYPE + self.flags = flags + self.stream_id = stream_id + + @classmethod + def _check_frame_size(cls, length, state): + if state: + settings = state.http2_settings + else: + settings = HTTP2_DEFAULT_SETTINGS.copy() + + max_frame_size = settings[ + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + + if length > max_frame_size: + raise FrameSizeError( + "Frame size exceeded: %d, but only %d allowed." % ( + length, max_frame_size)) + + @classmethod + def from_file(cls, fp, state=None): + """ + read a HTTP/2 frame sent by a server or client + fp is a "file like" object that could be backed by a network + stream or a disk or an in memory stream reader + """ + raw_header = fp.safe_read(9) + + fields = struct.unpack("!HBBBL", raw_header) + length = (fields[0] << 8) + fields[1] + flags = fields[3] + stream_id = fields[4] + + if raw_header[:4] == b'HTTP': # pragma no cover + print >> sys.stderr, "WARNING: This looks like an HTTP/1 connection!" + + cls._check_frame_size(length, state) + + payload = fp.safe_read(length) + return FRAMES[fields[2]].from_bytes( + state, + length, + flags, + stream_id, + payload) + + def to_bytes(self): + payload = self.payload_bytes() + self.length = len(payload) + + self._check_frame_size(self.length, self.state) + + b = struct.pack('!HB', self.length & 0xFFFF00, self.length & 0x0000FF) + b += struct.pack('!B', self.TYPE) + b += struct.pack('!B', self.flags) + b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) + b += payload + + return b + + def payload_bytes(self): # pragma: no cover + raise NotImplementedError() + + def payload_human_readable(self): # pragma: no cover + raise NotImplementedError() + + def human_readable(self, direction="-"): + self.length = len(self.payload_bytes()) + + return "\n".join([ + "%s: %s | length: %d | flags: %#x | stream_id: %d" % ( + direction, self.__class__.__name__, self.length, self.flags, self.stream_id), + self.payload_human_readable(), + "===============================================================", + ]) + + def __eq__(self, other): + return self.to_bytes() == other.to_bytes() + + +class DataFrame(Frame): + TYPE = 0x0 + VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b'', + pad_length=0): + super(DataFrame, self).__init__(state, length, flags, stream_id) + self.payload = payload + self.pad_length = pad_length + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & Frame.FLAG_PADDED: + f.pad_length = struct.unpack('!B', payload[0])[0] + f.payload = payload[1:-f.pad_length] + else: + f.payload = payload + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('DATA frames MUST be associated with a stream.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + b += bytes(self.payload) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + return "payload: %s" % str(self.payload) + + +class HeadersFrame(Frame): + TYPE = 0x1 + VALID_FLAGS = [ + Frame.FLAG_END_STREAM, + Frame.FLAG_END_HEADERS, + Frame.FLAG_PADDED, + Frame.FLAG_PRIORITY] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment=b'', + pad_length=0, + exclusive=False, + stream_dependency=0x0, + weight=0): + super(HeadersFrame, self).__init__(state, length, flags, stream_id) + + self.header_block_fragment = header_block_fragment + self.pad_length = pad_length + self.exclusive = exclusive + self.stream_dependency = stream_dependency + self.weight = weight + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & Frame.FLAG_PADDED: + f.pad_length = struct.unpack('!B', payload[0])[0] + f.header_block_fragment = payload[1:-f.pad_length] + else: + f.header_block_fragment = payload[0:] + + if f.flags & Frame.FLAG_PRIORITY: + f.stream_dependency, f.weight = struct.unpack( + '!LB', f.header_block_fragment[:5]) + f.exclusive = bool(f.stream_dependency >> 31) + f.stream_dependency &= 0x7FFFFFFF + f.header_block_fragment = f.header_block_fragment[5:] + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('HEADERS frames MUST be associated with a stream.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + if self.flags & self.FLAG_PRIORITY: + b += struct.pack('!LB', + (int(self.exclusive) << 31) | self.stream_dependency, + self.weight) + + b += self.header_block_fragment + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + s = [] + + if self.flags & self.FLAG_PRIORITY: + s.append("exclusive: %d" % self.exclusive) + s.append("stream dependency: %#x" % self.stream_dependency) + s.append("weight: %d" % self.weight) + + if self.flags & self.FLAG_PADDED: + s.append("padding: %d" % self.pad_length) + + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + + return "\n".join(s) + + +class PriorityFrame(Frame): + TYPE = 0x2 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + exclusive=False, + stream_dependency=0x0, + weight=0): + super(PriorityFrame, self).__init__(state, length, flags, stream_id) + self.exclusive = exclusive + self.stream_dependency = stream_dependency + self.weight = weight + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + f.stream_dependency, f.weight = struct.unpack('!LB', payload) + f.exclusive = bool(f.stream_dependency >> 31) + f.stream_dependency &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'PRIORITY frames MUST be associated with a stream.') + + if self.stream_dependency == 0x0: + raise ValueError('stream dependency is invalid.') + + return struct.pack( + '!LB', + (int( + self.exclusive) << 31) | self.stream_dependency, + self.weight) + + def payload_human_readable(self): + s = [] + s.append("exclusive: %d" % self.exclusive) + s.append("stream dependency: %#x" % self.stream_dependency) + s.append("weight: %d" % self.weight) + return "\n".join(s) + + +class RstStreamFrame(Frame): + TYPE = 0x3 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + error_code=0x0): + super(RstStreamFrame, self).__init__(state, length, flags, stream_id) + self.error_code = error_code + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + f.error_code = struct.unpack('!L', payload)[0] + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'RST_STREAM frames MUST be associated with a stream.') + + return struct.pack('!L', self.error_code) + + def payload_human_readable(self): + return "error code: %#x" % self.error_code + + +class SettingsFrame(Frame): + TYPE = 0x4 + VALID_FLAGS = [Frame.FLAG_ACK] + + SETTINGS = utils.BiDi( + SETTINGS_HEADER_TABLE_SIZE=0x1, + SETTINGS_ENABLE_PUSH=0x2, + SETTINGS_MAX_CONCURRENT_STREAMS=0x3, + SETTINGS_INITIAL_WINDOW_SIZE=0x4, + SETTINGS_MAX_FRAME_SIZE=0x5, + SETTINGS_MAX_HEADER_LIST_SIZE=0x6, + ) + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + settings=None): + super(SettingsFrame, self).__init__(state, length, flags, stream_id) + + if settings is None: + settings = {} + + self.settings = settings + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + for i in xrange(0, len(payload), 6): + identifier, value = struct.unpack("!HL", payload[i:i + 6]) + f.settings[identifier] = value + + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'SETTINGS frames MUST NOT be associated with a stream.') + + b = b'' + for identifier, value in self.settings.items(): + b += struct.pack("!HL", identifier & 0xFF, value) + + return b + + def payload_human_readable(self): + s = [] + + for identifier, value in self.settings.items(): + s.append("%s: %#x" % (self.SETTINGS.get_name(identifier), value)) + + if not s: + return "settings: None" + else: + return "\n".join(s) + + +class PushPromiseFrame(Frame): + TYPE = 0x5 + VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + promised_stream=0x0, + header_block_fragment=b'', + pad_length=0): + super(PushPromiseFrame, self).__init__(state, length, flags, stream_id) + self.pad_length = pad_length + self.promised_stream = promised_stream + self.header_block_fragment = header_block_fragment + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & Frame.FLAG_PADDED: + f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) + f.header_block_fragment = payload[5:-f.pad_length] + else: + f.promised_stream = int(struct.unpack("!L", payload[:4])[0]) + f.header_block_fragment = payload[4:] + + f.promised_stream &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'PUSH_PROMISE frames MUST be associated with a stream.') + + if self.promised_stream == 0x0: + raise ValueError('Promised stream id not valid.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + b += struct.pack('!L', self.promised_stream & 0x7FFFFFFF) + b += bytes(self.header_block_fragment) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + s = [] + + if self.flags & self.FLAG_PADDED: + s.append("padding: %d" % self.pad_length) + + s.append("promised stream: %#x" % self.promised_stream) + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + + return "\n".join(s) + + +class PingFrame(Frame): + TYPE = 0x6 + VALID_FLAGS = [Frame.FLAG_ACK] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b''): + super(PingFrame, self).__init__(state, length, flags, stream_id) + self.payload = payload + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + f.payload = payload + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'PING frames MUST NOT be associated with a stream.') + + b = self.payload[0:8] + b += b'\0' * (8 - len(b)) + return b + + def payload_human_readable(self): + return "opaque data: %s" % str(self.payload) + + +class GoAwayFrame(Frame): + TYPE = 0x7 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + last_stream=0x0, + error_code=0x0, + data=b''): + super(GoAwayFrame, self).__init__(state, length, flags, stream_id) + self.last_stream = last_stream + self.error_code = error_code + self.data = data + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) + f.last_stream &= 0x7FFFFFFF + f.data = payload[8:] + + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'GOAWAY frames MUST NOT be associated with a stream.') + + b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) + b += bytes(self.data) + return b + + def payload_human_readable(self): + s = [] + s.append("last stream: %#x" % self.last_stream) + s.append("error code: %d" % self.error_code) + s.append("debug data: %s" % str(self.data)) + return "\n".join(s) + + +class WindowUpdateFrame(Frame): + TYPE = 0x8 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + window_size_increment=0x0): + super(WindowUpdateFrame, self).__init__(state, length, flags, stream_id) + self.window_size_increment = window_size_increment + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + f.window_size_increment = struct.unpack("!L", payload)[0] + f.window_size_increment &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: + raise ValueError( + 'Window Szie Increment MUST be greater than 0 and less than 2^31.') + + return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) + + def payload_human_readable(self): + return "window size increment: %#x" % self.window_size_increment + + +class ContinuationFrame(Frame): + TYPE = 0x9 + VALID_FLAGS = [Frame.FLAG_END_HEADERS] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment=b''): + super(ContinuationFrame, self).__init__(state, length, flags, stream_id) + self.header_block_fragment = header_block_fragment + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + f.header_block_fragment = payload + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'CONTINUATION frames MUST be associated with a stream.') + + return self.header_block_fragment + + def payload_human_readable(self): + s = [] + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + return "\n".join(s) + +_FRAME_CLASSES = [ + DataFrame, + HeadersFrame, + PriorityFrame, + RstStreamFrame, + SettingsFrame, + PushPromiseFrame, + PingFrame, + GoAwayFrame, + WindowUpdateFrame, + ContinuationFrame +] +FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} + + +HTTP2_DEFAULT_SETTINGS = { + SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, + SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, +} diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py new file mode 100644 index 00000000..8e5f5429 --- /dev/null +++ b/netlib/http/http2/protocol.py @@ -0,0 +1,240 @@ +from __future__ import (absolute_import, print_function, division) +import itertools + +from hpack.hpack import Encoder, Decoder +from .. import utils +from . import frame + + +class HTTP2Protocol(object): + + ERROR_CODES = utils.BiDi( + NO_ERROR=0x0, + PROTOCOL_ERROR=0x1, + INTERNAL_ERROR=0x2, + FLOW_CONTROL_ERROR=0x3, + SETTINGS_TIMEOUT=0x4, + STREAM_CLOSED=0x5, + FRAME_SIZE_ERROR=0x6, + REFUSED_STREAM=0x7, + CANCEL=0x8, + COMPRESSION_ERROR=0x9, + CONNECT_ERROR=0xa, + ENHANCE_YOUR_CALM=0xb, + INADEQUATE_SECURITY=0xc, + HTTP_1_1_REQUIRED=0xd + ) + + # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + CLIENT_CONNECTION_PREFACE =\ + '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') + + ALPN_PROTO_H2 = 'h2' + + def __init__(self, tcp_handler, is_server=False, dump_frames=False): + self.tcp_handler = tcp_handler + self.is_server = is_server + + self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() + self.current_stream_id = None + self.encoder = Encoder() + self.decoder = Decoder() + self.connection_preface_performed = False + self.dump_frames = dump_frames + + def check_alpn(self): + alp = self.tcp_handler.get_alpn_proto_negotiated() + if alp != self.ALPN_PROTO_H2: + raise NotImplementedError( + "HTTP2Protocol can not handle unknown ALP: %s" % alp) + return True + + def _receive_settings(self, hide=False): + while True: + frm = self.read_frame(hide) + if isinstance(frm, frame.SettingsFrame): + break + + def _read_settings_ack(self, hide=False): # pragma no cover + while True: + frm = self.read_frame(hide) + if isinstance(frm, frame.SettingsFrame): + assert frm.flags & frame.Frame.FLAG_ACK + assert len(frm.settings) == 0 + break + + def perform_server_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + self.connection_preface_performed = True + + magic_length = len(self.CLIENT_CONNECTION_PREFACE) + magic = self.tcp_handler.rfile.safe_read(magic_length) + assert magic == self.CLIENT_CONNECTION_PREFACE + + self.send_frame(frame.SettingsFrame(state=self), hide=True) + self._receive_settings(hide=True) + + def perform_client_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + self.connection_preface_performed = True + + self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) + + self.send_frame(frame.SettingsFrame(state=self), hide=True) + self._receive_settings(hide=True) + + def next_stream_id(self): + if self.current_stream_id is None: + if self.is_server: + # servers must use even stream ids + self.current_stream_id = 2 + else: + # clients must use odd stream ids + self.current_stream_id = 1 + else: + self.current_stream_id += 2 + return self.current_stream_id + + def send_frame(self, frm, hide=False): + raw_bytes = frm.to_bytes() + self.tcp_handler.wfile.write(raw_bytes) + self.tcp_handler.wfile.flush() + if not hide and self.dump_frames: # pragma no cover + print(frm.human_readable(">>")) + + def read_frame(self, hide=False): + frm = frame.Frame.from_file(self.tcp_handler.rfile, self) + if not hide and self.dump_frames: # pragma no cover + print(frm.human_readable("<<")) + if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: + self._apply_settings(frm.settings, hide) + + return frm + + def _apply_settings(self, settings, hide=False): + for setting, value in settings.items(): + old_value = self.http2_settings[setting] + if not old_value: + old_value = '-' + self.http2_settings[setting] = value + + frm = frame.SettingsFrame( + state=self, + flags=frame.Frame.FLAG_ACK) + self.send_frame(frm, hide) + + # be liberal in what we expect from the other end + # to be more strict use: self._read_settings_ack(hide) + + def _create_headers(self, headers, stream_id, end_stream=True): + # TODO: implement max frame size checks and sending in chunks + + flags = frame.Frame.FLAG_END_HEADERS + if end_stream: + flags |= frame.Frame.FLAG_END_STREAM + + header_block_fragment = self.encoder.encode(headers) + + frm = frame.HeadersFrame( + state=self, + flags=flags, + stream_id=stream_id, + header_block_fragment=header_block_fragment) + + if self.dump_frames: # pragma no cover + print(frm.human_readable(">>")) + + return [frm.to_bytes()] + + def _create_body(self, body, stream_id): + if body is None or len(body) == 0: + return b'' + + # TODO: implement max frame size checks and sending in chunks + # TODO: implement flow-control window + + frm = frame.DataFrame( + state=self, + flags=frame.Frame.FLAG_END_STREAM, + stream_id=stream_id, + payload=body) + + if self.dump_frames: # pragma no cover + print(frm.human_readable(">>")) + + return [frm.to_bytes()] + + + def create_request(self, method, path, headers=None, body=None): + if headers is None: + headers = [] + + authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host + if self.tcp_handler.address.port != 443: + authority += ":%d" % self.tcp_handler.address.port + + headers = [ + (b':method', bytes(method)), + (b':path', bytes(path)), + (b':scheme', b'https'), + (b':authority', authority), + ] + headers + + stream_id = self.next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(body is None)), + self._create_body(body, stream_id))) + + def read_response(self): + stream_id_, headers, body = self._receive_transmission() + return headers[':status'], headers, body + + def read_request(self): + return self._receive_transmission() + + def _receive_transmission(self): + body_expected = True + + stream_id = 0 + header_block_fragment = b'' + body = b'' + + while True: + frm = self.read_frame() + if isinstance(frm, frame.HeadersFrame)\ + or isinstance(frm, frame.ContinuationFrame): + stream_id = frm.stream_id + header_block_fragment += frm.header_block_fragment + if frm.flags & frame.Frame.FLAG_END_STREAM: + body_expected = False + if frm.flags & frame.Frame.FLAG_END_HEADERS: + break + + while body_expected: + frm = self.read_frame() + if isinstance(frm, frame.DataFrame): + body += frm.payload + if frm.flags & frame.Frame.FLAG_END_STREAM: + break + # TODO: implement window update & flow + + headers = {} + for header, value in self.decoder.decode(header_block_fragment): + headers[header] = value + + return stream_id, headers, body + + def create_response(self, code, stream_id=None, headers=None, body=None): + if headers is None: + headers = [] + + headers = [(b':status', bytes(str(code)))] + headers + + if not stream_id: + stream_id = self.next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(body is None)), + self._create_body(body, stream_id), + )) diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py new file mode 100644 index 00000000..e7e84fe3 --- /dev/null +++ b/netlib/http/semantics.py @@ -0,0 +1,94 @@ +from __future__ import (absolute_import, print_function, division) +import binascii +import collections +import string +import sys +import urlparse + +from .. import utils + +class Response(object): + + def __init__( + self, + httpversion, + status_code, + msg, + headers, + content, + sslinfo=None, + ): + self.httpversion = httpversion + self.status_code = status_code + self.msg = msg + self.headers = headers + self.content = content + self.sslinfo = sslinfo + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __repr__(self): + return "Response(%s - %s)" % (self.status_code, self.msg) + + + +def is_valid_port(port): + if not 0 <= port <= 65535: + return False + return True + + +def is_valid_host(host): + try: + host.decode("idna") + except ValueError: + return False + if "\0" in host: + return None + return True + + + +def parse_url(url): + """ + Returns a (scheme, host, port, path) tuple, or None on error. + + Checks that: + port is an integer 0-65535 + host is a valid IDNA-encoded hostname with no null-bytes + path is valid ASCII + """ + try: + scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) + except ValueError: + return None + if not scheme: + return None + if '@' in netloc: + # FIXME: Consider what to do with the discarded credentials here Most + # probably we should extend the signature to return these as a separate + # value. + _, netloc = string.rsplit(netloc, '@', maxsplit=1) + if ':' in netloc: + host, port = string.rsplit(netloc, ':', maxsplit=1) + try: + port = int(port) + except ValueError: + return None + else: + host = netloc + if scheme == "https": + port = 443 + else: + port = 80 + path = urlparse.urlunparse(('', '', path, params, query, fragment)) + if not path.startswith("/"): + path = "/" + path + if not is_valid_host(host): + return None + if not utils.isascii(path): + return None + if not is_valid_port(port): + return None + return scheme, host, port, path diff --git a/netlib/http/status_codes.py b/netlib/http/status_codes.py new file mode 100644 index 00000000..dc09f465 --- /dev/null +++ b/netlib/http/status_codes.py @@ -0,0 +1,104 @@ +from __future__ import (absolute_import, print_function, division) + +CONTINUE = 100 +SWITCHING = 101 +OK = 200 +CREATED = 201 +ACCEPTED = 202 +NON_AUTHORITATIVE_INFORMATION = 203 +NO_CONTENT = 204 +RESET_CONTENT = 205 +PARTIAL_CONTENT = 206 +MULTI_STATUS = 207 + +MULTIPLE_CHOICE = 300 +MOVED_PERMANENTLY = 301 +FOUND = 302 +SEE_OTHER = 303 +NOT_MODIFIED = 304 +USE_PROXY = 305 +TEMPORARY_REDIRECT = 307 + +BAD_REQUEST = 400 +UNAUTHORIZED = 401 +PAYMENT_REQUIRED = 402 +FORBIDDEN = 403 +NOT_FOUND = 404 +NOT_ALLOWED = 405 +NOT_ACCEPTABLE = 406 +PROXY_AUTH_REQUIRED = 407 +REQUEST_TIMEOUT = 408 +CONFLICT = 409 +GONE = 410 +LENGTH_REQUIRED = 411 +PRECONDITION_FAILED = 412 +REQUEST_ENTITY_TOO_LARGE = 413 +REQUEST_URI_TOO_LONG = 414 +UNSUPPORTED_MEDIA_TYPE = 415 +REQUESTED_RANGE_NOT_SATISFIABLE = 416 +EXPECTATION_FAILED = 417 + +INTERNAL_SERVER_ERROR = 500 +NOT_IMPLEMENTED = 501 +BAD_GATEWAY = 502 +SERVICE_UNAVAILABLE = 503 +GATEWAY_TIMEOUT = 504 +HTTP_VERSION_NOT_SUPPORTED = 505 +INSUFFICIENT_STORAGE_SPACE = 507 +NOT_EXTENDED = 510 + +RESPONSES = { + # 100 + CONTINUE: "Continue", + SWITCHING: "Switching Protocols", + + # 200 + OK: "OK", + CREATED: "Created", + ACCEPTED: "Accepted", + NON_AUTHORITATIVE_INFORMATION: "Non-Authoritative Information", + NO_CONTENT: "No Content", + RESET_CONTENT: "Reset Content.", + PARTIAL_CONTENT: "Partial Content", + MULTI_STATUS: "Multi-Status", + + # 300 + MULTIPLE_CHOICE: "Multiple Choices", + MOVED_PERMANENTLY: "Moved Permanently", + FOUND: "Found", + SEE_OTHER: "See Other", + NOT_MODIFIED: "Not Modified", + USE_PROXY: "Use Proxy", + # 306 not defined?? + TEMPORARY_REDIRECT: "Temporary Redirect", + + # 400 + BAD_REQUEST: "Bad Request", + UNAUTHORIZED: "Unauthorized", + PAYMENT_REQUIRED: "Payment Required", + FORBIDDEN: "Forbidden", + NOT_FOUND: "Not Found", + NOT_ALLOWED: "Method Not Allowed", + NOT_ACCEPTABLE: "Not Acceptable", + PROXY_AUTH_REQUIRED: "Proxy Authentication Required", + REQUEST_TIMEOUT: "Request Time-out", + CONFLICT: "Conflict", + GONE: "Gone", + LENGTH_REQUIRED: "Length Required", + PRECONDITION_FAILED: "Precondition Failed", + REQUEST_ENTITY_TOO_LARGE: "Request Entity Too Large", + REQUEST_URI_TOO_LONG: "Request-URI Too Long", + UNSUPPORTED_MEDIA_TYPE: "Unsupported Media Type", + REQUESTED_RANGE_NOT_SATISFIABLE: "Requested Range not satisfiable", + EXPECTATION_FAILED: "Expectation Failed", + + # 500 + INTERNAL_SERVER_ERROR: "Internal Server Error", + NOT_IMPLEMENTED: "Not Implemented", + BAD_GATEWAY: "Bad Gateway", + SERVICE_UNAVAILABLE: "Service Unavailable", + GATEWAY_TIMEOUT: "Gateway Time-out", + HTTP_VERSION_NOT_SUPPORTED: "HTTP Version not supported", + INSUFFICIENT_STORAGE_SPACE: "Insufficient Storage Space", + NOT_EXTENDED: "Not Extended" +} diff --git a/netlib/http/user_agents.py b/netlib/http/user_agents.py new file mode 100644 index 00000000..e8681908 --- /dev/null +++ b/netlib/http/user_agents.py @@ -0,0 +1,52 @@ +from __future__ import (absolute_import, print_function, division) + +""" + A small collection of useful user-agent header strings. These should be + kept reasonably current to reflect common usage. +""" + +# pylint: line-too-long + +# A collection of (name, shortcut, string) tuples. + +UASTRINGS = [ + ("android", + "a", + "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02"), # noqa + ("blackberry", + "l", + "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+"), # noqa + ("bingbot", + "b", + "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)"), # noqa + ("chrome", + "c", + "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1"), # noqa + ("firefox", + "f", + "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1"), # noqa + ("googlebot", + "g", + "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"), # noqa + ("ie9", + "i", + "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US)"), # noqa + ("ipad", + "p", + "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9B176 Safari/7534.48.3"), # noqa + ("iphone", + "h", + "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5"), # noqa + ("safari", + "s", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10"), # noqa +] + + +def get_by_shortcut(s): + """ + Retrieve a user agent entry by shortcut. + """ + for i in UASTRINGS: + if s == i[1]: + return i diff --git a/netlib/http2/__init__.py b/netlib/http2/__init__.py deleted file mode 100644 index 5acf7696..00000000 --- a/netlib/http2/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from frame import * -from protocol import * diff --git a/netlib/http2/frame.py b/netlib/http2/frame.py deleted file mode 100644 index f7e60471..00000000 --- a/netlib/http2/frame.py +++ /dev/null @@ -1,636 +0,0 @@ -import sys -import struct -from hpack.hpack import Encoder, Decoder - -from .. import utils - - -class FrameSizeError(Exception): - pass - - -class Frame(object): - - """ - Baseclass Frame - contains header - payload is defined in subclasses - """ - - FLAG_NO_FLAGS = 0x0 - FLAG_ACK = 0x1 - FLAG_END_STREAM = 0x1 - FLAG_END_HEADERS = 0x4 - FLAG_PADDED = 0x8 - FLAG_PRIORITY = 0x20 - - def __init__( - self, - state=None, - length=0, - flags=FLAG_NO_FLAGS, - stream_id=0x0): - valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) - if flags | valid_flags != valid_flags: - raise ValueError('invalid flags detected.') - - if state is None: - class State(object): - pass - - state = State() - state.http2_settings = HTTP2_DEFAULT_SETTINGS.copy() - state.encoder = Encoder() - state.decoder = Decoder() - - self.state = state - - self.length = length - self.type = self.TYPE - self.flags = flags - self.stream_id = stream_id - - @classmethod - def _check_frame_size(cls, length, state): - if state: - settings = state.http2_settings - else: - settings = HTTP2_DEFAULT_SETTINGS.copy() - - max_frame_size = settings[ - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] - - if length > max_frame_size: - raise FrameSizeError( - "Frame size exceeded: %d, but only %d allowed." % ( - length, max_frame_size)) - - @classmethod - def from_file(cls, fp, state=None): - """ - read a HTTP/2 frame sent by a server or client - fp is a "file like" object that could be backed by a network - stream or a disk or an in memory stream reader - """ - raw_header = fp.safe_read(9) - - fields = struct.unpack("!HBBBL", raw_header) - length = (fields[0] << 8) + fields[1] - flags = fields[3] - stream_id = fields[4] - - if raw_header[:4] == b'HTTP': # pragma no cover - print >> sys.stderr, "WARNING: This looks like an HTTP/1 connection!" - - cls._check_frame_size(length, state) - - payload = fp.safe_read(length) - return FRAMES[fields[2]].from_bytes( - state, - length, - flags, - stream_id, - payload) - - def to_bytes(self): - payload = self.payload_bytes() - self.length = len(payload) - - self._check_frame_size(self.length, self.state) - - b = struct.pack('!HB', self.length & 0xFFFF00, self.length & 0x0000FF) - b += struct.pack('!B', self.TYPE) - b += struct.pack('!B', self.flags) - b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) - b += payload - - return b - - def payload_bytes(self): # pragma: no cover - raise NotImplementedError() - - def payload_human_readable(self): # pragma: no cover - raise NotImplementedError() - - def human_readable(self, direction="-"): - self.length = len(self.payload_bytes()) - - return "\n".join([ - "%s: %s | length: %d | flags: %#x | stream_id: %d" % ( - direction, self.__class__.__name__, self.length, self.flags, self.stream_id), - self.payload_human_readable(), - "===============================================================", - ]) - - def __eq__(self, other): - return self.to_bytes() == other.to_bytes() - - -class DataFrame(Frame): - TYPE = 0x0 - VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b'', - pad_length=0): - super(DataFrame, self).__init__(state, length, flags, stream_id) - self.payload = payload - self.pad_length = pad_length - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length = struct.unpack('!B', payload[0])[0] - f.payload = payload[1:-f.pad_length] - else: - f.payload = payload - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError('DATA frames MUST be associated with a stream.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - b += bytes(self.payload) - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - return "payload: %s" % str(self.payload) - - -class HeadersFrame(Frame): - TYPE = 0x1 - VALID_FLAGS = [ - Frame.FLAG_END_STREAM, - Frame.FLAG_END_HEADERS, - Frame.FLAG_PADDED, - Frame.FLAG_PRIORITY] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment=b'', - pad_length=0, - exclusive=False, - stream_dependency=0x0, - weight=0): - super(HeadersFrame, self).__init__(state, length, flags, stream_id) - - self.header_block_fragment = header_block_fragment - self.pad_length = pad_length - self.exclusive = exclusive - self.stream_dependency = stream_dependency - self.weight = weight - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length = struct.unpack('!B', payload[0])[0] - f.header_block_fragment = payload[1:-f.pad_length] - else: - f.header_block_fragment = payload[0:] - - if f.flags & Frame.FLAG_PRIORITY: - f.stream_dependency, f.weight = struct.unpack( - '!LB', f.header_block_fragment[:5]) - f.exclusive = bool(f.stream_dependency >> 31) - f.stream_dependency &= 0x7FFFFFFF - f.header_block_fragment = f.header_block_fragment[5:] - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError('HEADERS frames MUST be associated with a stream.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - if self.flags & self.FLAG_PRIORITY: - b += struct.pack('!LB', - (int(self.exclusive) << 31) | self.stream_dependency, - self.weight) - - b += self.header_block_fragment - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - s = [] - - if self.flags & self.FLAG_PRIORITY: - s.append("exclusive: %d" % self.exclusive) - s.append("stream dependency: %#x" % self.stream_dependency) - s.append("weight: %d" % self.weight) - - if self.flags & self.FLAG_PADDED: - s.append("padding: %d" % self.pad_length) - - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - - return "\n".join(s) - - -class PriorityFrame(Frame): - TYPE = 0x2 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - exclusive=False, - stream_dependency=0x0, - weight=0): - super(PriorityFrame, self).__init__(state, length, flags, stream_id) - self.exclusive = exclusive - self.stream_dependency = stream_dependency - self.weight = weight - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.stream_dependency, f.weight = struct.unpack('!LB', payload) - f.exclusive = bool(f.stream_dependency >> 31) - f.stream_dependency &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'PRIORITY frames MUST be associated with a stream.') - - if self.stream_dependency == 0x0: - raise ValueError('stream dependency is invalid.') - - return struct.pack( - '!LB', - (int( - self.exclusive) << 31) | self.stream_dependency, - self.weight) - - def payload_human_readable(self): - s = [] - s.append("exclusive: %d" % self.exclusive) - s.append("stream dependency: %#x" % self.stream_dependency) - s.append("weight: %d" % self.weight) - return "\n".join(s) - - -class RstStreamFrame(Frame): - TYPE = 0x3 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - error_code=0x0): - super(RstStreamFrame, self).__init__(state, length, flags, stream_id) - self.error_code = error_code - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.error_code = struct.unpack('!L', payload)[0] - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'RST_STREAM frames MUST be associated with a stream.') - - return struct.pack('!L', self.error_code) - - def payload_human_readable(self): - return "error code: %#x" % self.error_code - - -class SettingsFrame(Frame): - TYPE = 0x4 - VALID_FLAGS = [Frame.FLAG_ACK] - - SETTINGS = utils.BiDi( - SETTINGS_HEADER_TABLE_SIZE=0x1, - SETTINGS_ENABLE_PUSH=0x2, - SETTINGS_MAX_CONCURRENT_STREAMS=0x3, - SETTINGS_INITIAL_WINDOW_SIZE=0x4, - SETTINGS_MAX_FRAME_SIZE=0x5, - SETTINGS_MAX_HEADER_LIST_SIZE=0x6, - ) - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings=None): - super(SettingsFrame, self).__init__(state, length, flags, stream_id) - - if settings is None: - settings = {} - - self.settings = settings - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - for i in xrange(0, len(payload), 6): - identifier, value = struct.unpack("!HL", payload[i:i + 6]) - f.settings[identifier] = value - - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'SETTINGS frames MUST NOT be associated with a stream.') - - b = b'' - for identifier, value in self.settings.items(): - b += struct.pack("!HL", identifier & 0xFF, value) - - return b - - def payload_human_readable(self): - s = [] - - for identifier, value in self.settings.items(): - s.append("%s: %#x" % (self.SETTINGS.get_name(identifier), value)) - - if not s: - return "settings: None" - else: - return "\n".join(s) - - -class PushPromiseFrame(Frame): - TYPE = 0x5 - VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - promised_stream=0x0, - header_block_fragment=b'', - pad_length=0): - super(PushPromiseFrame, self).__init__(state, length, flags, stream_id) - self.pad_length = pad_length - self.promised_stream = promised_stream - self.header_block_fragment = header_block_fragment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) - f.header_block_fragment = payload[5:-f.pad_length] - else: - f.promised_stream = int(struct.unpack("!L", payload[:4])[0]) - f.header_block_fragment = payload[4:] - - f.promised_stream &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'PUSH_PROMISE frames MUST be associated with a stream.') - - if self.promised_stream == 0x0: - raise ValueError('Promised stream id not valid.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - b += struct.pack('!L', self.promised_stream & 0x7FFFFFFF) - b += bytes(self.header_block_fragment) - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - s = [] - - if self.flags & self.FLAG_PADDED: - s.append("padding: %d" % self.pad_length) - - s.append("promised stream: %#x" % self.promised_stream) - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - - return "\n".join(s) - - -class PingFrame(Frame): - TYPE = 0x6 - VALID_FLAGS = [Frame.FLAG_ACK] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b''): - super(PingFrame, self).__init__(state, length, flags, stream_id) - self.payload = payload - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.payload = payload - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'PING frames MUST NOT be associated with a stream.') - - b = self.payload[0:8] - b += b'\0' * (8 - len(b)) - return b - - def payload_human_readable(self): - return "opaque data: %s" % str(self.payload) - - -class GoAwayFrame(Frame): - TYPE = 0x7 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x0, - error_code=0x0, - data=b''): - super(GoAwayFrame, self).__init__(state, length, flags, stream_id) - self.last_stream = last_stream - self.error_code = error_code - self.data = data - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) - f.last_stream &= 0x7FFFFFFF - f.data = payload[8:] - - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'GOAWAY frames MUST NOT be associated with a stream.') - - b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) - b += bytes(self.data) - return b - - def payload_human_readable(self): - s = [] - s.append("last stream: %#x" % self.last_stream) - s.append("error code: %d" % self.error_code) - s.append("debug data: %s" % str(self.data)) - return "\n".join(s) - - -class WindowUpdateFrame(Frame): - TYPE = 0x8 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - window_size_increment=0x0): - super(WindowUpdateFrame, self).__init__(state, length, flags, stream_id) - self.window_size_increment = window_size_increment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.window_size_increment = struct.unpack("!L", payload)[0] - f.window_size_increment &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: - raise ValueError( - 'Window Szie Increment MUST be greater than 0 and less than 2^31.') - - return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) - - def payload_human_readable(self): - return "window size increment: %#x" % self.window_size_increment - - -class ContinuationFrame(Frame): - TYPE = 0x9 - VALID_FLAGS = [Frame.FLAG_END_HEADERS] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment=b''): - super(ContinuationFrame, self).__init__(state, length, flags, stream_id) - self.header_block_fragment = header_block_fragment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.header_block_fragment = payload - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'CONTINUATION frames MUST be associated with a stream.') - - return self.header_block_fragment - - def payload_human_readable(self): - s = [] - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - return "\n".join(s) - -_FRAME_CLASSES = [ - DataFrame, - HeadersFrame, - PriorityFrame, - RstStreamFrame, - SettingsFrame, - PushPromiseFrame, - PingFrame, - GoAwayFrame, - WindowUpdateFrame, - ContinuationFrame -] -FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} - - -HTTP2_DEFAULT_SETTINGS = { - SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, - SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, -} diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py deleted file mode 100644 index 8e5f5429..00000000 --- a/netlib/http2/protocol.py +++ /dev/null @@ -1,240 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import itertools - -from hpack.hpack import Encoder, Decoder -from .. import utils -from . import frame - - -class HTTP2Protocol(object): - - ERROR_CODES = utils.BiDi( - NO_ERROR=0x0, - PROTOCOL_ERROR=0x1, - INTERNAL_ERROR=0x2, - FLOW_CONTROL_ERROR=0x3, - SETTINGS_TIMEOUT=0x4, - STREAM_CLOSED=0x5, - FRAME_SIZE_ERROR=0x6, - REFUSED_STREAM=0x7, - CANCEL=0x8, - COMPRESSION_ERROR=0x9, - CONNECT_ERROR=0xa, - ENHANCE_YOUR_CALM=0xb, - INADEQUATE_SECURITY=0xc, - HTTP_1_1_REQUIRED=0xd - ) - - # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" - CLIENT_CONNECTION_PREFACE =\ - '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') - - ALPN_PROTO_H2 = 'h2' - - def __init__(self, tcp_handler, is_server=False, dump_frames=False): - self.tcp_handler = tcp_handler - self.is_server = is_server - - self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() - self.current_stream_id = None - self.encoder = Encoder() - self.decoder = Decoder() - self.connection_preface_performed = False - self.dump_frames = dump_frames - - def check_alpn(self): - alp = self.tcp_handler.get_alpn_proto_negotiated() - if alp != self.ALPN_PROTO_H2: - raise NotImplementedError( - "HTTP2Protocol can not handle unknown ALP: %s" % alp) - return True - - def _receive_settings(self, hide=False): - while True: - frm = self.read_frame(hide) - if isinstance(frm, frame.SettingsFrame): - break - - def _read_settings_ack(self, hide=False): # pragma no cover - while True: - frm = self.read_frame(hide) - if isinstance(frm, frame.SettingsFrame): - assert frm.flags & frame.Frame.FLAG_ACK - assert len(frm.settings) == 0 - break - - def perform_server_connection_preface(self, force=False): - if force or not self.connection_preface_performed: - self.connection_preface_performed = True - - magic_length = len(self.CLIENT_CONNECTION_PREFACE) - magic = self.tcp_handler.rfile.safe_read(magic_length) - assert magic == self.CLIENT_CONNECTION_PREFACE - - self.send_frame(frame.SettingsFrame(state=self), hide=True) - self._receive_settings(hide=True) - - def perform_client_connection_preface(self, force=False): - if force or not self.connection_preface_performed: - self.connection_preface_performed = True - - self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) - - self.send_frame(frame.SettingsFrame(state=self), hide=True) - self._receive_settings(hide=True) - - def next_stream_id(self): - if self.current_stream_id is None: - if self.is_server: - # servers must use even stream ids - self.current_stream_id = 2 - else: - # clients must use odd stream ids - self.current_stream_id = 1 - else: - self.current_stream_id += 2 - return self.current_stream_id - - def send_frame(self, frm, hide=False): - raw_bytes = frm.to_bytes() - self.tcp_handler.wfile.write(raw_bytes) - self.tcp_handler.wfile.flush() - if not hide and self.dump_frames: # pragma no cover - print(frm.human_readable(">>")) - - def read_frame(self, hide=False): - frm = frame.Frame.from_file(self.tcp_handler.rfile, self) - if not hide and self.dump_frames: # pragma no cover - print(frm.human_readable("<<")) - if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: - self._apply_settings(frm.settings, hide) - - return frm - - def _apply_settings(self, settings, hide=False): - for setting, value in settings.items(): - old_value = self.http2_settings[setting] - if not old_value: - old_value = '-' - self.http2_settings[setting] = value - - frm = frame.SettingsFrame( - state=self, - flags=frame.Frame.FLAG_ACK) - self.send_frame(frm, hide) - - # be liberal in what we expect from the other end - # to be more strict use: self._read_settings_ack(hide) - - def _create_headers(self, headers, stream_id, end_stream=True): - # TODO: implement max frame size checks and sending in chunks - - flags = frame.Frame.FLAG_END_HEADERS - if end_stream: - flags |= frame.Frame.FLAG_END_STREAM - - header_block_fragment = self.encoder.encode(headers) - - frm = frame.HeadersFrame( - state=self, - flags=flags, - stream_id=stream_id, - header_block_fragment=header_block_fragment) - - if self.dump_frames: # pragma no cover - print(frm.human_readable(">>")) - - return [frm.to_bytes()] - - def _create_body(self, body, stream_id): - if body is None or len(body) == 0: - return b'' - - # TODO: implement max frame size checks and sending in chunks - # TODO: implement flow-control window - - frm = frame.DataFrame( - state=self, - flags=frame.Frame.FLAG_END_STREAM, - stream_id=stream_id, - payload=body) - - if self.dump_frames: # pragma no cover - print(frm.human_readable(">>")) - - return [frm.to_bytes()] - - - def create_request(self, method, path, headers=None, body=None): - if headers is None: - headers = [] - - authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host - if self.tcp_handler.address.port != 443: - authority += ":%d" % self.tcp_handler.address.port - - headers = [ - (b':method', bytes(method)), - (b':path', bytes(path)), - (b':scheme', b'https'), - (b':authority', authority), - ] + headers - - stream_id = self.next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(body is None)), - self._create_body(body, stream_id))) - - def read_response(self): - stream_id_, headers, body = self._receive_transmission() - return headers[':status'], headers, body - - def read_request(self): - return self._receive_transmission() - - def _receive_transmission(self): - body_expected = True - - stream_id = 0 - header_block_fragment = b'' - body = b'' - - while True: - frm = self.read_frame() - if isinstance(frm, frame.HeadersFrame)\ - or isinstance(frm, frame.ContinuationFrame): - stream_id = frm.stream_id - header_block_fragment += frm.header_block_fragment - if frm.flags & frame.Frame.FLAG_END_STREAM: - body_expected = False - if frm.flags & frame.Frame.FLAG_END_HEADERS: - break - - while body_expected: - frm = self.read_frame() - if isinstance(frm, frame.DataFrame): - body += frm.payload - if frm.flags & frame.Frame.FLAG_END_STREAM: - break - # TODO: implement window update & flow - - headers = {} - for header, value in self.decoder.decode(header_block_fragment): - headers[header] = value - - return stream_id, headers, body - - def create_response(self, code, stream_id=None, headers=None, body=None): - if headers is None: - headers = [] - - headers = [(b':status', bytes(str(code)))] + headers - - if not stream_id: - stream_id = self.next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(body is None)), - self._create_body(body, stream_id), - )) diff --git a/netlib/http_auth.py b/netlib/http_auth.py deleted file mode 100644 index adab4aed..00000000 --- a/netlib/http_auth.py +++ /dev/null @@ -1,148 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -from argparse import Action, ArgumentTypeError -from . import http - - -class NullProxyAuth(object): - - """ - No proxy auth at all (returns empty challange headers) - """ - - def __init__(self, password_manager): - self.password_manager = password_manager - - def clean(self, headers_): - """ - Clean up authentication headers, so they're not passed upstream. - """ - pass - - def authenticate(self, headers_): - """ - Tests that the user is allowed to use the proxy - """ - return True - - def auth_challenge_headers(self): - """ - Returns a dictionary containing the headers require to challenge the user - """ - return {} - - -class BasicProxyAuth(NullProxyAuth): - CHALLENGE_HEADER = 'Proxy-Authenticate' - AUTH_HEADER = 'Proxy-Authorization' - - def __init__(self, password_manager, realm): - NullProxyAuth.__init__(self, password_manager) - self.realm = realm - - def clean(self, headers): - del headers[self.AUTH_HEADER] - - def authenticate(self, headers): - auth_value = headers.get(self.AUTH_HEADER, []) - if not auth_value: - return False - parts = http.parse_http_basic_auth(auth_value[0]) - if not parts: - return False - scheme, username, password = parts - if scheme.lower() != 'basic': - return False - if not self.password_manager.test(username, password): - return False - self.username = username - return True - - def auth_challenge_headers(self): - return {self.CHALLENGE_HEADER: 'Basic realm="%s"' % self.realm} - - -class PassMan(object): - - def test(self, username_, password_token_): - return False - - -class PassManNonAnon(PassMan): - - """ - Ensure the user specifies a username, accept any password. - """ - - def test(self, username, password_token_): - if username: - return True - return False - - -class PassManHtpasswd(PassMan): - - """ - Read usernames and passwords from an htpasswd file - """ - - def __init__(self, path): - """ - Raises ValueError if htpasswd file is invalid. - """ - import passlib.apache - self.htpasswd = passlib.apache.HtpasswdFile(path) - - def test(self, username, password_token): - return bool(self.htpasswd.check_password(username, password_token)) - - -class PassManSingleUser(PassMan): - - def __init__(self, username, password): - self.username, self.password = username, password - - def test(self, username, password_token): - return self.username == username and self.password == password_token - - -class AuthAction(Action): - - """ - Helper class to allow seamless integration int argparse. Example usage: - parser.add_argument( - "--nonanonymous", - action=NonanonymousAuthAction, nargs=0, - help="Allow access to any user long as a credentials are specified." - ) - """ - - def __call__(self, parser, namespace, values, option_string=None): - passman = self.getPasswordManager(values) - authenticator = BasicProxyAuth(passman, "mitmproxy") - setattr(namespace, self.dest, authenticator) - - def getPasswordManager(self, s): # pragma: nocover - raise NotImplementedError() - - -class SingleuserAuthAction(AuthAction): - - def getPasswordManager(self, s): - if len(s.split(':')) != 2: - raise ArgumentTypeError( - "Invalid single-user specification. Please use the format username:password" - ) - username, password = s.split(':') - return PassManSingleUser(username, password) - - -class NonanonymousAuthAction(AuthAction): - - def getPasswordManager(self, s): - return PassManNonAnon() - - -class HtpasswdAuthAction(AuthAction): - - def getPasswordManager(self, s): - return PassManHtpasswd(s) diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py deleted file mode 100644 index e91ee5c0..00000000 --- a/netlib/http_cookies.py +++ /dev/null @@ -1,193 +0,0 @@ -""" -A flexible module for cookie parsing and manipulation. - -This module differs from usual standards-compliant cookie modules in a number -of ways. We try to be as permissive as possible, and to retain even mal-formed -information. Duplicate cookies are preserved in parsing, and can be set in -formatting. We do attempt to escape and quote values where needed, but will not -reject data that violate the specs. - -Parsing accepts the formats in RFC6265 and partially RFC2109 and RFC2965. We do -not parse the comma-separated variant of Set-Cookie that allows multiple -cookies to be set in a single header. Technically this should be feasible, but -it turns out that violations of RFC6265 that makes the parsing problem -indeterminate are much more common than genuine occurences of the multi-cookie -variants. Serialization follows RFC6265. - - http://tools.ietf.org/html/rfc6265 - http://tools.ietf.org/html/rfc2109 - http://tools.ietf.org/html/rfc2965 -""" - -# TODO -# - Disallow LHS-only Cookie values - -import re - -import odict - - -def _read_until(s, start, term): - """ - Read until one of the characters in term is reached. - """ - if start == len(s): - return "", start + 1 - for i in range(start, len(s)): - if s[i] in term: - return s[start:i], i - return s[start:i + 1], i + 1 - - -def _read_token(s, start): - """ - Read a token - the LHS of a token/value pair in a cookie. - """ - return _read_until(s, start, ";=") - - -def _read_quoted_string(s, start): - """ - start: offset to the first quote of the string to be read - - A sort of loose super-set of the various quoted string specifications. - - RFC6265 disallows backslashes or double quotes within quoted strings. - Prior RFCs use backslashes to escape. This leaves us free to apply - backslash escaping by default and be compatible with everything. - """ - escaping = False - ret = [] - # Skip the first quote - for i in range(start + 1, len(s)): - if escaping: - ret.append(s[i]) - escaping = False - elif s[i] == '"': - break - elif s[i] == "\\": - escaping = True - else: - ret.append(s[i]) - return "".join(ret), i + 1 - - -def _read_value(s, start, delims): - """ - Reads a value - the RHS of a token/value pair in a cookie. - - special: If the value is special, commas are premitted. Else comma - terminates. This helps us support old and new style values. - """ - if start >= len(s): - return "", start - elif s[start] == '"': - return _read_quoted_string(s, start) - else: - return _read_until(s, start, delims) - - -def _read_pairs(s, off=0): - """ - Read pairs of lhs=rhs values. - - off: start offset - specials: a lower-cased list of keys that may contain commas - """ - vals = [] - while True: - lhs, off = _read_token(s, off) - lhs = lhs.lstrip() - if lhs: - rhs = None - if off < len(s): - if s[off] == "=": - rhs, off = _read_value(s, off + 1, ";") - vals.append([lhs, rhs]) - off += 1 - if not off < len(s): - break - return vals, off - - -def _has_special(s): - for i in s: - if i in '",;\\': - return True - o = ord(i) - if o < 0x21 or o > 0x7e: - return True - return False - - -ESCAPE = re.compile(r"([\"\\])") - - -def _format_pairs(lst, specials=(), sep="; "): - """ - specials: A lower-cased list of keys that will not be quoted. - """ - vals = [] - for k, v in lst: - if v is None: - vals.append(k) - else: - if k.lower() not in specials and _has_special(v): - v = ESCAPE.sub(r"\\\1", v) - v = '"%s"' % v - vals.append("%s=%s" % (k, v)) - return sep.join(vals) - - -def _format_set_cookie_pairs(lst): - return _format_pairs( - lst, - specials=("expires", "path") - ) - - -def _parse_set_cookie_pairs(s): - """ - For Set-Cookie, we support multiple cookies as described in RFC2109. - This function therefore returns a list of lists. - """ - pairs, off_ = _read_pairs(s) - return pairs - - -def parse_set_cookie_header(line): - """ - Parse a Set-Cookie header value - - Returns a (name, value, attrs) tuple, or None, where attrs is an - ODictCaseless set of attributes. No attempt is made to parse attribute - values - they are treated purely as strings. - """ - pairs = _parse_set_cookie_pairs(line) - if pairs: - return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) - - -def format_set_cookie_header(name, value, attrs): - """ - Formats a Set-Cookie header value. - """ - pairs = [[name, value]] - pairs.extend(attrs.lst) - return _format_set_cookie_pairs(pairs) - - -def parse_cookie_header(line): - """ - Parse a Cookie header value. - Returns a (possibly empty) ODict object. - """ - pairs, off_ = _read_pairs(line) - return odict.ODict(pairs) - - -def format_cookie_header(od): - """ - Formats a Cookie header value. - """ - return _format_pairs(od.lst) diff --git a/netlib/http_semantics.py b/netlib/http_semantics.py deleted file mode 100644 index e8313e3c..00000000 --- a/netlib/http_semantics.py +++ /dev/null @@ -1,23 +0,0 @@ -class Response(object): - - def __init__( - self, - httpversion, - status_code, - msg, - headers, - content, - sslinfo=None, - ): - self.httpversion = httpversion - self.status_code = status_code - self.msg = msg - self.headers = headers - self.content = content - self.sslinfo = sslinfo - - def __eq__(self, other): - return self.__dict__ == other.__dict__ - - def __repr__(self): - return "Response(%s - %s)" % (self.status_code, self.msg) diff --git a/netlib/http_status.py b/netlib/http_status.py deleted file mode 100644 index dc09f465..00000000 --- a/netlib/http_status.py +++ /dev/null @@ -1,104 +0,0 @@ -from __future__ import (absolute_import, print_function, division) - -CONTINUE = 100 -SWITCHING = 101 -OK = 200 -CREATED = 201 -ACCEPTED = 202 -NON_AUTHORITATIVE_INFORMATION = 203 -NO_CONTENT = 204 -RESET_CONTENT = 205 -PARTIAL_CONTENT = 206 -MULTI_STATUS = 207 - -MULTIPLE_CHOICE = 300 -MOVED_PERMANENTLY = 301 -FOUND = 302 -SEE_OTHER = 303 -NOT_MODIFIED = 304 -USE_PROXY = 305 -TEMPORARY_REDIRECT = 307 - -BAD_REQUEST = 400 -UNAUTHORIZED = 401 -PAYMENT_REQUIRED = 402 -FORBIDDEN = 403 -NOT_FOUND = 404 -NOT_ALLOWED = 405 -NOT_ACCEPTABLE = 406 -PROXY_AUTH_REQUIRED = 407 -REQUEST_TIMEOUT = 408 -CONFLICT = 409 -GONE = 410 -LENGTH_REQUIRED = 411 -PRECONDITION_FAILED = 412 -REQUEST_ENTITY_TOO_LARGE = 413 -REQUEST_URI_TOO_LONG = 414 -UNSUPPORTED_MEDIA_TYPE = 415 -REQUESTED_RANGE_NOT_SATISFIABLE = 416 -EXPECTATION_FAILED = 417 - -INTERNAL_SERVER_ERROR = 500 -NOT_IMPLEMENTED = 501 -BAD_GATEWAY = 502 -SERVICE_UNAVAILABLE = 503 -GATEWAY_TIMEOUT = 504 -HTTP_VERSION_NOT_SUPPORTED = 505 -INSUFFICIENT_STORAGE_SPACE = 507 -NOT_EXTENDED = 510 - -RESPONSES = { - # 100 - CONTINUE: "Continue", - SWITCHING: "Switching Protocols", - - # 200 - OK: "OK", - CREATED: "Created", - ACCEPTED: "Accepted", - NON_AUTHORITATIVE_INFORMATION: "Non-Authoritative Information", - NO_CONTENT: "No Content", - RESET_CONTENT: "Reset Content.", - PARTIAL_CONTENT: "Partial Content", - MULTI_STATUS: "Multi-Status", - - # 300 - MULTIPLE_CHOICE: "Multiple Choices", - MOVED_PERMANENTLY: "Moved Permanently", - FOUND: "Found", - SEE_OTHER: "See Other", - NOT_MODIFIED: "Not Modified", - USE_PROXY: "Use Proxy", - # 306 not defined?? - TEMPORARY_REDIRECT: "Temporary Redirect", - - # 400 - BAD_REQUEST: "Bad Request", - UNAUTHORIZED: "Unauthorized", - PAYMENT_REQUIRED: "Payment Required", - FORBIDDEN: "Forbidden", - NOT_FOUND: "Not Found", - NOT_ALLOWED: "Method Not Allowed", - NOT_ACCEPTABLE: "Not Acceptable", - PROXY_AUTH_REQUIRED: "Proxy Authentication Required", - REQUEST_TIMEOUT: "Request Time-out", - CONFLICT: "Conflict", - GONE: "Gone", - LENGTH_REQUIRED: "Length Required", - PRECONDITION_FAILED: "Precondition Failed", - REQUEST_ENTITY_TOO_LARGE: "Request Entity Too Large", - REQUEST_URI_TOO_LONG: "Request-URI Too Long", - UNSUPPORTED_MEDIA_TYPE: "Unsupported Media Type", - REQUESTED_RANGE_NOT_SATISFIABLE: "Requested Range not satisfiable", - EXPECTATION_FAILED: "Expectation Failed", - - # 500 - INTERNAL_SERVER_ERROR: "Internal Server Error", - NOT_IMPLEMENTED: "Not Implemented", - BAD_GATEWAY: "Bad Gateway", - SERVICE_UNAVAILABLE: "Service Unavailable", - GATEWAY_TIMEOUT: "Gateway Time-out", - HTTP_VERSION_NOT_SUPPORTED: "HTTP Version not supported", - INSUFFICIENT_STORAGE_SPACE: "Insufficient Storage Space", - NOT_EXTENDED: "Not Extended" -} diff --git a/netlib/http_uastrings.py b/netlib/http_uastrings.py deleted file mode 100644 index e8681908..00000000 --- a/netlib/http_uastrings.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import (absolute_import, print_function, division) - -""" - A small collection of useful user-agent header strings. These should be - kept reasonably current to reflect common usage. -""" - -# pylint: line-too-long - -# A collection of (name, shortcut, string) tuples. - -UASTRINGS = [ - ("android", - "a", - "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02"), # noqa - ("blackberry", - "l", - "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+"), # noqa - ("bingbot", - "b", - "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)"), # noqa - ("chrome", - "c", - "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1"), # noqa - ("firefox", - "f", - "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1"), # noqa - ("googlebot", - "g", - "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"), # noqa - ("ie9", - "i", - "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US)"), # noqa - ("ipad", - "p", - "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9B176 Safari/7534.48.3"), # noqa - ("iphone", - "h", - "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5"), # noqa - ("safari", - "s", - "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10"), # noqa -] - - -def get_by_shortcut(s): - """ - Retrieve a user agent entry by shortcut. - """ - for i in UASTRINGS: - if s == i[1]: - return i diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py index d41059fa..49d8ee10 100644 --- a/netlib/websockets/frame.py +++ b/netlib/websockets/frame.py @@ -6,7 +6,7 @@ import struct import io from .protocol import Masker -from .. import utils, odict, tcp +from netlib import utils, odict, tcp DEFAULT = object() diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py index dcab53fb..29b4db3d 100644 --- a/netlib/websockets/protocol.py +++ b/netlib/websockets/protocol.py @@ -5,7 +5,7 @@ import os import struct import io -from .. import utils, odict, tcp +from netlib import utils, odict, tcp # Colleciton of utility functions that implement small portions of the RFC6455 # WebSockets Protocol Useful for building WebSocket clients and servers. diff --git a/test/http/__init__.py b/test/http/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/http/http1/__init__.py b/test/http/http1/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py new file mode 100644 index 00000000..05e82831 --- /dev/null +++ b/test/http/http1/test_protocol.py @@ -0,0 +1,445 @@ +import cStringIO +import textwrap +import binascii + +from netlib import http, odict, tcp +from netlib.http.http1 import protocol +from ... import tutils, tservers + + +def test_has_chunked_encoding(): + h = odict.ODictCaseless() + assert not protocol.has_chunked_encoding(h) + h["transfer-encoding"] = ["chunked"] + assert protocol.has_chunked_encoding(h) + + +def test_read_chunked(): + + h = odict.ODictCaseless() + h["transfer-encoding"] = ["chunked"] + s = cStringIO.StringIO("1\r\na\r\n0\r\n") + + tutils.raises( + "malformed chunked body", + protocol.read_http_body, + s, h, None, "GET", None, True + ) + + s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") + assert protocol.read_http_body(s, h, None, "GET", None, True) == "a" + + s = cStringIO.StringIO("\r\n\r\n1\r\na\r\n0\r\n\r\n") + assert protocol.read_http_body(s, h, None, "GET", None, True) == "a" + + s = cStringIO.StringIO("\r\n") + tutils.raises( + "closed prematurely", + protocol.read_http_body, + s, h, None, "GET", None, True + ) + + s = cStringIO.StringIO("1\r\nfoo") + tutils.raises( + "malformed chunked body", + protocol.read_http_body, + s, h, None, "GET", None, True + ) + + s = cStringIO.StringIO("foo\r\nfoo") + tutils.raises( + protocol.HttpError, + protocol.read_http_body, + s, h, None, "GET", None, True + ) + + s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") + tutils.raises("too large", protocol.read_http_body, s, h, 2, "GET", None, True) + + +def test_connection_close(): + h = odict.ODictCaseless() + assert protocol.connection_close((1, 0), h) + assert not protocol.connection_close((1, 1), h) + + h["connection"] = ["keep-alive"] + assert not protocol.connection_close((1, 1), h) + + h["connection"] = ["close"] + assert protocol.connection_close((1, 1), h) + + +def test_get_header_tokens(): + h = odict.ODictCaseless() + assert protocol.get_header_tokens(h, "foo") == [] + h["foo"] = ["bar"] + assert protocol.get_header_tokens(h, "foo") == ["bar"] + h["foo"] = ["bar, voing"] + assert protocol.get_header_tokens(h, "foo") == ["bar", "voing"] + h["foo"] = ["bar, voing", "oink"] + assert protocol.get_header_tokens(h, "foo") == ["bar", "voing", "oink"] + + +def test_read_http_body_request(): + h = odict.ODictCaseless() + r = cStringIO.StringIO("testing") + assert protocol.read_http_body(r, h, None, "GET", None, True) == "" + + +def test_read_http_body_response(): + h = odict.ODictCaseless() + s = tcp.Reader(cStringIO.StringIO("testing")) + assert protocol.read_http_body(s, h, None, "GET", 200, False) == "testing" + + +def test_read_http_body(): + # test default case + h = odict.ODictCaseless() + h["content-length"] = [7] + s = cStringIO.StringIO("testing") + assert protocol.read_http_body(s, h, None, "GET", 200, False) == "testing" + + # test content length: invalid header + h["content-length"] = ["foo"] + s = cStringIO.StringIO("testing") + tutils.raises( + protocol.HttpError, + protocol.read_http_body, + s, h, None, "GET", 200, False + ) + + # test content length: invalid header #2 + h["content-length"] = [-1] + s = cStringIO.StringIO("testing") + tutils.raises( + protocol.HttpError, + protocol.read_http_body, + s, h, None, "GET", 200, False + ) + + # test content length: content length > actual content + h["content-length"] = [5] + s = cStringIO.StringIO("testing") + tutils.raises( + protocol.HttpError, + protocol.read_http_body, + s, h, 4, "GET", 200, False + ) + + # test content length: content length < actual content + s = cStringIO.StringIO("testing") + assert len(protocol.read_http_body(s, h, None, "GET", 200, False)) == 5 + + # test no content length: limit > actual content + h = odict.ODictCaseless() + s = tcp.Reader(cStringIO.StringIO("testing")) + assert len(protocol.read_http_body(s, h, 100, "GET", 200, False)) == 7 + + # test no content length: limit < actual content + s = tcp.Reader(cStringIO.StringIO("testing")) + tutils.raises( + protocol.HttpError, + protocol.read_http_body, + s, h, 4, "GET", 200, False + ) + + # test chunked + h = odict.ODictCaseless() + h["transfer-encoding"] = ["chunked"] + s = tcp.Reader(cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n")) + assert protocol.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa" + + +def test_expected_http_body_size(): + # gibber in the content-length field + h = odict.ODictCaseless() + h["content-length"] = ["foo"] + assert protocol.expected_http_body_size(h, False, "GET", 200) is None + # negative number in the content-length field + h = odict.ODictCaseless() + h["content-length"] = ["-7"] + assert protocol.expected_http_body_size(h, False, "GET", 200) is None + # explicit length + h = odict.ODictCaseless() + h["content-length"] = ["5"] + assert protocol.expected_http_body_size(h, False, "GET", 200) == 5 + # no length + h = odict.ODictCaseless() + assert protocol.expected_http_body_size(h, False, "GET", 200) == -1 + # no length request + h = odict.ODictCaseless() + assert protocol.expected_http_body_size(h, True, "GET", None) == 0 + + +def test_parse_http_protocol(): + assert protocol.parse_http_protocol("HTTP/1.1") == (1, 1) + assert protocol.parse_http_protocol("HTTP/0.0") == (0, 0) + assert not protocol.parse_http_protocol("HTTP/a.1") + assert not protocol.parse_http_protocol("HTTP/1.a") + assert not protocol.parse_http_protocol("foo/0.0") + assert not protocol.parse_http_protocol("HTTP/x") + + +def test_parse_init_connect(): + assert protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0") + assert not protocol.parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0") + assert not protocol.parse_init_connect("CONNECT \0host.com:443 HTTP/1.0") + assert not protocol.parse_init_connect("CONNECT host.com:444444 HTTP/1.0") + assert not protocol.parse_init_connect("bogus") + assert not protocol.parse_init_connect("GET host.com:443 HTTP/1.0") + assert not protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0") + assert not protocol.parse_init_connect("CONNECT host.com:443 foo/1.0") + assert not protocol.parse_init_connect("CONNECT host.com:foo HTTP/1.0") + + +def test_parse_init_proxy(): + u = "GET http://foo.com:8888/test HTTP/1.1" + m, s, h, po, pa, httpversion = protocol.parse_init_proxy(u) + assert m == "GET" + assert s == "http" + assert h == "foo.com" + assert po == 8888 + assert pa == "/test" + assert httpversion == (1, 1) + + u = "G\xfeET http://foo.com:8888/test HTTP/1.1" + assert not protocol.parse_init_proxy(u) + + assert not protocol.parse_init_proxy("invalid") + assert not protocol.parse_init_proxy("GET invalid HTTP/1.1") + assert not protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1") + + +def test_parse_init_http(): + u = "GET /test HTTP/1.1" + m, u, httpversion = protocol.parse_init_http(u) + assert m == "GET" + assert u == "/test" + assert httpversion == (1, 1) + + u = "G\xfeET /test HTTP/1.1" + assert not protocol.parse_init_http(u) + + assert not protocol.parse_init_http("invalid") + assert not protocol.parse_init_http("GET invalid HTTP/1.1") + assert not protocol.parse_init_http("GET /test foo/1.1") + assert not protocol.parse_init_http("GET /test\xc0 HTTP/1.1") + + +class TestReadHeaders: + + def _read(self, data, verbatim=False): + if not verbatim: + data = textwrap.dedent(data) + data = data.strip() + s = cStringIO.StringIO(data) + return protocol.read_headers(s) + + def test_read_simple(self): + data = """ + Header: one + Header2: two + \r\n + """ + h = self._read(data) + assert h.lst == [["Header", "one"], ["Header2", "two"]] + + def test_read_multi(self): + data = """ + Header: one + Header: two + \r\n + """ + h = self._read(data) + assert h.lst == [["Header", "one"], ["Header", "two"]] + + def test_read_continued(self): + data = """ + Header: one + \ttwo + Header2: three + \r\n + """ + h = self._read(data) + assert h.lst == [["Header", "one\r\n two"], ["Header2", "three"]] + + def test_read_continued_err(self): + data = "\tfoo: bar\r\n" + assert self._read(data, True) is None + + def test_read_err(self): + data = """ + foo + """ + assert self._read(data) is None + + +class NoContentLengthHTTPHandler(tcp.BaseHandler): + + def handle(self): + self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n") + self.wfile.flush() + + +class TestReadResponseNoContentLength(tservers.ServerTestBase): + handler = NoContentLengthHTTPHandler + + def test_no_content_length(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + resp = protocol.read_response(c.rfile, "GET", None) + assert resp.content == "bar\r\n\r\n" + + +def test_read_response(): + def tst(data, method, limit, include_body=True): + data = textwrap.dedent(data) + r = cStringIO.StringIO(data) + return protocol.read_response( + r, method, limit, include_body=include_body + ) + + tutils.raises("server disconnect", tst, "", "GET", None) + tutils.raises("invalid server response", tst, "foo", "GET", None) + data = """ + HTTP/1.1 200 OK + """ + assert tst(data, "GET", None) == http.Response( + (1, 1), 200, 'OK', odict.ODictCaseless(), '' + ) + data = """ + HTTP/1.1 200 + """ + assert tst(data, "GET", None) == http.Response( + (1, 1), 200, '', odict.ODictCaseless(), '' + ) + data = """ + HTTP/x 200 OK + """ + tutils.raises("invalid http version", tst, data, "GET", None) + data = """ + HTTP/1.1 xx OK + """ + tutils.raises("invalid server response", tst, data, "GET", None) + + data = """ + HTTP/1.1 100 CONTINUE + + HTTP/1.1 200 OK + """ + assert tst(data, "GET", None) == http.Response( + (1, 1), 100, 'CONTINUE', odict.ODictCaseless(), '' + ) + + data = """ + HTTP/1.1 200 OK + Content-Length: 3 + + foo + """ + assert tst(data, "GET", None).content == 'foo' + assert tst(data, "HEAD", None).content == '' + + data = """ + HTTP/1.1 200 OK + \tContent-Length: 3 + + foo + """ + tutils.raises("invalid headers", tst, data, "GET", None) + + data = """ + HTTP/1.1 200 OK + Content-Length: 3 + + foo + """ + assert tst(data, "GET", None, include_body=False).content is None + + +def test_parse_http_basic_auth(): + vals = ("basic", "foo", "bar") + assert protocol.parse_http_basic_auth( + protocol.assemble_http_basic_auth(*vals) + ) == vals + assert not protocol.parse_http_basic_auth("") + assert not protocol.parse_http_basic_auth("foo bar") + v = "basic " + binascii.b2a_base64("foo") + assert not protocol.parse_http_basic_auth(v) + + +def test_get_request_line(): + r = cStringIO.StringIO("\nfoo") + assert protocol.get_request_line(r) == "foo" + assert not protocol.get_request_line(r) + + +class TestReadRequest(): + + def tst(self, data, **kwargs): + r = cStringIO.StringIO(data) + return protocol.read_request(r, **kwargs) + + def test_invalid(self): + tutils.raises( + "bad http request", + self.tst, + "xxx" + ) + tutils.raises( + "bad http request line", + self.tst, + "get /\xff HTTP/1.1" + ) + tutils.raises( + "invalid headers", + self.tst, + "get / HTTP/1.1\r\nfoo" + ) + tutils.raises( + tcp.NetLibDisconnect, + self.tst, + "\r\n" + ) + + def test_asterisk_form_in(self): + v = self.tst("OPTIONS * HTTP/1.1") + assert v.form_in == "relative" + assert v.method == "OPTIONS" + + def test_absolute_form_in(self): + tutils.raises( + "Bad HTTP request line", + self.tst, + "GET oops-no-protocol.com HTTP/1.1" + ) + v = self.tst("GET http://address:22/ HTTP/1.1") + assert v.form_in == "absolute" + assert v.port == 22 + assert v.host == "address" + assert v.scheme == "http" + + def test_connect(self): + tutils.raises( + "Bad HTTP request line", + self.tst, + "CONNECT oops-no-port.com HTTP/1.1" + ) + v = self.tst("CONNECT foo.com:443 HTTP/1.1") + assert v.form_in == "authority" + assert v.method == "CONNECT" + assert v.port == 443 + assert v.host == "foo.com" + + def test_expect(self): + w = cStringIO.StringIO() + r = cStringIO.StringIO( + "GET / HTTP/1.1\r\n" + "Content-Length: 3\r\n" + "Expect: 100-continue\r\n\r\n" + "foobar", + ) + v = protocol.read_request(r, wfile=w) + assert w.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" + assert v.content == "foo" + assert r.read(3) == "bar" diff --git a/test/http/http2/__init__.py b/test/http/http2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/http/http2/test_frames.py b/test/http/http2/test_frames.py new file mode 100644 index 00000000..ee2edc39 --- /dev/null +++ b/test/http/http2/test_frames.py @@ -0,0 +1,704 @@ +import cStringIO +from test import tutils +from nose.tools import assert_equal +from netlib import tcp +from netlib.http.http2.frame import * + + +def hex_to_file(data): + data = data.decode('hex') + return tcp.Reader(cStringIO.StringIO(data)) + + +def test_invalid_flags(): + tutils.raises( + ValueError, + DataFrame, + flags=ContinuationFrame.FLAG_END_HEADERS, + stream_id=0x1234567, + payload='foobar') + + +def test_frame_equality(): + a = DataFrame( + length=6, + flags=Frame.FLAG_END_STREAM, + stream_id=0x1234567, + payload='foobar') + b = DataFrame( + length=6, + flags=Frame.FLAG_END_STREAM, + stream_id=0x1234567, + payload='foobar') + assert_equal(a, b) + + +def test_too_large_frames(): + f = DataFrame( + length=9000, + flags=Frame.FLAG_END_STREAM, + stream_id=0x1234567, + payload='foobar' * 3000) + tutils.raises(FrameSizeError, f.to_bytes) + + +def test_data_frame_to_bytes(): + f = DataFrame( + length=6, + flags=Frame.FLAG_END_STREAM, + stream_id=0x1234567, + payload='foobar') + assert_equal(f.to_bytes().encode('hex'), '000006000101234567666f6f626172') + + f = DataFrame( + length=11, + flags=(Frame.FLAG_END_STREAM | Frame.FLAG_PADDED), + stream_id=0x1234567, + payload='foobar', + pad_length=3) + assert_equal( + f.to_bytes().encode('hex'), + '00000a00090123456703666f6f626172000000') + + f = DataFrame( + length=6, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload='foobar') + tutils.raises(ValueError, f.to_bytes) + + +def test_data_frame_from_bytes(): + f = Frame.from_file(hex_to_file('000006000101234567666f6f626172')) + assert isinstance(f, DataFrame) + assert_equal(f.length, 6) + assert_equal(f.TYPE, DataFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_END_STREAM) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.payload, 'foobar') + + f = Frame.from_file(hex_to_file('00000a00090123456703666f6f626172000000')) + assert isinstance(f, DataFrame) + assert_equal(f.length, 10) + assert_equal(f.TYPE, DataFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_END_STREAM | Frame.FLAG_PADDED) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.payload, 'foobar') + + +def test_data_frame_human_readable(): + f = DataFrame( + length=11, + flags=(Frame.FLAG_END_STREAM | Frame.FLAG_PADDED), + stream_id=0x1234567, + payload='foobar', + pad_length=3) + assert f.human_readable() + + +def test_headers_frame_to_bytes(): + f = HeadersFrame( + length=6, + flags=(Frame.FLAG_NO_FLAGS), + stream_id=0x1234567, + header_block_fragment='668594e75e31d9'.decode('hex')) + assert_equal(f.to_bytes().encode('hex'), '000007010001234567668594e75e31d9') + + f = HeadersFrame( + length=10, + flags=(HeadersFrame.FLAG_PADDED), + stream_id=0x1234567, + header_block_fragment='668594e75e31d9'.decode('hex'), + pad_length=3) + assert_equal( + f.to_bytes().encode('hex'), + '00000b01080123456703668594e75e31d9000000') + + f = HeadersFrame( + length=10, + flags=(HeadersFrame.FLAG_PRIORITY), + stream_id=0x1234567, + header_block_fragment='668594e75e31d9'.decode('hex'), + exclusive=True, + stream_dependency=0x7654321, + weight=42) + assert_equal( + f.to_bytes().encode('hex'), + '00000c012001234567876543212a668594e75e31d9') + + f = HeadersFrame( + length=14, + flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), + stream_id=0x1234567, + header_block_fragment='668594e75e31d9'.decode('hex'), + pad_length=3, + exclusive=True, + stream_dependency=0x7654321, + weight=42) + assert_equal( + f.to_bytes().encode('hex'), + '00001001280123456703876543212a668594e75e31d9000000') + + f = HeadersFrame( + length=14, + flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), + stream_id=0x1234567, + header_block_fragment='668594e75e31d9'.decode('hex'), + pad_length=3, + exclusive=False, + stream_dependency=0x7654321, + weight=42) + assert_equal( + f.to_bytes().encode('hex'), + '00001001280123456703076543212a668594e75e31d9000000') + + f = HeadersFrame( + length=6, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment='668594e75e31d9'.decode('hex')) + tutils.raises(ValueError, f.to_bytes) + + +def test_headers_frame_from_bytes(): + f = Frame.from_file(hex_to_file( + '000007010001234567668594e75e31d9')) + assert isinstance(f, HeadersFrame) + assert_equal(f.length, 7) + assert_equal(f.TYPE, HeadersFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) + + f = Frame.from_file(hex_to_file( + '00000b01080123456703668594e75e31d9000000')) + assert isinstance(f, HeadersFrame) + assert_equal(f.length, 11) + assert_equal(f.TYPE, HeadersFrame.TYPE) + assert_equal(f.flags, HeadersFrame.FLAG_PADDED) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) + + f = Frame.from_file(hex_to_file( + '00000c012001234567876543212a668594e75e31d9')) + assert isinstance(f, HeadersFrame) + assert_equal(f.length, 12) + assert_equal(f.TYPE, HeadersFrame.TYPE) + assert_equal(f.flags, HeadersFrame.FLAG_PRIORITY) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) + assert_equal(f.exclusive, True) + assert_equal(f.stream_dependency, 0x7654321) + assert_equal(f.weight, 42) + + f = Frame.from_file(hex_to_file( + '00001001280123456703876543212a668594e75e31d9000000')) + assert isinstance(f, HeadersFrame) + assert_equal(f.length, 16) + assert_equal(f.TYPE, HeadersFrame.TYPE) + assert_equal(f.flags, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) + assert_equal(f.exclusive, True) + assert_equal(f.stream_dependency, 0x7654321) + assert_equal(f.weight, 42) + + f = Frame.from_file(hex_to_file( + '00001001280123456703076543212a668594e75e31d9000000')) + assert isinstance(f, HeadersFrame) + assert_equal(f.length, 16) + assert_equal(f.TYPE, HeadersFrame.TYPE) + assert_equal(f.flags, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) + assert_equal(f.exclusive, False) + assert_equal(f.stream_dependency, 0x7654321) + assert_equal(f.weight, 42) + + +def test_headers_frame_human_readable(): + f = HeadersFrame( + length=7, + flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), + stream_id=0x1234567, + header_block_fragment=b'', + pad_length=3, + exclusive=False, + stream_dependency=0x7654321, + weight=42) + assert f.human_readable() + + f = HeadersFrame( + length=14, + flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), + stream_id=0x1234567, + header_block_fragment='668594e75e31d9'.decode('hex'), + pad_length=3, + exclusive=False, + stream_dependency=0x7654321, + weight=42) + assert f.human_readable() + + +def test_priority_frame_to_bytes(): + f = PriorityFrame( + length=5, + flags=(Frame.FLAG_NO_FLAGS), + stream_id=0x1234567, + exclusive=True, + stream_dependency=0x7654321, + weight=42) + assert_equal(f.to_bytes().encode('hex'), '000005020001234567876543212a') + + f = PriorityFrame( + length=5, + flags=(Frame.FLAG_NO_FLAGS), + stream_id=0x1234567, + exclusive=False, + stream_dependency=0x7654321, + weight=21) + assert_equal(f.to_bytes().encode('hex'), '0000050200012345670765432115') + + f = PriorityFrame( + length=5, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + stream_dependency=0x1234567) + tutils.raises(ValueError, f.to_bytes) + + f = PriorityFrame( + length=5, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + stream_dependency=0x0) + tutils.raises(ValueError, f.to_bytes) + + +def test_priority_frame_from_bytes(): + f = Frame.from_file(hex_to_file('000005020001234567876543212a')) + assert isinstance(f, PriorityFrame) + assert_equal(f.length, 5) + assert_equal(f.TYPE, PriorityFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.exclusive, True) + assert_equal(f.stream_dependency, 0x7654321) + assert_equal(f.weight, 42) + + f = Frame.from_file(hex_to_file('0000050200012345670765432115')) + assert isinstance(f, PriorityFrame) + assert_equal(f.length, 5) + assert_equal(f.TYPE, PriorityFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.exclusive, False) + assert_equal(f.stream_dependency, 0x7654321) + assert_equal(f.weight, 21) + + +def test_priority_frame_human_readable(): + f = PriorityFrame( + length=5, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + exclusive=False, + stream_dependency=0x7654321, + weight=21) + assert f.human_readable() + + +def test_rst_stream_frame_to_bytes(): + f = RstStreamFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + error_code=0x7654321) + assert_equal(f.to_bytes().encode('hex'), '00000403000123456707654321') + + f = RstStreamFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0) + tutils.raises(ValueError, f.to_bytes) + + +def test_rst_stream_frame_from_bytes(): + f = Frame.from_file(hex_to_file('00000403000123456707654321')) + assert isinstance(f, RstStreamFrame) + assert_equal(f.length, 4) + assert_equal(f.TYPE, RstStreamFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.error_code, 0x07654321) + + +def test_rst_stream_frame_human_readable(): + f = RstStreamFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + error_code=0x7654321) + assert f.human_readable() + + +def test_settings_frame_to_bytes(): + f = SettingsFrame( + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0) + assert_equal(f.to_bytes().encode('hex'), '000000040000000000') + + f = SettingsFrame( + length=0, + flags=SettingsFrame.FLAG_ACK, + stream_id=0x0) + assert_equal(f.to_bytes().encode('hex'), '000000040100000000') + + f = SettingsFrame( + length=6, + flags=SettingsFrame.FLAG_ACK, + stream_id=0x0, + settings={ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1}) + assert_equal(f.to_bytes().encode('hex'), '000006040100000000000200000001') + + f = SettingsFrame( + length=12, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + settings={ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) + assert_equal( + f.to_bytes().encode('hex'), + '00000c040000000000000200000001000312345678') + + f = SettingsFrame( + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567) + tutils.raises(ValueError, f.to_bytes) + + +def test_settings_frame_from_bytes(): + f = Frame.from_file(hex_to_file('000000040000000000')) + assert isinstance(f, SettingsFrame) + assert_equal(f.length, 0) + assert_equal(f.TYPE, SettingsFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + + f = Frame.from_file(hex_to_file('000000040100000000')) + assert isinstance(f, SettingsFrame) + assert_equal(f.length, 0) + assert_equal(f.TYPE, SettingsFrame.TYPE) + assert_equal(f.flags, SettingsFrame.FLAG_ACK) + assert_equal(f.stream_id, 0x0) + + f = Frame.from_file(hex_to_file('000006040100000000000200000001')) + assert isinstance(f, SettingsFrame) + assert_equal(f.length, 6) + assert_equal(f.TYPE, SettingsFrame.TYPE) + assert_equal(f.flags, SettingsFrame.FLAG_ACK, 0x0) + assert_equal(f.stream_id, 0x0) + assert_equal(len(f.settings), 1) + assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) + + f = Frame.from_file(hex_to_file( + '00000c040000000000000200000001000312345678')) + assert isinstance(f, SettingsFrame) + assert_equal(f.length, 12) + assert_equal(f.TYPE, SettingsFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + assert_equal(len(f.settings), 2) + assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) + assert_equal( + f.settings[ + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS], + 0x12345678) + + +def test_settings_frame_human_readable(): + f = SettingsFrame( + length=12, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + settings={}) + assert f.human_readable() + + f = SettingsFrame( + length=12, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + settings={ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) + assert f.human_readable() + + +def test_push_promise_frame_to_bytes(): + f = PushPromiseFrame( + length=10, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + promised_stream=0x7654321, + header_block_fragment='foobar') + assert_equal( + f.to_bytes().encode('hex'), + '00000a05000123456707654321666f6f626172') + + f = PushPromiseFrame( + length=14, + flags=HeadersFrame.FLAG_PADDED, + stream_id=0x1234567, + promised_stream=0x7654321, + header_block_fragment='foobar', + pad_length=3) + assert_equal( + f.to_bytes().encode('hex'), + '00000e0508012345670307654321666f6f626172000000') + + f = PushPromiseFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + promised_stream=0x1234567) + tutils.raises(ValueError, f.to_bytes) + + f = PushPromiseFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + promised_stream=0x0) + tutils.raises(ValueError, f.to_bytes) + + +def test_push_promise_frame_from_bytes(): + f = Frame.from_file(hex_to_file('00000a05000123456707654321666f6f626172')) + assert isinstance(f, PushPromiseFrame) + assert_equal(f.length, 10) + assert_equal(f.TYPE, PushPromiseFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, 'foobar') + + f = Frame.from_file(hex_to_file( + '00000e0508012345670307654321666f6f626172000000')) + assert isinstance(f, PushPromiseFrame) + assert_equal(f.length, 14) + assert_equal(f.TYPE, PushPromiseFrame.TYPE) + assert_equal(f.flags, PushPromiseFrame.FLAG_PADDED) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, 'foobar') + + +def test_push_promise_frame_human_readable(): + f = PushPromiseFrame( + length=14, + flags=HeadersFrame.FLAG_PADDED, + stream_id=0x1234567, + promised_stream=0x7654321, + header_block_fragment='foobar', + pad_length=3) + assert f.human_readable() + + +def test_ping_frame_to_bytes(): + f = PingFrame( + length=8, + flags=PingFrame.FLAG_ACK, + stream_id=0x0, + payload=b'foobar') + assert_equal( + f.to_bytes().encode('hex'), + '000008060100000000666f6f6261720000') + + f = PingFrame( + length=8, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b'foobardeadbeef') + assert_equal( + f.to_bytes().encode('hex'), + '000008060000000000666f6f6261726465') + + f = PingFrame( + length=8, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567) + tutils.raises(ValueError, f.to_bytes) + + +def test_ping_frame_from_bytes(): + f = Frame.from_file(hex_to_file('000008060100000000666f6f6261720000')) + assert isinstance(f, PingFrame) + assert_equal(f.length, 8) + assert_equal(f.TYPE, PingFrame.TYPE) + assert_equal(f.flags, PingFrame.FLAG_ACK) + assert_equal(f.stream_id, 0x0) + assert_equal(f.payload, b'foobar\0\0') + + f = Frame.from_file(hex_to_file('000008060000000000666f6f6261726465')) + assert isinstance(f, PingFrame) + assert_equal(f.length, 8) + assert_equal(f.TYPE, PingFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + assert_equal(f.payload, b'foobarde') + + +def test_ping_frame_human_readable(): + f = PingFrame( + length=8, + flags=PingFrame.FLAG_ACK, + stream_id=0x0, + payload=b'foobar') + assert f.human_readable() + + +def test_goaway_frame_to_bytes(): + f = GoAwayFrame( + length=8, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + last_stream=0x1234567, + error_code=0x87654321, + data=b'') + assert_equal( + f.to_bytes().encode('hex'), + '0000080700000000000123456787654321') + + f = GoAwayFrame( + length=14, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + last_stream=0x1234567, + error_code=0x87654321, + data=b'foobar') + assert_equal( + f.to_bytes().encode('hex'), + '00000e0700000000000123456787654321666f6f626172') + + f = GoAwayFrame( + length=8, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + last_stream=0x1234567, + error_code=0x87654321) + tutils.raises(ValueError, f.to_bytes) + + +def test_goaway_frame_from_bytes(): + f = Frame.from_file(hex_to_file( + '0000080700000000000123456787654321')) + assert isinstance(f, GoAwayFrame) + assert_equal(f.length, 8) + assert_equal(f.TYPE, GoAwayFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + assert_equal(f.last_stream, 0x1234567) + assert_equal(f.error_code, 0x87654321) + assert_equal(f.data, b'') + + f = Frame.from_file(hex_to_file( + '00000e0700000000000123456787654321666f6f626172')) + assert isinstance(f, GoAwayFrame) + assert_equal(f.length, 14) + assert_equal(f.TYPE, GoAwayFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + assert_equal(f.last_stream, 0x1234567) + assert_equal(f.error_code, 0x87654321) + assert_equal(f.data, b'foobar') + + +def test_go_away_frame_human_readable(): + f = GoAwayFrame( + length=14, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + last_stream=0x1234567, + error_code=0x87654321, + data=b'foobar') + assert f.human_readable() + + +def test_window_update_frame_to_bytes(): + f = WindowUpdateFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + window_size_increment=0x1234567) + assert_equal(f.to_bytes().encode('hex'), '00000408000000000001234567') + + f = WindowUpdateFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + window_size_increment=0x7654321) + assert_equal(f.to_bytes().encode('hex'), '00000408000123456707654321') + + f = WindowUpdateFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + window_size_increment=0xdeadbeef) + tutils.raises(ValueError, f.to_bytes) + + f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x0, window_size_increment=0) + tutils.raises(ValueError, f.to_bytes) + + +def test_window_update_frame_from_bytes(): + f = Frame.from_file(hex_to_file('00000408000000000001234567')) + assert isinstance(f, WindowUpdateFrame) + assert_equal(f.length, 4) + assert_equal(f.TYPE, WindowUpdateFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + assert_equal(f.window_size_increment, 0x1234567) + + +def test_window_update_frame_human_readable(): + f = WindowUpdateFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + window_size_increment=0x7654321) + assert f.human_readable() + + +def test_continuation_frame_to_bytes(): + f = ContinuationFrame( + length=6, + flags=ContinuationFrame.FLAG_END_HEADERS, + stream_id=0x1234567, + header_block_fragment='foobar') + assert_equal(f.to_bytes().encode('hex'), '000006090401234567666f6f626172') + + f = ContinuationFrame( + length=6, + flags=ContinuationFrame.FLAG_END_HEADERS, + stream_id=0x0, + header_block_fragment='foobar') + tutils.raises(ValueError, f.to_bytes) + + +def test_continuation_frame_from_bytes(): + f = Frame.from_file(hex_to_file('000006090401234567666f6f626172')) + assert isinstance(f, ContinuationFrame) + assert_equal(f.length, 6) + assert_equal(f.TYPE, ContinuationFrame.TYPE) + assert_equal(f.flags, ContinuationFrame.FLAG_END_HEADERS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, 'foobar') + + +def test_continuation_frame_human_readable(): + f = ContinuationFrame( + length=6, + flags=ContinuationFrame.FLAG_END_HEADERS, + stream_id=0x1234567, + header_block_fragment='foobar') + assert f.human_readable() diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py new file mode 100644 index 00000000..f607860e --- /dev/null +++ b/test/http/http2/test_protocol.py @@ -0,0 +1,325 @@ +import OpenSSL + +from netlib import tcp +from netlib.http import http2 +from netlib.http.http2.frame import * +from ... import tutils, tservers + + +class EchoHandler(tcp.BaseHandler): + sni = None + + def handle(self): + while True: + v = self.rfile.safe_read(1) + self.wfile.write(v) + self.wfile.flush() + + +class TestCheckALPNMatch(tservers.ServerTestBase): + handler = EchoHandler + ssl = dict( + alpn_select=http2.HTTP2Protocol.ALPN_PROTO_H2, + ) + + if OpenSSL._util.lib.Cryptography_HAS_ALPN: + + def test_check_alpn(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) + protocol = http2.HTTP2Protocol(c) + assert protocol.check_alpn() + + +class TestCheckALPNMismatch(tservers.ServerTestBase): + handler = EchoHandler + ssl = dict( + alpn_select=None, + ) + + if OpenSSL._util.lib.Cryptography_HAS_ALPN: + + def test_check_alpn(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) + protocol = http2.HTTP2Protocol(c) + tutils.raises(NotImplementedError, protocol.check_alpn) + + +class TestPerformServerConnectionPreface(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + # send magic + self.wfile.write( + '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex')) + self.wfile.flush() + + # send empty settings frame + self.wfile.write('000000040000000000'.decode('hex')) + self.wfile.flush() + + # check empty settings frame + assert self.rfile.read(9) ==\ + '000000040000000000'.decode('hex') + + # check settings acknowledgement + assert self.rfile.read(9) == \ + '000000040100000000'.decode('hex') + + # send settings acknowledgement + self.wfile.write('000000040100000000'.decode('hex')) + self.wfile.flush() + + def test_perform_server_connection_preface(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + protocol = http2.HTTP2Protocol(c) + protocol.perform_server_connection_preface() + + +class TestPerformClientConnectionPreface(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + # check magic + assert self.rfile.read(24) ==\ + '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') + + # check empty settings frame + assert self.rfile.read(9) ==\ + '000000040000000000'.decode('hex') + + # send empty settings frame + self.wfile.write('000000040000000000'.decode('hex')) + self.wfile.flush() + + # check settings acknowledgement + assert self.rfile.read(9) == \ + '000000040100000000'.decode('hex') + + # send settings acknowledgement + self.wfile.write('000000040100000000'.decode('hex')) + self.wfile.flush() + + def test_perform_client_connection_preface(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + protocol = http2.HTTP2Protocol(c) + protocol.perform_client_connection_preface() + + +class TestClientStreamIds(): + c = tcp.TCPClient(("127.0.0.1", 0)) + protocol = http2.HTTP2Protocol(c) + + def test_client_stream_ids(self): + assert self.protocol.current_stream_id is None + assert self.protocol.next_stream_id() == 1 + assert self.protocol.current_stream_id == 1 + assert self.protocol.next_stream_id() == 3 + assert self.protocol.current_stream_id == 3 + assert self.protocol.next_stream_id() == 5 + assert self.protocol.current_stream_id == 5 + + +class TestServerStreamIds(): + c = tcp.TCPClient(("127.0.0.1", 0)) + protocol = http2.HTTP2Protocol(c, is_server=True) + + def test_server_stream_ids(self): + assert self.protocol.current_stream_id is None + assert self.protocol.next_stream_id() == 2 + assert self.protocol.current_stream_id == 2 + assert self.protocol.next_stream_id() == 4 + assert self.protocol.current_stream_id == 4 + assert self.protocol.next_stream_id() == 6 + assert self.protocol.current_stream_id == 6 + + +class TestApplySettings(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + # check settings acknowledgement + assert self.rfile.read(9) == '000000040100000000'.decode('hex') + self.wfile.write("OK") + self.wfile.flush() + + ssl = True + + def test_apply_settings(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = http2.HTTP2Protocol(c) + + protocol._apply_settings({ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 'foo', + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 'bar', + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 'deadbeef', + }) + + assert c.rfile.safe_read(2) == "OK" + + assert protocol.http2_settings[ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH] == 'foo' + assert protocol.http2_settings[ + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS] == 'bar' + assert protocol.http2_settings[ + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE] == 'deadbeef' + + +class TestCreateHeaders(): + c = tcp.TCPClient(("127.0.0.1", 0)) + + def test_create_headers(self): + headers = [ + (b':method', b'GET'), + (b':path', b'index.html'), + (b':scheme', b'https'), + (b'foo', b'bar')] + + bytes = http2.HTTP2Protocol(self.c)._create_headers( + headers, 1, end_stream=True) + assert b''.join(bytes) ==\ + '000014010500000001824488355217caf3a69a3f87408294e7838c767f'\ + .decode('hex') + + bytes = http2.HTTP2Protocol(self.c)._create_headers( + headers, 1, end_stream=False) + assert b''.join(bytes) ==\ + '000014010400000001824488355217caf3a69a3f87408294e7838c767f'\ + .decode('hex') + + # TODO: add test for too large header_block_fragments + + +class TestCreateBody(): + c = tcp.TCPClient(("127.0.0.1", 0)) + protocol = http2.HTTP2Protocol(c) + + def test_create_body_empty(self): + bytes = self.protocol._create_body(b'', 1) + assert b''.join(bytes) == ''.decode('hex') + + def test_create_body_single_frame(self): + bytes = self.protocol._create_body('foobar', 1) + assert b''.join(bytes) == '000006000100000001666f6f626172'.decode('hex') + + def test_create_body_multiple_frames(self): + pass + # bytes = self.protocol._create_body('foobar' * 3000, 1) + # TODO: add test for too large frames + + +class TestCreateRequest(): + c = tcp.TCPClient(("127.0.0.1", 0)) + + def test_create_request_simple(self): + bytes = http2.HTTP2Protocol(self.c).create_request('GET', '/') + assert len(bytes) == 1 + assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex') + + def test_create_request_with_body(self): + bytes = http2.HTTP2Protocol(self.c).create_request( + 'GET', '/', [(b'foo', b'bar')], 'foobar') + assert len(bytes) == 2 + assert bytes[0] ==\ + '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex') + assert bytes[1] ==\ + '000006000100000001666f6f626172'.decode('hex') + + +class TestReadResponse(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + self.wfile.write( + b'00000801040000000188628594e78c767f'.decode('hex')) + self.wfile.write( + b'000006000100000001666f6f626172'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_read_response(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = http2.HTTP2Protocol(c) + + status, headers, body = protocol.read_response() + + assert headers == {':status': '200', 'etag': 'foobar'} + assert status == "200" + assert body == b'foobar' + + +class TestReadEmptyResponse(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + self.wfile.write( + b'00000801050000000188628594e78c767f'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_read_empty_response(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = http2.HTTP2Protocol(c) + + status, headers, body = protocol.read_response() + + assert headers == {':status': '200', 'etag': 'foobar'} + assert status == "200" + assert body == b'' + + +class TestReadRequest(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + self.wfile.write( + b'000003010400000001828487'.decode('hex')) + self.wfile.write( + b'000006000100000001666f6f626172'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_read_request(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = http2.HTTP2Protocol(c, is_server=True) + + stream_id, headers, body = protocol.read_request() + + assert stream_id + assert headers == {':method': 'GET', ':path': '/', ':scheme': 'https'} + assert body == b'foobar' + + +class TestCreateResponse(): + c = tcp.TCPClient(("127.0.0.1", 0)) + + def test_create_response_simple(self): + bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(200) + assert len(bytes) == 1 + assert bytes[0] ==\ + '00000101050000000288'.decode('hex') + + def test_create_response_with_body(self): + bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response( + 200, 1, [(b'foo', b'bar')], 'foobar') + assert len(bytes) == 2 + assert bytes[0] ==\ + '00000901040000000188408294e7838c767f'.decode('hex') + assert bytes[1] ==\ + '000006000100000001666f6f626172'.decode('hex') diff --git a/test/http/test_authentication.py b/test/http/test_authentication.py new file mode 100644 index 00000000..c0dae1a2 --- /dev/null +++ b/test/http/test_authentication.py @@ -0,0 +1,110 @@ +from netlib import odict, http +from netlib.http import authentication +from .. import tutils + + +class TestPassManNonAnon: + + def test_simple(self): + p = authentication.PassManNonAnon() + assert not p.test("", "") + assert p.test("user", "") + + +class TestPassManHtpasswd: + + def test_file_errors(self): + tutils.raises( + "malformed htpasswd file", + authentication.PassManHtpasswd, + tutils.test_data.path("data/server.crt")) + + def test_simple(self): + pm = authentication.PassManHtpasswd(tutils.test_data.path("data/htpasswd")) + + vals = ("basic", "test", "test") + http.http1.assemble_http_basic_auth(*vals) + assert pm.test("test", "test") + assert not pm.test("test", "foo") + assert not pm.test("foo", "test") + assert not pm.test("test", "") + assert not pm.test("", "") + + +class TestPassManSingleUser: + + def test_simple(self): + pm = authentication.PassManSingleUser("test", "test") + assert pm.test("test", "test") + assert not pm.test("test", "foo") + assert not pm.test("foo", "test") + + +class TestNullProxyAuth: + + def test_simple(self): + na = authentication.NullProxyAuth(authentication.PassManNonAnon()) + assert not na.auth_challenge_headers() + assert na.authenticate("foo") + na.clean({}) + + +class TestBasicProxyAuth: + + def test_simple(self): + ba = authentication.BasicProxyAuth(authentication.PassManNonAnon(), "test") + h = odict.ODictCaseless() + assert ba.auth_challenge_headers() + assert not ba.authenticate(h) + + def test_authenticate_clean(self): + ba = authentication.BasicProxyAuth(authentication.PassManNonAnon(), "test") + + hdrs = odict.ODictCaseless() + vals = ("basic", "foo", "bar") + hdrs[ba.AUTH_HEADER] = [http.http1.assemble_http_basic_auth(*vals)] + assert ba.authenticate(hdrs) + + ba.clean(hdrs) + assert not ba.AUTH_HEADER in hdrs + + hdrs[ba.AUTH_HEADER] = [""] + assert not ba.authenticate(hdrs) + + hdrs[ba.AUTH_HEADER] = ["foo"] + assert not ba.authenticate(hdrs) + + vals = ("foo", "foo", "bar") + hdrs[ba.AUTH_HEADER] = [http.http1.assemble_http_basic_auth(*vals)] + assert not ba.authenticate(hdrs) + + ba = authentication.BasicProxyAuth(authentication.PassMan(), "test") + vals = ("basic", "foo", "bar") + hdrs[ba.AUTH_HEADER] = [http.http1.assemble_http_basic_auth(*vals)] + assert not ba.authenticate(hdrs) + + +class Bunch: + pass + + +class TestAuthAction: + + def test_nonanonymous(self): + m = Bunch() + aa = authentication.NonanonymousAuthAction(None, "authenticator") + aa(None, m, None, None) + assert m.authenticator + + def test_singleuser(self): + m = Bunch() + aa = authentication.SingleuserAuthAction(None, "authenticator") + aa(None, m, "foo:bar", None) + assert m.authenticator + tutils.raises("invalid", aa, None, m, "foo", None) + + def test_httppasswd(self): + m = Bunch() + aa = authentication.HtpasswdAuthAction(None, "authenticator") + aa(None, m, tutils.test_data.path("data/htpasswd"), None) + assert m.authenticator diff --git a/test/http/test_cookies.py b/test/http/test_cookies.py new file mode 100644 index 00000000..4f99593a --- /dev/null +++ b/test/http/test_cookies.py @@ -0,0 +1,219 @@ +import nose.tools + +from netlib.http import cookies + + +def test_read_token(): + tokens = [ + [("foo", 0), ("foo", 3)], + [("foo", 1), ("oo", 3)], + [(" foo", 1), ("foo", 4)], + [(" foo;", 1), ("foo", 4)], + [(" foo=", 1), ("foo", 4)], + [(" foo=bar", 1), ("foo", 4)], + ] + for q, a in tokens: + nose.tools.eq_(cookies._read_token(*q), a) + + +def test_read_quoted_string(): + tokens = [ + [('"foo" x', 0), ("foo", 5)], + [('"f\oo" x', 0), ("foo", 6)], + [(r'"f\\o" x', 0), (r"f\o", 6)], + [(r'"f\\" x', 0), (r"f" + '\\', 5)], + [('"fo\\\"" x', 0), ("fo\"", 6)], + ] + for q, a in tokens: + nose.tools.eq_(cookies._read_quoted_string(*q), a) + + +def test_read_pairs(): + vals = [ + [ + "one", + [["one", None]] + ], + [ + "one=two", + [["one", "two"]] + ], + [ + "one=", + [["one", ""]] + ], + [ + 'one="two"', + [["one", "two"]] + ], + [ + 'one="two"; three=four', + [["one", "two"], ["three", "four"]] + ], + [ + 'one="two"; three=four; five', + [["one", "two"], ["three", "four"], ["five", None]] + ], + [ + 'one="\\"two"; three=four', + [["one", '"two'], ["three", "four"]] + ], + ] + for s, lst in vals: + ret, off = cookies._read_pairs(s) + nose.tools.eq_(ret, lst) + + +def test_pairs_roundtrips(): + pairs = [ + [ + "", + [] + ], + [ + "one=uno", + [["one", "uno"]] + ], + [ + "one", + [["one", None]] + ], + [ + "one=uno; two=due", + [["one", "uno"], ["two", "due"]] + ], + [ + 'one="uno"; two="\due"', + [["one", "uno"], ["two", "due"]] + ], + [ + 'one="un\\"o"', + [["one", 'un"o']] + ], + [ + 'one="uno,due"', + [["one", 'uno,due']] + ], + [ + "one=uno; two; three=tre", + [["one", "uno"], ["two", None], ["three", "tre"]] + ], + [ + "_lvs2=zHai1+Hq+Tc2vmc2r4GAbdOI5Jopg3EwsdUT9g=; " + "_rcc2=53VdltWl+Ov6ordflA==;", + [ + ["_lvs2", "zHai1+Hq+Tc2vmc2r4GAbdOI5Jopg3EwsdUT9g="], + ["_rcc2", "53VdltWl+Ov6ordflA=="] + ] + ] + ] + for s, lst in pairs: + ret, off = cookies._read_pairs(s) + nose.tools.eq_(ret, lst) + s2 = cookies._format_pairs(lst) + ret, off = cookies._read_pairs(s2) + nose.tools.eq_(ret, lst) + + +def test_cookie_roundtrips(): + pairs = [ + [ + "one=uno", + [["one", "uno"]] + ], + [ + "one=uno; two=due", + [["one", "uno"], ["two", "due"]] + ], + ] + for s, lst in pairs: + ret = cookies.parse_cookie_header(s) + nose.tools.eq_(ret.lst, lst) + s2 = cookies.format_cookie_header(ret) + ret = cookies.parse_cookie_header(s2) + nose.tools.eq_(ret.lst, lst) + + +def test_parse_set_cookie_pairs(): + pairs = [ + [ + "one=uno", + [ + ["one", "uno"] + ] + ], + [ + "one=un\x20", + [ + ["one", "un\x20"] + ] + ], + [ + "one=uno; foo", + [ + ["one", "uno"], + ["foo", None] + ] + ], + [ + "mun=1.390.f60; " + "expires=sun, 11-oct-2015 12:38:31 gmt; path=/; " + "domain=b.aol.com", + [ + ["mun", "1.390.f60"], + ["expires", "sun, 11-oct-2015 12:38:31 gmt"], + ["path", "/"], + ["domain", "b.aol.com"] + ] + ], + [ + r'rpb=190%3d1%2616726%3d1%2634832%3d1%2634874%3d1; ' + 'domain=.rubiconproject.com; ' + 'expires=mon, 11-may-2015 21:54:57 gmt; ' + 'path=/', + [ + ['rpb', r'190%3d1%2616726%3d1%2634832%3d1%2634874%3d1'], + ['domain', '.rubiconproject.com'], + ['expires', 'mon, 11-may-2015 21:54:57 gmt'], + ['path', '/'] + ] + ], + ] + for s, lst in pairs: + ret = cookies._parse_set_cookie_pairs(s) + nose.tools.eq_(ret, lst) + s2 = cookies._format_set_cookie_pairs(ret) + ret2 = cookies._parse_set_cookie_pairs(s2) + nose.tools.eq_(ret2, lst) + + +def test_parse_set_cookie_header(): + vals = [ + [ + "", None + ], + [ + ";", None + ], + [ + "one=uno", + ("one", "uno", []) + ], + [ + "one=uno; foo=bar", + ("one", "uno", [["foo", "bar"]]) + ] + ] + for s, expected in vals: + ret = cookies.parse_set_cookie_header(s) + if expected: + assert ret[0] == expected[0] + assert ret[1] == expected[1] + nose.tools.eq_(ret[2].lst, expected[2]) + s2 = cookies.format_set_cookie_header(*ret) + ret2 = cookies.parse_set_cookie_header(s2) + assert ret2[0] == expected[0] + assert ret2[1] == expected[1] + nose.tools.eq_(ret2[2].lst, expected[2]) + else: + assert ret is None diff --git a/test/http/test_semantics.py b/test/http/test_semantics.py new file mode 100644 index 00000000..c4605302 --- /dev/null +++ b/test/http/test_semantics.py @@ -0,0 +1,54 @@ +import cStringIO +import textwrap +import binascii + +from netlib import http, odict, tcp +from netlib.http import http1 +from .. import tutils, tservers + +def test_httperror(): + e = http.exceptions.HttpError(404, "Not found") + assert str(e) + + +def test_parse_url(): + assert not http.parse_url("") + + u = "http://foo.com:8888/test" + s, h, po, pa = http.parse_url(u) + assert s == "http" + assert h == "foo.com" + assert po == 8888 + assert pa == "/test" + + s, h, po, pa = http.parse_url("http://foo/bar") + assert s == "http" + assert h == "foo" + assert po == 80 + assert pa == "/bar" + + s, h, po, pa = http.parse_url("http://user:pass@foo/bar") + assert s == "http" + assert h == "foo" + assert po == 80 + assert pa == "/bar" + + s, h, po, pa = http.parse_url("http://foo") + assert pa == "/" + + s, h, po, pa = http.parse_url("https://foo") + assert po == 443 + + assert not http.parse_url("https://foo:bar") + assert not http.parse_url("https://foo:") + + # Invalid IDNA + assert not http.parse_url("http://\xfafoo") + # Invalid PATH + assert not http.parse_url("http:/\xc6/localhost:56121") + # Null byte in host + assert not http.parse_url("http://foo\0") + # Port out of range + assert not http.parse_url("http://foo:999999") + # Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt + assert not http.parse_url('http://lo[calhost') diff --git a/test/http/test_user_agents.py b/test/http/test_user_agents.py new file mode 100644 index 00000000..0bf1bba7 --- /dev/null +++ b/test/http/test_user_agents.py @@ -0,0 +1,6 @@ +from netlib.http import user_agents + + +def test_get_shortcut(): + assert user_agents.get_by_shortcut("c")[0] == "chrome" + assert not user_agents.get_by_shortcut("_") diff --git a/test/http2/__init__.py b/test/http2/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/test/http2/test_frames.py b/test/http2/test_frames.py deleted file mode 100644 index 76a4b712..00000000 --- a/test/http2/test_frames.py +++ /dev/null @@ -1,704 +0,0 @@ -import cStringIO -from test import tutils -from nose.tools import assert_equal -from netlib import tcp -from netlib.http2.frame import * - - -def hex_to_file(data): - data = data.decode('hex') - return tcp.Reader(cStringIO.StringIO(data)) - - -def test_invalid_flags(): - tutils.raises( - ValueError, - DataFrame, - flags=ContinuationFrame.FLAG_END_HEADERS, - stream_id=0x1234567, - payload='foobar') - - -def test_frame_equality(): - a = DataFrame( - length=6, - flags=Frame.FLAG_END_STREAM, - stream_id=0x1234567, - payload='foobar') - b = DataFrame( - length=6, - flags=Frame.FLAG_END_STREAM, - stream_id=0x1234567, - payload='foobar') - assert_equal(a, b) - - -def test_too_large_frames(): - f = DataFrame( - length=9000, - flags=Frame.FLAG_END_STREAM, - stream_id=0x1234567, - payload='foobar' * 3000) - tutils.raises(FrameSizeError, f.to_bytes) - - -def test_data_frame_to_bytes(): - f = DataFrame( - length=6, - flags=Frame.FLAG_END_STREAM, - stream_id=0x1234567, - payload='foobar') - assert_equal(f.to_bytes().encode('hex'), '000006000101234567666f6f626172') - - f = DataFrame( - length=11, - flags=(Frame.FLAG_END_STREAM | Frame.FLAG_PADDED), - stream_id=0x1234567, - payload='foobar', - pad_length=3) - assert_equal( - f.to_bytes().encode('hex'), - '00000a00090123456703666f6f626172000000') - - f = DataFrame( - length=6, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload='foobar') - tutils.raises(ValueError, f.to_bytes) - - -def test_data_frame_from_bytes(): - f = Frame.from_file(hex_to_file('000006000101234567666f6f626172')) - assert isinstance(f, DataFrame) - assert_equal(f.length, 6) - assert_equal(f.TYPE, DataFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_END_STREAM) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.payload, 'foobar') - - f = Frame.from_file(hex_to_file('00000a00090123456703666f6f626172000000')) - assert isinstance(f, DataFrame) - assert_equal(f.length, 10) - assert_equal(f.TYPE, DataFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_END_STREAM | Frame.FLAG_PADDED) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.payload, 'foobar') - - -def test_data_frame_human_readable(): - f = DataFrame( - length=11, - flags=(Frame.FLAG_END_STREAM | Frame.FLAG_PADDED), - stream_id=0x1234567, - payload='foobar', - pad_length=3) - assert f.human_readable() - - -def test_headers_frame_to_bytes(): - f = HeadersFrame( - length=6, - flags=(Frame.FLAG_NO_FLAGS), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex')) - assert_equal(f.to_bytes().encode('hex'), '000007010001234567668594e75e31d9') - - f = HeadersFrame( - length=10, - flags=(HeadersFrame.FLAG_PADDED), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex'), - pad_length=3) - assert_equal( - f.to_bytes().encode('hex'), - '00000b01080123456703668594e75e31d9000000') - - f = HeadersFrame( - length=10, - flags=(HeadersFrame.FLAG_PRIORITY), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex'), - exclusive=True, - stream_dependency=0x7654321, - weight=42) - assert_equal( - f.to_bytes().encode('hex'), - '00000c012001234567876543212a668594e75e31d9') - - f = HeadersFrame( - length=14, - flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex'), - pad_length=3, - exclusive=True, - stream_dependency=0x7654321, - weight=42) - assert_equal( - f.to_bytes().encode('hex'), - '00001001280123456703876543212a668594e75e31d9000000') - - f = HeadersFrame( - length=14, - flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex'), - pad_length=3, - exclusive=False, - stream_dependency=0x7654321, - weight=42) - assert_equal( - f.to_bytes().encode('hex'), - '00001001280123456703076543212a668594e75e31d9000000') - - f = HeadersFrame( - length=6, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment='668594e75e31d9'.decode('hex')) - tutils.raises(ValueError, f.to_bytes) - - -def test_headers_frame_from_bytes(): - f = Frame.from_file(hex_to_file( - '000007010001234567668594e75e31d9')) - assert isinstance(f, HeadersFrame) - assert_equal(f.length, 7) - assert_equal(f.TYPE, HeadersFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - - f = Frame.from_file(hex_to_file( - '00000b01080123456703668594e75e31d9000000')) - assert isinstance(f, HeadersFrame) - assert_equal(f.length, 11) - assert_equal(f.TYPE, HeadersFrame.TYPE) - assert_equal(f.flags, HeadersFrame.FLAG_PADDED) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - - f = Frame.from_file(hex_to_file( - '00000c012001234567876543212a668594e75e31d9')) - assert isinstance(f, HeadersFrame) - assert_equal(f.length, 12) - assert_equal(f.TYPE, HeadersFrame.TYPE) - assert_equal(f.flags, HeadersFrame.FLAG_PRIORITY) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - assert_equal(f.exclusive, True) - assert_equal(f.stream_dependency, 0x7654321) - assert_equal(f.weight, 42) - - f = Frame.from_file(hex_to_file( - '00001001280123456703876543212a668594e75e31d9000000')) - assert isinstance(f, HeadersFrame) - assert_equal(f.length, 16) - assert_equal(f.TYPE, HeadersFrame.TYPE) - assert_equal(f.flags, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - assert_equal(f.exclusive, True) - assert_equal(f.stream_dependency, 0x7654321) - assert_equal(f.weight, 42) - - f = Frame.from_file(hex_to_file( - '00001001280123456703076543212a668594e75e31d9000000')) - assert isinstance(f, HeadersFrame) - assert_equal(f.length, 16) - assert_equal(f.TYPE, HeadersFrame.TYPE) - assert_equal(f.flags, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - assert_equal(f.exclusive, False) - assert_equal(f.stream_dependency, 0x7654321) - assert_equal(f.weight, 42) - - -def test_headers_frame_human_readable(): - f = HeadersFrame( - length=7, - flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), - stream_id=0x1234567, - header_block_fragment=b'', - pad_length=3, - exclusive=False, - stream_dependency=0x7654321, - weight=42) - assert f.human_readable() - - f = HeadersFrame( - length=14, - flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex'), - pad_length=3, - exclusive=False, - stream_dependency=0x7654321, - weight=42) - assert f.human_readable() - - -def test_priority_frame_to_bytes(): - f = PriorityFrame( - length=5, - flags=(Frame.FLAG_NO_FLAGS), - stream_id=0x1234567, - exclusive=True, - stream_dependency=0x7654321, - weight=42) - assert_equal(f.to_bytes().encode('hex'), '000005020001234567876543212a') - - f = PriorityFrame( - length=5, - flags=(Frame.FLAG_NO_FLAGS), - stream_id=0x1234567, - exclusive=False, - stream_dependency=0x7654321, - weight=21) - assert_equal(f.to_bytes().encode('hex'), '0000050200012345670765432115') - - f = PriorityFrame( - length=5, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - stream_dependency=0x1234567) - tutils.raises(ValueError, f.to_bytes) - - f = PriorityFrame( - length=5, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - stream_dependency=0x0) - tutils.raises(ValueError, f.to_bytes) - - -def test_priority_frame_from_bytes(): - f = Frame.from_file(hex_to_file('000005020001234567876543212a')) - assert isinstance(f, PriorityFrame) - assert_equal(f.length, 5) - assert_equal(f.TYPE, PriorityFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.exclusive, True) - assert_equal(f.stream_dependency, 0x7654321) - assert_equal(f.weight, 42) - - f = Frame.from_file(hex_to_file('0000050200012345670765432115')) - assert isinstance(f, PriorityFrame) - assert_equal(f.length, 5) - assert_equal(f.TYPE, PriorityFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.exclusive, False) - assert_equal(f.stream_dependency, 0x7654321) - assert_equal(f.weight, 21) - - -def test_priority_frame_human_readable(): - f = PriorityFrame( - length=5, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - exclusive=False, - stream_dependency=0x7654321, - weight=21) - assert f.human_readable() - - -def test_rst_stream_frame_to_bytes(): - f = RstStreamFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - error_code=0x7654321) - assert_equal(f.to_bytes().encode('hex'), '00000403000123456707654321') - - f = RstStreamFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0) - tutils.raises(ValueError, f.to_bytes) - - -def test_rst_stream_frame_from_bytes(): - f = Frame.from_file(hex_to_file('00000403000123456707654321')) - assert isinstance(f, RstStreamFrame) - assert_equal(f.length, 4) - assert_equal(f.TYPE, RstStreamFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.error_code, 0x07654321) - - -def test_rst_stream_frame_human_readable(): - f = RstStreamFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - error_code=0x7654321) - assert f.human_readable() - - -def test_settings_frame_to_bytes(): - f = SettingsFrame( - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0) - assert_equal(f.to_bytes().encode('hex'), '000000040000000000') - - f = SettingsFrame( - length=0, - flags=SettingsFrame.FLAG_ACK, - stream_id=0x0) - assert_equal(f.to_bytes().encode('hex'), '000000040100000000') - - f = SettingsFrame( - length=6, - flags=SettingsFrame.FLAG_ACK, - stream_id=0x0, - settings={ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1}) - assert_equal(f.to_bytes().encode('hex'), '000006040100000000000200000001') - - f = SettingsFrame( - length=12, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings={ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) - assert_equal( - f.to_bytes().encode('hex'), - '00000c040000000000000200000001000312345678') - - f = SettingsFrame( - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567) - tutils.raises(ValueError, f.to_bytes) - - -def test_settings_frame_from_bytes(): - f = Frame.from_file(hex_to_file('000000040000000000')) - assert isinstance(f, SettingsFrame) - assert_equal(f.length, 0) - assert_equal(f.TYPE, SettingsFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - - f = Frame.from_file(hex_to_file('000000040100000000')) - assert isinstance(f, SettingsFrame) - assert_equal(f.length, 0) - assert_equal(f.TYPE, SettingsFrame.TYPE) - assert_equal(f.flags, SettingsFrame.FLAG_ACK) - assert_equal(f.stream_id, 0x0) - - f = Frame.from_file(hex_to_file('000006040100000000000200000001')) - assert isinstance(f, SettingsFrame) - assert_equal(f.length, 6) - assert_equal(f.TYPE, SettingsFrame.TYPE) - assert_equal(f.flags, SettingsFrame.FLAG_ACK, 0x0) - assert_equal(f.stream_id, 0x0) - assert_equal(len(f.settings), 1) - assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) - - f = Frame.from_file(hex_to_file( - '00000c040000000000000200000001000312345678')) - assert isinstance(f, SettingsFrame) - assert_equal(f.length, 12) - assert_equal(f.TYPE, SettingsFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - assert_equal(len(f.settings), 2) - assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) - assert_equal( - f.settings[ - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS], - 0x12345678) - - -def test_settings_frame_human_readable(): - f = SettingsFrame( - length=12, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings={}) - assert f.human_readable() - - f = SettingsFrame( - length=12, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings={ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) - assert f.human_readable() - - -def test_push_promise_frame_to_bytes(): - f = PushPromiseFrame( - length=10, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - promised_stream=0x7654321, - header_block_fragment='foobar') - assert_equal( - f.to_bytes().encode('hex'), - '00000a05000123456707654321666f6f626172') - - f = PushPromiseFrame( - length=14, - flags=HeadersFrame.FLAG_PADDED, - stream_id=0x1234567, - promised_stream=0x7654321, - header_block_fragment='foobar', - pad_length=3) - assert_equal( - f.to_bytes().encode('hex'), - '00000e0508012345670307654321666f6f626172000000') - - f = PushPromiseFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - promised_stream=0x1234567) - tutils.raises(ValueError, f.to_bytes) - - f = PushPromiseFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - promised_stream=0x0) - tutils.raises(ValueError, f.to_bytes) - - -def test_push_promise_frame_from_bytes(): - f = Frame.from_file(hex_to_file('00000a05000123456707654321666f6f626172')) - assert isinstance(f, PushPromiseFrame) - assert_equal(f.length, 10) - assert_equal(f.TYPE, PushPromiseFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, 'foobar') - - f = Frame.from_file(hex_to_file( - '00000e0508012345670307654321666f6f626172000000')) - assert isinstance(f, PushPromiseFrame) - assert_equal(f.length, 14) - assert_equal(f.TYPE, PushPromiseFrame.TYPE) - assert_equal(f.flags, PushPromiseFrame.FLAG_PADDED) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, 'foobar') - - -def test_push_promise_frame_human_readable(): - f = PushPromiseFrame( - length=14, - flags=HeadersFrame.FLAG_PADDED, - stream_id=0x1234567, - promised_stream=0x7654321, - header_block_fragment='foobar', - pad_length=3) - assert f.human_readable() - - -def test_ping_frame_to_bytes(): - f = PingFrame( - length=8, - flags=PingFrame.FLAG_ACK, - stream_id=0x0, - payload=b'foobar') - assert_equal( - f.to_bytes().encode('hex'), - '000008060100000000666f6f6261720000') - - f = PingFrame( - length=8, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b'foobardeadbeef') - assert_equal( - f.to_bytes().encode('hex'), - '000008060000000000666f6f6261726465') - - f = PingFrame( - length=8, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567) - tutils.raises(ValueError, f.to_bytes) - - -def test_ping_frame_from_bytes(): - f = Frame.from_file(hex_to_file('000008060100000000666f6f6261720000')) - assert isinstance(f, PingFrame) - assert_equal(f.length, 8) - assert_equal(f.TYPE, PingFrame.TYPE) - assert_equal(f.flags, PingFrame.FLAG_ACK) - assert_equal(f.stream_id, 0x0) - assert_equal(f.payload, b'foobar\0\0') - - f = Frame.from_file(hex_to_file('000008060000000000666f6f6261726465')) - assert isinstance(f, PingFrame) - assert_equal(f.length, 8) - assert_equal(f.TYPE, PingFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - assert_equal(f.payload, b'foobarde') - - -def test_ping_frame_human_readable(): - f = PingFrame( - length=8, - flags=PingFrame.FLAG_ACK, - stream_id=0x0, - payload=b'foobar') - assert f.human_readable() - - -def test_goaway_frame_to_bytes(): - f = GoAwayFrame( - length=8, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x1234567, - error_code=0x87654321, - data=b'') - assert_equal( - f.to_bytes().encode('hex'), - '0000080700000000000123456787654321') - - f = GoAwayFrame( - length=14, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x1234567, - error_code=0x87654321, - data=b'foobar') - assert_equal( - f.to_bytes().encode('hex'), - '00000e0700000000000123456787654321666f6f626172') - - f = GoAwayFrame( - length=8, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - last_stream=0x1234567, - error_code=0x87654321) - tutils.raises(ValueError, f.to_bytes) - - -def test_goaway_frame_from_bytes(): - f = Frame.from_file(hex_to_file( - '0000080700000000000123456787654321')) - assert isinstance(f, GoAwayFrame) - assert_equal(f.length, 8) - assert_equal(f.TYPE, GoAwayFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - assert_equal(f.last_stream, 0x1234567) - assert_equal(f.error_code, 0x87654321) - assert_equal(f.data, b'') - - f = Frame.from_file(hex_to_file( - '00000e0700000000000123456787654321666f6f626172')) - assert isinstance(f, GoAwayFrame) - assert_equal(f.length, 14) - assert_equal(f.TYPE, GoAwayFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - assert_equal(f.last_stream, 0x1234567) - assert_equal(f.error_code, 0x87654321) - assert_equal(f.data, b'foobar') - - -def test_go_away_frame_human_readable(): - f = GoAwayFrame( - length=14, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x1234567, - error_code=0x87654321, - data=b'foobar') - assert f.human_readable() - - -def test_window_update_frame_to_bytes(): - f = WindowUpdateFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - window_size_increment=0x1234567) - assert_equal(f.to_bytes().encode('hex'), '00000408000000000001234567') - - f = WindowUpdateFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - window_size_increment=0x7654321) - assert_equal(f.to_bytes().encode('hex'), '00000408000123456707654321') - - f = WindowUpdateFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - window_size_increment=0xdeadbeef) - tutils.raises(ValueError, f.to_bytes) - - f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x0, window_size_increment=0) - tutils.raises(ValueError, f.to_bytes) - - -def test_window_update_frame_from_bytes(): - f = Frame.from_file(hex_to_file('00000408000000000001234567')) - assert isinstance(f, WindowUpdateFrame) - assert_equal(f.length, 4) - assert_equal(f.TYPE, WindowUpdateFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - assert_equal(f.window_size_increment, 0x1234567) - - -def test_window_update_frame_human_readable(): - f = WindowUpdateFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - window_size_increment=0x7654321) - assert f.human_readable() - - -def test_continuation_frame_to_bytes(): - f = ContinuationFrame( - length=6, - flags=ContinuationFrame.FLAG_END_HEADERS, - stream_id=0x1234567, - header_block_fragment='foobar') - assert_equal(f.to_bytes().encode('hex'), '000006090401234567666f6f626172') - - f = ContinuationFrame( - length=6, - flags=ContinuationFrame.FLAG_END_HEADERS, - stream_id=0x0, - header_block_fragment='foobar') - tutils.raises(ValueError, f.to_bytes) - - -def test_continuation_frame_from_bytes(): - f = Frame.from_file(hex_to_file('000006090401234567666f6f626172')) - assert isinstance(f, ContinuationFrame) - assert_equal(f.length, 6) - assert_equal(f.TYPE, ContinuationFrame.TYPE) - assert_equal(f.flags, ContinuationFrame.FLAG_END_HEADERS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, 'foobar') - - -def test_continuation_frame_human_readable(): - f = ContinuationFrame( - length=6, - flags=ContinuationFrame.FLAG_END_HEADERS, - stream_id=0x1234567, - header_block_fragment='foobar') - assert f.human_readable() diff --git a/test/http2/test_protocol.py b/test/http2/test_protocol.py deleted file mode 100644 index 5e2af34e..00000000 --- a/test/http2/test_protocol.py +++ /dev/null @@ -1,326 +0,0 @@ -import OpenSSL - -from netlib import http2 -from netlib import tcp -from netlib.http2.frame import * -from test import tutils -from .. import tservers - - -class EchoHandler(tcp.BaseHandler): - sni = None - - def handle(self): - while True: - v = self.rfile.safe_read(1) - self.wfile.write(v) - self.wfile.flush() - - -class TestCheckALPNMatch(tservers.ServerTestBase): - handler = EchoHandler - ssl = dict( - alpn_select=http2.HTTP2Protocol.ALPN_PROTO_H2, - ) - - if OpenSSL._util.lib.Cryptography_HAS_ALPN: - - def test_check_alpn(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) - protocol = http2.HTTP2Protocol(c) - assert protocol.check_alpn() - - -class TestCheckALPNMismatch(tservers.ServerTestBase): - handler = EchoHandler - ssl = dict( - alpn_select=None, - ) - - if OpenSSL._util.lib.Cryptography_HAS_ALPN: - - def test_check_alpn(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) - protocol = http2.HTTP2Protocol(c) - tutils.raises(NotImplementedError, protocol.check_alpn) - - -class TestPerformServerConnectionPreface(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - # send magic - self.wfile.write( - '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex')) - self.wfile.flush() - - # send empty settings frame - self.wfile.write('000000040000000000'.decode('hex')) - self.wfile.flush() - - # check empty settings frame - assert self.rfile.read(9) ==\ - '000000040000000000'.decode('hex') - - # check settings acknowledgement - assert self.rfile.read(9) == \ - '000000040100000000'.decode('hex') - - # send settings acknowledgement - self.wfile.write('000000040100000000'.decode('hex')) - self.wfile.flush() - - def test_perform_server_connection_preface(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - protocol = http2.HTTP2Protocol(c) - protocol.perform_server_connection_preface() - - -class TestPerformClientConnectionPreface(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - # check magic - assert self.rfile.read(24) ==\ - '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') - - # check empty settings frame - assert self.rfile.read(9) ==\ - '000000040000000000'.decode('hex') - - # send empty settings frame - self.wfile.write('000000040000000000'.decode('hex')) - self.wfile.flush() - - # check settings acknowledgement - assert self.rfile.read(9) == \ - '000000040100000000'.decode('hex') - - # send settings acknowledgement - self.wfile.write('000000040100000000'.decode('hex')) - self.wfile.flush() - - def test_perform_client_connection_preface(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - protocol = http2.HTTP2Protocol(c) - protocol.perform_client_connection_preface() - - -class TestClientStreamIds(): - c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = http2.HTTP2Protocol(c) - - def test_client_stream_ids(self): - assert self.protocol.current_stream_id is None - assert self.protocol.next_stream_id() == 1 - assert self.protocol.current_stream_id == 1 - assert self.protocol.next_stream_id() == 3 - assert self.protocol.current_stream_id == 3 - assert self.protocol.next_stream_id() == 5 - assert self.protocol.current_stream_id == 5 - - -class TestServerStreamIds(): - c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = http2.HTTP2Protocol(c, is_server=True) - - def test_server_stream_ids(self): - assert self.protocol.current_stream_id is None - assert self.protocol.next_stream_id() == 2 - assert self.protocol.current_stream_id == 2 - assert self.protocol.next_stream_id() == 4 - assert self.protocol.current_stream_id == 4 - assert self.protocol.next_stream_id() == 6 - assert self.protocol.current_stream_id == 6 - - -class TestApplySettings(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - # check settings acknowledgement - assert self.rfile.read(9) == '000000040100000000'.decode('hex') - self.wfile.write("OK") - self.wfile.flush() - - ssl = True - - def test_apply_settings(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c) - - protocol._apply_settings({ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 'foo', - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 'bar', - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 'deadbeef', - }) - - assert c.rfile.safe_read(2) == "OK" - - assert protocol.http2_settings[ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH] == 'foo' - assert protocol.http2_settings[ - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS] == 'bar' - assert protocol.http2_settings[ - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE] == 'deadbeef' - - -class TestCreateHeaders(): - c = tcp.TCPClient(("127.0.0.1", 0)) - - def test_create_headers(self): - headers = [ - (b':method', b'GET'), - (b':path', b'index.html'), - (b':scheme', b'https'), - (b'foo', b'bar')] - - bytes = http2.HTTP2Protocol(self.c)._create_headers( - headers, 1, end_stream=True) - assert b''.join(bytes) ==\ - '000014010500000001824488355217caf3a69a3f87408294e7838c767f'\ - .decode('hex') - - bytes = http2.HTTP2Protocol(self.c)._create_headers( - headers, 1, end_stream=False) - assert b''.join(bytes) ==\ - '000014010400000001824488355217caf3a69a3f87408294e7838c767f'\ - .decode('hex') - - # TODO: add test for too large header_block_fragments - - -class TestCreateBody(): - c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = http2.HTTP2Protocol(c) - - def test_create_body_empty(self): - bytes = self.protocol._create_body(b'', 1) - assert b''.join(bytes) == ''.decode('hex') - - def test_create_body_single_frame(self): - bytes = self.protocol._create_body('foobar', 1) - assert b''.join(bytes) == '000006000100000001666f6f626172'.decode('hex') - - def test_create_body_multiple_frames(self): - pass - # bytes = self.protocol._create_body('foobar' * 3000, 1) - # TODO: add test for too large frames - - -class TestCreateRequest(): - c = tcp.TCPClient(("127.0.0.1", 0)) - - def test_create_request_simple(self): - bytes = http2.HTTP2Protocol(self.c).create_request('GET', '/') - assert len(bytes) == 1 - assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex') - - def test_create_request_with_body(self): - bytes = http2.HTTP2Protocol(self.c).create_request( - 'GET', '/', [(b'foo', b'bar')], 'foobar') - assert len(bytes) == 2 - assert bytes[0] ==\ - '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex') - assert bytes[1] ==\ - '000006000100000001666f6f626172'.decode('hex') - - -class TestReadResponse(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - self.wfile.write( - b'00000801040000000188628594e78c767f'.decode('hex')) - self.wfile.write( - b'000006000100000001666f6f626172'.decode('hex')) - self.wfile.flush() - - ssl = True - - def test_read_response(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c) - - status, headers, body = protocol.read_response() - - assert headers == {':status': '200', 'etag': 'foobar'} - assert status == "200" - assert body == b'foobar' - - -class TestReadEmptyResponse(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - self.wfile.write( - b'00000801050000000188628594e78c767f'.decode('hex')) - self.wfile.flush() - - ssl = True - - def test_read_empty_response(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c) - - status, headers, body = protocol.read_response() - - assert headers == {':status': '200', 'etag': 'foobar'} - assert status == "200" - assert body == b'' - - -class TestReadRequest(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - self.wfile.write( - b'000003010400000001828487'.decode('hex')) - self.wfile.write( - b'000006000100000001666f6f626172'.decode('hex')) - self.wfile.flush() - - ssl = True - - def test_read_request(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c, is_server=True) - - stream_id, headers, body = protocol.read_request() - - assert stream_id - assert headers == {':method': 'GET', ':path': '/', ':scheme': 'https'} - assert body == b'foobar' - - -class TestCreateResponse(): - c = tcp.TCPClient(("127.0.0.1", 0)) - - def test_create_response_simple(self): - bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(200) - assert len(bytes) == 1 - assert bytes[0] ==\ - '00000101050000000288'.decode('hex') - - def test_create_response_with_body(self): - bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response( - 200, 1, [(b'foo', b'bar')], 'foobar') - assert len(bytes) == 2 - assert bytes[0] ==\ - '00000901040000000188408294e7838c767f'.decode('hex') - assert bytes[1] ==\ - '000006000100000001666f6f626172'.decode('hex') diff --git a/test/test_http.py b/test/test_http.py deleted file mode 100644 index bbc78847..00000000 --- a/test/test_http.py +++ /dev/null @@ -1,491 +0,0 @@ -import cStringIO -import textwrap -import binascii -from netlib import http, http_semantics, odict, tcp -from . import tutils, tservers - - -def test_httperror(): - e = http.HttpError(404, "Not found") - assert str(e) - - -def test_has_chunked_encoding(): - h = odict.ODictCaseless() - assert not http.has_chunked_encoding(h) - h["transfer-encoding"] = ["chunked"] - assert http.has_chunked_encoding(h) - - -def test_read_chunked(): - - h = odict.ODictCaseless() - h["transfer-encoding"] = ["chunked"] - s = cStringIO.StringIO("1\r\na\r\n0\r\n") - - tutils.raises( - "malformed chunked body", - http.read_http_body, - s, h, None, "GET", None, True - ) - - s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") - assert http.read_http_body(s, h, None, "GET", None, True) == "a" - - s = cStringIO.StringIO("\r\n\r\n1\r\na\r\n0\r\n\r\n") - assert http.read_http_body(s, h, None, "GET", None, True) == "a" - - s = cStringIO.StringIO("\r\n") - tutils.raises( - "closed prematurely", - http.read_http_body, - s, h, None, "GET", None, True - ) - - s = cStringIO.StringIO("1\r\nfoo") - tutils.raises( - "malformed chunked body", - http.read_http_body, - s, h, None, "GET", None, True - ) - - s = cStringIO.StringIO("foo\r\nfoo") - tutils.raises( - http.HttpError, - http.read_http_body, - s, h, None, "GET", None, True - ) - - s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") - tutils.raises("too large", http.read_http_body, s, h, 2, "GET", None, True) - - -def test_connection_close(): - h = odict.ODictCaseless() - assert http.connection_close((1, 0), h) - assert not http.connection_close((1, 1), h) - - h["connection"] = ["keep-alive"] - assert not http.connection_close((1, 1), h) - - h["connection"] = ["close"] - assert http.connection_close((1, 1), h) - - -def test_get_header_tokens(): - h = odict.ODictCaseless() - assert http.get_header_tokens(h, "foo") == [] - h["foo"] = ["bar"] - assert http.get_header_tokens(h, "foo") == ["bar"] - h["foo"] = ["bar, voing"] - assert http.get_header_tokens(h, "foo") == ["bar", "voing"] - h["foo"] = ["bar, voing", "oink"] - assert http.get_header_tokens(h, "foo") == ["bar", "voing", "oink"] - - -def test_read_http_body_request(): - h = odict.ODictCaseless() - r = cStringIO.StringIO("testing") - assert http.read_http_body(r, h, None, "GET", None, True) == "" - - -def test_read_http_body_response(): - h = odict.ODictCaseless() - s = tcp.Reader(cStringIO.StringIO("testing")) - assert http.read_http_body(s, h, None, "GET", 200, False) == "testing" - - -def test_read_http_body(): - # test default case - h = odict.ODictCaseless() - h["content-length"] = [7] - s = cStringIO.StringIO("testing") - assert http.read_http_body(s, h, None, "GET", 200, False) == "testing" - - # test content length: invalid header - h["content-length"] = ["foo"] - s = cStringIO.StringIO("testing") - tutils.raises( - http.HttpError, - http.read_http_body, - s, h, None, "GET", 200, False - ) - - # test content length: invalid header #2 - h["content-length"] = [-1] - s = cStringIO.StringIO("testing") - tutils.raises( - http.HttpError, - http.read_http_body, - s, h, None, "GET", 200, False - ) - - # test content length: content length > actual content - h["content-length"] = [5] - s = cStringIO.StringIO("testing") - tutils.raises( - http.HttpError, - http.read_http_body, - s, h, 4, "GET", 200, False - ) - - # test content length: content length < actual content - s = cStringIO.StringIO("testing") - assert len(http.read_http_body(s, h, None, "GET", 200, False)) == 5 - - # test no content length: limit > actual content - h = odict.ODictCaseless() - s = tcp.Reader(cStringIO.StringIO("testing")) - assert len(http.read_http_body(s, h, 100, "GET", 200, False)) == 7 - - # test no content length: limit < actual content - s = tcp.Reader(cStringIO.StringIO("testing")) - tutils.raises( - http.HttpError, - http.read_http_body, - s, h, 4, "GET", 200, False - ) - - # test chunked - h = odict.ODictCaseless() - h["transfer-encoding"] = ["chunked"] - s = tcp.Reader(cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n")) - assert http.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa" - - -def test_expected_http_body_size(): - # gibber in the content-length field - h = odict.ODictCaseless() - h["content-length"] = ["foo"] - assert http.expected_http_body_size(h, False, "GET", 200) is None - # negative number in the content-length field - h = odict.ODictCaseless() - h["content-length"] = ["-7"] - assert http.expected_http_body_size(h, False, "GET", 200) is None - # explicit length - h = odict.ODictCaseless() - h["content-length"] = ["5"] - assert http.expected_http_body_size(h, False, "GET", 200) == 5 - # no length - h = odict.ODictCaseless() - assert http.expected_http_body_size(h, False, "GET", 200) == -1 - # no length request - h = odict.ODictCaseless() - assert http.expected_http_body_size(h, True, "GET", None) == 0 - - -def test_parse_http_protocol(): - assert http.parse_http_protocol("HTTP/1.1") == (1, 1) - assert http.parse_http_protocol("HTTP/0.0") == (0, 0) - assert not http.parse_http_protocol("HTTP/a.1") - assert not http.parse_http_protocol("HTTP/1.a") - assert not http.parse_http_protocol("foo/0.0") - assert not http.parse_http_protocol("HTTP/x") - - -def test_parse_init_connect(): - assert http.parse_init_connect("CONNECT host.com:443 HTTP/1.0") - assert not http.parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0") - assert not http.parse_init_connect("CONNECT \0host.com:443 HTTP/1.0") - assert not http.parse_init_connect("CONNECT host.com:444444 HTTP/1.0") - assert not http.parse_init_connect("bogus") - assert not http.parse_init_connect("GET host.com:443 HTTP/1.0") - assert not http.parse_init_connect("CONNECT host.com443 HTTP/1.0") - assert not http.parse_init_connect("CONNECT host.com:443 foo/1.0") - assert not http.parse_init_connect("CONNECT host.com:foo HTTP/1.0") - - -def test_parse_init_proxy(): - u = "GET http://foo.com:8888/test HTTP/1.1" - m, s, h, po, pa, httpversion = http.parse_init_proxy(u) - assert m == "GET" - assert s == "http" - assert h == "foo.com" - assert po == 8888 - assert pa == "/test" - assert httpversion == (1, 1) - - u = "G\xfeET http://foo.com:8888/test HTTP/1.1" - assert not http.parse_init_proxy(u) - - assert not http.parse_init_proxy("invalid") - assert not http.parse_init_proxy("GET invalid HTTP/1.1") - assert not http.parse_init_proxy("GET http://foo.com:8888/test foo/1.1") - - -def test_parse_init_http(): - u = "GET /test HTTP/1.1" - m, u, httpversion = http.parse_init_http(u) - assert m == "GET" - assert u == "/test" - assert httpversion == (1, 1) - - u = "G\xfeET /test HTTP/1.1" - assert not http.parse_init_http(u) - - assert not http.parse_init_http("invalid") - assert not http.parse_init_http("GET invalid HTTP/1.1") - assert not http.parse_init_http("GET /test foo/1.1") - assert not http.parse_init_http("GET /test\xc0 HTTP/1.1") - - -class TestReadHeaders: - - def _read(self, data, verbatim=False): - if not verbatim: - data = textwrap.dedent(data) - data = data.strip() - s = cStringIO.StringIO(data) - return http.read_headers(s) - - def test_read_simple(self): - data = """ - Header: one - Header2: two - \r\n - """ - h = self._read(data) - assert h.lst == [["Header", "one"], ["Header2", "two"]] - - def test_read_multi(self): - data = """ - Header: one - Header: two - \r\n - """ - h = self._read(data) - assert h.lst == [["Header", "one"], ["Header", "two"]] - - def test_read_continued(self): - data = """ - Header: one - \ttwo - Header2: three - \r\n - """ - h = self._read(data) - assert h.lst == [["Header", "one\r\n two"], ["Header2", "three"]] - - def test_read_continued_err(self): - data = "\tfoo: bar\r\n" - assert self._read(data, True) is None - - def test_read_err(self): - data = """ - foo - """ - assert self._read(data) is None - - -class NoContentLengthHTTPHandler(tcp.BaseHandler): - - def handle(self): - self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n") - self.wfile.flush() - - -class TestReadResponseNoContentLength(tservers.ServerTestBase): - handler = NoContentLengthHTTPHandler - - def test_no_content_length(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - resp = http.read_response(c.rfile, "GET", None) - assert resp.content == "bar\r\n\r\n" - - -def test_read_response(): - def tst(data, method, limit, include_body=True): - data = textwrap.dedent(data) - r = cStringIO.StringIO(data) - return http.read_response( - r, method, limit, include_body=include_body - ) - - tutils.raises("server disconnect", tst, "", "GET", None) - tutils.raises("invalid server response", tst, "foo", "GET", None) - data = """ - HTTP/1.1 200 OK - """ - assert tst(data, "GET", None) == http_semantics.Response( - (1, 1), 200, 'OK', odict.ODictCaseless(), '' - ) - data = """ - HTTP/1.1 200 - """ - assert tst(data, "GET", None) == http_semantics.Response( - (1, 1), 200, '', odict.ODictCaseless(), '' - ) - data = """ - HTTP/x 200 OK - """ - tutils.raises("invalid http version", tst, data, "GET", None) - data = """ - HTTP/1.1 xx OK - """ - tutils.raises("invalid server response", tst, data, "GET", None) - - data = """ - HTTP/1.1 100 CONTINUE - - HTTP/1.1 200 OK - """ - assert tst(data, "GET", None) == http_semantics.Response( - (1, 1), 100, 'CONTINUE', odict.ODictCaseless(), '' - ) - - data = """ - HTTP/1.1 200 OK - Content-Length: 3 - - foo - """ - assert tst(data, "GET", None).content == 'foo' - assert tst(data, "HEAD", None).content == '' - - data = """ - HTTP/1.1 200 OK - \tContent-Length: 3 - - foo - """ - tutils.raises("invalid headers", tst, data, "GET", None) - - data = """ - HTTP/1.1 200 OK - Content-Length: 3 - - foo - """ - assert tst(data, "GET", None, include_body=False).content is None - - -def test_parse_url(): - assert not http.parse_url("") - - u = "http://foo.com:8888/test" - s, h, po, pa = http.parse_url(u) - assert s == "http" - assert h == "foo.com" - assert po == 8888 - assert pa == "/test" - - s, h, po, pa = http.parse_url("http://foo/bar") - assert s == "http" - assert h == "foo" - assert po == 80 - assert pa == "/bar" - - s, h, po, pa = http.parse_url("http://user:pass@foo/bar") - assert s == "http" - assert h == "foo" - assert po == 80 - assert pa == "/bar" - - s, h, po, pa = http.parse_url("http://foo") - assert pa == "/" - - s, h, po, pa = http.parse_url("https://foo") - assert po == 443 - - assert not http.parse_url("https://foo:bar") - assert not http.parse_url("https://foo:") - - # Invalid IDNA - assert not http.parse_url("http://\xfafoo") - # Invalid PATH - assert not http.parse_url("http:/\xc6/localhost:56121") - # Null byte in host - assert not http.parse_url("http://foo\0") - # Port out of range - assert not http.parse_url("http://foo:999999") - # Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt - assert not http.parse_url('http://lo[calhost') - - -def test_parse_http_basic_auth(): - vals = ("basic", "foo", "bar") - assert http.parse_http_basic_auth( - http.assemble_http_basic_auth(*vals) - ) == vals - assert not http.parse_http_basic_auth("") - assert not http.parse_http_basic_auth("foo bar") - v = "basic " + binascii.b2a_base64("foo") - assert not http.parse_http_basic_auth(v) - - -def test_get_request_line(): - r = cStringIO.StringIO("\nfoo") - assert http.get_request_line(r) == "foo" - assert not http.get_request_line(r) - - -class TestReadRequest(): - - def tst(self, data, **kwargs): - r = cStringIO.StringIO(data) - return http.read_request(r, **kwargs) - - def test_invalid(self): - tutils.raises( - "bad http request", - self.tst, - "xxx" - ) - tutils.raises( - "bad http request line", - self.tst, - "get /\xff HTTP/1.1" - ) - tutils.raises( - "invalid headers", - self.tst, - "get / HTTP/1.1\r\nfoo" - ) - tutils.raises( - tcp.NetLibDisconnect, - self.tst, - "\r\n" - ) - - def test_asterisk_form_in(self): - v = self.tst("OPTIONS * HTTP/1.1") - assert v.form_in == "relative" - assert v.method == "OPTIONS" - - def test_absolute_form_in(self): - tutils.raises( - "Bad HTTP request line", - self.tst, - "GET oops-no-protocol.com HTTP/1.1" - ) - v = self.tst("GET http://address:22/ HTTP/1.1") - assert v.form_in == "absolute" - assert v.port == 22 - assert v.host == "address" - assert v.scheme == "http" - - def test_connect(self): - tutils.raises( - "Bad HTTP request line", - self.tst, - "CONNECT oops-no-port.com HTTP/1.1" - ) - v = self.tst("CONNECT foo.com:443 HTTP/1.1") - assert v.form_in == "authority" - assert v.method == "CONNECT" - assert v.port == 443 - assert v.host == "foo.com" - - def test_expect(self): - w = cStringIO.StringIO() - r = cStringIO.StringIO( - "GET / HTTP/1.1\r\n" - "Content-Length: 3\r\n" - "Expect: 100-continue\r\n\r\n" - "foobar", - ) - v = http.read_request(r, wfile=w) - assert w.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" - assert v.content == "foo" - assert r.read(3) == "bar" diff --git a/test/test_http_auth.py b/test/test_http_auth.py deleted file mode 100644 index c842925b..00000000 --- a/test/test_http_auth.py +++ /dev/null @@ -1,109 +0,0 @@ -from netlib import odict, http_auth, http -import tutils - - -class TestPassManNonAnon: - - def test_simple(self): - p = http_auth.PassManNonAnon() - assert not p.test("", "") - assert p.test("user", "") - - -class TestPassManHtpasswd: - - def test_file_errors(self): - tutils.raises( - "malformed htpasswd file", - http_auth.PassManHtpasswd, - tutils.test_data.path("data/server.crt")) - - def test_simple(self): - pm = http_auth.PassManHtpasswd(tutils.test_data.path("data/htpasswd")) - - vals = ("basic", "test", "test") - http.assemble_http_basic_auth(*vals) - assert pm.test("test", "test") - assert not pm.test("test", "foo") - assert not pm.test("foo", "test") - assert not pm.test("test", "") - assert not pm.test("", "") - - -class TestPassManSingleUser: - - def test_simple(self): - pm = http_auth.PassManSingleUser("test", "test") - assert pm.test("test", "test") - assert not pm.test("test", "foo") - assert not pm.test("foo", "test") - - -class TestNullProxyAuth: - - def test_simple(self): - na = http_auth.NullProxyAuth(http_auth.PassManNonAnon()) - assert not na.auth_challenge_headers() - assert na.authenticate("foo") - na.clean({}) - - -class TestBasicProxyAuth: - - def test_simple(self): - ba = http_auth.BasicProxyAuth(http_auth.PassManNonAnon(), "test") - h = odict.ODictCaseless() - assert ba.auth_challenge_headers() - assert not ba.authenticate(h) - - def test_authenticate_clean(self): - ba = http_auth.BasicProxyAuth(http_auth.PassManNonAnon(), "test") - - hdrs = odict.ODictCaseless() - vals = ("basic", "foo", "bar") - hdrs[ba.AUTH_HEADER] = [http.assemble_http_basic_auth(*vals)] - assert ba.authenticate(hdrs) - - ba.clean(hdrs) - assert not ba.AUTH_HEADER in hdrs - - hdrs[ba.AUTH_HEADER] = [""] - assert not ba.authenticate(hdrs) - - hdrs[ba.AUTH_HEADER] = ["foo"] - assert not ba.authenticate(hdrs) - - vals = ("foo", "foo", "bar") - hdrs[ba.AUTH_HEADER] = [http.assemble_http_basic_auth(*vals)] - assert not ba.authenticate(hdrs) - - ba = http_auth.BasicProxyAuth(http_auth.PassMan(), "test") - vals = ("basic", "foo", "bar") - hdrs[ba.AUTH_HEADER] = [http.assemble_http_basic_auth(*vals)] - assert not ba.authenticate(hdrs) - - -class Bunch: - pass - - -class TestAuthAction: - - def test_nonanonymous(self): - m = Bunch() - aa = http_auth.NonanonymousAuthAction(None, "authenticator") - aa(None, m, None, None) - assert m.authenticator - - def test_singleuser(self): - m = Bunch() - aa = http_auth.SingleuserAuthAction(None, "authenticator") - aa(None, m, "foo:bar", None) - assert m.authenticator - tutils.raises("invalid", aa, None, m, "foo", None) - - def test_httppasswd(self): - m = Bunch() - aa = http_auth.HtpasswdAuthAction(None, "authenticator") - aa(None, m, tutils.test_data.path("data/htpasswd"), None) - assert m.authenticator diff --git a/test/test_http_cookies.py b/test/test_http_cookies.py deleted file mode 100644 index 070849cf..00000000 --- a/test/test_http_cookies.py +++ /dev/null @@ -1,219 +0,0 @@ -import nose.tools - -from netlib import http_cookies - - -def test_read_token(): - tokens = [ - [("foo", 0), ("foo", 3)], - [("foo", 1), ("oo", 3)], - [(" foo", 1), ("foo", 4)], - [(" foo;", 1), ("foo", 4)], - [(" foo=", 1), ("foo", 4)], - [(" foo=bar", 1), ("foo", 4)], - ] - for q, a in tokens: - nose.tools.eq_(http_cookies._read_token(*q), a) - - -def test_read_quoted_string(): - tokens = [ - [('"foo" x', 0), ("foo", 5)], - [('"f\oo" x', 0), ("foo", 6)], - [(r'"f\\o" x', 0), (r"f\o", 6)], - [(r'"f\\" x', 0), (r"f" + '\\', 5)], - [('"fo\\\"" x', 0), ("fo\"", 6)], - ] - for q, a in tokens: - nose.tools.eq_(http_cookies._read_quoted_string(*q), a) - - -def test_read_pairs(): - vals = [ - [ - "one", - [["one", None]] - ], - [ - "one=two", - [["one", "two"]] - ], - [ - "one=", - [["one", ""]] - ], - [ - 'one="two"', - [["one", "two"]] - ], - [ - 'one="two"; three=four', - [["one", "two"], ["three", "four"]] - ], - [ - 'one="two"; three=four; five', - [["one", "two"], ["three", "four"], ["five", None]] - ], - [ - 'one="\\"two"; three=four', - [["one", '"two'], ["three", "four"]] - ], - ] - for s, lst in vals: - ret, off = http_cookies._read_pairs(s) - nose.tools.eq_(ret, lst) - - -def test_pairs_roundtrips(): - pairs = [ - [ - "", - [] - ], - [ - "one=uno", - [["one", "uno"]] - ], - [ - "one", - [["one", None]] - ], - [ - "one=uno; two=due", - [["one", "uno"], ["two", "due"]] - ], - [ - 'one="uno"; two="\due"', - [["one", "uno"], ["two", "due"]] - ], - [ - 'one="un\\"o"', - [["one", 'un"o']] - ], - [ - 'one="uno,due"', - [["one", 'uno,due']] - ], - [ - "one=uno; two; three=tre", - [["one", "uno"], ["two", None], ["three", "tre"]] - ], - [ - "_lvs2=zHai1+Hq+Tc2vmc2r4GAbdOI5Jopg3EwsdUT9g=; " - "_rcc2=53VdltWl+Ov6ordflA==;", - [ - ["_lvs2", "zHai1+Hq+Tc2vmc2r4GAbdOI5Jopg3EwsdUT9g="], - ["_rcc2", "53VdltWl+Ov6ordflA=="] - ] - ] - ] - for s, lst in pairs: - ret, off = http_cookies._read_pairs(s) - nose.tools.eq_(ret, lst) - s2 = http_cookies._format_pairs(lst) - ret, off = http_cookies._read_pairs(s2) - nose.tools.eq_(ret, lst) - - -def test_cookie_roundtrips(): - pairs = [ - [ - "one=uno", - [["one", "uno"]] - ], - [ - "one=uno; two=due", - [["one", "uno"], ["two", "due"]] - ], - ] - for s, lst in pairs: - ret = http_cookies.parse_cookie_header(s) - nose.tools.eq_(ret.lst, lst) - s2 = http_cookies.format_cookie_header(ret) - ret = http_cookies.parse_cookie_header(s2) - nose.tools.eq_(ret.lst, lst) - - -def test_parse_set_cookie_pairs(): - pairs = [ - [ - "one=uno", - [ - ["one", "uno"] - ] - ], - [ - "one=un\x20", - [ - ["one", "un\x20"] - ] - ], - [ - "one=uno; foo", - [ - ["one", "uno"], - ["foo", None] - ] - ], - [ - "mun=1.390.f60; " - "expires=sun, 11-oct-2015 12:38:31 gmt; path=/; " - "domain=b.aol.com", - [ - ["mun", "1.390.f60"], - ["expires", "sun, 11-oct-2015 12:38:31 gmt"], - ["path", "/"], - ["domain", "b.aol.com"] - ] - ], - [ - r'rpb=190%3d1%2616726%3d1%2634832%3d1%2634874%3d1; ' - 'domain=.rubiconproject.com; ' - 'expires=mon, 11-may-2015 21:54:57 gmt; ' - 'path=/', - [ - ['rpb', r'190%3d1%2616726%3d1%2634832%3d1%2634874%3d1'], - ['domain', '.rubiconproject.com'], - ['expires', 'mon, 11-may-2015 21:54:57 gmt'], - ['path', '/'] - ] - ], - ] - for s, lst in pairs: - ret = http_cookies._parse_set_cookie_pairs(s) - nose.tools.eq_(ret, lst) - s2 = http_cookies._format_set_cookie_pairs(ret) - ret2 = http_cookies._parse_set_cookie_pairs(s2) - nose.tools.eq_(ret2, lst) - - -def test_parse_set_cookie_header(): - vals = [ - [ - "", None - ], - [ - ";", None - ], - [ - "one=uno", - ("one", "uno", []) - ], - [ - "one=uno; foo=bar", - ("one", "uno", [["foo", "bar"]]) - ] - ] - for s, expected in vals: - ret = http_cookies.parse_set_cookie_header(s) - if expected: - assert ret[0] == expected[0] - assert ret[1] == expected[1] - nose.tools.eq_(ret[2].lst, expected[2]) - s2 = http_cookies.format_set_cookie_header(*ret) - ret2 = http_cookies.parse_set_cookie_header(s2) - assert ret2[0] == expected[0] - assert ret2[1] == expected[1] - nose.tools.eq_(ret2[2].lst, expected[2]) - else: - assert ret is None diff --git a/test/test_http_uastrings.py b/test/test_http_uastrings.py deleted file mode 100644 index 3fa4f359..00000000 --- a/test/test_http_uastrings.py +++ /dev/null @@ -1,6 +0,0 @@ -from netlib import http_uastrings - - -def test_get_shortcut(): - assert http_uastrings.get_by_shortcut("c")[0] == "chrome" - assert not http_uastrings.get_by_shortcut("_") diff --git a/test/test_websockets.py b/test/test_websockets.py deleted file mode 100644 index ae0a5e33..00000000 --- a/test/test_websockets.py +++ /dev/null @@ -1,261 +0,0 @@ -import os - -from nose.tools import raises - -from netlib import tcp, websockets, http -from . import tutils, tservers - - -class WebSocketsEchoHandler(tcp.BaseHandler): - - def __init__(self, connection, address, server): - super(WebSocketsEchoHandler, self).__init__( - connection, address, server - ) - self.protocol = websockets.WebsocketsProtocol() - self.handshake_done = False - - def handle(self): - while True: - if not self.handshake_done: - self.handshake() - else: - self.read_next_message() - - def read_next_message(self): - frame = websockets.Frame.from_file(self.rfile) - self.on_message(frame.payload) - - def send_message(self, message): - frame = websockets.Frame.default(message, from_client=False) - frame.to_file(self.wfile) - - def handshake(self): - req = http.read_request(self.rfile) - key = self.protocol.check_client_handshake(req.headers) - - self.wfile.write(http.response_preamble(101) + "\r\n") - headers = self.protocol.server_handshake_headers(key) - self.wfile.write(headers.format() + "\r\n") - self.wfile.flush() - self.handshake_done = True - - def on_message(self, message): - if message is not None: - self.send_message(message) - - -class WebSocketsClient(tcp.TCPClient): - - def __init__(self, address, source_address=None): - super(WebSocketsClient, self).__init__(address, source_address) - self.protocol = websockets.WebsocketsProtocol() - self.client_nonce = None - - def connect(self): - super(WebSocketsClient, self).connect() - - preamble = http.request_preamble("GET", "/") - self.wfile.write(preamble + "\r\n") - headers = self.protocol.client_handshake_headers() - self.client_nonce = headers.get_first("sec-websocket-key") - self.wfile.write(headers.format() + "\r\n") - self.wfile.flush() - - resp = http.read_response(self.rfile, "get", None) - server_nonce = self.protocol.check_server_handshake(resp.headers) - - if not server_nonce == self.protocol.create_server_nonce( - self.client_nonce): - self.close() - - def read_next_message(self): - return websockets.Frame.from_file(self.rfile).payload - - def send_message(self, message): - frame = websockets.Frame.default(message, from_client=True) - frame.to_file(self.wfile) - - -class TestWebSockets(tservers.ServerTestBase): - handler = WebSocketsEchoHandler - - def __init__(self): - self.protocol = websockets.WebsocketsProtocol() - - def random_bytes(self, n=100): - return os.urandom(n) - - def echo(self, msg): - client = WebSocketsClient(("127.0.0.1", self.port)) - client.connect() - client.send_message(msg) - response = client.read_next_message() - assert response == msg - - def test_simple_echo(self): - self.echo("hello I'm the client") - - def test_frame_sizes(self): - # length can fit in the the 7 bit payload length - small_msg = self.random_bytes(100) - # 50kb, sligthly larger than can fit in a 7 bit int - medium_msg = self.random_bytes(50000) - # 150kb, slightly larger than can fit in a 16 bit int - large_msg = self.random_bytes(150000) - - self.echo(small_msg) - self.echo(medium_msg) - self.echo(large_msg) - - def test_default_builder(self): - """ - default builder should always generate valid frames - """ - msg = self.random_bytes() - client_frame = websockets.Frame.default(msg, from_client=True) - server_frame = websockets.Frame.default(msg, from_client=False) - - def test_serialization_bijection(self): - """ - Ensure that various frame types can be serialized/deserialized back - and forth between to_bytes() and from_bytes() - """ - for is_client in [True, False]: - for num_bytes in [100, 50000, 150000]: - frame = websockets.Frame.default( - self.random_bytes(num_bytes), is_client - ) - frame2 = websockets.Frame.from_bytes( - frame.to_bytes() - ) - assert frame == frame2 - - bytes = b'\x81\x03cba' - assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes - - def test_check_server_handshake(self): - headers = self.protocol.server_handshake_headers("key") - assert self.protocol.check_server_handshake(headers) - headers["Upgrade"] = ["not_websocket"] - assert not self.protocol.check_server_handshake(headers) - - def test_check_client_handshake(self): - headers = self.protocol.client_handshake_headers("key") - assert self.protocol.check_client_handshake(headers) == "key" - headers["Upgrade"] = ["not_websocket"] - assert not self.protocol.check_client_handshake(headers) - - -class BadHandshakeHandler(WebSocketsEchoHandler): - - def handshake(self): - client_hs = http.read_request(self.rfile) - self.protocol.check_client_handshake(client_hs.headers) - - self.wfile.write(http.response_preamble(101) + "\r\n") - headers = self.protocol.server_handshake_headers("malformed key") - self.wfile.write(headers.format() + "\r\n") - self.wfile.flush() - self.handshake_done = True - - -class TestBadHandshake(tservers.ServerTestBase): - - """ - Ensure that the client disconnects if the server handshake is malformed - """ - handler = BadHandshakeHandler - - @raises(tcp.NetLibDisconnect) - def test(self): - client = WebSocketsClient(("127.0.0.1", self.port)) - client.connect() - client.send_message("hello") - - -class TestFrameHeader: - - def test_roundtrip(self): - def round(*args, **kwargs): - f = websockets.FrameHeader(*args, **kwargs) - bytes = f.to_bytes() - f2 = websockets.FrameHeader.from_file(tutils.treader(bytes)) - assert f == f2 - round() - round(fin=1) - round(rsv1=1) - round(rsv2=1) - round(rsv3=1) - round(payload_length=1) - round(payload_length=100) - round(payload_length=1000) - round(payload_length=10000) - round(opcode=websockets.OPCODE.PING) - round(masking_key="test") - - def test_human_readable(self): - f = websockets.FrameHeader( - masking_key="test", - fin=True, - payload_length=10 - ) - assert f.human_readable() - f = websockets.FrameHeader() - assert f.human_readable() - - def test_funky(self): - f = websockets.FrameHeader(masking_key="test", mask=False) - bytes = f.to_bytes() - f2 = websockets.FrameHeader.from_file(tutils.treader(bytes)) - assert not f2.mask - - def test_violations(self): - tutils.raises("opcode", websockets.FrameHeader, opcode=17) - tutils.raises("masking key", websockets.FrameHeader, masking_key="x") - - def test_automask(self): - f = websockets.FrameHeader(mask=True) - assert f.masking_key - - f = websockets.FrameHeader(masking_key="foob") - assert f.mask - - f = websockets.FrameHeader(masking_key="foob", mask=0) - assert not f.mask - assert f.masking_key - - -class TestFrame: - - def test_roundtrip(self): - def round(*args, **kwargs): - f = websockets.Frame(*args, **kwargs) - bytes = f.to_bytes() - f2 = websockets.Frame.from_file(tutils.treader(bytes)) - assert f == f2 - round("test") - round("test", fin=1) - round("test", rsv1=1) - round("test", opcode=websockets.OPCODE.PING) - round("test", masking_key="test") - - def test_human_readable(self): - f = websockets.Frame() - assert f.human_readable() - - -def test_masker(): - tests = [ - ["a"], - ["four"], - ["fourf"], - ["fourfive"], - ["a", "aasdfasdfa", "asdf"], - ["a" * 50, "aasdfasdfa", "asdf"], - ] - for i in tests: - m = websockets.Masker("abcd") - data = "".join([m(t) for t in i]) - data2 = websockets.Masker("abcd")(data) - assert data2 == "".join(i) diff --git a/test/websockets/__init__.py b/test/websockets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py new file mode 100644 index 00000000..07ad0452 --- /dev/null +++ b/test/websockets/test_websockets.py @@ -0,0 +1,262 @@ +import os + +from nose.tools import raises + +from netlib import tcp, http, websockets +from netlib.http.exceptions import * +from .. import tutils, tservers + + +class WebSocketsEchoHandler(tcp.BaseHandler): + + def __init__(self, connection, address, server): + super(WebSocketsEchoHandler, self).__init__( + connection, address, server + ) + self.protocol = websockets.WebsocketsProtocol() + self.handshake_done = False + + def handle(self): + while True: + if not self.handshake_done: + self.handshake() + else: + self.read_next_message() + + def read_next_message(self): + frame = websockets.Frame.from_file(self.rfile) + self.on_message(frame.payload) + + def send_message(self, message): + frame = websockets.Frame.default(message, from_client=False) + frame.to_file(self.wfile) + + def handshake(self): + req = http.http1.read_request(self.rfile) + key = self.protocol.check_client_handshake(req.headers) + + self.wfile.write(http.http1.response_preamble(101) + "\r\n") + headers = self.protocol.server_handshake_headers(key) + self.wfile.write(headers.format() + "\r\n") + self.wfile.flush() + self.handshake_done = True + + def on_message(self, message): + if message is not None: + self.send_message(message) + + +class WebSocketsClient(tcp.TCPClient): + + def __init__(self, address, source_address=None): + super(WebSocketsClient, self).__init__(address, source_address) + self.protocol = websockets.WebsocketsProtocol() + self.client_nonce = None + + def connect(self): + super(WebSocketsClient, self).connect() + + preamble = http.http1.protocol.request_preamble("GET", "/") + self.wfile.write(preamble + "\r\n") + headers = self.protocol.client_handshake_headers() + self.client_nonce = headers.get_first("sec-websocket-key") + self.wfile.write(headers.format() + "\r\n") + self.wfile.flush() + + resp = http.http1.protocol.read_response(self.rfile, "get", None) + server_nonce = self.protocol.check_server_handshake(resp.headers) + + if not server_nonce == self.protocol.create_server_nonce( + self.client_nonce): + self.close() + + def read_next_message(self): + return websockets.Frame.from_file(self.rfile).payload + + def send_message(self, message): + frame = websockets.Frame.default(message, from_client=True) + frame.to_file(self.wfile) + + +class TestWebSockets(tservers.ServerTestBase): + handler = WebSocketsEchoHandler + + def __init__(self): + self.protocol = websockets.WebsocketsProtocol() + + def random_bytes(self, n=100): + return os.urandom(n) + + def echo(self, msg): + client = WebSocketsClient(("127.0.0.1", self.port)) + client.connect() + client.send_message(msg) + response = client.read_next_message() + assert response == msg + + def test_simple_echo(self): + self.echo("hello I'm the client") + + def test_frame_sizes(self): + # length can fit in the the 7 bit payload length + small_msg = self.random_bytes(100) + # 50kb, sligthly larger than can fit in a 7 bit int + medium_msg = self.random_bytes(50000) + # 150kb, slightly larger than can fit in a 16 bit int + large_msg = self.random_bytes(150000) + + self.echo(small_msg) + self.echo(medium_msg) + self.echo(large_msg) + + def test_default_builder(self): + """ + default builder should always generate valid frames + """ + msg = self.random_bytes() + client_frame = websockets.Frame.default(msg, from_client=True) + server_frame = websockets.Frame.default(msg, from_client=False) + + def test_serialization_bijection(self): + """ + Ensure that various frame types can be serialized/deserialized back + and forth between to_bytes() and from_bytes() + """ + for is_client in [True, False]: + for num_bytes in [100, 50000, 150000]: + frame = websockets.Frame.default( + self.random_bytes(num_bytes), is_client + ) + frame2 = websockets.Frame.from_bytes( + frame.to_bytes() + ) + assert frame == frame2 + + bytes = b'\x81\x03cba' + assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes + + def test_check_server_handshake(self): + headers = self.protocol.server_handshake_headers("key") + assert self.protocol.check_server_handshake(headers) + headers["Upgrade"] = ["not_websocket"] + assert not self.protocol.check_server_handshake(headers) + + def test_check_client_handshake(self): + headers = self.protocol.client_handshake_headers("key") + assert self.protocol.check_client_handshake(headers) == "key" + headers["Upgrade"] = ["not_websocket"] + assert not self.protocol.check_client_handshake(headers) + + +class BadHandshakeHandler(WebSocketsEchoHandler): + + def handshake(self): + client_hs = http.http1.protocol.read_request(self.rfile) + self.protocol.check_client_handshake(client_hs.headers) + + self.wfile.write(http.http1.protocol.response_preamble(101) + "\r\n") + headers = self.protocol.server_handshake_headers("malformed key") + self.wfile.write(headers.format() + "\r\n") + self.wfile.flush() + self.handshake_done = True + + +class TestBadHandshake(tservers.ServerTestBase): + + """ + Ensure that the client disconnects if the server handshake is malformed + """ + handler = BadHandshakeHandler + + @raises(tcp.NetLibDisconnect) + def test(self): + client = WebSocketsClient(("127.0.0.1", self.port)) + client.connect() + client.send_message("hello") + + +class TestFrameHeader: + + def test_roundtrip(self): + def round(*args, **kwargs): + f = websockets.FrameHeader(*args, **kwargs) + bytes = f.to_bytes() + f2 = websockets.FrameHeader.from_file(tutils.treader(bytes)) + assert f == f2 + round() + round(fin=1) + round(rsv1=1) + round(rsv2=1) + round(rsv3=1) + round(payload_length=1) + round(payload_length=100) + round(payload_length=1000) + round(payload_length=10000) + round(opcode=websockets.OPCODE.PING) + round(masking_key="test") + + def test_human_readable(self): + f = websockets.FrameHeader( + masking_key="test", + fin=True, + payload_length=10 + ) + assert f.human_readable() + f = websockets.FrameHeader() + assert f.human_readable() + + def test_funky(self): + f = websockets.FrameHeader(masking_key="test", mask=False) + bytes = f.to_bytes() + f2 = websockets.FrameHeader.from_file(tutils.treader(bytes)) + assert not f2.mask + + def test_violations(self): + tutils.raises("opcode", websockets.FrameHeader, opcode=17) + tutils.raises("masking key", websockets.FrameHeader, masking_key="x") + + def test_automask(self): + f = websockets.FrameHeader(mask=True) + assert f.masking_key + + f = websockets.FrameHeader(masking_key="foob") + assert f.mask + + f = websockets.FrameHeader(masking_key="foob", mask=0) + assert not f.mask + assert f.masking_key + + +class TestFrame: + + def test_roundtrip(self): + def round(*args, **kwargs): + f = websockets.Frame(*args, **kwargs) + bytes = f.to_bytes() + f2 = websockets.Frame.from_file(tutils.treader(bytes)) + assert f == f2 + round("test") + round("test", fin=1) + round("test", rsv1=1) + round("test", opcode=websockets.OPCODE.PING) + round("test", masking_key="test") + + def test_human_readable(self): + f = websockets.Frame() + assert f.human_readable() + + +def test_masker(): + tests = [ + ["a"], + ["four"], + ["fourf"], + ["fourfive"], + ["a", "aasdfasdfa", "asdf"], + ["a" * 50, "aasdfasdfa", "asdf"], + ] + for i in tests: + m = websockets.Masker("abcd") + data = "".join([m(t) for t in i]) + data2 = websockets.Masker("abcd")(data) + assert data2 == "".join(i) -- cgit v1.2.3