diff options
Diffstat (limited to 'netlib')
| -rw-r--r-- | netlib/__init__.py | 0 | ||||
| -rw-r--r-- | netlib/check.py | 22 | ||||
| -rw-r--r-- | netlib/http/__init__.py | 15 | ||||
| -rw-r--r-- | netlib/http/authentication.py | 176 | ||||
| -rw-r--r-- | netlib/http/cookies.py | 384 | ||||
| -rw-r--r-- | netlib/http/encoding.py | 175 | ||||
| -rw-r--r-- | netlib/http/headers.py | 221 | ||||
| -rw-r--r-- | netlib/http/http1/__init__.py | 24 | ||||
| -rw-r--r-- | netlib/http/http1/assemble.py | 100 | ||||
| -rw-r--r-- | netlib/http/http1/read.py | 377 | ||||
| -rw-r--r-- | netlib/http/http2/__init__.py | 8 | ||||
| -rw-r--r-- | netlib/http/http2/framereader.py | 25 | ||||
| -rw-r--r-- | netlib/http/http2/utils.py | 37 | ||||
| -rw-r--r-- | netlib/http/message.py | 300 | ||||
| -rw-r--r-- | netlib/http/multipart.py | 32 | ||||
| -rw-r--r-- | netlib/http/request.py | 405 | ||||
| -rw-r--r-- | netlib/http/response.py | 192 | ||||
| -rw-r--r-- | netlib/http/status_codes.py | 104 | ||||
| -rw-r--r-- | netlib/http/url.py | 127 | ||||
| -rw-r--r-- | netlib/http/user_agents.py | 50 | ||||
| -rw-r--r-- | netlib/socks.py | 234 | ||||
| -rw-r--r-- | netlib/tcp.py | 989 | ||||
| -rw-r--r-- | netlib/websockets/__init__.py | 35 | ||||
| -rw-r--r-- | netlib/websockets/frame.py | 274 | ||||
| -rw-r--r-- | netlib/websockets/masker.py | 25 | ||||
| -rw-r--r-- | netlib/websockets/utils.py | 90 | ||||
| -rw-r--r-- | netlib/wsgi.py | 166 | 
27 files changed, 0 insertions, 4587 deletions
| diff --git a/netlib/__init__.py b/netlib/__init__.py deleted file mode 100644 index e69de29b..00000000 --- a/netlib/__init__.py +++ /dev/null diff --git a/netlib/check.py b/netlib/check.py deleted file mode 100644 index 7b007cb5..00000000 --- a/netlib/check.py +++ /dev/null @@ -1,22 +0,0 @@ -import re - -_label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE) - - -def is_valid_host(host: bytes) -> bool: -    """ -        Checks if a hostname is valid. -    """ -    try: -        host.decode("idna") -    except ValueError: -        return False -    if len(host) > 255: -        return False -    if host and host[-1:] == b".": -        host = host[:-1] -    return all(_label_valid.match(x) for x in host.split(b".")) - - -def is_valid_port(port): -    return 0 <= port <= 65535 diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py deleted file mode 100644 index 315f61ac..00000000 --- a/netlib/http/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from netlib.http.request import Request -from netlib.http.response import Response -from netlib.http.message import Message -from netlib.http.headers import Headers, parse_content_type -from netlib.http.message import decoded -from netlib.http import http1, http2, status_codes, multipart - -__all__ = [ -    "Request", -    "Response", -    "Message", -    "Headers", "parse_content_type", -    "decoded", -    "http1", "http2", "status_codes", "multipart", -] diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py deleted file mode 100644 index a65279e4..00000000 --- a/netlib/http/authentication.py +++ /dev/null @@ -1,176 +0,0 @@ -import argparse -import binascii - - -def parse_http_basic_auth(s): -    words = s.split() -    if len(words) != 2: -        return None -    scheme = words[0] -    try: -        user = binascii.a2b_base64(words[1]).decode("utf8", "replace") -    except binascii.Error: -        return None -    parts = user.split(':') -    if len(parts) != 2: -        return None -    return scheme, parts[0], parts[1] - - -def assemble_http_basic_auth(scheme, username, password): -    v = binascii.b2a_base64((username + ":" + password).encode("utf8")).decode("ascii") -    return scheme + " " + v - - -class NullProxyAuth: - -    """ -        No proxy auth at all (returns empty challange headers) -    """ - -    def __init__(self, password_manager): -        self.password_manager = password_manager - -    def clean(self, headers_): -        """ -            Clean up authentication headers, so they're not passed upstream. -        """ - -    def authenticate(self, headers_): -        """ -            Tests that the user is allowed to use the proxy -        """ -        return True - -    def auth_challenge_headers(self): -        """ -            Returns a dictionary containing the headers require to challenge the user -        """ -        return {} - - -class BasicAuth(NullProxyAuth): -    CHALLENGE_HEADER = None -    AUTH_HEADER = None - -    def __init__(self, password_manager, realm): -        NullProxyAuth.__init__(self, password_manager) -        self.realm = realm - -    def clean(self, headers): -        del headers[self.AUTH_HEADER] - -    def authenticate(self, headers): -        auth_value = headers.get(self.AUTH_HEADER) -        if not auth_value: -            return False -        parts = parse_http_basic_auth(auth_value) -        if not parts: -            return False -        scheme, username, password = parts -        if scheme.lower() != 'basic': -            return False -        if not self.password_manager.test(username, password): -            return False -        self.username = username -        return True - -    def auth_challenge_headers(self): -        return {self.CHALLENGE_HEADER: 'Basic realm="%s"' % self.realm} - - -class BasicWebsiteAuth(BasicAuth): -    CHALLENGE_HEADER = 'WWW-Authenticate' -    AUTH_HEADER = 'Authorization' - - -class BasicProxyAuth(BasicAuth): -    CHALLENGE_HEADER = 'Proxy-Authenticate' -    AUTH_HEADER = 'Proxy-Authorization' - - -class PassMan: - -    def test(self, username_, password_token_): -        return False - - -class PassManNonAnon(PassMan): - -    """ -        Ensure the user specifies a username, accept any password. -    """ - -    def test(self, username, password_token_): -        if username: -            return True -        return False - - -class PassManHtpasswd(PassMan): - -    """ -        Read usernames and passwords from an htpasswd file -    """ - -    def __init__(self, path): -        """ -            Raises ValueError if htpasswd file is invalid. -        """ -        import passlib.apache -        self.htpasswd = passlib.apache.HtpasswdFile(path) - -    def test(self, username, password_token): -        return bool(self.htpasswd.check_password(username, password_token)) - - -class PassManSingleUser(PassMan): - -    def __init__(self, username, password): -        self.username, self.password = username, password - -    def test(self, username, password_token): -        return self.username == username and self.password == password_token - - -class AuthAction(argparse.Action): - -    """ -    Helper class to allow seamless integration int argparse. Example usage: -    parser.add_argument( -        "--nonanonymous", -        action=NonanonymousAuthAction, nargs=0, -        help="Allow access to any user long as a credentials are specified." -    ) -    """ - -    def __call__(self, parser, namespace, values, option_string=None): -        passman = self.getPasswordManager(values) -        authenticator = BasicProxyAuth(passman, "mitmproxy") -        setattr(namespace, self.dest, authenticator) - -    def getPasswordManager(self, s):  # pragma: no cover -        raise NotImplementedError() - - -class SingleuserAuthAction(AuthAction): - -    def getPasswordManager(self, s): -        if len(s.split(':')) != 2: -            raise argparse.ArgumentTypeError( -                "Invalid single-user specification. Please use the format username:password" -            ) -        username, password = s.split(':') -        return PassManSingleUser(username, password) - - -class NonanonymousAuthAction(AuthAction): - -    def getPasswordManager(self, s): -        return PassManNonAnon() - - -class HtpasswdAuthAction(AuthAction): - -    def getPasswordManager(self, s): -        return PassManHtpasswd(s) diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py deleted file mode 100644 index 9f32fa5e..00000000 --- a/netlib/http/cookies.py +++ /dev/null @@ -1,384 +0,0 @@ -import collections -import email.utils -import re -import time - -from mitmproxy.types import multidict - -""" -A flexible module for cookie parsing and manipulation. - -This module differs from usual standards-compliant cookie modules in a number -of ways. We try to be as permissive as possible, and to retain even mal-formed -information. Duplicate cookies are preserved in parsing, and can be set in -formatting. We do attempt to escape and quote values where needed, but will not -reject data that violate the specs. - -Parsing accepts the formats in RFC6265 and partially RFC2109 and RFC2965. We -also parse the comma-separated variant of Set-Cookie that allows multiple -cookies to be set in a single header. Serialization follows RFC6265. - -    http://tools.ietf.org/html/rfc6265 -    http://tools.ietf.org/html/rfc2109 -    http://tools.ietf.org/html/rfc2965 -""" - -_cookie_params = set(( -    'expires', 'path', 'comment', 'max-age', -    'secure', 'httponly', 'version', -)) - -ESCAPE = re.compile(r"([\"\\])") - - -class CookieAttrs(multidict.ImmutableMultiDict): -    @staticmethod -    def _kconv(key): -        return key.lower() - -    @staticmethod -    def _reduce_values(values): -        # See the StickyCookieTest for a weird cookie that only makes sense -        # if we take the last part. -        return values[-1] - -SetCookie = collections.namedtuple("SetCookie", ["value", "attrs"]) - - -def _read_until(s, start, term): -    """ -        Read until one of the characters in term is reached. -    """ -    if start == len(s): -        return "", start + 1 -    for i in range(start, len(s)): -        if s[i] in term: -            return s[start:i], i -    return s[start:i + 1], i + 1 - - -def _read_quoted_string(s, start): -    """ -        start: offset to the first quote of the string to be read - -        A sort of loose super-set of the various quoted string specifications. - -        RFC6265 disallows backslashes or double quotes within quoted strings. -        Prior RFCs use backslashes to escape. This leaves us free to apply -        backslash escaping by default and be compatible with everything. -    """ -    escaping = False -    ret = [] -    # Skip the first quote -    i = start  # initialize in case the loop doesn't run. -    for i in range(start + 1, len(s)): -        if escaping: -            ret.append(s[i]) -            escaping = False -        elif s[i] == '"': -            break -        elif s[i] == "\\": -            escaping = True -        else: -            ret.append(s[i]) -    return "".join(ret), i + 1 - - -def _read_key(s, start, delims=";="): -    """ -        Read a key - the LHS of a token/value pair in a cookie. -    """ -    return _read_until(s, start, delims) - - -def _read_value(s, start, delims): -    """ -        Reads a value - the RHS of a token/value pair in a cookie. -    """ -    if start >= len(s): -        return "", start -    elif s[start] == '"': -        return _read_quoted_string(s, start) -    else: -        return _read_until(s, start, delims) - - -def _read_cookie_pairs(s, off=0): -    """ -        Read pairs of lhs=rhs values from Cookie headers. - -        off: start offset -    """ -    pairs = [] - -    while True: -        lhs, off = _read_key(s, off) -        lhs = lhs.lstrip() - -        if lhs: -            rhs = None -            if off < len(s) and s[off] == "=": -                rhs, off = _read_value(s, off + 1, ";") - -            pairs.append([lhs, rhs]) - -        off += 1 - -        if not off < len(s): -            break - -    return pairs, off - - -def _read_set_cookie_pairs(s, off=0): -    """ -        Read pairs of lhs=rhs values from SetCookie headers while handling multiple cookies. - -        off: start offset -        specials: attributes that are treated specially -    """ -    cookies = [] -    pairs = [] - -    while True: -        lhs, off = _read_key(s, off, ";=,") -        lhs = lhs.lstrip() - -        if lhs: -            rhs = None -            if off < len(s) and s[off] == "=": -                rhs, off = _read_value(s, off + 1, ";,") - -                # Special handliing of attributes -                if lhs.lower() == "expires": -                    # 'expires' values can contain commas in them so they need to -                    # be handled separately. - -                    # We actually bank on the fact that the expires value WILL -                    # contain a comma. Things will fail, if they don't. - -                    # '3' is just a heuristic we use to determine whether we've -                    # only read a part of the expires value and we should read more. -                    if len(rhs) <= 3: -                        trail, off = _read_value(s, off + 1, ";,") -                        rhs = rhs + "," + trail - -            pairs.append([lhs, rhs]) - -            # comma marks the beginning of a new cookie -            if off < len(s) and s[off] == ",": -                cookies.append(pairs) -                pairs = [] - -        off += 1 - -        if not off < len(s): -            break - -    if pairs or not cookies: -        cookies.append(pairs) - -    return cookies, off - - -def _has_special(s): -    for i in s: -        if i in '",;\\': -            return True -        o = ord(i) -        if o < 0x21 or o > 0x7e: -            return True -    return False - - -def _format_pairs(pairs, specials=(), sep="; "): -    """ -        specials: A lower-cased list of keys that will not be quoted. -    """ -    vals = [] -    for k, v in pairs: -        if v is None: -            vals.append(k) -        else: -            if k.lower() not in specials and _has_special(v): -                v = ESCAPE.sub(r"\\\1", v) -                v = '"%s"' % v -            vals.append("%s=%s" % (k, v)) -    return sep.join(vals) - - -def _format_set_cookie_pairs(lst): -    return _format_pairs( -        lst, -        specials=("expires", "path") -    ) - - -def parse_cookie_header(line): -    """ -        Parse a Cookie header value. -        Returns a list of (lhs, rhs) tuples. -    """ -    pairs, off_ = _read_cookie_pairs(line) -    return pairs - - -def parse_cookie_headers(cookie_headers): -    cookie_list = [] -    for header in cookie_headers: -        cookie_list.extend(parse_cookie_header(header)) -    return cookie_list - - -def format_cookie_header(lst): -    """ -        Formats a Cookie header value. -    """ -    return _format_pairs(lst) - - -def parse_set_cookie_header(line): -    """ -        Parse a Set-Cookie header value - -        Returns a list of (name, value, attrs) tuples, where attrs is a -        CookieAttrs dict of attributes. No attempt is made to parse attribute -        values - they are treated purely as strings. -    """ -    cookie_pairs, off = _read_set_cookie_pairs(line) -    cookies = [ -        (pairs[0][0], pairs[0][1], CookieAttrs(tuple(x) for x in pairs[1:])) -        for pairs in cookie_pairs if pairs -    ] -    return cookies - - -def parse_set_cookie_headers(headers): -    rv = [] -    for header in headers: -        cookies = parse_set_cookie_header(header) -        if cookies: -            for name, value, attrs in cookies: -                rv.append((name, SetCookie(value, attrs))) -    return rv - - -def format_set_cookie_header(set_cookies): -    """ -        Formats a Set-Cookie header value. -    """ - -    rv = [] - -    for set_cookie in set_cookies: -        name, value, attrs = set_cookie - -        pairs = [(name, value)] -        pairs.extend( -            attrs.fields if hasattr(attrs, "fields") else attrs -        ) - -        rv.append(_format_set_cookie_pairs(pairs)) - -    return ", ".join(rv) - - -def refresh_set_cookie_header(c, delta): -    """ -    Args: -        c: A Set-Cookie string -        delta: Time delta in seconds -    Returns: -        A refreshed Set-Cookie string -    """ - -    name, value, attrs = parse_set_cookie_header(c)[0] -    if not name or not value: -        raise ValueError("Invalid Cookie") - -    if "expires" in attrs: -        e = email.utils.parsedate_tz(attrs["expires"]) -        if e: -            f = email.utils.mktime_tz(e) + delta -            attrs = attrs.with_set_all("expires", [email.utils.formatdate(f)]) -        else: -            # This can happen when the expires tag is invalid. -            # reddit.com sends a an expires tag like this: "Thu, 31 Dec -            # 2037 23:59:59 GMT", which is valid RFC 1123, but not -            # strictly correct according to the cookie spec. Browsers -            # appear to parse this tolerantly - maybe we should too. -            # For now, we just ignore this. -            attrs = attrs.with_delitem("expires") - -    rv = format_set_cookie_header([(name, value, attrs)]) -    if not rv: -        raise ValueError("Invalid Cookie") -    return rv - - -def get_expiration_ts(cookie_attrs): -    """ -        Determines the time when the cookie will be expired. - -        Considering both 'expires' and 'max-age' parameters. - -        Returns: timestamp of when the cookie will expire. -                 None, if no expiration time is set. -    """ -    if 'expires' in cookie_attrs: -        e = email.utils.parsedate_tz(cookie_attrs["expires"]) -        if e: -            return email.utils.mktime_tz(e) - -    elif 'max-age' in cookie_attrs: -        try: -            max_age = int(cookie_attrs['Max-Age']) -        except ValueError: -            pass -        else: -            now_ts = time.time() -            return now_ts + max_age - -    return None - - -def is_expired(cookie_attrs): -    """ -        Determines whether a cookie has expired. - -        Returns: boolean -    """ - -    exp_ts = get_expiration_ts(cookie_attrs) -    now_ts = time.time() - -    # If no expiration information was provided with the cookie -    if exp_ts is None: -        return False -    else: -        return exp_ts <= now_ts - - -def group_cookies(pairs): -    """ -    Converts a list of pairs to a (name, value, attrs) for each cookie. -    """ - -    if not pairs: -        return [] - -    cookie_list = [] - -    # First pair is always a new cookie -    name, value = pairs[0] -    attrs = [] - -    for k, v in pairs[1:]: -        if k.lower() in _cookie_params: -            attrs.append((k, v)) -        else: -            cookie_list.append((name, value, CookieAttrs(attrs))) -            name, value, attrs = k, v, [] - -    cookie_list.append((name, value, CookieAttrs(attrs))) -    return cookie_list diff --git a/netlib/http/encoding.py b/netlib/http/encoding.py deleted file mode 100644 index e123a033..00000000 --- a/netlib/http/encoding.py +++ /dev/null @@ -1,175 +0,0 @@ -""" -Utility functions for decoding response bodies. -""" - -import codecs -import collections -from io import BytesIO - -import gzip -import zlib -import brotli - -from typing import Union - - -# We have a shared single-element cache for encoding and decoding. -# This is quite useful in practice, e.g. -# flow.request.content = flow.request.content.replace(b"foo", b"bar") -# does not require an .encode() call if content does not contain b"foo" -CachedDecode = collections.namedtuple("CachedDecode", "encoded encoding errors decoded") -_cache = CachedDecode(None, None, None, None) - - -def decode(encoded: Union[str, bytes], encoding: str, errors: str='strict') -> Union[str, bytes]: -    """ -    Decode the given input object - -    Returns: -        The decoded value - -    Raises: -        ValueError, if decoding fails. -    """ -    if len(encoded) == 0: -        return encoded - -    global _cache -    cached = ( -        isinstance(encoded, bytes) and -        _cache.encoded == encoded and -        _cache.encoding == encoding and -        _cache.errors == errors -    ) -    if cached: -        return _cache.decoded -    try: -        try: -            decoded = custom_decode[encoding](encoded) -        except KeyError: -            decoded = codecs.decode(encoded, encoding, errors) -        if encoding in ("gzip", "deflate", "br"): -            _cache = CachedDecode(encoded, encoding, errors, decoded) -        return decoded -    except TypeError: -        raise -    except Exception as e: -        raise ValueError("{} when decoding {} with {}: {}".format( -            type(e).__name__, -            repr(encoded)[:10], -            repr(encoding), -            repr(e), -        )) - - -def encode(decoded: Union[str, bytes], encoding: str, errors: str='strict') -> Union[str, bytes]: -    """ -    Encode the given input object - -    Returns: -        The encoded value - -    Raises: -        ValueError, if encoding fails. -    """ -    if len(decoded) == 0: -        return decoded - -    global _cache -    cached = ( -        isinstance(decoded, bytes) and -        _cache.decoded == decoded and -        _cache.encoding == encoding and -        _cache.errors == errors -    ) -    if cached: -        return _cache.encoded -    try: -        try: -            value = decoded -            if isinstance(value, str): -                value = decoded.encode() -            encoded = custom_encode[encoding](value) -        except KeyError: -            encoded = codecs.encode(decoded, encoding, errors) -        if encoding in ("gzip", "deflate", "br"): -            _cache = CachedDecode(encoded, encoding, errors, decoded) -        return encoded -    except TypeError: -        raise -    except Exception as e: -        raise ValueError("{} when encoding {} with {}: {}".format( -            type(e).__name__, -            repr(decoded)[:10], -            repr(encoding), -            repr(e), -        )) - - -def identity(content): -    """ -        Returns content unchanged. Identity is the default value of -        Accept-Encoding headers. -    """ -    return content - - -def decode_gzip(content): -    gfile = gzip.GzipFile(fileobj=BytesIO(content)) -    return gfile.read() - - -def encode_gzip(content): -    s = BytesIO() -    gf = gzip.GzipFile(fileobj=s, mode='wb') -    gf.write(content) -    gf.close() -    return s.getvalue() - - -def decode_brotli(content): -    return brotli.decompress(content) - - -def encode_brotli(content): -    return brotli.compress(content) - - -def decode_deflate(content): -    """ -        Returns decompressed data for DEFLATE. Some servers may respond with -        compressed data without a zlib header or checksum. An undocumented -        feature of zlib permits the lenient decompression of data missing both -        values. - -        http://bugs.python.org/issue5784 -    """ -    try: -        return zlib.decompress(content) -    except zlib.error: -        return zlib.decompress(content, -15) - - -def encode_deflate(content): -    """ -        Returns compressed content, always including zlib header and checksum. -    """ -    return zlib.compress(content) - - -custom_decode = { -    "none": identity, -    "identity": identity, -    "gzip": decode_gzip, -    "deflate": decode_deflate, -    "br": decode_brotli, -} -custom_encode = { -    "none": identity, -    "identity": identity, -    "gzip": encode_gzip, -    "deflate": encode_deflate, -    "br": encode_brotli, -} - -__all__ = ["encode", "decode"] diff --git a/netlib/http/headers.py b/netlib/http/headers.py deleted file mode 100644 index 8fc0cd43..00000000 --- a/netlib/http/headers.py +++ /dev/null @@ -1,221 +0,0 @@ -import re - -import collections -from mitmproxy.types import multidict -from mitmproxy.utils import strutils - -# See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/ - - -# While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. -def _native(x): -    return x.decode("utf-8", "surrogateescape") - - -def _always_bytes(x): -    return strutils.always_bytes(x, "utf-8", "surrogateescape") - - -class Headers(multidict.MultiDict): -    """ -    Header class which allows both convenient access to individual headers as well as -    direct access to the underlying raw data. Provides a full dictionary interface. - -    Example: - -    .. code-block:: python - -        # Create headers with keyword arguments -        >>> h = Headers(host="example.com", content_type="application/xml") - -        # Headers mostly behave like a normal dict. -        >>> h["Host"] -        "example.com" - -        # HTTP Headers are case insensitive -        >>> h["host"] -        "example.com" - -        # Headers can also be created from a list of raw (header_name, header_value) byte tuples -        >>> h = Headers([ -            (b"Host",b"example.com"), -            (b"Accept",b"text/html"), -            (b"accept",b"application/xml") -        ]) - -        # Multiple headers are folded into a single header as per RFC7230 -        >>> h["Accept"] -        "text/html, application/xml" - -        # Setting a header removes all existing headers with the same name. -        >>> h["Accept"] = "application/text" -        >>> h["Accept"] -        "application/text" - -        # bytes(h) returns a HTTP1 header block. -        >>> print(bytes(h)) -        Host: example.com -        Accept: application/text - -        # For full control, the raw header fields can be accessed -        >>> h.fields - -    Caveats: -        For use with the "Set-Cookie" header, see :py:meth:`get_all`. -    """ - -    def __init__(self, fields=(), **headers): -        """ -        Args: -            fields: (optional) list of ``(name, value)`` header byte tuples, -                e.g. ``[(b"Host", b"example.com")]``. All names and values must be bytes. -            **headers: Additional headers to set. Will overwrite existing values from `fields`. -                For convenience, underscores in header names will be transformed to dashes - -                this behaviour does not extend to other methods. -                If ``**headers`` contains multiple keys that have equal ``.lower()`` s, -                the behavior is undefined. -        """ -        super().__init__(fields) - -        for key, value in self.fields: -            if not isinstance(key, bytes) or not isinstance(value, bytes): -                raise TypeError("Header fields must be bytes.") - -        # content_type -> content-type -        headers = { -            _always_bytes(name).replace(b"_", b"-"): _always_bytes(value) -            for name, value in headers.items() -        } -        self.update(headers) - -    @staticmethod -    def _reduce_values(values): -        # Headers can be folded -        return ", ".join(values) - -    @staticmethod -    def _kconv(key): -        # Headers are case-insensitive -        return key.lower() - -    def __bytes__(self): -        if self.fields: -            return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" -        else: -            return b"" - -    def __delitem__(self, key): -        key = _always_bytes(key) -        super().__delitem__(key) - -    def __iter__(self): -        for x in super().__iter__(): -            yield _native(x) - -    def get_all(self, name): -        """ -        Like :py:meth:`get`, but does not fold multiple headers into a single one. -        This is useful for Set-Cookie headers, which do not support folding. -        See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 -        """ -        name = _always_bytes(name) -        return [ -            _native(x) for x in -            super().get_all(name) -        ] - -    def set_all(self, name, values): -        """ -        Explicitly set multiple headers for the given key. -        See: :py:meth:`get_all` -        """ -        name = _always_bytes(name) -        values = [_always_bytes(x) for x in values] -        return super().set_all(name, values) - -    def insert(self, index, key, value): -        key = _always_bytes(key) -        value = _always_bytes(value) -        super().insert(index, key, value) - -    def items(self, multi=False): -        if multi: -            return ( -                (_native(k), _native(v)) -                for k, v in self.fields -            ) -        else: -            return super().items() - -    def replace(self, pattern, repl, flags=0, count=0): -        """ -        Replaces a regular expression pattern with repl in each "name: value" -        header line. - -        Returns: -            The number of replacements made. -        """ -        if isinstance(pattern, str): -            pattern = strutils.escaped_str_to_bytes(pattern) -        if isinstance(repl, str): -            repl = strutils.escaped_str_to_bytes(repl) -        pattern = re.compile(pattern, flags) -        replacements = 0 -        flag_count = count > 0 -        fields = [] -        for name, value in self.fields: -            line, n = pattern.subn(repl, name + b": " + value, count=count) -            try: -                name, value = line.split(b": ", 1) -            except ValueError: -                # We get a ValueError if the replacement removed the ": " -                # There's not much we can do about this, so we just keep the header as-is. -                pass -            else: -                replacements += n -                if flag_count: -                    count -= n -                    if count == 0: -                        break -            fields.append((name, value)) -        self.fields = tuple(fields) -        return replacements - - -def parse_content_type(c): -    """ -        A simple parser for content-type values. Returns a (type, subtype, -        parameters) tuple, where type and subtype are strings, and parameters -        is a dict. If the string could not be parsed, return None. - -        E.g. the following string: - -            text/html; charset=UTF-8 - -        Returns: - -            ("text", "html", {"charset": "UTF-8"}) -    """ -    parts = c.split(";", 1) -    ts = parts[0].split("/", 1) -    if len(ts) != 2: -        return None -    d = collections.OrderedDict() -    if len(parts) == 2: -        for i in parts[1].split(";"): -            clause = i.split("=", 1) -            if len(clause) == 2: -                d[clause[0].strip()] = clause[1].strip() -    return ts[0].lower(), ts[1].lower(), d - - -def assemble_content_type(type, subtype, parameters): -    if not parameters: -        return "{}/{}".format(type, subtype) -    params = "; ".join( -        "{}={}".format(k, v) -        for k, v in parameters.items() -    ) -    return "{}/{}; {}".format( -        type, subtype, params -    ) diff --git a/netlib/http/http1/__init__.py b/netlib/http/http1/__init__.py deleted file mode 100644 index e4bf01c5..00000000 --- a/netlib/http/http1/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -from .read import ( -    read_request, read_request_head, -    read_response, read_response_head, -    read_body, -    connection_close, -    expected_http_body_size, -) -from .assemble import ( -    assemble_request, assemble_request_head, -    assemble_response, assemble_response_head, -    assemble_body, -) - - -__all__ = [ -    "read_request", "read_request_head", -    "read_response", "read_response_head", -    "read_body", -    "connection_close", -    "expected_http_body_size", -    "assemble_request", "assemble_request_head", -    "assemble_response", "assemble_response_head", -    "assemble_body", -] diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py deleted file mode 100644 index e0a91ad8..00000000 --- a/netlib/http/http1/assemble.py +++ /dev/null @@ -1,100 +0,0 @@ -import netlib.http.url -from mitmproxy import exceptions - - -def assemble_request(request): -    if request.data.content is None: -        raise exceptions.HttpException("Cannot assemble flow with missing content") -    head = assemble_request_head(request) -    body = b"".join(assemble_body(request.data.headers, [request.data.content])) -    return head + body - - -def assemble_request_head(request): -    first_line = _assemble_request_line(request.data) -    headers = _assemble_request_headers(request.data) -    return b"%s\r\n%s\r\n" % (first_line, headers) - - -def assemble_response(response): -    if response.data.content is None: -        raise exceptions.HttpException("Cannot assemble flow with missing content") -    head = assemble_response_head(response) -    body = b"".join(assemble_body(response.data.headers, [response.data.content])) -    return head + body - - -def assemble_response_head(response): -    first_line = _assemble_response_line(response.data) -    headers = _assemble_response_headers(response.data) -    return b"%s\r\n%s\r\n" % (first_line, headers) - - -def assemble_body(headers, body_chunks): -    if "chunked" in headers.get("transfer-encoding", "").lower(): -        for chunk in body_chunks: -            if chunk: -                yield b"%x\r\n%s\r\n" % (len(chunk), chunk) -        yield b"0\r\n\r\n" -    else: -        for chunk in body_chunks: -            yield chunk - - -def _assemble_request_line(request_data): -    """ -    Args: -        request_data (netlib.http.request.RequestData) -    """ -    form = request_data.first_line_format -    if form == "relative": -        return b"%s %s %s" % ( -            request_data.method, -            request_data.path, -            request_data.http_version -        ) -    elif form == "authority": -        return b"%s %s:%d %s" % ( -            request_data.method, -            request_data.host, -            request_data.port, -            request_data.http_version -        ) -    elif form == "absolute": -        return b"%s %s://%s:%d%s %s" % ( -            request_data.method, -            request_data.scheme, -            request_data.host, -            request_data.port, -            request_data.path, -            request_data.http_version -        ) -    else: -        raise RuntimeError("Invalid request form") - - -def _assemble_request_headers(request_data): -    """ -    Args: -        request_data (netlib.http.request.RequestData) -    """ -    headers = request_data.headers.copy() -    if "host" not in headers and request_data.scheme and request_data.host and request_data.port: -        headers["host"] = netlib.http.url.hostport( -            request_data.scheme, -            request_data.host, -            request_data.port -        ) -    return bytes(headers) - - -def _assemble_response_line(response_data): -    return b"%s %d %s" % ( -        response_data.http_version, -        response_data.status_code, -        response_data.reason, -    ) - - -def _assemble_response_headers(response): -    return bytes(response.headers) diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py deleted file mode 100644 index e6b22863..00000000 --- a/netlib/http/http1/read.py +++ /dev/null @@ -1,377 +0,0 @@ -import time -import sys -import re - -from netlib.http import request -from netlib.http import response -from netlib.http import headers -from netlib.http import url -from netlib import check -from mitmproxy import exceptions - - -def get_header_tokens(headers, key): -    """ -        Retrieve all tokens for a header key. A number of different headers -        follow a pattern where each header line can containe comma-separated -        tokens, and headers can be set multiple times. -    """ -    if key not in headers: -        return [] -    tokens = headers[key].split(",") -    return [token.strip() for token in tokens] - - -def read_request(rfile, body_size_limit=None): -    request = read_request_head(rfile) -    expected_body_size = expected_http_body_size(request) -    request.data.content = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit)) -    request.timestamp_end = time.time() -    return request - - -def read_request_head(rfile): -    """ -    Parse an HTTP request head (request line + headers) from an input stream - -    Args: -        rfile: The input stream - -    Returns: -        The HTTP request object (without body) - -    Raises: -        exceptions.HttpReadDisconnect: No bytes can be read from rfile. -        exceptions.HttpSyntaxException: The input is malformed HTTP. -        exceptions.HttpException: Any other error occured. -    """ -    timestamp_start = time.time() -    if hasattr(rfile, "reset_timestamps"): -        rfile.reset_timestamps() - -    form, method, scheme, host, port, path, http_version = _read_request_line(rfile) -    headers = _read_headers(rfile) - -    if hasattr(rfile, "first_byte_timestamp"): -        # more accurate timestamp_start -        timestamp_start = rfile.first_byte_timestamp - -    return request.Request( -        form, method, scheme, host, port, path, http_version, headers, None, timestamp_start -    ) - - -def read_response(rfile, request, body_size_limit=None): -    response = read_response_head(rfile) -    expected_body_size = expected_http_body_size(request, response) -    response.data.content = b"".join(read_body(rfile, expected_body_size, body_size_limit)) -    response.timestamp_end = time.time() -    return response - - -def read_response_head(rfile): -    """ -    Parse an HTTP response head (response line + headers) from an input stream - -    Args: -        rfile: The input stream - -    Returns: -        The HTTP request object (without body) - -    Raises: -        exceptions.HttpReadDisconnect: No bytes can be read from rfile. -        exceptions.HttpSyntaxException: The input is malformed HTTP. -        exceptions.HttpException: Any other error occured. -    """ - -    timestamp_start = time.time() -    if hasattr(rfile, "reset_timestamps"): -        rfile.reset_timestamps() - -    http_version, status_code, message = _read_response_line(rfile) -    headers = _read_headers(rfile) - -    if hasattr(rfile, "first_byte_timestamp"): -        # more accurate timestamp_start -        timestamp_start = rfile.first_byte_timestamp - -    return response.Response(http_version, status_code, message, headers, None, timestamp_start) - - -def read_body(rfile, expected_size, limit=None, max_chunk_size=4096): -    """ -        Read an HTTP message body - -        Args: -            rfile: The input stream -            expected_size: The expected body size (see :py:meth:`expected_body_size`) -            limit: Maximum body size -            max_chunk_size: Maximium chunk size that gets yielded - -        Returns: -            A generator that yields byte chunks of the content. - -        Raises: -            exceptions.HttpException, if an error occurs - -        Caveats: -            max_chunk_size is not considered if the transfer encoding is chunked. -    """ -    if not limit or limit < 0: -        limit = sys.maxsize -    if not max_chunk_size: -        max_chunk_size = limit - -    if expected_size is None: -        for x in _read_chunked(rfile, limit): -            yield x -    elif expected_size >= 0: -        if limit is not None and expected_size > limit: -            raise exceptions.HttpException( -                "HTTP Body too large. " -                "Limit is {}, content length was advertised as {}".format(limit, expected_size) -            ) -        bytes_left = expected_size -        while bytes_left: -            chunk_size = min(bytes_left, max_chunk_size) -            content = rfile.read(chunk_size) -            if len(content) < chunk_size: -                raise exceptions.HttpException("Unexpected EOF") -            yield content -            bytes_left -= chunk_size -    else: -        bytes_left = limit -        while bytes_left: -            chunk_size = min(bytes_left, max_chunk_size) -            content = rfile.read(chunk_size) -            if not content: -                return -            yield content -            bytes_left -= chunk_size -        not_done = rfile.read(1) -        if not_done: -            raise exceptions.HttpException("HTTP body too large. Limit is {}.".format(limit)) - - -def connection_close(http_version, headers): -    """ -        Checks the message to see if the client connection should be closed -        according to RFC 2616 Section 8.1. -    """ -    # At first, check if we have an explicit Connection header. -    if "connection" in headers: -        tokens = get_header_tokens(headers, "connection") -        if "close" in tokens: -            return True -        elif "keep-alive" in tokens: -            return False - -    # If we don't have a Connection header, HTTP 1.1 connections are assumed to -    # be persistent -    return http_version != "HTTP/1.1" and http_version != b"HTTP/1.1"  # FIXME: Remove one case. - - -def expected_http_body_size(request, response=None): -    """ -        Returns: -            The expected body length: -            - a positive integer, if the size is known in advance -            - None, if the size in unknown in advance (chunked encoding) -            - -1, if all data should be read until end of stream. - -        Raises: -            exceptions.HttpSyntaxException, if the content length header is invalid -    """ -    # Determine response size according to -    # http://tools.ietf.org/html/rfc7230#section-3.3 -    if not response: -        headers = request.headers -        response_code = None -        is_request = True -    else: -        headers = response.headers -        response_code = response.status_code -        is_request = False - -    if is_request: -        if headers.get("expect", "").lower() == "100-continue": -            return 0 -    else: -        if request.method.upper() == "HEAD": -            return 0 -        if 100 <= response_code <= 199: -            return 0 -        if response_code == 200 and request.method.upper() == "CONNECT": -            return 0 -        if response_code in (204, 304): -            return 0 - -    if "chunked" in headers.get("transfer-encoding", "").lower(): -        return None -    if "content-length" in headers: -        try: -            size = int(headers["content-length"]) -            if size < 0: -                raise ValueError() -            return size -        except ValueError: -            raise exceptions.HttpSyntaxException("Unparseable Content Length") -    if is_request: -        return 0 -    return -1 - - -def _get_first_line(rfile): -    try: -        line = rfile.readline() -        if line == b"\r\n" or line == b"\n": -            # Possible leftover from previous message -            line = rfile.readline() -    except exceptions.TcpDisconnect: -        raise exceptions.HttpReadDisconnect("Remote disconnected") -    if not line: -        raise exceptions.HttpReadDisconnect("Remote disconnected") -    return line.strip() - - -def _read_request_line(rfile): -    try: -        line = _get_first_line(rfile) -    except exceptions.HttpReadDisconnect: -        # We want to provide a better error message. -        raise exceptions.HttpReadDisconnect("Client disconnected") - -    try: -        method, path, http_version = line.split() - -        if path == b"*" or path.startswith(b"/"): -            form = "relative" -            scheme, host, port = None, None, None -        elif method == b"CONNECT": -            form = "authority" -            host, port = _parse_authority_form(path) -            scheme, path = None, None -        else: -            form = "absolute" -            scheme, host, port, path = url.parse(path) - -        _check_http_version(http_version) -    except ValueError: -        raise exceptions.HttpSyntaxException("Bad HTTP request line: {}".format(line)) - -    return form, method, scheme, host, port, path, http_version - - -def _parse_authority_form(hostport): -    """ -        Returns (host, port) if hostport is a valid authority-form host specification. -        http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 - -        Raises: -            ValueError, if the input is malformed -    """ -    try: -        host, port = hostport.split(b":") -        port = int(port) -        if not check.is_valid_host(host) or not check.is_valid_port(port): -            raise ValueError() -    except ValueError: -        raise exceptions.HttpSyntaxException("Invalid host specification: {}".format(hostport)) - -    return host, port - - -def _read_response_line(rfile): -    try: -        line = _get_first_line(rfile) -    except exceptions.HttpReadDisconnect: -        # We want to provide a better error message. -        raise exceptions.HttpReadDisconnect("Server disconnected") - -    try: -        parts = line.split(None, 2) -        if len(parts) == 2:  # handle missing message gracefully -            parts.append(b"") - -        http_version, status_code, message = parts -        status_code = int(status_code) -        _check_http_version(http_version) - -    except ValueError: -        raise exceptions.HttpSyntaxException("Bad HTTP response line: {}".format(line)) - -    return http_version, status_code, message - - -def _check_http_version(http_version): -    if not re.match(br"^HTTP/\d\.\d$", http_version): -        raise exceptions.HttpSyntaxException("Unknown HTTP version: {}".format(http_version)) - - -def _read_headers(rfile): -    """ -        Read a set of headers. -        Stop once a blank line is reached. - -        Returns: -            A headers object - -        Raises: -            exceptions.HttpSyntaxException -    """ -    ret = [] -    while True: -        line = rfile.readline() -        if not line or line == b"\r\n" or line == b"\n": -            break -        if line[0] in b" \t": -            if not ret: -                raise exceptions.HttpSyntaxException("Invalid headers") -            # continued header -            ret[-1] = (ret[-1][0], ret[-1][1] + b'\r\n ' + line.strip()) -        else: -            try: -                name, value = line.split(b":", 1) -                value = value.strip() -                if not name: -                    raise ValueError() -                ret.append((name, value)) -            except ValueError: -                raise exceptions.HttpSyntaxException( -                    "Invalid header line: %s" % repr(line) -                ) -    return headers.Headers(ret) - - -def _read_chunked(rfile, limit=sys.maxsize): -    """ -    Read a HTTP body with chunked transfer encoding. - -    Args: -        rfile: the input file -        limit: A positive integer -    """ -    total = 0 -    while True: -        line = rfile.readline(128) -        if line == b"": -            raise exceptions.HttpException("Connection closed prematurely") -        if line != b"\r\n" and line != b"\n": -            try: -                length = int(line, 16) -            except ValueError: -                raise exceptions.HttpSyntaxException("Invalid chunked encoding length: {}".format(line)) -            total += length -            if total > limit: -                raise exceptions.HttpException( -                    "HTTP Body too large. Limit is {}, " -                    "chunked content longer than {}".format(limit, total) -                ) -            chunk = rfile.read(length) -            suffix = rfile.readline(5) -            if suffix != b"\r\n": -                raise exceptions.HttpSyntaxException("Malformed chunked body") -            if length == 0: -                return -            yield chunk diff --git a/netlib/http/http2/__init__.py b/netlib/http/http2/__init__.py deleted file mode 100644 index 20cc63a0..00000000 --- a/netlib/http/http2/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from netlib.http.http2.framereader import read_raw_frame, parse_frame -from netlib.http.http2.utils import parse_headers - -__all__ = [ -    "read_raw_frame", -    "parse_frame", -    "parse_headers", -] diff --git a/netlib/http/http2/framereader.py b/netlib/http/http2/framereader.py deleted file mode 100644 index 6a164919..00000000 --- a/netlib/http/http2/framereader.py +++ /dev/null @@ -1,25 +0,0 @@ -import codecs - -import hyperframe -from mitmproxy import exceptions - - -def read_raw_frame(rfile): -    header = rfile.safe_read(9) -    length = int(codecs.encode(header[:3], 'hex_codec'), 16) - -    if length == 4740180: -        raise exceptions.HttpException("Length field looks more like HTTP/1.1:\n{}".format(rfile.read(-1))) - -    body = rfile.safe_read(length) -    return [header, body] - - -def parse_frame(header, body=None): -    if body is None: -        body = header[9:] -        header = header[:9] - -    frame, length = hyperframe.frame.Frame.parse_frame_header(header) -    frame.parse_body(memoryview(body)) -    return frame diff --git a/netlib/http/http2/utils.py b/netlib/http/http2/utils.py deleted file mode 100644 index 164bacc8..00000000 --- a/netlib/http/http2/utils.py +++ /dev/null @@ -1,37 +0,0 @@ -from netlib.http import url - - -def parse_headers(headers): -    authority = headers.get(':authority', '').encode() -    method = headers.get(':method', 'GET').encode() -    scheme = headers.get(':scheme', 'https').encode() -    path = headers.get(':path', '/').encode() - -    headers.pop(":method", None) -    headers.pop(":scheme", None) -    headers.pop(":path", None) - -    host = None -    port = None - -    if path == b'*' or path.startswith(b"/"): -        first_line_format = "relative" -    elif method == b'CONNECT':  # pragma: no cover -        raise NotImplementedError("CONNECT over HTTP/2 is not implemented.") -    else:  # pragma: no cover -        first_line_format = "absolute" -        # FIXME: verify if path or :host contains what we need -        scheme, host, port, _ = url.parse(path) - -    if authority: -        host, _, port = authority.partition(b':') - -    if not host: -        host = b'localhost' - -    if not port: -        port = 443 if scheme == b'https' else 80 - -    port = int(port) - -    return first_line_format, method, scheme, host, port, path diff --git a/netlib/http/message.py b/netlib/http/message.py deleted file mode 100644 index 772a124e..00000000 --- a/netlib/http/message.py +++ /dev/null @@ -1,300 +0,0 @@ -import re -import warnings -from typing import Optional - -from mitmproxy.utils import strutils -from netlib.http import encoding -from mitmproxy.types import serializable -from netlib.http import headers - - -# While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. -def _native(x): -    return x.decode("utf-8", "surrogateescape") - - -def _always_bytes(x): -    return strutils.always_bytes(x, "utf-8", "surrogateescape") - - -class MessageData(serializable.Serializable): -    def __eq__(self, other): -        if isinstance(other, MessageData): -            return self.__dict__ == other.__dict__ -        return False - -    def __ne__(self, other): -        return not self.__eq__(other) - -    def set_state(self, state): -        for k, v in state.items(): -            if k == "headers": -                v = headers.Headers.from_state(v) -            setattr(self, k, v) - -    def get_state(self): -        state = vars(self).copy() -        state["headers"] = state["headers"].get_state() -        return state - -    @classmethod -    def from_state(cls, state): -        state["headers"] = headers.Headers.from_state(state["headers"]) -        return cls(**state) - - -class Message(serializable.Serializable): -    def __eq__(self, other): -        if isinstance(other, Message): -            return self.data == other.data -        return False - -    def __ne__(self, other): -        return not self.__eq__(other) - -    def get_state(self): -        return self.data.get_state() - -    def set_state(self, state): -        self.data.set_state(state) - -    @classmethod -    def from_state(cls, state): -        state["headers"] = headers.Headers.from_state(state["headers"]) -        return cls(**state) - -    @property -    def headers(self): -        """ -        Message headers object - -        Returns: -            netlib.http.Headers -        """ -        return self.data.headers - -    @headers.setter -    def headers(self, h): -        self.data.headers = h - -    @property -    def raw_content(self) -> bytes: -        """ -        The raw (encoded) HTTP message body - -        See also: :py:attr:`content`, :py:class:`text` -        """ -        return self.data.content - -    @raw_content.setter -    def raw_content(self, content): -        self.data.content = content - -    def get_content(self, strict: bool=True) -> bytes: -        """ -        The HTTP message body decoded with the content-encoding header (e.g. gzip) - -        Raises: -            ValueError, when the content-encoding is invalid and strict is True. - -        See also: :py:class:`raw_content`, :py:attr:`text` -        """ -        if self.raw_content is None: -            return None -        ce = self.headers.get("content-encoding") -        if ce: -            try: -                return encoding.decode(self.raw_content, ce) -            except ValueError: -                if strict: -                    raise -                return self.raw_content -        else: -            return self.raw_content - -    def set_content(self, value): -        if value is None: -            self.raw_content = None -            return -        if not isinstance(value, bytes): -            raise TypeError( -                "Message content must be bytes, not {}. " -                "Please use .text if you want to assign a str." -                .format(type(value).__name__) -            ) -        ce = self.headers.get("content-encoding") -        try: -            self.raw_content = encoding.encode(value, ce or "identity") -        except ValueError: -            # So we have an invalid content-encoding? -            # Let's remove it! -            del self.headers["content-encoding"] -            self.raw_content = value -        self.headers["content-length"] = str(len(self.raw_content)) - -    content = property(get_content, set_content) - -    @property -    def http_version(self): -        """ -        Version string, e.g. "HTTP/1.1" -        """ -        return _native(self.data.http_version) - -    @http_version.setter -    def http_version(self, http_version): -        self.data.http_version = _always_bytes(http_version) - -    @property -    def timestamp_start(self): -        """ -        First byte timestamp -        """ -        return self.data.timestamp_start - -    @timestamp_start.setter -    def timestamp_start(self, timestamp_start): -        self.data.timestamp_start = timestamp_start - -    @property -    def timestamp_end(self): -        """ -        Last byte timestamp -        """ -        return self.data.timestamp_end - -    @timestamp_end.setter -    def timestamp_end(self, timestamp_end): -        self.data.timestamp_end = timestamp_end - -    def _get_content_type_charset(self) -> Optional[str]: -        ct = headers.parse_content_type(self.headers.get("content-type", "")) -        if ct: -            return ct[2].get("charset") - -    def _guess_encoding(self) -> str: -        enc = self._get_content_type_charset() -        if enc: -            return enc - -        if "json" in self.headers.get("content-type", ""): -            return "utf8" -        else: -            # We may also want to check for HTML meta tags here at some point. -            return "latin-1" - -    def get_text(self, strict: bool=True) -> str: -        """ -        The HTTP message body decoded with both content-encoding header (e.g. gzip) -        and content-type header charset. - -        Raises: -            ValueError, when either content-encoding or charset is invalid and strict is True. - -        See also: :py:attr:`content`, :py:class:`raw_content` -        """ -        if self.raw_content is None: -            return None -        enc = self._guess_encoding() - -        content = self.get_content(strict) -        try: -            return encoding.decode(content, enc) -        except ValueError: -            if strict: -                raise -            return content.decode("utf8", "surrogateescape") - -    def set_text(self, text): -        if text is None: -            self.content = None -            return -        enc = self._guess_encoding() - -        try: -            self.content = encoding.encode(text, enc) -        except ValueError: -            # Fall back to UTF-8 and update the content-type header. -            ct = headers.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {}) -            ct[2]["charset"] = "utf-8" -            self.headers["content-type"] = headers.assemble_content_type(*ct) -            enc = "utf8" -            self.content = text.encode(enc, "surrogateescape") - -    text = property(get_text, set_text) - -    def decode(self, strict=True): -        """ -        Decodes body based on the current Content-Encoding header, then -        removes the header. If there is no Content-Encoding header, no -        action is taken. - -        Raises: -            ValueError, when the content-encoding is invalid and strict is True. -        """ -        self.raw_content = self.get_content(strict) -        self.headers.pop("content-encoding", None) - -    def encode(self, e): -        """ -        Encodes body with the encoding e, where e is "gzip", "deflate", "identity", or "br". -        Any existing content-encodings are overwritten, -        the content is not decoded beforehand. - -        Raises: -            ValueError, when the specified content-encoding is invalid. -        """ -        self.headers["content-encoding"] = e -        self.content = self.raw_content -        if "content-encoding" not in self.headers: -            raise ValueError("Invalid content encoding {}".format(repr(e))) - -    def replace(self, pattern, repl, flags=0, count=0): -        """ -        Replaces a regular expression pattern with repl in both the headers -        and the body of the message. Encoded body will be decoded -        before replacement, and re-encoded afterwards. - -        Returns: -            The number of replacements made. -        """ -        if isinstance(pattern, str): -            pattern = strutils.escaped_str_to_bytes(pattern) -        if isinstance(repl, str): -            repl = strutils.escaped_str_to_bytes(repl) -        replacements = 0 -        if self.content: -            self.content, replacements = re.subn( -                pattern, repl, self.content, flags=flags, count=count -            ) -        replacements += self.headers.replace(pattern, repl, flags=flags, count=count) -        return replacements - -    # Legacy - -    @property -    def body(self):  # pragma: no cover -        warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning) -        return self.content - -    @body.setter -    def body(self, body):  # pragma: no cover -        warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning) -        self.content = body - - -class decoded: -    """ -    Deprecated: You can now directly use :py:attr:`content`. -    :py:attr:`raw_content` has the encoded content. -    """ - -    def __init__(self, message):  # pragma no cover -        warnings.warn("decoded() is deprecated, you can now directly use .content instead. " -                      ".raw_content has the encoded content.", DeprecationWarning) - -    def __enter__(self):  # pragma no cover -        pass - -    def __exit__(self, type, value, tb):  # pragma no cover -        pass diff --git a/netlib/http/multipart.py b/netlib/http/multipart.py deleted file mode 100644 index 536b2809..00000000 --- a/netlib/http/multipart.py +++ /dev/null @@ -1,32 +0,0 @@ -import re - -from netlib.http import headers - - -def decode(hdrs, content): -    """ -        Takes a multipart boundary encoded string and returns list of (key, value) tuples. -    """ -    v = hdrs.get("content-type") -    if v: -        v = headers.parse_content_type(v) -        if not v: -            return [] -        try: -            boundary = v[2]["boundary"].encode("ascii") -        except (KeyError, UnicodeError): -            return [] - -        rx = re.compile(br'\bname="([^"]+)"') -        r = [] - -        for i in content.split(b"--" + boundary): -            parts = i.splitlines() -            if len(parts) > 1 and parts[0][0:2] != b"--": -                match = rx.search(parts[1]) -                if match: -                    key = match.group(1) -                    value = b"".join(parts[3 + parts[2:].index(b""):]) -                    r.append((key, value)) -        return r -    return [] diff --git a/netlib/http/request.py b/netlib/http/request.py deleted file mode 100644 index 16b0c986..00000000 --- a/netlib/http/request.py +++ /dev/null @@ -1,405 +0,0 @@ -import re -import urllib - -from mitmproxy.types import multidict -from mitmproxy.utils import strutils -from netlib.http import multipart -from netlib.http import cookies -from netlib.http import headers as nheaders -from netlib.http import message -import netlib.http.url - -# This regex extracts & splits the host header into host and port. -# Handles the edge case of IPv6 addresses containing colons. -# https://bugzilla.mozilla.org/show_bug.cgi?id=45891 -host_header_re = re.compile(r"^(?P<host>[^:]+|\[.+\])(?::(?P<port>\d+))?$") - - -class RequestData(message.MessageData): -    def __init__( -        self, -        first_line_format, -        method, -        scheme, -        host, -        port, -        path, -        http_version, -        headers=(), -        content=None, -        timestamp_start=None, -        timestamp_end=None -    ): -        if isinstance(method, str): -            method = method.encode("ascii", "strict") -        if isinstance(scheme, str): -            scheme = scheme.encode("ascii", "strict") -        if isinstance(host, str): -            host = host.encode("idna", "strict") -        if isinstance(path, str): -            path = path.encode("ascii", "strict") -        if isinstance(http_version, str): -            http_version = http_version.encode("ascii", "strict") -        if not isinstance(headers, nheaders.Headers): -            headers = nheaders.Headers(headers) -        if isinstance(content, str): -            raise ValueError("Content must be bytes, not {}".format(type(content).__name__)) - -        self.first_line_format = first_line_format -        self.method = method -        self.scheme = scheme -        self.host = host -        self.port = port -        self.path = path -        self.http_version = http_version -        self.headers = headers -        self.content = content -        self.timestamp_start = timestamp_start -        self.timestamp_end = timestamp_end - - -class Request(message.Message): -    """ -    An HTTP request. -    """ -    def __init__(self, *args, **kwargs): -        super().__init__() -        self.data = RequestData(*args, **kwargs) - -    def __repr__(self): -        if self.host and self.port: -            hostport = "{}:{}".format(self.host, self.port) -        else: -            hostport = "" -        path = self.path or "" -        return "Request({} {}{})".format( -            self.method, hostport, path -        ) - -    def replace(self, pattern, repl, flags=0, count=0): -        """ -            Replaces a regular expression pattern with repl in the headers, the -            request path and the body of the request. Encoded content will be -            decoded before replacement, and re-encoded afterwards. - -            Returns: -                The number of replacements made. -        """ -        if isinstance(pattern, str): -            pattern = strutils.escaped_str_to_bytes(pattern) -        if isinstance(repl, str): -            repl = strutils.escaped_str_to_bytes(repl) - -        c = super().replace(pattern, repl, flags, count) -        self.path, pc = re.subn( -            pattern, repl, self.data.path, flags=flags, count=count -        ) -        c += pc -        return c - -    @property -    def first_line_format(self): -        """ -        HTTP request form as defined in `RFC7230 <https://tools.ietf.org/html/rfc7230#section-5.3>`_. - -        origin-form and asterisk-form are subsumed as "relative". -        """ -        return self.data.first_line_format - -    @first_line_format.setter -    def first_line_format(self, first_line_format): -        self.data.first_line_format = first_line_format - -    @property -    def method(self): -        """ -        HTTP request method, e.g. "GET". -        """ -        return message._native(self.data.method).upper() - -    @method.setter -    def method(self, method): -        self.data.method = message._always_bytes(method) - -    @property -    def scheme(self): -        """ -        HTTP request scheme, which should be "http" or "https". -        """ -        if not self.data.scheme: -            return self.data.scheme -        return message._native(self.data.scheme) - -    @scheme.setter -    def scheme(self, scheme): -        self.data.scheme = message._always_bytes(scheme) - -    @property -    def host(self): -        """ -        Target host. This may be parsed from the raw request -        (e.g. from a ``GET http://example.com/ HTTP/1.1`` request line) -        or inferred from the proxy mode (e.g. an IP in transparent mode). - -        Setting the host attribute also updates the host header, if present. -        """ -        if not self.data.host: -            return self.data.host -        try: -            return self.data.host.decode("idna") -        except UnicodeError: -            return self.data.host.decode("utf8", "surrogateescape") - -    @host.setter -    def host(self, host): -        if isinstance(host, str): -            try: -                # There's no non-strict mode for IDNA encoding. -                # We don't want this operation to fail though, so we try -                # utf8 as a last resort. -                host = host.encode("idna", "strict") -            except UnicodeError: -                host = host.encode("utf8", "surrogateescape") - -        self.data.host = host - -        # Update host header -        if "host" in self.headers: -            if host: -                self.headers["host"] = host -            else: -                self.headers.pop("host") - -    @property -    def port(self): -        """ -        Target port -        """ -        return self.data.port - -    @port.setter -    def port(self, port): -        self.data.port = port - -    @property -    def path(self): -        """ -        HTTP request path, e.g. "/index.html". -        Guaranteed to start with a slash, except for OPTIONS requests, which may just be "*". -        """ -        if self.data.path is None: -            return None -        else: -            return message._native(self.data.path) - -    @path.setter -    def path(self, path): -        self.data.path = message._always_bytes(path) - -    @property -    def url(self): -        """ -        The URL string, constructed from the request's URL components -        """ -        if self.first_line_format == "authority": -            return "%s:%d" % (self.host, self.port) -        return netlib.http.url.unparse(self.scheme, self.host, self.port, self.path) - -    @url.setter -    def url(self, url): -        self.scheme, self.host, self.port, self.path = netlib.http.url.parse(url) - -    def _parse_host_header(self): -        """Extract the host and port from Host header""" -        if "host" not in self.headers: -            return None, None -        host, port = self.headers["host"], None -        m = host_header_re.match(host) -        if m: -            host = m.group("host").strip("[]") -            if m.group("port"): -                port = int(m.group("port")) -        return host, port - -    @property -    def pretty_host(self): -        """ -        Similar to :py:attr:`host`, but using the Host headers as an additional preferred data source. -        This is useful in transparent mode where :py:attr:`host` is only an IP address, -        but may not reflect the actual destination as the Host header could be spoofed. -        """ -        host, port = self._parse_host_header() -        if not host: -            return self.host -        if not port: -            port = 443 if self.scheme == 'https' else 80 -        # Prefer the original address if host header has an unexpected form -        return host if port == self.port else self.host - -    @property -    def pretty_url(self): -        """ -        Like :py:attr:`url`, but using :py:attr:`pretty_host` instead of :py:attr:`host`. -        """ -        if self.first_line_format == "authority": -            return "%s:%d" % (self.pretty_host, self.port) -        return netlib.http.url.unparse(self.scheme, self.pretty_host, self.port, self.path) - -    @property -    def query(self) -> multidict.MultiDictView: -        """ -        The request query string as an :py:class:`~netlib.multidict.MultiDictView` object. -        """ -        return multidict.MultiDictView( -            self._get_query, -            self._set_query -        ) - -    def _get_query(self): -        query = urllib.parse.urlparse(self.url).query -        return tuple(netlib.http.url.decode(query)) - -    def _set_query(self, query_data): -        query = netlib.http.url.encode(query_data) -        _, _, path, params, _, fragment = urllib.parse.urlparse(self.url) -        self.path = urllib.parse.urlunparse(["", "", path, params, query, fragment]) - -    @query.setter -    def query(self, value): -        self._set_query(value) - -    @property -    def cookies(self) -> multidict.MultiDictView: -        """ -        The request cookies. - -        An empty :py:class:`~netlib.multidict.MultiDictView` object if the cookie monster ate them all. -        """ -        return multidict.MultiDictView( -            self._get_cookies, -            self._set_cookies -        ) - -    def _get_cookies(self): -        h = self.headers.get_all("Cookie") -        return tuple(cookies.parse_cookie_headers(h)) - -    def _set_cookies(self, value): -        self.headers["cookie"] = cookies.format_cookie_header(value) - -    @cookies.setter -    def cookies(self, value): -        self._set_cookies(value) - -    @property -    def path_components(self): -        """ -        The URL's path components as a tuple of strings. -        Components are unquoted. -        """ -        path = urllib.parse.urlparse(self.url).path -        # This needs to be a tuple so that it's immutable. -        # Otherwise, this would fail silently: -        #   request.path_components.append("foo") -        return tuple(netlib.http.url.unquote(i) for i in path.split("/") if i) - -    @path_components.setter -    def path_components(self, components): -        components = map(lambda x: netlib.http.url.quote(x, safe=""), components) -        path = "/" + "/".join(components) -        _, _, _, params, query, fragment = urllib.parse.urlparse(self.url) -        self.path = urllib.parse.urlunparse(["", "", path, params, query, fragment]) - -    def anticache(self): -        """ -        Modifies this request to remove headers that might produce a cached -        response. That is, we remove ETags and If-Modified-Since headers. -        """ -        delheaders = [ -            "if-modified-since", -            "if-none-match", -        ] -        for i in delheaders: -            self.headers.pop(i, None) - -    def anticomp(self): -        """ -        Modifies this request to remove headers that will compress the -        resource's data. -        """ -        self.headers["accept-encoding"] = "identity" - -    def constrain_encoding(self): -        """ -        Limits the permissible Accept-Encoding values, based on what we can -        decode appropriately. -        """ -        accept_encoding = self.headers.get("accept-encoding") -        if accept_encoding: -            self.headers["accept-encoding"] = ( -                ', '.join( -                    e -                    for e in {"gzip", "identity", "deflate", "br"} -                    if e in accept_encoding -                ) -            ) - -    @property -    def urlencoded_form(self): -        """ -        The URL-encoded form data as an :py:class:`~netlib.multidict.MultiDictView` object. -        An empty multidict.MultiDictView if the content-type indicates non-form data -        or the content could not be parsed. -        """ -        return multidict.MultiDictView( -            self._get_urlencoded_form, -            self._set_urlencoded_form -        ) - -    def _get_urlencoded_form(self): -        is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() -        if is_valid_content_type: -            try: -                return tuple(netlib.http.url.decode(self.content)) -            except ValueError: -                pass -        return () - -    def _set_urlencoded_form(self, form_data): -        """ -        Sets the body to the URL-encoded form data, and adds the appropriate content-type header. -        This will overwrite the existing content if there is one. -        """ -        self.headers["content-type"] = "application/x-www-form-urlencoded" -        self.content = netlib.http.url.encode(form_data).encode() - -    @urlencoded_form.setter -    def urlencoded_form(self, value): -        self._set_urlencoded_form(value) - -    @property -    def multipart_form(self): -        """ -        The multipart form data as an :py:class:`~netlib.multidict.MultiDictView` object. -        None if the content-type indicates non-form data. -        """ -        return multidict.MultiDictView( -            self._get_multipart_form, -            self._set_multipart_form -        ) - -    def _get_multipart_form(self): -        is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() -        if is_valid_content_type: -            try: -                return multipart.decode(self.headers, self.content) -            except ValueError: -                pass -        return () - -    def _set_multipart_form(self, value): -        raise NotImplementedError() - -    @multipart_form.setter -    def multipart_form(self, value): -        self._set_multipart_form(value) diff --git a/netlib/http/response.py b/netlib/http/response.py deleted file mode 100644 index 4d1d5d24..00000000 --- a/netlib/http/response.py +++ /dev/null @@ -1,192 +0,0 @@ -import time -from email.utils import parsedate_tz, formatdate, mktime_tz -from mitmproxy.utils import human -from mitmproxy.types import multidict -from netlib.http import cookies -from netlib.http import headers as nheaders -from netlib.http import message -from netlib.http import status_codes -from typing import AnyStr -from typing import Dict -from typing import Iterable -from typing import Tuple -from typing import Union - - -class ResponseData(message.MessageData): -    def __init__( -        self, -        http_version, -        status_code, -        reason=None, -        headers=(), -        content=None, -        timestamp_start=None, -        timestamp_end=None -    ): -        if isinstance(http_version, str): -            http_version = http_version.encode("ascii", "strict") -        if isinstance(reason, str): -            reason = reason.encode("ascii", "strict") -        if not isinstance(headers, nheaders.Headers): -            headers = nheaders.Headers(headers) -        if isinstance(content, str): -            raise ValueError("Content must be bytes, not {}".format(type(content).__name__)) - -        self.http_version = http_version -        self.status_code = status_code -        self.reason = reason -        self.headers = headers -        self.content = content -        self.timestamp_start = timestamp_start -        self.timestamp_end = timestamp_end - - -class Response(message.Message): -    """ -    An HTTP response. -    """ -    def __init__(self, *args, **kwargs): -        super().__init__() -        self.data = ResponseData(*args, **kwargs) - -    def __repr__(self): -        if self.raw_content: -            details = "{}, {}".format( -                self.headers.get("content-type", "unknown content type"), -                human.pretty_size(len(self.raw_content)) -            ) -        else: -            details = "no content" -        return "Response({status_code} {reason}, {details})".format( -            status_code=self.status_code, -            reason=self.reason, -            details=details -        ) - -    @classmethod -    def make( -            cls, -            status_code: int=200, -            content: AnyStr=b"", -            headers: Union[Dict[AnyStr, AnyStr], Iterable[Tuple[bytes, bytes]]]=() -    ): -        """ -        Simplified API for creating response objects. -        """ -        resp = cls( -            b"HTTP/1.1", -            status_code, -            status_codes.RESPONSES.get(status_code, "").encode(), -            (), -            None -        ) - -        # Headers can be list or dict, we differentiate here. -        if isinstance(headers, dict): -            resp.headers = nheaders.Headers(**headers) -        elif isinstance(headers, Iterable): -            resp.headers = nheaders.Headers(headers) -        else: -            raise TypeError("Expected headers to be an iterable or dict, but is {}.".format( -                type(headers).__name__ -            )) - -        # Assign this manually to update the content-length header. -        if isinstance(content, bytes): -            resp.content = content -        elif isinstance(content, str): -            resp.text = content -        else: -            raise TypeError("Expected content to be str or bytes, but is {}.".format( -                type(content).__name__ -            )) - -        return resp - -    @property -    def status_code(self): -        """ -        HTTP Status Code, e.g. ``200``. -        """ -        return self.data.status_code - -    @status_code.setter -    def status_code(self, status_code): -        self.data.status_code = status_code - -    @property -    def reason(self): -        """ -        HTTP Reason Phrase, e.g. "Not Found". -        This is always :py:obj:`None` for HTTP2 requests, because HTTP2 responses do not contain a reason phrase. -        """ -        return message._native(self.data.reason) - -    @reason.setter -    def reason(self, reason): -        self.data.reason = message._always_bytes(reason) - -    @property -    def cookies(self) -> multidict.MultiDictView: -        """ -        The response cookies. A possibly empty -        :py:class:`~netlib.multidict.MultiDictView`, where the keys are cookie -        name strings, and values are (value, attr) tuples. Value is a string, -        and attr is an MultiDictView containing cookie attributes. Within -        attrs, unary attributes (e.g. HTTPOnly) are indicated by a Null value. - -        Caveats: -            Updating the attr -        """ -        return multidict.MultiDictView( -            self._get_cookies, -            self._set_cookies -        ) - -    def _get_cookies(self): -        h = self.headers.get_all("set-cookie") -        return tuple(cookies.parse_set_cookie_headers(h)) - -    def _set_cookies(self, value): -        cookie_headers = [] -        for k, v in value: -            header = cookies.format_set_cookie_header([(k, v[0], v[1])]) -            cookie_headers.append(header) -        self.headers.set_all("set-cookie", cookie_headers) - -    @cookies.setter -    def cookies(self, value): -        self._set_cookies(value) - -    def refresh(self, now=None): -        """ -        This fairly complex and heuristic function refreshes a server -        response for replay. - -            - It adjusts date, expires and last-modified headers. -            - It adjusts cookie expiration. -        """ -        if not now: -            now = time.time() -        delta = now - self.timestamp_start -        refresh_headers = [ -            "date", -            "expires", -            "last-modified", -        ] -        for i in refresh_headers: -            if i in self.headers: -                d = parsedate_tz(self.headers[i]) -                if d: -                    new = mktime_tz(d) + delta -                    self.headers[i] = formatdate(new) -        c = [] -        for set_cookie_header in self.headers.get_all("set-cookie"): -            try: -                refreshed = cookies.refresh_set_cookie_header(set_cookie_header, delta) -            except ValueError: -                refreshed = set_cookie_header -            c.append(refreshed) -        if c: -            self.headers.set_all("set-cookie", c) diff --git a/netlib/http/status_codes.py b/netlib/http/status_codes.py deleted file mode 100644 index 5a83cd73..00000000 --- a/netlib/http/status_codes.py +++ /dev/null @@ -1,104 +0,0 @@ -CONTINUE = 100 -SWITCHING = 101 -OK = 200 -CREATED = 201 -ACCEPTED = 202 -NON_AUTHORITATIVE_INFORMATION = 203 -NO_CONTENT = 204 -RESET_CONTENT = 205 -PARTIAL_CONTENT = 206 -MULTI_STATUS = 207 - -MULTIPLE_CHOICE = 300 -MOVED_PERMANENTLY = 301 -FOUND = 302 -SEE_OTHER = 303 -NOT_MODIFIED = 304 -USE_PROXY = 305 -TEMPORARY_REDIRECT = 307 - -BAD_REQUEST = 400 -UNAUTHORIZED = 401 -PAYMENT_REQUIRED = 402 -FORBIDDEN = 403 -NOT_FOUND = 404 -NOT_ALLOWED = 405 -NOT_ACCEPTABLE = 406 -PROXY_AUTH_REQUIRED = 407 -REQUEST_TIMEOUT = 408 -CONFLICT = 409 -GONE = 410 -LENGTH_REQUIRED = 411 -PRECONDITION_FAILED = 412 -REQUEST_ENTITY_TOO_LARGE = 413 -REQUEST_URI_TOO_LONG = 414 -UNSUPPORTED_MEDIA_TYPE = 415 -REQUESTED_RANGE_NOT_SATISFIABLE = 416 -EXPECTATION_FAILED = 417 -IM_A_TEAPOT = 418 - -INTERNAL_SERVER_ERROR = 500 -NOT_IMPLEMENTED = 501 -BAD_GATEWAY = 502 -SERVICE_UNAVAILABLE = 503 -GATEWAY_TIMEOUT = 504 -HTTP_VERSION_NOT_SUPPORTED = 505 -INSUFFICIENT_STORAGE_SPACE = 507 -NOT_EXTENDED = 510 - -RESPONSES = { -    # 100 -    CONTINUE: "Continue", -    SWITCHING: "Switching Protocols", - -    # 200 -    OK: "OK", -    CREATED: "Created", -    ACCEPTED: "Accepted", -    NON_AUTHORITATIVE_INFORMATION: "Non-Authoritative Information", -    NO_CONTENT: "No Content", -    RESET_CONTENT: "Reset Content.", -    PARTIAL_CONTENT: "Partial Content", -    MULTI_STATUS: "Multi-Status", - -    # 300 -    MULTIPLE_CHOICE: "Multiple Choices", -    MOVED_PERMANENTLY: "Moved Permanently", -    FOUND: "Found", -    SEE_OTHER: "See Other", -    NOT_MODIFIED: "Not Modified", -    USE_PROXY: "Use Proxy", -    # 306 not defined?? -    TEMPORARY_REDIRECT: "Temporary Redirect", - -    # 400 -    BAD_REQUEST: "Bad Request", -    UNAUTHORIZED: "Unauthorized", -    PAYMENT_REQUIRED: "Payment Required", -    FORBIDDEN: "Forbidden", -    NOT_FOUND: "Not Found", -    NOT_ALLOWED: "Method Not Allowed", -    NOT_ACCEPTABLE: "Not Acceptable", -    PROXY_AUTH_REQUIRED: "Proxy Authentication Required", -    REQUEST_TIMEOUT: "Request Time-out", -    CONFLICT: "Conflict", -    GONE: "Gone", -    LENGTH_REQUIRED: "Length Required", -    PRECONDITION_FAILED: "Precondition Failed", -    REQUEST_ENTITY_TOO_LARGE: "Request Entity Too Large", -    REQUEST_URI_TOO_LONG: "Request-URI Too Long", -    UNSUPPORTED_MEDIA_TYPE: "Unsupported Media Type", -    REQUESTED_RANGE_NOT_SATISFIABLE: "Requested Range not satisfiable", -    EXPECTATION_FAILED: "Expectation Failed", -    IM_A_TEAPOT: "I'm a teapot", - -    # 500 -    INTERNAL_SERVER_ERROR: "Internal Server Error", -    NOT_IMPLEMENTED: "Not Implemented", -    BAD_GATEWAY: "Bad Gateway", -    SERVICE_UNAVAILABLE: "Service Unavailable", -    GATEWAY_TIMEOUT: "Gateway Time-out", -    HTTP_VERSION_NOT_SUPPORTED: "HTTP Version not supported", -    INSUFFICIENT_STORAGE_SPACE: "Insufficient Storage Space", -    NOT_EXTENDED: "Not Extended" -} diff --git a/netlib/http/url.py b/netlib/http/url.py deleted file mode 100644 index 3ca58120..00000000 --- a/netlib/http/url.py +++ /dev/null @@ -1,127 +0,0 @@ -import urllib -from typing import Sequence -from typing import Tuple - -from netlib import check - - -# PY2 workaround -def decode_parse_result(result, enc): -    if hasattr(result, "decode"): -        return result.decode(enc) -    else: -        return urllib.parse.ParseResult(*[x.decode(enc) for x in result]) - - -# PY2 workaround -def encode_parse_result(result, enc): -    if hasattr(result, "encode"): -        return result.encode(enc) -    else: -        return urllib.parse.ParseResult(*[x.encode(enc) for x in result]) - - -def parse(url): -    """ -        URL-parsing function that checks that -            - port is an integer 0-65535 -            - host is a valid IDNA-encoded hostname with no null-bytes -            - path is valid ASCII - -        Args: -            A URL (as bytes or as unicode) - -        Returns: -            A (scheme, host, port, path) tuple - -        Raises: -            ValueError, if the URL is not properly formatted. -    """ -    parsed = urllib.parse.urlparse(url) - -    if not parsed.hostname: -        raise ValueError("No hostname given") - -    if isinstance(url, bytes): -        host = parsed.hostname - -        # this should not raise a ValueError, -        # but we try to be very forgiving here and accept just everything. -        # decode_parse_result(parsed, "ascii") -    else: -        host = parsed.hostname.encode("idna") -        parsed = encode_parse_result(parsed, "ascii") - -    port = parsed.port -    if not port: -        port = 443 if parsed.scheme == b"https" else 80 - -    full_path = urllib.parse.urlunparse( -        (b"", b"", parsed.path, parsed.params, parsed.query, parsed.fragment) -    ) -    if not full_path.startswith(b"/"): -        full_path = b"/" + full_path - -    if not check.is_valid_host(host): -        raise ValueError("Invalid Host") -    if not check.is_valid_port(port): -        raise ValueError("Invalid Port") - -    return parsed.scheme, host, port, full_path - - -def unparse(scheme, host, port, path=""): -    """ -    Returns a URL string, constructed from the specified components. - -    Args: -        All args must be str. -    """ -    if path == "*": -        path = "" -    return "%s://%s%s" % (scheme, hostport(scheme, host, port), path) - - -def encode(s: Sequence[Tuple[str, str]]) -> str: -    """ -        Takes a list of (key, value) tuples and returns a urlencoded string. -    """ -    return urllib.parse.urlencode(s, False, errors="surrogateescape") - - -def decode(s): -    """ -        Takes a urlencoded string and returns a list of surrogate-escaped (key, value) tuples. -    """ -    return urllib.parse.parse_qsl(s, keep_blank_values=True, errors='surrogateescape') - - -def quote(b: str, safe: str="/") -> str: -    """ -    Returns: -        An ascii-encodable str. -    """ -    return urllib.parse.quote(b, safe=safe, errors="surrogateescape") - - -def unquote(s: str) -> str: -    """ -    Args: -        s: A surrogate-escaped str -    Returns: -        A surrogate-escaped str -    """ -    return urllib.parse.unquote(s, errors="surrogateescape") - - -def hostport(scheme, host, port): -    """ -        Returns the host component, with a port specifcation if needed. -    """ -    if (port, scheme) in [(80, "http"), (443, "https"), (80, b"http"), (443, b"https")]: -        return host -    else: -        if isinstance(host, bytes): -            return b"%s:%d" % (host, port) -        else: -            return "%s:%d" % (host, port) diff --git a/netlib/http/user_agents.py b/netlib/http/user_agents.py deleted file mode 100644 index d0ca2f21..00000000 --- a/netlib/http/user_agents.py +++ /dev/null @@ -1,50 +0,0 @@ -""" -    A small collection of useful user-agent header strings. These should be -    kept reasonably current to reflect common usage. -""" - -# pylint: line-too-long - -# A collection of (name, shortcut, string) tuples. - -UASTRINGS = [ -    ("android", -     "a", -     "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02"),  # noqa -    ("blackberry", -     "l", -     "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+"),  # noqa -    ("bingbot", -     "b", -     "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)"),  # noqa -    ("chrome", -     "c", -     "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1"),  # noqa -    ("firefox", -     "f", -     "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1"),  # noqa -    ("googlebot", -     "g", -     "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"),  # noqa -    ("ie9", -     "i", -     "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US)"),  # noqa -    ("ipad", -     "p", -     "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9B176 Safari/7534.48.3"),  # noqa -    ("iphone", -     "h", -     "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5"),  # noqa -    ("safari", -     "s", -     "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10"),  # noqa -] - - -def get_by_shortcut(s): -    """ -        Retrieve a user agent entry by shortcut. -    """ -    for i in UASTRINGS: -        if s == i[1]: -            return i diff --git a/netlib/socks.py b/netlib/socks.py deleted file mode 100644 index 377308a8..00000000 --- a/netlib/socks.py +++ /dev/null @@ -1,234 +0,0 @@ -import struct -import array -import ipaddress - -from netlib import tcp -from netlib import check -from mitmproxy.types import bidi - - -class SocksError(Exception): -    def __init__(self, code, message): -        super().__init__(message) -        self.code = code - -VERSION = bidi.BiDi( -    SOCKS4=0x04, -    SOCKS5=0x05 -) - -CMD = bidi.BiDi( -    CONNECT=0x01, -    BIND=0x02, -    UDP_ASSOCIATE=0x03 -) - -ATYP = bidi.BiDi( -    IPV4_ADDRESS=0x01, -    DOMAINNAME=0x03, -    IPV6_ADDRESS=0x04 -) - -REP = bidi.BiDi( -    SUCCEEDED=0x00, -    GENERAL_SOCKS_SERVER_FAILURE=0x01, -    CONNECTION_NOT_ALLOWED_BY_RULESET=0x02, -    NETWORK_UNREACHABLE=0x03, -    HOST_UNREACHABLE=0x04, -    CONNECTION_REFUSED=0x05, -    TTL_EXPIRED=0x06, -    COMMAND_NOT_SUPPORTED=0x07, -    ADDRESS_TYPE_NOT_SUPPORTED=0x08, -) - -METHOD = bidi.BiDi( -    NO_AUTHENTICATION_REQUIRED=0x00, -    GSSAPI=0x01, -    USERNAME_PASSWORD=0x02, -    NO_ACCEPTABLE_METHODS=0xFF -) - -USERNAME_PASSWORD_VERSION = bidi.BiDi( -    DEFAULT=0x01 -) - - -class ClientGreeting: -    __slots__ = ("ver", "methods") - -    def __init__(self, ver, methods): -        self.ver = ver -        self.methods = array.array("B") -        self.methods.extend(methods) - -    def assert_socks5(self): -        if self.ver != VERSION.SOCKS5: -            if self.ver == ord("G") and len(self.methods) == ord("E"): -                guess = "Probably not a SOCKS request but a regular HTTP request. " -            else: -                guess = "" - -            raise SocksError( -                REP.GENERAL_SOCKS_SERVER_FAILURE, -                guess + "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.ver -            ) - -    @classmethod -    def from_file(cls, f, fail_early=False): -        """ -        :param fail_early: If true, a SocksError will be raised if the first byte does not indicate socks5. -        """ -        ver, nmethods = struct.unpack("!BB", f.safe_read(2)) -        client_greeting = cls(ver, []) -        if fail_early: -            client_greeting.assert_socks5() -        client_greeting.methods.fromstring(f.safe_read(nmethods)) -        return client_greeting - -    def to_file(self, f): -        f.write(struct.pack("!BB", self.ver, len(self.methods))) -        f.write(self.methods.tostring()) - - -class ServerGreeting: -    __slots__ = ("ver", "method") - -    def __init__(self, ver, method): -        self.ver = ver -        self.method = method - -    def assert_socks5(self): -        if self.ver != VERSION.SOCKS5: -            if self.ver == ord("H") and self.method == ord("T"): -                guess = "Probably not a SOCKS request but a regular HTTP response. " -            else: -                guess = "" - -            raise SocksError( -                REP.GENERAL_SOCKS_SERVER_FAILURE, -                guess + "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.ver -            ) - -    @classmethod -    def from_file(cls, f): -        ver, method = struct.unpack("!BB", f.safe_read(2)) -        return cls(ver, method) - -    def to_file(self, f): -        f.write(struct.pack("!BB", self.ver, self.method)) - - -class UsernamePasswordAuth: -    __slots__ = ("ver", "username", "password") - -    def __init__(self, ver, username, password): -        self.ver = ver -        self.username = username -        self.password = password - -    def assert_authver1(self): -        if self.ver != USERNAME_PASSWORD_VERSION.DEFAULT: -            raise SocksError( -                0, -                "Invalid auth version. Expected 0x01, got 0x%x" % self.ver -            ) - -    @classmethod -    def from_file(cls, f): -        ver, ulen = struct.unpack("!BB", f.safe_read(2)) -        username = f.safe_read(ulen) -        plen, = struct.unpack("!B", f.safe_read(1)) -        password = f.safe_read(plen) -        return cls(ver, username.decode(), password.decode()) - -    def to_file(self, f): -        f.write(struct.pack("!BB", self.ver, len(self.username))) -        f.write(self.username.encode()) -        f.write(struct.pack("!B", len(self.password))) -        f.write(self.password.encode()) - - -class UsernamePasswordAuthResponse: -    __slots__ = ("ver", "status") - -    def __init__(self, ver, status): -        self.ver = ver -        self.status = status - -    def assert_authver1(self): -        if self.ver != USERNAME_PASSWORD_VERSION.DEFAULT: -            raise SocksError( -                0, -                "Invalid auth version. Expected 0x01, got 0x%x" % self.ver -            ) - -    @classmethod -    def from_file(cls, f): -        ver, status = struct.unpack("!BB", f.safe_read(2)) -        return cls(ver, status) - -    def to_file(self, f): -        f.write(struct.pack("!BB", self.ver, self.status)) - - -class Message: -    __slots__ = ("ver", "msg", "atyp", "addr") - -    def __init__(self, ver, msg, atyp, addr): -        self.ver = ver -        self.msg = msg -        self.atyp = atyp -        self.addr = tcp.Address.wrap(addr) - -    def assert_socks5(self): -        if self.ver != VERSION.SOCKS5: -            raise SocksError( -                REP.GENERAL_SOCKS_SERVER_FAILURE, -                "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.ver -            ) - -    @classmethod -    def from_file(cls, f): -        ver, msg, rsv, atyp = struct.unpack("!BBBB", f.safe_read(4)) -        if rsv != 0x00: -            raise SocksError( -                REP.GENERAL_SOCKS_SERVER_FAILURE, -                "Socks Request: Invalid reserved byte: %s" % rsv -            ) -        if atyp == ATYP.IPV4_ADDRESS: -            # We use tnoa here as ntop is not commonly available on Windows. -            host = ipaddress.IPv4Address(f.safe_read(4)).compressed -            use_ipv6 = False -        elif atyp == ATYP.IPV6_ADDRESS: -            host = ipaddress.IPv6Address(f.safe_read(16)).compressed -            use_ipv6 = True -        elif atyp == ATYP.DOMAINNAME: -            length, = struct.unpack("!B", f.safe_read(1)) -            host = f.safe_read(length) -            if not check.is_valid_host(host): -                raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, "Invalid hostname: %s" % host) -            host = host.decode("idna") -            use_ipv6 = False -        else: -            raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED, -                             "Socks Request: Unknown ATYP: %s" % atyp) - -        port, = struct.unpack("!H", f.safe_read(2)) -        addr = tcp.Address((host, port), use_ipv6=use_ipv6) -        return cls(ver, msg, atyp, addr) - -    def to_file(self, f): -        f.write(struct.pack("!BBBB", self.ver, self.msg, 0x00, self.atyp)) -        if self.atyp == ATYP.IPV4_ADDRESS: -            f.write(ipaddress.IPv4Address(self.addr.host).packed) -        elif self.atyp == ATYP.IPV6_ADDRESS: -            f.write(ipaddress.IPv6Address(self.addr.host).packed) -        elif self.atyp == ATYP.DOMAINNAME: -            f.write(struct.pack("!B", len(self.addr.host))) -            f.write(self.addr.host.encode("idna")) -        else: -            raise SocksError( -                REP.ADDRESS_TYPE_NOT_SUPPORTED, -                "Unknown ATYP: %s" % self.atyp -            ) -        f.write(struct.pack("!H", self.addr.port)) diff --git a/netlib/tcp.py b/netlib/tcp.py deleted file mode 100644 index ac368a9c..00000000 --- a/netlib/tcp.py +++ /dev/null @@ -1,989 +0,0 @@ -import os -import select -import socket -import sys -import threading -import time -import traceback - -import binascii - -from typing import Optional  # noqa - -from mitmproxy.utils import strutils - -import certifi -from backports import ssl_match_hostname -import OpenSSL -from OpenSSL import SSL - -from mitmproxy import certs -from mitmproxy.utils import version_check -from mitmproxy.types import serializable -from mitmproxy import exceptions -from mitmproxy.types import basethread - -# This is a rather hackish way to make sure that -# the latest version of pyOpenSSL is actually installed. -version_check.check_pyopenssl_version() - -socket_fileobject = socket.SocketIO - -EINTR = 4 -if os.environ.get("NO_ALPN"): -    HAS_ALPN = False -else: -    HAS_ALPN = SSL._lib.Cryptography_HAS_ALPN - -# To enable all SSL methods use: SSLv23 -# then add options to disable certain methods -# https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3 -SSL_BASIC_OPTIONS = ( -    SSL.OP_CIPHER_SERVER_PREFERENCE -) -if hasattr(SSL, "OP_NO_COMPRESSION"): -    SSL_BASIC_OPTIONS |= SSL.OP_NO_COMPRESSION - -SSL_DEFAULT_METHOD = SSL.SSLv23_METHOD -SSL_DEFAULT_OPTIONS = ( -    SSL.OP_NO_SSLv2 | -    SSL.OP_NO_SSLv3 | -    SSL_BASIC_OPTIONS -) -if hasattr(SSL, "OP_NO_COMPRESSION"): -    SSL_DEFAULT_OPTIONS |= SSL.OP_NO_COMPRESSION - -""" -Map a reasonable SSL version specification into the format OpenSSL expects. -Don't ask... -https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3 -""" -sslversion_choices = { -    "all": (SSL.SSLv23_METHOD, SSL_BASIC_OPTIONS), -    # SSLv23_METHOD + NO_SSLv2 + NO_SSLv3 == TLS 1.0+ -    # TLSv1_METHOD would be TLS 1.0 only -    "secure": (SSL.SSLv23_METHOD, (SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL_BASIC_OPTIONS)), -    "SSLv2": (SSL.SSLv2_METHOD, SSL_BASIC_OPTIONS), -    "SSLv3": (SSL.SSLv3_METHOD, SSL_BASIC_OPTIONS), -    "TLSv1": (SSL.TLSv1_METHOD, SSL_BASIC_OPTIONS), -    "TLSv1_1": (SSL.TLSv1_1_METHOD, SSL_BASIC_OPTIONS), -    "TLSv1_2": (SSL.TLSv1_2_METHOD, SSL_BASIC_OPTIONS), -} - - -class SSLKeyLogger: - -    def __init__(self, filename): -        self.filename = filename -        self.f = None -        self.lock = threading.Lock() - -    # required for functools.wraps, which pyOpenSSL uses. -    __name__ = "SSLKeyLogger" - -    def __call__(self, connection, where, ret): -        if where == SSL.SSL_CB_HANDSHAKE_DONE and ret == 1: -            with self.lock: -                if not self.f: -                    d = os.path.dirname(self.filename) -                    if not os.path.isdir(d): -                        os.makedirs(d) -                    self.f = open(self.filename, "ab") -                    self.f.write(b"\r\n") -                client_random = binascii.hexlify(connection.client_random()) -                masterkey = binascii.hexlify(connection.master_key()) -                self.f.write(b"CLIENT_RANDOM %s %s\r\n" % (client_random, masterkey)) -                self.f.flush() - -    def close(self): -        with self.lock: -            if self.f: -                self.f.close() - -    @staticmethod -    def create_logfun(filename): -        if filename: -            return SSLKeyLogger(filename) -        return False - -log_ssl_key = SSLKeyLogger.create_logfun( -    os.getenv("MITMPROXY_SSLKEYLOGFILE") or os.getenv("SSLKEYLOGFILE")) - - -class _FileLike: -    BLOCKSIZE = 1024 * 32 - -    def __init__(self, o): -        self.o = o -        self._log = None -        self.first_byte_timestamp = None - -    def set_descriptor(self, o): -        self.o = o - -    def __getattr__(self, attr): -        return getattr(self.o, attr) - -    def start_log(self): -        """ -            Starts or resets the log. - -            This will store all bytes read or written. -        """ -        self._log = [] - -    def stop_log(self): -        """ -            Stops the log. -        """ -        self._log = None - -    def is_logging(self): -        return self._log is not None - -    def get_log(self): -        """ -            Returns the log as a string. -        """ -        if not self.is_logging(): -            raise ValueError("Not logging!") -        return b"".join(self._log) - -    def add_log(self, v): -        if self.is_logging(): -            self._log.append(v) - -    def reset_timestamps(self): -        self.first_byte_timestamp = None - - -class Writer(_FileLike): - -    def flush(self): -        """ -            May raise exceptions.TcpDisconnect -        """ -        if hasattr(self.o, "flush"): -            try: -                self.o.flush() -            except (socket.error, IOError) as v: -                raise exceptions.TcpDisconnect(str(v)) - -    def write(self, v): -        """ -            May raise exceptions.TcpDisconnect -        """ -        if v: -            self.first_byte_timestamp = self.first_byte_timestamp or time.time() -            try: -                if hasattr(self.o, "sendall"): -                    self.add_log(v) -                    return self.o.sendall(v) -                else: -                    r = self.o.write(v) -                    self.add_log(v[:r]) -                    return r -            except (SSL.Error, socket.error) as e: -                raise exceptions.TcpDisconnect(str(e)) - - -class Reader(_FileLike): - -    def read(self, length): -        """ -            If length is -1, we read until connection closes. -        """ -        result = b'' -        start = time.time() -        while length == -1 or length > 0: -            if length == -1 or length > self.BLOCKSIZE: -                rlen = self.BLOCKSIZE -            else: -                rlen = length -            try: -                data = self.o.read(rlen) -            except SSL.ZeroReturnError: -                # TLS connection was shut down cleanly -                break -            except (SSL.WantWriteError, SSL.WantReadError): -                # From the OpenSSL docs: -                # If the underlying BIO is non-blocking, SSL_read() will also return when the -                # underlying BIO could not satisfy the needs of SSL_read() to continue the -                # operation. In this case a call to SSL_get_error with the return value of -                # SSL_read() will yield SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE. -                if (time.time() - start) < self.o.gettimeout(): -                    time.sleep(0.1) -                    continue -                else: -                    raise exceptions.TcpTimeout() -            except socket.timeout: -                raise exceptions.TcpTimeout() -            except socket.error as e: -                raise exceptions.TcpDisconnect(str(e)) -            except SSL.SysCallError as e: -                if e.args == (-1, 'Unexpected EOF'): -                    break -                raise exceptions.TlsException(str(e)) -            except SSL.Error as e: -                raise exceptions.TlsException(str(e)) -            self.first_byte_timestamp = self.first_byte_timestamp or time.time() -            if not data: -                break -            result += data -            if length != -1: -                length -= len(data) -        self.add_log(result) -        return result - -    def readline(self, size=None): -        result = b'' -        bytes_read = 0 -        while True: -            if size is not None and bytes_read >= size: -                break -            ch = self.read(1) -            bytes_read += 1 -            if not ch: -                break -            else: -                result += ch -                if ch == b'\n': -                    break -        return result - -    def safe_read(self, length): -        """ -            Like .read, but is guaranteed to either return length bytes, or -            raise an exception. -        """ -        result = self.read(length) -        if length != -1 and len(result) != length: -            if not result: -                raise exceptions.TcpDisconnect() -            else: -                raise exceptions.TcpReadIncomplete( -                    "Expected %s bytes, got %s" % (length, len(result)) -                ) -        return result - -    def peek(self, length): -        """ -        Tries to peek into the underlying file object. - -        Returns: -            Up to the next N bytes if peeking is successful. - -        Raises: -            exceptions.TcpException if there was an error with the socket -            TlsException if there was an error with pyOpenSSL. -            NotImplementedError if the underlying file object is not a [pyOpenSSL] socket -        """ -        if isinstance(self.o, socket_fileobject): -            try: -                return self.o._sock.recv(length, socket.MSG_PEEK) -            except socket.error as e: -                raise exceptions.TcpException(repr(e)) -        elif isinstance(self.o, SSL.Connection): -            try: -                return self.o.recv(length, socket.MSG_PEEK) -            except SSL.Error as e: -                raise exceptions.TlsException(str(e)) -        else: -            raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") - - -class Address(serializable.Serializable): - -    """ -        This class wraps an IPv4/IPv6 tuple to provide named attributes and -        ipv6 information. -    """ - -    def __init__(self, address, use_ipv6=False): -        self.address = tuple(address) -        self.use_ipv6 = use_ipv6 - -    def get_state(self): -        return { -            "address": self.address, -            "use_ipv6": self.use_ipv6 -        } - -    def set_state(self, state): -        self.address = state["address"] -        self.use_ipv6 = state["use_ipv6"] - -    @classmethod -    def from_state(cls, state): -        return Address(**state) - -    @classmethod -    def wrap(cls, t): -        if isinstance(t, cls): -            return t -        else: -            return cls(t) - -    def __call__(self): -        return self.address - -    @property -    def host(self): -        return self.address[0] - -    @property -    def port(self): -        return self.address[1] - -    @property -    def use_ipv6(self): -        return self.family == socket.AF_INET6 - -    @use_ipv6.setter -    def use_ipv6(self, b): -        self.family = socket.AF_INET6 if b else socket.AF_INET - -    def __repr__(self): -        return "{}:{}".format(self.host, self.port) - -    def __eq__(self, other): -        if not other: -            return False -        other = Address.wrap(other) -        return (self.address, self.family) == (other.address, other.family) - -    def __ne__(self, other): -        return not self.__eq__(other) - -    def __hash__(self): -        return hash(self.address) ^ 42  # different hash than the tuple alone. - - -def ssl_read_select(rlist, timeout): -    """ -    This is a wrapper around select.select() which also works for SSL.Connections -    by taking ssl_connection.pending() into account. - -    Caveats: -        If .pending() > 0 for any of the connections in rlist, we avoid the select syscall -        and **will not include any other connections which may or may not be ready**. - -    Args: -        rlist: wait until ready for reading - -    Returns: -        subset of rlist which is ready for reading. -    """ -    return [ -        conn for conn in rlist -        if isinstance(conn, SSL.Connection) and conn.pending() > 0 -    ] or select.select(rlist, (), (), timeout)[0] - - -def close_socket(sock): -    """ -    Does a hard close of a socket, without emitting a RST. -    """ -    try: -        # We already indicate that we close our end. -        # may raise "Transport endpoint is not connected" on Linux -        sock.shutdown(socket.SHUT_WR) - -        # Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending -        # readable data could lead to an immediate RST being sent (which is the -        # case on Windows). -        # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html -        # -        # This in turn results in the following issue: If we send an error page -        # to the client and then close the socket, the RST may be received by -        # the client before the error page and the users sees a connection -        # error rather than the error page. Thus, we try to empty the read -        # buffer on Windows first. (see -        # https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988) -        # - -        if os.name == "nt":  # pragma: no cover -            # We cannot rely on the shutdown()-followed-by-read()-eof technique -            # proposed by the page above: Some remote machines just don't send -            # a TCP FIN, which would leave us in the unfortunate situation that -            # recv() would block infinitely. As a workaround, we set a timeout -            # here even if we are in blocking mode. -            sock.settimeout(sock.gettimeout() or 20) - -            # limit at a megabyte so that we don't read infinitely -            for _ in range(1024 ** 3 // 4096): -                # may raise a timeout/disconnect exception. -                if not sock.recv(4096): -                    break - -        # Now we can close the other half as well. -        sock.shutdown(socket.SHUT_RD) - -    except socket.error: -        pass - -    sock.close() - - -class _Connection: - -    rbufsize = -1 -    wbufsize = -1 - -    def _makefile(self): -        """ -        Set up .rfile and .wfile attributes from .connection -        """ -        # Ideally, we would use the Buffered IO in Python 3 by default. -        # Unfortunately, the implementation of .peek() is broken for n>1 bytes, -        # as it may just return what's left in the buffer and not all the bytes we want. -        # As a workaround, we just use unbuffered sockets directly. -        # https://mail.python.org/pipermail/python-dev/2009-June/089986.html -        self.rfile = Reader(socket.SocketIO(self.connection, "rb")) -        self.wfile = Writer(socket.SocketIO(self.connection, "wb")) - -    def __init__(self, connection): -        if connection: -            self.connection = connection -            self.ip_address = Address(connection.getpeername()) -            self._makefile() -        else: -            self.connection = None -            self.ip_address = None -            self.rfile = None -            self.wfile = None - -        self.ssl_established = False -        self.finished = False - -    def get_current_cipher(self): -        if not self.ssl_established: -            return None - -        name = self.connection.get_cipher_name() -        bits = self.connection.get_cipher_bits() -        version = self.connection.get_cipher_version() -        return name, bits, version - -    def finish(self): -        self.finished = True -        # If we have an SSL connection, wfile.close == connection.close -        # (We call _FileLike.set_descriptor(conn)) -        # Closing the socket is not our task, therefore we don't call close -        # then. -        if not isinstance(self.connection, SSL.Connection): -            if not getattr(self.wfile, "closed", False): -                try: -                    self.wfile.flush() -                    self.wfile.close() -                except exceptions.TcpDisconnect: -                    pass - -            self.rfile.close() -        else: -            try: -                self.connection.shutdown() -            except SSL.Error: -                pass - -    def _create_ssl_context(self, -                            method=SSL_DEFAULT_METHOD, -                            options=SSL_DEFAULT_OPTIONS, -                            verify_options=SSL.VERIFY_NONE, -                            ca_path=None, -                            ca_pemfile=None, -                            cipher_list=None, -                            alpn_protos=None, -                            alpn_select=None, -                            alpn_select_callback=None, -                            sni=None, -                            ): -        """ -        Creates an SSL Context. - -        :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD, TLSv1_1_METHOD, or TLSv1_2_METHOD -        :param options: A bit field consisting of OpenSSL.SSL.OP_* values -        :param verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values -        :param ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool -        :param ca_pemfile: Path to a PEM formatted trusted CA certificate -        :param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html -        :rtype : SSL.Context -        """ -        context = SSL.Context(method) -        # Options (NO_SSLv2/3) -        if options is not None: -            context.set_options(options) - -        # Verify Options (NONE/PEER and trusted CAs) -        if verify_options is not None: -            def verify_cert(conn, x509, errno, err_depth, is_cert_verified): -                if not is_cert_verified: -                    self.ssl_verification_error = exceptions.InvalidCertificateException( -                        "Certificate Verification Error for {}: {} (errno: {}, depth: {})".format( -                            sni, -                            strutils.native(SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(errno)), "utf8"), -                            errno, -                            err_depth -                        ) -                    ) -                return is_cert_verified - -            context.set_verify(verify_options, verify_cert) -            if ca_path is None and ca_pemfile is None: -                ca_pemfile = certifi.where() -            context.load_verify_locations(ca_pemfile, ca_path) - -        # Workaround for -        # https://github.com/pyca/pyopenssl/issues/190 -        # https://github.com/mitmproxy/mitmproxy/issues/472 -        # Options already set before are not cleared. -        context.set_mode(SSL._lib.SSL_MODE_AUTO_RETRY) - -        # Cipher List -        if cipher_list: -            try: -                context.set_cipher_list(cipher_list) - -                # TODO: maybe change this to with newer pyOpenSSL APIs -                context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1')) -            except SSL.Error as v: -                raise exceptions.TlsException("SSL cipher specification error: %s" % str(v)) - -        # SSLKEYLOGFILE -        if log_ssl_key: -            context.set_info_callback(log_ssl_key) - -        if HAS_ALPN: -            if alpn_protos is not None: -                # advertise application layer protocols -                context.set_alpn_protos(alpn_protos) -            elif alpn_select is not None and alpn_select_callback is None: -                # select application layer protocol -                def alpn_select_callback(conn_, options): -                    if alpn_select in options: -                        return bytes(alpn_select) -                    else:  # pragma no cover -                        return options[0] -                context.set_alpn_select_callback(alpn_select_callback) -            elif alpn_select_callback is not None and alpn_select is None: -                context.set_alpn_select_callback(alpn_select_callback) -            elif alpn_select_callback is not None and alpn_select is not None: -                raise exceptions.TlsException("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).") - -        return context - - -class ConnectionCloser: -    def __init__(self, conn): -        self.conn = conn -        self._canceled = False - -    def pop(self): -        """ -            Cancel the current closer, and return a fresh one. -        """ -        self._canceled = True -        return ConnectionCloser(self.conn) - -    def __enter__(self): -        return self - -    def __exit__(self, *args): -        if not self._canceled: -            self.conn.close() - - -class TCPClient(_Connection): - -    def __init__(self, address, source_address=None, spoof_source_address=None): -        super().__init__(None) -        self.address = address -        self.source_address = source_address -        self.cert = None -        self.server_certs = [] -        self.ssl_verification_error = None  # type: Optional[exceptions.InvalidCertificateException] -        self.sni = None -        self.spoof_source_address = spoof_source_address - -    @property -    def address(self): -        return self.__address - -    @address.setter -    def address(self, address): -        if address: -            self.__address = Address.wrap(address) -        else: -            self.__address = None - -    @property -    def source_address(self): -        return self.__source_address - -    @source_address.setter -    def source_address(self, source_address): -        if source_address: -            self.__source_address = Address.wrap(source_address) -        else: -            self.__source_address = None - -    def close(self): -        # Make sure to close the real socket, not the SSL proxy. -        # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection, -        # it tries to renegotiate... -        if isinstance(self.connection, SSL.Connection): -            close_socket(self.connection._socket) -        else: -            close_socket(self.connection) - -    def create_ssl_context(self, cert=None, alpn_protos=None, **sslctx_kwargs): -        context = self._create_ssl_context( -            alpn_protos=alpn_protos, -            **sslctx_kwargs) -        # Client Certs -        if cert: -            try: -                context.use_privatekey_file(cert) -                context.use_certificate_file(cert) -            except SSL.Error as v: -                raise exceptions.TlsException("SSL client certificate error: %s" % str(v)) -        return context - -    def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs): -        """ -            cert: Path to a file containing both client cert and private key. - -            options: A bit field consisting of OpenSSL.SSL.OP_* values -            verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values -            ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool -            ca_pemfile: Path to a PEM formatted trusted CA certificate -        """ -        verification_mode = sslctx_kwargs.get('verify_options', None) -        if verification_mode == SSL.VERIFY_PEER and not sni: -            raise exceptions.TlsException("Cannot validate certificate hostname without SNI") - -        context = self.create_ssl_context( -            alpn_protos=alpn_protos, -            sni=sni, -            **sslctx_kwargs -        ) -        self.connection = SSL.Connection(context, self.connection) -        if sni: -            self.sni = sni -            self.connection.set_tlsext_host_name(sni.encode("idna")) -        self.connection.set_connect_state() -        try: -            self.connection.do_handshake() -        except SSL.Error as v: -            if self.ssl_verification_error: -                raise self.ssl_verification_error -            else: -                raise exceptions.TlsException("SSL handshake error: %s" % repr(v)) -        else: -            # Fix for pre v1.0 OpenSSL, which doesn't throw an exception on -            # certificate validation failure -            if verification_mode == SSL.VERIFY_PEER and self.ssl_verification_error: -                raise self.ssl_verification_error - -        self.cert = certs.SSLCert(self.connection.get_peer_certificate()) - -        # Keep all server certificates in a list -        for i in self.connection.get_peer_cert_chain(): -            self.server_certs.append(certs.SSLCert(i)) - -        # Validate TLS Hostname -        try: -            crt = dict( -                subjectAltName=[("DNS", x.decode("ascii", "strict")) for x in self.cert.altnames] -            ) -            if self.cert.cn: -                crt["subject"] = [[["commonName", self.cert.cn.decode("ascii", "strict")]]] -            if sni: -                hostname = sni -            else: -                hostname = "no-hostname" -            ssl_match_hostname.match_hostname(crt, hostname) -        except (ValueError, ssl_match_hostname.CertificateError) as e: -            self.ssl_verification_error = exceptions.InvalidCertificateException( -                "Certificate Verification Error for {}: {}".format( -                    sni or repr(self.address), -                    str(e) -                ) -            ) -            if verification_mode == SSL.VERIFY_PEER: -                raise self.ssl_verification_error - -        self.ssl_established = True -        self.rfile.set_descriptor(self.connection) -        self.wfile.set_descriptor(self.connection) - -    def makesocket(self): -        # some parties (cuckoo sandbox) need to hook this -        return socket.socket(self.address.family, socket.SOCK_STREAM) - -    def connect(self): -        try: -            connection = self.makesocket() - -            if self.spoof_source_address: -                try: -                    # 19 is `IP_TRANSPARENT`, which is only available on Python 3.3+ on some OSes -                    if not connection.getsockopt(socket.SOL_IP, 19): -                        connection.setsockopt(socket.SOL_IP, 19, 1) -                except socket.error as e: -                    raise exceptions.TcpException( -                        "Failed to spoof the source address: " + e.strerror -                    ) -            if self.source_address: -                connection.bind(self.source_address()) -            connection.connect(self.address()) -            self.source_address = Address(connection.getsockname()) -        except (socket.error, IOError) as err: -            raise exceptions.TcpException( -                'Error connecting to "%s": %s' % -                (self.address.host, err) -            ) -        self.connection = connection -        self.ip_address = Address(connection.getpeername()) -        self._makefile() -        return ConnectionCloser(self) - -    def settimeout(self, n): -        self.connection.settimeout(n) - -    def gettimeout(self): -        return self.connection.gettimeout() - -    def get_alpn_proto_negotiated(self): -        if HAS_ALPN and self.ssl_established: -            return self.connection.get_alpn_proto_negotiated() -        else: -            return b"" - - -class BaseHandler(_Connection): - -    """ -        The instantiator is expected to call the handle() and finish() methods. -    """ - -    def __init__(self, connection, address, server): -        super().__init__(connection) -        self.address = Address.wrap(address) -        self.server = server -        self.clientcert = None - -    def create_ssl_context(self, -                           cert, key, -                           handle_sni=None, -                           request_client_cert=None, -                           chain_file=None, -                           dhparams=None, -                           extra_chain_certs=None, -                           **sslctx_kwargs): -        """ -            cert: A certs.SSLCert object or the path to a certificate -            chain file. - -            handle_sni: SNI handler, should take a connection object. Server -            name can be retrieved like this: - -                    connection.get_servername() - -            And you can specify the connection keys as follows: - -                    new_context = Context(TLSv1_METHOD) -                    new_context.use_privatekey(key) -                    new_context.use_certificate(cert) -                    connection.set_context(new_context) - -            The request_client_cert argument requires some explanation. We're -            supposed to be able to do this with no negative effects - if the -            client has no cert to present, we're notified and proceed as usual. -            Unfortunately, Android seems to have a bug (tested on 4.2.2) - when -            an Android client is asked to present a certificate it does not -            have, it hangs up, which is frankly bogus. Some time down the track -            we may be able to make the proper behaviour the default again, but -            until then we're conservative. -        """ - -        context = self._create_ssl_context(ca_pemfile=chain_file, **sslctx_kwargs) - -        context.use_privatekey(key) -        if isinstance(cert, certs.SSLCert): -            context.use_certificate(cert.x509) -        else: -            context.use_certificate_chain_file(cert) - -        if extra_chain_certs: -            for i in extra_chain_certs: -                context.add_extra_chain_cert(i.x509) - -        if handle_sni: -            # SNI callback happens during do_handshake() -            context.set_tlsext_servername_callback(handle_sni) - -        if request_client_cert: -            def save_cert(conn_, cert, errno_, depth_, preverify_ok_): -                self.clientcert = certs.SSLCert(cert) -                # Return true to prevent cert verification error -                return True -            context.set_verify(SSL.VERIFY_PEER, save_cert) - -        if dhparams: -            SSL._lib.SSL_CTX_set_tmp_dh(context._context, dhparams) - -        return context - -    def convert_to_ssl(self, cert, key, **sslctx_kwargs): -        """ -        Convert connection to SSL. -        For a list of parameters, see BaseHandler._create_ssl_context(...) -        """ - -        context = self.create_ssl_context( -            cert, -            key, -            **sslctx_kwargs) -        self.connection = SSL.Connection(context, self.connection) -        self.connection.set_accept_state() -        try: -            self.connection.do_handshake() -        except SSL.Error as v: -            raise exceptions.TlsException("SSL handshake error: %s" % repr(v)) -        self.ssl_established = True -        self.rfile.set_descriptor(self.connection) -        self.wfile.set_descriptor(self.connection) - -    def handle(self):  # pragma: no cover -        raise NotImplementedError - -    def settimeout(self, n): -        self.connection.settimeout(n) - -    def get_alpn_proto_negotiated(self): -        if HAS_ALPN and self.ssl_established: -            return self.connection.get_alpn_proto_negotiated() -        else: -            return b"" - - -class Counter: -    def __init__(self): -        self._count = 0 -        self._lock = threading.Lock() - -    @property -    def count(self): -        with self._lock: -            return self._count - -    def __enter__(self): -        with self._lock: -            self._count += 1 - -    def __exit__(self, *args): -        with self._lock: -            self._count -= 1 - - -class TCPServer: -    request_queue_size = 20 - -    def __init__(self, address): -        self.address = Address.wrap(address) -        self.__is_shut_down = threading.Event() -        self.__shutdown_request = False -        self.socket = socket.socket(self.address.family, socket.SOCK_STREAM) -        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) -        self.socket.bind(self.address()) -        self.address = Address.wrap(self.socket.getsockname()) -        self.socket.listen(self.request_queue_size) -        self.handler_counter = Counter() - -    def connection_thread(self, connection, client_address): -        with self.handler_counter: -            client_address = Address(client_address) -            try: -                self.handle_client_connection(connection, client_address) -            except: -                self.handle_error(connection, client_address) -            finally: -                close_socket(connection) - -    def serve_forever(self, poll_interval=0.1): -        self.__is_shut_down.clear() -        try: -            while not self.__shutdown_request: -                try: -                    r, w_, e_ = select.select( -                        [self.socket], [], [], poll_interval) -                except select.error as ex:  # pragma: no cover -                    if ex[0] == EINTR: -                        continue -                    else: -                        raise -                if self.socket in r: -                    connection, client_address = self.socket.accept() -                    t = basethread.BaseThread( -                        "TCPConnectionHandler (%s: %s:%s -> %s:%s)" % ( -                            self.__class__.__name__, -                            client_address[0], -                            client_address[1], -                            self.address.host, -                            self.address.port -                        ), -                        target=self.connection_thread, -                        args=(connection, client_address), -                    ) -                    t.setDaemon(1) -                    try: -                        t.start() -                    except threading.ThreadError: -                        self.handle_error(connection, Address(client_address)) -                        connection.close() -        finally: -            self.__shutdown_request = False -            self.__is_shut_down.set() - -    def shutdown(self): -        self.__shutdown_request = True -        self.__is_shut_down.wait() -        self.socket.close() -        self.handle_shutdown() - -    def handle_error(self, connection_, client_address, fp=sys.stderr): -        """ -            Called when handle_client_connection raises an exception. -        """ -        # If a thread has persisted after interpreter exit, the module might be -        # none. -        if traceback: -            exc = str(traceback.format_exc()) -            print(u'-' * 40, file=fp) -            print( -                u"Error in processing of request from %s" % repr(client_address), file=fp) -            print(exc, file=fp) -            print(u'-' * 40, file=fp) - -    def handle_client_connection(self, conn, client_address):  # pragma: no cover -        """ -            Called after client connection. -        """ -        raise NotImplementedError - -    def handle_shutdown(self): -        """ -            Called after server shutdown. -        """ - -    def wait_for_silence(self, timeout=5): -        start = time.time() -        while 1: -            if time.time() - start >= timeout: -                raise exceptions.Timeout( -                    "%s service threads still alive" % -                    self.handler_counter.count -                ) -            if self.handler_counter.count == 0: -                return diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py deleted file mode 100644 index 2d6f0a0c..00000000 --- a/netlib/websockets/__init__.py +++ /dev/null @@ -1,35 +0,0 @@ -from .frame import FrameHeader -from .frame import Frame -from .frame import OPCODE -from .frame import CLOSE_REASON -from .masker import Masker -from .utils import MAGIC -from .utils import VERSION -from .utils import client_handshake_headers -from .utils import server_handshake_headers -from .utils import check_handshake -from .utils import check_client_version -from .utils import create_server_nonce -from .utils import get_extensions -from .utils import get_protocol -from .utils import get_client_key -from .utils import get_server_accept - -__all__ = [ -    "FrameHeader", -    "Frame", -    "OPCODE", -    "CLOSE_REASON", -    "Masker", -    "MAGIC", -    "VERSION", -    "client_handshake_headers", -    "server_handshake_headers", -    "check_handshake", -    "check_client_version", -    "create_server_nonce", -    "get_extensions", -    "get_protocol", -    "get_client_key", -    "get_server_accept", -] diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py deleted file mode 100644 index bc4ae43a..00000000 --- a/netlib/websockets/frame.py +++ /dev/null @@ -1,274 +0,0 @@ -import os -import struct -import io - -from netlib import tcp -from mitmproxy.utils import strutils -from mitmproxy.utils import bits -from mitmproxy.utils import human -from mitmproxy.types import bidi -from .masker import Masker - - -MAX_16_BIT_INT = (1 << 16) -MAX_64_BIT_INT = (1 << 64) - -DEFAULT = object() - -# RFC 6455, Section 5.2 - Base Framing Protocol -OPCODE = bidi.BiDi( -    CONTINUE=0x00, -    TEXT=0x01, -    BINARY=0x02, -    CLOSE=0x08, -    PING=0x09, -    PONG=0x0a -) - -# RFC 6455, Section 7.4.1 - Defined Status Codes -CLOSE_REASON = bidi.BiDi( -    NORMAL_CLOSURE=1000, -    GOING_AWAY=1001, -    PROTOCOL_ERROR=1002, -    UNSUPPORTED_DATA=1003, -    RESERVED=1004, -    RESERVED_NO_STATUS=1005, -    RESERVED_ABNORMAL_CLOSURE=1006, -    INVALID_PAYLOAD_DATA=1007, -    POLICY_VIOLATION=1008, -    MESSAGE_TOO_BIG=1009, -    MANDATORY_EXTENSION=1010, -    INTERNAL_ERROR=1011, -    RESERVED_TLS_HANDHSAKE_FAILED=1015, -) - - -class FrameHeader: - -    def __init__( -        self, -        opcode=OPCODE.TEXT, -        payload_length=0, -        fin=False, -        rsv1=False, -        rsv2=False, -        rsv3=False, -        masking_key=DEFAULT, -        mask=DEFAULT, -        length_code=DEFAULT -    ): -        if not 0 <= opcode < 2 ** 4: -            raise ValueError("opcode must be 0-16") -        self.opcode = opcode -        self.payload_length = payload_length -        self.fin = fin -        self.rsv1 = rsv1 -        self.rsv2 = rsv2 -        self.rsv3 = rsv3 - -        if length_code is DEFAULT: -            self.length_code = self._make_length_code(self.payload_length) -        else: -            self.length_code = length_code - -        if mask is DEFAULT and masking_key is DEFAULT: -            self.mask = False -            self.masking_key = b"" -        elif mask is DEFAULT: -            self.mask = 1 -            self.masking_key = masking_key -        elif masking_key is DEFAULT: -            self.mask = mask -            self.masking_key = os.urandom(4) -        else: -            self.mask = mask -            self.masking_key = masking_key - -        if self.masking_key and len(self.masking_key) != 4: -            raise ValueError("Masking key must be 4 bytes.") - -    @classmethod -    def _make_length_code(self, length): -        """ -         A websockets frame contains an initial length_code, and an optional -         extended length code to represent the actual length if length code is -         larger than 125 -        """ -        if length <= 125: -            return length -        elif length >= 126 and length <= 65535: -            return 126 -        else: -            return 127 - -    def __repr__(self): -        vals = [ -            "ws frame:", -            OPCODE.get_name(self.opcode, hex(self.opcode)).lower() -        ] -        flags = [] -        for i in ["fin", "rsv1", "rsv2", "rsv3", "mask"]: -            if getattr(self, i): -                flags.append(i) -        if flags: -            vals.extend([":", "|".join(flags)]) -        if self.masking_key: -            vals.append(":key=%s" % repr(self.masking_key)) -        if self.payload_length: -            vals.append(" %s" % human.pretty_size(self.payload_length)) -        return "".join(vals) - -    def __bytes__(self): -        first_byte = bits.setbit(0, 7, self.fin) -        first_byte = bits.setbit(first_byte, 6, self.rsv1) -        first_byte = bits.setbit(first_byte, 5, self.rsv2) -        first_byte = bits.setbit(first_byte, 4, self.rsv3) -        first_byte = first_byte | self.opcode - -        second_byte = bits.setbit(self.length_code, 7, self.mask) - -        b = bytes([first_byte, second_byte]) - -        if self.payload_length < 126: -            pass -        elif self.payload_length < MAX_16_BIT_INT: -            # '!H' pack as 16 bit unsigned short -            # add 2 byte extended payload length -            b += struct.pack('!H', self.payload_length) -        elif self.payload_length < MAX_64_BIT_INT: -            # '!Q' = pack as 64 bit unsigned long long -            # add 8 bytes extended payload length -            b += struct.pack('!Q', self.payload_length) -        else: -            raise ValueError("Payload length exceeds 64bit integer") - -        if self.masking_key: -            b += self.masking_key -        return b - -    @classmethod -    def from_file(cls, fp): -        """ -          read a websockets frame header -        """ -        first_byte, second_byte = fp.safe_read(2) -        fin = bits.getbit(first_byte, 7) -        rsv1 = bits.getbit(first_byte, 6) -        rsv2 = bits.getbit(first_byte, 5) -        rsv3 = bits.getbit(first_byte, 4) -        opcode = first_byte & 0xF -        mask_bit = bits.getbit(second_byte, 7) -        length_code = second_byte & 0x7F - -        # payload_length > 125 indicates you need to read more bytes -        # to get the actual payload length -        if length_code <= 125: -            payload_length = length_code -        elif length_code == 126: -            payload_length, = struct.unpack("!H", fp.safe_read(2)) -        else:  # length_code == 127: -            payload_length, = struct.unpack("!Q", fp.safe_read(8)) - -        # masking key only present if mask bit set -        if mask_bit == 1: -            masking_key = fp.safe_read(4) -        else: -            masking_key = None - -        return cls( -            fin=fin, -            rsv1=rsv1, -            rsv2=rsv2, -            rsv3=rsv3, -            opcode=opcode, -            mask=mask_bit, -            length_code=length_code, -            payload_length=payload_length, -            masking_key=masking_key, -        ) - -    def __eq__(self, other): -        if isinstance(other, FrameHeader): -            return bytes(self) == bytes(other) -        return False - - -class Frame: -    """ -    Represents a single WebSockets frame. -    Constructor takes human readable forms of the frame components. -    from_bytes() reads from a file-like object to create a new Frame. - -    WebSockets Frame as defined in RFC6455 - -       0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -      +-+-+-+-+-------+-+-------------+-------------------------------+ -      |F|R|R|R| opcode|M| Payload len |    Extended payload length    | -      |I|S|S|S|  (4)  |A|     (7)     |             (16/64)           | -      |N|V|V|V|       |S|             |   (if payload len==126/127)   | -      | |1|2|3|       |K|             |                               | -      +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + -      |     Extended payload length continued, if payload len == 127  | -      + - - - - - - - - - - - - - - - +-------------------------------+ -      |                               |Masking-key, if MASK set to 1  | -      +-------------------------------+-------------------------------+ -      | Masking-key (continued)       |          Payload Data         | -      +-------------------------------- - - - - - - - - - - - - - - - + -      :                     Payload Data continued ...                : -      + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + -      |                     Payload Data continued ...                | -      +---------------------------------------------------------------+ -    """ - -    def __init__(self, payload=b"", **kwargs): -        self.payload = payload -        kwargs["payload_length"] = kwargs.get("payload_length", len(payload)) -        self.header = FrameHeader(**kwargs) - -    @classmethod -    def from_bytes(cls, bytestring): -        """ -          Construct a websocket frame from an in-memory bytestring -          to construct a frame from a stream of bytes, use from_file() directly -        """ -        return cls.from_file(tcp.Reader(io.BytesIO(bytestring))) - -    def __repr__(self): -        ret = repr(self.header) -        if self.payload: -            ret = ret + "\nPayload:\n" + strutils.bytes_to_escaped_str(self.payload) -        return ret - -    def __bytes__(self): -        """ -            Serialize the frame to wire format. Returns a string. -        """ -        b = bytes(self.header) -        if self.header.masking_key: -            b += Masker(self.header.masking_key)(self.payload) -        else: -            b += self.payload -        return b - -    @classmethod -    def from_file(cls, fp): -        """ -          read a websockets frame sent by a server or client - -          fp is a "file like" object that could be backed by a network -          stream or a disk or an in memory stream reader -        """ -        header = FrameHeader.from_file(fp) -        payload = fp.safe_read(header.payload_length) - -        if header.mask == 1 and header.masking_key: -            payload = Masker(header.masking_key)(payload) - -        frame = cls(payload) -        frame.header = header -        return frame - -    def __eq__(self, other): -        if isinstance(other, Frame): -            return bytes(self) == bytes(other) -        return False diff --git a/netlib/websockets/masker.py b/netlib/websockets/masker.py deleted file mode 100644 index 47b1a688..00000000 --- a/netlib/websockets/masker.py +++ /dev/null @@ -1,25 +0,0 @@ -class Masker: -    """ -    Data sent from the server must be masked to prevent malicious clients -    from sending data over the wire in predictable patterns. - -    Servers do not have to mask data they send to the client. -    https://tools.ietf.org/html/rfc6455#section-5.3 -    """ - -    def __init__(self, key): -        self.key = key -        self.offset = 0 - -    def mask(self, offset, data): -        result = bytearray(data) -        for i in range(len(data)): -            result[i] ^= self.key[offset % 4] -            offset += 1 -        result = bytes(result) -        return result - -    def __call__(self, data): -        ret = self.mask(self.offset, data) -        self.offset += len(ret) -        return ret diff --git a/netlib/websockets/utils.py b/netlib/websockets/utils.py deleted file mode 100644 index 98043662..00000000 --- a/netlib/websockets/utils.py +++ /dev/null @@ -1,90 +0,0 @@ -""" -Collection of WebSockets Protocol utility functions (RFC6455) -Spec: https://tools.ietf.org/html/rfc6455 -""" - - -import base64 -import hashlib -import os - -from netlib import http -from mitmproxy.utils import strutils - -MAGIC = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11' -VERSION = "13" - - -def client_handshake_headers(version=None, key=None, protocol=None, extensions=None): -    """ -        Create the headers for a valid HTTP upgrade request. If Key is not -        specified, it is generated, and can be found in sec-websocket-key in -        the returned header set. - -        Returns an instance of http.Headers -    """ -    if version is None: -        version = VERSION -    if key is None: -        key = base64.b64encode(os.urandom(16)).decode('ascii') -    h = http.Headers( -        connection="upgrade", -        upgrade="websocket", -        sec_websocket_version=version, -        sec_websocket_key=key, -    ) -    if protocol is not None: -        h['sec-websocket-protocol'] = protocol -    if extensions is not None: -        h['sec-websocket-extensions'] = extensions -    return h - - -def server_handshake_headers(client_key, protocol=None, extensions=None): -    """ -      The server response is a valid HTTP 101 response. - -      Returns an instance of http.Headers -    """ -    h = http.Headers( -        connection="upgrade", -        upgrade="websocket", -        sec_websocket_accept=create_server_nonce(client_key), -    ) -    if protocol is not None: -        h['sec-websocket-protocol'] = protocol -    if extensions is not None: -        h['sec-websocket-extensions'] = extensions -    return h - - -def check_handshake(headers): -    return ( -        "upgrade" in headers.get("connection", "").lower() and -        headers.get("upgrade", "").lower() == "websocket" and -        (headers.get("sec-websocket-key") is not None or headers.get("sec-websocket-accept") is not None) -    ) - - -def create_server_nonce(client_nonce): -    return base64.b64encode(hashlib.sha1(strutils.always_bytes(client_nonce) + MAGIC).digest()) - - -def check_client_version(headers): -    return headers.get("sec-websocket-version", "") == VERSION - - -def get_extensions(headers): -    return headers.get("sec-websocket-extensions", None) - - -def get_protocol(headers): -    return headers.get("sec-websocket-protocol", None) - - -def get_client_key(headers): -    return headers.get("sec-websocket-key", None) - - -def get_server_accept(headers): -    return headers.get("sec-websocket-accept", None) diff --git a/netlib/wsgi.py b/netlib/wsgi.py deleted file mode 100644 index 5a54cd70..00000000 --- a/netlib/wsgi.py +++ /dev/null @@ -1,166 +0,0 @@ -import time -import traceback -import urllib -import io - -from netlib import http -from netlib import tcp -from mitmproxy.utils import strutils - - -class ClientConn: - -    def __init__(self, address): -        self.address = tcp.Address.wrap(address) - - -class Flow: - -    def __init__(self, address, request): -        self.client_conn = ClientConn(address) -        self.request = request - - -class Request: - -    def __init__(self, scheme, method, path, http_version, headers, content): -        self.scheme, self.method, self.path = scheme, method, path -        self.headers, self.content = headers, content -        self.http_version = http_version - - -def date_time_string(): -    """Return the current date and time formatted for a message header.""" -    WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] -    MONTHS = [ -        None, -        'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', -        'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec' -    ] -    now = time.time() -    year, month, day, hh, mm, ss, wd, y_, z_ = time.gmtime(now) -    s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( -        WEEKS[wd], -        day, MONTHS[month], year, -        hh, mm, ss -    ) -    return s - - -class WSGIAdaptor: - -    def __init__(self, app, domain, port, sversion): -        self.app, self.domain, self.port, self.sversion = app, domain, port, sversion - -    def make_environ(self, flow, errsoc, **extra): -        """ -        Raises: -            ValueError, if the content-encoding is invalid. -        """ -        path = strutils.native(flow.request.path, "latin-1") -        if '?' in path: -            path_info, query = strutils.native(path, "latin-1").split('?', 1) -        else: -            path_info = path -            query = '' -        environ = { -            'wsgi.version': (1, 0), -            'wsgi.url_scheme': strutils.native(flow.request.scheme, "latin-1"), -            'wsgi.input': io.BytesIO(flow.request.content or b""), -            'wsgi.errors': errsoc, -            'wsgi.multithread': True, -            'wsgi.multiprocess': False, -            'wsgi.run_once': False, -            'SERVER_SOFTWARE': self.sversion, -            'REQUEST_METHOD': strutils.native(flow.request.method, "latin-1"), -            'SCRIPT_NAME': '', -            'PATH_INFO': urllib.parse.unquote(path_info), -            'QUERY_STRING': query, -            'CONTENT_TYPE': strutils.native(flow.request.headers.get('Content-Type', ''), "latin-1"), -            'CONTENT_LENGTH': strutils.native(flow.request.headers.get('Content-Length', ''), "latin-1"), -            'SERVER_NAME': self.domain, -            'SERVER_PORT': str(self.port), -            'SERVER_PROTOCOL': strutils.native(flow.request.http_version, "latin-1"), -        } -        environ.update(extra) -        if flow.client_conn.address: -            environ["REMOTE_ADDR"] = strutils.native(flow.client_conn.address.host, "latin-1") -            environ["REMOTE_PORT"] = flow.client_conn.address.port - -        for key, value in flow.request.headers.items(): -            key = 'HTTP_' + strutils.native(key, "latin-1").upper().replace('-', '_') -            if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'): -                environ[key] = value -        return environ - -    def error_page(self, soc, headers_sent, s): -        """ -            Make a best-effort attempt to write an error page. If headers are -            already sent, we just bung the error into the page. -        """ -        c = """ -            <html> -                <h1>Internal Server Error</h1> -                <pre>{err}"</pre> -            </html> -        """.format(err=s).strip().encode() - -        if not headers_sent: -            soc.write(b"HTTP/1.1 500 Internal Server Error\r\n") -            soc.write(b"Content-Type: text/html\r\n") -            soc.write("Content-Length: {length}\r\n".format(length=len(c)).encode()) -            soc.write(b"\r\n") -        soc.write(c) - -    def serve(self, request, soc, **env): -        state = dict( -            response_started=False, -            headers_sent=False, -            status=None, -            headers=None -        ) - -        def write(data): -            if not state["headers_sent"]: -                soc.write("HTTP/1.1 {status}\r\n".format(status=state["status"]).encode()) -                headers = state["headers"] -                if 'server' not in headers: -                    headers["Server"] = self.sversion -                if 'date' not in headers: -                    headers["Date"] = date_time_string() -                soc.write(bytes(headers)) -                soc.write(b"\r\n") -                state["headers_sent"] = True -            if data: -                soc.write(data) -            soc.flush() - -        def start_response(status, headers, exc_info=None): -            if exc_info: -                if state["headers_sent"]: -                    raise exc_info[1] -            elif state["status"]: -                raise AssertionError('Response already started') -            state["status"] = status -            state["headers"] = http.Headers([[strutils.always_bytes(k), strutils.always_bytes(v)] for k, v in headers]) -            if exc_info: -                self.error_page(soc, state["headers_sent"], traceback.format_tb(exc_info[2])) -                state["headers_sent"] = True - -        errs = io.BytesIO() -        try: -            dataiter = self.app( -                self.make_environ(request, errs, **env), start_response -            ) -            for i in dataiter: -                write(i) -            if not state["headers_sent"]: -                write(b"") -        except Exception: -            try: -                s = traceback.format_exc() -                errs.write(s.encode("utf-8", "replace")) -                self.error_page(soc, state["headers_sent"], s) -            except Exception:    # pragma: no cover -                pass -        return errs.getvalue() | 
