From 89f22f735944989912a7a0394dd7e80d420cb0f3 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 27 Jul 2015 11:46:49 +0200 Subject: refactor connection & protocol handling --- libmproxy/protocol/http.py | 911 +----------------------------------- libmproxy/protocol/http_wrappers.py | 829 ++++++++++++++++++++++++++++++++ libmproxy/proxy/connection.py | 15 +- test/test_protocol_http.py | 76 +-- 4 files changed, 906 insertions(+), 925 deletions(-) create mode 100644 libmproxy/protocol/http_wrappers.py diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index f2ac5acc..e0deadd5 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -16,6 +16,8 @@ from .primitives import KILL, ProtocolHandler, Flow, Error from ..proxy.connection import ServerConnection from .. import encoding, utils, controller, stateobject, proxy +from .http_wrappers import decoded, HTTPRequest, HTTPResponse + HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" HDR_FORM_MULTIPART = "multipart/form-data" CONTENT_MISSING = 0 @@ -38,9 +40,10 @@ def send_connect_request(conn, host, port, update_state=True): "" ) conn.send(upstream_request.assemble()) - resp = HTTPResponse.from_stream(conn.rfile, upstream_request.method) - if resp.code != 200: - raise proxy.ProxyError(resp.code, + protocol = http.http1.HTTP1Protocol(conn) + resp = HTTPResponse.from_protocol(protocol, upstream_request.method) + if resp.status_code != 200: + raise proxy.ProxyError(resp.status_code, "Cannot establish SSL " + "connection with upstream proxy: \r\n" + str(resp.assemble())) @@ -53,884 +56,6 @@ def send_connect_request(conn, host, port, update_state=True): return resp -class decoded(object): - """ - A context manager that decodes a request or response, and then - re-encodes it with the same encoding after execution of the block. - - Example: - with decoded(request): - request.content = request.content.replace("foo", "bar") - """ - - def __init__(self, o): - self.o = o - ce = o.headers.get_first("content-encoding") - if ce in encoding.ENCODINGS: - self.ce = ce - else: - self.ce = None - - def __enter__(self): - if self.ce: - self.o.decode() - - def __exit__(self, type, value, tb): - if self.ce: - self.o.encode(self.ce) - - -class HTTPMessage(stateobject.StateObject): - """ - Base class for HTTPRequest and HTTPResponse - """ - - def __init__(self, httpversion, headers, content, timestamp_start=None, - timestamp_end=None): - self.httpversion = httpversion - self.headers = headers - """@type: odict.ODictCaseless""" - self.content = content - - self.timestamp_start = timestamp_start - self.timestamp_end = timestamp_end - - _stateobject_attributes = dict( - httpversion=tuple, - headers=odict.ODictCaseless, - content=str, - timestamp_start=float, - timestamp_end=float - ) - _stateobject_long_attributes = {"content"} - - def get_state(self, short=False): - ret = super(HTTPMessage, self).get_state(short) - if short: - if self.content: - ret["contentLength"] = len(self.content) - elif self.content == CONTENT_MISSING: - ret["contentLength"] = None - else: - ret["contentLength"] = 0 - return ret - - def get_decoded_content(self): - """ - Returns the decoded content based on the current Content-Encoding - header. - Doesn't change the message iteself or its headers. - """ - ce = self.headers.get_first("content-encoding") - if not self.content or ce not in encoding.ENCODINGS: - return self.content - return encoding.decode(ce, self.content) - - def decode(self): - """ - Decodes content based on the current Content-Encoding header, then - removes the header. If there is no Content-Encoding header, no - action is taken. - - Returns True if decoding succeeded, False otherwise. - """ - ce = self.headers.get_first("content-encoding") - if not self.content or ce not in encoding.ENCODINGS: - return False - data = encoding.decode(ce, self.content) - if data is None: - return False - self.content = data - del self.headers["content-encoding"] - return True - - def encode(self, e): - """ - Encodes content with the encoding e, where e is "gzip", "deflate" - or "identity". - """ - # FIXME: Error if there's an existing encoding header? - self.content = encoding.encode(e, self.content) - self.headers["content-encoding"] = [e] - - def size(self, **kwargs): - """ - Size in bytes of a fully rendered message, including headers and - HTTP lead-in. - """ - hl = len(self._assemble_head(**kwargs)) - if self.content: - return hl + len(self.content) - else: - return hl - - def copy(self): - c = copy.copy(self) - c.headers = self.headers.copy() - return c - - def replace(self, pattern, repl, *args, **kwargs): - """ - Replaces a regular expression pattern with repl in both the headers - and the body of the message. Encoded content will be decoded - before replacement, and re-encoded afterwards. - - Returns the number of replacements made. - """ - with decoded(self): - self.content, c = utils.safe_subn( - pattern, repl, self.content, *args, **kwargs - ) - c += self.headers.replace(pattern, repl, *args, **kwargs) - return c - - def _assemble_first_line(self): - """ - Returns the assembled request/response line - """ - raise NotImplementedError() # pragma: nocover - - def _assemble_headers(self): - """ - Returns the assembled headers - """ - raise NotImplementedError() # pragma: nocover - - def _assemble_head(self): - """ - Returns the assembled request/response line plus headers - """ - raise NotImplementedError() # pragma: nocover - - def assemble(self): - """ - Returns the assembled request/response - """ - raise NotImplementedError() # pragma: nocover - - -class HTTPRequest(HTTPMessage): - """ - An HTTP request. - - Exposes the following attributes: - - method: HTTP method - - scheme: URL scheme (http/https) - - host: Target hostname of the request. This is not neccessarily the - directy upstream server (which could be another proxy), but it's always - the target server we want to reach at the end. This attribute is either - inferred from the request itself (absolute-form, authority-form) or from - the connection metadata (e.g. the host in reverse proxy mode). - - port: Destination port - - path: Path portion of the URL (not present in authority-form) - - httpversion: HTTP version tuple, e.g. (1,1) - - headers: odict.ODictCaseless object - - content: Content of the request, None, or CONTENT_MISSING if there - is content associated, but not present. CONTENT_MISSING evaluates - to False to make checking for the presence of content natural. - - form_in: The request form which mitmproxy has received. The following - values are possible: - - - relative (GET /index.html, OPTIONS *) (covers origin form and - asterisk form) - - absolute (GET http://example.com:80/index.html) - - authority-form (CONNECT example.com:443) - Details: http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-25#section-5.3 - - form_out: The request form which mitmproxy will send out to the - destination - - timestamp_start: Timestamp indicating when request transmission started - - timestamp_end: Timestamp indicating when request transmission ended - """ - - def __init__( - self, - form_in, - method, - scheme, - host, - port, - path, - httpversion, - headers, - content, - timestamp_start=None, - timestamp_end=None, - form_out=None - ): - assert isinstance(headers, odict.ODictCaseless) or not headers - HTTPMessage.__init__( - self, - httpversion, - headers, - content, - timestamp_start, - timestamp_end - ) - self.form_in = form_in - self.method = method - self.scheme = scheme - self.host = host - self.port = port - self.path = path - self.httpversion = httpversion - self.form_out = form_out or form_in - - # Have this request's cookies been modified by sticky cookies or auth? - self.stickycookie = False - self.stickyauth = False - # Is this request replayed? - self.is_replay = False - - _stateobject_attributes = HTTPMessage._stateobject_attributes.copy() - _stateobject_attributes.update( - form_in=str, - method=str, - scheme=str, - host=str, - port=int, - path=str, - form_out=str, - is_replay=bool - ) - - @property - def body(self): - return self.content - - @classmethod - def from_state(cls, state): - f = cls( - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None) - f.load_state(state) - return f - - def __repr__(self): - return "".format( - self._assemble_first_line(self.form_in)[:-9] - ) - - @classmethod - def from_stream( - cls, - rfile, - include_body=True, - body_size_limit=None, - wfile=None): - """ - Parse an HTTP request from a file stream - - Args: - rfile (file): Input file to read from - include_body (bool): Read response body as well - body_size_limit (bool): Maximum body size - wfile (file): If specified, HTTP Expect headers are handled automatically. - by writing a HTTP 100 CONTINUE response to the stream. - - Returns: - HTTPRequest: The HTTP request - - Raises: - HttpError: If the input is invalid. - """ - timestamp_start, timestamp_end = None, None - - timestamp_start = utils.timestamp() - if hasattr(rfile, "reset_timestamps"): - rfile.reset_timestamps() - - protocol = http1.HTTP1Protocol(rfile=rfile, wfile=wfile) - req = protocol.read_request( - include_body = include_body, - body_size_limit = body_size_limit, - ) - - if hasattr(rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = rfile.first_byte_timestamp - - timestamp_end = utils.timestamp() - return HTTPRequest( - req.form_in, - req.method, - req.scheme, - req.host, - req.port, - req.path, - req.httpversion, - req.headers, - req.body, - timestamp_start, - timestamp_end - ) - - def _assemble_first_line(self, form=None): - form = form or self.form_out - - if form == "relative": - request_line = '%s %s HTTP/%s.%s' % ( - self.method, self.path, self.httpversion[0], self.httpversion[1] - ) - elif form == "authority": - request_line = '%s %s:%s HTTP/%s.%s' % ( - self.method, self.host, self.port, self.httpversion[0], - self.httpversion[1] - ) - elif form == "absolute": - request_line = '%s %s://%s:%s%s HTTP/%s.%s' % ( - self.method, self.scheme, self.host, - self.port, self.path, self.httpversion[0], - self.httpversion[1] - ) - else: - raise http.HttpError(400, "Invalid request form") - return request_line - - # This list is adopted legacy code. - # We probably don't need to strip off keep-alive. - _headers_to_strip_off = ['Proxy-Connection', - 'Keep-Alive', - 'Connection', - 'Transfer-Encoding', - 'Upgrade'] - - def _assemble_headers(self): - headers = self.headers.copy() - for k in self._headers_to_strip_off: - del headers[k] - if 'host' not in headers and self.scheme and self.host and self.port: - headers["Host"] = [utils.hostport(self.scheme, - self.host, - self.port)] - - # If content is defined (i.e. not None or CONTENT_MISSING), we always - # add a content-length header. - if self.content or self.content == "": - headers["Content-Length"] = [str(len(self.content))] - - return headers.format() - - def _assemble_head(self, form=None): - return "%s\r\n%s\r\n" % ( - self._assemble_first_line(form), self._assemble_headers() - ) - - def assemble(self, form=None): - """ - Assembles the request for transmission to the server. We make some - modifications to make sure interception works properly. - - Raises an Exception if the request cannot be assembled. - """ - if self.content == CONTENT_MISSING: - raise proxy.ProxyError( - 502, - "Cannot assemble flow with CONTENT_MISSING" - ) - head = self._assemble_head(form) - if self.content: - return head + self.content - else: - return head - - def __hash__(self): - return id(self) - - def anticache(self): - """ - Modifies this request to remove headers that might produce a cached - response. That is, we remove ETags and If-Modified-Since headers. - """ - delheaders = [ - "if-modified-since", - "if-none-match", - ] - for i in delheaders: - del self.headers[i] - - def anticomp(self): - """ - Modifies this request to remove headers that will compress the - resource's data. - """ - self.headers["accept-encoding"] = ["identity"] - - def constrain_encoding(self): - """ - Limits the permissible Accept-Encoding values, based on what we can - decode appropriately. - """ - if self.headers["accept-encoding"]: - self.headers["accept-encoding"] = [ - ', '.join( - e for e in encoding.ENCODINGS if e in self.headers["accept-encoding"][0])] - - def update_host_header(self): - """ - Update the host header to reflect the current target. - """ - self.headers["Host"] = [self.host] - - def get_form(self): - """ - Retrieves the URL-encoded or multipart form data, returning an ODict object. - Returns an empty ODict if there is no data or the content-type - indicates non-form data. - """ - if self.content: - if self.headers.in_any("content-type", HDR_FORM_URLENCODED, True): - return self.get_form_urlencoded() - elif self.headers.in_any("content-type", HDR_FORM_MULTIPART, True): - return self.get_form_multipart() - return odict.ODict([]) - - def get_form_urlencoded(self): - """ - Retrieves the URL-encoded form data, returning an ODict object. - Returns an empty ODict if there is no data or the content-type - indicates non-form data. - """ - if self.content and self.headers.in_any( - "content-type", - HDR_FORM_URLENCODED, - True): - return odict.ODict(utils.urldecode(self.content)) - return odict.ODict([]) - - def get_form_multipart(self): - if self.content and self.headers.in_any( - "content-type", - HDR_FORM_MULTIPART, - True): - return odict.ODict( - utils.multipartdecode( - self.headers, - self.content)) - return odict.ODict([]) - - def set_form_urlencoded(self, odict): - """ - Sets the body to the URL-encoded form data, and adds the - appropriate content-type header. Note that this will destory the - existing body if there is one. - """ - # FIXME: If there's an existing content-type header indicating a - # url-encoded form, leave it alone. - self.headers["Content-Type"] = [HDR_FORM_URLENCODED] - self.content = utils.urlencode(odict.lst) - - def get_path_components(self): - """ - Returns the path components of the URL as a list of strings. - - Components are unquoted. - """ - _, _, path, _, _, _ = urlparse.urlparse(self.url) - return [urllib.unquote(i) for i in path.split("/") if i] - - def set_path_components(self, lst): - """ - Takes a list of strings, and sets the path component of the URL. - - Components are quoted. - """ - lst = [urllib.quote(i, safe="") for i in lst] - path = "/" + "/".join(lst) - scheme, netloc, _, params, query, fragment = urlparse.urlparse(self.url) - self.url = urlparse.urlunparse( - [scheme, netloc, path, params, query, fragment] - ) - - def get_query(self): - """ - Gets the request query string. Returns an ODict object. - """ - _, _, _, _, query, _ = urlparse.urlparse(self.url) - if query: - return odict.ODict(utils.urldecode(query)) - return odict.ODict([]) - - def set_query(self, odict): - """ - Takes an ODict object, and sets the request query string. - """ - scheme, netloc, path, params, _, fragment = urlparse.urlparse(self.url) - query = utils.urlencode(odict.lst) - self.url = urlparse.urlunparse( - [scheme, netloc, path, params, query, fragment] - ) - - def pretty_host(self, hostheader): - """ - Heuristic to get the host of the request. - - Note that pretty_host() does not always return the TCP destination - of the request, e.g. if an upstream proxy is in place - - If hostheader is set to True, the Host: header will be used as - additional (and preferred) data source. This is handy in - transparent mode, where only the IO of the destination is known, - but not the resolved name. This is disabled by default, as an - attacker may spoof the host header to confuse an analyst. - """ - host = None - if hostheader: - host = self.headers.get_first("host") - if not host: - host = self.host - if host: - try: - return host.encode("idna") - except ValueError: - return host - else: - return None - - def pretty_url(self, hostheader): - if self.form_out == "authority": # upstream proxy mode - return "%s:%s" % (self.pretty_host(hostheader), self.port) - return utils.unparse_url(self.scheme, - self.pretty_host(hostheader), - self.port, - self.path).encode('ascii') - - @property - def url(self): - """ - Returns a URL string, constructed from the Request's URL components. - """ - return utils.unparse_url( - self.scheme, - self.host, - self.port, - self.path - ).encode('ascii') - - @url.setter - def url(self, url): - """ - Parses a URL specification, and updates the Request's information - accordingly. - - Returns False if the URL was invalid, True if the request succeeded. - """ - parts = http.parse_url(url) - if not parts: - raise ValueError("Invalid URL: %s" % url) - self.scheme, self.host, self.port, self.path = parts - - def get_cookies(self): - """ - - Returns a possibly empty netlib.odict.ODict object. - """ - ret = odict.ODict() - for i in self.headers["cookie"]: - ret.extend(cookies.parse_cookie_header(i)) - return ret - - def set_cookies(self, odict): - """ - Takes an netlib.odict.ODict object. Over-writes any existing Cookie - headers. - """ - v = cookies.format_cookie_header(odict) - self.headers["Cookie"] = [v] - - def replace(self, pattern, repl, *args, **kwargs): - """ - Replaces a regular expression pattern with repl in the headers, the - request path and the body of the request. Encoded content will be - decoded before replacement, and re-encoded afterwards. - - Returns the number of replacements made. - """ - c = HTTPMessage.replace(self, pattern, repl, *args, **kwargs) - self.path, pc = utils.safe_subn( - pattern, repl, self.path, *args, **kwargs - ) - c += pc - return c - - -class HTTPResponse(HTTPMessage): - """ - An HTTP response. - - Exposes the following attributes: - - httpversion: HTTP version tuple, e.g. (1,1) - - code: HTTP response code - - msg: HTTP response message - - headers: ODict object - - content: Content of the request, None, or CONTENT_MISSING if there - is content associated, but not present. CONTENT_MISSING evaluates - to False to make checking for the presence of content natural. - - timestamp_start: Timestamp indicating when request transmission started - - timestamp_end: Timestamp indicating when request transmission ended - """ - - def __init__( - self, - httpversion, - code, - msg, - headers, - content, - timestamp_start=None, - timestamp_end=None): - assert isinstance(headers, odict.ODictCaseless) or headers is None - HTTPMessage.__init__( - self, - httpversion, - headers, - content, - timestamp_start, - timestamp_end - ) - - self.code = code - self.msg = msg - - # Is this request replayed? - self.is_replay = False - self.stream = False - - _stateobject_attributes = HTTPMessage._stateobject_attributes.copy() - _stateobject_attributes.update( - code=int, - msg=str - ) - - - @property - def body(self): - return self.content - - - @classmethod - def from_state(cls, state): - f = cls(None, None, None, None, None) - f.load_state(state) - return f - - def __repr__(self): - if self.content: - size = netlib.utils.pretty_size(len(self.content)) - else: - size = "content missing" - return "".format( - code=self.code, - msg=self.msg, - contenttype=self.headers.get_first( - "content-type", "unknown content type" - ), - size=size - ) - - @classmethod - def from_stream( - cls, - rfile, - request_method, - include_body=True, - body_size_limit=None): - """ - Parse an HTTP response from a file stream - """ - - timestamp_start = utils.timestamp() - - if hasattr(rfile, "reset_timestamps"): - rfile.reset_timestamps() - - protocol = http1.HTTP1Protocol(rfile=rfile) - resp = protocol.read_response( - request_method, - body_size_limit, - include_body=include_body - ) - - if hasattr(rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = rfile.first_byte_timestamp - - if include_body: - timestamp_end = utils.timestamp() - else: - timestamp_end = None - - return HTTPResponse( - resp.httpversion, - resp.status_code, - resp.msg, - resp.headers, - resp.body, - timestamp_start, - timestamp_end - ) - - def _assemble_first_line(self): - return 'HTTP/%s.%s %s %s' % \ - (self.httpversion[0], self.httpversion[1], self.code, self.msg) - - _headers_to_strip_off = ['Proxy-Connection', - 'Alternate-Protocol', - 'Alt-Svc'] - - def _assemble_headers(self, preserve_transfer_encoding=False): - headers = self.headers.copy() - for k in self._headers_to_strip_off: - del headers[k] - if not preserve_transfer_encoding: - del headers['Transfer-Encoding'] - - # If content is defined (i.e. not None or CONTENT_MISSING), we always - # add a content-length header. - if self.content or self.content == "": - headers["Content-Length"] = [str(len(self.content))] - - return headers.format() - - def _assemble_head(self, preserve_transfer_encoding=False): - return '%s\r\n%s\r\n' % ( - self._assemble_first_line(), - self._assemble_headers( - preserve_transfer_encoding=preserve_transfer_encoding - ) - ) - - def assemble(self): - """ - Assembles the response for transmission to the client. We make some - modifications to make sure interception works properly. - - Raises an Exception if the request cannot be assembled. - """ - if self.content == CONTENT_MISSING: - raise proxy.ProxyError( - 502, - "Cannot assemble flow with CONTENT_MISSING" - ) - head = self._assemble_head() - if self.content: - return head + self.content - else: - return head - - def _refresh_cookie(self, c, delta): - """ - Takes a cookie string c and a time delta in seconds, and returns - a refreshed cookie string. - """ - c = Cookie.SimpleCookie(str(c)) - for i in c.values(): - if "expires" in i: - d = parsedate_tz(i["expires"]) - if d: - d = mktime_tz(d) + delta - i["expires"] = formatdate(d) - else: - # This can happen when the expires tag is invalid. - # reddit.com sends a an expires tag like this: "Thu, 31 Dec - # 2037 23:59:59 GMT", which is valid RFC 1123, but not - # strictly correct according to the cookie spec. Browsers - # appear to parse this tolerantly - maybe we should too. - # For now, we just ignore this. - del i["expires"] - return c.output(header="").strip() - - def refresh(self, now=None): - """ - This fairly complex and heuristic function refreshes a server - response for replay. - - - It adjusts date, expires and last-modified headers. - - It adjusts cookie expiration. - """ - if not now: - now = time.time() - delta = now - self.timestamp_start - refresh_headers = [ - "date", - "expires", - "last-modified", - ] - for i in refresh_headers: - if i in self.headers: - d = parsedate_tz(self.headers[i][0]) - if d: - new = mktime_tz(d) + delta - self.headers[i] = [formatdate(new)] - c = [] - for i in self.headers["set-cookie"]: - c.append(self._refresh_cookie(i, delta)) - if c: - self.headers["set-cookie"] = c - - def get_cookies(self): - """ - Get the contents of all Set-Cookie headers. - - Returns a possibly empty ODict, where keys are cookie name strings, - and values are [value, attr] lists. Value is a string, and attr is - an ODictCaseless containing cookie attributes. Within attrs, unary - attributes (e.g. HTTPOnly) are indicated by a Null value. - """ - ret = [] - for header in self.headers["set-cookie"]: - v = http.cookies.parse_set_cookie_header(header) - if v: - name, value, attrs = v - ret.append([name, [value, attrs]]) - return odict.ODict(ret) - - def set_cookies(self, odict): - """ - Set the Set-Cookie headers on this response, over-writing existing - headers. - - Accepts an ODict of the same format as that returned by get_cookies. - """ - values = [] - for i in odict.lst: - values.append( - http.cookies.format_set_cookie_header( - i[0], - i[1][0], - i[1][1] - ) - ) - self.headers["Set-Cookie"] = values - - class HTTPFlow(Flow): """ A HTTPFlow is a collection of objects representing a single HTTP @@ -1054,9 +179,11 @@ class HTTPHandler(ProtocolHandler): for attempt in (0, 1): try: self.c.server_conn.send(request_raw) + # Only get the headers at first... - flow.response = HTTPResponse.from_stream( - self.c.server_conn.rfile, + protocol = http.http1.HTTP1Protocol(self.c.server_conn) + flow.response = HTTPResponse.from_protocol( + protocol, flow.request.method, body_size_limit=self.c.config.body_size_limit, include_body=False @@ -1094,7 +221,7 @@ class HTTPHandler(ProtocolHandler): if flow.response.stream: flow.response.content = CONTENT_MISSING else: - protocol = http1.HTTP1Protocol(rfile=self.c.server_conn.rfile) + protocol = http1.HTTP1Protocol(self.c.server_conn) flow.response.content = protocol.read_http_body( flow.response.headers, self.c.config.body_size_limit, @@ -1108,10 +235,10 @@ class HTTPHandler(ProtocolHandler): flow = HTTPFlow(self.c.client_conn, self.c.server_conn, self.live) try: try: - req = HTTPRequest.from_stream( - self.c.client_conn.rfile, - body_size_limit=self.c.config.body_size_limit, - wfile=self.c.client_conn.wfile + protocol = http.http1.HTTP1Protocol(self.c.client_conn) + req = HTTPRequest.from_protocol( + protocol, + body_size_limit=self.c.config.body_size_limit ) except tcp.NetLibError: # don't throw an error for disconnects that happen @@ -1601,10 +728,12 @@ class RequestReplayThread(threading.Thread): r.form_out = "relative" server.send(r.assemble()) self.flow.server_conn = server - self.flow.response = HTTPResponse.from_stream( - server.rfile, + + protocol = http.http1.HTTP1Protocol(server) + self.flow.response = HTTPResponse.from_protocol( + protocol, r.method, - body_size_limit=self.config.body_size_limit + body_size_limit=self.config.body_size_limit, ) if self.channel: response_reply = self.channel.ask("response", self.flow) diff --git a/libmproxy/protocol/http_wrappers.py b/libmproxy/protocol/http_wrappers.py new file mode 100644 index 00000000..7d3e3706 --- /dev/null +++ b/libmproxy/protocol/http_wrappers.py @@ -0,0 +1,829 @@ +from __future__ import absolute_import +import Cookie +import copy +import threading +import time +import urllib +import urlparse +from email.utils import parsedate_tz, formatdate, mktime_tz + +import netlib +from netlib import http, tcp, odict, utils +from netlib.http import cookies, semantics, http1 + +from .tcp import TCPHandler +from .primitives import KILL, ProtocolHandler, Flow, Error +from ..proxy.connection import ServerConnection +from .. import encoding, utils, controller, stateobject, proxy + + +HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" +HDR_FORM_MULTIPART = "multipart/form-data" +CONTENT_MISSING = 0 + + +class decoded(object): + """ + A context manager that decodes a request or response, and then + re-encodes it with the same encoding after execution of the block. + + Example: + with decoded(request): + request.content = request.content.replace("foo", "bar") + """ + + def __init__(self, o): + self.o = o + ce = o.headers.get_first("content-encoding") + if ce in encoding.ENCODINGS: + self.ce = ce + else: + self.ce = None + + def __enter__(self): + if self.ce: + self.o.decode() + + def __exit__(self, type, value, tb): + if self.ce: + self.o.encode(self.ce) + + +class MessageMixin(stateobject.StateObject): + _stateobject_attributes = dict( + httpversion=tuple, + headers=odict.ODictCaseless, + body=str, + timestamp_start=float, + timestamp_end=float + ) + _stateobject_long_attributes = {"body"} + + def get_state(self, short=False): + ret = super(MessageMixin, self).get_state(short) + if short: + if self.body: + ret["contentLength"] = len(self.body) + elif self.body == CONTENT_MISSING: + ret["contentLength"] = None + else: + ret["contentLength"] = 0 + return ret + + def get_decoded_content(self): + """ + Returns the decoded content based on the current Content-Encoding + header. + Doesn't change the message iteself or its headers. + """ + ce = self.headers.get_first("content-encoding") + if not self.body or ce not in encoding.ENCODINGS: + return self.body + return encoding.decode(ce, self.body) + + def decode(self): + """ + Decodes body based on the current Content-Encoding header, then + removes the header. If there is no Content-Encoding header, no + action is taken. + + Returns True if decoding succeeded, False otherwise. + """ + ce = self.headers.get_first("content-encoding") + if not self.body or ce not in encoding.ENCODINGS: + return False + data = encoding.decode(ce, self.body) + if data is None: + return False + self.body = data + del self.headers["content-encoding"] + return True + + def encode(self, e): + """ + Encodes body with the encoding e, where e is "gzip", "deflate" + or "identity". + """ + # FIXME: Error if there's an existing encoding header? + self.body = encoding.encode(e, self.body) + self.headers["content-encoding"] = [e] + + def size(self, **kwargs): + """ + Size in bytes of a fully rendered message, including headers and + HTTP lead-in. + """ + hl = len(self._assemble_head(**kwargs)) + if self.body: + return hl + len(self.body) + else: + return hl + + def copy(self): + c = copy.copy(self) + c.headers = self.headers.copy() + return c + + def replace(self, pattern, repl, *args, **kwargs): + """ + Replaces a regular expression pattern with repl in both the headers + and the body of the message. Encoded body will be decoded + before replacement, and re-encoded afterwards. + + Returns the number of replacements made. + """ + with decoded(self): + self.body, c = utils.safe_subn( + pattern, repl, self.body, *args, **kwargs + ) + c += self.headers.replace(pattern, repl, *args, **kwargs) + return c + + def _assemble_first_line(self): + """ + Returns the assembled request/response line + """ + raise NotImplementedError() # pragma: nocover + + def _assemble_headers(self): + """ + Returns the assembled headers + """ + raise NotImplementedError() # pragma: nocover + + def _assemble_head(self): + """ + Returns the assembled request/response line plus headers + """ + raise NotImplementedError() # pragma: nocover + + def assemble(self): + """ + Returns the assembled request/response + """ + raise NotImplementedError() # pragma: nocover + + +class HTTPRequest(MessageMixin, semantics.Request): + """ + An HTTP request. + + Exposes the following attributes: + + method: HTTP method + + scheme: URL scheme (http/https) + + host: Target hostname of the request. This is not neccessarily the + directy upstream server (which could be another proxy), but it's always + the target server we want to reach at the end. This attribute is either + inferred from the request itself (absolute-form, authority-form) or from + the connection metadata (e.g. the host in reverse proxy mode). + + port: Destination port + + path: Path portion of the URL (not present in authority-form) + + httpversion: HTTP version tuple, e.g. (1,1) + + headers: odict.ODictCaseless object + + content: Content of the request, None, or CONTENT_MISSING if there + is content associated, but not present. CONTENT_MISSING evaluates + to False to make checking for the presence of content natural. + + form_in: The request form which mitmproxy has received. The following + values are possible: + + - relative (GET /index.html, OPTIONS *) (covers origin form and + asterisk form) + - absolute (GET http://example.com:80/index.html) + - authority-form (CONNECT example.com:443) + Details: http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-25#section-5.3 + + form_out: The request form which mitmproxy will send out to the + destination + + timestamp_start: Timestamp indicating when request transmission started + + timestamp_end: Timestamp indicating when request transmission ended + """ + + def __init__( + self, + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers, + body, + timestamp_start=None, + timestamp_end=None, + form_out=None, + ): + semantics.Request.__init__( + self, + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers, + body, + timestamp_start, + timestamp_end, + ) + self.form_out = form_out or form_in + + # Have this request's cookies been modified by sticky cookies or auth? + self.stickycookie = False + self.stickyauth = False + + # Is this request replayed? + self.is_replay = False + + _stateobject_attributes = MessageMixin._stateobject_attributes.copy() + _stateobject_attributes.update( + form_in=str, + method=str, + scheme=str, + host=str, + port=int, + path=str, + form_out=str, + is_replay=bool + ) + + # This list is adopted legacy code. + # We probably don't need to strip off keep-alive. + _headers_to_strip_off = ['Proxy-Connection', + 'Keep-Alive', + 'Connection', + 'Transfer-Encoding', + 'Upgrade'] + + @classmethod + def from_state(cls, state): + f = cls( + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None) + f.load_state(state) + return f + + def __repr__(self): + return "".format( + self._assemble_first_line(self.form_in)[:-9] + ) + + @classmethod + def from_protocol( + self, + protocol, + include_body=True, + body_size_limit=None, + ): + req = protocol.read_request( + include_body = include_body, + body_size_limit = body_size_limit, + ) + + return HTTPRequest( + req.form_in, + req.method, + req.scheme, + req.host, + req.port, + req.path, + req.httpversion, + req.headers, + req.body, + req.timestamp_start, + req.timestamp_end, + ) + + def _assemble_first_line(self, form=None): + form = form or self.form_out + + if form == "relative": + request_line = '%s %s HTTP/%s.%s' % ( + self.method, self.path, self.httpversion[0], self.httpversion[1] + ) + elif form == "authority": + request_line = '%s %s:%s HTTP/%s.%s' % ( + self.method, self.host, self.port, self.httpversion[0], + self.httpversion[1] + ) + elif form == "absolute": + request_line = '%s %s://%s:%s%s HTTP/%s.%s' % ( + self.method, self.scheme, self.host, + self.port, self.path, self.httpversion[0], + self.httpversion[1] + ) + else: + raise http.HttpError(400, "Invalid request form") + return request_line + + def _assemble_headers(self): + headers = self.headers.copy() + for k in self._headers_to_strip_off: + del headers[k] + if 'host' not in headers and self.scheme and self.host and self.port: + headers["Host"] = [utils.hostport(self.scheme, + self.host, + self.port)] + + # If content is defined (i.e. not None or CONTENT_MISSING), we always + # add a content-length header. + if self.body or self.body == "": + headers["Content-Length"] = [str(len(self.body))] + + return headers.format() + + def _assemble_head(self, form=None): + return "%s\r\n%s\r\n" % ( + self._assemble_first_line(form), self._assemble_headers() + ) + + def assemble(self, form=None): + """ + Assembles the request for transmission to the server. We make some + modifications to make sure interception works properly. + + Raises an Exception if the request cannot be assembled. + """ + if self.body == CONTENT_MISSING: + raise proxy.ProxyError( + 502, + "Cannot assemble flow with CONTENT_MISSING" + ) + head = self._assemble_head(form) + if self.body: + return head + self.body + else: + return head + + def __hash__(self): + return id(self) + + def anticache(self): + """ + Modifies this request to remove headers that might produce a cached + response. That is, we remove ETags and If-Modified-Since headers. + """ + delheaders = [ + "if-modified-since", + "if-none-match", + ] + for i in delheaders: + del self.headers[i] + + def anticomp(self): + """ + Modifies this request to remove headers that will compress the + resource's data. + """ + self.headers["accept-encoding"] = ["identity"] + + def constrain_encoding(self): + """ + Limits the permissible Accept-Encoding values, based on what we can + decode appropriately. + """ + if self.headers["accept-encoding"]: + self.headers["accept-encoding"] = [ + ', '.join( + e for e in encoding.ENCODINGS if e in self.headers["accept-encoding"][0])] + + def update_host_header(self): + """ + Update the host header to reflect the current target. + """ + self.headers["Host"] = [self.host] + + def get_form(self): + """ + Retrieves the URL-encoded or multipart form data, returning an ODict object. + Returns an empty ODict if there is no data or the content-type + indicates non-form data. + """ + if self.body: + if self.headers.in_any("content-type", HDR_FORM_URLENCODED, True): + return self.get_form_urlencoded() + elif self.headers.in_any("content-type", HDR_FORM_MULTIPART, True): + return self.get_form_multipart() + return odict.ODict([]) + + def get_form_urlencoded(self): + """ + Retrieves the URL-encoded form data, returning an ODict object. + Returns an empty ODict if there is no data or the content-type + indicates non-form data. + """ + if self.body and self.headers.in_any( + "content-type", + HDR_FORM_URLENCODED, + True): + return odict.ODict(utils.urldecode(self.body)) + return odict.ODict([]) + + def get_form_multipart(self): + if self.body and self.headers.in_any( + "content-type", + HDR_FORM_MULTIPART, + True): + return odict.ODict( + utils.multipartdecode( + self.headers, + self.body)) + return odict.ODict([]) + + def set_form_urlencoded(self, odict): + """ + Sets the body to the URL-encoded form data, and adds the + appropriate content-type header. Note that this will destory the + existing body if there is one. + """ + # FIXME: If there's an existing content-type header indicating a + # url-encoded form, leave it alone. + self.headers["Content-Type"] = [HDR_FORM_URLENCODED] + self.body = utils.urlencode(odict.lst) + + def get_path_components(self): + """ + Returns the path components of the URL as a list of strings. + + Components are unquoted. + """ + _, _, path, _, _, _ = urlparse.urlparse(self.url) + return [urllib.unquote(i) for i in path.split("/") if i] + + def set_path_components(self, lst): + """ + Takes a list of strings, and sets the path component of the URL. + + Components are quoted. + """ + lst = [urllib.quote(i, safe="") for i in lst] + path = "/" + "/".join(lst) + scheme, netloc, _, params, query, fragment = urlparse.urlparse(self.url) + self.url = urlparse.urlunparse( + [scheme, netloc, path, params, query, fragment] + ) + + def get_query(self): + """ + Gets the request query string. Returns an ODict object. + """ + _, _, _, _, query, _ = urlparse.urlparse(self.url) + if query: + return odict.ODict(utils.urldecode(query)) + return odict.ODict([]) + + def set_query(self, odict): + """ + Takes an ODict object, and sets the request query string. + """ + scheme, netloc, path, params, _, fragment = urlparse.urlparse(self.url) + query = utils.urlencode(odict.lst) + self.url = urlparse.urlunparse( + [scheme, netloc, path, params, query, fragment] + ) + + def pretty_host(self, hostheader): + """ + Heuristic to get the host of the request. + + Note that pretty_host() does not always return the TCP destination + of the request, e.g. if an upstream proxy is in place + + If hostheader is set to True, the Host: header will be used as + additional (and preferred) data source. This is handy in + transparent mode, where only the IO of the destination is known, + but not the resolved name. This is disabled by default, as an + attacker may spoof the host header to confuse an analyst. + """ + host = None + if hostheader: + host = self.headers.get_first("host") + if not host: + host = self.host + if host: + try: + return host.encode("idna") + except ValueError: + return host + else: + return None + + def pretty_url(self, hostheader): + if self.form_out == "authority": # upstream proxy mode + return "%s:%s" % (self.pretty_host(hostheader), self.port) + return utils.unparse_url(self.scheme, + self.pretty_host(hostheader), + self.port, + self.path).encode('ascii') + + @property + def url(self): + """ + Returns a URL string, constructed from the Request's URL components. + """ + return utils.unparse_url( + self.scheme, + self.host, + self.port, + self.path + ).encode('ascii') + + @url.setter + def url(self, url): + """ + Parses a URL specification, and updates the Request's information + accordingly. + + Returns False if the URL was invalid, True if the request succeeded. + """ + parts = http.parse_url(url) + if not parts: + raise ValueError("Invalid URL: %s" % url) + self.scheme, self.host, self.port, self.path = parts + + def get_cookies(self): + """ + + Returns a possibly empty netlib.odict.ODict object. + """ + ret = odict.ODict() + for i in self.headers["cookie"]: + ret.extend(cookies.parse_cookie_header(i)) + return ret + + def set_cookies(self, odict): + """ + Takes an netlib.odict.ODict object. Over-writes any existing Cookie + headers. + """ + v = cookies.format_cookie_header(odict) + self.headers["Cookie"] = [v] + + def replace(self, pattern, repl, *args, **kwargs): + """ + Replaces a regular expression pattern with repl in the headers, the + request path and the body of the request. Encoded content will be + decoded before replacement, and re-encoded afterwards. + + Returns the number of replacements made. + """ + c = MessageMixin.replace(self, pattern, repl, *args, **kwargs) + self.path, pc = utils.safe_subn( + pattern, repl, self.path, *args, **kwargs + ) + c += pc + return c + + +class HTTPResponse(MessageMixin, semantics.Response): + """ + An HTTP response. + + Exposes the following attributes: + + httpversion: HTTP version tuple, e.g. (1, 0), (1, 1), or (2, 0) + + status_code: HTTP response status code + + msg: HTTP response message + + headers: ODict Caseless object + + content: Content of the request, None, or CONTENT_MISSING if there + is content associated, but not present. CONTENT_MISSING evaluates + to False to make checking for the presence of content natural. + + timestamp_start: Timestamp indicating when request transmission started + + timestamp_end: Timestamp indicating when request transmission ended + """ + + def __init__( + self, + httpversion, + status_code, + msg, + headers, + body, + timestamp_start=None, + timestamp_end=None, + ): + semantics.Response.__init__( + self, + httpversion, + status_code, + msg, + headers, + body, + timestamp_start=timestamp_start, + timestamp_end=timestamp_end, + ) + + # Is this request replayed? + self.is_replay = False + self.stream = False + + _stateobject_attributes = MessageMixin._stateobject_attributes.copy() + _stateobject_attributes.update( + code=int, + msg=str + ) + + _headers_to_strip_off = ['Proxy-Connection', + 'Alternate-Protocol', + 'Alt-Svc'] + + + @classmethod + def from_state(cls, state): + f = cls(None, None, None, None, None) + f.load_state(state) + return f + + def __repr__(self): + if self.body: + size = netlib.utils.pretty_size(len(self.body)) + else: + size = "content missing" + return "".format( + status_code=self.status_code, + msg=self.msg, + contenttype=self.headers.get_first( + "content-type", "unknown content type" + ), + size=size + ) + + @classmethod + def from_protocol( + self, + protocol, + request_method, + include_body=True, + body_size_limit=None + ): + resp = protocol.read_response( + request_method, + body_size_limit, + include_body=include_body + ) + + return HTTPResponse( + resp.httpversion, + resp.status_code, + resp.msg, + resp.headers, + resp.body, + resp.timestamp_start, + resp.timestamp_end, + ) + + def _assemble_first_line(self): + return 'HTTP/%s.%s %s %s' % \ + (self.httpversion[0], self.httpversion[1], self.code, self.msg) + + def _assemble_headers(self, preserve_transfer_encoding=False): + headers = self.headers.copy() + for k in self._headers_to_strip_off: + del headers[k] + if not preserve_transfer_encoding: + del headers['Transfer-Encoding'] + + # If body is defined (i.e. not None or CONTENT_MISSING), we always + # add a content-length header. + if self.body or self.body == "": + headers["Content-Length"] = [str(len(self.body))] + + return headers.format() + + def _assemble_head(self, preserve_transfer_encoding=False): + return '%s\r\n%s\r\n' % ( + self._assemble_first_line(), + self._assemble_headers( + preserve_transfer_encoding=preserve_transfer_encoding + ) + ) + + def assemble(self): + """ + Assembles the response for transmission to the client. We make some + modifications to make sure interception works properly. + + Raises an Exception if the request cannot be assembled. + """ + if self.body == CONTENT_MISSING: + raise proxy.ProxyError( + 502, + "Cannot assemble flow with CONTENT_MISSING" + ) + head = self._assemble_head() + if self.body: + return head + self.body + else: + return head + + def _refresh_cookie(self, c, delta): + """ + Takes a cookie string c and a time delta in seconds, and returns + a refreshed cookie string. + """ + c = Cookie.SimpleCookie(str(c)) + for i in c.values(): + if "expires" in i: + d = parsedate_tz(i["expires"]) + if d: + d = mktime_tz(d) + delta + i["expires"] = formatdate(d) + else: + # This can happen when the expires tag is invalid. + # reddit.com sends a an expires tag like this: "Thu, 31 Dec + # 2037 23:59:59 GMT", which is valid RFC 1123, but not + # strictly correct according to the cookie spec. Browsers + # appear to parse this tolerantly - maybe we should too. + # For now, we just ignore this. + del i["expires"] + return c.output(header="").strip() + + def refresh(self, now=None): + """ + This fairly complex and heuristic function refreshes a server + response for replay. + + - It adjusts date, expires and last-modified headers. + - It adjusts cookie expiration. + """ + if not now: + now = time.time() + delta = now - self.timestamp_start + refresh_headers = [ + "date", + "expires", + "last-modified", + ] + for i in refresh_headers: + if i in self.headers: + d = parsedate_tz(self.headers[i][0]) + if d: + new = mktime_tz(d) + delta + self.headers[i] = [formatdate(new)] + c = [] + for i in self.headers["set-cookie"]: + c.append(self._refresh_cookie(i, delta)) + if c: + self.headers["set-cookie"] = c + + def get_cookies(self): + """ + Get the contents of all Set-Cookie headers. + + Returns a possibly empty ODict, where keys are cookie name strings, + and values are [value, attr] lists. Value is a string, and attr is + an ODictCaseless containing cookie attributes. Within attrs, unary + attributes (e.g. HTTPOnly) are indicated by a Null value. + """ + ret = [] + for header in self.headers["set-cookie"]: + v = http.cookies.parse_set_cookie_header(header) + if v: + name, value, attrs = v + ret.append([name, [value, attrs]]) + return odict.ODict(ret) + + def set_cookies(self, odict): + """ + Set the Set-Cookie headers on this response, over-writing existing + headers. + + Accepts an ODict of the same format as that returned by get_cookies. + """ + values = [] + for i in odict.lst: + values.append( + http.cookies.format_set_cookie_header( + i[0], + i[1][0], + i[1][1] + ) + ) + self.headers["Set-Cookie"] = values diff --git a/libmproxy/proxy/connection.py b/libmproxy/proxy/connection.py index 5219023b..54b3688e 100644 --- a/libmproxy/proxy/connection.py +++ b/libmproxy/proxy/connection.py @@ -68,7 +68,15 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject): return f def convert_to_ssl(self, *args, **kwargs): - tcp.BaseHandler.convert_to_ssl(self, *args, **kwargs) + def alpn_select_callback(conn_, options): + if alpn_select in options: + return bytes(alpn_select) + else: # pragma no cover + return options[0] + + # TODO: read ALPN from server and select same proto for client conn + + tcp.BaseHandler.convert_to_ssl(self, alpn_select=alpn_select_callback, *args, **kwargs) self.timestamp_ssl_setup = utils.timestamp() def finish(self): @@ -160,7 +168,10 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): self.address.host.encode("idna")) + ".pem" if os.path.exists(path): clientcert = path - self.convert_to_ssl(cert=clientcert, sni=sni, **kwargs) + + # TODO: read ALPN from client and use same list for server conn + + self.convert_to_ssl(cert=clientcert, sni=sni, alpn_protos=['h2'], **kwargs) self.sni = sni self.timestamp_ssl_setup = utils.timestamp() diff --git a/test/test_protocol_http.py b/test/test_protocol_http.py index 747fdc1e..18238593 100644 --- a/test/test_protocol_http.py +++ b/test/test_protocol_http.py @@ -1,13 +1,21 @@ +import cStringIO from cStringIO import StringIO from mock import MagicMock from libmproxy.protocol.http import * from netlib import odict +from netlib.http import http1 import tutils import tservers +def mock_protocol(data='', chunked=False): + rfile = cStringIO.StringIO(data) + wfile = cStringIO.StringIO() + return http1.HTTP1Protocol(rfile=rfile, wfile=wfile) + + def test_HttpAuthenticationError(): x = HttpAuthenticationError({"foo": "bar"}) @@ -30,9 +38,10 @@ def test_stripped_chunked_encoding_no_content(): class TestHTTPRequest: def test_asterisk_form_in(self): - s = StringIO("OPTIONS * HTTP/1.1") f = tutils.tflow(req=None) - f.request = HTTPRequest.from_stream(s) + protocol = mock_protocol("OPTIONS * HTTP/1.1") + f.request = HTTPRequest.from_protocol(protocol) + assert f.request.form_in == "relative" f.request.host = f.server_conn.address.host f.request.port = f.server_conn.address.port @@ -42,10 +51,11 @@ class TestHTTPRequest: "Content-Length: 0\r\n\r\n") def test_relative_form_in(self): - s = StringIO("GET /foo\xff HTTP/1.1") - tutils.raises("Bad HTTP request line", HTTPRequest.from_stream, s) - s = StringIO("GET /foo HTTP/1.1\r\nConnection: Upgrade\r\nUpgrade: h2c") - r = HTTPRequest.from_stream(s) + protocol = mock_protocol("GET /foo\xff HTTP/1.1") + tutils.raises("Bad HTTP request line", HTTPRequest.from_protocol, protocol) + + protocol = mock_protocol("GET /foo HTTP/1.1\r\nConnection: Upgrade\r\nUpgrade: h2c") + r = HTTPRequest.from_protocol(protocol) assert r.headers["Upgrade"] == ["h2c"] raw = r._assemble_headers() @@ -61,19 +71,19 @@ class TestHTTPRequest: assert "Host" in r.headers def test_expect_header(self): - s = StringIO( + protocol = mock_protocol( "GET / HTTP/1.1\r\nContent-Length: 3\r\nExpect: 100-continue\r\n\r\nfoobar") - w = StringIO() - r = HTTPRequest.from_stream(s, wfile=w) - assert w.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" + r = HTTPRequest.from_protocol(protocol) + assert protocol.tcp_handler.wfile.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" assert r.content == "foo" - assert s.read(3) == "bar" + assert protocol.tcp_handler.rfile.read(3) == "bar" def test_authority_form_in(self): - s = StringIO("CONNECT oops-no-port.com HTTP/1.1") - tutils.raises("Bad HTTP request line", HTTPRequest.from_stream, s) - s = StringIO("CONNECT address:22 HTTP/1.1") - r = HTTPRequest.from_stream(s) + protocol = mock_protocol("CONNECT oops-no-port.com HTTP/1.1") + tutils.raises("Bad HTTP request line", HTTPRequest.from_protocol, protocol) + + protocol = mock_protocol("CONNECT address:22 HTTP/1.1") + r = HTTPRequest.from_protocol(protocol) r.scheme, r.host, r.port = "http", "address", 22 assert r.assemble() == ("CONNECT address:22 HTTP/1.1\r\n" "Host: address:22\r\n" @@ -81,10 +91,11 @@ class TestHTTPRequest: assert r.pretty_url(False) == "address:22" def test_absolute_form_in(self): - s = StringIO("GET oops-no-protocol.com HTTP/1.1") - tutils.raises("Bad HTTP request line", HTTPRequest.from_stream, s) - s = StringIO("GET http://address:22/ HTTP/1.1") - r = HTTPRequest.from_stream(s) + protocol = mock_protocol("GET oops-no-protocol.com HTTP/1.1") + tutils.raises("Bad HTTP request line", HTTPRequest.from_protocol, protocol) + + protocol = mock_protocol("GET http://address:22/ HTTP/1.1") + r = HTTPRequest.from_protocol(protocol) assert r.assemble( ) == "GET http://address:22/ HTTP/1.1\r\nHost: address:22\r\nContent-Length: 0\r\n\r\n" @@ -92,8 +103,8 @@ class TestHTTPRequest: """ Exercises fix for Issue #392. """ - s = StringIO("OPTIONS /secret/resource HTTP/1.1") - r = HTTPRequest.from_stream(s) + protocol = mock_protocol("OPTIONS /secret/resource HTTP/1.1") + r = HTTPRequest.from_protocol(protocol) r.host = 'address' r.port = 80 r.scheme = "http" @@ -102,8 +113,8 @@ class TestHTTPRequest: "Content-Length: 0\r\n\r\n") def test_http_options_absolute_form_in(self): - s = StringIO("OPTIONS http://address/secret/resource HTTP/1.1") - r = HTTPRequest.from_stream(s) + protocol = mock_protocol("OPTIONS http://address/secret/resource HTTP/1.1") + r = HTTPRequest.from_protocol(protocol) r.host = 'address' r.port = 80 r.scheme = "http" @@ -216,26 +227,27 @@ class TestHTTPRequest: class TestHTTPResponse: def test_read_from_stringio(self): - _s = "HTTP/1.1 200 OK\r\n" \ + s = "HTTP/1.1 200 OK\r\n" \ "Content-Length: 7\r\n" \ "\r\n"\ "content\r\n" \ "HTTP/1.1 204 OK\r\n" \ "\r\n" - s = StringIO(_s) - r = HTTPResponse.from_stream(s, "GET") - assert r.code == 200 + + protocol = mock_protocol(s) + r = HTTPResponse.from_protocol(protocol, "GET") + assert r.status_code == 200 assert r.content == "content" - assert HTTPResponse.from_stream(s, "GET").code == 204 + assert HTTPResponse.from_protocol(protocol, "GET").status_code == 204 - s = StringIO(_s) + protocol = mock_protocol(s) # HEAD must not have content by spec. We should leave it on the pipe. - r = HTTPResponse.from_stream(s, "HEAD") - assert r.code == 200 + r = HTTPResponse.from_protocol(protocol, "HEAD") + assert r.status_code == 200 assert r.content == "" tutils.raises( "Invalid server response: 'content", - HTTPResponse.from_stream, s, "GET" + HTTPResponse.from_protocol, protocol, "GET" ) def test_repr(self): -- cgit v1.2.3