From 106f7046d3862cb0e3cbb4f38335af0330b4e7e3 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 26 Sep 2015 00:39:04 +0200 Subject: refactor request model --- netlib/http/__init__.py | 5 +- netlib/http/headers.py | 2 +- netlib/http/http1/assemble.py | 65 ++++--- netlib/http/http1/read.py | 8 +- netlib/http/message.py | 146 +++++++++++++++ netlib/http/models.py | 233 ------------------------ netlib/http/request.py | 351 +++++++++++++++++++++++++++++++++++++ netlib/http/response.py | 3 + netlib/tutils.py | 4 +- netlib/utils.py | 15 +- test/http/http1/test_assemble.py | 12 +- test/http/http1/test_read.py | 8 +- test/http/test_models.py | 75 +++----- test/http/test_request.py | 3 + test/http/test_response.py | 3 + test/test_utils.py | 8 +- test/websockets/test_websockets.py | 2 +- 17 files changed, 598 insertions(+), 345 deletions(-) create mode 100644 netlib/http/message.py create mode 100644 netlib/http/request.py create mode 100644 netlib/http/response.py create mode 100644 test/http/test_request.py create mode 100644 test/http/test_response.py diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index 0ccf6b32..e8c7ba20 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -1,12 +1,15 @@ from __future__ import absolute_import, print_function, division from .headers import Headers -from .models import Request, Response +from .message import decoded +from .request import Request +from .models import Response from .models import ALPN_PROTO_HTTP1, ALPN_PROTO_H2 from .models import HDR_FORM_MULTIPART, HDR_FORM_URLENCODED, CONTENT_MISSING from . import http1, http2 __all__ = [ "Headers", + "decoded", "Request", "Response", "ALPN_PROTO_HTTP1", "ALPN_PROTO_H2", "HDR_FORM_MULTIPART", "HDR_FORM_URLENCODED", "CONTENT_MISSING", diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 613beb4f..47ea923b 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -27,7 +27,7 @@ else: _always_byte_args = always_byte_args("utf-8", "surrogateescape") -class Headers(MutableMapping, object): +class Headers(MutableMapping): """ Header class which allows both convenient access to individual headers as well as direct access to the underlying raw data. Provides a full dictionary interface. diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index 88aeac05..864f6017 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -7,24 +7,24 @@ from .. import CONTENT_MISSING def assemble_request(request): - if request.body == CONTENT_MISSING: + if request.content == CONTENT_MISSING: raise HttpException("Cannot assemble flow with CONTENT_MISSING") head = assemble_request_head(request) - body = b"".join(assemble_body(request.headers, [request.body])) + body = b"".join(assemble_body(request.headers, [request.data.content])) return head + body def assemble_request_head(request): - first_line = _assemble_request_line(request) - headers = _assemble_request_headers(request) + first_line = _assemble_request_line(request.data) + headers = _assemble_request_headers(request.data) return b"%s\r\n%s\r\n" % (first_line, headers) def assemble_response(response): - if response.body == CONTENT_MISSING: + if response.content == CONTENT_MISSING: raise HttpException("Cannot assemble flow with CONTENT_MISSING") head = assemble_response_head(response) - body = b"".join(assemble_body(response.headers, [response.body])) + body = b"".join(assemble_body(response.headers, [response.content])) return head + body @@ -45,42 +45,49 @@ def assemble_body(headers, body_chunks): yield chunk -def _assemble_request_line(request, form=None): - if form is None: - form = request.form_out +def _assemble_request_line(request_data): + """ + Args: + request_data (netlib.http.request.RequestData) + """ + form = request_data.first_line_format if form == "relative": return b"%s %s %s" % ( - request.method, - request.path, - request.http_version + request_data.method, + request_data.path, + request_data.http_version ) elif form == "authority": return b"%s %s:%d %s" % ( - request.method, - request.host, - request.port, - request.http_version + request_data.method, + request_data.host, + request_data.port, + request_data.http_version ) elif form == "absolute": return b"%s %s://%s:%d%s %s" % ( - request.method, - request.scheme, - request.host, - request.port, - request.path, - request.http_version + request_data.method, + request_data.scheme, + request_data.host, + request_data.port, + request_data.path, + request_data.http_version ) - else: # pragma: nocover + else: raise RuntimeError("Invalid request form") -def _assemble_request_headers(request): - headers = request.headers.copy() - if "host" not in headers and request.scheme and request.host and request.port: +def _assemble_request_headers(request_data): + """ + Args: + request_data (netlib.http.request.RequestData) + """ + headers = request_data.headers.copy() + if "host" not in headers and request_data.scheme and request_data.host and request_data.port: headers["host"] = utils.hostport( - request.scheme, - request.host, - request.port + request_data.scheme, + request_data.host, + request_data.port ) return bytes(headers) diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 4c898348..76721e06 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -11,7 +11,7 @@ from .. import Request, Response, Headers def read_request(rfile, body_size_limit=None): request = read_request_head(rfile) expected_body_size = expected_http_body_size(request) - request._body = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit)) + request.data.content = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit)) request.timestamp_end = time.time() return request @@ -155,7 +155,7 @@ def connection_close(http_version, headers): # If we don't have a Connection header, HTTP 1.1 connections are assumed to # be persistent - return http_version != b"HTTP/1.1" + return http_version != "HTTP/1.1" and http_version != b"HTTP/1.1" # FIXME: Remove one case. def expected_http_body_size(request, response=None): @@ -184,11 +184,11 @@ def expected_http_body_size(request, response=None): if headers.get("expect", "").lower() == "100-continue": return 0 else: - if request.method.upper() == b"HEAD": + if request.method.upper() == "HEAD": return 0 if 100 <= response_code <= 199: return 0 - if response_code == 200 and request.method.upper() == b"CONNECT": + if response_code == 200 and request.method.upper() == "CONNECT": return 0 if response_code in (204, 304): return 0 diff --git a/netlib/http/message.py b/netlib/http/message.py new file mode 100644 index 00000000..20497bd5 --- /dev/null +++ b/netlib/http/message.py @@ -0,0 +1,146 @@ +from __future__ import absolute_import, print_function, division + +import warnings + +import six + +from .. import encoding, utils + +if six.PY2: + _native = lambda x: x + _always_bytes = lambda x: x +else: + # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. + _native = lambda x: x.decode("utf-8", "surrogateescape") + _always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape") + + +class Message(object): + def __init__(self, data): + self.data = data + + def __eq__(self, other): + if isinstance(other, Message): + return self.data == other.data + return False + + def __ne__(self, other): + return not self.__eq__(other) + + @property + def http_version(self): + """ + Version string, e.g. "HTTP/1.1" + """ + return _native(self.data.http_version) + + @http_version.setter + def http_version(self, http_version): + self.data.http_version = _always_bytes(http_version) + + @property + def headers(self): + """ + Message headers object + + Returns: + netlib.http.Headers + """ + return self.data.headers + + @headers.setter + def headers(self, h): + self.data.headers = h + + @property + def timestamp_start(self): + """ + First byte timestamp + """ + return self.data.timestamp_start + + @timestamp_start.setter + def timestamp_start(self, timestamp_start): + self.data.timestamp_start = timestamp_start + + @property + def timestamp_end(self): + """ + Last byte timestamp + """ + return self.data.timestamp_end + + @timestamp_end.setter + def timestamp_end(self, timestamp_end): + self.data.timestamp_end = timestamp_end + + @property + def content(self): + """ + The raw (encoded) HTTP message body + + See also: :py:attr:`text` + """ + return self.data.content + + @content.setter + def content(self, content): + self.data.content = content + if isinstance(content, bytes): + self.headers["content-length"] = str(len(content)) + + @property + def text(self): + """ + The decoded HTTP message body. + Decoded contents are not cached, so this method is relatively expensive to call. + + See also: :py:attr:`content`, :py:class:`decoded` + """ + # This attribute should be called text, because that's what requests does. + raise NotImplementedError() + + @text.setter + def text(self, text): + raise NotImplementedError() + + @property + def body(self): + warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning) + return self.content + + @body.setter + def body(self, body): + warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning) + self.content = body + + +class decoded(object): + """ + A context manager that decodes a request or response, and then + re-encodes it with the same encoding after execution of the block. + + Example: + + .. code-block:: python + + with decoded(request): + request.content = request.content.replace("foo", "bar") + """ + + def __init__(self, message): + self.message = message + ce = message.headers.get("content-encoding") + if ce in encoding.ENCODINGS: + self.ce = ce + else: + self.ce = None + + def __enter__(self): + if self.ce: + if not self.message.decode(): + self.ce = None + + def __exit__(self, type, value, tb): + if self.ce: + self.message.encode(self.ce) \ No newline at end of file diff --git a/netlib/http/models.py b/netlib/http/models.py index 55664533..40f6e98c 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -47,239 +47,6 @@ class Message(object): return False -class Request(Message): - def __init__( - self, - form_in, - method, - scheme, - host, - port, - path, - http_version, - headers=None, - body=None, - timestamp_start=None, - timestamp_end=None, - form_out=None - ): - super(Request, self).__init__(http_version, headers, body, timestamp_start, timestamp_end) - - self.form_in = form_in - self.method = method - self.scheme = scheme - self.host = host - self.port = port - self.path = path - self.form_out = form_out or form_in - - def __repr__(self): - if self.host and self.port: - hostport = "{}:{}".format(native(self.host,"idna"), self.port) - else: - hostport = "" - path = self.path or "" - return "HTTPRequest({} {}{})".format( - self.method, hostport, path - ) - - def anticache(self): - """ - Modifies this request to remove headers that might produce a cached - response. That is, we remove ETags and If-Modified-Since headers. - """ - delheaders = [ - "if-modified-since", - "if-none-match", - ] - for i in delheaders: - self.headers.pop(i, None) - - def anticomp(self): - """ - Modifies this request to remove headers that will compress the - resource's data. - """ - self.headers["accept-encoding"] = "identity" - - def constrain_encoding(self): - """ - Limits the permissible Accept-Encoding values, based on what we can - decode appropriately. - """ - accept_encoding = self.headers.get("accept-encoding") - if accept_encoding: - self.headers["accept-encoding"] = ( - ', '.join( - e - for e in encoding.ENCODINGS - if e in accept_encoding - ) - ) - - def update_host_header(self): - """ - Update the host header to reflect the current target. - """ - self.headers["host"] = self.host - - def get_form(self): - """ - Retrieves the URL-encoded or multipart form data, returning an ODict object. - Returns an empty ODict if there is no data or the content-type - indicates non-form data. - """ - if self.body: - if HDR_FORM_URLENCODED in self.headers.get("content-type", "").lower(): - return self.get_form_urlencoded() - elif HDR_FORM_MULTIPART in self.headers.get("content-type", "").lower(): - return self.get_form_multipart() - return ODict([]) - - def get_form_urlencoded(self): - """ - Retrieves the URL-encoded form data, returning an ODict object. - Returns an empty ODict if there is no data or the content-type - indicates non-form data. - """ - if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type", "").lower(): - return ODict(utils.urldecode(self.body)) - return ODict([]) - - def get_form_multipart(self): - if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type", "").lower(): - return ODict( - utils.multipartdecode( - self.headers, - self.body)) - return ODict([]) - - def set_form_urlencoded(self, odict): - """ - Sets the body to the URL-encoded form data, and adds the - appropriate content-type header. Note that this will destory the - existing body if there is one. - """ - # FIXME: If there's an existing content-type header indicating a - # url-encoded form, leave it alone. - self.headers["content-type"] = HDR_FORM_URLENCODED - self.body = utils.urlencode(odict.lst) - - def get_path_components(self): - """ - Returns the path components of the URL as a list of strings. - - Components are unquoted. - """ - _, _, path, _, _, _ = urllib.parse.urlparse(self.url) - return [urllib.parse.unquote(native(i,"ascii")) for i in path.split(b"/") if i] - - def set_path_components(self, lst): - """ - Takes a list of strings, and sets the path component of the URL. - - Components are quoted. - """ - lst = [urllib.parse.quote(i, safe="") for i in lst] - path = always_bytes("/" + "/".join(lst)) - scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url) - self.url = urllib.parse.urlunparse( - [scheme, netloc, path, params, query, fragment] - ) - - def get_query(self): - """ - Gets the request query string. Returns an ODict object. - """ - _, _, _, _, query, _ = urllib.parse.urlparse(self.url) - if query: - return ODict(utils.urldecode(query)) - return ODict([]) - - def set_query(self, odict): - """ - Takes an ODict object, and sets the request query string. - """ - scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) - query = utils.urlencode(odict.lst) - self.url = urllib.parse.urlunparse( - [scheme, netloc, path, params, query, fragment] - ) - - def pretty_host(self, hostheader): - """ - Heuristic to get the host of the request. - - Note that pretty_host() does not always return the TCP destination - of the request, e.g. if an upstream proxy is in place - - If hostheader is set to True, the Host: header will be used as - additional (and preferred) data source. This is handy in - transparent mode, where only the IO of the destination is known, - but not the resolved name. This is disabled by default, as an - attacker may spoof the host header to confuse an analyst. - """ - if hostheader and "host" in self.headers: - try: - return self.headers["host"] - except ValueError: - pass - if self.host: - return self.host.decode("idna") - - def pretty_url(self, hostheader): - if self.form_out == "authority": # upstream proxy mode - return b"%s:%d" % (always_bytes(self.pretty_host(hostheader)), self.port) - return utils.unparse_url(self.scheme, - self.pretty_host(hostheader), - self.port, - self.path) - - def get_cookies(self): - """ - Returns a possibly empty netlib.odict.ODict object. - """ - ret = ODict() - for i in self.headers.get_all("Cookie"): - ret.extend(cookies.parse_cookie_header(i)) - return ret - - def set_cookies(self, odict): - """ - Takes an netlib.odict.ODict object. Over-writes any existing Cookie - headers. - """ - v = cookies.format_cookie_header(odict) - self.headers["cookie"] = v - - @property - def url(self): - """ - Returns a URL string, constructed from the Request's URL components. - """ - return utils.unparse_url( - self.scheme, - self.host, - self.port, - self.path - ) - - @url.setter - def url(self, url): - """ - Parses a URL specification, and updates the Request's information - accordingly. - - Raises: - ValueError if the URL was invalid - """ - # TODO: Should handle incoming unicode here. - parts = utils.parse_url(url) - if not parts: - raise ValueError("Invalid URL: %s" % url) - self.scheme, self.host, self.port, self.path = parts - - class Response(Message): def __init__( self, diff --git a/netlib/http/request.py b/netlib/http/request.py new file mode 100644 index 00000000..6830ca40 --- /dev/null +++ b/netlib/http/request.py @@ -0,0 +1,351 @@ +from __future__ import absolute_import, print_function, division + +import warnings + +import six +from six.moves import urllib + +from netlib import utils +from netlib.http import cookies +from netlib.odict import ODict +from .. import encoding +from .headers import Headers +from .message import Message, _native, _always_bytes + + +class RequestData(object): + def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=None, content=None, + timestamp_start=None, timestamp_end=None): + if not headers: + headers = Headers() + assert isinstance(headers, Headers) + + self.first_line_format = first_line_format + self.method = method + self.scheme = scheme + self.host = host + self.port = port + self.path = path + self.http_version = http_version + self.headers = headers + self.content = content + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end + + def __eq__(self, other): + if isinstance(other, RequestData): + return self.__dict__ == other.__dict__ + return False + + def __ne__(self, other): + return not self.__eq__(other) + + +class Request(Message): + """ + An HTTP request. + """ + def __init__(self, *args, **kwargs): + data = RequestData(*args, **kwargs) + super(Request, self).__init__(data) + + def __repr__(self): + if self.host and self.port: + hostport = "{}:{}".format(self.host, self.port) + else: + hostport = "" + path = self.path or "" + return "HTTPRequest({} {}{})".format( + self.method, hostport, path + ) + + @property + def first_line_format(self): + """ + HTTP request form as defined in `RFC7230 `_. + + origin-form and asterisk-form are subsumed as "relative". + """ + return self.data.first_line_format + + @first_line_format.setter + def first_line_format(self, first_line_format): + self.data.first_line_format = first_line_format + + @property + def method(self): + """ + HTTP request method, e.g. "GET". + """ + return _native(self.data.method) + + @method.setter + def method(self, method): + self.data.method = _always_bytes(method) + + @property + def scheme(self): + """ + HTTP request scheme, which should be "http" or "https". + """ + return _native(self.data.scheme) + + @scheme.setter + def scheme(self, scheme): + self.data.scheme = _always_bytes(scheme) + + @property + def host(self): + """ + Target host for the request. This may be directly taken in the request (e.g. "GET http://example.com/ HTTP/1.1") + or inferred from the proxy mode (e.g. an IP in transparent mode). + """ + + if six.PY2: + return self.data.host + + if not self.data.host: + return self.data.host + try: + return self.data.host.decode("idna") + except UnicodeError: + return self.data.host.decode("utf8", "surrogateescape") + + @host.setter + def host(self, host): + if isinstance(host, six.text_type): + try: + # There's no non-strict mode for IDNA encoding. + # We don't want this operation to fail though, so we try + # utf8 as a last resort. + host = host.encode("idna", "strict") + except UnicodeError: + host = host.encode("utf8", "surrogateescape") + + self.data.host = host + + # Update host header + if "host" in self.headers: + if host: + self.headers["host"] = host + else: + self.headers.pop("host") + + @property + def port(self): + """ + Target port + """ + return self.data.port + + @port.setter + def port(self, port): + self.data.port = port + + @property + def path(self): + """ + HTTP request path, e.g. "/index.html". + Guaranteed to start with a slash. + """ + return _native(self.data.path) + + @path.setter + def path(self, path): + self.data.path = _always_bytes(path) + + def anticache(self): + """ + Modifies this request to remove headers that might produce a cached + response. That is, we remove ETags and If-Modified-Since headers. + """ + delheaders = [ + "if-modified-since", + "if-none-match", + ] + for i in delheaders: + self.headers.pop(i, None) + + def anticomp(self): + """ + Modifies this request to remove headers that will compress the + resource's data. + """ + self.headers["accept-encoding"] = "identity" + + def constrain_encoding(self): + """ + Limits the permissible Accept-Encoding values, based on what we can + decode appropriately. + """ + accept_encoding = self.headers.get("accept-encoding") + if accept_encoding: + self.headers["accept-encoding"] = ( + ', '.join( + e + for e in encoding.ENCODINGS + if e in accept_encoding + ) + ) + + @property + def urlencoded_form(self): + """ + The URL-encoded form data as an ODict object. + None if there is no data or the content-type indicates non-form data. + """ + is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() + if self.content and is_valid_content_type: + return ODict(utils.urldecode(self.content)) + return None + + @urlencoded_form.setter + def urlencoded_form(self, odict): + """ + Sets the body to the URL-encoded form data, and adds the appropriate content-type header. + This will overwrite the existing content if there is one. + """ + self.headers["content-type"] = "application/x-www-form-urlencoded" + self.content = utils.urlencode(odict.lst) + + @property + def multipart_form(self): + """ + The multipart form data as an ODict object. + None if there is no data or the content-type indicates non-form data. + """ + is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() + if self.content and is_valid_content_type: + return ODict(utils.multipartdecode(self.headers,self.content)) + return None + + @multipart_form.setter + def multipart_form(self): + raise NotImplementedError() + + @property + def path_components(self): + """ + The URL's path components as a list of strings. + Components are unquoted. + """ + _, _, path, _, _, _ = urllib.parse.urlparse(self.url) + return [urllib.parse.unquote(i) for i in path.split("/") if i] + + @path_components.setter + def path_components(self, components): + components = map(lambda x: urllib.parse.quote(x, safe=""), components) + path = "/" + "/".join(components) + scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url) + self.url = urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]) + + @property + def query(self): + """ + The request query string as an ODict object. + None, if there is no query. + """ + _, _, _, _, query, _ = urllib.parse.urlparse(self.url) + if query: + return ODict(utils.urldecode(query)) + return None + + @query.setter + def query(self, odict): + query = utils.urlencode(odict.lst) + scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) + self.url = urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]) + + @property + def cookies(self): + """ + The request cookies. + An empty ODict object if the cookie monster ate them all. + """ + ret = ODict() + for i in self.headers.get_all("Cookie"): + ret.extend(cookies.parse_cookie_header(i)) + return ret + + @cookies.setter + def cookies(self, odict): + self.headers["cookie"] = cookies.format_cookie_header(odict) + + @property + def url(self): + """ + The URL string, constructed from the request's URL components + """ + return utils.unparse_url(self.scheme, self.host, self.port, self.path) + + @url.setter + def url(self, url): + self.scheme, self.host, self.port, self.path = utils.parse_url(url) + + @property + def pretty_host(self): + return self.headers.get("host", self.host) + + @property + def pretty_url(self): + if self.first_line_format == "authority": + return "%s:%d" % (self.pretty_host, self.port) + return utils.unparse_url(self.scheme, self.pretty_host, self.port, self.path) + + # Legacy + + def get_cookies(self): + warnings.warn(".get_cookies is deprecated, use .cookies instead.", DeprecationWarning) + return self.cookies + + def set_cookies(self, odict): + warnings.warn(".set_cookies is deprecated, use .cookies instead.", DeprecationWarning) + self.cookies = odict + + def get_query(self): + warnings.warn(".get_query is deprecated, use .query instead.", DeprecationWarning) + return self.query or ODict([]) + + def set_query(self, odict): + warnings.warn(".set_query is deprecated, use .query instead.", DeprecationWarning) + self.query = odict + + def get_path_components(self): + warnings.warn(".get_path_components is deprecated, use .path_components instead.", DeprecationWarning) + return self.path_components + + def set_path_components(self, lst): + warnings.warn(".set_path_components is deprecated, use .path_components instead.", DeprecationWarning) + self.path_components = lst + + def get_form_urlencoded(self): + warnings.warn(".get_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) + return self.urlencoded_form or ODict([]) + + def set_form_urlencoded(self, odict): + warnings.warn(".set_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) + self.urlencoded_form = odict + + def get_form_multipart(self): + warnings.warn(".get_form_multipart is deprecated, use .multipart_form instead.", DeprecationWarning) + return self.multipart_form or ODict([]) + + @property + def form_in(self): + warnings.warn(".form_in is deprecated, use .first_line_format instead.", DeprecationWarning) + return self.first_line_format + + @form_in.setter + def form_in(self, form_in): + warnings.warn(".form_in is deprecated, use .first_line_format instead.", DeprecationWarning) + self.first_line_format = form_in + + @property + def form_out(self): + warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning) + return self.first_line_format + + @form_out.setter + def form_out(self, form_out): + warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning) + self.first_line_format = form_out \ No newline at end of file diff --git a/netlib/http/response.py b/netlib/http/response.py new file mode 100644 index 00000000..02fac3df --- /dev/null +++ b/netlib/http/response.py @@ -0,0 +1,3 @@ +from __future__ import absolute_import, print_function, division + +# TODO \ No newline at end of file diff --git a/netlib/tutils.py b/netlib/tutils.py index 1665a792..ff63c33c 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -98,7 +98,7 @@ def treq(**kwargs): netlib.http.Request """ default = dict( - form_in="relative", + first_line_format="relative", method=b"GET", scheme=b"http", host=b"address", @@ -106,7 +106,7 @@ def treq(**kwargs): path=b"/path", http_version=b"HTTP/1.1", headers=Headers(header="qvalue"), - body=b"content" + content=b"content" ) default.update(kwargs) return Request(**default) diff --git a/netlib/utils.py b/netlib/utils.py index 6f6d1ea0..3ec60890 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -273,22 +273,27 @@ def get_header_tokens(headers, key): return [token.strip() for token in tokens] -@always_byte_args() def hostport(scheme, host, port): """ Returns the host component, with a port specifcation if needed. """ - if (port, scheme) in [(80, b"http"), (443, b"https")]: + if (port, scheme) in [(80, "http"), (443, "https"), (80, b"http"), (443, b"https")]: return host else: - return b"%s:%d" % (host, port) + if isinstance(host, six.binary_type): + return b"%s:%d" % (host, port) + else: + return "%s:%d" % (host, port) def unparse_url(scheme, host, port, path=""): """ - Returns a URL string, constructed from the specified compnents. + Returns a URL string, constructed from the specified components. + + Args: + All args must be str. """ - return b"%s://%s%s" % (scheme, hostport(scheme, host, port), path) + return "%s://%s%s" % (scheme, hostport(scheme, host, port), path) def urlencode(s): diff --git a/test/http/http1/test_assemble.py b/test/http/http1/test_assemble.py index 963e7549..47d11d33 100644 --- a/test/http/http1/test_assemble.py +++ b/test/http/http1/test_assemble.py @@ -20,7 +20,7 @@ def test_assemble_request(): ) with raises(HttpException): - assemble_request(treq(body=CONTENT_MISSING)) + assemble_request(treq(content=CONTENT_MISSING)) def test_assemble_request_head(): @@ -62,21 +62,21 @@ def test_assemble_body(): def test_assemble_request_line(): - assert _assemble_request_line(treq()) == b"GET /path HTTP/1.1" + assert _assemble_request_line(treq().data) == b"GET /path HTTP/1.1" - authority_request = treq(method=b"CONNECT", form_in="authority") + authority_request = treq(method=b"CONNECT", first_line_format="authority").data assert _assemble_request_line(authority_request) == b"CONNECT address:22 HTTP/1.1" - absolute_request = treq(form_in="absolute") + absolute_request = treq(first_line_format="absolute").data assert _assemble_request_line(absolute_request) == b"GET http://address:22/path HTTP/1.1" with raises(RuntimeError): - _assemble_request_line(treq(), "invalid_form") + _assemble_request_line(treq(first_line_format="invalid_form").data) def test_assemble_request_headers(): # https://github.com/mitmproxy/mitmproxy/issues/186 - r = treq(body=b"") + r = treq(content=b"") r.headers["Transfer-Encoding"] = "chunked" c = _assemble_request_headers(r) assert b"Transfer-Encoding" in c diff --git a/test/http/http1/test_read.py b/test/http/http1/test_read.py index 9eb02a24..c3f744bf 100644 --- a/test/http/http1/test_read.py +++ b/test/http/http1/test_read.py @@ -16,8 +16,8 @@ from netlib.tutils import treq, tresp, raises def test_read_request(): rfile = BytesIO(b"GET / HTTP/1.1\r\n\r\nskip") r = read_request(rfile) - assert r.method == b"GET" - assert r.body == b"" + assert r.method == "GET" + assert r.content == b"" assert r.timestamp_end assert rfile.read() == b"skip" @@ -32,7 +32,7 @@ def test_read_request_head(): rfile.reset_timestamps = Mock() rfile.first_byte_timestamp = 42 r = read_request_head(rfile) - assert r.method == b"GET" + assert r.method == "GET" assert r.headers["Content-Length"] == "4" assert r.body is None assert rfile.reset_timestamps.called @@ -283,7 +283,7 @@ class TestReadHeaders(object): def test_read_chunked(): - req = treq(body=None) + req = treq(content=None) req.headers["Transfer-Encoding"] = "chunked" data = b"1\r\na\r\n0\r\n" diff --git a/test/http/test_models.py b/test/http/test_models.py index 10e0795a..3c196847 100644 --- a/test/http/test_models.py +++ b/test/http/test_models.py @@ -39,6 +39,7 @@ class TestRequest(object): a = tutils.treq(timestamp_start=42, timestamp_end=43) b = tutils.treq(timestamp_start=42, timestamp_end=43) assert a == b + assert not a != b assert not a == 'foo' assert not b == 'foo' @@ -70,45 +71,17 @@ class TestRequest(object): req = tutils.treq() req.headers["Host"] = "" req.host = "foobar" - req.update_host_header() assert req.headers["Host"] == "foobar" - def test_get_form(self): - req = tutils.treq() - assert req.get_form() == ODict() - - @mock.patch("netlib.http.Request.get_form_multipart") - @mock.patch("netlib.http.Request.get_form_urlencoded") - def test_get_form_with_url_encoded(self, mock_method_urlencoded, mock_method_multipart): - req = tutils.treq() - assert req.get_form() == ODict() - - req = tutils.treq() - req.body = "foobar" - req.headers["Content-Type"] = HDR_FORM_URLENCODED - req.get_form() - assert req.get_form_urlencoded.called - assert not req.get_form_multipart.called - - @mock.patch("netlib.http.Request.get_form_multipart") - @mock.patch("netlib.http.Request.get_form_urlencoded") - def test_get_form_with_multipart(self, mock_method_urlencoded, mock_method_multipart): - req = tutils.treq() - req.body = "foobar" - req.headers["Content-Type"] = HDR_FORM_MULTIPART - req.get_form() - assert not req.get_form_urlencoded.called - assert req.get_form_multipart.called - def test_get_form_urlencoded(self): - req = tutils.treq(body="foobar") + req = tutils.treq(content="foobar") assert req.get_form_urlencoded() == ODict() req.headers["Content-Type"] = HDR_FORM_URLENCODED assert req.get_form_urlencoded() == ODict(utils.urldecode(req.body)) def test_get_form_multipart(self): - req = tutils.treq(body="foobar") + req = tutils.treq(content="foobar") assert req.get_form_multipart() == ODict() req.headers["Content-Type"] = HDR_FORM_MULTIPART @@ -140,7 +113,7 @@ class TestRequest(object): assert req.get_query().lst == [] req.url = "http://localhost:80/foo?bar=42" - assert req.get_query().lst == [(b"bar", b"42")] + assert req.get_query().lst == [("bar", "42")] def test_set_query(self): req = tutils.treq() @@ -148,31 +121,23 @@ class TestRequest(object): def test_pretty_host(self): r = tutils.treq() - assert r.pretty_host(True) == "address" - assert r.pretty_host(False) == "address" + assert r.pretty_host == "address" + assert r.host == "address" r.headers["host"] = "other" - assert r.pretty_host(True) == "other" - assert r.pretty_host(False) == "address" + assert r.pretty_host == "other" + assert r.host == "address" r.host = None - assert r.pretty_host(True) == "other" - assert r.pretty_host(False) is None - del r.headers["host"] - assert r.pretty_host(True) is None - assert r.pretty_host(False) is None + assert r.pretty_host is None + assert r.host is None # Invalid IDNA r.headers["host"] = ".disqus.com" - assert r.pretty_host(True) == ".disqus.com" + assert r.pretty_host == ".disqus.com" def test_pretty_url(self): - req = tutils.treq() - req.form_out = "authority" - assert req.pretty_url(True) == b"address:22" - assert req.pretty_url(False) == b"address:22" - - req.form_out = "relative" - assert req.pretty_url(True) == b"http://address:22/path" - assert req.pretty_url(False) == b"http://address:22/path" + req = tutils.treq(first_line_format="relative") + assert req.pretty_url == "http://address:22/path" + assert req.url == "http://address:22/path" def test_get_cookies_none(self): headers = Headers() @@ -212,12 +177,12 @@ class TestRequest(object): assert r.get_cookies()["cookiename"] == ["foo"] def test_set_url(self): - r = tutils.treq(form_in="absolute") + r = tutils.treq(first_line_format="absolute") r.url = b"https://otheraddress:42/ORLY" - assert r.scheme == b"https" - assert r.host == b"otheraddress" + assert r.scheme == "https" + assert r.host == "otheraddress" assert r.port == 42 - assert r.path == b"/ORLY" + assert r.path == "/ORLY" try: r.url = "//localhost:80/foo@bar" @@ -230,7 +195,7 @@ class TestRequest(object): # protocol = mock_protocol("OPTIONS * HTTP/1.1") # f.request = HTTPRequest.from_protocol(protocol) # - # assert f.request.form_in == "relative" + # assert f.request.first_line_format == "relative" # f.request.host = f.server_conn.address.host # f.request.port = f.server_conn.address.port # f.request.scheme = "http" @@ -266,7 +231,7 @@ class TestRequest(object): # "CONNECT address:22 HTTP/1.1\r\n" # "Host: address:22\r\n" # "Content-Length: 0\r\n\r\n") - # assert r.pretty_url(False) == "address:22" + # assert r.pretty_url == "address:22" # # def test_absolute_form_in(self): # protocol = mock_protocol("GET oops-no-protocol.com HTTP/1.1") diff --git a/test/http/test_request.py b/test/http/test_request.py new file mode 100644 index 00000000..02fac3df --- /dev/null +++ b/test/http/test_request.py @@ -0,0 +1,3 @@ +from __future__ import absolute_import, print_function, division + +# TODO \ No newline at end of file diff --git a/test/http/test_response.py b/test/http/test_response.py new file mode 100644 index 00000000..02fac3df --- /dev/null +++ b/test/http/test_response.py @@ -0,0 +1,3 @@ +from __future__ import absolute_import, print_function, division + +# TODO \ No newline at end of file diff --git a/test/test_utils.py b/test/test_utils.py index 17636cc4..b096e5bc 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -84,10 +84,10 @@ def test_parse_url(): def test_unparse_url(): - assert utils.unparse_url(b"http", b"foo.com", 99, b"") == b"http://foo.com:99" - assert utils.unparse_url(b"http", b"foo.com", 80, b"/bar") == b"http://foo.com/bar" - assert utils.unparse_url(b"https", b"foo.com", 80, b"") == b"https://foo.com:80" - assert utils.unparse_url(b"https", b"foo.com", 443, b"") == b"https://foo.com" + assert utils.unparse_url("http", "foo.com", 99, "") == "http://foo.com:99" + assert utils.unparse_url("http", "foo.com", 80, "/bar") == "http://foo.com/bar" + assert utils.unparse_url("https", "foo.com", 80, "") == "https://foo.com:80" + assert utils.unparse_url("https", "foo.com", 443, "") == "https://foo.com" def test_urlencode(): diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index 4ae4cf45..9a1e5d3d 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -68,7 +68,7 @@ class WebSocketsClient(tcp.TCPClient): self.wfile.write(bytes(headers) + b"\r\n") self.wfile.flush() - resp = read_response(self.rfile, treq(method="GET")) + resp = read_response(self.rfile, treq(method=b"GET")) server_nonce = self.protocol.check_server_handshake(resp.headers) if not server_nonce == self.protocol.create_server_nonce(self.client_nonce): -- cgit v1.2.3 From 49ea8fc0ebcfe4861f099200044a553f092faec7 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 26 Sep 2015 17:39:50 +0200 Subject: refactor response model --- netlib/http/__init__.py | 15 ++-- netlib/http/headers.py | 26 +++---- netlib/http/http1/assemble.py | 16 ++-- netlib/http/http1/read.py | 2 +- netlib/http/http2/connections.py | 4 +- netlib/http/http2/frame.py | 3 - netlib/http/message.py | 64 +++++++++------- netlib/http/models.py | 112 ---------------------------- netlib/http/request.py | 155 +++++++++++++++++++++------------------ netlib/http/response.py | 124 ++++++++++++++++++++++++++++++- netlib/tutils.py | 6 +- netlib/wsgi.py | 6 +- test/http/http1/test_assemble.py | 4 +- test/http/http1/test_read.py | 6 +- test/http/http2/test_protocol.py | 12 +-- test/http/test_models.py | 12 ++- 16 files changed, 293 insertions(+), 274 deletions(-) delete mode 100644 netlib/http/models.py diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index e8c7ba20..fd632cd5 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -1,17 +1,14 @@ from __future__ import absolute_import, print_function, division -from .headers import Headers -from .message import decoded from .request import Request -from .models import Response -from .models import ALPN_PROTO_HTTP1, ALPN_PROTO_H2 -from .models import HDR_FORM_MULTIPART, HDR_FORM_URLENCODED, CONTENT_MISSING +from .response import Response +from .headers import Headers +from .message import decoded, CONTENT_MISSING from . import http1, http2 __all__ = [ + "Request", + "Response", "Headers", - "decoded", - "Request", "Response", - "ALPN_PROTO_HTTP1", "ALPN_PROTO_H2", - "HDR_FORM_MULTIPART", "HDR_FORM_URLENCODED", "CONTENT_MISSING", + "decoded", "CONTENT_MISSING", "http1", "http2", ] diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 47ea923b..c79c3344 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -36,12 +36,8 @@ class Headers(MutableMapping): .. code-block:: python - # Create header from a list of (header_name, header_value) tuples - >>> h = Headers([ - ["Host","example.com"], - ["Accept","text/html"], - ["accept","application/xml"] - ]) + # Create headers with keyword arguments + >>> h = Headers(host="example.com", content_type="application/xml") # Headers mostly behave like a normal dict. >>> h["Host"] @@ -51,6 +47,13 @@ class Headers(MutableMapping): >>> h["host"] "example.com" + # Headers can also be creatd from a list of raw (header_name, header_value) byte tuples + >>> h = Headers([ + [b"Host",b"example.com"], + [b"Accept",b"text/html"], + [b"accept",b"application/xml"] + ]) + # Multiple headers are folded into a single header as per RFC7230 >>> h["Accept"] "text/html, application/xml" @@ -60,17 +63,14 @@ class Headers(MutableMapping): >>> h["Accept"] "application/text" - # str(h) returns a HTTP1 header block. - >>> print(h) + # bytes(h) returns a HTTP1 header block. + >>> print(bytes(h)) Host: example.com Accept: application/text # For full control, the raw header fields can be accessed >>> h.fields - # Headers can also be crated from keyword arguments - >>> h = Headers(host="example.com", content_type="application/xml") - Caveats: For use with the "Set-Cookie" header, see :py:meth:`get_all`. """ @@ -79,8 +79,8 @@ class Headers(MutableMapping): def __init__(self, fields=None, **headers): """ Args: - fields: (optional) list of ``(name, value)`` header tuples, - e.g. ``[("Host","example.com")]``. All names and values must be bytes. + fields: (optional) list of ``(name, value)`` header byte tuples, + e.g. ``[(b"Host", b"example.com")]``. All names and values must be bytes. **headers: Additional headers to set. Will overwrite existing values from `fields`. For convenience, underscores in header names will be transformed to dashes - this behaviour does not extend to other methods. diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index 864f6017..785ee8d3 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -10,7 +10,7 @@ def assemble_request(request): if request.content == CONTENT_MISSING: raise HttpException("Cannot assemble flow with CONTENT_MISSING") head = assemble_request_head(request) - body = b"".join(assemble_body(request.headers, [request.data.content])) + body = b"".join(assemble_body(request.data.headers, [request.data.content])) return head + body @@ -24,13 +24,13 @@ def assemble_response(response): if response.content == CONTENT_MISSING: raise HttpException("Cannot assemble flow with CONTENT_MISSING") head = assemble_response_head(response) - body = b"".join(assemble_body(response.headers, [response.content])) + body = b"".join(assemble_body(response.data.headers, [response.data.content])) return head + body def assemble_response_head(response): - first_line = _assemble_response_line(response) - headers = _assemble_response_headers(response) + first_line = _assemble_response_line(response.data) + headers = _assemble_response_headers(response.data) return b"%s\r\n%s\r\n" % (first_line, headers) @@ -92,11 +92,11 @@ def _assemble_request_headers(request_data): return bytes(headers) -def _assemble_response_line(response): +def _assemble_response_line(response_data): return b"%s %d %s" % ( - response.http_version, - response.status_code, - response.msg, + response_data.http_version, + response_data.status_code, + response_data.reason, ) diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 76721e06..0d5e7f4b 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -50,7 +50,7 @@ def read_request_head(rfile): def read_response(rfile, request, body_size_limit=None): response = read_response_head(rfile) expected_body_size = expected_http_body_size(request, response) - response._body = b"".join(read_body(rfile, expected_body_size, body_size_limit)) + response.data.content = b"".join(read_body(rfile, expected_body_size, body_size_limit)) response.timestamp_end = time.time() return response diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py index 5220d5d2..c493abe6 100644 --- a/netlib/http/http2/connections.py +++ b/netlib/http/http2/connections.py @@ -4,7 +4,7 @@ import time from hpack.hpack import Encoder, Decoder from ... import utils -from .. import Headers, Response, Request, ALPN_PROTO_H2 +from .. import Headers, Response, Request from . import frame @@ -283,7 +283,7 @@ class HTTP2Protocol(object): def check_alpn(self): alp = self.tcp_handler.get_alpn_proto_negotiated() - if alp != ALPN_PROTO_H2: + if alp != b'h2': raise NotImplementedError( "HTTP2Protocol can not handle unknown ALP: %s" % alp) return True diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frame.py index cb2cde99..188629d4 100644 --- a/netlib/http/http2/frame.py +++ b/netlib/http/http2/frame.py @@ -25,9 +25,6 @@ ERROR_CODES = BiDi( CLIENT_CONNECTION_PREFACE = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" -ALPN_PROTO_H2 = b'h2' - - class Frame(object): """ diff --git a/netlib/http/message.py b/netlib/http/message.py index 20497bd5..ee138746 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -6,11 +6,14 @@ import six from .. import encoding, utils + +CONTENT_MISSING = 0 + if six.PY2: _native = lambda x: x _always_bytes = lambda x: x else: - # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. + # While the HTTP head _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. _native = lambda x: x.decode("utf-8", "surrogateescape") _always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape") @@ -27,17 +30,6 @@ class Message(object): def __ne__(self, other): return not self.__eq__(other) - @property - def http_version(self): - """ - Version string, e.g. "HTTP/1.1" - """ - return _native(self.data.http_version) - - @http_version.setter - def http_version(self, http_version): - self.data.http_version = _always_bytes(http_version) - @property def headers(self): """ @@ -52,6 +44,32 @@ class Message(object): def headers(self, h): self.data.headers = h + @property + def content(self): + """ + The raw (encoded) HTTP message body + + See also: :py:attr:`text` + """ + return self.data.content + + @content.setter + def content(self, content): + self.data.content = content + if isinstance(content, bytes): + self.headers["content-length"] = str(len(content)) + + @property + def http_version(self): + """ + Version string, e.g. "HTTP/1.1" + """ + return _native(self.data.http_version) + + @http_version.setter + def http_version(self, http_version): + self.data.http_version = _always_bytes(http_version) + @property def timestamp_start(self): """ @@ -74,26 +92,14 @@ class Message(object): def timestamp_end(self, timestamp_end): self.data.timestamp_end = timestamp_end - @property - def content(self): - """ - The raw (encoded) HTTP message body - - See also: :py:attr:`text` - """ - return self.data.content - - @content.setter - def content(self, content): - self.data.content = content - if isinstance(content, bytes): - self.headers["content-length"] = str(len(content)) - @property def text(self): """ The decoded HTTP message body. - Decoded contents are not cached, so this method is relatively expensive to call. + Decoded contents are not cached, so accessing this attribute repeatedly is relatively expensive. + + .. note:: + This is not implemented yet. See also: :py:attr:`content`, :py:class:`decoded` """ @@ -104,6 +110,8 @@ class Message(object): def text(self, text): raise NotImplementedError() + # Legacy + @property def body(self): warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning) diff --git a/netlib/http/models.py b/netlib/http/models.py deleted file mode 100644 index 40f6e98c..00000000 --- a/netlib/http/models.py +++ /dev/null @@ -1,112 +0,0 @@ - - -from ..odict import ODict -from .. import utils, encoding -from ..utils import always_bytes, native -from . import cookies -from .headers import Headers - -from six.moves import urllib - -# TODO: Move somewhere else? -ALPN_PROTO_HTTP1 = b'http/1.1' -ALPN_PROTO_H2 = b'h2' -HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" -HDR_FORM_MULTIPART = "multipart/form-data" - -CONTENT_MISSING = 0 - - -class Message(object): - def __init__(self, http_version, headers, body, timestamp_start, timestamp_end): - self.http_version = http_version - if not headers: - headers = Headers() - assert isinstance(headers, Headers) - self.headers = headers - - self._body = body - self.timestamp_start = timestamp_start - self.timestamp_end = timestamp_end - - @property - def body(self): - return self._body - - @body.setter - def body(self, body): - self._body = body - if isinstance(body, bytes): - self.headers["content-length"] = str(len(body)).encode() - - content = body - - def __eq__(self, other): - if isinstance(other, Message): - return self.__dict__ == other.__dict__ - return False - - -class Response(Message): - def __init__( - self, - http_version, - status_code, - msg=None, - headers=None, - body=None, - timestamp_start=None, - timestamp_end=None, - ): - super(Response, self).__init__(http_version, headers, body, timestamp_start, timestamp_end) - self.status_code = status_code - self.msg = msg - - def __repr__(self): - # return "Response(%s - %s)" % (self.status_code, self.msg) - - if self.body: - size = utils.pretty_size(len(self.body)) - else: - size = "content missing" - # TODO: Remove "(unknown content type, content missing)" edge-case - return "".format( - status_code=self.status_code, - msg=self.msg, - contenttype=self.headers.get("content-type", "unknown content type"), - size=size) - - def get_cookies(self): - """ - Get the contents of all Set-Cookie headers. - - Returns a possibly empty ODict, where keys are cookie name strings, - and values are [value, attr] lists. Value is a string, and attr is - an ODictCaseless containing cookie attributes. Within attrs, unary - attributes (e.g. HTTPOnly) are indicated by a Null value. - """ - ret = [] - for header in self.headers.get_all("set-cookie"): - v = cookies.parse_set_cookie_header(header) - if v: - name, value, attrs = v - ret.append([name, [value, attrs]]) - return ODict(ret) - - def set_cookies(self, odict): - """ - Set the Set-Cookie headers on this response, over-writing existing - headers. - - Accepts an ODict of the same format as that returned by get_cookies. - """ - values = [] - for i in odict.lst: - values.append( - cookies.format_set_cookie_header( - i[0], - i[1][0], - i[1][1] - ) - ) - self.headers.set_all("set-cookie", values) diff --git a/netlib/http/request.py b/netlib/http/request.py index 6830ca40..f8a3b5b9 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -55,7 +55,7 @@ class Request(Message): else: hostport = "" path = self.path or "" - return "HTTPRequest({} {}{})".format( + return "Request({} {}{})".format( self.method, hostport, path ) @@ -97,7 +97,8 @@ class Request(Message): @property def host(self): """ - Target host for the request. This may be directly taken in the request (e.g. "GET http://example.com/ HTTP/1.1") + Target host. This may be parsed from the raw request + (e.g. from a ``GET http://example.com/ HTTP/1.1`` request line) or inferred from the proxy mode (e.g. an IP in transparent mode). """ @@ -154,6 +155,83 @@ class Request(Message): def path(self, path): self.data.path = _always_bytes(path) + @property + def url(self): + """ + The URL string, constructed from the request's URL components + """ + return utils.unparse_url(self.scheme, self.host, self.port, self.path) + + @url.setter + def url(self, url): + self.scheme, self.host, self.port, self.path = utils.parse_url(url) + + @property + def pretty_host(self): + """ + Similar to :py:attr:`host`, but using the Host headers as an additional preferred data source. + This is useful in transparent mode where :py:attr:`host` is only an IP address, + but may not reflect the actual destination as the Host header could be spoofed. + """ + return self.headers.get("host", self.host) + + @property + def pretty_url(self): + """ + Like :py:attr:`url`, but using :py:attr:`pretty_host` instead of :py:attr:`host`. + """ + if self.first_line_format == "authority": + return "%s:%d" % (self.pretty_host, self.port) + return utils.unparse_url(self.scheme, self.pretty_host, self.port, self.path) + + @property + def query(self): + """ + The request query string as an :py:class:`ODict` object. + None, if there is no query. + """ + _, _, _, _, query, _ = urllib.parse.urlparse(self.url) + if query: + return ODict(utils.urldecode(query)) + return None + + @query.setter + def query(self, odict): + query = utils.urlencode(odict.lst) + scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) + self.url = urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]) + + @property + def cookies(self): + """ + The request cookies. + An empty :py:class:`ODict` object if the cookie monster ate them all. + """ + ret = ODict() + for i in self.headers.get_all("Cookie"): + ret.extend(cookies.parse_cookie_header(i)) + return ret + + @cookies.setter + def cookies(self, odict): + self.headers["cookie"] = cookies.format_cookie_header(odict) + + @property + def path_components(self): + """ + The URL's path components as a list of strings. + Components are unquoted. + """ + _, _, path, _, _, _ = urllib.parse.urlparse(self.url) + return [urllib.parse.unquote(i) for i in path.split("/") if i] + + @path_components.setter + def path_components(self, components): + components = map(lambda x: urllib.parse.quote(x, safe=""), components) + path = "/" + "/".join(components) + scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url) + self.url = urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]) + def anticache(self): """ Modifies this request to remove headers that might produce a cached @@ -191,7 +269,7 @@ class Request(Message): @property def urlencoded_form(self): """ - The URL-encoded form data as an ODict object. + The URL-encoded form data as an :py:class:`ODict` object. None if there is no data or the content-type indicates non-form data. """ is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() @@ -211,7 +289,7 @@ class Request(Message): @property def multipart_form(self): """ - The multipart form data as an ODict object. + The multipart form data as an :py:class:`ODict` object. None if there is no data or the content-type indicates non-form data. """ is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() @@ -223,75 +301,6 @@ class Request(Message): def multipart_form(self): raise NotImplementedError() - @property - def path_components(self): - """ - The URL's path components as a list of strings. - Components are unquoted. - """ - _, _, path, _, _, _ = urllib.parse.urlparse(self.url) - return [urllib.parse.unquote(i) for i in path.split("/") if i] - - @path_components.setter - def path_components(self, components): - components = map(lambda x: urllib.parse.quote(x, safe=""), components) - path = "/" + "/".join(components) - scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url) - self.url = urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]) - - @property - def query(self): - """ - The request query string as an ODict object. - None, if there is no query. - """ - _, _, _, _, query, _ = urllib.parse.urlparse(self.url) - if query: - return ODict(utils.urldecode(query)) - return None - - @query.setter - def query(self, odict): - query = utils.urlencode(odict.lst) - scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url) - self.url = urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]) - - @property - def cookies(self): - """ - The request cookies. - An empty ODict object if the cookie monster ate them all. - """ - ret = ODict() - for i in self.headers.get_all("Cookie"): - ret.extend(cookies.parse_cookie_header(i)) - return ret - - @cookies.setter - def cookies(self, odict): - self.headers["cookie"] = cookies.format_cookie_header(odict) - - @property - def url(self): - """ - The URL string, constructed from the request's URL components - """ - return utils.unparse_url(self.scheme, self.host, self.port, self.path) - - @url.setter - def url(self, url): - self.scheme, self.host, self.port, self.path = utils.parse_url(url) - - @property - def pretty_host(self): - return self.headers.get("host", self.host) - - @property - def pretty_url(self): - if self.first_line_format == "authority": - return "%s:%d" % (self.pretty_host, self.port) - return utils.unparse_url(self.scheme, self.pretty_host, self.port, self.path) - # Legacy def get_cookies(self): diff --git a/netlib/http/response.py b/netlib/http/response.py index 02fac3df..7d64243d 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -1,3 +1,125 @@ from __future__ import absolute_import, print_function, division -# TODO \ No newline at end of file +import warnings + +from . import cookies +from .headers import Headers +from .message import Message, _native, _always_bytes +from .. import utils +from ..odict import ODict + + +class ResponseData(object): + def __init__(self, http_version, status_code, reason=None, headers=None, content=None, + timestamp_start=None, timestamp_end=None): + if not headers: + headers = Headers() + assert isinstance(headers, Headers) + + self.http_version = http_version + self.status_code = status_code + self.reason = reason + self.headers = headers + self.content = content + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end + + def __eq__(self, other): + if isinstance(other, ResponseData): + return self.__dict__ == other.__dict__ + return False + + def __ne__(self, other): + return not self.__eq__(other) + + +class Response(Message): + """ + An HTTP response. + """ + def __init__(self, *args, **kwargs): + data = ResponseData(*args, **kwargs) + super(Response, self).__init__(data) + + def __repr__(self): + if self.content: + details = "{}, {}".format( + self.headers.get("content-type", "unknown content type"), + utils.pretty_size(len(self.content)) + ) + else: + details = "content missing" + return "Response({status_code} {reason}, {details})".format( + status_code=self.status_code, + reason=self.reason, + details=details + ) + + @property + def status_code(self): + """ + HTTP Status Code, e.g. ``200``. + """ + return self.data.status_code + + @status_code.setter + def status_code(self, status_code): + self.data.status_code = status_code + + @property + def reason(self): + """ + HTTP Reason Phrase, e.g. "Not Found". + This is always :py:obj:`None` for HTTP2 requests, because HTTP2 responses do not contain a reason phrase. + """ + return _native(self.data.reason) + + @reason.setter + def reason(self, reason): + self.data.reason = _always_bytes(reason) + + @property + def cookies(self): + """ + Get the contents of all Set-Cookie headers. + + A possibly empty :py:class:`ODict`, where keys are cookie name strings, + and values are [value, attr] lists. Value is a string, and attr is + an ODictCaseless containing cookie attributes. Within attrs, unary + attributes (e.g. HTTPOnly) are indicated by a Null value. + """ + ret = [] + for header in self.headers.get_all("set-cookie"): + v = cookies.parse_set_cookie_header(header) + if v: + name, value, attrs = v + ret.append([name, [value, attrs]]) + return ODict(ret) + + @cookies.setter + def cookies(self, odict): + values = [] + for i in odict.lst: + header = cookies.format_set_cookie_header(i[0], i[1][0], i[1][1]) + values.append(header) + self.headers.set_all("set-cookie", values) + + # Legacy + + def get_cookies(self): + warnings.warn(".get_cookies is deprecated, use .cookies instead.", DeprecationWarning) + return self.cookies + + def set_cookies(self, odict): + warnings.warn(".set_cookies is deprecated, use .cookies instead.", DeprecationWarning) + self.cookies = odict + + @property + def msg(self): + warnings.warn(".msg is deprecated, use .reason instead.", DeprecationWarning) + return self.reason + + @msg.setter + def msg(self, reason): + warnings.warn(".msg is deprecated, use .reason instead.", DeprecationWarning) + self.reason = reason diff --git a/netlib/tutils.py b/netlib/tutils.py index ff63c33c..e16f1a76 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -120,9 +120,9 @@ def tresp(**kwargs): default = dict( http_version=b"HTTP/1.1", status_code=200, - msg=b"OK", - headers=Headers(header_response=b"svalue"), - body=b"message", + reason=b"OK", + headers=Headers(header_response="svalue"), + content=b"message", timestamp_start=time.time(), timestamp_end=time.time(), ) diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 4fcd5178..df248a19 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -25,9 +25,9 @@ class Flow(object): class Request(object): - def __init__(self, scheme, method, path, http_version, headers, body): + def __init__(self, scheme, method, path, http_version, headers, content): self.scheme, self.method, self.path = scheme, method, path - self.headers, self.body = headers, body + self.headers, self.content = headers, content self.http_version = http_version @@ -64,7 +64,7 @@ class WSGIAdaptor(object): environ = { 'wsgi.version': (1, 0), 'wsgi.url_scheme': native(flow.request.scheme, "latin-1"), - 'wsgi.input': BytesIO(flow.request.body or b""), + 'wsgi.input': BytesIO(flow.request.content or b""), 'wsgi.errors': errsoc, 'wsgi.multithread': True, 'wsgi.multiprocess': False, diff --git a/test/http/http1/test_assemble.py b/test/http/http1/test_assemble.py index 47d11d33..460e22c5 100644 --- a/test/http/http1/test_assemble.py +++ b/test/http/http1/test_assemble.py @@ -40,7 +40,7 @@ def test_assemble_response(): ) with raises(HttpException): - assemble_response(tresp(body=CONTENT_MISSING)) + assemble_response(tresp(content=CONTENT_MISSING)) def test_assemble_response_head(): @@ -86,7 +86,7 @@ def test_assemble_request_headers(): def test_assemble_response_headers(): # https://github.com/mitmproxy/mitmproxy/issues/186 - r = tresp(body=b"") + r = tresp(content=b"") r.headers["Transfer-Encoding"] = "chunked" c = _assemble_response_headers(r) assert b"Transfer-Encoding" in c diff --git a/test/http/http1/test_read.py b/test/http/http1/test_read.py index c3f744bf..fadfe446 100644 --- a/test/http/http1/test_read.py +++ b/test/http/http1/test_read.py @@ -34,7 +34,7 @@ def test_read_request_head(): r = read_request_head(rfile) assert r.method == "GET" assert r.headers["Content-Length"] == "4" - assert r.body is None + assert r.content is None assert rfile.reset_timestamps.called assert r.timestamp_start == 42 assert rfile.read() == b"skip" @@ -45,7 +45,7 @@ def test_read_response(): rfile = BytesIO(b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody") r = read_response(rfile, req) assert r.status_code == 418 - assert r.body == b"body" + assert r.content == b"body" assert r.timestamp_end @@ -61,7 +61,7 @@ def test_read_response_head(): r = read_response_head(rfile) assert r.status_code == 418 assert r.headers["Content-Length"] == "4" - assert r.body is None + assert r.content is None assert rfile.reset_timestamps.called assert r.timestamp_start == 42 assert rfile.read() == b"skip" diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index a55941e0..6bda96f5 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -65,7 +65,7 @@ class TestProtocol: class TestCheckALPNMatch(tservers.ServerTestBase): handler = EchoHandler ssl = dict( - alpn_select=ALPN_PROTO_H2, + alpn_select=b'h2', ) if OpenSSL._util.lib.Cryptography_HAS_ALPN: @@ -73,7 +73,7 @@ class TestCheckALPNMatch(tservers.ServerTestBase): def test_check_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(alpn_protos=[ALPN_PROTO_H2]) + c.convert_to_ssl(alpn_protos=[b'h2']) protocol = HTTP2Protocol(c) assert protocol.check_alpn() @@ -89,7 +89,7 @@ class TestCheckALPNMismatch(tservers.ServerTestBase): def test_check_alpn(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - c.convert_to_ssl(alpn_protos=[ALPN_PROTO_H2]) + c.convert_to_ssl(alpn_protos=[b'h2']) protocol = HTTP2Protocol(c) tutils.raises(NotImplementedError, protocol.check_alpn) @@ -311,7 +311,7 @@ class TestReadRequest(tservers.ServerTestBase): assert req.stream_id assert req.headers.fields == [[':method', 'GET'], [':path', '/'], [':scheme', 'https']] - assert req.body == b'foobar' + assert req.content == b'foobar' class TestReadRequestRelative(tservers.ServerTestBase): @@ -417,7 +417,7 @@ class TestReadResponse(tservers.ServerTestBase): assert resp.status_code == 200 assert resp.msg == "" assert resp.headers.fields == [[':status', '200'], ['etag', 'foobar']] - assert resp.body == b'foobar' + assert resp.content == b'foobar' assert resp.timestamp_end @@ -444,7 +444,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase): assert resp.status_code == 200 assert resp.msg == "" assert resp.headers.fields == [[':status', '200'], ['etag', 'foobar']] - assert resp.body == b'' + assert resp.content == b'' class TestAssembleRequest(object): diff --git a/test/http/test_models.py b/test/http/test_models.py index 3c196847..aa267944 100644 --- a/test/http/test_models.py +++ b/test/http/test_models.py @@ -3,9 +3,7 @@ import mock from netlib import tutils from netlib import utils from netlib.odict import ODict, ODictCaseless -from netlib.http import Request, Response, Headers, CONTENT_MISSING, HDR_FORM_URLENCODED, \ - HDR_FORM_MULTIPART - +from netlib.http import Request, Response, Headers, CONTENT_MISSING class TestRequest(object): def test_repr(self): @@ -77,14 +75,14 @@ class TestRequest(object): req = tutils.treq(content="foobar") assert req.get_form_urlencoded() == ODict() - req.headers["Content-Type"] = HDR_FORM_URLENCODED + req.headers["Content-Type"] = "application/x-www-form-urlencoded" assert req.get_form_urlencoded() == ODict(utils.urldecode(req.body)) def test_get_form_multipart(self): req = tutils.treq(content="foobar") assert req.get_form_multipart() == ODict() - req.headers["Content-Type"] = HDR_FORM_MULTIPART + req.headers["Content-Type"] = "multipart/form-data" assert req.get_form_multipart() == ODict( utils.multipartdecode( req.headers, @@ -95,7 +93,7 @@ class TestRequest(object): def test_set_form_urlencoded(self): req = tutils.treq() req.set_form_urlencoded(ODict([('foo', 'bar'), ('rab', 'oof')])) - assert req.headers["Content-Type"] == HDR_FORM_URLENCODED + assert req.headers["Content-Type"] == "application/x-www-form-urlencoded" assert req.body def test_get_path_components(self): @@ -298,7 +296,7 @@ class TestResponse(object): assert "unknown content type" in repr(r) r.headers["content-type"] = "foo" assert "foo" in repr(r) - assert repr(tutils.tresp(body=CONTENT_MISSING)) + assert repr(tutils.tresp(content=CONTENT_MISSING)) def test_get_cookies_none(self): resp = tutils.tresp() -- cgit v1.2.3 From 466888b01a361e46fb3d4e66afa2c6a0fd168c8e Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 26 Sep 2015 20:07:11 +0200 Subject: improve request tests, coverage++ --- netlib/encoding.py | 4 + netlib/http/headers.py | 8 +- netlib/http/message.py | 42 ++++++- netlib/http/request.py | 28 ++--- netlib/http/response.py | 8 +- netlib/http/status_codes.py | 4 +- test/http/http1/test_read.py | 17 ++- test/http/test_headers.py | 3 + test/http/test_message.py | 136 +++++++++++++++++++++ test/http/test_models.py | 266 +---------------------------------------- test/http/test_request.py | 229 ++++++++++++++++++++++++++++++++++- test/http/test_status_codes.py | 6 + 12 files changed, 455 insertions(+), 296 deletions(-) create mode 100644 test/http/test_message.py create mode 100644 test/http/test_status_codes.py diff --git a/netlib/encoding.py b/netlib/encoding.py index 4c11273b..14479e00 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -12,6 +12,8 @@ ENCODINGS = {"identity", "gzip", "deflate"} def decode(e, content): + if not isinstance(content, bytes): + return None encoding_map = { "identity": identity, "gzip": decode_gzip, @@ -23,6 +25,8 @@ def decode(e, content): def encode(e, content): + if not isinstance(content, bytes): + return None encoding_map = { "identity": identity, "gzip": encode_gzip, diff --git a/netlib/http/headers.py b/netlib/http/headers.py index c79c3344..f64e6200 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -8,15 +8,15 @@ from __future__ import absolute_import, print_function, division import copy try: from collections.abc import MutableMapping -except ImportError: # Workaround for Python < 3.3 - from collections import MutableMapping +except ImportError: # pragma: nocover + from collections import MutableMapping # Workaround for Python < 3.3 import six from netlib.utils import always_byte_args, always_bytes -if six.PY2: +if six.PY2: # pragma: nocover _native = lambda x: x _always_bytes = lambda x: x _always_byte_args = lambda x: x @@ -106,7 +106,7 @@ class Headers(MutableMapping): else: return b"" - if six.PY2: + if six.PY2: # pragma: nocover __str__ = __bytes__ @_always_byte_args diff --git a/netlib/http/message.py b/netlib/http/message.py index ee138746..7cb18f52 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -9,7 +9,7 @@ from .. import encoding, utils CONTENT_MISSING = 0 -if six.PY2: +if six.PY2: # pragma: nocover _native = lambda x: x _always_bytes = lambda x: x else: @@ -110,15 +110,48 @@ class Message(object): def text(self, text): raise NotImplementedError() + def decode(self): + """ + Decodes body based on the current Content-Encoding header, then + removes the header. If there is no Content-Encoding header, no + action is taken. + + Returns: + True, if decoding succeeded. + False, otherwise. + """ + ce = self.headers.get("content-encoding") + data = encoding.decode(ce, self.content) + if data is None: + return False + self.content = data + self.headers.pop("content-encoding", None) + return True + + def encode(self, e): + """ + Encodes body with the encoding e, where e is "gzip", "deflate" or "identity". + + Returns: + True, if decoding succeeded. + False, otherwise. + """ + data = encoding.encode(e, self.content) + if data is None: + return False + self.content = data + self.headers["content-encoding"] = e + return True + # Legacy @property - def body(self): + def body(self): # pragma: nocover warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning) return self.content @body.setter - def body(self, body): + def body(self, body): # pragma: nocover warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning) self.content = body @@ -146,8 +179,7 @@ class decoded(object): def __enter__(self): if self.ce: - if not self.message.decode(): - self.ce = None + self.message.decode() def __exit__(self, type, value, tb): if self.ce: diff --git a/netlib/http/request.py b/netlib/http/request.py index f8a3b5b9..325c0080 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -102,7 +102,7 @@ class Request(Message): or inferred from the proxy mode (e.g. an IP in transparent mode). """ - if six.PY2: + if six.PY2: # pragma: nocover return self.data.host if not self.data.host: @@ -303,58 +303,58 @@ class Request(Message): # Legacy - def get_cookies(self): + def get_cookies(self): # pragma: nocover warnings.warn(".get_cookies is deprecated, use .cookies instead.", DeprecationWarning) return self.cookies - def set_cookies(self, odict): + def set_cookies(self, odict): # pragma: nocover warnings.warn(".set_cookies is deprecated, use .cookies instead.", DeprecationWarning) self.cookies = odict - def get_query(self): + def get_query(self): # pragma: nocover warnings.warn(".get_query is deprecated, use .query instead.", DeprecationWarning) return self.query or ODict([]) - def set_query(self, odict): + def set_query(self, odict): # pragma: nocover warnings.warn(".set_query is deprecated, use .query instead.", DeprecationWarning) self.query = odict - def get_path_components(self): + def get_path_components(self): # pragma: nocover warnings.warn(".get_path_components is deprecated, use .path_components instead.", DeprecationWarning) return self.path_components - def set_path_components(self, lst): + def set_path_components(self, lst): # pragma: nocover warnings.warn(".set_path_components is deprecated, use .path_components instead.", DeprecationWarning) self.path_components = lst - def get_form_urlencoded(self): + def get_form_urlencoded(self): # pragma: nocover warnings.warn(".get_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) return self.urlencoded_form or ODict([]) - def set_form_urlencoded(self, odict): + def set_form_urlencoded(self, odict): # pragma: nocover warnings.warn(".set_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) self.urlencoded_form = odict - def get_form_multipart(self): + def get_form_multipart(self): # pragma: nocover warnings.warn(".get_form_multipart is deprecated, use .multipart_form instead.", DeprecationWarning) return self.multipart_form or ODict([]) @property - def form_in(self): + def form_in(self): # pragma: nocover warnings.warn(".form_in is deprecated, use .first_line_format instead.", DeprecationWarning) return self.first_line_format @form_in.setter - def form_in(self, form_in): + def form_in(self, form_in): # pragma: nocover warnings.warn(".form_in is deprecated, use .first_line_format instead.", DeprecationWarning) self.first_line_format = form_in @property - def form_out(self): + def form_out(self): # pragma: nocover warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning) return self.first_line_format @form_out.setter - def form_out(self, form_out): + def form_out(self, form_out): # pragma: nocover warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning) self.first_line_format = form_out \ No newline at end of file diff --git a/netlib/http/response.py b/netlib/http/response.py index 7d64243d..db31d2b9 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -106,20 +106,20 @@ class Response(Message): # Legacy - def get_cookies(self): + def get_cookies(self): # pragma: nocover warnings.warn(".get_cookies is deprecated, use .cookies instead.", DeprecationWarning) return self.cookies - def set_cookies(self, odict): + def set_cookies(self, odict): # pragma: nocover warnings.warn(".set_cookies is deprecated, use .cookies instead.", DeprecationWarning) self.cookies = odict @property - def msg(self): + def msg(self): # pragma: nocover warnings.warn(".msg is deprecated, use .reason instead.", DeprecationWarning) return self.reason @msg.setter - def msg(self, reason): + def msg(self, reason): # pragma: nocover warnings.warn(".msg is deprecated, use .reason instead.", DeprecationWarning) self.reason = reason diff --git a/netlib/http/status_codes.py b/netlib/http/status_codes.py index dc09f465..8a4dc1f5 100644 --- a/netlib/http/status_codes.py +++ b/netlib/http/status_codes.py @@ -1,4 +1,4 @@ -from __future__ import (absolute_import, print_function, division) +from __future__ import absolute_import, print_function, division CONTINUE = 100 SWITCHING = 101 @@ -37,6 +37,7 @@ REQUEST_URI_TOO_LONG = 414 UNSUPPORTED_MEDIA_TYPE = 415 REQUESTED_RANGE_NOT_SATISFIABLE = 416 EXPECTATION_FAILED = 417 +IM_A_TEAPOT = 418 INTERNAL_SERVER_ERROR = 500 NOT_IMPLEMENTED = 501 @@ -91,6 +92,7 @@ RESPONSES = { UNSUPPORTED_MEDIA_TYPE: "Unsupported Media Type", REQUESTED_RANGE_NOT_SATISFIABLE: "Requested Range not satisfiable", EXPECTATION_FAILED: "Expectation Failed", + IM_A_TEAPOT: "I'm a teapot", # 500 INTERNAL_SERVER_ERROR: "Internal Server Error", diff --git a/test/http/http1/test_read.py b/test/http/http1/test_read.py index fadfe446..a0085db9 100644 --- a/test/http/http1/test_read.py +++ b/test/http/http1/test_read.py @@ -2,7 +2,7 @@ from __future__ import absolute_import, print_function, division from io import BytesIO import textwrap from mock import Mock -from netlib.exceptions import HttpException, HttpSyntaxException, HttpReadDisconnect +from netlib.exceptions import HttpException, HttpSyntaxException, HttpReadDisconnect, TcpDisconnect from netlib.http import Headers from netlib.http.http1.read import ( read_request, read_response, read_request_head, @@ -100,6 +100,11 @@ class TestReadBody(object): with raises(HttpException): b"".join(read_body(rfile, -1, 3)) + def test_max_chunk_size(self): + rfile = BytesIO(b"123456") + assert list(read_body(rfile, -1, max_chunk_size=None)) == [b"123456"] + rfile = BytesIO(b"123456") + assert list(read_body(rfile, -1, max_chunk_size=1)) == [b"1", b"2", b"3", b"4", b"5", b"6"] def test_connection_close(): headers = Headers() @@ -169,6 +174,11 @@ def test_get_first_line(): rfile = BytesIO(b"") _get_first_line(rfile) + with raises(HttpReadDisconnect): + rfile = Mock() + rfile.readline.side_effect = TcpDisconnect + _get_first_line(rfile) + with raises(HttpSyntaxException): rfile = BytesIO(b"GET /\xff HTTP/1.1") _get_first_line(rfile) @@ -191,7 +201,8 @@ def test_read_request_line(): t(b"GET / WTF/1.1") with raises(HttpSyntaxException): t(b"this is not http") - + with raises(HttpReadDisconnect): + t(b"") def test_parse_authority_form(): assert _parse_authority_form(b"foo:42") == (b"foo", 42) @@ -218,6 +229,8 @@ def test_read_response_line(): t(b"HTTP/1.1 OK OK") with raises(HttpSyntaxException): t(b"WTF/1.1 200 OK") + with raises(HttpReadDisconnect): + t(b"") def test_check_http_version(): diff --git a/test/http/test_headers.py b/test/http/test_headers.py index f1af1feb..8bddc0b2 100644 --- a/test/http/test_headers.py +++ b/test/http/test_headers.py @@ -38,6 +38,9 @@ class TestHeaders(object): assert headers["Host"] == "example.com" assert headers["Accept"] == "text/plain" + with raises(ValueError): + Headers([[b"Host", u"not-bytes"]]) + def test_getitem(self): headers = Headers(Host="example.com") assert headers["Host"] == "example.com" diff --git a/test/http/test_message.py b/test/http/test_message.py new file mode 100644 index 00000000..b0b7e27f --- /dev/null +++ b/test/http/test_message.py @@ -0,0 +1,136 @@ +# -*- coding: utf-8 -*- +from __future__ import absolute_import, print_function, division + +from netlib.http import decoded +from netlib.tutils import tresp + + +def _test_passthrough_attr(message, attr): + def t(self=None): + assert getattr(message, attr) == getattr(message.data, attr) + setattr(message, attr, "foo") + assert getattr(message.data, attr) == "foo" + return t + + +def _test_decoded_attr(message, attr): + def t(self=None): + assert getattr(message, attr) == getattr(message.data, attr).decode("utf8") + # Set str, get raw bytes + setattr(message, attr, "foo") + assert getattr(message.data, attr) == b"foo" + # Set raw bytes, get decoded + setattr(message.data, attr, b"bar") + assert getattr(message, attr) == "bar" + # Set bytes, get raw bytes + setattr(message, attr, b"baz") + assert getattr(message.data, attr) == b"baz" + + # Set UTF8 + setattr(message, attr, "Non-Autorisé") + assert getattr(message.data, attr) == b"Non-Autoris\xc3\xa9" + # Don't fail on garbage + setattr(message.data, attr, b"foo\xFF\x00bar") + assert getattr(message, attr).startswith("foo") + assert getattr(message, attr).endswith("bar") + # foo.bar = foo.bar should not cause any side effects. + d = getattr(message, attr) + setattr(message, attr, d) + assert getattr(message.data, attr) == b"foo\xFF\x00bar" + return t + + +class TestMessage(object): + + def test_init(self): + resp = tresp() + assert resp.data + + def test_eq_ne(self): + resp = tresp(timestamp_start=42, timestamp_end=42) + same = tresp(timestamp_start=42, timestamp_end=42) + assert resp == same + assert not resp != same + + other = tresp(timestamp_start=0, timestamp_end=0) + assert not resp == other + assert resp != other + + assert resp != 0 + + def test_content_length_update(self): + resp = tresp() + resp.content = b"foo" + assert resp.data.content == b"foo" + assert resp.headers["content-length"] == "3" + resp.content = b"" + assert resp.data.content == b"" + assert resp.headers["content-length"] == "0" + + test_content_basic = _test_passthrough_attr(tresp(), "content") + test_headers = _test_passthrough_attr(tresp(), "headers") + test_timestamp_start = _test_passthrough_attr(tresp(), "timestamp_start") + test_timestamp_end = _test_passthrough_attr(tresp(), "timestamp_end") + + test_http_version = _test_decoded_attr(tresp(), "http_version") + + +class TestDecodedDecorator(object): + + def test_simple(self): + r = tresp() + assert r.content == b"message" + assert "content-encoding" not in r.headers + assert r.encode("gzip") + + assert r.headers["content-encoding"] + assert r.content != b"message" + with decoded(r): + assert "content-encoding" not in r.headers + assert r.content == b"message" + assert r.headers["content-encoding"] + assert r.content != b"message" + + def test_modify(self): + r = tresp() + assert "content-encoding" not in r.headers + assert r.encode("gzip") + + with decoded(r): + r.content = b"foo" + + assert r.content != b"foo" + r.decode() + assert r.content == b"foo" + + def test_unknown_ce(self): + r = tresp() + r.headers["content-encoding"] = "zopfli" + r.content = b"foo" + with decoded(r): + assert r.headers["content-encoding"] + assert r.content == b"foo" + assert r.headers["content-encoding"] + assert r.content == b"foo" + + def test_cannot_decode(self): + r = tresp() + assert r.encode("gzip") + r.content = b"foo" + with decoded(r): + assert r.headers["content-encoding"] + assert r.content == b"foo" + assert r.headers["content-encoding"] + assert r.content != b"foo" + r.decode() + assert r.content == b"foo" + + def test_cannot_encode(self): + r = tresp() + assert r.encode("gzip") + with decoded(r): + r.content = None + + assert "content-encoding" not in r.headers + assert r.content is None + diff --git a/test/http/test_models.py b/test/http/test_models.py index aa267944..76a05446 100644 --- a/test/http/test_models.py +++ b/test/http/test_models.py @@ -1,271 +1,7 @@ -import mock from netlib import tutils -from netlib import utils from netlib.odict import ODict, ODictCaseless -from netlib.http import Request, Response, Headers, CONTENT_MISSING - -class TestRequest(object): - def test_repr(self): - r = tutils.treq() - assert repr(r) - - def test_headers(self): - tutils.raises(AssertionError, Request, - 'form_in', - 'method', - 'scheme', - 'host', - 'port', - 'path', - b"HTTP/1.1", - 'foobar', - ) - - req = Request( - 'form_in', - 'method', - 'scheme', - 'host', - 'port', - 'path', - b"HTTP/1.1", - ) - assert isinstance(req.headers, Headers) - - def test_equal(self): - a = tutils.treq(timestamp_start=42, timestamp_end=43) - b = tutils.treq(timestamp_start=42, timestamp_end=43) - assert a == b - assert not a != b - - assert not a == 'foo' - assert not b == 'foo' - assert not 'foo' == a - assert not 'foo' == b - - - def test_anticache(self): - req = tutils.treq() - req.headers["If-Modified-Since"] = "foo" - req.headers["If-None-Match"] = "bar" - req.anticache() - assert "If-Modified-Since" not in req.headers - assert "If-None-Match" not in req.headers - - def test_anticomp(self): - req = tutils.treq() - req.headers["Accept-Encoding"] = "foobar" - req.anticomp() - assert req.headers["Accept-Encoding"] == "identity" - - def test_constrain_encoding(self): - req = tutils.treq() - req.headers["Accept-Encoding"] = "identity, gzip, foo" - req.constrain_encoding() - assert "foo" not in req.headers["Accept-Encoding"] - - def test_update_host(self): - req = tutils.treq() - req.headers["Host"] = "" - req.host = "foobar" - assert req.headers["Host"] == "foobar" - - def test_get_form_urlencoded(self): - req = tutils.treq(content="foobar") - assert req.get_form_urlencoded() == ODict() - - req.headers["Content-Type"] = "application/x-www-form-urlencoded" - assert req.get_form_urlencoded() == ODict(utils.urldecode(req.body)) - - def test_get_form_multipart(self): - req = tutils.treq(content="foobar") - assert req.get_form_multipart() == ODict() - - req.headers["Content-Type"] = "multipart/form-data" - assert req.get_form_multipart() == ODict( - utils.multipartdecode( - req.headers, - req.body - ) - ) - - def test_set_form_urlencoded(self): - req = tutils.treq() - req.set_form_urlencoded(ODict([('foo', 'bar'), ('rab', 'oof')])) - assert req.headers["Content-Type"] == "application/x-www-form-urlencoded" - assert req.body - - def test_get_path_components(self): - req = tutils.treq() - assert req.get_path_components() - # TODO: add meaningful assertions - - def test_set_path_components(self): - req = tutils.treq() - req.set_path_components([b"foo", b"bar"]) - # TODO: add meaningful assertions - - def test_get_query(self): - req = tutils.treq() - assert req.get_query().lst == [] - - req.url = "http://localhost:80/foo?bar=42" - assert req.get_query().lst == [("bar", "42")] - - def test_set_query(self): - req = tutils.treq() - req.set_query(ODict([])) - - def test_pretty_host(self): - r = tutils.treq() - assert r.pretty_host == "address" - assert r.host == "address" - r.headers["host"] = "other" - assert r.pretty_host == "other" - assert r.host == "address" - r.host = None - assert r.pretty_host is None - assert r.host is None - - # Invalid IDNA - r.headers["host"] = ".disqus.com" - assert r.pretty_host == ".disqus.com" - - def test_pretty_url(self): - req = tutils.treq(first_line_format="relative") - assert req.pretty_url == "http://address:22/path" - assert req.url == "http://address:22/path" - - def test_get_cookies_none(self): - headers = Headers() - r = tutils.treq() - r.headers = headers - assert len(r.get_cookies()) == 0 - - def test_get_cookies_single(self): - r = tutils.treq() - r.headers = Headers(cookie="cookiename=cookievalue") - result = r.get_cookies() - assert len(result) == 1 - assert result['cookiename'] == ['cookievalue'] - - def test_get_cookies_double(self): - r = tutils.treq() - r.headers = Headers(cookie="cookiename=cookievalue;othercookiename=othercookievalue") - result = r.get_cookies() - assert len(result) == 2 - assert result['cookiename'] == ['cookievalue'] - assert result['othercookiename'] == ['othercookievalue'] - - def test_get_cookies_withequalsign(self): - r = tutils.treq() - r.headers = Headers(cookie="cookiename=coo=kievalue;othercookiename=othercookievalue") - result = r.get_cookies() - assert len(result) == 2 - assert result['cookiename'] == ['coo=kievalue'] - assert result['othercookiename'] == ['othercookievalue'] - - def test_set_cookies(self): - r = tutils.treq() - r.headers = Headers(cookie="cookiename=cookievalue") - result = r.get_cookies() - result["cookiename"] = ["foo"] - r.set_cookies(result) - assert r.get_cookies()["cookiename"] == ["foo"] - - def test_set_url(self): - r = tutils.treq(first_line_format="absolute") - r.url = b"https://otheraddress:42/ORLY" - assert r.scheme == "https" - assert r.host == "otheraddress" - assert r.port == 42 - assert r.path == "/ORLY" - - try: - r.url = "//localhost:80/foo@bar" - assert False - except: - assert True - - # def test_asterisk_form_in(self): - # f = tutils.tflow(req=None) - # protocol = mock_protocol("OPTIONS * HTTP/1.1") - # f.request = HTTPRequest.from_protocol(protocol) - # - # assert f.request.first_line_format == "relative" - # f.request.host = f.server_conn.address.host - # f.request.port = f.server_conn.address.port - # f.request.scheme = "http" - # assert protocol.assemble(f.request) == ( - # "OPTIONS * HTTP/1.1\r\n" - # "Host: address:22\r\n" - # "Content-Length: 0\r\n\r\n") - # - # def test_relative_form_in(self): - # protocol = mock_protocol("GET /foo\xff HTTP/1.1") - # tutils.raises("Bad HTTP request line", HTTPRequest.from_protocol, protocol) - # - # protocol = mock_protocol("GET /foo HTTP/1.1\r\nConnection: Upgrade\r\nUpgrade: h2c") - # r = HTTPRequest.from_protocol(protocol) - # assert r.headers["Upgrade"] == ["h2c"] - # - # def test_expect_header(self): - # protocol = mock_protocol( - # "GET / HTTP/1.1\r\nContent-Length: 3\r\nExpect: 100-continue\r\n\r\nfoobar") - # r = HTTPRequest.from_protocol(protocol) - # assert protocol.tcp_handler.wfile.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" - # assert r.content == "foo" - # assert protocol.tcp_handler.rfile.read(3) == "bar" - # - # def test_authority_form_in(self): - # protocol = mock_protocol("CONNECT oops-no-port.com HTTP/1.1") - # tutils.raises("Bad HTTP request line", HTTPRequest.from_protocol, protocol) - # - # protocol = mock_protocol("CONNECT address:22 HTTP/1.1") - # r = HTTPRequest.from_protocol(protocol) - # r.scheme, r.host, r.port = "http", "address", 22 - # assert protocol.assemble(r) == ( - # "CONNECT address:22 HTTP/1.1\r\n" - # "Host: address:22\r\n" - # "Content-Length: 0\r\n\r\n") - # assert r.pretty_url == "address:22" - # - # def test_absolute_form_in(self): - # protocol = mock_protocol("GET oops-no-protocol.com HTTP/1.1") - # tutils.raises("Bad HTTP request line", HTTPRequest.from_protocol, protocol) - # - # protocol = mock_protocol("GET http://address:22/ HTTP/1.1") - # r = HTTPRequest.from_protocol(protocol) - # assert protocol.assemble(r) == ( - # "GET http://address:22/ HTTP/1.1\r\n" - # "Host: address:22\r\n" - # "Content-Length: 0\r\n\r\n") - # - # def test_http_options_relative_form_in(self): - # """ - # Exercises fix for Issue #392. - # """ - # protocol = mock_protocol("OPTIONS /secret/resource HTTP/1.1") - # r = HTTPRequest.from_protocol(protocol) - # r.host = 'address' - # r.port = 80 - # r.scheme = "http" - # assert protocol.assemble(r) == ( - # "OPTIONS /secret/resource HTTP/1.1\r\n" - # "Host: address\r\n" - # "Content-Length: 0\r\n\r\n") - # - # def test_http_options_absolute_form_in(self): - # protocol = mock_protocol("OPTIONS http://address/secret/resource HTTP/1.1") - # r = HTTPRequest.from_protocol(protocol) - # r.host = 'address' - # r.port = 80 - # r.scheme = "http" - # assert protocol.assemble(r) == ( - # "OPTIONS http://address:80/secret/resource HTTP/1.1\r\n" - # "Host: address\r\n" - # "Content-Length: 0\r\n\r\n") +from netlib.http import Response, Headers, CONTENT_MISSING class TestResponse(object): def test_headers(self): diff --git a/test/http/test_request.py b/test/http/test_request.py index 02fac3df..15bdd3e3 100644 --- a/test/http/test_request.py +++ b/test/http/test_request.py @@ -1,3 +1,230 @@ +# -*- coding: utf-8 -*- from __future__ import absolute_import, print_function, division -# TODO \ No newline at end of file +import six + +from netlib import utils +from netlib.http import Headers +from netlib.odict import ODict +from netlib.tutils import treq, raises +from .test_message import _test_decoded_attr, _test_passthrough_attr + + +class TestRequestData(object): + def test_init(self): + with raises(AssertionError): + treq(headers="foobar") + + assert isinstance(treq(headers=None).headers, Headers) + + def test_eq_ne(self): + request_data = treq().data + same = treq().data + assert request_data == same + assert not request_data != same + + other = treq(content=b"foo").data + assert not request_data == other + assert request_data != other + + assert request_data != 0 + + +class TestRequestCore(object): + def test_repr(self): + request = treq() + assert repr(request) == "Request(GET address:22/path)" + request.host = None + assert repr(request) == "Request(GET /path)" + + test_first_line_format = _test_passthrough_attr(treq(), "first_line_format") + test_method = _test_decoded_attr(treq(), "method") + test_scheme = _test_decoded_attr(treq(), "scheme") + test_port = _test_passthrough_attr(treq(), "port") + test_path = _test_decoded_attr(treq(), "path") + + def test_host(self): + if six.PY2: + from unittest import SkipTest + raise SkipTest() + + request = treq() + assert request.host == request.data.host.decode("idna") + + # Test IDNA encoding + # Set str, get raw bytes + request.host = "ídna.example" + assert request.data.host == b"xn--dna-qma.example" + # Set raw bytes, get decoded + request.data.host = b"xn--idn-gla.example" + assert request.host == "idná.example" + # Set bytes, get raw bytes + request.host = b"xn--dn-qia9b.example" + assert request.data.host == b"xn--dn-qia9b.example" + # IDNA encoding is not bijective + request.host = "fußball" + assert request.host == "fussball" + + # Don't fail on garbage + request.data.host = b"foo\xFF\x00bar" + assert request.host.startswith("foo") + assert request.host.endswith("bar") + # foo.bar = foo.bar should not cause any side effects. + d = request.host + request.host = d + assert request.data.host == b"foo\xFF\x00bar" + + def test_host_header_update(self): + request = treq() + assert "host" not in request.headers + request.host = "example.com" + assert "host" not in request.headers + + request.headers["Host"] = "foo" + request.host = "example.org" + assert request.headers["Host"] == "example.org" + + +class TestRequestUtils(object): + def test_url(self): + request = treq() + assert request.url == "http://address:22/path" + + request.url = "https://otheraddress:42/foo" + assert request.scheme == "https" + assert request.host == "otheraddress" + assert request.port == 42 + assert request.path == "/foo" + + with raises(ValueError): + request.url = "not-a-url" + + def test_pretty_host(self): + request = treq() + assert request.pretty_host == "address" + assert request.host == "address" + request.headers["host"] = "other" + assert request.pretty_host == "other" + assert request.host == "address" + request.host = None + assert request.pretty_host is None + assert request.host is None + + # Invalid IDNA + request.headers["host"] = ".disqus.com" + assert request.pretty_host == ".disqus.com" + + def test_pretty_url(self): + request = treq() + assert request.url == "http://address:22/path" + assert request.pretty_url == "http://address:22/path" + request.headers["host"] = "other" + assert request.pretty_url == "http://other:22/path" + + def test_pretty_url_authority(self): + request = treq(first_line_format="authority") + assert request.pretty_url == "address:22" + + def test_get_query(self): + request = treq() + assert request.query is None + + request.url = "http://localhost:80/foo?bar=42" + assert request.query.lst == [("bar", "42")] + + def test_set_query(self): + request = treq() + request.query = ODict([]) + + def test_get_cookies_none(self): + request = treq() + request.headers = Headers() + assert len(request.cookies) == 0 + + def test_get_cookies_single(self): + request = treq() + request.headers = Headers(cookie="cookiename=cookievalue") + result = request.cookies + assert len(result) == 1 + assert result['cookiename'] == ['cookievalue'] + + def test_get_cookies_double(self): + request = treq() + request.headers = Headers(cookie="cookiename=cookievalue;othercookiename=othercookievalue") + result = request.cookies + assert len(result) == 2 + assert result['cookiename'] == ['cookievalue'] + assert result['othercookiename'] == ['othercookievalue'] + + def test_get_cookies_withequalsign(self): + request = treq() + request.headers = Headers(cookie="cookiename=coo=kievalue;othercookiename=othercookievalue") + result = request.cookies + assert len(result) == 2 + assert result['cookiename'] == ['coo=kievalue'] + assert result['othercookiename'] == ['othercookievalue'] + + def test_set_cookies(self): + request = treq() + request.headers = Headers(cookie="cookiename=cookievalue") + result = request.cookies + result["cookiename"] = ["foo"] + request.cookies = result + assert request.cookies["cookiename"] == ["foo"] + + def test_get_path_components(self): + request = treq(path=b"/foo/bar") + assert request.path_components == ["foo", "bar"] + + def test_set_path_components(self): + request = treq() + request.path_components = ["foo", "baz"] + assert request.path == "/foo/baz" + request.path_components = [] + assert request.path == "/" + + def test_anticache(self): + request = treq() + request.headers["If-Modified-Since"] = "foo" + request.headers["If-None-Match"] = "bar" + request.anticache() + assert "If-Modified-Since" not in request.headers + assert "If-None-Match" not in request.headers + + def test_anticomp(self): + request = treq() + request.headers["Accept-Encoding"] = "foobar" + request.anticomp() + assert request.headers["Accept-Encoding"] == "identity" + + def test_constrain_encoding(self): + request = treq() + request.headers["Accept-Encoding"] = "identity, gzip, foo" + request.constrain_encoding() + assert "foo" not in request.headers["Accept-Encoding"] + assert "gzip" in request.headers["Accept-Encoding"] + + def test_get_urlencoded_form(self): + request = treq(content="foobar") + assert request.urlencoded_form is None + + request.headers["Content-Type"] = "application/x-www-form-urlencoded" + assert request.urlencoded_form == ODict(utils.urldecode(request.content)) + + def test_set_urlencoded_form(self): + request = treq() + request.urlencoded_form = ODict([('foo', 'bar'), ('rab', 'oof')]) + assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" + assert request.content + + def test_get_multipart_form(self): + request = treq(content="foobar") + assert request.multipart_form is None + + request.headers["Content-Type"] = "multipart/form-data" + assert request.multipart_form == ODict( + utils.multipartdecode( + request.headers, + request.content + ) + ) diff --git a/test/http/test_status_codes.py b/test/http/test_status_codes.py new file mode 100644 index 00000000..9fea6b70 --- /dev/null +++ b/test/http/test_status_codes.py @@ -0,0 +1,6 @@ +from netlib.http import status_codes + + +def test_simple(): + assert status_codes.IM_A_TEAPOT == 418 + assert status_codes.RESPONSES[418] == "I'm a teapot" -- cgit v1.2.3 From 23d13e4c1282bc46c54222479c3b83032dad3335 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sun, 27 Sep 2015 00:49:41 +0200 Subject: test response model, push coverage to 100% branch cov --- netlib/http/cookies.py | 1 + netlib/http/message.py | 10 ++++ netlib/http/request.py | 12 +---- netlib/http/response.py | 14 ++---- test/http/http1/test_assemble.py | 13 +++++- test/http/http1/test_read.py | 3 ++ test/http/test_cookies.py | 1 + test/http/test_message.py | 91 +++++++++++++++++++++--------------- test/http/test_models.py | 94 -------------------------------------- test/http/test_request.py | 42 ++++++++++------- test/http/test_response.py | 99 +++++++++++++++++++++++++++++++++++++++- 11 files changed, 208 insertions(+), 172 deletions(-) delete mode 100644 test/http/test_models.py diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py index 78b03a83..18544b5e 100644 --- a/netlib/http/cookies.py +++ b/netlib/http/cookies.py @@ -58,6 +58,7 @@ def _read_quoted_string(s, start): escaping = False ret = [] # Skip the first quote + i = start # initialize in case the loop doesn't run. for i in range(start + 1, len(s)): if escaping: ret.append(s[i]) diff --git a/netlib/http/message.py b/netlib/http/message.py index 7cb18f52..e4e799ca 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -18,6 +18,16 @@ else: _always_bytes = lambda x: utils.always_bytes(x, "utf-8", "surrogateescape") +class MessageData(object): + def __eq__(self, other): + if isinstance(other, MessageData): + return self.__dict__ == other.__dict__ + return False + + def __ne__(self, other): + return not self.__eq__(other) + + class Message(object): def __init__(self, data): self.data = data diff --git a/netlib/http/request.py b/netlib/http/request.py index 325c0080..095b5945 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -10,10 +10,10 @@ from netlib.http import cookies from netlib.odict import ODict from .. import encoding from .headers import Headers -from .message import Message, _native, _always_bytes +from .message import Message, _native, _always_bytes, MessageData -class RequestData(object): +class RequestData(MessageData): def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=None, content=None, timestamp_start=None, timestamp_end=None): if not headers: @@ -32,14 +32,6 @@ class RequestData(object): self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end - def __eq__(self, other): - if isinstance(other, RequestData): - return self.__dict__ == other.__dict__ - return False - - def __ne__(self, other): - return not self.__eq__(other) - class Request(Message): """ diff --git a/netlib/http/response.py b/netlib/http/response.py index db31d2b9..66e5ded6 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -4,12 +4,12 @@ import warnings from . import cookies from .headers import Headers -from .message import Message, _native, _always_bytes +from .message import Message, _native, _always_bytes, MessageData from .. import utils from ..odict import ODict -class ResponseData(object): +class ResponseData(MessageData): def __init__(self, http_version, status_code, reason=None, headers=None, content=None, timestamp_start=None, timestamp_end=None): if not headers: @@ -24,14 +24,6 @@ class ResponseData(object): self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end - def __eq__(self, other): - if isinstance(other, ResponseData): - return self.__dict__ == other.__dict__ - return False - - def __ne__(self, other): - return not self.__eq__(other) - class Response(Message): """ @@ -48,7 +40,7 @@ class Response(Message): utils.pretty_size(len(self.content)) ) else: - details = "content missing" + details = "no content" return "Response({status_code} {reason}, {details})".format( status_code=self.status_code, reason=self.reason, diff --git a/test/http/http1/test_assemble.py b/test/http/http1/test_assemble.py index 460e22c5..ed94292d 100644 --- a/test/http/http1/test_assemble.py +++ b/test/http/http1/test_assemble.py @@ -78,10 +78,19 @@ def test_assemble_request_headers(): # https://github.com/mitmproxy/mitmproxy/issues/186 r = treq(content=b"") r.headers["Transfer-Encoding"] = "chunked" - c = _assemble_request_headers(r) + c = _assemble_request_headers(r.data) assert b"Transfer-Encoding" in c - assert b"host" in _assemble_request_headers(treq(headers=Headers())) + +def test_assemble_request_headers_host_header(): + r = treq() + r.headers = Headers() + c = _assemble_request_headers(r.data) + assert b"host" in c + + r.host = None + c = _assemble_request_headers(r.data) + assert b"host" not in c def test_assemble_response_headers(): diff --git a/test/http/http1/test_read.py b/test/http/http1/test_read.py index a0085db9..84a43f8b 100644 --- a/test/http/http1/test_read.py +++ b/test/http/http1/test_read.py @@ -117,6 +117,9 @@ def test_connection_close(): headers["connection"] = "close" assert connection_close(b"HTTP/1.1", headers) + headers["connection"] = "foobar" + assert connection_close(b"HTTP/1.0", headers) + assert not connection_close(b"HTTP/1.1", headers) def test_expected_http_body_size(): # Expect: 100-continue diff --git a/test/http/test_cookies.py b/test/http/test_cookies.py index 413b6241..34bb64f2 100644 --- a/test/http/test_cookies.py +++ b/test/http/test_cookies.py @@ -21,6 +21,7 @@ def test_read_quoted_string(): [(r'"f\\o" x', 0), (r"f\o", 6)], [(r'"f\\" x', 0), (r"f" + '\\', 5)], [('"fo\\\"" x', 0), ("fo\"", 6)], + [('"foo" x', 7), ("", 8)], ] for q, a in tokens: assert cookies._read_quoted_string(*q) == a diff --git a/test/http/test_message.py b/test/http/test_message.py index b0b7e27f..2c37dc3e 100644 --- a/test/http/test_message.py +++ b/test/http/test_message.py @@ -1,43 +1,53 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, print_function, division -from netlib.http import decoded -from netlib.tutils import tresp +from netlib.http import decoded, Headers +from netlib.tutils import tresp, raises def _test_passthrough_attr(message, attr): - def t(self=None): - assert getattr(message, attr) == getattr(message.data, attr) - setattr(message, attr, "foo") - assert getattr(message.data, attr) == "foo" - return t + assert getattr(message, attr) == getattr(message.data, attr) + setattr(message, attr, "foo") + assert getattr(message.data, attr) == "foo" def _test_decoded_attr(message, attr): - def t(self=None): - assert getattr(message, attr) == getattr(message.data, attr).decode("utf8") - # Set str, get raw bytes - setattr(message, attr, "foo") - assert getattr(message.data, attr) == b"foo" - # Set raw bytes, get decoded - setattr(message.data, attr, b"bar") - assert getattr(message, attr) == "bar" - # Set bytes, get raw bytes - setattr(message, attr, b"baz") - assert getattr(message.data, attr) == b"baz" - - # Set UTF8 - setattr(message, attr, "Non-Autorisé") - assert getattr(message.data, attr) == b"Non-Autoris\xc3\xa9" - # Don't fail on garbage - setattr(message.data, attr, b"foo\xFF\x00bar") - assert getattr(message, attr).startswith("foo") - assert getattr(message, attr).endswith("bar") - # foo.bar = foo.bar should not cause any side effects. - d = getattr(message, attr) - setattr(message, attr, d) - assert getattr(message.data, attr) == b"foo\xFF\x00bar" - return t + assert getattr(message, attr) == getattr(message.data, attr).decode("utf8") + # Set str, get raw bytes + setattr(message, attr, "foo") + assert getattr(message.data, attr) == b"foo" + # Set raw bytes, get decoded + setattr(message.data, attr, b"bar") + assert getattr(message, attr) == "bar" + # Set bytes, get raw bytes + setattr(message, attr, b"baz") + assert getattr(message.data, attr) == b"baz" + + # Set UTF8 + setattr(message, attr, "Non-Autorisé") + assert getattr(message.data, attr) == b"Non-Autoris\xc3\xa9" + # Don't fail on garbage + setattr(message.data, attr, b"foo\xFF\x00bar") + assert getattr(message, attr).startswith("foo") + assert getattr(message, attr).endswith("bar") + # foo.bar = foo.bar should not cause any side effects. + d = getattr(message, attr) + setattr(message, attr, d) + assert getattr(message.data, attr) == b"foo\xFF\x00bar" + + +class TestMessageData(object): + def test_eq_ne(self): + data = tresp(timestamp_start=42, timestamp_end=42).data + same = tresp(timestamp_start=42, timestamp_end=42).data + assert data == same + assert not data != same + + other = tresp(content=b"foo").data + assert not data == other + assert data != other + + assert data != 0 class TestMessage(object): @@ -67,12 +77,20 @@ class TestMessage(object): assert resp.data.content == b"" assert resp.headers["content-length"] == "0" - test_content_basic = _test_passthrough_attr(tresp(), "content") - test_headers = _test_passthrough_attr(tresp(), "headers") - test_timestamp_start = _test_passthrough_attr(tresp(), "timestamp_start") - test_timestamp_end = _test_passthrough_attr(tresp(), "timestamp_end") + def test_content_basic(self): + _test_passthrough_attr(tresp(), "content") + + def test_headers(self): + _test_passthrough_attr(tresp(), "headers") - test_http_version = _test_decoded_attr(tresp(), "http_version") + def test_timestamp_start(self): + _test_passthrough_attr(tresp(), "timestamp_start") + + def test_timestamp_end(self): + _test_passthrough_attr(tresp(), "timestamp_end") + + def teste_http_version(self): + _test_decoded_attr(tresp(), "http_version") class TestDecodedDecorator(object): @@ -133,4 +151,3 @@ class TestDecodedDecorator(object): assert "content-encoding" not in r.headers assert r.content is None - diff --git a/test/http/test_models.py b/test/http/test_models.py deleted file mode 100644 index 76a05446..00000000 --- a/test/http/test_models.py +++ /dev/null @@ -1,94 +0,0 @@ - -from netlib import tutils -from netlib.odict import ODict, ODictCaseless -from netlib.http import Response, Headers, CONTENT_MISSING - -class TestResponse(object): - def test_headers(self): - tutils.raises(AssertionError, Response, - b"HTTP/1.1", - 200, - headers='foobar', - ) - - resp = Response( - b"HTTP/1.1", - 200, - ) - assert isinstance(resp.headers, Headers) - - def test_equal(self): - a = tutils.tresp(timestamp_start=42, timestamp_end=43) - b = tutils.tresp(timestamp_start=42, timestamp_end=43) - assert a == b - - assert not a == 'foo' - assert not b == 'foo' - assert not 'foo' == a - assert not 'foo' == b - - def test_repr(self): - r = tutils.tresp() - assert "unknown content type" in repr(r) - r.headers["content-type"] = "foo" - assert "foo" in repr(r) - assert repr(tutils.tresp(content=CONTENT_MISSING)) - - def test_get_cookies_none(self): - resp = tutils.tresp() - resp.headers = Headers() - assert not resp.get_cookies() - - def test_get_cookies_simple(self): - resp = tutils.tresp() - resp.headers = Headers(set_cookie="cookiename=cookievalue") - result = resp.get_cookies() - assert len(result) == 1 - assert "cookiename" in result - assert result["cookiename"][0] == ["cookievalue", ODict()] - - def test_get_cookies_with_parameters(self): - resp = tutils.tresp() - resp.headers = Headers(set_cookie="cookiename=cookievalue;domain=example.com;expires=Wed Oct 21 16:29:41 2015;path=/; HttpOnly") - result = resp.get_cookies() - assert len(result) == 1 - assert "cookiename" in result - assert result["cookiename"][0][0] == "cookievalue" - attrs = result["cookiename"][0][1] - assert len(attrs) == 4 - assert attrs["domain"] == ["example.com"] - assert attrs["expires"] == ["Wed Oct 21 16:29:41 2015"] - assert attrs["path"] == ["/"] - assert attrs["httponly"] == [None] - - def test_get_cookies_no_value(self): - resp = tutils.tresp() - resp.headers = Headers(set_cookie="cookiename=; Expires=Thu, 01-Jan-1970 00:00:01 GMT; path=/") - result = resp.get_cookies() - assert len(result) == 1 - assert "cookiename" in result - assert result["cookiename"][0][0] == "" - assert len(result["cookiename"][0][1]) == 2 - - def test_get_cookies_twocookies(self): - resp = tutils.tresp() - resp.headers = Headers([ - [b"Set-Cookie", b"cookiename=cookievalue"], - [b"Set-Cookie", b"othercookie=othervalue"] - ]) - result = resp.get_cookies() - assert len(result) == 2 - assert "cookiename" in result - assert result["cookiename"][0] == ["cookievalue", ODict()] - assert "othercookie" in result - assert result["othercookie"][0] == ["othervalue", ODict()] - - def test_set_cookies(self): - resp = tutils.tresp() - v = resp.get_cookies() - v.add("foo", ["bar", ODictCaseless()]) - resp.set_cookies(v) - - v = resp.get_cookies() - assert len(v) == 1 - assert v["foo"] == [["bar", ODictCaseless()]] diff --git a/test/http/test_request.py b/test/http/test_request.py index 15bdd3e3..8cf69ffe 100644 --- a/test/http/test_request.py +++ b/test/http/test_request.py @@ -17,31 +17,31 @@ class TestRequestData(object): assert isinstance(treq(headers=None).headers, Headers) - def test_eq_ne(self): - request_data = treq().data - same = treq().data - assert request_data == same - assert not request_data != same - - other = treq(content=b"foo").data - assert not request_data == other - assert request_data != other - - assert request_data != 0 - class TestRequestCore(object): + """ + Tests for builtins and the attributes that are directly proxied from the data structure + """ def test_repr(self): request = treq() assert repr(request) == "Request(GET address:22/path)" request.host = None assert repr(request) == "Request(GET /path)" - test_first_line_format = _test_passthrough_attr(treq(), "first_line_format") - test_method = _test_decoded_attr(treq(), "method") - test_scheme = _test_decoded_attr(treq(), "scheme") - test_port = _test_passthrough_attr(treq(), "port") - test_path = _test_decoded_attr(treq(), "path") + def test_first_line_format(self): + _test_passthrough_attr(treq(), "first_line_format") + + def test_method(self): + _test_decoded_attr(treq(), "method") + + def test_scheme(self): + _test_decoded_attr(treq(), "scheme") + + def test_port(self): + _test_passthrough_attr(treq(), "port") + + def test_path(self): + _test_decoded_attr(treq(), "path") def test_host(self): if six.PY2: @@ -86,6 +86,9 @@ class TestRequestCore(object): class TestRequestUtils(object): + """ + Tests for additional convenience methods. + """ def test_url(self): request = treq() assert request.url == "http://address:22/path" @@ -199,6 +202,11 @@ class TestRequestUtils(object): def test_constrain_encoding(self): request = treq() + + h = request.headers.copy() + request.constrain_encoding() # no-op if there is no accept_encoding header. + assert request.headers == h + request.headers["Accept-Encoding"] = "identity, gzip, foo" request.constrain_encoding() assert "foo" not in request.headers["Accept-Encoding"] diff --git a/test/http/test_response.py b/test/http/test_response.py index 02fac3df..a1f4abd7 100644 --- a/test/http/test_response.py +++ b/test/http/test_response.py @@ -1,3 +1,100 @@ from __future__ import absolute_import, print_function, division -# TODO \ No newline at end of file +from netlib.http import Headers +from netlib.odict import ODict, ODictCaseless +from netlib.tutils import raises, tresp +from .test_message import _test_passthrough_attr, _test_decoded_attr + + +class TestResponseData(object): + def test_init(self): + with raises(AssertionError): + tresp(headers="foobar") + + assert isinstance(tresp(headers=None).headers, Headers) + + +class TestResponseCore(object): + """ + Tests for builtins and the attributes that are directly proxied from the data structure + """ + def test_repr(self): + response = tresp() + assert repr(response) == "Response(200 OK, unknown content type, 7B)" + response.content = None + assert repr(response) == "Response(200 OK, no content)" + + def test_status_code(self): + _test_passthrough_attr(tresp(), "status_code") + + def test_reason(self): + _test_decoded_attr(tresp(), "reason") + + +class TestResponseUtils(object): + """ + Tests for additional convenience methods. + """ + def test_get_cookies_none(self): + resp = tresp() + resp.headers = Headers() + assert not resp.cookies + + def test_get_cookies_empty(self): + resp = tresp() + resp.headers = Headers(set_cookie="") + assert not resp.cookies + + def test_get_cookies_simple(self): + resp = tresp() + resp.headers = Headers(set_cookie="cookiename=cookievalue") + result = resp.cookies + assert len(result) == 1 + assert "cookiename" in result + assert result["cookiename"][0] == ["cookievalue", ODict()] + + def test_get_cookies_with_parameters(self): + resp = tresp() + resp.headers = Headers(set_cookie="cookiename=cookievalue;domain=example.com;expires=Wed Oct 21 16:29:41 2015;path=/; HttpOnly") + result = resp.cookies + assert len(result) == 1 + assert "cookiename" in result + assert result["cookiename"][0][0] == "cookievalue" + attrs = result["cookiename"][0][1] + assert len(attrs) == 4 + assert attrs["domain"] == ["example.com"] + assert attrs["expires"] == ["Wed Oct 21 16:29:41 2015"] + assert attrs["path"] == ["/"] + assert attrs["httponly"] == [None] + + def test_get_cookies_no_value(self): + resp = tresp() + resp.headers = Headers(set_cookie="cookiename=; Expires=Thu, 01-Jan-1970 00:00:01 GMT; path=/") + result = resp.cookies + assert len(result) == 1 + assert "cookiename" in result + assert result["cookiename"][0][0] == "" + assert len(result["cookiename"][0][1]) == 2 + + def test_get_cookies_twocookies(self): + resp = tresp() + resp.headers = Headers([ + [b"Set-Cookie", b"cookiename=cookievalue"], + [b"Set-Cookie", b"othercookie=othervalue"] + ]) + result = resp.cookies + assert len(result) == 2 + assert "cookiename" in result + assert result["cookiename"][0] == ["cookievalue", ODict()] + assert "othercookie" in result + assert result["othercookie"][0] == ["othervalue", ODict()] + + def test_set_cookies(self): + resp = tresp() + v = resp.cookies + v.add("foo", ["bar", ODictCaseless()]) + resp.set_cookies(v) + + v = resp.cookies + assert len(v) == 1 + assert v["foo"] == [["bar", ODictCaseless()]] -- cgit v1.2.3