diff options
-rw-r--r-- | netlib/http.py | 595 | ||||
-rw-r--r-- | netlib/http/__init__.py | 3 | ||||
-rw-r--r-- | netlib/http/authentication.py (renamed from netlib/http_auth.py) | 25 | ||||
-rw-r--r-- | netlib/http/cookies.py (renamed from netlib/http_cookies.py) | 8 | ||||
-rw-r--r-- | netlib/http/exceptions.py | 9 | ||||
-rw-r--r-- | netlib/http/http1/__init__.py | 1 | ||||
-rw-r--r-- | netlib/http/http1/protocol.py | 496 | ||||
-rw-r--r-- | netlib/http/http2/__init__.py (renamed from netlib/http2/__init__.py) | 0 | ||||
-rw-r--r-- | netlib/http/http2/frame.py (renamed from netlib/http2/frame.py) | 0 | ||||
-rw-r--r-- | netlib/http/http2/protocol.py (renamed from netlib/http2/protocol.py) | 31 | ||||
-rw-r--r-- | netlib/http/semantics.py | 159 | ||||
-rw-r--r-- | netlib/http/status_codes.py (renamed from netlib/http_status.py) | 0 | ||||
-rw-r--r-- | netlib/http/user_agents.py (renamed from netlib/http_uastrings.py) | 0 | ||||
-rw-r--r-- | netlib/websockets/__init__.py | 2 | ||||
-rw-r--r-- | netlib/websockets/frame.py (renamed from netlib/websockets.py) | 133 | ||||
-rw-r--r-- | netlib/websockets/protocol.py | 111 | ||||
-rw-r--r-- | netlib/wsgi.py | 6 | ||||
-rw-r--r-- | test/http/__init__.py (renamed from test/http2/__init__.py) | 0 | ||||
-rw-r--r-- | test/http/http1/__init__.py | 0 | ||||
-rw-r--r-- | test/http/http1/test_protocol.py (renamed from test/test_http.py) | 274 | ||||
-rw-r--r-- | test/http/http2/__init__.py | 0 | ||||
-rw-r--r-- | test/http/http2/test_frames.py (renamed from test/http2/test_frames.py) | 2 | ||||
-rw-r--r-- | test/http/http2/test_protocol.py (renamed from test/http2/test_protocol.py) | 38 | ||||
-rw-r--r-- | test/http/test_authentication.py (renamed from test/test_http_auth.py) | 48 | ||||
-rw-r--r-- | test/http/test_cookies.py (renamed from test/test_http_cookies.py) | 32 | ||||
-rw-r--r-- | test/http/test_semantics.py | 54 | ||||
-rw-r--r-- | test/http/test_user_agents.py | 6 | ||||
-rw-r--r-- | test/test_http_uastrings.py | 6 | ||||
-rw-r--r-- | test/websockets/__init__.py | 0 | ||||
-rw-r--r-- | test/websockets/test_websockets.py (renamed from test/test_websockets.py) | 57 |
30 files changed, 1131 insertions, 965 deletions
diff --git a/netlib/http.py b/netlib/http.py deleted file mode 100644 index a2af9e49..00000000 --- a/netlib/http.py +++ /dev/null @@ -1,595 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import collections -import string -import urlparse -import binascii -import sys -from . import odict, utils, tcp, http_status - - -class HttpError(Exception): - - def __init__(self, code, message): - super(HttpError, self).__init__(message) - self.code = code - - -class HttpErrorConnClosed(HttpError): - pass - - -def _is_valid_port(port): - if not 0 <= port <= 65535: - return False - return True - - -def _is_valid_host(host): - try: - host.decode("idna") - except ValueError: - return False - if "\0" in host: - return None - return True - - -def get_request_line(fp): - """ - Get a line, possibly preceded by a blank. - """ - line = fp.readline() - if line == "\r\n" or line == "\n": - # Possible leftover from previous message - line = fp.readline() - return line - - -def parse_url(url): - """ - Returns a (scheme, host, port, path) tuple, or None on error. - - Checks that: - port is an integer 0-65535 - host is a valid IDNA-encoded hostname with no null-bytes - path is valid ASCII - """ - try: - scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) - except ValueError: - return None - if not scheme: - return None - if '@' in netloc: - # FIXME: Consider what to do with the discarded credentials here Most - # probably we should extend the signature to return these as a separate - # value. - _, netloc = string.rsplit(netloc, '@', maxsplit=1) - if ':' in netloc: - host, port = string.rsplit(netloc, ':', maxsplit=1) - try: - port = int(port) - except ValueError: - return None - else: - host = netloc - if scheme == "https": - port = 443 - else: - port = 80 - path = urlparse.urlunparse(('', '', path, params, query, fragment)) - if not path.startswith("/"): - path = "/" + path - if not _is_valid_host(host): - return None - if not utils.isascii(path): - return None - if not _is_valid_port(port): - return None - return scheme, host, port, path - - -def read_headers(fp): - """ - Read a set of headers from a file pointer. Stop once a blank line is - reached. Return a ODictCaseless object, or None if headers are invalid. - """ - ret = [] - name = '' - while True: - line = fp.readline() - if not line or line == '\r\n' or line == '\n': - break - if line[0] in ' \t': - if not ret: - return None - # continued header - ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() - else: - i = line.find(':') - # We're being liberal in what we accept, here. - if i > 0: - name = line[:i] - value = line[i + 1:].strip() - ret.append([name, value]) - else: - return None - return odict.ODictCaseless(ret) - - -def read_chunked(fp, limit, is_request): - """ - Read a chunked HTTP body. - - May raise HttpError. - """ - # FIXME: Should check if chunked is the final encoding in the headers - # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 - # 3.3 2. - total = 0 - code = 400 if is_request else 502 - while True: - line = fp.readline(128) - if line == "": - raise HttpErrorConnClosed(code, "Connection closed prematurely") - if line != '\r\n' and line != '\n': - try: - length = int(line, 16) - except ValueError: - raise HttpError( - code, - "Invalid chunked encoding length: %s" % line - ) - total += length - if limit is not None and total > limit: - msg = "HTTP Body too large. Limit is %s," \ - " chunked content longer than %s" % (limit, total) - raise HttpError(code, msg) - chunk = fp.read(length) - suffix = fp.readline(5) - if suffix != '\r\n': - raise HttpError(code, "Malformed chunked body") - yield line, chunk, '\r\n' - if length == 0: - return - - -def get_header_tokens(headers, key): - """ - Retrieve all tokens for a header key. A number of different headers - follow a pattern where each header line can containe comma-separated - tokens, and headers can be set multiple times. - """ - toks = [] - for i in headers[key]: - for j in i.split(","): - toks.append(j.strip()) - return toks - - -def has_chunked_encoding(headers): - return "chunked" in [ - i.lower() for i in get_header_tokens(headers, "transfer-encoding") - ] - - -def parse_http_protocol(s): - """ - Parse an HTTP protocol declaration. Returns a (major, minor) tuple, or - None. - """ - if not s.startswith("HTTP/"): - return None - _, version = s.split('/', 1) - if "." not in version: - return None - major, minor = version.split('.', 1) - try: - major = int(major) - minor = int(minor) - except ValueError: - return None - return major, minor - - -def parse_http_basic_auth(s): - words = s.split() - if len(words) != 2: - return None - scheme = words[0] - try: - user = binascii.a2b_base64(words[1]) - except binascii.Error: - return None - parts = user.split(':') - if len(parts) != 2: - return None - return scheme, parts[0], parts[1] - - -def assemble_http_basic_auth(scheme, username, password): - v = binascii.b2a_base64(username + ":" + password) - return scheme + " " + v - - -def parse_init(line): - try: - method, url, protocol = string.split(line) - except ValueError: - return None - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None - if not utils.isascii(method): - return None - return method, url, httpversion - - -def parse_init_connect(line): - """ - Returns (host, port, httpversion) if line is a valid CONNECT line. - http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 - """ - v = parse_init(line) - if not v: - return None - method, url, httpversion = v - - if method.upper() != 'CONNECT': - return None - try: - host, port = url.split(":") - except ValueError: - return None - try: - port = int(port) - except ValueError: - return None - if not _is_valid_port(port): - return None - if not _is_valid_host(host): - return None - return host, port, httpversion - - -def parse_init_proxy(line): - v = parse_init(line) - if not v: - return None - method, url, httpversion = v - - parts = parse_url(url) - if not parts: - return None - scheme, host, port, path = parts - return method, scheme, host, port, path, httpversion - - -def parse_init_http(line): - """ - Returns (method, url, httpversion) - """ - v = parse_init(line) - if not v: - return None - method, url, httpversion = v - if not utils.isascii(url): - return None - if not (url.startswith("/") or url == "*"): - return None - return method, url, httpversion - - -def connection_close(httpversion, headers): - """ - Checks the message to see if the client connection should be closed - according to RFC 2616 Section 8.1 Note that a connection should be - closed as well if the response has been read until end of the stream. - """ - # At first, check if we have an explicit Connection header. - if "connection" in headers: - toks = get_header_tokens(headers, "connection") - if "close" in toks: - return True - elif "keep-alive" in toks: - return False - # If we don't have a Connection header, HTTP 1.1 connections are assumed to - # be persistent - if httpversion == (1, 1): - return False - return True - - -def parse_response_line(line): - parts = line.strip().split(" ", 2) - if len(parts) == 2: # handle missing message gracefully - parts.append("") - if len(parts) != 3: - return None - proto, code, msg = parts - try: - code = int(code) - except ValueError: - return None - return (proto, code, msg) - - -def read_http_body(*args, **kwargs): - return "".join( - content for _, content, _ in read_http_body_chunked(*args, **kwargs) - ) - - -def read_http_body_chunked( - rfile, - headers, - limit, - request_method, - response_code, - is_request, - max_chunk_size=None -): - """ - Read an HTTP message body: - - rfile: A file descriptor to read from - headers: An ODictCaseless object - limit: Size limit. - is_request: True if the body to read belongs to a request, False - otherwise - """ - if max_chunk_size is None: - max_chunk_size = limit or sys.maxsize - - expected_size = expected_http_body_size( - headers, is_request, request_method, response_code - ) - - if expected_size is None: - if has_chunked_encoding(headers): - # Python 3: yield from - for x in read_chunked(rfile, limit, is_request): - yield x - else: # pragma: nocover - raise HttpError( - 400 if is_request else 502, - "Content-Length unknown but no chunked encoding" - ) - elif expected_size >= 0: - if limit is not None and expected_size > limit: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s, content-length was %s" % ( - limit, expected_size - ) - ) - bytes_left = expected_size - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - yield "", rfile.read(chunk_size), "" - bytes_left -= chunk_size - else: - bytes_left = limit or -1 - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - content = rfile.read(chunk_size) - if not content: - return - yield "", content, "" - bytes_left -= chunk_size - not_done = rfile.read(1) - if not_done: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s," % limit - ) - - -def expected_http_body_size(headers, is_request, request_method, response_code): - """ - Returns the expected body length: - - a positive integer, if the size is known in advance - - None, if the size in unknown in advance (chunked encoding or invalid - data) - - -1, if all data should be read until end of stream. - - May raise HttpError. - """ - # Determine response size according to - # http://tools.ietf.org/html/rfc7230#section-3.3 - if request_method: - request_method = request_method.upper() - - if (not is_request and ( - request_method == "HEAD" or - (request_method == "CONNECT" and response_code == 200) or - response_code in [204, 304] or - 100 <= response_code <= 199)): - return 0 - if has_chunked_encoding(headers): - return None - if "content-length" in headers: - try: - size = int(headers["content-length"][0]) - if size < 0: - raise ValueError() - return size - except ValueError: - return None - if is_request: - return 0 - return -1 - - -Request = collections.namedtuple( - "Request", - [ - "form_in", - "method", - "scheme", - "host", - "port", - "path", - "httpversion", - "headers", - "content" - ] -) - - -def read_request(rfile, include_body=True, body_size_limit=None, wfile=None): - """ - Parse an HTTP request from a file stream - - Args: - rfile (file): Input file to read from - include_body (bool): Read response body as well - body_size_limit (bool): Maximum body size - wfile (file): If specified, HTTP Expect headers are handled - automatically, by writing a HTTP 100 CONTINUE response to the stream. - - Returns: - Request: The HTTP request - - Raises: - HttpError: If the input is invalid. - """ - httpversion, host, port, scheme, method, path, headers, content = ( - None, None, None, None, None, None, None, None) - - request_line = get_request_line(rfile) - if not request_line: - raise tcp.NetLibDisconnect() - - request_line_parts = parse_init(request_line) - if not request_line_parts: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - method, path, httpversion = request_line_parts - - if path == '*' or path.startswith("/"): - form_in = "relative" - if not utils.isascii(path): - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - elif method.upper() == 'CONNECT': - form_in = "authority" - r = parse_init_connect(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - host, port, _ = r - path = None - else: - form_in = "absolute" - r = parse_init_proxy(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - _, scheme, host, port, path, _ = r - - headers = read_headers(rfile) - if headers is None: - raise HttpError(400, "Invalid headers") - - expect_header = headers.get_first("expect", "").lower() - if expect_header == "100-continue" and httpversion >= (1, 1): - wfile.write( - 'HTTP/1.1 100 Continue\r\n' - '\r\n' - ) - wfile.flush() - del headers['expect'] - - if include_body: - content = read_http_body( - rfile, headers, body_size_limit, method, None, True - ) - - return Request( - form_in, - method, - scheme, - host, - port, - path, - httpversion, - headers, - content - ) - - -Response = collections.namedtuple( - "Response", - [ - "httpversion", - "code", - "msg", - "headers", - "content" - ] -) - - -def read_response(rfile, request_method, body_size_limit, include_body=True): - """ - Return an (httpversion, code, msg, headers, content) tuple. - - By default, both response header and body are read. - If include_body=False is specified, content may be one of the - following: - - None, if the response is technically allowed to have a response body - - "", if the response must not have a response body (e.g. it's a - response to a HEAD request) - """ - line = rfile.readline() - # Possible leftover from previous message - if line == "\r\n" or line == "\n": - line = rfile.readline() - if not line: - raise HttpErrorConnClosed(502, "Server disconnect.") - parts = parse_response_line(line) - if not parts: - raise HttpError(502, "Invalid server response: %s" % repr(line)) - proto, code, msg = parts - httpversion = parse_http_protocol(proto) - if httpversion is None: - raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) - headers = read_headers(rfile) - if headers is None: - raise HttpError(502, "Invalid headers.") - - if include_body: - content = read_http_body( - rfile, - headers, - body_size_limit, - request_method, - code, - False - ) - else: - # if include_body==False then a None content means the body should be - # read separately - content = None - return Response(httpversion, code, msg, headers, content) - - -def request_preamble(method, resource, http_major="1", http_minor="1"): - return '%s %s HTTP/%s.%s' % ( - method, resource, http_major, http_minor - ) - - -def response_preamble(code, message=None, http_major="1", http_minor="1"): - if message is None: - message = http_status.RESPONSES.get(code) - return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py new file mode 100644 index 00000000..b01afc6d --- /dev/null +++ b/netlib/http/__init__.py @@ -0,0 +1,3 @@ +from . import * +from exceptions import * +from semantics import * diff --git a/netlib/http_auth.py b/netlib/http/authentication.py index adab4aed..9a227010 100644 --- a/netlib/http_auth.py +++ b/netlib/http/authentication.py @@ -1,6 +1,27 @@ from __future__ import (absolute_import, print_function, division) from argparse import Action, ArgumentTypeError -from . import http +import binascii + +from .. import http + +def parse_http_basic_auth(s): + words = s.split() + if len(words) != 2: + return None + scheme = words[0] + try: + user = binascii.a2b_base64(words[1]) + except binascii.Error: + return None + parts = user.split(':') + if len(parts) != 2: + return None + return scheme, parts[0], parts[1] + + +def assemble_http_basic_auth(scheme, username, password): + v = binascii.b2a_base64(username + ":" + password) + return scheme + " " + v class NullProxyAuth(object): @@ -46,7 +67,7 @@ class BasicProxyAuth(NullProxyAuth): auth_value = headers.get(self.AUTH_HEADER, []) if not auth_value: return False - parts = http.parse_http_basic_auth(auth_value[0]) + parts = parse_http_basic_auth(auth_value[0]) if not parts: return False scheme, username, password = parts diff --git a/netlib/http_cookies.py b/netlib/http/cookies.py index e91ee5c0..b77e3503 100644 --- a/netlib/http_cookies.py +++ b/netlib/http/cookies.py @@ -1,3 +1,7 @@ +import re + +from .. import odict + """ A flexible module for cookie parsing and manipulation. @@ -22,10 +26,6 @@ variants. Serialization follows RFC6265. # TODO # - Disallow LHS-only Cookie values -import re - -import odict - def _read_until(s, start, term): """ diff --git a/netlib/http/exceptions.py b/netlib/http/exceptions.py new file mode 100644 index 00000000..8a2bbebc --- /dev/null +++ b/netlib/http/exceptions.py @@ -0,0 +1,9 @@ +class HttpError(Exception): + + def __init__(self, code, message): + super(HttpError, self).__init__(message) + self.code = code + + +class HttpErrorConnClosed(HttpError): + pass diff --git a/netlib/http/http1/__init__.py b/netlib/http/http1/__init__.py new file mode 100644 index 00000000..6b5043af --- /dev/null +++ b/netlib/http/http1/__init__.py @@ -0,0 +1 @@ +from protocol import * diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py new file mode 100644 index 00000000..e46ad7ab --- /dev/null +++ b/netlib/http/http1/protocol.py @@ -0,0 +1,496 @@ +from __future__ import (absolute_import, print_function, division) +import binascii +import collections +import string +import sys +import urlparse + +from netlib import odict, utils, tcp, http +from .. import status_codes +from ..exceptions import * + +class TCPHandler(object): + def __init__(self, rfile, wfile=None): + self.rfile = rfile + self.wfile = wfile + +class HTTP1Protocol(object): + + def __init__(self, tcp_handler=None, rfile=None, wfile=None): + if tcp_handler: + self.tcp_handler = tcp_handler + else: + self.tcp_handler = TCPHandler(rfile, wfile) + + + def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): + """ + Parse an HTTP request from a file stream + + Args: + include_body (bool): Read response body as well + body_size_limit (bool): Maximum body size + wfile (file): If specified, HTTP Expect headers are handled + automatically, by writing a HTTP 100 CONTINUE response to the stream. + + Returns: + Request: The HTTP request + + Raises: + HttpError: If the input is invalid. + """ + httpversion, host, port, scheme, method, path, headers, body = ( + None, None, None, None, None, None, None, None) + + request_line = self._get_request_line() + if not request_line: + if allow_empty: + return http.EmptyRequest() + else: + raise tcp.NetLibDisconnect() + + request_line_parts = self._parse_init(request_line) + if not request_line_parts: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + method, path, httpversion = request_line_parts + + if path == '*' or path.startswith("/"): + form_in = "relative" + if not utils.isascii(path): + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + elif method == 'CONNECT': + form_in = "authority" + r = self._parse_init_connect(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + host, port, httpversion = r + path = None + else: + form_in = "absolute" + r = self._parse_init_proxy(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + _, scheme, host, port, path, _ = r + + headers = self.read_headers() + if headers is None: + raise HttpError(400, "Invalid headers") + + expect_header = headers.get_first("expect", "").lower() + if expect_header == "100-continue" and httpversion == (1, 1): + self.tcp_handler.wfile.write( + 'HTTP/1.1 100 Continue\r\n' + '\r\n' + ) + self.tcp_handler.wfile.flush() + del headers['expect'] + + if include_body: + body = self.read_http_body( + headers, + body_size_limit, + method, + None, + True + ) + + return http.Request( + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers, + body + ) + + + def read_response(self, request_method, body_size_limit, include_body=True): + """ + Returns an http.Response + + By default, both response header and body are read. + If include_body=False is specified, content may be one of the + following: + - None, if the response is technically allowed to have a response body + - "", if the response must not have a response body (e.g. it's a + response to a HEAD request) + """ + + line = self.tcp_handler.rfile.readline() + # Possible leftover from previous message + if line == "\r\n" or line == "\n": + line = self.tcp_handler.rfile.readline() + if not line: + raise HttpErrorConnClosed(502, "Server disconnect.") + parts = self.parse_response_line(line) + if not parts: + raise HttpError(502, "Invalid server response: %s" % repr(line)) + proto, code, msg = parts + httpversion = self._parse_http_protocol(proto) + if httpversion is None: + raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) + headers = self.read_headers() + if headers is None: + raise HttpError(502, "Invalid headers.") + + if include_body: + content = self.read_http_body( + headers, + body_size_limit, + request_method, + code, + False + ) + else: + # if include_body==False then a None content means the body should be + # read separately + content = None + return http.Response(httpversion, code, msg, headers, content) + + + def read_headers(self): + """ + Read a set of headers. + Stop once a blank line is reached. + + Return a ODictCaseless object, or None if headers are invalid. + """ + ret = [] + name = '' + while True: + line = self.tcp_handler.rfile.readline() + if not line or line == '\r\n' or line == '\n': + break + if line[0] in ' \t': + if not ret: + return None + # continued header + ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() + else: + i = line.find(':') + # We're being liberal in what we accept, here. + if i > 0: + name = line[:i] + value = line[i + 1:].strip() + ret.append([name, value]) + else: + return None + return odict.ODictCaseless(ret) + + + def read_http_body(self, *args, **kwargs): + return "".join( + content for _, content, _ in self.read_http_body_chunked(*args, **kwargs) + ) + + + def read_http_body_chunked( + self, + headers, + limit, + request_method, + response_code, + is_request, + max_chunk_size=None + ): + """ + Read an HTTP message body: + headers: An ODictCaseless object + limit: Size limit. + is_request: True if the body to read belongs to a request, False + otherwise + """ + if max_chunk_size is None: + max_chunk_size = limit or sys.maxsize + + expected_size = self.expected_http_body_size( + headers, is_request, request_method, response_code + ) + + if expected_size is None: + if self.has_chunked_encoding(headers): + # Python 3: yield from + for x in self._read_chunked(limit, is_request): + yield x + else: # pragma: nocover + raise HttpError( + 400 if is_request else 502, + "Content-Length unknown but no chunked encoding" + ) + elif expected_size >= 0: + if limit is not None and expected_size > limit: + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s, content-length was %s" % ( + limit, expected_size + ) + ) + bytes_left = expected_size + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + yield "", self.tcp_handler.rfile.read(chunk_size), "" + bytes_left -= chunk_size + else: + bytes_left = limit or -1 + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + content = self.tcp_handler.rfile.read(chunk_size) + if not content: + return + yield "", content, "" + bytes_left -= chunk_size + not_done = self.tcp_handler.rfile.read(1) + if not_done: + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s," % limit + ) + + + @classmethod + def expected_http_body_size(self, headers, is_request, request_method, response_code): + """ + Returns the expected body length: + - a positive integer, if the size is known in advance + - None, if the size in unknown in advance (chunked encoding or invalid + data) + - -1, if all data should be read until end of stream. + + May raise HttpError. + """ + # Determine response size according to + # http://tools.ietf.org/html/rfc7230#section-3.3 + if request_method: + request_method = request_method.upper() + + if (not is_request and ( + request_method == "HEAD" or + (request_method == "CONNECT" and response_code == 200) or + response_code in [204, 304] or + 100 <= response_code <= 199)): + return 0 + if self.has_chunked_encoding(headers): + return None + if "content-length" in headers: + try: + size = int(headers["content-length"][0]) + if size < 0: + raise ValueError() + return size + except ValueError: + return None + if is_request: + return 0 + return -1 + + + @classmethod + def request_preamble(self, method, resource, http_major="1", http_minor="1"): + return '%s %s HTTP/%s.%s' % ( + method, resource, http_major, http_minor + ) + + + @classmethod + def response_preamble(self, code, message=None, http_major="1", http_minor="1"): + if message is None: + message = status_codes.RESPONSES.get(code) + return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) + + + @classmethod + def has_chunked_encoding(self, headers): + return "chunked" in [ + i.lower() for i in http.get_header_tokens(headers, "transfer-encoding") + ] + + + def _get_request_line(self): + """ + Get a line, possibly preceded by a blank. + """ + line = self.tcp_handler.rfile.readline() + if line == "\r\n" or line == "\n": + # Possible leftover from previous message + line = self.tcp_handler.rfile.readline() + return line + + + + def _read_chunked(self, limit, is_request): + """ + Read a chunked HTTP body. + + May raise HttpError. + """ + # FIXME: Should check if chunked is the final encoding in the headers + # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 + # 3.3 2. + total = 0 + code = 400 if is_request else 502 + while True: + line = self.tcp_handler.rfile.readline(128) + if line == "": + raise HttpErrorConnClosed(code, "Connection closed prematurely") + if line != '\r\n' and line != '\n': + try: + length = int(line, 16) + except ValueError: + raise HttpError( + code, + "Invalid chunked encoding length: %s" % line + ) + total += length + if limit is not None and total > limit: + msg = "HTTP Body too large. Limit is %s," \ + " chunked content longer than %s" % (limit, total) + raise HttpError(code, msg) + chunk = self.tcp_handler.rfile.read(length) + suffix = self.tcp_handler.rfile.readline(5) + if suffix != '\r\n': + raise HttpError(code, "Malformed chunked body") + yield line, chunk, '\r\n' + if length == 0: + return + + + @classmethod + def _parse_http_protocol(self, line): + """ + Parse an HTTP protocol declaration. + Returns a (major, minor) tuple, or None. + """ + if not line.startswith("HTTP/"): + return None + _, version = line.split('/', 1) + if "." not in version: + return None + major, minor = version.split('.', 1) + try: + major = int(major) + minor = int(minor) + except ValueError: + return None + return major, minor + + + @classmethod + def _parse_init(self, line): + try: + method, url, protocol = string.split(line) + except ValueError: + return None + httpversion = self._parse_http_protocol(protocol) + if not httpversion: + return None + if not utils.isascii(method): + return None + return method, url, httpversion + + + @classmethod + def _parse_init_connect(self, line): + """ + Returns (host, port, httpversion) if line is a valid CONNECT line. + http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 + """ + v = self._parse_init(line) + if not v: + return None + method, url, httpversion = v + + if method.upper() != 'CONNECT': + return None + try: + host, port = url.split(":") + except ValueError: + return None + try: + port = int(port) + except ValueError: + return None + if not http.is_valid_port(port): + return None + if not http.is_valid_host(host): + return None + return host, port, httpversion + + + @classmethod + def _parse_init_proxy(self, line): + v = self._parse_init(line) + if not v: + return None + method, url, httpversion = v + + parts = http.parse_url(url) + if not parts: + return None + scheme, host, port, path = parts + return method, scheme, host, port, path, httpversion + + + @classmethod + def _parse_init_http(self, line): + """ + Returns (method, url, httpversion) + """ + v = self._parse_init(line) + if not v: + return None + method, url, httpversion = v + if not utils.isascii(url): + return None + if not (url.startswith("/") or url == "*"): + return None + return method, url, httpversion + + + @classmethod + def connection_close(self, httpversion, headers): + """ + Checks the message to see if the client connection should be closed + according to RFC 2616 Section 8.1 Note that a connection should be + closed as well if the response has been read until end of the stream. + """ + # At first, check if we have an explicit Connection header. + if "connection" in headers: + toks = http.get_header_tokens(headers, "connection") + if "close" in toks: + return True + elif "keep-alive" in toks: + return False + + # If we don't have a Connection header, HTTP 1.1 connections are assumed to + # be persistent + return httpversion != (1, 1) + + + @classmethod + def parse_response_line(self, line): + parts = line.strip().split(" ", 2) + if len(parts) == 2: # handle missing message gracefully + parts.append("") + if len(parts) != 3: + return None + proto, code, msg = parts + try: + code = int(code) + except ValueError: + return None + return (proto, code, msg) diff --git a/netlib/http2/__init__.py b/netlib/http/http2/__init__.py index 5acf7696..5acf7696 100644 --- a/netlib/http2/__init__.py +++ b/netlib/http/http2/__init__.py diff --git a/netlib/http2/frame.py b/netlib/http/http2/frame.py index f7e60471..f7e60471 100644 --- a/netlib/http2/frame.py +++ b/netlib/http/http2/frame.py diff --git a/netlib/http2/protocol.py b/netlib/http/http2/protocol.py index 8e5f5429..55b5ca76 100644 --- a/netlib/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -2,7 +2,7 @@ from __future__ import (absolute_import, print_function, division) import itertools from hpack.hpack import Encoder, Decoder -from .. import utils +from netlib import http, utils, odict from . import frame @@ -186,12 +186,27 @@ class HTTP2Protocol(object): self._create_headers(headers, stream_id, end_stream=(body is None)), self._create_body(body, stream_id))) - def read_response(self): - stream_id_, headers, body = self._receive_transmission() - return headers[':status'], headers, body + def read_response(self, *args): + stream_id, headers, body = self._receive_transmission() + + status = headers[':status'][0] + response = http.Response("HTTP/2", status, "", headers, body) + response.stream_id = stream_id + return response def read_request(self): - return self._receive_transmission() + stream_id, headers, body = self._receive_transmission() + + form_in = "" + method = headers.get(':method', [''])[0] + scheme = headers.get(':scheme', [''])[0] + host = headers.get(':host', [''])[0] + port = '' # TODO: parse port number? + path = headers.get(':path', [''])[0] + + request = http.Request(form_in, method, scheme, host, port, path, "HTTP/2", headers, body) + request.stream_id = stream_id + return request def _receive_transmission(self): body_expected = True @@ -219,15 +234,17 @@ class HTTP2Protocol(object): break # TODO: implement window update & flow - headers = {} + headers = odict.ODictCaseless() for header, value in self.decoder.decode(header_block_fragment): - headers[header] = value + headers.add(header, value) return stream_id, headers, body def create_response(self, code, stream_id=None, headers=None, body=None): if headers is None: headers = [] + if isinstance(headers, odict.ODict): + headers = headers.items() headers = [(b':status', bytes(str(code)))] + headers diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py new file mode 100644 index 00000000..9e13edaa --- /dev/null +++ b/netlib/http/semantics.py @@ -0,0 +1,159 @@ +from __future__ import (absolute_import, print_function, division) +import binascii +import collections +import string +import sys +import urlparse + +from .. import utils, odict + +class Request(object): + + def __init__( + self, + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers, + body, + ): + self.form_in = form_in + self.method = method + self.scheme = scheme + self.host = host + self.port = port + self.path = path + self.httpversion = httpversion + self.headers = headers + self.body = body + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __repr__(self): + return "Request(%s - %s, %s)" % (self.method, self.host, self.path) + + @property + def content(self): + return self.body + + +class EmptyRequest(Request): + def __init__(self): + super(EmptyRequest, self).__init__( + form_in="", + method="", + scheme="", + host="", + port="", + path="", + httpversion=(0, 0), + headers=odict.ODictCaseless(), + body="", + ) + + +class Response(object): + + def __init__( + self, + httpversion, + status_code, + msg, + headers, + body, + sslinfo=None, + ): + self.httpversion = httpversion + self.status_code = status_code + self.msg = msg + self.headers = headers + self.body = body + self.sslinfo = sslinfo + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __repr__(self): + return "Response(%s - %s)" % (self.status_code, self.msg) + + @property + def content(self): + return self.body + + +def is_valid_port(port): + if not 0 <= port <= 65535: + return False + return True + + +def is_valid_host(host): + try: + host.decode("idna") + except ValueError: + return False + if "\0" in host: + return None + return True + + +def parse_url(url): + """ + Returns a (scheme, host, port, path) tuple, or None on error. + + Checks that: + port is an integer 0-65535 + host is a valid IDNA-encoded hostname with no null-bytes + path is valid ASCII + """ + try: + scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) + except ValueError: + return None + if not scheme: + return None + if '@' in netloc: + # FIXME: Consider what to do with the discarded credentials here Most + # probably we should extend the signature to return these as a separate + # value. + _, netloc = string.rsplit(netloc, '@', maxsplit=1) + if ':' in netloc: + host, port = string.rsplit(netloc, ':', maxsplit=1) + try: + port = int(port) + except ValueError: + return None + else: + host = netloc + if scheme == "https": + port = 443 + else: + port = 80 + path = urlparse.urlunparse(('', '', path, params, query, fragment)) + if not path.startswith("/"): + path = "/" + path + if not is_valid_host(host): + return None + if not utils.isascii(path): + return None + if not is_valid_port(port): + return None + return scheme, host, port, path + + +def get_header_tokens(headers, key): + """ + Retrieve all tokens for a header key. A number of different headers + follow a pattern where each header line can containe comma-separated + tokens, and headers can be set multiple times. + """ + toks = [] + for i in headers[key]: + for j in i.split(","): + toks.append(j.strip()) + return toks diff --git a/netlib/http_status.py b/netlib/http/status_codes.py index dc09f465..dc09f465 100644 --- a/netlib/http_status.py +++ b/netlib/http/status_codes.py diff --git a/netlib/http_uastrings.py b/netlib/http/user_agents.py index e8681908..e8681908 100644 --- a/netlib/http_uastrings.py +++ b/netlib/http/user_agents.py diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py new file mode 100644 index 00000000..5acf7696 --- /dev/null +++ b/netlib/websockets/__init__.py @@ -0,0 +1,2 @@ +from frame import * +from protocol import * diff --git a/netlib/websockets.py b/netlib/websockets/frame.py index c45db4df..49d8ee10 100644 --- a/netlib/websockets.py +++ b/netlib/websockets/frame.py @@ -5,26 +5,14 @@ import os import struct import io -from . import utils, odict, tcp - -# Colleciton of utility functions that implement small portions of the RFC6455 -# WebSockets Protocol Useful for building WebSocket clients and servers. -# -# Emphassis is on readabilty, simplicity and modularity, not performance or -# completeness -# -# This is a work in progress and does not yet contain all the utilites need to -# create fully complient client/servers # -# Spec: https://tools.ietf.org/html/rfc6455 - -# The magic sha that websocket servers must know to prove they understand -# RFC6455 -websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' -VERSION = "13" +from .protocol import Masker +from netlib import utils, odict, tcp + +DEFAULT = object() + MAX_16_BIT_INT = (1 << 16) MAX_64_BIT_INT = (1 << 64) - OPCODE = utils.BiDi( CONTINUE=0x00, TEXT=0x01, @@ -34,101 +22,6 @@ OPCODE = utils.BiDi( PONG=0x0a ) - -class Masker(object): - - """ - Data sent from the server must be masked to prevent malicious clients - from sending data over the wire in predictable patterns - - Servers do not have to mask data they send to the client. - https://tools.ietf.org/html/rfc6455#section-5.3 - """ - - def __init__(self, key): - self.key = key - self.masks = [utils.bytes_to_int(byte) for byte in key] - self.offset = 0 - - def mask(self, offset, data): - result = "" - for c in data: - result += chr(ord(c) ^ self.masks[offset % 4]) - offset += 1 - return result - - def __call__(self, data): - ret = self.mask(self.offset, data) - self.offset += len(ret) - return ret - - -def client_handshake_headers(key=None, version=VERSION): - """ - Create the headers for a valid HTTP upgrade request. If Key is not - specified, it is generated, and can be found in sec-websocket-key in - the returned header set. - - Returns an instance of ODictCaseless - """ - if not key: - key = base64.b64encode(os.urandom(16)).decode('utf-8') - return odict.ODictCaseless([ - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - ('Sec-WebSocket-Key', key), - ('Sec-WebSocket-Version', version) - ]) - - -def server_handshake_headers(key): - """ - The server response is a valid HTTP 101 response. - """ - return odict.ODictCaseless( - [ - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - ('Sec-WebSocket-Accept', create_server_nonce(key)) - ] - ) - - -def make_length_code(length): - """ - A websockets frame contains an initial length_code, and an optional - extended length code to represent the actual length if length code is - larger than 125 - """ - if length <= 125: - return length - elif length >= 126 and length <= 65535: - return 126 - else: - return 127 - - -def check_client_handshake(headers): - if headers.get_first("upgrade", None) != "websocket": - return - return headers.get_first('sec-websocket-key') - - -def check_server_handshake(headers): - if headers.get_first("upgrade", None) != "websocket": - return - return headers.get_first('sec-websocket-accept') - - -def create_server_nonce(client_nonce): - return base64.b64encode( - hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex') - ) - - -DEFAULT = object() - - class FrameHeader(object): def __init__( @@ -153,7 +46,7 @@ class FrameHeader(object): self.rsv3 = rsv3 if length_code is DEFAULT: - self.length_code = make_length_code(self.payload_length) + self.length_code = self._make_length_code(self.payload_length) else: self.length_code = length_code @@ -173,6 +66,20 @@ class FrameHeader(object): if self.masking_key and len(self.masking_key) != 4: raise ValueError("Masking key must be 4 bytes.") + @classmethod + def _make_length_code(self, length): + """ + A websockets frame contains an initial length_code, and an optional + extended length code to represent the actual length if length code is + larger than 125 + """ + if length <= 125: + return length + elif length >= 126 and length <= 65535: + return 126 + else: + return 127 + def human_readable(self): vals = [ "ws frame:", diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py new file mode 100644 index 00000000..29b4db3d --- /dev/null +++ b/netlib/websockets/protocol.py @@ -0,0 +1,111 @@ +from __future__ import absolute_import +import base64 +import hashlib +import os +import struct +import io + +from netlib import utils, odict, tcp + +# Colleciton of utility functions that implement small portions of the RFC6455 +# WebSockets Protocol Useful for building WebSocket clients and servers. +# +# Emphassis is on readabilty, simplicity and modularity, not performance or +# completeness +# +# This is a work in progress and does not yet contain all the utilites need to +# create fully complient client/servers # +# Spec: https://tools.ietf.org/html/rfc6455 + +# The magic sha that websocket servers must know to prove they understand +# RFC6455 +websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' +VERSION = "13" + +HEADER_WEBSOCKET_KEY = 'sec-websocket-key' +HEADER_WEBSOCKET_ACCEPT = 'sec-websocket-accept' +HEADER_WEBSOCKET_VERSION = 'sec-websocket-version' + +class Masker(object): + + """ + Data sent from the server must be masked to prevent malicious clients + from sending data over the wire in predictable patterns + + Servers do not have to mask data they send to the client. + https://tools.ietf.org/html/rfc6455#section-5.3 + """ + + def __init__(self, key): + self.key = key + self.masks = [utils.bytes_to_int(byte) for byte in key] + self.offset = 0 + + def mask(self, offset, data): + result = "" + for c in data: + result += chr(ord(c) ^ self.masks[offset % 4]) + offset += 1 + return result + + def __call__(self, data): + ret = self.mask(self.offset, data) + self.offset += len(ret) + return ret + +class WebsocketsProtocol(object): + + def __init__(self): + pass + + @classmethod + def client_handshake_headers(self, key=None, version=VERSION): + """ + Create the headers for a valid HTTP upgrade request. If Key is not + specified, it is generated, and can be found in sec-websocket-key in + the returned header set. + + Returns an instance of ODictCaseless + """ + if not key: + key = base64.b64encode(os.urandom(16)).decode('utf-8') + return odict.ODictCaseless([ + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + (HEADER_WEBSOCKET_KEY, key), + (HEADER_WEBSOCKET_VERSION, version) + ]) + + @classmethod + def server_handshake_headers(self, key): + """ + The server response is a valid HTTP 101 response. + """ + return odict.ODictCaseless( + [ + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + (HEADER_WEBSOCKET_ACCEPT, self.create_server_nonce(key)) + ] + ) + + + @classmethod + def check_client_handshake(self, headers): + if headers.get_first("upgrade", None) != "websocket": + return + return headers.get_first(HEADER_WEBSOCKET_KEY) + + + @classmethod + def check_server_handshake(self, headers): + if headers.get_first("upgrade", None) != "websocket": + return + return headers.get_first(HEADER_WEBSOCKET_ACCEPT) + + + @classmethod + def create_server_nonce(self, client_nonce): + return base64.b64encode( + hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex') + ) diff --git a/netlib/wsgi.py b/netlib/wsgi.py index ad43dc19..99afe00e 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -21,9 +21,9 @@ class Flow(object): class Request(object): - def __init__(self, scheme, method, path, headers, content): + def __init__(self, scheme, method, path, headers, body): self.scheme, self.method, self.path = scheme, method, path - self.headers, self.content = headers, content + self.headers, self.body = headers, body def date_time_string(): @@ -58,7 +58,7 @@ class WSGIAdaptor(object): environ = { 'wsgi.version': (1, 0), 'wsgi.url_scheme': flow.request.scheme, - 'wsgi.input': cStringIO.StringIO(flow.request.content), + 'wsgi.input': cStringIO.StringIO(flow.request.body or ""), 'wsgi.errors': errsoc, 'wsgi.multithread': True, 'wsgi.multiprocess': False, diff --git a/test/http2/__init__.py b/test/http/__init__.py index e69de29b..e69de29b 100644 --- a/test/http2/__init__.py +++ b/test/http/__init__.py diff --git a/test/http/http1/__init__.py b/test/http/http1/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/http/http1/__init__.py diff --git a/test/test_http.py b/test/http/http1/test_protocol.py index 2ad81d24..dcebbd5e 100644 --- a/test/test_http.py +++ b/test/http/http1/test_protocol.py @@ -1,75 +1,78 @@ import cStringIO import textwrap import binascii + from netlib import http, odict, tcp -from . import tutils, tservers +from netlib.http.http1 import HTTP1Protocol +from ... import tutils, tservers + +def mock_protocol(data='', chunked=False): + rfile = cStringIO.StringIO(data) + wfile = cStringIO.StringIO() + return HTTP1Protocol(rfile=rfile, wfile=wfile) -def test_httperror(): - e = http.HttpError(404, "Not found") - assert str(e) def test_has_chunked_encoding(): h = odict.ODictCaseless() - assert not http.has_chunked_encoding(h) + assert not HTTP1Protocol.has_chunked_encoding(h) h["transfer-encoding"] = ["chunked"] - assert http.has_chunked_encoding(h) + assert HTTP1Protocol.has_chunked_encoding(h) def test_read_chunked(): - h = odict.ODictCaseless() h["transfer-encoding"] = ["chunked"] - s = cStringIO.StringIO("1\r\na\r\n0\r\n") + data = "1\r\na\r\n0\r\n" tutils.raises( "malformed chunked body", - http.read_http_body, - s, h, None, "GET", None, True + mock_protocol(data).read_http_body, + h, None, "GET", None, True ) - s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") - assert http.read_http_body(s, h, None, "GET", None, True) == "a" + data = "1\r\na\r\n0\r\n\r\n" + assert mock_protocol(data).read_http_body(h, None, "GET", None, True) == "a" - s = cStringIO.StringIO("\r\n\r\n1\r\na\r\n0\r\n\r\n") - assert http.read_http_body(s, h, None, "GET", None, True) == "a" + data = "\r\n\r\n1\r\na\r\n0\r\n\r\n" + assert mock_protocol(data).read_http_body(h, None, "GET", None, True) == "a" - s = cStringIO.StringIO("\r\n") + data = "\r\n" tutils.raises( "closed prematurely", - http.read_http_body, - s, h, None, "GET", None, True + mock_protocol(data).read_http_body, + h, None, "GET", None, True ) - s = cStringIO.StringIO("1\r\nfoo") + data = "1\r\nfoo" tutils.raises( "malformed chunked body", - http.read_http_body, - s, h, None, "GET", None, True + mock_protocol(data).read_http_body, + h, None, "GET", None, True ) - s = cStringIO.StringIO("foo\r\nfoo") + data = "foo\r\nfoo" tutils.raises( http.HttpError, - http.read_http_body, - s, h, None, "GET", None, True + mock_protocol(data).read_http_body, + h, None, "GET", None, True ) - s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") - tutils.raises("too large", http.read_http_body, s, h, 2, "GET", None, True) + data = "5\r\naaaaa\r\n0\r\n\r\n" + tutils.raises("too large", mock_protocol(data).read_http_body, h, 2, "GET", None, True) def test_connection_close(): h = odict.ODictCaseless() - assert http.connection_close((1, 0), h) - assert not http.connection_close((1, 1), h) + assert HTTP1Protocol.connection_close((1, 0), h) + assert not HTTP1Protocol.connection_close((1, 1), h) h["connection"] = ["keep-alive"] - assert not http.connection_close((1, 1), h) + assert not HTTP1Protocol.connection_close((1, 1), h) h["connection"] = ["close"] - assert http.connection_close((1, 1), h) + assert HTTP1Protocol.connection_close((1, 1), h) def test_get_header_tokens(): @@ -85,119 +88,119 @@ def test_get_header_tokens(): def test_read_http_body_request(): h = odict.ODictCaseless() - r = cStringIO.StringIO("testing") - assert http.read_http_body(r, h, None, "GET", None, True) == "" + data = "testing" + assert mock_protocol(data).read_http_body(h, None, "GET", None, True) == "" def test_read_http_body_response(): h = odict.ODictCaseless() - s = tcp.Reader(cStringIO.StringIO("testing")) - assert http.read_http_body(s, h, None, "GET", 200, False) == "testing" + data = "testing" + assert mock_protocol(data, chunked=True).read_http_body(h, None, "GET", 200, False) == "testing" def test_read_http_body(): # test default case h = odict.ODictCaseless() h["content-length"] = [7] - s = cStringIO.StringIO("testing") - assert http.read_http_body(s, h, None, "GET", 200, False) == "testing" + data = "testing" + assert mock_protocol(data).read_http_body(h, None, "GET", 200, False) == "testing" # test content length: invalid header h["content-length"] = ["foo"] - s = cStringIO.StringIO("testing") + data = "testing" tutils.raises( http.HttpError, - http.read_http_body, - s, h, None, "GET", 200, False + mock_protocol(data).read_http_body, + h, None, "GET", 200, False ) # test content length: invalid header #2 h["content-length"] = [-1] - s = cStringIO.StringIO("testing") + data = "testing" tutils.raises( http.HttpError, - http.read_http_body, - s, h, None, "GET", 200, False + mock_protocol(data).read_http_body, + h, None, "GET", 200, False ) # test content length: content length > actual content h["content-length"] = [5] - s = cStringIO.StringIO("testing") + data = "testing" tutils.raises( http.HttpError, - http.read_http_body, - s, h, 4, "GET", 200, False + mock_protocol(data).read_http_body, + h, 4, "GET", 200, False ) # test content length: content length < actual content - s = cStringIO.StringIO("testing") - assert len(http.read_http_body(s, h, None, "GET", 200, False)) == 5 + data = "testing" + assert len(mock_protocol(data).read_http_body(h, None, "GET", 200, False)) == 5 # test no content length: limit > actual content h = odict.ODictCaseless() - s = tcp.Reader(cStringIO.StringIO("testing")) - assert len(http.read_http_body(s, h, 100, "GET", 200, False)) == 7 + data = "testing" + assert len(mock_protocol(data, chunked=True).read_http_body(h, 100, "GET", 200, False)) == 7 # test no content length: limit < actual content - s = tcp.Reader(cStringIO.StringIO("testing")) + data = "testing" tutils.raises( http.HttpError, - http.read_http_body, - s, h, 4, "GET", 200, False + mock_protocol(data, chunked=True).read_http_body, + h, 4, "GET", 200, False ) # test chunked h = odict.ODictCaseless() h["transfer-encoding"] = ["chunked"] - s = tcp.Reader(cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n")) - assert http.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa" + data = "5\r\naaaaa\r\n0\r\n\r\n" + assert mock_protocol(data, chunked=True).read_http_body(h, 100, "GET", 200, False) == "aaaaa" def test_expected_http_body_size(): # gibber in the content-length field h = odict.ODictCaseless() h["content-length"] = ["foo"] - assert http.expected_http_body_size(h, False, "GET", 200) is None + assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) is None # negative number in the content-length field h = odict.ODictCaseless() h["content-length"] = ["-7"] - assert http.expected_http_body_size(h, False, "GET", 200) is None + assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) is None # explicit length h = odict.ODictCaseless() h["content-length"] = ["5"] - assert http.expected_http_body_size(h, False, "GET", 200) == 5 + assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) == 5 # no length h = odict.ODictCaseless() - assert http.expected_http_body_size(h, False, "GET", 200) == -1 + assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) == -1 # no length request h = odict.ODictCaseless() - assert http.expected_http_body_size(h, True, "GET", None) == 0 + assert HTTP1Protocol.expected_http_body_size(h, True, "GET", None) == 0 def test_parse_http_protocol(): - assert http.parse_http_protocol("HTTP/1.1") == (1, 1) - assert http.parse_http_protocol("HTTP/0.0") == (0, 0) - assert not http.parse_http_protocol("HTTP/a.1") - assert not http.parse_http_protocol("HTTP/1.a") - assert not http.parse_http_protocol("foo/0.0") - assert not http.parse_http_protocol("HTTP/x") + assert HTTP1Protocol._parse_http_protocol("HTTP/1.1") == (1, 1) + assert HTTP1Protocol._parse_http_protocol("HTTP/0.0") == (0, 0) + assert not HTTP1Protocol._parse_http_protocol("HTTP/a.1") + assert not HTTP1Protocol._parse_http_protocol("HTTP/1.a") + assert not HTTP1Protocol._parse_http_protocol("foo/0.0") + assert not HTTP1Protocol._parse_http_protocol("HTTP/x") def test_parse_init_connect(): - assert http.parse_init_connect("CONNECT host.com:443 HTTP/1.0") - assert not http.parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0") - assert not http.parse_init_connect("CONNECT \0host.com:443 HTTP/1.0") - assert not http.parse_init_connect("CONNECT host.com:444444 HTTP/1.0") - assert not http.parse_init_connect("bogus") - assert not http.parse_init_connect("GET host.com:443 HTTP/1.0") - assert not http.parse_init_connect("CONNECT host.com443 HTTP/1.0") - assert not http.parse_init_connect("CONNECT host.com:443 foo/1.0") - assert not http.parse_init_connect("CONNECT host.com:foo HTTP/1.0") + assert HTTP1Protocol._parse_init_connect("CONNECT host.com:443 HTTP/1.0") + assert not HTTP1Protocol._parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0") + assert not HTTP1Protocol._parse_init_connect("CONNECT \0host.com:443 HTTP/1.0") + assert not HTTP1Protocol._parse_init_connect("CONNECT host.com:444444 HTTP/1.0") + assert not HTTP1Protocol._parse_init_connect("bogus") + assert not HTTP1Protocol._parse_init_connect("GET host.com:443 HTTP/1.0") + assert not HTTP1Protocol._parse_init_connect("CONNECT host.com443 HTTP/1.0") + assert not HTTP1Protocol._parse_init_connect("CONNECT host.com:443 foo/1.0") + assert not HTTP1Protocol._parse_init_connect("CONNECT host.com:foo HTTP/1.0") def test_parse_init_proxy(): u = "GET http://foo.com:8888/test HTTP/1.1" - m, s, h, po, pa, httpversion = http.parse_init_proxy(u) + m, s, h, po, pa, httpversion = HTTP1Protocol._parse_init_proxy(u) assert m == "GET" assert s == "http" assert h == "foo.com" @@ -206,27 +209,27 @@ def test_parse_init_proxy(): assert httpversion == (1, 1) u = "G\xfeET http://foo.com:8888/test HTTP/1.1" - assert not http.parse_init_proxy(u) + assert not HTTP1Protocol._parse_init_proxy(u) - assert not http.parse_init_proxy("invalid") - assert not http.parse_init_proxy("GET invalid HTTP/1.1") - assert not http.parse_init_proxy("GET http://foo.com:8888/test foo/1.1") + assert not HTTP1Protocol._parse_init_proxy("invalid") + assert not HTTP1Protocol._parse_init_proxy("GET invalid HTTP/1.1") + assert not HTTP1Protocol._parse_init_proxy("GET http://foo.com:8888/test foo/1.1") def test_parse_init_http(): u = "GET /test HTTP/1.1" - m, u, httpversion = http.parse_init_http(u) + m, u, httpversion = HTTP1Protocol._parse_init_http(u) assert m == "GET" assert u == "/test" assert httpversion == (1, 1) u = "G\xfeET /test HTTP/1.1" - assert not http.parse_init_http(u) + assert not HTTP1Protocol._parse_init_http(u) - assert not http.parse_init_http("invalid") - assert not http.parse_init_http("GET invalid HTTP/1.1") - assert not http.parse_init_http("GET /test foo/1.1") - assert not http.parse_init_http("GET /test\xc0 HTTP/1.1") + assert not HTTP1Protocol._parse_init_http("invalid") + assert not HTTP1Protocol._parse_init_http("GET invalid HTTP/1.1") + assert not HTTP1Protocol._parse_init_http("GET /test foo/1.1") + assert not HTTP1Protocol._parse_init_http("GET /test\xc0 HTTP/1.1") class TestReadHeaders: @@ -235,8 +238,7 @@ class TestReadHeaders: if not verbatim: data = textwrap.dedent(data) data = data.strip() - s = cStringIO.StringIO(data) - return http.read_headers(s) + return mock_protocol(data).read_headers() def test_read_simple(self): data = """ @@ -290,16 +292,15 @@ class TestReadResponseNoContentLength(tservers.ServerTestBase): def test_no_content_length(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - resp = http.read_response(c.rfile, "GET", None) - assert resp.content == "bar\r\n\r\n" + resp = HTTP1Protocol(c).read_response("GET", None) + assert resp.body == "bar\r\n\r\n" def test_read_response(): def tst(data, method, limit, include_body=True): data = textwrap.dedent(data) - r = cStringIO.StringIO(data) - return http.read_response( - r, method, limit, include_body=include_body + return mock_protocol(data).read_response( + method, limit, include_body=include_body ) tutils.raises("server disconnect", tst, "", "GET", None) @@ -307,13 +308,13 @@ def test_read_response(): data = """ HTTP/1.1 200 OK """ - assert tst(data, "GET", None) == ( + assert tst(data, "GET", None) == http.Response( (1, 1), 200, 'OK', odict.ODictCaseless(), '' ) data = """ HTTP/1.1 200 """ - assert tst(data, "GET", None) == ( + assert tst(data, "GET", None) == http.Response( (1, 1), 200, '', odict.ODictCaseless(), '' ) data = """ @@ -330,7 +331,7 @@ def test_read_response(): HTTP/1.1 200 OK """ - assert tst(data, "GET", None) == ( + assert tst(data, "GET", None) == http.Response( (1, 1), 100, 'CONTINUE', odict.ODictCaseless(), '' ) @@ -340,8 +341,8 @@ def test_read_response(): foo """ - assert tst(data, "GET", None)[4] == 'foo' - assert tst(data, "HEAD", None)[4] == '' + assert tst(data, "GET", None).body == 'foo' + assert tst(data, "HEAD", None).body == '' data = """ HTTP/1.1 200 OK @@ -357,74 +358,20 @@ def test_read_response(): foo """ - assert tst(data, "GET", None, include_body=False)[4] is None - - -def test_parse_url(): - assert not http.parse_url("") - - u = "http://foo.com:8888/test" - s, h, po, pa = http.parse_url(u) - assert s == "http" - assert h == "foo.com" - assert po == 8888 - assert pa == "/test" - - s, h, po, pa = http.parse_url("http://foo/bar") - assert s == "http" - assert h == "foo" - assert po == 80 - assert pa == "/bar" - - s, h, po, pa = http.parse_url("http://user:pass@foo/bar") - assert s == "http" - assert h == "foo" - assert po == 80 - assert pa == "/bar" - - s, h, po, pa = http.parse_url("http://foo") - assert pa == "/" - - s, h, po, pa = http.parse_url("https://foo") - assert po == 443 - - assert not http.parse_url("https://foo:bar") - assert not http.parse_url("https://foo:") - - # Invalid IDNA - assert not http.parse_url("http://\xfafoo") - # Invalid PATH - assert not http.parse_url("http:/\xc6/localhost:56121") - # Null byte in host - assert not http.parse_url("http://foo\0") - # Port out of range - assert not http.parse_url("http://foo:999999") - # Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt - assert not http.parse_url('http://lo[calhost') - - -def test_parse_http_basic_auth(): - vals = ("basic", "foo", "bar") - assert http.parse_http_basic_auth( - http.assemble_http_basic_auth(*vals) - ) == vals - assert not http.parse_http_basic_auth("") - assert not http.parse_http_basic_auth("foo bar") - v = "basic " + binascii.b2a_base64("foo") - assert not http.parse_http_basic_auth(v) + assert tst(data, "GET", None, include_body=False).body is None def test_get_request_line(): - r = cStringIO.StringIO("\nfoo") - assert http.get_request_line(r) == "foo" - assert not http.get_request_line(r) + data = "\nfoo" + p = mock_protocol(data) + assert p._get_request_line() == "foo" + assert not p._get_request_line() class TestReadRequest(): def tst(self, data, **kwargs): - r = cStringIO.StringIO(data) - return http.read_request(r, **kwargs) + return mock_protocol(data).read_request(**kwargs) def test_invalid(self): tutils.raises( @@ -478,14 +425,15 @@ class TestReadRequest(): assert v.host == "foo.com" def test_expect(self): - w = cStringIO.StringIO() - r = cStringIO.StringIO( + data = "".join( "GET / HTTP/1.1\r\n" "Content-Length: 3\r\n" "Expect: 100-continue\r\n\r\n" - "foobar", + "foobar" ) - v = http.read_request(r, wfile=w) - assert w.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" - assert v.content == "foo" - assert r.read(3) == "bar" + + p = mock_protocol(data) + v = p.read_request() + assert p.tcp_handler.wfile.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" + assert v.body == "foo" + assert p.tcp_handler.rfile.read(3) == "bar" diff --git a/test/http/http2/__init__.py b/test/http/http2/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/http/http2/__init__.py diff --git a/test/http2/test_frames.py b/test/http/http2/test_frames.py index 76a4b712..ee2edc39 100644 --- a/test/http2/test_frames.py +++ b/test/http/http2/test_frames.py @@ -2,7 +2,7 @@ import cStringIO from test import tutils from nose.tools import assert_equal from netlib import tcp -from netlib.http2.frame import * +from netlib.http.http2.frame import * def hex_to_file(data): diff --git a/test/http2/test_protocol.py b/test/http/http2/test_protocol.py index 5e2af34e..d3040266 100644 --- a/test/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -1,10 +1,9 @@ import OpenSSL -from netlib import http2 -from netlib import tcp -from netlib.http2.frame import * -from test import tutils -from .. import tservers +from netlib import tcp, odict +from netlib.http import http2 +from netlib.http.http2.frame import * +from ... import tutils, tservers class EchoHandler(tcp.BaseHandler): @@ -252,11 +251,13 @@ class TestReadResponse(tservers.ServerTestBase): c.convert_to_ssl() protocol = http2.HTTP2Protocol(c) - status, headers, body = protocol.read_response() + resp = protocol.read_response() - assert headers == {':status': '200', 'etag': 'foobar'} - assert status == "200" - assert body == b'foobar' + assert resp.httpversion == "HTTP/2" + assert resp.status_code == "200" + assert resp.msg == "" + assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] + assert resp.body == b'foobar' class TestReadEmptyResponse(tservers.ServerTestBase): @@ -275,11 +276,14 @@ class TestReadEmptyResponse(tservers.ServerTestBase): c.convert_to_ssl() protocol = http2.HTTP2Protocol(c) - status, headers, body = protocol.read_response() + resp = protocol.read_response() - assert headers == {':status': '200', 'etag': 'foobar'} - assert status == "200" - assert body == b'' + assert resp.stream_id + assert resp.httpversion == "HTTP/2" + assert resp.status_code == "200" + assert resp.msg == "" + assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] + assert resp.body == b'' class TestReadRequest(tservers.ServerTestBase): @@ -300,11 +304,11 @@ class TestReadRequest(tservers.ServerTestBase): c.convert_to_ssl() protocol = http2.HTTP2Protocol(c, is_server=True) - stream_id, headers, body = protocol.read_request() + resp = protocol.read_request() - assert stream_id - assert headers == {':method': 'GET', ':path': '/', ':scheme': 'https'} - assert body == b'foobar' + assert resp.stream_id + assert resp.headers.lst == [[u':method', u'GET'], [u':path', u'/'], [u':scheme', u'https']] + assert resp.body == b'foobar' class TestCreateResponse(): diff --git a/test/test_http_auth.py b/test/http/test_authentication.py index c842925b..8f231643 100644 --- a/test/test_http_auth.py +++ b/test/http/test_authentication.py @@ -1,11 +1,25 @@ -from netlib import odict, http_auth, http -import tutils +import binascii + +from netlib import odict, http +from netlib.http import authentication +from .. import tutils + + +def test_parse_http_basic_auth(): + vals = ("basic", "foo", "bar") + assert http.authentication.parse_http_basic_auth( + http.authentication.assemble_http_basic_auth(*vals) + ) == vals + assert not http.authentication.parse_http_basic_auth("") + assert not http.authentication.parse_http_basic_auth("foo bar") + v = "basic " + binascii.b2a_base64("foo") + assert not http.authentication.parse_http_basic_auth(v) class TestPassManNonAnon: def test_simple(self): - p = http_auth.PassManNonAnon() + p = authentication.PassManNonAnon() assert not p.test("", "") assert p.test("user", "") @@ -15,14 +29,14 @@ class TestPassManHtpasswd: def test_file_errors(self): tutils.raises( "malformed htpasswd file", - http_auth.PassManHtpasswd, + authentication.PassManHtpasswd, tutils.test_data.path("data/server.crt")) def test_simple(self): - pm = http_auth.PassManHtpasswd(tutils.test_data.path("data/htpasswd")) + pm = authentication.PassManHtpasswd(tutils.test_data.path("data/htpasswd")) vals = ("basic", "test", "test") - http.assemble_http_basic_auth(*vals) + authentication.assemble_http_basic_auth(*vals) assert pm.test("test", "test") assert not pm.test("test", "foo") assert not pm.test("foo", "test") @@ -33,7 +47,7 @@ class TestPassManHtpasswd: class TestPassManSingleUser: def test_simple(self): - pm = http_auth.PassManSingleUser("test", "test") + pm = authentication.PassManSingleUser("test", "test") assert pm.test("test", "test") assert not pm.test("test", "foo") assert not pm.test("foo", "test") @@ -42,7 +56,7 @@ class TestPassManSingleUser: class TestNullProxyAuth: def test_simple(self): - na = http_auth.NullProxyAuth(http_auth.PassManNonAnon()) + na = authentication.NullProxyAuth(authentication.PassManNonAnon()) assert not na.auth_challenge_headers() assert na.authenticate("foo") na.clean({}) @@ -51,17 +65,17 @@ class TestNullProxyAuth: class TestBasicProxyAuth: def test_simple(self): - ba = http_auth.BasicProxyAuth(http_auth.PassManNonAnon(), "test") + ba = authentication.BasicProxyAuth(authentication.PassManNonAnon(), "test") h = odict.ODictCaseless() assert ba.auth_challenge_headers() assert not ba.authenticate(h) def test_authenticate_clean(self): - ba = http_auth.BasicProxyAuth(http_auth.PassManNonAnon(), "test") + ba = authentication.BasicProxyAuth(authentication.PassManNonAnon(), "test") hdrs = odict.ODictCaseless() vals = ("basic", "foo", "bar") - hdrs[ba.AUTH_HEADER] = [http.assemble_http_basic_auth(*vals)] + hdrs[ba.AUTH_HEADER] = [authentication.assemble_http_basic_auth(*vals)] assert ba.authenticate(hdrs) ba.clean(hdrs) @@ -74,12 +88,12 @@ class TestBasicProxyAuth: assert not ba.authenticate(hdrs) vals = ("foo", "foo", "bar") - hdrs[ba.AUTH_HEADER] = [http.assemble_http_basic_auth(*vals)] + hdrs[ba.AUTH_HEADER] = [authentication.assemble_http_basic_auth(*vals)] assert not ba.authenticate(hdrs) - ba = http_auth.BasicProxyAuth(http_auth.PassMan(), "test") + ba = authentication.BasicProxyAuth(authentication.PassMan(), "test") vals = ("basic", "foo", "bar") - hdrs[ba.AUTH_HEADER] = [http.assemble_http_basic_auth(*vals)] + hdrs[ba.AUTH_HEADER] = [authentication.assemble_http_basic_auth(*vals)] assert not ba.authenticate(hdrs) @@ -91,19 +105,19 @@ class TestAuthAction: def test_nonanonymous(self): m = Bunch() - aa = http_auth.NonanonymousAuthAction(None, "authenticator") + aa = authentication.NonanonymousAuthAction(None, "authenticator") aa(None, m, None, None) assert m.authenticator def test_singleuser(self): m = Bunch() - aa = http_auth.SingleuserAuthAction(None, "authenticator") + aa = authentication.SingleuserAuthAction(None, "authenticator") aa(None, m, "foo:bar", None) assert m.authenticator tutils.raises("invalid", aa, None, m, "foo", None) def test_httppasswd(self): m = Bunch() - aa = http_auth.HtpasswdAuthAction(None, "authenticator") + aa = authentication.HtpasswdAuthAction(None, "authenticator") aa(None, m, tutils.test_data.path("data/htpasswd"), None) assert m.authenticator diff --git a/test/test_http_cookies.py b/test/http/test_cookies.py index 070849cf..4f99593a 100644 --- a/test/test_http_cookies.py +++ b/test/http/test_cookies.py @@ -1,6 +1,6 @@ import nose.tools -from netlib import http_cookies +from netlib.http import cookies def test_read_token(): @@ -13,7 +13,7 @@ def test_read_token(): [(" foo=bar", 1), ("foo", 4)], ] for q, a in tokens: - nose.tools.eq_(http_cookies._read_token(*q), a) + nose.tools.eq_(cookies._read_token(*q), a) def test_read_quoted_string(): @@ -25,7 +25,7 @@ def test_read_quoted_string(): [('"fo\\\"" x', 0), ("fo\"", 6)], ] for q, a in tokens: - nose.tools.eq_(http_cookies._read_quoted_string(*q), a) + nose.tools.eq_(cookies._read_quoted_string(*q), a) def test_read_pairs(): @@ -60,7 +60,7 @@ def test_read_pairs(): ], ] for s, lst in vals: - ret, off = http_cookies._read_pairs(s) + ret, off = cookies._read_pairs(s) nose.tools.eq_(ret, lst) @@ -108,10 +108,10 @@ def test_pairs_roundtrips(): ] ] for s, lst in pairs: - ret, off = http_cookies._read_pairs(s) + ret, off = cookies._read_pairs(s) nose.tools.eq_(ret, lst) - s2 = http_cookies._format_pairs(lst) - ret, off = http_cookies._read_pairs(s2) + s2 = cookies._format_pairs(lst) + ret, off = cookies._read_pairs(s2) nose.tools.eq_(ret, lst) @@ -127,10 +127,10 @@ def test_cookie_roundtrips(): ], ] for s, lst in pairs: - ret = http_cookies.parse_cookie_header(s) + ret = cookies.parse_cookie_header(s) nose.tools.eq_(ret.lst, lst) - s2 = http_cookies.format_cookie_header(ret) - ret = http_cookies.parse_cookie_header(s2) + s2 = cookies.format_cookie_header(ret) + ret = cookies.parse_cookie_header(s2) nose.tools.eq_(ret.lst, lst) @@ -180,10 +180,10 @@ def test_parse_set_cookie_pairs(): ], ] for s, lst in pairs: - ret = http_cookies._parse_set_cookie_pairs(s) + ret = cookies._parse_set_cookie_pairs(s) nose.tools.eq_(ret, lst) - s2 = http_cookies._format_set_cookie_pairs(ret) - ret2 = http_cookies._parse_set_cookie_pairs(s2) + s2 = cookies._format_set_cookie_pairs(ret) + ret2 = cookies._parse_set_cookie_pairs(s2) nose.tools.eq_(ret2, lst) @@ -205,13 +205,13 @@ def test_parse_set_cookie_header(): ] ] for s, expected in vals: - ret = http_cookies.parse_set_cookie_header(s) + ret = cookies.parse_set_cookie_header(s) if expected: assert ret[0] == expected[0] assert ret[1] == expected[1] nose.tools.eq_(ret[2].lst, expected[2]) - s2 = http_cookies.format_set_cookie_header(*ret) - ret2 = http_cookies.parse_set_cookie_header(s2) + s2 = cookies.format_set_cookie_header(*ret) + ret2 = cookies.parse_set_cookie_header(s2) assert ret2[0] == expected[0] assert ret2[1] == expected[1] nose.tools.eq_(ret2[2].lst, expected[2]) diff --git a/test/http/test_semantics.py b/test/http/test_semantics.py new file mode 100644 index 00000000..c4605302 --- /dev/null +++ b/test/http/test_semantics.py @@ -0,0 +1,54 @@ +import cStringIO +import textwrap +import binascii + +from netlib import http, odict, tcp +from netlib.http import http1 +from .. import tutils, tservers + +def test_httperror(): + e = http.exceptions.HttpError(404, "Not found") + assert str(e) + + +def test_parse_url(): + assert not http.parse_url("") + + u = "http://foo.com:8888/test" + s, h, po, pa = http.parse_url(u) + assert s == "http" + assert h == "foo.com" + assert po == 8888 + assert pa == "/test" + + s, h, po, pa = http.parse_url("http://foo/bar") + assert s == "http" + assert h == "foo" + assert po == 80 + assert pa == "/bar" + + s, h, po, pa = http.parse_url("http://user:pass@foo/bar") + assert s == "http" + assert h == "foo" + assert po == 80 + assert pa == "/bar" + + s, h, po, pa = http.parse_url("http://foo") + assert pa == "/" + + s, h, po, pa = http.parse_url("https://foo") + assert po == 443 + + assert not http.parse_url("https://foo:bar") + assert not http.parse_url("https://foo:") + + # Invalid IDNA + assert not http.parse_url("http://\xfafoo") + # Invalid PATH + assert not http.parse_url("http:/\xc6/localhost:56121") + # Null byte in host + assert not http.parse_url("http://foo\0") + # Port out of range + assert not http.parse_url("http://foo:999999") + # Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt + assert not http.parse_url('http://lo[calhost') diff --git a/test/http/test_user_agents.py b/test/http/test_user_agents.py new file mode 100644 index 00000000..0bf1bba7 --- /dev/null +++ b/test/http/test_user_agents.py @@ -0,0 +1,6 @@ +from netlib.http import user_agents + + +def test_get_shortcut(): + assert user_agents.get_by_shortcut("c")[0] == "chrome" + assert not user_agents.get_by_shortcut("_") diff --git a/test/test_http_uastrings.py b/test/test_http_uastrings.py deleted file mode 100644 index 3fa4f359..00000000 --- a/test/test_http_uastrings.py +++ /dev/null @@ -1,6 +0,0 @@ -from netlib import http_uastrings - - -def test_get_shortcut(): - assert http_uastrings.get_by_shortcut("c")[0] == "chrome" - assert not http_uastrings.get_by_shortcut("_") diff --git a/test/websockets/__init__.py b/test/websockets/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/websockets/__init__.py diff --git a/test/test_websockets.py b/test/websockets/test_websockets.py index 9956543b..fb7ba39a 100644 --- a/test/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -2,8 +2,10 @@ import os from nose.tools import raises -from netlib import tcp, websockets, http -from . import tutils, tservers +from netlib import tcp, http, websockets +from netlib.http.exceptions import * +from netlib.http.http1 import HTTP1Protocol +from .. import tutils, tservers class WebSocketsEchoHandler(tcp.BaseHandler): @@ -12,6 +14,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler): super(WebSocketsEchoHandler, self).__init__( connection, address, server ) + self.protocol = websockets.WebsocketsProtocol() self.handshake_done = False def handle(self): @@ -30,11 +33,14 @@ class WebSocketsEchoHandler(tcp.BaseHandler): frame.to_file(self.wfile) def handshake(self): - req = http.read_request(self.rfile) - key = websockets.check_client_handshake(req.headers) + http1_protocol = HTTP1Protocol(self) - self.wfile.write(http.response_preamble(101) + "\r\n") - headers = websockets.server_handshake_headers(key) + req = http1_protocol.read_request() + key = self.protocol.check_client_handshake(req.headers) + + preamble = http1_protocol.response_preamble(101) + self.wfile.write(preamble + "\r\n") + headers = self.protocol.server_handshake_headers(key) self.wfile.write(headers.format() + "\r\n") self.wfile.flush() self.handshake_done = True @@ -48,22 +54,25 @@ class WebSocketsClient(tcp.TCPClient): def __init__(self, address, source_address=None): super(WebSocketsClient, self).__init__(address, source_address) + self.protocol = websockets.WebsocketsProtocol() self.client_nonce = None def connect(self): super(WebSocketsClient, self).connect() - preamble = http.request_preamble("GET", "/") + http1_protocol = HTTP1Protocol(self) + + preamble = http1_protocol.request_preamble("GET", "/") self.wfile.write(preamble + "\r\n") - headers = websockets.client_handshake_headers() + headers = self.protocol.client_handshake_headers() self.client_nonce = headers.get_first("sec-websocket-key") self.wfile.write(headers.format() + "\r\n") self.wfile.flush() - resp = http.read_response(self.rfile, "get", None) - server_nonce = websockets.check_server_handshake(resp.headers) + resp = http1_protocol.read_response("get", None) + server_nonce = self.protocol.check_server_handshake(resp.headers) - if not server_nonce == websockets.create_server_nonce( + if not server_nonce == self.protocol.create_server_nonce( self.client_nonce): self.close() @@ -78,6 +87,9 @@ class WebSocketsClient(tcp.TCPClient): class TestWebSockets(tservers.ServerTestBase): handler = WebSocketsEchoHandler + def __init__(self): + self.protocol = websockets.WebsocketsProtocol() + def random_bytes(self, n=100): return os.urandom(n) @@ -130,26 +142,29 @@ class TestWebSockets(tservers.ServerTestBase): assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes def test_check_server_handshake(self): - headers = websockets.server_handshake_headers("key") - assert websockets.check_server_handshake(headers) + headers = self.protocol.server_handshake_headers("key") + assert self.protocol.check_server_handshake(headers) headers["Upgrade"] = ["not_websocket"] - assert not websockets.check_server_handshake(headers) + assert not self.protocol.check_server_handshake(headers) def test_check_client_handshake(self): - headers = websockets.client_handshake_headers("key") - assert websockets.check_client_handshake(headers) == "key" + headers = self.protocol.client_handshake_headers("key") + assert self.protocol.check_client_handshake(headers) == "key" headers["Upgrade"] = ["not_websocket"] - assert not websockets.check_client_handshake(headers) + assert not self.protocol.check_client_handshake(headers) class BadHandshakeHandler(WebSocketsEchoHandler): def handshake(self): - client_hs = http.read_request(self.rfile) - websockets.check_client_handshake(client_hs.headers) + http1_protocol = HTTP1Protocol(self) + + client_hs = http1_protocol.read_request() + self.protocol.check_client_handshake(client_hs.headers) - self.wfile.write(http.response_preamble(101) + "\r\n") - headers = websockets.server_handshake_headers("malformed key") + preamble = http1_protocol.response_preamble(101) + self.wfile.write(preamble + "\r\n") + headers = self.protocol.server_handshake_headers("malformed key") self.wfile.write(headers.format() + "\r\n") self.wfile.flush() self.handshake_done = True |