diff options
author | Maximilian Hils <git@maximilianhils.com> | 2015-09-05 18:15:47 +0200 |
---|---|---|
committer | Maximilian Hils <git@maximilianhils.com> | 2015-09-05 18:15:47 +0200 |
commit | 66ee1f465f6c492d5a4ff5659e6f0346fb243d67 (patch) | |
tree | 81599af6bf38402059dcf6f387dfcf9b599c375e /netlib | |
parent | 3718e59308745e4582f4e8061b4ff6113d9dfc74 (diff) | |
download | mitmproxy-66ee1f465f6c492d5a4ff5659e6f0346fb243d67.tar.gz mitmproxy-66ee1f465f6c492d5a4ff5659e6f0346fb243d67.tar.bz2 mitmproxy-66ee1f465f6c492d5a4ff5659e6f0346fb243d67.zip |
headers: adjust everything
Diffstat (limited to 'netlib')
-rw-r--r-- | netlib/http/authentication.py | 4 | ||||
-rw-r--r-- | netlib/http/exceptions.py | 18 | ||||
-rw-r--r-- | netlib/http/http1/protocol.py | 41 | ||||
-rw-r--r-- | netlib/http/http2/protocol.py | 44 | ||||
-rw-r--r-- | netlib/http/semantics.py | 148 | ||||
-rw-r--r-- | netlib/tutils.py | 10 | ||||
-rw-r--r-- | netlib/utils.py | 13 | ||||
-rw-r--r-- | netlib/websockets/protocol.py | 28 | ||||
-rw-r--r-- | netlib/wsgi.py | 22 |
9 files changed, 155 insertions, 173 deletions
diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py index 29b9eb3c..fe1f0d14 100644 --- a/netlib/http/authentication.py +++ b/netlib/http/authentication.py @@ -62,10 +62,10 @@ class BasicProxyAuth(NullProxyAuth): del headers[self.AUTH_HEADER] def authenticate(self, headers): - auth_value = headers.get(self.AUTH_HEADER, []) + auth_value = headers.get(self.AUTH_HEADER) if not auth_value: return False - parts = parse_http_basic_auth(auth_value[0]) + parts = parse_http_basic_auth(auth_value) if not parts: return False scheme, username, password = parts diff --git a/netlib/http/exceptions.py b/netlib/http/exceptions.py index 987a7908..8a2bbebc 100644 --- a/netlib/http/exceptions.py +++ b/netlib/http/exceptions.py @@ -1,6 +1,3 @@ -from netlib import odict - - class HttpError(Exception): def __init__(self, code, message): @@ -10,18 +7,3 @@ class HttpError(Exception): class HttpErrorConnClosed(HttpError): pass - - -class HttpAuthenticationError(Exception): - - def __init__(self, auth_headers=None): - super(HttpAuthenticationError, self).__init__( - "Proxy Authentication Required" - ) - if isinstance(auth_headers, dict): - auth_headers = odict.ODictCaseless(auth_headers.items()) - self.headers = auth_headers - self.code = 407 - - def __repr__(self): - return "Proxy Authentication Required" diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 50975818..bf33a18e 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -3,8 +3,8 @@ import string import sys import time -from netlib import odict, utils, tcp, http -from netlib.http import semantics +from ... import utils, tcp, http +from .. import semantics, Headers from ..exceptions import * @@ -96,7 +96,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): if headers is None: raise HttpError(400, "Invalid headers") - expect_header = headers.get_first("expect", "").lower() + expect_header = headers.get("expect", "").lower() if expect_header == "100-continue" and httpversion == (1, 1): self.tcp_handler.wfile.write( 'HTTP/1.1 100 Continue\r\n' @@ -232,10 +232,9 @@ class HTTP1Protocol(semantics.ProtocolMixin): Read a set of headers. Stop once a blank line is reached. - Return a ODictCaseless object, or None if headers are invalid. + Return a Header object, or None if headers are invalid. """ ret = [] - name = '' while True: line = self.tcp_handler.rfile.readline() if not line or line == '\r\n' or line == '\n': @@ -254,7 +253,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): ret.append([name, value]) else: return None - return odict.ODictCaseless(ret) + return Headers(ret) def read_http_body(self, *args, **kwargs): @@ -272,7 +271,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): ): """ Read an HTTP message body: - headers: An ODictCaseless object + headers: A Header object limit: Size limit. is_request: True if the body to read belongs to a request, False otherwise @@ -356,7 +355,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): return None if "content-length" in headers: try: - size = int(headers["content-length"][0]) + size = int(headers["content-length"]) if size < 0: raise ValueError() return size @@ -369,9 +368,7 @@ class HTTP1Protocol(semantics.ProtocolMixin): @classmethod def has_chunked_encoding(self, headers): - return "chunked" in [ - i.lower() for i in utils.get_header_tokens(headers, "transfer-encoding") - ] + return "chunked" in headers.get("transfer-encoding", "").lower() def _get_request_line(self): @@ -547,18 +544,20 @@ class HTTP1Protocol(semantics.ProtocolMixin): def _assemble_request_headers(self, request): headers = request.headers.copy() for k in request._headers_to_strip_off: - del headers[k] + headers.pop(k, None) if 'host' not in headers and request.scheme and request.host and request.port: - headers["Host"] = [utils.hostport(request.scheme, - request.host, - request.port)] + headers["Host"] = utils.hostport( + request.scheme, + request.host, + request.port + ) # If content is defined (i.e. not None or CONTENT_MISSING), we always # add a content-length header. if request.body or request.body == "": - headers["Content-Length"] = [str(len(request.body))] + headers["Content-Length"] = str(len(request.body)) - return headers.format() + return str(headers) def _assemble_response_first_line(self, response): return 'HTTP/%s.%s %s %s' % ( @@ -575,13 +574,13 @@ class HTTP1Protocol(semantics.ProtocolMixin): ): headers = response.headers.copy() for k in response._headers_to_strip_off: - del headers[k] + headers.pop(k, None) if not preserve_transfer_encoding: - del headers['Transfer-Encoding'] + headers.pop('Transfer-Encoding', None) # If body is defined (i.e. not None or CONTENT_MISSING), we always # add a content-length header. if response.body or response.body == "": - headers["Content-Length"] = [str(len(response.body))] + headers["Content-Length"] = str(len(response.body)) - return headers.format() + return str(headers) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index b297e0b8..f3254caa 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -3,7 +3,7 @@ import itertools import time from hpack.hpack import Encoder, Decoder -from netlib import http, utils, odict +from netlib import http, utils from netlib.http import semantics from . import frame @@ -85,10 +85,10 @@ class HTTP2Protocol(semantics.ProtocolMixin): timestamp_end = time.time() - authority = headers.get_first(':authority', '') - method = headers.get_first(':method', 'GET') - scheme = headers.get_first(':scheme', 'https') - path = headers.get_first(':path', '/') + authority = headers.get(':authority', '') + method = headers.get(':method', 'GET') + scheme = headers.get(':scheme', 'https') + path = headers.get(':path', '/') host = None port = None @@ -161,7 +161,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): response = http.Response( (2, 0), - int(headers.get_first(':status')), + int(headers.get(':status', 502)), "", headers, body, @@ -181,16 +181,14 @@ class HTTP2Protocol(semantics.ProtocolMixin): headers = request.headers.copy() - if ':authority' not in headers.keys(): - headers.add(':authority', bytes(authority), prepend=True) - if ':scheme' not in headers.keys(): - headers.add(':scheme', bytes(request.scheme), prepend=True) - if ':path' not in headers.keys(): - headers.add(':path', bytes(request.path), prepend=True) - if ':method' not in headers.keys(): - headers.add(':method', bytes(request.method), prepend=True) - - headers = headers.items() + if ':authority' not in headers: + headers.fields.insert(0, (':authority', bytes(authority))) + if ':scheme' not in headers: + headers.fields.insert(0, (':scheme', bytes(request.scheme))) + if ':path' not in headers: + headers.fields.insert(0, (':path', bytes(request.path))) + if ':method' not in headers: + headers.fields.insert(0, (':method', bytes(request.method))) if hasattr(request, 'stream_id'): stream_id = request.stream_id @@ -206,10 +204,8 @@ class HTTP2Protocol(semantics.ProtocolMixin): headers = response.headers.copy() - if ':status' not in headers.keys(): - headers.add(':status', bytes(str(response.status_code)), prepend=True) - - headers = headers.items() + if ':status' not in headers: + headers.fields.insert(0, (':status', bytes(str(response.status_code)))) if hasattr(response, 'stream_id'): stream_id = response.stream_id @@ -329,7 +325,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): else: yield frame.ContinuationFrame, i - header_block_fragment = self.encoder.encode(headers) + header_block_fragment = self.encoder.encode(headers.fields) chunk_size = self.http2_settings[frame.SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] chunks = range(0, len(header_block_fragment), chunk_size) @@ -402,8 +398,8 @@ class HTTP2Protocol(semantics.ProtocolMixin): else: self._handle_unexpected_frame(frm) - headers = odict.ODictCaseless() - for header, value in self.decoder.decode(header_block_fragment): - headers.add(header, value) + headers = http.Headers( + [[str(k), str(v)] for k, v in self.decoder.decode(header_block_fragment)] + ) return stream_id, headers, body diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 2fadf2c4..edf5fc07 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -1,9 +1,10 @@ from __future__ import (absolute_import, print_function, division) import UserDict +import copy import urllib import urlparse -from .. import utils, odict +from .. import odict from . import cookies, exceptions from netlib import utils, encoding @@ -77,11 +78,11 @@ class Headers(UserDict.DictMixin): headers = { name.replace("_", "-"): value for name, value in headers.iteritems() - } + } self.update(headers) def __str__(self): - return "\r\n".join(": ".join(field) for field in self.fields) + return "\r\n".join(": ".join(field) for field in self.fields) + "\r\n" def __getitem__(self, name): values = self.get_all(name) @@ -107,7 +108,7 @@ class Headers(UserDict.DictMixin): self.fields = [ field for field in self.fields if name != field[0].lower() - ] + ] def _index(self, name): name = name.lower() @@ -134,7 +135,7 @@ class Headers(UserDict.DictMixin): def __ne__(self, other): return not self.__eq__(other) - def get_all(self, name, default=None): + def get_all(self, name, default=[]): """ Like :py:meth:`get`, but does not fold multiple headers into a single one. This is useful for Set-Cookie headers, which do not support folding. @@ -156,6 +157,9 @@ class Headers(UserDict.DictMixin): [name, value] for value in values ) + def copy(self): + return Headers(copy.copy(self.fields)) + # Implement the StateObject protocol from mitmproxy def get_state(self, short=False): return tuple(tuple(field) for field in self.fields) @@ -202,23 +206,23 @@ class Request(object): ] def __init__( - self, - form_in, - method, - scheme, - host, - port, - path, - httpversion, - headers=None, - body=None, - timestamp_start=None, - timestamp_end=None, - form_out=None + self, + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers=None, + body=None, + timestamp_start=None, + timestamp_end=None, + form_out=None ): if not headers: - headers = odict.ODictCaseless() - assert isinstance(headers, odict.ODictCaseless) + headers = Headers() + assert isinstance(headers, Headers) self.form_in = form_in self.method = method @@ -235,8 +239,10 @@ class Request(object): def __eq__(self, other): try: - self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] - other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + self_d = [self.__dict__[k] for k in self.__dict__ if + k not in ('timestamp_start', 'timestamp_end')] + other_d = [other.__dict__[k] for k in other.__dict__ if + k not in ('timestamp_start', 'timestamp_end')] return self_d == other_d except: return False @@ -289,30 +295,35 @@ class Request(object): "if-none-match", ] for i in delheaders: - del self.headers[i] + self.headers.pop(i, None) def anticomp(self): """ Modifies this request to remove headers that will compress the resource's data. """ - self.headers["accept-encoding"] = ["identity"] + self.headers["accept-encoding"] = "identity" def constrain_encoding(self): """ Limits the permissible Accept-Encoding values, based on what we can decode appropriately. """ - if self.headers["accept-encoding"]: - self.headers["accept-encoding"] = [ + accept_encoding = self.headers.get("accept-encoding") + if accept_encoding: + self.headers["accept-encoding"] = ( ', '.join( - e for e in encoding.ENCODINGS if e in self.headers.get_first("accept-encoding"))] + e + for e in encoding.ENCODINGS + if e in accept_encoding + ) + ) def update_host_header(self): """ Update the host header to reflect the current target. """ - self.headers["Host"] = [self.host] + self.headers["Host"] = self.host def get_form(self): """ @@ -321,9 +332,9 @@ class Request(object): indicates non-form data. """ if self.body: - if self.headers.in_any("content-type", HDR_FORM_URLENCODED, True): + if HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): return self.get_form_urlencoded() - elif self.headers.in_any("content-type", HDR_FORM_MULTIPART, True): + elif HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): return self.get_form_multipart() return odict.ODict([]) @@ -333,18 +344,12 @@ class Request(object): Returns an empty ODict if there is no data or the content-type indicates non-form data. """ - if self.body and self.headers.in_any( - "content-type", - HDR_FORM_URLENCODED, - True): + if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): return odict.ODict(utils.urldecode(self.body)) return odict.ODict([]) def get_form_multipart(self): - if self.body and self.headers.in_any( - "content-type", - HDR_FORM_MULTIPART, - True): + if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): return odict.ODict( utils.multipartdecode( self.headers, @@ -359,7 +364,7 @@ class Request(object): """ # FIXME: If there's an existing content-type header indicating a # url-encoded form, leave it alone. - self.headers["Content-Type"] = [HDR_FORM_URLENCODED] + self.headers["Content-Type"] = HDR_FORM_URLENCODED self.body = utils.urlencode(odict.lst) def get_path_components(self): @@ -418,7 +423,7 @@ class Request(object): """ host = None if hostheader: - host = self.headers.get_first("host") + host = self.headers.get("Host") if not host: host = self.host if host: @@ -442,7 +447,7 @@ class Request(object): Returns a possibly empty netlib.odict.ODict object. """ ret = odict.ODict() - for i in self.headers["cookie"]: + for i in self.headers.get_all("cookie"): ret.extend(cookies.parse_cookie_header(i)) return ret @@ -452,7 +457,7 @@ class Request(object): headers. """ v = cookies.format_cookie_header(odict) - self.headers["Cookie"] = [v] + self.headers["Cookie"] = v @property def url(self): @@ -491,18 +496,17 @@ class Request(object): class EmptyRequest(Request): - def __init__( - self, - form_in="", - method="", - scheme="", - host="", - port="", - path="", - httpversion=(0, 0), - headers=None, - body="" + self, + form_in="", + method="", + scheme="", + host="", + port="", + path="", + httpversion=(0, 0), + headers=None, + body="" ): super(EmptyRequest, self).__init__( form_in=form_in, @@ -512,7 +516,7 @@ class EmptyRequest(Request): port=port, path=path, httpversion=httpversion, - headers=(headers or odict.ODictCaseless()), + headers=headers, body=body, ) @@ -525,19 +529,19 @@ class Response(object): ] def __init__( - self, - httpversion, - status_code, - msg=None, - headers=None, - body=None, - sslinfo=None, - timestamp_start=None, - timestamp_end=None, + self, + httpversion, + status_code, + msg=None, + headers=None, + body=None, + sslinfo=None, + timestamp_start=None, + timestamp_end=None, ): if not headers: - headers = odict.ODictCaseless() - assert isinstance(headers, odict.ODictCaseless) + headers = Headers() + assert isinstance(headers, Headers) self.httpversion = httpversion self.status_code = status_code @@ -550,8 +554,10 @@ class Response(object): def __eq__(self, other): try: - self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] - other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + self_d = [self.__dict__[k] for k in self.__dict__ if + k not in ('timestamp_start', 'timestamp_end')] + other_d = [other.__dict__[k] for k in other.__dict__ if + k not in ('timestamp_start', 'timestamp_end')] return self_d == other_d except: return False @@ -567,9 +573,7 @@ class Response(object): return "<Response: {status_code} {msg} ({contenttype}, {size})>".format( status_code=self.status_code, msg=self.msg, - contenttype=self.headers.get_first( - "content-type", - "unknown content type"), + contenttype=self.headers.get("content-type", "unknown content type"), size=size) def get_cookies(self): @@ -582,7 +586,7 @@ class Response(object): attributes (e.g. HTTPOnly) are indicated by a Null value. """ ret = [] - for header in self.headers["set-cookie"]: + for header in self.headers.get_all("set-cookie"): v = cookies.parse_set_cookie_header(header) if v: name, value, attrs = v @@ -605,7 +609,7 @@ class Response(object): i[1][1] ) ) - self.headers["Set-Cookie"] = values + self.headers.set_all("Set-Cookie", values) @property def content(self): # pragma: no cover diff --git a/netlib/tutils.py b/netlib/tutils.py index 7434c108..951ef3d9 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -5,7 +5,7 @@ import time import shutil from contextlib import contextmanager -from netlib import tcp, utils, odict, http +from netlib import tcp, utils, http def treader(bytes): @@ -73,8 +73,8 @@ def treq(content="content", scheme="http", host="address", port=22): """ @return: libmproxy.protocol.http.HTTPRequest """ - headers = odict.ODictCaseless() - headers["header"] = ["qvalue"] + headers = http.Headers() + headers["header"] = "qvalue" req = http.Request( "relative", "GET", @@ -108,8 +108,8 @@ def tresp(content="message"): @return: libmproxy.protocol.http.HTTPResponse """ - headers = odict.ODictCaseless() - headers["header_response"] = ["svalue"] + headers = http.Headers() + headers["header_response"] = "svalue" resp = http.semantics.Response( (1, 1), diff --git a/netlib/utils.py b/netlib/utils.py index d6190673..aae187da 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -204,11 +204,10 @@ def get_header_tokens(headers, key): follow a pattern where each header line can containe comma-separated tokens, and headers can be set multiple times. """ - toks = [] - for i in headers[key]: - for j in i.split(","): - toks.append(j.strip()) - return toks + if key not in headers: + return [] + tokens = headers[key].split(",") + return [token.strip() for token in tokens] def hostport(scheme, host, port): @@ -270,11 +269,11 @@ def parse_content_type(c): return ts[0].lower(), ts[1].lower(), d -def multipartdecode(hdrs, content): +def multipartdecode(headers, content): """ Takes a multipart boundary encoded string and returns list of (key, value) tuples. """ - v = hdrs.get_first("content-type") + v = headers.get("content-type") if v: v = parse_content_type(v) if not v: diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py index 6ce32eac..46c02875 100644 --- a/netlib/websockets/protocol.py +++ b/netlib/websockets/protocol.py @@ -1,10 +1,5 @@ -from __future__ import absolute_import -import base64 -import hashlib -import os -from netlib import odict -from netlib import utils + # Colleciton of utility functions that implement small portions of the RFC6455 # WebSockets Protocol Useful for building WebSocket clients and servers. @@ -18,6 +13,13 @@ from netlib import utils # The magic sha that websocket servers must know to prove they understand # RFC6455 +from __future__ import absolute_import +import base64 +import hashlib +import os +from ..http import Headers +from .. import utils + websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' VERSION = "13" @@ -66,11 +68,11 @@ class WebsocketsProtocol(object): specified, it is generated, and can be found in sec-websocket-key in the returned header set. - Returns an instance of ODictCaseless + Returns an instance of Headers """ if not key: key = base64.b64encode(os.urandom(16)).decode('utf-8') - return odict.ODictCaseless([ + return Headers([ ('Connection', 'Upgrade'), ('Upgrade', 'websocket'), (HEADER_WEBSOCKET_KEY, key), @@ -82,7 +84,7 @@ class WebsocketsProtocol(object): """ The server response is a valid HTTP 101 response. """ - return odict.ODictCaseless( + return Headers( [ ('Connection', 'Upgrade'), ('Upgrade', 'websocket'), @@ -93,16 +95,16 @@ class WebsocketsProtocol(object): @classmethod def check_client_handshake(self, headers): - if headers.get_first("upgrade", None) != "websocket": + if headers.get("upgrade") != "websocket": return - return headers.get_first(HEADER_WEBSOCKET_KEY) + return headers.get(HEADER_WEBSOCKET_KEY) @classmethod def check_server_handshake(self, headers): - if headers.get_first("upgrade", None) != "websocket": + if headers.get("upgrade") != "websocket": return - return headers.get_first(HEADER_WEBSOCKET_ACCEPT) + return headers.get(HEADER_WEBSOCKET_ACCEPT) @classmethod diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 99afe00e..8a98884a 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -3,7 +3,7 @@ import cStringIO import urllib import time import traceback -from . import odict, tcp +from . import http, tcp class ClientConn(object): @@ -68,8 +68,8 @@ class WSGIAdaptor(object): 'SCRIPT_NAME': '', 'PATH_INFO': urllib.unquote(path_info), 'QUERY_STRING': query, - 'CONTENT_TYPE': flow.request.headers.get('Content-Type', [''])[0], - 'CONTENT_LENGTH': flow.request.headers.get('Content-Length', [''])[0], + 'CONTENT_TYPE': flow.request.headers.get('Content-Type', ''), + 'CONTENT_LENGTH': flow.request.headers.get('Content-Length', ''), 'SERVER_NAME': self.domain, 'SERVER_PORT': str(self.port), # FIXME: We need to pick up the protocol read from the request. @@ -115,12 +115,12 @@ class WSGIAdaptor(object): def write(data): if not state["headers_sent"]: soc.write("HTTP/1.1 %s\r\n" % state["status"]) - h = state["headers"] - if 'server' not in h: - h["Server"] = [self.sversion] - if 'date' not in h: - h["Date"] = [date_time_string()] - soc.write(h.format()) + headers = state["headers"] + if 'server' not in headers: + headers["Server"] = self.sversion + if 'date' not in headers: + headers["Date"] = date_time_string() + soc.write(str(headers)) soc.write("\r\n") state["headers_sent"] = True if data: @@ -137,7 +137,7 @@ class WSGIAdaptor(object): elif state["status"]: raise AssertionError('Response already started') state["status"] = status - state["headers"] = odict.ODictCaseless(headers) + state["headers"] = http.Headers(headers) return write errs = cStringIO.StringIO() @@ -149,7 +149,7 @@ class WSGIAdaptor(object): write(i) if not state["headers_sent"]: write("") - except Exception: + except Exception as e: try: s = traceback.format_exc() errs.write(s) |