diff options
56 files changed, 1674 insertions, 680 deletions
diff --git a/docs/certinstall.rst b/docs/certinstall.rst index 5d97e92c..1bd6df99 100644 --- a/docs/certinstall.rst +++ b/docs/certinstall.rst @@ -42,7 +42,7 @@ iOS  See http://jasdev.me/intercepting-ios-traffic -and http://web.archive.org/web/20150920082614/http://kb.mit.edu/confluence/pages/viewpage.action?pageId=152600377 +and https://web.archive.org/web/20150920082614/http://kb.mit.edu/confluence/pages/viewpage.action?pageId=152600377  iOS Simulator  ^^^^^^^^^^^^^ @@ -52,7 +52,7 @@ See https://github.com/ADVTOOLS/ADVTrustStore#how-to-use-advtruststore  Java  ^^^^ -See http://docs.oracle.com/cd/E19906-01/820-4916/geygn/index.html +See https://docs.oracle.com/cd/E19906-01/820-4916/geygn/index.html  Android/Android Simulator  ^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -62,7 +62,7 @@ See http://wiki.cacert.org/FAQ/ImportRootCert#Android_Phones_.26_Tablets  Windows  ^^^^^^^ -See http://windows.microsoft.com/en-ca/windows/import-export-certificates-private-keys#1TC=windows-7 +See https://web.archive.org/web/20160612045445/http://windows.microsoft.com/en-ca/windows/import-export-certificates-private-keys#1TC=windows-7  Windows (automated)  ^^^^^^^^^^^^^^^^^^^ @@ -79,7 +79,7 @@ See https://support.apple.com/kb/PH7297?locale=en_US  Ubuntu/Debian  ^^^^^^^^^^^^^ -See http://askubuntu.com/questions/73287/how-do-i-install-a-root-certificate/94861#94861 +See https://askubuntu.com/questions/73287/how-do-i-install-a-root-certificate/94861#94861  Mozilla Firefox  ^^^^^^^^^^^^^^^ @@ -89,7 +89,7 @@ See https://wiki.mozilla.org/MozillaRootCertificate#Mozilla_Firefox  Chrome on Linux  ^^^^^^^^^^^^^^^ -See https://code.google.com/p/chromium/wiki/LinuxCertManagement +See https://stackoverflow.com/a/15076602/198996  The mitmproxy certificate authority @@ -205,4 +205,4 @@ directory and uses this as the client cert. -.. _Certificate Pinning: http://security.stackexchange.com/questions/29988/what-is-certificate-pinning/ +.. _Certificate Pinning: https://security.stackexchange.com/questions/29988/what-is-certificate-pinning/ diff --git a/examples/complex/har_dump.py b/examples/complex/har_dump.py index f7c1e658..51983b54 100644 --- a/examples/complex/har_dump.py +++ b/examples/complex/har_dump.py @@ -10,7 +10,7 @@ import zlib  import os  from datetime import datetime -import pytz +from datetime import timezone  import mitmproxy @@ -89,7 +89,7 @@ def response(flow):      # Timings set to -1 will be ignored as per spec.      full_time = sum(v for v in timings.values() if v > -1) -    started_date_time = format_datetime(datetime.utcfromtimestamp(flow.request.timestamp_start)) +    started_date_time = datetime.fromtimestamp(flow.request.timestamp_start, timezone.utc).isoformat()      # Response body size and encoding      response_body_size = len(flow.response.raw_content) @@ -173,10 +173,6 @@ def done():          mitmproxy.ctx.log("HAR dump finished (wrote %s bytes to file)" % len(json_dump)) -def format_datetime(dt): -    return dt.replace(tzinfo=pytz.timezone("UTC")).isoformat() - -  def format_cookies(cookie_list):      rv = [] @@ -198,7 +194,7 @@ def format_cookies(cookie_list):          # Expiration time needs to be formatted          expire_ts = cookies.get_expiration_ts(attrs)          if expire_ts is not None: -            cookie_har["expires"] = format_datetime(datetime.fromtimestamp(expire_ts)) +            cookie_har["expires"] = datetime.fromtimestamp(expire_ts, timezone.utc).isoformat()          rv.append(cookie_har) diff --git a/examples/complex/xss_scanner.py b/examples/complex/xss_scanner.py new file mode 100755 index 00000000..a0572d5d --- /dev/null +++ b/examples/complex/xss_scanner.py @@ -0,0 +1,407 @@ +""" + + __   __ _____ _____     _____ + \ \ / // ____/ ____|   / ____| +  \ V /| (___| (___    | (___   ___ __ _ _ __  _ __   ___ _ __ +   > <  \___ \\___ \    \___ \ / __/ _` | '_ \| '_ \ / _ \ '__| +  / . \ ____) |___) |   ____) | (_| (_| | | | | | | |  __/ | + /_/ \_\_____/_____/   |_____/ \___\__,_|_| |_|_| |_|\___|_| + + +This script automatically scans all visited webpages for XSS and SQLi vulnerabilities. + +Usage: mitmproxy -s xss_scanner.py + +This script scans for vulnerabilities by injecting a fuzzing payload (see PAYLOAD below) into 4 different places +and examining the HTML to look for XSS and SQLi injection vulnerabilities. The XSS scanning functionality works by +looking to see whether it is possible to inject HTML based off of of where the payload appears in the page and what +characters are escaped. In addition, it also looks for any script tags that load javascript from unclaimed domains. +The SQLi scanning functionality works by using regular expressions to look for errors from a number of different +common databases. Since it is only looking for errors, it will not find blind SQLi vulnerabilities. + +The 4 places it injects the payload into are: +1. URLs         (e.g. https://example.com/ -> https://example.com/PAYLOAD/) +2. Queries      (e.g. https://example.com/index.html?a=b -> https://example.com/index.html?a=PAYLOAD) +3. Referers     (e.g. The referer changes from https://example.com to PAYLOAD) +4. User Agents  (e.g. The UA changes from Chrome to PAYLOAD) + +Reports from this script show up in the event log (viewable by pressing e) and formatted like: + +===== XSS Found ==== +XSS URL: http://daviddworken.com/vulnerableUA.php +Injection Point: User Agent +Suggested Exploit: <script>alert(0)</script> +Line: 1029zxcs'd"ao<ac>so[sb]po(pc)se;sl/bsl\eq=3847asd + +""" + +from mitmproxy import ctx +from socket import gaierror, gethostbyname +from urllib.parse import urlparse +import requests +import re +from html.parser import HTMLParser +from mitmproxy import http +from typing import Dict, Union, Tuple, Optional, List, NamedTuple + +# The actual payload is put between a frontWall and a backWall to make it easy +# to locate the payload with regular expressions +FRONT_WALL = b"1029zxc" +BACK_WALL = b"3847asd" +PAYLOAD = b"""s'd"ao<ac>so[sb]po(pc)se;sl/bsl\\eq=""" +FULL_PAYLOAD = FRONT_WALL + PAYLOAD + BACK_WALL + +# A XSSData is a named tuple with the following fields: +#   - url -> str +#   - injection_point -> str +#   - exploit -> str +#   - line -> str +XSSData = NamedTuple('XSSData', [('url', str), +                                 ('injection_point', str), +                                 ('exploit', str), +                                 ('line', str)]) + +# A SQLiData is named tuple with the following fields: +#   - url -> str +#   - injection_point -> str +#   - regex -> str +#   - dbms -> str +SQLiData = NamedTuple('SQLiData', [('url', str), +                                   ('injection_point', str), +                                   ('regex', str), +                                   ('dbms', str)]) + + +VulnData = Tuple[Optional[XSSData], Optional[SQLiData]] +Cookies = Dict[str, str] + + +def get_cookies(flow: http.HTTPFlow) -> Cookies: +    """ Return a dict going from cookie names to cookie values +          - Note that it includes both the cookies sent in the original request and +            the cookies sent by the server """ +    return {name: value for name, value in flow.request.cookies.fields} + + +def find_unclaimed_URLs(body: Union[str, bytes], requestUrl: bytes) -> None: +    """ Look for unclaimed URLs in script tags and log them if found""" +    class ScriptURLExtractor(HTMLParser): +        script_URLs = [] + +        def handle_starttag(self, tag, attrs): +            if tag == "script" and "src" in [name for name, value in attrs]: +                for name, value in attrs: +                    if name == "src": +                        self.script_URLs.append(value) + +    parser = ScriptURLExtractor() +    try: +        parser.feed(body) +    except TypeError: +        parser.feed(body.decode('utf-8')) +    for url in parser.script_URLs: +        parser = urlparse(url) +        domain = parser.netloc +        try: +            gethostbyname(domain) +        except gaierror: +            ctx.log.error("XSS found in %s due to unclaimed URL \"%s\" in script tag." % (requestUrl, url)) + + +def test_end_of_URL_injection(original_body: str, request_URL: str, cookies: Cookies) -> VulnData: +    """ Test the given URL for XSS via injection onto the end of the URL and +        log the XSS if found """ +    parsed_URL = urlparse(request_URL) +    path = parsed_URL.path +    if path != "" and path[-1] != "/":  # ensure the path ends in a / +        path += "/" +    path += FULL_PAYLOAD.decode('utf-8')  # the path must be a string while the payload is bytes +    url = parsed_URL._replace(path=path).geturl() +    body = requests.get(url, cookies=cookies).text.lower() +    xss_info = get_XSS_data(body, url, "End of URL") +    sqli_info = get_SQLi_data(body, original_body, url, "End of URL") +    return xss_info, sqli_info + + +def test_referer_injection(original_body: str, request_URL: str, cookies: Cookies) -> VulnData: +    """ Test the given URL for XSS via injection into the referer and +        log the XSS if found """ +    body = requests.get(request_URL, headers={'referer': FULL_PAYLOAD}, cookies=cookies).text.lower() +    xss_info = get_XSS_data(body, request_URL, "Referer") +    sqli_info = get_SQLi_data(body, original_body, request_URL, "Referer") +    return xss_info, sqli_info + + +def test_user_agent_injection(original_body: str, request_URL: str, cookies: Cookies) -> VulnData: +    """ Test the given URL for XSS via injection into the user agent and +        log the XSS if found """ +    body = requests.get(request_URL, headers={'User-Agent': FULL_PAYLOAD}, cookies=cookies).text.lower() +    xss_info = get_XSS_data(body, request_URL, "User Agent") +    sqli_info = get_SQLi_data(body, original_body, request_URL, "User Agent") +    return xss_info, sqli_info + + +def test_query_injection(original_body: str, request_URL: str, cookies: Cookies): +    """ Test the given URL for XSS via injection into URL queries and +        log the XSS if found """ +    parsed_URL = urlparse(request_URL) +    query_string = parsed_URL.query +    # queries is a list of parameters where each parameter is set to the payload +    queries = [query.split("=")[0] + "=" + FULL_PAYLOAD.decode('utf-8') for query in query_string.split("&")] +    new_query_string = "&".join(queries) +    new_URL = parsed_URL._replace(query=new_query_string).geturl() +    body = requests.get(new_URL, cookies=cookies).text.lower() +    xss_info = get_XSS_data(body, new_URL, "Query") +    sqli_info = get_SQLi_data(body, original_body, new_URL, "Query") +    return xss_info, sqli_info + + +def log_XSS_data(xss_info: Optional[XSSData]) -> None: +    """ Log information about the given XSS to mitmproxy """ +    # If it is None, then there is no info to log +    if not xss_info: +        return +    ctx.log.error("===== XSS Found ====") +    ctx.log.error("XSS URL: %s" % xss_info.url) +    ctx.log.error("Injection Point: %s" % xss_info.injection_point) +    ctx.log.error("Suggested Exploit: %s" % xss_info.exploit) +    ctx.log.error("Line: %s" % xss_info.line) + + +def log_SQLi_data(sqli_info: Optional[SQLiData]) -> None: +    """ Log information about the given SQLi to mitmproxy """ +    if not sqli_info: +        return +    ctx.log.error("===== SQLi Found =====") +    ctx.log.error("SQLi URL: %s" % sqli_info.url.decode('utf-8')) +    ctx.log.error("Injection Point: %s" % sqli_info.injection_point.decode('utf-8')) +    ctx.log.error("Regex used: %s" % sqli_info.regex.decode('utf-8')) +    ctx.log.error("Suspected DBMS: %s" % sqli_info.dbms.decode('utf-8')) + + +def get_SQLi_data(new_body: str, original_body: str, request_URL: str, injection_point: str) -> Optional[SQLiData]: +    """ Return a SQLiDict if there is a SQLi otherwise return None +        String String URL String -> (SQLiDict or None) """ +    # Regexes taken from Damn Small SQLi Scanner: https://github.com/stamparm/DSSS/blob/master/dsss.py#L17 +    DBMS_ERRORS = { +        "MySQL": (r"SQL syntax.*MySQL", r"Warning.*mysql_.*", r"valid MySQL result", r"MySqlClient\."), +        "PostgreSQL": (r"PostgreSQL.*ERROR", r"Warning.*\Wpg_.*", r"valid PostgreSQL result", r"Npgsql\."), +        "Microsoft SQL Server": (r"Driver.* SQL[\-\_\ ]*Server", r"OLE DB.* SQL Server", r"(\W|\A)SQL Server.*Driver", +                                 r"Warning.*mssql_.*", r"(\W|\A)SQL Server.*[0-9a-fA-F]{8}", +                                 r"(?s)Exception.*\WSystem\.Data\.SqlClient\.", r"(?s)Exception.*\WRoadhouse\.Cms\."), +        "Microsoft Access": (r"Microsoft Access Driver", r"JET Database Engine", r"Access Database Engine"), +        "Oracle": (r"\bORA-[0-9][0-9][0-9][0-9]", r"Oracle error", r"Oracle.*Driver", r"Warning.*\Woci_.*", r"Warning.*\Wora_.*"), +        "IBM DB2": (r"CLI Driver.*DB2", r"DB2 SQL error", r"\bdb2_\w+\("), +        "SQLite": (r"SQLite/JDBCDriver", r"SQLite.Exception", r"System.Data.SQLite.SQLiteException", r"Warning.*sqlite_.*", +                   r"Warning.*SQLite3::", r"\[SQLITE_ERROR\]"), +        "Sybase": (r"(?i)Warning.*sybase.*", r"Sybase message", r"Sybase.*Server message.*"), +    } +    for dbms, regexes in DBMS_ERRORS.items(): +        for regex in regexes: +            if re.search(regex, new_body) and not re.search(regex, original_body): +                return SQLiData(request_URL, +                                injection_point, +                                regex, +                                dbms) + + +# A qc is either ' or " +def inside_quote(qc: str, substring: bytes, text_index: int, body: bytes) -> bool: +    """ Whether the Numberth occurence of the first string in the second +        string is inside quotes as defined by the supplied QuoteChar """ +    substring = substring.decode('utf-8') +    body = body.decode('utf-8') +    num_substrings_found = 0 +    in_quote = False +    for index, char in enumerate(body): +        # Whether the next chunk of len(substring) chars is the substring +        next_part_is_substring = ( +            (not (index + len(substring) > len(body))) and +            (body[index:index + len(substring)] == substring) +        ) +        # Whether this char is escaped with a \ +        is_not_escaped = ( +            (index - 1 < 0 or index - 1 > len(body)) or +            (body[index - 1] != "\\") +        ) +        if char == qc and is_not_escaped: +            in_quote = not in_quote +        if next_part_is_substring: +            if num_substrings_found == text_index: +                return in_quote +            num_substrings_found += 1 +    return False + + +def paths_to_text(html: str, str: str) -> List[str]: +    """ Return list of Paths to a given str in the given HTML tree +          - Note that it does a BFS """ + +    def remove_last_occurence_of_sub_string(str: str, substr: str): +        """ Delete the last occurence of substr from str +        String String -> String +        """ +        index = str.rfind(substr) +        return str[:index] + str[index + len(substr):] + +    class PathHTMLParser(HTMLParser): +        currentPath = "" +        paths = [] + +        def handle_starttag(self, tag, attrs): +            self.currentPath += ("/" + tag) + +        def handle_endtag(self, tag): +            self.currentPath = remove_last_occurence_of_sub_string(self.currentPath, "/" + tag) + +        def handle_data(self, data): +            if str in data: +                self.paths.append(self.currentPath) + +    parser = PathHTMLParser() +    parser.feed(html) +    return parser.paths + + +def get_XSS_data(body: str, request_URL: str, injection_point: str) -> Optional[XSSData]: +    """ Return a XSSDict if there is a XSS otherwise return None """ +    def in_script(text, index, body) -> bool: +        """ Whether the Numberth occurence of the first string in the second +            string is inside a script tag """ +        paths = paths_to_text(body.decode('utf-8'), text.decode("utf-8")) +        try: +            path = paths[index] +            return "script" in path +        except IndexError: +            return False + +    def in_HTML(text: bytes, index: int, body: bytes) -> bool: +        """ Whether the Numberth occurence of the first string in the second +            string is inside the HTML but not inside a script tag or part of +            a HTML attribute""" +        # if there is a < then lxml will interpret that as a tag, so only search for the stuff before it +        text = text.split(b"<")[0] +        paths = paths_to_text(body.decode('utf-8'), text.decode("utf-8")) +        try: +            path = paths[index] +            return "script" not in path +        except IndexError: +            return False + +    def inject_javascript_handler(html: str) -> bool: +        """ Whether you can inject a Javascript:alert(0) as a link """ +        class injectJSHandlerHTMLParser(HTMLParser): +            injectJSHandler = False + +            def handle_starttag(self, tag, attrs): +                for name, value in attrs: +                    if name == "href" and value.startswith(FRONT_WALL.decode('utf-8')): +                        self.injectJSHandler = True + +        parser = injectJSHandlerHTMLParser() +        parser.feed(html) +        return parser.injectJSHandler +    # Only convert the body to bytes if needed +    if isinstance(body, str): +        body = bytes(body, 'utf-8') +    # Regex for between 24 and 72 (aka 24*3) characters encapsulated by the walls +    regex = re.compile(b"""%s.{24,72}?%s""" % (FRONT_WALL, BACK_WALL)) +    matches = regex.findall(body) +    for index, match in enumerate(matches): +        # Where the string is injected into the HTML +        in_script = in_script(match, index, body) +        in_HTML = in_HTML(match, index, body) +        in_tag = not in_script and not in_HTML +        in_single_quotes = inside_quote("'", match, index, body) +        in_double_quotes = inside_quote('"', match, index, body) +        # Whether you can inject: +        inject_open_angle = b"ao<ac" in match  # open angle brackets +        inject_close_angle = b"ac>so" in match  # close angle brackets +        inject_single_quotes = b"s'd" in match  # single quotes +        inject_double_quotes = b'd"ao' in match  # double quotes +        inject_slash = b"sl/bsl" in match  # forward slashes +        inject_semi = b"se;sl" in match  # semicolons +        inject_equals = b"eq=" in match  # equals sign +        if in_script and inject_slash and inject_open_angle and inject_close_angle:  # e.g. <script>PAYLOAD</script> +            return XSSData(request_URL, +                           injection_point, +                           '</script><script>alert(0)</script><script>', +                           match.decode('utf-8')) +        elif in_script and in_single_quotes and inject_single_quotes and inject_semi:  # e.g. <script>t='PAYLOAD';</script> +            return XSSData(request_URL, +                           injection_point, +                           "';alert(0);g='", +                           match.decode('utf-8')) +        elif in_script and in_double_quotes and inject_double_quotes and inject_semi:  # e.g. <script>t="PAYLOAD";</script> +            return XSSData(request_URL, +                           injection_point, +                           '";alert(0);g="', +                           match.decode('utf-8')) +        elif in_tag and in_single_quotes and inject_single_quotes and inject_open_angle and inject_close_angle and inject_slash: +            # e.g. <a href='PAYLOAD'>Test</a> +            return XSSData(request_URL, +                           injection_point, +                           "'><script>alert(0)</script>", +                           match.decode('utf-8')) +        elif in_tag and in_double_quotes and inject_double_quotes and inject_open_angle and inject_close_angle and inject_slash: +            # e.g. <a href="PAYLOAD">Test</a> +            return XSSData(request_URL, +                           injection_point, +                           '"><script>alert(0)</script>', +                           match.decode('utf-8')) +        elif in_tag and not in_double_quotes and not in_single_quotes and inject_open_angle and inject_close_angle and inject_slash: +            # e.g. <a href=PAYLOAD>Test</a> +            return XSSData(request_URL, +                           injection_point, +                           '><script>alert(0)</script>', +                           match.decode('utf-8')) +        elif inject_javascript_handler(body.decode('utf-8')):  # e.g. <html><a href=PAYLOAD>Test</a> +            return XSSData(request_URL, +                           injection_point, +                           'Javascript:alert(0)', +                           match.decode('utf-8')) +        elif in_tag and in_double_quotes and inject_double_quotes and inject_equals:  # e.g. <a href="PAYLOAD">Test</a> +            return XSSData(request_URL, +                           injection_point, +                           '" onmouseover="alert(0)" t="', +                           match.decode('utf-8')) +        elif in_tag and in_single_quotes and inject_single_quotes and inject_equals:  # e.g. <a href='PAYLOAD'>Test</a> +            return XSSData(request_URL, +                           injection_point, +                           "' onmouseover='alert(0)' t='", +                           match.decode('utf-8')) +        elif in_tag and not in_single_quotes and not in_double_quotes and inject_equals:  # e.g. <a href=PAYLOAD>Test</a> +            return XSSData(request_URL, +                           injection_point, +                           " onmouseover=alert(0) t=", +                           match.decode('utf-8')) +        elif in_HTML and not in_script and inject_open_angle and inject_close_angle and inject_slash:  # e.g. <html>PAYLOAD</html> +            return XSSData(request_URL, +                           injection_point, +                           '<script>alert(0)</script>', +                           match.decode('utf-8')) +        else: +            return None + + +# response is mitmproxy's entry point +def response(flow: http.HTTPFlow) -> None: +    cookiesDict = get_cookies(flow) +    # Example: http://xss.guru/unclaimedScriptTag.html +    find_unclaimed_URLs(flow.response.content, flow.request.url) +    results = test_end_of_URL_injection(flow.response.content.decode('utf-8'), flow.request.url, cookiesDict) +    log_XSS_data(results[0]) +    log_SQLi_data(results[1]) +    # Example: https://daviddworken.com/vulnerableReferer.php +    results = test_referer_injection(flow.response.content.decode('utf-8'), flow.request.url, cookiesDict) +    log_XSS_data(results[0]) +    log_SQLi_data(results[1]) +    # Example: https://daviddworken.com/vulnerableUA.php +    results = test_user_agent_injection(flow.response.content.decode('utf-8'), flow.request.url, cookiesDict) +    log_XSS_data(results[0]) +    log_SQLi_data(results[1]) +    if "?" in flow.request.url: +        # Example: https://daviddworken.com/vulnerable.php?name= +        results = test_query_injection(flow.response.content.decode('utf-8'), flow.request.url, cookiesDict) +        log_XSS_data(results[0]) +        log_SQLi_data(results[1]) diff --git a/mitmproxy/certs.py b/mitmproxy/certs.py index 4b939c80..6485eed7 100644 --- a/mitmproxy/certs.py +++ b/mitmproxy/certs.py @@ -93,9 +93,9 @@ def dummy_cert(privkey, cacert, commonname, sans):          try:              ipaddress.ip_address(i.decode("ascii"))          except ValueError: -            ss.append(b"DNS: %s" % i) +            ss.append(b"DNS:%s" % i)          else: -            ss.append(b"IP: %s" % i) +            ss.append(b"IP:%s" % i)      ss = b", ".join(ss)      cert = OpenSSL.crypto.X509() @@ -356,14 +356,14 @@ class CertStore:  class _GeneralName(univ.Choice): -    # We are only interested in dNSNames. We use a default handler to ignore -    # other types. -    # TODO: We should also handle iPAddresses. +    # We only care about dNSName and iPAddress      componentType = namedtype.NamedTypes(          namedtype.NamedType('dNSName', char.IA5String().subtype(              implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2) -        ) -        ), +        )), +        namedtype.NamedType('iPAddress', univ.OctetString().subtype( +            implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 7) +        )),      ) @@ -477,5 +477,10 @@ class SSLCert(serializable.Serializable):                  except PyAsn1Error:                      continue                  for i in dec[0]: -                    altnames.append(i[0].asOctets()) +                    if i[0] is None and isinstance(i[1], univ.OctetString) and not isinstance(i[1], char.IA5String): +                        # This would give back the IP address: b'.'.join([str(e).encode() for e in i[1].asNumbers()]) +                        continue +                    else: +                        e = i[0].asOctets() +                    altnames.append(e)          return altnames diff --git a/mitmproxy/connections.py b/mitmproxy/connections.py index f914c7d2..9359b67d 100644 --- a/mitmproxy/connections.py +++ b/mitmproxy/connections.py @@ -54,24 +54,35 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):          return bool(self.connection) and not self.finished      def __repr__(self): +        if self.ssl_established: +            tls = "[{}] ".format(self.tls_version) +        else: +            tls = "" +          if self.alpn_proto_negotiated:              alpn = "[ALPN: {}] ".format(                  strutils.bytes_to_escaped_str(self.alpn_proto_negotiated)              )          else:              alpn = "" -        return "<ClientConnection: {ssl}{alpn}{address}>".format( -            ssl="[ssl] " if self.ssl_established else "", + +        return "<ClientConnection: {tls}{alpn}{host}:{port}>".format( +            tls=tls,              alpn=alpn, -            address=repr(self.address) +            host=self.address[0], +            port=self.address[1],          )      @property      def tls_established(self):          return self.ssl_established +    @tls_established.setter +    def tls_established(self, value): +        self.ssl_established = value +      _stateobject_attributes = dict( -        address=tcp.Address, +        address=tuple,          ssl_established=bool,          clientcert=certs.SSLCert,          mitmcert=certs.SSLCert, @@ -99,7 +110,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):      @classmethod      def make_dummy(cls, address):          return cls.from_state(dict( -            address=dict(address=address, use_ipv6=False), +            address=address,              clientcert=None,              mitmcert=None,              ssl_established=False, @@ -143,6 +154,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):          cert: The certificate presented by the remote during the TLS handshake          sni: Server Name Indication sent by the proxy during the TLS handshake          alpn_proto_negotiated: The negotiated application protocol +        tls_version: TLS version          via: The underlying server connection (e.g. the connection to the upstream proxy in upstream proxy mode)          timestamp_start: Connection start timestamp          timestamp_tcp_setup: TCP ACK received timestamp @@ -154,6 +166,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):          tcp.TCPClient.__init__(self, address, source_address, spoof_source_address)          self.alpn_proto_negotiated = None +        self.tls_version = None          self.via = None          self.timestamp_start = None          self.timestamp_end = None @@ -165,35 +178,41 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):      def __repr__(self):          if self.ssl_established and self.sni: -            ssl = "[ssl: {0}] ".format(self.sni) +            tls = "[{}: {}] ".format(self.tls_version or "TLS", self.sni)          elif self.ssl_established: -            ssl = "[ssl] " +            tls = "[{}] ".format(self.tls_version or "TLS")          else: -            ssl = "" +            tls = ""          if self.alpn_proto_negotiated:              alpn = "[ALPN: {}] ".format(                  strutils.bytes_to_escaped_str(self.alpn_proto_negotiated)              )          else:              alpn = "" -        return "<ServerConnection: {ssl}{alpn}{address}>".format( -            ssl=ssl, +        return "<ServerConnection: {tls}{alpn}{host}:{port}>".format( +            tls=tls,              alpn=alpn, -            address=repr(self.address) +            host=self.address[0], +            port=self.address[1],          )      @property      def tls_established(self):          return self.ssl_established +    @tls_established.setter +    def tls_established(self, value): +        self.ssl_established = value +      _stateobject_attributes = dict( -        address=tcp.Address, -        ip_address=tcp.Address, -        source_address=tcp.Address, +        address=tuple, +        ip_address=tuple, +        source_address=tuple,          ssl_established=bool,          cert=certs.SSLCert,          sni=str,          alpn_proto_negotiated=bytes, +        tls_version=str,          timestamp_start=float,          timestamp_tcp_setup=float,          timestamp_ssl_setup=float, @@ -209,12 +228,13 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):      @classmethod      def make_dummy(cls, address):          return cls.from_state(dict( -            address=dict(address=address, use_ipv6=False), -            ip_address=dict(address=address, use_ipv6=False), +            address=address, +            ip_address=address,              cert=None,              sni=None,              alpn_proto_negotiated=None, -            source_address=dict(address=('', 0), use_ipv6=False), +            tls_version=None, +            source_address=('', 0),              ssl_established=False,              timestamp_start=None,              timestamp_tcp_setup=None, @@ -244,13 +264,14 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):              else:                  path = os.path.join(                      clientcerts, -                    self.address.host.encode("idna").decode()) + ".pem" +                    self.address[0].encode("idna").decode()) + ".pem"                  if os.path.exists(path):                      clientcert = path          self.convert_to_ssl(cert=clientcert, sni=sni, **kwargs)          self.sni = sni          self.alpn_proto_negotiated = self.get_alpn_proto_negotiated() +        self.tls_version = self.connection.get_protocol_version_name()          self.timestamp_ssl_setup = time.time()      def finish(self): diff --git a/mitmproxy/flowfilter.py b/mitmproxy/flowfilter.py index 4a0eeeb1..7c4f95f7 100644 --- a/mitmproxy/flowfilter.py +++ b/mitmproxy/flowfilter.py @@ -348,7 +348,10 @@ class FSrc(_Rex):      is_binary = False      def __call__(self, f): -        return f.client_conn.address and self.re.search(repr(f.client_conn.address)) +        if not f.client_conn or not f.client_conn.address: +            return False +        r = "{}:{}".format(f.client_conn.address[0], f.client_conn.address[1]) +        return f.client_conn.address and self.re.search(r)  class FDst(_Rex): @@ -357,7 +360,10 @@ class FDst(_Rex):      is_binary = False      def __call__(self, f): -        return f.server_conn.address and self.re.search(repr(f.server_conn.address)) +        if not f.server_conn or not f.server_conn.address: +            return False +        r = "{}:{}".format(f.server_conn.address[0], f.server_conn.address[1]) +        return f.server_conn.address and self.re.search(r)  class _Int(_Action): @@ -425,6 +431,7 @@ filter_unary = [      FReq,      FResp,      FTCP, +    FWebSocket,  ]  filter_rex = [      FBod, diff --git a/mitmproxy/http.py b/mitmproxy/http.py index f0cabcf8..c6b17533 100644 --- a/mitmproxy/http.py +++ b/mitmproxy/http.py @@ -5,7 +5,6 @@ from mitmproxy import flow  from mitmproxy.net import http  from mitmproxy import version -from mitmproxy.net import tcp  from mitmproxy import connections  # noqa @@ -245,9 +244,8 @@ def make_error_response(  def make_connect_request(address): -    address = tcp.Address.wrap(address)      return HTTPRequest( -        "authority", b"CONNECT", None, address.host, address.port, None, b"HTTP/1.1", +        "authority", b"CONNECT", None, address[0], address[1], None, b"HTTP/1.1",          http.Headers(), b""      ) diff --git a/mitmproxy/io_compat.py b/mitmproxy/io_compat.py index c12d2098..16cbc9fe 100644 --- a/mitmproxy/io_compat.py +++ b/mitmproxy/io_compat.py @@ -88,12 +88,20 @@ def convert_019_100(data):  def convert_100_200(data):      data["version"] = (2, 0, 0) +    data["client_conn"]["address"] = data["client_conn"]["address"]["address"] +    data["server_conn"]["address"] = data["server_conn"]["address"]["address"] +    data["server_conn"]["source_address"] = data["server_conn"]["source_address"]["address"] +    if data["server_conn"]["ip_address"]: +        data["server_conn"]["ip_address"] = data["server_conn"]["ip_address"]["address"]      return data  def convert_200_300(data):      data["version"] = (3, 0, 0)      data["client_conn"]["mitmcert"] = None +    data["server_conn"]["tls_version"] = None +    if data["server_conn"]["via"]: +        data["server_conn"]["via"]["tls_version"] = None      return data diff --git a/mitmproxy/master.py b/mitmproxy/master.py index 3a3f4399..633f32aa 100644 --- a/mitmproxy/master.py +++ b/mitmproxy/master.py @@ -149,8 +149,8 @@ class Master:          """          if isinstance(f, http.HTTPFlow):              if self.server and self.options.mode == "reverse": -                f.request.host = self.server.config.upstream_server.address.host -                f.request.port = self.server.config.upstream_server.address.port +                f.request.host = self.server.config.upstream_server.address[0] +                f.request.port = self.server.config.upstream_server.address[1]                  f.request.scheme = self.server.config.upstream_server.scheme          f.reply = controller.DummyReply()          for e, o in eventsequence.iterate(f): diff --git a/mitmproxy/net/check.py b/mitmproxy/net/check.py index f793d397..d30c1df6 100644 --- a/mitmproxy/net/check.py +++ b/mitmproxy/net/check.py @@ -1,3 +1,4 @@ +import ipaddress  import re  # Allow underscore in host name @@ -6,17 +7,26 @@ _label_valid = re.compile(b"(?!-)[A-Z\d\-_]{1,63}(?<!-)$", re.IGNORECASE)  def is_valid_host(host: bytes) -> bool:      """ -        Checks if a hostname is valid. +    Checks if the passed bytes are a valid DNS hostname or an IPv4/IPv6 address.      """      try:          host.decode("idna")      except ValueError:          return False +    # RFC1035: 255 bytes or less.      if len(host) > 255:          return False      if host and host[-1:] == b".":          host = host[:-1] -    return all(_label_valid.match(x) for x in host.split(b".")) +    # DNS hostname +    if all(_label_valid.match(x) for x in host.split(b".")): +        return True +    # IPv4/IPv6 address +    try: +        ipaddress.ip_address(host.decode('idna')) +        return True +    except ValueError: +        return False  def is_valid_port(port): diff --git a/mitmproxy/net/socks.py b/mitmproxy/net/socks.py index a972283e..570a4afb 100644 --- a/mitmproxy/net/socks.py +++ b/mitmproxy/net/socks.py @@ -2,7 +2,6 @@ import struct  import array  import ipaddress -from mitmproxy.net import tcp  from mitmproxy.net import check  from mitmproxy.types import bidi @@ -179,7 +178,7 @@ class Message:          self.ver = ver          self.msg = msg          self.atyp = atyp -        self.addr = tcp.Address.wrap(addr) +        self.addr = addr      def assert_socks5(self):          if self.ver != VERSION.SOCKS5: @@ -199,37 +198,34 @@ class Message:          if atyp == ATYP.IPV4_ADDRESS:              # We use tnoa here as ntop is not commonly available on Windows.              host = ipaddress.IPv4Address(f.safe_read(4)).compressed -            use_ipv6 = False          elif atyp == ATYP.IPV6_ADDRESS:              host = ipaddress.IPv6Address(f.safe_read(16)).compressed -            use_ipv6 = True          elif atyp == ATYP.DOMAINNAME:              length, = struct.unpack("!B", f.safe_read(1))              host = f.safe_read(length)              if not check.is_valid_host(host):                  raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, "Invalid hostname: %s" % host)              host = host.decode("idna") -            use_ipv6 = False          else:              raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED,                               "Socks Request: Unknown ATYP: %s" % atyp)          port, = struct.unpack("!H", f.safe_read(2)) -        addr = tcp.Address((host, port), use_ipv6=use_ipv6) +        addr = (host, port)          return cls(ver, msg, atyp, addr)      def to_file(self, f):          f.write(struct.pack("!BBBB", self.ver, self.msg, 0x00, self.atyp))          if self.atyp == ATYP.IPV4_ADDRESS: -            f.write(ipaddress.IPv4Address(self.addr.host).packed) +            f.write(ipaddress.IPv4Address(self.addr[0]).packed)          elif self.atyp == ATYP.IPV6_ADDRESS: -            f.write(ipaddress.IPv6Address(self.addr.host).packed) +            f.write(ipaddress.IPv6Address(self.addr[0]).packed)          elif self.atyp == ATYP.DOMAINNAME: -            f.write(struct.pack("!B", len(self.addr.host))) -            f.write(self.addr.host.encode("idna")) +            f.write(struct.pack("!B", len(self.addr[0]))) +            f.write(self.addr[0].encode("idna"))          else:              raise SocksError(                  REP.ADDRESS_TYPE_NOT_SUPPORTED,                  "Unknown ATYP: %s" % self.atyp              ) -        f.write(struct.pack("!H", self.addr.port)) +        f.write(struct.pack("!H", self.addr[1])) diff --git a/mitmproxy/net/tcp.py b/mitmproxy/net/tcp.py index eabc8006..dc5e2ee2 100644 --- a/mitmproxy/net/tcp.py +++ b/mitmproxy/net/tcp.py @@ -19,7 +19,6 @@ from OpenSSL import SSL  from mitmproxy import certs  from mitmproxy.utils import version_check -from mitmproxy.types import serializable  from mitmproxy import exceptions  from mitmproxy.types import basethread @@ -29,6 +28,10 @@ version_check.check_pyopenssl_version()  socket_fileobject = socket.SocketIO +# workaround for https://bugs.python.org/issue29515 +# Python 3.5 and 3.6 for Windows is missing a constant +IPPROTO_IPV6 = getattr(socket, "IPPROTO_IPV6", 41) +  EINTR = 4  HAS_ALPN = SSL._lib.Cryptography_HAS_ALPN @@ -299,73 +302,6 @@ class Reader(_FileLike):              raise NotImplementedError("Can only peek into (pyOpenSSL) sockets") -class Address(serializable.Serializable): - -    """ -        This class wraps an IPv4/IPv6 tuple to provide named attributes and -        ipv6 information. -    """ - -    def __init__(self, address, use_ipv6=False): -        self.address = tuple(address) -        self.use_ipv6 = use_ipv6 - -    def get_state(self): -        return { -            "address": self.address, -            "use_ipv6": self.use_ipv6 -        } - -    def set_state(self, state): -        self.address = state["address"] -        self.use_ipv6 = state["use_ipv6"] - -    @classmethod -    def from_state(cls, state): -        return Address(**state) - -    @classmethod -    def wrap(cls, t): -        if isinstance(t, cls): -            return t -        else: -            return cls(t) - -    def __call__(self): -        return self.address - -    @property -    def host(self): -        return self.address[0] - -    @property -    def port(self): -        return self.address[1] - -    @property -    def use_ipv6(self): -        return self.family == socket.AF_INET6 - -    @use_ipv6.setter -    def use_ipv6(self, b): -        self.family = socket.AF_INET6 if b else socket.AF_INET - -    def __repr__(self): -        return "{}:{}".format(self.host, self.port) - -    def __eq__(self, other): -        if not other: -            return False -        other = Address.wrap(other) -        return (self.address, self.family) == (other.address, other.family) - -    def __ne__(self, other): -        return not self.__eq__(other) - -    def __hash__(self): -        return hash(self.address) ^ 42  # different hash than the tuple alone. - -  def ssl_read_select(rlist, timeout):      """      This is a wrapper around select.select() which also works for SSL.Connections @@ -452,7 +388,7 @@ class _Connection:      def __init__(self, connection):          if connection:              self.connection = connection -            self.ip_address = Address(connection.getpeername()) +            self.ip_address = connection.getpeername()              self._makefile()          else:              self.connection = None @@ -629,28 +565,6 @@ class TCPClient(_Connection):          self.sni = None          self.spoof_source_address = spoof_source_address -    @property -    def address(self): -        return self.__address - -    @address.setter -    def address(self, address): -        if address: -            self.__address = Address.wrap(address) -        else: -            self.__address = None - -    @property -    def source_address(self): -        return self.__source_address - -    @source_address.setter -    def source_address(self, source_address): -        if source_address: -            self.__source_address = Address.wrap(source_address) -        else: -            self.__source_address = None -      def close(self):          # Make sure to close the real socket, not the SSL proxy.          # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection, @@ -741,34 +655,57 @@ class TCPClient(_Connection):          self.rfile.set_descriptor(self.connection)          self.wfile.set_descriptor(self.connection) -    def makesocket(self): +    def makesocket(self, family, type, proto):          # some parties (cuckoo sandbox) need to hook this -        return socket.socket(self.address.family, socket.SOCK_STREAM) +        return socket.socket(family, type, proto) + +    def create_connection(self, timeout=None): +        # Based on the official socket.create_connection implementation of Python 3.6. +        # https://github.com/python/cpython/blob/3cc5817cfaf5663645f4ee447eaed603d2ad290a/Lib/socket.py + +        err = None +        for res in socket.getaddrinfo(self.address[0], self.address[1], 0, socket.SOCK_STREAM): +            af, socktype, proto, canonname, sa = res +            sock = None +            try: +                sock = self.makesocket(af, socktype, proto) +                if timeout: +                    sock.settimeout(timeout) +                if self.source_address: +                    sock.bind(self.source_address) +                if self.spoof_source_address: +                    try: +                        if not sock.getsockopt(socket.SOL_IP, socket.IP_TRANSPARENT): +                            sock.setsockopt(socket.SOL_IP, socket.IP_TRANSPARENT, 1) +                    except Exception as e: +                        # socket.IP_TRANSPARENT might not be available on every OS and Python version +                        raise exceptions.TcpException( +                            "Failed to spoof the source address: " + e.strerror +                        ) +                sock.connect(sa) +                return sock + +            except socket.error as _: +                err = _ +                if sock is not None: +                    sock.close() + +        if err is not None: +            raise err +        else: +            raise socket.error("getaddrinfo returns an empty list")      def connect(self):          try: -            connection = self.makesocket() - -            if self.spoof_source_address: -                try: -                    # 19 is `IP_TRANSPARENT`, which is only available on Python 3.3+ on some OSes -                    if not connection.getsockopt(socket.SOL_IP, 19): -                        connection.setsockopt(socket.SOL_IP, 19, 1) -                except socket.error as e: -                    raise exceptions.TcpException( -                        "Failed to spoof the source address: " + e.strerror -                    ) -            if self.source_address: -                connection.bind(self.source_address()) -            connection.connect(self.address()) -            self.source_address = Address(connection.getsockname()) +            connection = self.create_connection()          except (socket.error, IOError) as err:              raise exceptions.TcpException(                  'Error connecting to "%s": %s' % -                (self.address.host, err) +                (self.address[0], err)              )          self.connection = connection -        self.ip_address = Address(connection.getpeername()) +        self.source_address = connection.getsockname() +        self.ip_address = connection.getpeername()          self._makefile()          return ConnectionCloser(self) @@ -793,7 +730,7 @@ class BaseHandler(_Connection):      def __init__(self, connection, address, server):          super().__init__(connection) -        self.address = Address.wrap(address) +        self.address = address          self.server = server          self.clientcert = None @@ -915,19 +852,36 @@ class TCPServer:      request_queue_size = 20      def __init__(self, address): -        self.address = Address.wrap(address) +        self.address = address          self.__is_shut_down = threading.Event()          self.__shutdown_request = False -        self.socket = socket.socket(self.address.family, socket.SOCK_STREAM) -        self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) -        self.socket.bind(self.address()) -        self.address = Address.wrap(self.socket.getsockname()) + +        if self.address == 'localhost': +            raise socket.error("Binding to 'localhost' is prohibited. Please use '::1' or '127.0.0.1' directly.") + +        try: +            # First try to bind an IPv6 socket, with possible IPv4 if the OS supports it. +            # This allows us to accept connections for ::1 and 127.0.0.1 on the same socket. +            # Only works if self.address == "" +            self.socket = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) +            self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) +            self.socket.setsockopt(IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) +            self.socket.bind(self.address) +        except socket.error: +            self.socket = None + +        if not self.socket: +            # Binding to an IPv6 socket failed, lets fall back to IPv4. +            self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) +            self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) +            self.socket.bind(self.address) + +        self.address = self.socket.getsockname()          self.socket.listen(self.request_queue_size)          self.handler_counter = Counter()      def connection_thread(self, connection, client_address):          with self.handler_counter: -            client_address = Address(client_address)              try:                  self.handle_client_connection(connection, client_address)              except: @@ -954,8 +908,8 @@ class TCPServer:                              self.__class__.__name__,                              client_address[0],                              client_address[1], -                            self.address.host, -                            self.address.port +                            self.address[0], +                            self.address[1],                          ),                          target=self.connection_thread,                          args=(connection, client_address), @@ -964,7 +918,7 @@ class TCPServer:                      try:                          t.start()                      except threading.ThreadError: -                        self.handle_error(connection, Address(client_address)) +                        self.handle_error(connection, client_address)                          connection.close()          finally:              self.__shutdown_request = False diff --git a/mitmproxy/net/wsgi.py b/mitmproxy/net/wsgi.py index 8bc5bb89..a40dceca 100644 --- a/mitmproxy/net/wsgi.py +++ b/mitmproxy/net/wsgi.py @@ -4,14 +4,13 @@ import urllib  import io  from mitmproxy.net import http -from mitmproxy.net import tcp  from mitmproxy.utils import strutils  class ClientConn:      def __init__(self, address): -        self.address = tcp.Address.wrap(address) +        self.address = address  class Flow: @@ -84,8 +83,8 @@ class WSGIAdaptor:          }          environ.update(extra)          if flow.client_conn.address: -            environ["REMOTE_ADDR"] = strutils.always_str(flow.client_conn.address.host, "latin-1") -            environ["REMOTE_PORT"] = flow.client_conn.address.port +            environ["REMOTE_ADDR"] = strutils.always_str(flow.client_conn.address[0], "latin-1") +            environ["REMOTE_PORT"] = flow.client_conn.address[1]          for key, value in flow.request.headers.items():              key = 'HTTP_' + strutils.always_str(key, "latin-1").upper().replace('-', '_') diff --git a/mitmproxy/proxy/config.py b/mitmproxy/proxy/config.py index 513c0b5b..ea2f7c7f 100644 --- a/mitmproxy/proxy/config.py +++ b/mitmproxy/proxy/config.py @@ -23,8 +23,7 @@ class HostMatcher:      def __call__(self, address):          if not address:              return False -        address = tcp.Address.wrap(address) -        host = "%s:%s" % (address.host, address.port) +        host = "%s:%s" % address          if any(rex.search(host) for rex in self.regexes):              return True          else: @@ -47,7 +46,7 @@ def parse_server_spec(spec):              "Invalid server specification: %s" % spec          )      host, port = p[1:3] -    address = tcp.Address((host.decode("ascii"), port)) +    address = (host.decode("ascii"), port)      scheme = p[0].decode("ascii").lower()      return ServerSpec(scheme, address) diff --git a/mitmproxy/proxy/protocol/base.py b/mitmproxy/proxy/protocol/base.py index 93619171..b10bb8f5 100644 --- a/mitmproxy/proxy/protocol/base.py +++ b/mitmproxy/proxy/protocol/base.py @@ -101,7 +101,7 @@ class ServerConnectionMixin:          self.server_conn = None          if self.config.options.spoof_source_address and self.config.options.upstream_bind_address == '':              self.server_conn = connections.ServerConnection( -                server_address, (self.ctx.client_conn.address.host, 0), True) +                server_address, (self.ctx.client_conn.address[0], 0), True)          else:              self.server_conn = connections.ServerConnection(                  server_address, (self.config.options.upstream_bind_address, 0), @@ -118,8 +118,8 @@ class ServerConnectionMixin:          address = self.server_conn.address          if address:              self_connect = ( -                address.port == self.config.options.listen_port and -                address.host in ("localhost", "127.0.0.1", "::1") +                address[1] == self.config.options.listen_port and +                address[0] in ("localhost", "127.0.0.1", "::1")              )              if self_connect:                  raise exceptions.ProtocolException( @@ -133,7 +133,7 @@ class ServerConnectionMixin:          """          if self.server_conn.connected():              self.disconnect() -        self.log("Set new server address: " + repr(address), "debug") +        self.log("Set new server address: {}:{}".format(address[0], address[1]), "debug")          self.server_conn.address = address          self.__check_self_connect() @@ -150,7 +150,7 @@ class ServerConnectionMixin:          self.server_conn = connections.ServerConnection(              address, -            (self.server_conn.source_address.host, 0), +            (self.server_conn.source_address[0], 0),              self.config.options.spoof_source_address          ) diff --git a/mitmproxy/proxy/protocol/http.py b/mitmproxy/proxy/protocol/http.py index a351ad66..b6f8463d 100644 --- a/mitmproxy/proxy/protocol/http.py +++ b/mitmproxy/proxy/protocol/http.py @@ -8,7 +8,6 @@ from mitmproxy import http  from mitmproxy import flow  from mitmproxy.proxy.protocol import base  from mitmproxy.proxy.protocol.websocket import WebSocketLayer -from mitmproxy.net import tcp  from mitmproxy.net import websockets @@ -59,7 +58,7 @@ class ConnectServerConnection:      """      def __init__(self, address, ctx): -        self.address = tcp.Address.wrap(address) +        self.address = address          self._ctx = ctx      @property @@ -112,9 +111,8 @@ class UpstreamConnectLayer(base.Layer):      def set_server(self, address):          if self.ctx.server_conn.connected():              self.ctx.disconnect() -        address = tcp.Address.wrap(address) -        self.connect_request.host = address.host -        self.connect_request.port = address.port +        self.connect_request.host = address[0] +        self.connect_request.port = address[1]          self.server_conn.address = address @@ -291,7 +289,7 @@ class HttpLayer(base.Layer):          # update host header in reverse proxy mode          if self.config.options.mode == "reverse" and not self.config.options.keep_host_header: -            f.request.host_header = self.config.upstream_server.address.host +            f.request.host_header = self.config.upstream_server.address[0]          # Determine .scheme, .host and .port attributes for inline scripts. For          # absolute-form requests, they are directly given in the request. For @@ -302,8 +300,8 @@ class HttpLayer(base.Layer):              # Setting request.host also updates the host header, which we want              # to preserve              host_header = f.request.host_header -            f.request.host = self.__initial_server_conn.address.host -            f.request.port = self.__initial_server_conn.address.port +            f.request.host = self.__initial_server_conn.address[0] +            f.request.port = self.__initial_server_conn.address[1]              f.request.host_header = host_header  # set again as .host overwrites this.              f.request.scheme = "https" if self.__initial_server_tls else "http"          self.channel.ask("request", f) @@ -453,14 +451,14 @@ class HttpLayer(base.Layer):              self.set_server(address)      def establish_server_connection(self, host: str, port: int, scheme: str): -        address = tcp.Address((host, port))          tls = (scheme == "https")          if self.mode is HTTPMode.regular or self.mode is HTTPMode.transparent:              # If there's an existing connection that doesn't match our expectations, kill it. +            address = (host, port)              if address != self.server_conn.address or tls != self.server_tls:                  self.set_server(address) -                self.set_server_tls(tls, address.host) +                self.set_server_tls(tls, address[0])              # Establish connection is neccessary.              if not self.server_conn.connected():                  self.connect() diff --git a/mitmproxy/proxy/protocol/http2.py b/mitmproxy/proxy/protocol/http2.py index cdce24b3..01406798 100644 --- a/mitmproxy/proxy/protocol/http2.py +++ b/mitmproxy/proxy/protocol/http2.py @@ -97,7 +97,6 @@ class Http2Layer(base.Layer):              client_side=False,              header_encoding=False,              validate_outbound_headers=False, -            normalize_outbound_headers=False,              validate_inbound_headers=False)          self.connections[self.client_conn] = SafeH2Connection(self.client_conn, config=config) @@ -107,7 +106,6 @@ class Http2Layer(base.Layer):                  client_side=True,                  header_encoding=False,                  validate_outbound_headers=False, -                normalize_outbound_headers=False,                  validate_inbound_headers=False)              self.connections[self.server_conn] = SafeH2Connection(self.server_conn, config=config)          self.connections[self.server_conn].initiate_connection() @@ -599,9 +597,6 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr      def send_response_headers(self, response):          headers = response.headers.copy()          headers.insert(0, ":status", str(response.status_code)) -        for forbidden_header in h2.utilities.CONNECTION_HEADERS: -            if forbidden_header in headers: -                del headers[forbidden_header]          with self.connections[self.client_conn].lock:              self.connections[self.client_conn].safe_send_headers(                  self.raise_zombie, diff --git a/mitmproxy/proxy/protocol/tls.py b/mitmproxy/proxy/protocol/tls.py index 08ce53d0..7d15130f 100644 --- a/mitmproxy/proxy/protocol/tls.py +++ b/mitmproxy/proxy/protocol/tls.py @@ -545,8 +545,9 @@ class TlsLayer(base.Layer):              raise exceptions.InvalidServerCertificate(str(e))          except exceptions.TlsException as e:              raise exceptions.TlsProtocolException( -                "Cannot establish TLS with {address} (sni: {sni}): {e}".format( -                    address=repr(self.server_conn.address), +                "Cannot establish TLS with {host}:{port} (sni: {sni}): {e}".format( +                    host=self.server_conn.address[0], +                    port=self.server_conn.address[1],                      sni=self.server_sni,                      e=repr(e)                  ) @@ -567,7 +568,7 @@ class TlsLayer(base.Layer):          # However, we may just want to establish TLS so that we can send an error message to the client,          # in which case the address can be None.          if self.server_conn.address: -            host = self.server_conn.address.host.encode("idna") +            host = self.server_conn.address[0].encode("idna")          # Should we incorporate information from the server certificate?          use_upstream_cert = ( diff --git a/mitmproxy/proxy/root_context.py b/mitmproxy/proxy/root_context.py index 180fc9ca..1987c8dc 100644 --- a/mitmproxy/proxy/root_context.py +++ b/mitmproxy/proxy/root_context.py @@ -68,7 +68,7 @@ class RootContext:                  top_layer,                  client_tls,                  top_layer.server_tls, -                top_layer.server_conn.address.host +                top_layer.server_conn.address[0]              )          if isinstance(top_layer, protocol.ServerConnectionMixin) or isinstance(top_layer, protocol.UpstreamConnectLayer):              return protocol.TlsLayer(top_layer, client_tls, client_tls) @@ -104,7 +104,7 @@ class RootContext:          Send a log message to the master.          """          full_msg = [ -            "{}: {}".format(repr(self.client_conn.address), msg) +            "{}:{}: {}".format(self.client_conn.address[0], self.client_conn.address[1], msg)          ]          for i in subs:              full_msg.append("  -> " + i) diff --git a/mitmproxy/proxy/server.py b/mitmproxy/proxy/server.py index 97018dad..8082cb64 100644 --- a/mitmproxy/proxy/server.py +++ b/mitmproxy/proxy/server.py @@ -1,4 +1,3 @@ -import socket  import sys  import traceback @@ -46,10 +45,10 @@ class ProxyServer(tcp.TCPServer):              )              if config.options.mode == "transparent":                  platform.init_transparent_mode() -        except socket.error as e: +        except Exception as e:              raise exceptions.ServerException(                  'Error starting proxy server: ' + repr(e) -            ) +            ) from e          self.channel = None      def set_channel(self, channel): diff --git a/mitmproxy/test/tflow.py b/mitmproxy/test/tflow.py index ea7be4b9..fd665055 100644 --- a/mitmproxy/test/tflow.py +++ b/mitmproxy/test/tflow.py @@ -1,3 +1,5 @@ +import io +  from mitmproxy.net import websockets  from mitmproxy.test import tutils  from mitmproxy import tcp @@ -72,7 +74,8 @@ def twebsocketflow(client_conn=True, server_conn=True, messages=True, err=None,      if messages is True:          messages = [              websocket.WebSocketMessage(websockets.OPCODE.BINARY, True, b"hello binary"), -            websocket.WebSocketMessage(websockets.OPCODE.TEXT, False, "hello text".encode()), +            websocket.WebSocketMessage(websockets.OPCODE.TEXT, True, "hello text".encode()), +            websocket.WebSocketMessage(websockets.OPCODE.TEXT, False, "it's me".encode()),          ]      if err is True:          err = terr() @@ -142,7 +145,7 @@ def tclient_conn():      @return: mitmproxy.proxy.connection.ClientConnection      """      c = connections.ClientConnection.from_state(dict( -        address=dict(address=("address", 22), use_ipv6=True), +        address=("address", 22),          clientcert=None,          mitmcert=None,          ssl_established=False, @@ -155,6 +158,8 @@ def tclient_conn():          tls_version="TLSv1.2",      ))      c.reply = controller.DummyReply() +    c.rfile = io.BytesIO() +    c.wfile = io.BytesIO()      return c @@ -163,8 +168,8 @@ def tserver_conn():      @return: mitmproxy.proxy.connection.ServerConnection      """      c = connections.ServerConnection.from_state(dict( -        address=dict(address=("address", 22), use_ipv6=True), -        source_address=dict(address=("address", 22), use_ipv6=True), +        address=("address", 22), +        source_address=("address", 22),          ip_address=None,          cert=None,          timestamp_start=1, @@ -174,9 +179,12 @@ def tserver_conn():          ssl_established=False,          sni="address",          alpn_proto_negotiated=None, +        tls_version=None,          via=None,      ))      c.reply = controller.DummyReply() +    c.rfile = io.BytesIO() +    c.wfile = io.BytesIO()      return c diff --git a/mitmproxy/tools/console/flowdetailview.py b/mitmproxy/tools/console/flowdetailview.py index d713787a..691f19a5 100644 --- a/mitmproxy/tools/console/flowdetailview.py +++ b/mitmproxy/tools/console/flowdetailview.py @@ -30,8 +30,8 @@ def flowdetails(state, flow: http.HTTPFlow):      if sc is not None:          text.append(urwid.Text([("head", "Server Connection:")]))          parts = [ -            ["Address", repr(sc.address)], -            ["Resolved Address", repr(sc.ip_address)], +            ["Address", "{}:{}".format(sc.address[0], sc.address[1])], +            ["Resolved Address", "{}:{}".format(sc.ip_address[0], sc.ip_address[1])],          ]          if resp:              parts.append(["HTTP Version", resp.http_version]) @@ -92,7 +92,7 @@ def flowdetails(state, flow: http.HTTPFlow):          text.append(urwid.Text([("head", "Client Connection:")]))          parts = [ -            ["Address", repr(cc.address)], +            ["Address", "{}:{}".format(cc.address[0], cc.address[1])],          ]          if req:              parts.append(["HTTP Version", req.http_version]) diff --git a/mitmproxy/tools/console/flowview.py b/mitmproxy/tools/console/flowview.py index a97a9b31..90cca1c5 100644 --- a/mitmproxy/tools/console/flowview.py +++ b/mitmproxy/tools/console/flowview.py @@ -681,7 +681,7 @@ class FlowView(tabs.Tabs):          encoding_map = {              "z": "gzip",              "d": "deflate", -            "b": "brotli", +            "b": "br",          }          conn.encode(encoding_map[key])          signals.flow_change.send(self, flow = self.flow) diff --git a/mitmproxy/tools/console/master.py b/mitmproxy/tools/console/master.py index 4ab9e1f4..d68dc93c 100644 --- a/mitmproxy/tools/console/master.py +++ b/mitmproxy/tools/console/master.py @@ -429,9 +429,11 @@ class ConsoleMaster(master.Master):          super().tcp_message(f)          message = f.messages[-1]          direction = "->" if message.from_client else "<-" -        signals.add_log("{client} {direction} tcp {direction} {server}".format( -            client=repr(f.client_conn.address), -            server=repr(f.server_conn.address), +        signals.add_log("{client_host}:{client_port} {direction} tcp {direction} {server_host}:{server_port}".format( +            client_host=f.client_conn.address[0], +            client_port=f.client_conn.address[1], +            server_host=f.server_conn.address[0], +            server_port=f.server_conn.address[1],              direction=direction,          ), "info")          signals.add_log(strutils.bytes_to_escaped_str(message.content), "debug") diff --git a/mitmproxy/tools/console/palettepicker.py b/mitmproxy/tools/console/palettepicker.py index 4c5c62a0..1f238b0d 100644 --- a/mitmproxy/tools/console/palettepicker.py +++ b/mitmproxy/tools/console/palettepicker.py @@ -43,7 +43,7 @@ class PalettePicker(urwid.WidgetWrap):                  i,                  None,                  lambda: self.master.options.console_palette == name, -                lambda: setattr(self.master.options, "palette", name) +                lambda: setattr(self.master.options, "console_palette", name)              )          for i in high: @@ -59,7 +59,7 @@ class PalettePicker(urwid.WidgetWrap):                      "Transparent",                      "T",                      lambda: master.options.console_palette_transparent, -                    master.options.toggler("palette_transparent") +                    master.options.toggler("console_palette_transparent")                  )              ]          ) diff --git a/mitmproxy/tools/console/statusbar.py b/mitmproxy/tools/console/statusbar.py index 2c7f9efb..d90d932b 100644 --- a/mitmproxy/tools/console/statusbar.py +++ b/mitmproxy/tools/console/statusbar.py @@ -238,8 +238,8 @@ class StatusBar(urwid.WidgetWrap):              dst = self.master.server.config.upstream_server              r.append("[dest:%s]" % mitmproxy.net.http.url.unparse(                  dst.scheme, -                dst.address.host, -                dst.address.port +                dst.address[0], +                dst.address[1],              ))          if self.master.options.scripts:              r.append("[") @@ -272,10 +272,10 @@ class StatusBar(urwid.WidgetWrap):          ]          if self.master.server.bound: -            host = self.master.server.address.host +            host = self.master.server.address[0]              if host == "0.0.0.0":                  host = "*" -            boundaddr = "[%s:%s]" % (host, self.master.server.address.port) +            boundaddr = "[%s:%s]" % (host, self.master.server.address[1])          else:              boundaddr = ""          t.extend(self.get_status()) diff --git a/mitmproxy/tools/dump.py b/mitmproxy/tools/dump.py index e1e40fb0..fefbddfb 100644 --- a/mitmproxy/tools/dump.py +++ b/mitmproxy/tools/dump.py @@ -46,7 +46,7 @@ class DumpMaster(master.Master):          if not self.options.no_server:              self.add_log( -                "Proxy server listening at http://{}".format(server.address), +                "Proxy server listening at http://{}:{}".format(server.address[0], server.address[1]),                  "info"              ) diff --git a/mitmproxy/tools/web/app.py b/mitmproxy/tools/web/app.py index 1f3467cc..893c3dde 100644 --- a/mitmproxy/tools/web/app.py +++ b/mitmproxy/tools/web/app.py @@ -85,6 +85,7 @@ def flow_to_json(flow: mitmproxy.flow.Flow) -> dict:                  "is_replay": flow.response.is_replay,              }      f.get("server_conn", {}).pop("cert", None) +    f.get("client_conn", {}).pop("mitmcert", None)      return f diff --git a/mitmproxy/tools/web/master.py b/mitmproxy/tools/web/master.py index 6ebcfe47..8c7f579d 100644 --- a/mitmproxy/tools/web/master.py +++ b/mitmproxy/tools/web/master.py @@ -109,7 +109,7 @@ class WebMaster(master.Master):          tornado.ioloop.PeriodicCallback(lambda: self.tick(timeout=0), 5).start()          self.add_log( -            "Proxy server listening at http://{}/".format(self.server.address), +            "Proxy server listening at http://{}:{}/".format(self.server.address[0], self.server.address[1]),              "info"          ) diff --git a/mitmproxy/tools/web/static/app.js b/mitmproxy/tools/web/static/app.js index 480d9c71..c21a478b 100644 --- a/mitmproxy/tools/web/static/app.js +++ b/mitmproxy/tools/web/static/app.js @@ -2046,7 +2046,7 @@ function ConnectionInfo(_ref2) {                  _react2.default.createElement(                      'td',                      null, -                    conn.address.address.join(':') +                    conn.address.join(':')                  )              ),              conn.sni && _react2.default.createElement( @@ -8449,7 +8449,7 @@ module.exports = function () {      function destination(regex) {        regex = new RegExp(regex, "i");        function destinationFilter(flow) { -        return !!flow.server_conn.address && regex.test(flow.server_conn.address.address[0] + ":" + flow.server_conn.address.address[1]); +        return !!flow.server_conn.address && regex.test(flow.server_conn.address[0] + ":" + flow.server_conn.address[1]);        }        destinationFilter.desc = "destination address matches " + regex;        return destinationFilter; @@ -8509,7 +8509,7 @@ module.exports = function () {      function source(regex) {        regex = new RegExp(regex, "i");        function sourceFilter(flow) { -        return !!flow.client_conn.address && regex.test(flow.client_conn.address.address[0] + ":" + flow.client_conn.address.address[1]); +        return !!flow.client_conn.address && regex.test(flow.client_conn.address[0] + ":" + flow.client_conn.address[1]);        }        sourceFilter.desc = "source address matches " + regex;        return sourceFilter; diff --git a/pathod/pathoc.py b/pathod/pathoc.py index 549444ca..4a613349 100644 --- a/pathod/pathoc.py +++ b/pathod/pathoc.py @@ -239,7 +239,7 @@ class Pathoc(tcp.TCPClient):              is_client=True,              staticdir=os.getcwd(),              unconstrained_file_access=True, -            request_host=self.address.host, +            request_host=self.address[0],              protocol=self.protocol,          ) @@ -286,7 +286,7 @@ class Pathoc(tcp.TCPClient):                  socks.VERSION.SOCKS5,                  socks.CMD.CONNECT,                  socks.ATYP.DOMAINNAME, -                tcp.Address.wrap(connect_to) +                connect_to,              )              connect_request.to_file(self.wfile)              self.wfile.flush() diff --git a/pathod/pathod.py b/pathod/pathod.py index 8d57897b..7416d325 100644 --- a/pathod/pathod.py +++ b/pathod/pathod.py @@ -166,7 +166,7 @@ class PathodHandler(tcp.BaseHandler):                      headers=headers.fields,                      http_version=http_version,                      sni=self.sni, -                    remote_address=self.address(), +                    remote_address=self.address,                      clientcert=clientcert,                      first_line_format=first_line_format                  ), diff --git a/pathod/protocols/http2.py b/pathod/protocols/http2.py index 628e3f33..7c88c5c7 100644 --- a/pathod/protocols/http2.py +++ b/pathod/protocols/http2.py @@ -172,9 +172,9 @@ class HTTP2StateProtocol:      def assemble_request(self, request):          assert isinstance(request, mitmproxy.net.http.request.Request) -        authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host -        if self.tcp_handler.address.port != 443: -            authority += ":%d" % self.tcp_handler.address.port +        authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address[0] +        if self.tcp_handler.address[1] != 443: +            authority += ":%d" % self.tcp_handler.address[1]          headers = request.headers.copy() diff --git a/pathod/test.py b/pathod/test.py index b819d723..81f5805f 100644 --- a/pathod/test.py +++ b/pathod/test.py @@ -97,8 +97,8 @@ class _PaThread(basethread.BaseThread):              **self.daemonargs          )          self.name = "PathodThread (%s:%s)" % ( -            self.server.address.host, -            self.server.address.port +            self.server.address[0], +            self.server.address[1],          ) -        self.q.put(self.server.address.port) +        self.q.put(self.server.address[1])          self.server.serve_forever() @@ -34,16 +34,11 @@ exclude =      mitmproxy/proxy/root_context.py      mitmproxy/proxy/server.py      mitmproxy/tools/ -    mitmproxy/certs.py -    mitmproxy/connections.py      mitmproxy/controller.py      mitmproxy/export.py      mitmproxy/flow.py -    mitmproxy/flowfilter.py -    mitmproxy/http.py      mitmproxy/io_compat.py      mitmproxy/master.py -    mitmproxy/optmanager.py      pathod/pathoc.py      pathod/pathod.py      pathod/test.py @@ -54,8 +49,6 @@ exclude =      mitmproxy/addonmanager.py      mitmproxy/addons/onboardingapp/app.py      mitmproxy/addons/termlog.py -    mitmproxy/certs.py -    mitmproxy/connections.py      mitmproxy/contentviews/base.py      mitmproxy/contentviews/wbxml.py      mitmproxy/contentviews/xml_html.py @@ -64,8 +57,6 @@ exclude =      mitmproxy/exceptions.py      mitmproxy/export.py      mitmproxy/flow.py -    mitmproxy/flowfilter.py -    mitmproxy/http.py      mitmproxy/io.py      mitmproxy/io_compat.py      mitmproxy/log.py @@ -78,7 +69,6 @@ exclude =      mitmproxy/net/http/url.py      mitmproxy/net/tcp.py      mitmproxy/options.py -    mitmproxy/optmanager.py      mitmproxy/proxy/config.py      mitmproxy/proxy/modes/http_proxy.py      mitmproxy/proxy/modes/reverse_proxy.py @@ -113,7 +113,6 @@ setup(          ],          'examples': [              "beautifulsoup4>=4.4.1, <4.6", -            "pytz>=2015.07.0, <=2016.10",              "Pillow>=3.2, <4.1",          ]      } diff --git a/test/mitmproxy/addons/test_dumper.py b/test/mitmproxy/addons/test_dumper.py index 6a66d0c9..22d2c2c6 100644 --- a/test/mitmproxy/addons/test_dumper.py +++ b/test/mitmproxy/addons/test_dumper.py @@ -70,7 +70,7 @@ def test_simple():          flow.request = tutils.treq()          flow.request.stickycookie = True          flow.client_conn = mock.MagicMock() -        flow.client_conn.address.host = "foo" +        flow.client_conn.address[0] = "foo"          flow.response = tutils.tresp(content=None)          flow.response.is_replay = True          flow.response.status_code = 300 @@ -176,7 +176,7 @@ def test_websocket():          ctx.configure(d, flow_detail=3, showhost=True)          f = tflow.twebsocketflow()          d.websocket_message(f) -        assert "hello text" in sio.getvalue() +        assert "it's me" in sio.getvalue()          sio.truncate(0)          d.websocket_end(f) diff --git a/test/mitmproxy/examples/test_xss_scanner.py b/test/mitmproxy/examples/test_xss_scanner.py new file mode 100755 index 00000000..14ee6902 --- /dev/null +++ b/test/mitmproxy/examples/test_xss_scanner.py @@ -0,0 +1,368 @@ +import pytest +import requests +from examples.complex import xss_scanner as xss +from mitmproxy.test import tflow, tutils + + +class TestXSSScanner(): +    def test_get_XSS_info(self): +        # First type of exploit: <script>PAYLOAD</script> +        # Exploitable: +        xss_info = xss.get_XSS_data(b"<html><script>%s</script><html>" % +                                    xss.FULL_PAYLOAD, +                                    "https://example.com", +                                    "End of URL") +        expected_xss_info = xss.XSSData('https://example.com', +                                        "End of URL", +                                        '</script><script>alert(0)</script><script>', +                                        xss.FULL_PAYLOAD.decode('utf-8')) +        assert xss_info == expected_xss_info +        xss_info = xss.get_XSS_data(b"<html><script>%s</script><html>" % +                                    xss.FULL_PAYLOAD.replace(b"'", b"%27").replace(b'"', b"%22"), +                                    "https://example.com", +                                    "End of URL") +        expected_xss_info = xss.XSSData("https://example.com", +                                        "End of URL", +                                        '</script><script>alert(0)</script><script>', +                                        xss.FULL_PAYLOAD.replace(b"'", b"%27").replace(b'"', b"%22").decode('utf-8')) +        assert xss_info == expected_xss_info +        # Non-Exploitable: +        xss_info = xss.get_XSS_data(b"<html><script>%s</script><html>" % +                                    xss.FULL_PAYLOAD.replace(b"'", b"%27").replace(b'"', b"%22").replace(b"/", b"%2F"), +                                    "https://example.com", +                                    "End of URL") +        assert xss_info is None +        # Second type of exploit: <script>t='PAYLOAD'</script> +        # Exploitable: +        xss_info = xss.get_XSS_data(b"<html><script>t='%s';</script></html>" % +                                    xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").replace(b"\"", b"%22"), +                                    "https://example.com", +                                    "End of URL") +        expected_xss_info = xss.XSSData("https://example.com", +                                        "End of URL", +                                        "';alert(0);g='", +                                        xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E") +                                        .replace(b"\"", b"%22").decode('utf-8')) +        assert xss_info == expected_xss_info +        # Non-Exploitable: +        xss_info = xss.get_XSS_data(b"<html><script>t='%s';</script></html>" % +                                    xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b"\"", b"%22").replace(b"'", b"%22"), +                                    "https://example.com", +                                    "End of URL") +        assert xss_info is None +        # Third type of exploit: <script>t="PAYLOAD"</script> +        # Exploitable: +        xss_info = xss.get_XSS_data(b"<html><script>t=\"%s\";</script></html>" % +                                    xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").replace(b"'", b"%27"), +                                    "https://example.com", +                                    "End of URL") +        expected_xss_info = xss.XSSData("https://example.com", +                                        "End of URL", +                                        '";alert(0);g="', +                                        xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E") +                                        .replace(b"'", b"%27").decode('utf-8')) +        assert xss_info == expected_xss_info +        # Non-Exploitable: +        xss_info = xss.get_XSS_data(b"<html><script>t=\"%s\";</script></html>" % +                                    xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b"'", b"%27").replace(b"\"", b"%22"), +                                    "https://example.com", +                                    "End of URL") +        assert xss_info is None +        # Fourth type of exploit: <a href='PAYLOAD'>Test</a> +        # Exploitable: +        xss_info = xss.get_XSS_data(b"<html><a href='%s'>Test</a></html>" % +                                    xss.FULL_PAYLOAD, +                                    "https://example.com", +                                    "End of URL") +        expected_xss_info = xss.XSSData("https://example.com", +                                        "End of URL", +                                        "'><script>alert(0)</script>", +                                        xss.FULL_PAYLOAD.decode('utf-8')) +        assert xss_info == expected_xss_info +        # Non-Exploitable: +        xss_info = xss.get_XSS_data(b"<html><a href='OtherStuff%s'>Test</a></html>" % +                                    xss.FULL_PAYLOAD.replace(b"'", b"%27"), +                                    "https://example.com", +                                    "End of URL") +        assert xss_info is None +        # Fifth type of exploit: <a href="PAYLOAD">Test</a> +        # Exploitable: +        xss_info = xss.get_XSS_data(b"<html><a href=\"%s\">Test</a></html>" % +                                    xss.FULL_PAYLOAD.replace(b"'", b"%27"), +                                    "https://example.com", +                                    "End of URL") +        expected_xss_info = xss.XSSData("https://example.com", +                                        "End of URL", +                                        "\"><script>alert(0)</script>", +                                        xss.FULL_PAYLOAD.replace(b"'", b"%27").decode('utf-8')) +        assert xss_info == expected_xss_info +        # Non-Exploitable: +        xss_info = xss.get_XSS_data(b"<html><a href=\"OtherStuff%s\">Test</a></html>" % +                                    xss.FULL_PAYLOAD.replace(b"'", b"%27").replace(b"\"", b"%22"), +                                    "https://example.com", +                                    "End of URL") +        assert xss_info is None +        # Sixth type of exploit: <a href=PAYLOAD>Test</a> +        # Exploitable: +        xss_info = xss.get_XSS_data(b"<html><a href=%s>Test</a></html>" % +                                    xss.FULL_PAYLOAD, +                                    "https://example.com", +                                    "End of URL") +        expected_xss_info = xss.XSSData("https://example.com", +                                        "End of URL", +                                        "><script>alert(0)</script>", +                                        xss.FULL_PAYLOAD.decode('utf-8')) +        assert xss_info == expected_xss_info +        # Non-Exploitable +        xss_info = xss.get_XSS_data(b"<html><a href=OtherStuff%s>Test</a></html>" % +                                    xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E") +                                    .replace(b"=", b"%3D"), +                                    "https://example.com", +                                    "End of URL") +        assert xss_info is None +        # Seventh type of exploit: <html>PAYLOAD</html> +        # Exploitable: +        xss_info = xss.get_XSS_data(b"<html><b>%s</b></html>" % +                                    xss.FULL_PAYLOAD, +                                    "https://example.com", +                                    "End of URL") +        expected_xss_info = xss.XSSData("https://example.com", +                                        "End of URL", +                                        "<script>alert(0)</script>", +                                        xss.FULL_PAYLOAD.decode('utf-8')) +        assert xss_info == expected_xss_info +        # Non-Exploitable +        xss_info = xss.get_XSS_data(b"<html><b>%s</b></html>" % +                                    xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").replace(b"/", b"%2F"), +                                    "https://example.com", +                                    "End of URL") +        assert xss_info is None +        # Eighth type of exploit: <a href=PAYLOAD>Test</a> +        # Exploitable: +        xss_info = xss.get_XSS_data(b"<html><a href=%s>Test</a></html>" % +                                    xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E"), +                                    "https://example.com", +                                    "End of URL") +        expected_xss_info = xss.XSSData("https://example.com", +                                        "End of URL", +                                        "Javascript:alert(0)", +                                        xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").decode('utf-8')) +        assert xss_info == expected_xss_info +        # Non-Exploitable: +        xss_info = xss.get_XSS_data(b"<html><a href=OtherStuff%s>Test</a></html>" % +                                    xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E") +                                    .replace(b"=", b"%3D"), +                                    "https://example.com", +                                    "End of URL") +        assert xss_info is None +        # Ninth type of exploit: <a href="STUFF PAYLOAD">Test</a> +        # Exploitable: +        xss_info = xss.get_XSS_data(b"<html><a href=\"STUFF %s\">Test</a></html>" % +                                    xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E"), +                                    "https://example.com", +                                    "End of URL") +        expected_xss_info = xss.XSSData("https://example.com", +                                        "End of URL", +                                        '" onmouseover="alert(0)" t="', +                                        xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").decode('utf-8')) +        assert xss_info == expected_xss_info +        # Non-Exploitable: +        xss_info = xss.get_XSS_data(b"<html><a href=\"STUFF %s\">Test</a></html>" % +                                    xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E") +                                    .replace(b'"', b"%22"), +                                    "https://example.com", +                                    "End of URL") +        assert xss_info is None +        # Tenth type of exploit: <a href='STUFF PAYLOAD'>Test</a> +        # Exploitable: +        xss_info = xss.get_XSS_data(b"<html><a href='STUFF %s'>Test</a></html>" % +                                    xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E"), +                                    "https://example.com", +                                    "End of URL") +        expected_xss_info = xss.XSSData("https://example.com", +                                        "End of URL", +                                        "' onmouseover='alert(0)' t='", +                                        xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").decode('utf-8')) +        assert xss_info == expected_xss_info +        # Non-Exploitable: +        xss_info = xss.get_XSS_data(b"<html><a href='STUFF %s'>Test</a></html>" % +                                    xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E") +                                    .replace(b"'", b"%22"), +                                    "https://example.com", +                                    "End of URL") +        assert xss_info is None +        # Eleventh type of exploit: <a href=STUFF_PAYLOAD>Test</a> +        # Exploitable: +        xss_info = xss.get_XSS_data(b"<html><a href=STUFF%s>Test</a></html>" % +                                    xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E"), +                                    "https://example.com", +                                    "End of URL") +        expected_xss_info = xss.XSSData("https://example.com", +                                        "End of URL", +                                        " onmouseover=alert(0) t=", +                                        xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E").decode('utf-8')) +        assert xss_info == expected_xss_info +        # Non-Exploitable: +        xss_info = xss.get_XSS_data(b"<html><a href=STUFF_%s>Test</a></html>" % +                                    xss.FULL_PAYLOAD.replace(b"<", b"%3C").replace(b">", b"%3E") +                                    .replace(b"=", b"%3D"), +                                    "https://example.com", +                                    "End of URL") +        assert xss_info is None + +    def test_get_SQLi_data(self): +        sqli_data = xss.get_SQLi_data("<html>SQL syntax MySQL</html>", +                                      "<html></html>", +                                      "https://example.com", +                                      "End of URL") +        expected_sqli_data = xss.SQLiData("https://example.com", +                                          "End of URL", +                                          "SQL syntax.*MySQL", +                                          "MySQL") +        assert sqli_data == expected_sqli_data +        sqli_data = xss.get_SQLi_data("<html>SQL syntax MySQL</html>", +                                      "<html>SQL syntax MySQL</html>", +                                      "https://example.com", +                                      "End of URL") +        assert sqli_data is None + +    def test_inside_quote(self): +        assert not xss.inside_quote("'", b"no", 0, b"no") +        assert xss.inside_quote("'", b"yes", 0, b"'yes'") +        assert xss.inside_quote("'", b"yes", 1, b"'yes'otherJunk'yes'more") +        assert not xss.inside_quote("'", b"longStringNotInIt", 1, b"short") + +    def test_paths_to_text(self): +        text = xss.paths_to_text("""<html><head><h1>STRING</h1></head> +                                    <script>STRING</script> +                                    <a href=STRING></a></html>""", "STRING") +        expected_text = ["/html/head/h1", "/html/script"] +        assert text == expected_text +        assert xss.paths_to_text("""<html></html>""", "STRING") == [] + +    def mocked_requests_vuln(*args, headers=None, cookies=None): +        class MockResponse: +            def __init__(self, html, headers=None, cookies=None): +                self.text = html +        return MockResponse("<html>%s</html>" % xss.FULL_PAYLOAD) + +    def mocked_requests_invuln(*args, headers=None, cookies=None): +        class MockResponse: +            def __init__(self, html, headers=None, cookies=None): +                self.text = html +        return MockResponse("<html></html>") + +    def test_test_end_of_url_injection(self, monkeypatch): +        monkeypatch.setattr(requests, 'get', self.mocked_requests_vuln) +        xss_info = xss.test_end_of_URL_injection("<html></html>", "https://example.com/index.html", {})[0] +        expected_xss_info = xss.XSSData('https://example.com/index.html/1029zxcs\'d"ao<ac>so[sb]po(pc)se;sl/bsl\\eq=3847asd', +                                        'End of URL', +                                        '<script>alert(0)</script>', +                                        '1029zxcs\\\'d"ao<ac>so[sb]po(pc)se;sl/bsl\\\\eq=3847asd') +        sqli_info = xss.test_end_of_URL_injection("<html></html>", "https://example.com/", {})[1] +        assert xss_info == expected_xss_info +        assert sqli_info is None + +    def test_test_referer_injection(self, monkeypatch): +        monkeypatch.setattr(requests, 'get', self.mocked_requests_vuln) +        xss_info = xss.test_referer_injection("<html></html>", "https://example.com/", {})[0] +        expected_xss_info = xss.XSSData('https://example.com/', +                                        'Referer', +                                        '<script>alert(0)</script>', +                                        '1029zxcs\\\'d"ao<ac>so[sb]po(pc)se;sl/bsl\\\\eq=3847asd') +        sqli_info = xss.test_referer_injection("<html></html>", "https://example.com/", {})[1] +        assert xss_info == expected_xss_info +        assert sqli_info is None + +    def test_test_user_agent_injection(self, monkeypatch): +        monkeypatch.setattr(requests, 'get', self.mocked_requests_vuln) +        xss_info = xss.test_user_agent_injection("<html></html>", "https://example.com/", {})[0] +        expected_xss_info = xss.XSSData('https://example.com/', +                                        'User Agent', +                                        '<script>alert(0)</script>', +                                        '1029zxcs\\\'d"ao<ac>so[sb]po(pc)se;sl/bsl\\\\eq=3847asd') +        sqli_info = xss.test_user_agent_injection("<html></html>", "https://example.com/", {})[1] +        assert xss_info == expected_xss_info +        assert sqli_info is None + +    def test_test_query_injection(self, monkeypatch): +        monkeypatch.setattr(requests, 'get', self.mocked_requests_vuln) +        xss_info = xss.test_query_injection("<html></html>", "https://example.com/vuln.php?cmd=ls", {})[0] +        expected_xss_info = xss.XSSData('https://example.com/vuln.php?cmd=1029zxcs\'d"ao<ac>so[sb]po(pc)se;sl/bsl\\eq=3847asd', +                                        'Query', +                                        '<script>alert(0)</script>', +                                        '1029zxcs\\\'d"ao<ac>so[sb]po(pc)se;sl/bsl\\\\eq=3847asd') +        sqli_info = xss.test_query_injection("<html></html>", "https://example.com/vuln.php?cmd=ls", {})[1] +        assert xss_info == expected_xss_info +        assert sqli_info is None + +    @pytest.fixture +    def logger(self): +        class Logger(): +            def __init__(self): +                self.args = [] + +            def error(self, str): +                self.args.append(str) +        return Logger() + +    def test_find_unclaimed_URLs(self, monkeypatch, logger): +        logger.args = [] +        monkeypatch.setattr("mitmproxy.ctx.log", logger) +        xss.find_unclaimed_URLs("<html><script src=\"http://google.com\"></script></html>", +                                "https://example.com") +        assert logger.args == [] +        xss.find_unclaimed_URLs("<html><script src=\"http://unclaimedDomainName.com\"></script></html>", +                                "https://example.com") +        assert logger.args[0] == 'XSS found in https://example.com due to unclaimed URL "http://unclaimedDomainName.com" in script tag.' + +    def test_log_XSS_data(self, monkeypatch, logger): +        logger.args = [] +        monkeypatch.setattr("mitmproxy.ctx.log", logger) +        xss.log_XSS_data(None) +        assert logger.args == [] +        # self, url: str, injection_point: str, exploit: str, line: str +        xss.log_XSS_data(xss.XSSData('https://example.com', +                                     'Location', +                                     'String', +                                     'Line of HTML')) +        assert logger.args[0] == '===== XSS Found ====' +        assert logger.args[1] == 'XSS URL: https://example.com' +        assert logger.args[2] == 'Injection Point: Location' +        assert logger.args[3] == 'Suggested Exploit: String' +        assert logger.args[4] == 'Line: Line of HTML' + +    def test_log_SQLi_data(self, monkeypatch, logger): +        logger.args = [] +        monkeypatch.setattr("mitmproxy.ctx.log", logger) +        xss.log_SQLi_data(None) +        assert logger.args == [] +        xss.log_SQLi_data(xss.SQLiData(b'https://example.com', +                                       b'Location', +                                       b'Oracle.*Driver', +                                       b'Oracle')) +        assert logger.args[0] == '===== SQLi Found =====' +        assert logger.args[1] == 'SQLi URL: https://example.com' +        assert logger.args[2] == 'Injection Point: Location' +        assert logger.args[3] == 'Regex used: Oracle.*Driver' + +    def test_get_cookies(self): +        mocked_req = tutils.treq() +        mocked_req.cookies = [("cookieName2", "cookieValue2")] +        mocked_flow = tflow.tflow(req=mocked_req) +        # It only uses the request cookies +        assert xss.get_cookies(mocked_flow) == {"cookieName2": "cookieValue2"} + +    def test_response(self, monkeypatch, logger): +        logger.args = [] +        monkeypatch.setattr("mitmproxy.ctx.log", logger) +        monkeypatch.setattr(requests, 'get', self.mocked_requests_invuln) +        mocked_flow = tflow.tflow(req=tutils.treq(path=b"index.html?q=1"), resp=tutils.tresp(content=b'<html></html>')) +        xss.response(mocked_flow) +        assert logger.args == [] + +    def test_data_equals(self): +        xssData = xss.XSSData("a", "b", "c", "d") +        sqliData = xss.SQLiData("a", "b", "c", "d") +        assert xssData == xssData +        assert sqliData == sqliData diff --git a/test/mitmproxy/net/test_check.py b/test/mitmproxy/net/test_check.py index 9dbc02e0..0ffd6b2e 100644 --- a/test/mitmproxy/net/test_check.py +++ b/test/mitmproxy/net/test_check.py @@ -11,3 +11,4 @@ def test_is_valid_host():      assert check.is_valid_host(b"one.two.")      # Allow underscore      assert check.is_valid_host(b"one_two") +    assert check.is_valid_host(b"::1") diff --git a/test/mitmproxy/net/test_socks.py b/test/mitmproxy/net/test_socks.py index e00dd410..fbd31ef4 100644 --- a/test/mitmproxy/net/test_socks.py +++ b/test/mitmproxy/net/test_socks.py @@ -3,7 +3,6 @@ from io import BytesIO  import pytest  from mitmproxy.net import socks -from mitmproxy.net import tcp  from mitmproxy.test import tutils @@ -176,7 +175,7 @@ def test_message_ipv6():      msg.to_file(out)      assert out.getvalue() == raw.getvalue()[:-2] -    assert msg.addr.host == ipv6_addr +    assert msg.addr[0] == ipv6_addr  def test_message_invalid_host(): @@ -196,6 +195,6 @@ def test_message_unknown_atyp():      with pytest.raises(socks.SocksError):          socks.Message.from_file(raw) -    m = socks.Message(5, 1, 0x02, tcp.Address(("example.com", 5050))) +    m = socks.Message(5, 1, 0x02, ("example.com", 5050))      with pytest.raises(socks.SocksError):          m.to_file(BytesIO()) diff --git a/test/mitmproxy/net/test_tcp.py b/test/mitmproxy/net/test_tcp.py index ff6362c8..cf010f6e 100644 --- a/test/mitmproxy/net/test_tcp.py +++ b/test/mitmproxy/net/test_tcp.py @@ -116,11 +116,11 @@ class TestServerBind(tservers.ServerTestBase):  class TestServerIPv6(tservers.ServerTestBase):      handler = EchoHandler -    addr = tcp.Address(("localhost", 0), use_ipv6=True) +    addr = ("::1", 0)      def test_echo(self):          testval = b"echo!\n" -        c = tcp.TCPClient(tcp.Address(("::1", self.port), use_ipv6=True)) +        c = tcp.TCPClient(("::1", self.port))          with c.connect():              c.wfile.write(testval)              c.wfile.flush() @@ -132,7 +132,7 @@ class TestEcho(tservers.ServerTestBase):      def test_echo(self):          testval = b"echo!\n" -        c = tcp.TCPClient(("127.0.0.1", self.port)) +        c = tcp.TCPClient(("localhost", self.port))          with c.connect():              c.wfile.write(testval)              c.wfile.flush() @@ -602,12 +602,6 @@ class TestDHParams(tservers.ServerTestBase):              ret = c.get_current_cipher()              assert ret[0] == "DHE-RSA-AES256-SHA" -    def test_create_dhparams(self): -        with tutils.tmpdir() as d: -            filename = os.path.join(d, "dhparam.pem") -            certs.CertStore.load_dhparam(filename) -            assert os.path.exists(filename) -  class TestTCPClient: @@ -783,18 +777,6 @@ class TestPeekSSL(TestPeek):              return conn.pop() -class TestAddress: -    def test_simple(self): -        a = tcp.Address(("localhost", 80), True) -        assert a.use_ipv6 -        b = tcp.Address(("foo.com", 80), True) -        assert not a == b -        c = tcp.Address(("localhost", 80), True) -        assert a == c -        assert not a != c -        assert repr(a) == "localhost:80" - -  class TestSSLKeyLogger(tservers.ServerTestBase):      handler = EchoHandler      ssl = dict( diff --git a/test/mitmproxy/net/tservers.py b/test/mitmproxy/net/tservers.py index 68a2caa0..ebe6d3eb 100644 --- a/test/mitmproxy/net/tservers.py +++ b/test/mitmproxy/net/tservers.py @@ -86,13 +86,13 @@ class _TServer(tcp.TCPServer):  class ServerTestBase:      ssl = None      handler = None -    addr = ("localhost", 0) +    addr = ("127.0.0.1", 0)      @classmethod      def setup_class(cls, **kwargs):          cls.q = queue.Queue()          s = cls.makeserver(**kwargs) -        cls.port = s.address.port +        cls.port = s.address[1]          cls.server = _ServerThread(s)          cls.server.start() diff --git a/test/mitmproxy/proxy/protocol/test_http2.py b/test/mitmproxy/proxy/protocol/test_http2.py index cb9c0474..871d02fe 100644 --- a/test/mitmproxy/proxy/protocol/test_http2.py +++ b/test/mitmproxy/proxy/protocol/test_http2.py @@ -124,10 +124,10 @@ class _Http2TestBase:              b'CONNECT',              b'',              b'localhost', -            self.server.server.address.port, +            self.server.server.address[1],              b'/',              b'HTTP/1.1', -            [(b'host', b'localhost:%d' % self.server.server.address.port)], +            [(b'host', b'localhost:%d' % self.server.server.address[1])],              b'',          )))          client.wfile.flush() @@ -231,7 +231,7 @@ class TestSimple(_Http2Test):              client.wfile,              h2_conn,              headers=[ -                (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), +                (':authority', "127.0.0.1:{}".format(self.server.server.address[1])),                  (':method', 'GET'),                  (':scheme', 'https'),                  (':path', '/'), @@ -272,75 +272,6 @@ class TestSimple(_Http2Test):  @requires_alpn -class TestForbiddenHeaders(_Http2Test): - -    @classmethod -    def handle_server_event(cls, event, h2_conn, rfile, wfile): -        if isinstance(event, h2.events.ConnectionTerminated): -            return False -        elif isinstance(event, h2.events.StreamEnded): -            import warnings -            with warnings.catch_warnings(): -                # Ignore UnicodeWarning: -                # h2/utilities.py:64: UnicodeWarning: Unicode equal comparison -                # failed to convert both arguments to Unicode - interpreting -                # them as being unequal. -                #     elif header[0] in (b'cookie', u'cookie') and len(header[1]) < 20: - -                warnings.simplefilter("ignore") - -                h2_conn.config.validate_outbound_headers = False -                h2_conn.send_headers(event.stream_id, [ -                    (':status', '200'), -                    ('keep-alive', 'foobar'), -                ]) -            h2_conn.send_data(event.stream_id, b'response body') -            h2_conn.end_stream(event.stream_id) -            wfile.write(h2_conn.data_to_send()) -            wfile.flush() -        return True - -    def test_forbidden_headers(self): -        client, h2_conn = self._setup_connection() - -        self._send_request( -            client.wfile, -            h2_conn, -            headers=[ -                (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), -                (':method', 'GET'), -                (':scheme', 'https'), -                (':path', '/'), -            ]) - -        done = False -        while not done: -            try: -                raw = b''.join(http2.read_raw_frame(client.rfile)) -                events = h2_conn.receive_data(raw) -            except exceptions.HttpException: -                print(traceback.format_exc()) -                assert False - -            client.wfile.write(h2_conn.data_to_send()) -            client.wfile.flush() - -            for event in events: -                if isinstance(event, h2.events.ResponseReceived): -                    assert 'keep-alive' not in event.headers -                elif isinstance(event, h2.events.StreamEnded): -                    done = True - -        h2_conn.close_connection() -        client.wfile.write(h2_conn.data_to_send()) -        client.wfile.flush() - -        assert len(self.master.state.flows) == 1 -        assert self.master.state.flows[0].response.status_code == 200 -        assert self.master.state.flows[0].response.headers['keep-alive'] == 'foobar' - - -@requires_alpn  class TestRequestWithPriority(_Http2Test):      @classmethod @@ -384,7 +315,7 @@ class TestRequestWithPriority(_Http2Test):              client.wfile,              h2_conn,              headers=[ -                (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), +                (':authority', "127.0.0.1:{}".format(self.server.server.address[1])),                  (':method', 'GET'),                  (':scheme', 'https'),                  (':path', '/'), @@ -469,7 +400,7 @@ class TestPriority(_Http2Test):              client.wfile,              h2_conn,              headers=[ -                (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), +                (':authority', "127.0.0.1:{}".format(self.server.server.address[1])),                  (':method', 'GET'),                  (':scheme', 'https'),                  (':path', '/'), @@ -527,7 +458,7 @@ class TestStreamResetFromServer(_Http2Test):              client.wfile,              h2_conn,              headers=[ -                (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), +                (':authority', "127.0.0.1:{}".format(self.server.server.address[1])),                  (':method', 'GET'),                  (':scheme', 'https'),                  (':path', '/'), @@ -576,7 +507,7 @@ class TestBodySizeLimit(_Http2Test):              client.wfile,              h2_conn,              headers=[ -                (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), +                (':authority', "127.0.0.1:{}".format(self.server.server.address[1])),                  (':method', 'GET'),                  (':scheme', 'https'),                  (':path', '/'), @@ -672,7 +603,7 @@ class TestPushPromise(_Http2Test):          client, h2_conn = self._setup_connection()          self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ -            (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), +            (':authority', "127.0.0.1:{}".format(self.server.server.address[1])),              (':method', 'GET'),              (':scheme', 'https'),              (':path', '/'), @@ -728,7 +659,7 @@ class TestPushPromise(_Http2Test):          client, h2_conn = self._setup_connection()          self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ -            (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), +            (':authority', "127.0.0.1:{}".format(self.server.server.address[1])),              (':method', 'GET'),              (':scheme', 'https'),              (':path', '/'), @@ -791,7 +722,7 @@ class TestConnectionLost(_Http2Test):          client, h2_conn = self._setup_connection()          self._send_request(client.wfile, h2_conn, stream_id=1, headers=[ -            (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), +            (':authority', "127.0.0.1:{}".format(self.server.server.address[1])),              (':method', 'GET'),              (':scheme', 'https'),              (':path', '/'), @@ -848,7 +779,7 @@ class TestMaxConcurrentStreams(_Http2Test):              # this will exceed MAX_CONCURRENT_STREAMS on the server connection              # and cause mitmproxy to throttle stream creation to the server              self._send_request(client.wfile, h2_conn, stream_id=stream_id, headers=[ -                (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), +                (':authority', "127.0.0.1:{}".format(self.server.server.address[1])),                  (':method', 'GET'),                  (':scheme', 'https'),                  (':path', '/'), @@ -894,7 +825,7 @@ class TestConnectionTerminated(_Http2Test):          client, h2_conn = self._setup_connection()          self._send_request(client.wfile, h2_conn, headers=[ -            (':authority', "127.0.0.1:{}".format(self.server.server.address.port)), +            (':authority', "127.0.0.1:{}".format(self.server.server.address[1])),              (':method', 'GET'),              (':scheme', 'https'),              (':path', '/'), diff --git a/test/mitmproxy/proxy/protocol/test_websocket.py b/test/mitmproxy/proxy/protocol/test_websocket.py index 4ea01d34..bac0e527 100644 --- a/test/mitmproxy/proxy/protocol/test_websocket.py +++ b/test/mitmproxy/proxy/protocol/test_websocket.py @@ -87,8 +87,8 @@ class _WebSocketTestBase:              "authority",              "CONNECT",              "", -            "localhost", -            self.server.server.address.port, +            "127.0.0.1", +            self.server.server.address[1],              "",              "HTTP/1.1",              content=b'') @@ -105,8 +105,8 @@ class _WebSocketTestBase:              "relative",              "GET",              "http", -            "localhost", -            self.server.server.address.port, +            "127.0.0.1", +            self.server.server.address[1],              "/ws",              "HTTP/1.1",              headers=http.Headers( diff --git a/test/mitmproxy/proxy/test_server.py b/test/mitmproxy/proxy/test_server.py index 46beea41..56b09b9a 100644 --- a/test/mitmproxy/proxy/test_server.py +++ b/test/mitmproxy/proxy/test_server.py @@ -17,7 +17,6 @@ from mitmproxy.net import socks  from mitmproxy import certs  from mitmproxy import exceptions  from mitmproxy.net.http import http1 -from mitmproxy.net.tcp import Address  from pathod import pathoc  from pathod import pathod @@ -581,7 +580,7 @@ class TestHttps2Http(tservers.ReverseProxyTest):      def get_options(cls):          opts = super().get_options()          s = parse_server_spec(opts.upstream_server) -        opts.upstream_server = "http://%s" % s.address +        opts.upstream_server = "http://{}:{}".format(s.address[0], s.address[1])          return opts      def pathoc(self, ssl, sni=None): @@ -740,7 +739,7 @@ class MasterRedirectRequest(tservers.TestMaster):              # This part should have no impact, but it should also not cause any exceptions.              addr = f.live.server_conn.address -            addr2 = Address(("127.0.0.1", self.redirect_port)) +            addr2 = ("127.0.0.1", self.redirect_port)              f.live.set_server(addr2)              f.live.set_server(addr) @@ -750,8 +749,8 @@ class MasterRedirectRequest(tservers.TestMaster):      @controller.handler      def response(self, f): -        f.response.content = bytes(f.client_conn.address.port) -        f.response.headers["server-conn-id"] = str(f.server_conn.source_address.port) +        f.response.content = bytes(f.client_conn.address[1]) +        f.response.headers["server-conn-id"] = str(f.server_conn.source_address[1])          super().response(f) diff --git a/test/mitmproxy/test_certs.py b/test/mitmproxy/test_certs.py index f1eff9ba..9bd3ad25 100644 --- a/test/mitmproxy/test_certs.py +++ b/test/mitmproxy/test_certs.py @@ -117,6 +117,12 @@ class TestCertStore:              ret = ca1.get_cert(b"foo.com", [])              assert ret[0].serial == dc[0].serial +    def test_create_dhparams(self): +        with tutils.tmpdir() as d: +            filename = os.path.join(d, "dhparam.pem") +            certs.CertStore.load_dhparam(filename) +            assert os.path.exists(filename) +  class TestDummyCert: @@ -127,9 +133,10 @@ class TestDummyCert:                  ca.default_privatekey,                  ca.default_ca,                  b"foo.com", -                [b"one.com", b"two.com", b"*.three.com"] +                [b"one.com", b"two.com", b"*.three.com", b"127.0.0.1"]              )              assert r.cn == b"foo.com" +            assert r.altnames == [b'one.com', b'two.com', b'*.three.com']              r = certs.dummy_cert(                  ca.default_privatekey, @@ -138,6 +145,7 @@ class TestDummyCert:                  []              )              assert r.cn is None +            assert r.altnames == []  class TestSSLCert: @@ -179,3 +187,20 @@ class TestSSLCert:              d = f.read()          s = certs.SSLCert.from_der(d)          assert s.cn + +    def test_state(self): +        with open(tutils.test_data.path("mitmproxy/net/data/text_cert"), "rb") as f: +            d = f.read() +        c = certs.SSLCert.from_pem(d) + +        c.get_state() +        c2 = c.copy() +        a = c.get_state() +        b = c2.get_state() +        assert a == b +        assert c == c2 +        assert c is not c2 + +        x = certs.SSLCert('') +        x.set_state(a) +        assert x == c diff --git a/test/mitmproxy/test_connections.py b/test/mitmproxy/test_connections.py index 777ab4dd..0083f57c 100644 --- a/test/mitmproxy/test_connections.py +++ b/test/mitmproxy/test_connections.py @@ -1 +1,210 @@ -# TODO: write tests +import socket +import os +import threading +import ssl +import OpenSSL +import pytest +from unittest import mock + +from mitmproxy import connections +from mitmproxy import exceptions +from mitmproxy.net import tcp +from mitmproxy.net.http import http1 +from mitmproxy.test import tflow +from mitmproxy.test import tutils +from .net import tservers +from pathod import test + + +class TestClientConnection: + +    def test_send(self): +        c = tflow.tclient_conn() +        c.send(b'foobar') +        c.send([b'foo', b'bar']) +        with pytest.raises(TypeError): +            c.send('string') +        with pytest.raises(TypeError): +            c.send(['string', 'not']) +        assert c.wfile.getvalue() == b'foobarfoobar' + +    def test_repr(self): +        c = tflow.tclient_conn() +        assert 'address:22' in repr(c) +        assert 'ALPN' in repr(c) +        assert 'TLS' not in repr(c) + +        c.alpn_proto_negotiated = None +        c.tls_established = True +        assert 'ALPN' not in repr(c) +        assert 'TLS' in repr(c) + +    def test_tls_established_property(self): +        c = tflow.tclient_conn() +        c.tls_established = True +        assert c.ssl_established +        assert c.tls_established +        c.tls_established = False +        assert not c.ssl_established +        assert not c.tls_established + +    def test_make_dummy(self): +        c = connections.ClientConnection.make_dummy(('foobar', 1234)) +        assert c.address == ('foobar', 1234) + +    def test_state(self): +        c = tflow.tclient_conn() +        assert connections.ClientConnection.from_state(c.get_state()).get_state() == \ +            c.get_state() + +        c2 = tflow.tclient_conn() +        c2.address = (c2.address[0], 4242) +        assert not c == c2 + +        c2.timestamp_start = 42 +        c.set_state(c2.get_state()) +        assert c.timestamp_start == 42 + +        c3 = c.copy() +        assert c3.get_state() == c.get_state() + + +class TestServerConnection: + +    def test_send(self): +        c = tflow.tserver_conn() +        c.send(b'foobar') +        c.send([b'foo', b'bar']) +        with pytest.raises(TypeError): +            c.send('string') +        with pytest.raises(TypeError): +            c.send(['string', 'not']) +        assert c.wfile.getvalue() == b'foobarfoobar' + +    def test_repr(self): +        c = tflow.tserver_conn() + +        c.sni = 'foobar' +        c.tls_established = True +        c.alpn_proto_negotiated = b'h2' +        assert 'address:22' in repr(c) +        assert 'ALPN' in repr(c) +        assert 'TLS: foobar' in repr(c) + +        c.sni = None +        c.tls_established = True +        c.alpn_proto_negotiated = None +        assert 'ALPN' not in repr(c) +        assert 'TLS' in repr(c) + +        c.sni = None +        c.tls_established = False +        assert 'TLS' not in repr(c) + +    def test_tls_established_property(self): +        c = tflow.tserver_conn() +        c.tls_established = True +        assert c.ssl_established +        assert c.tls_established +        c.tls_established = False +        assert not c.ssl_established +        assert not c.tls_established + +    def test_make_dummy(self): +        c = connections.ServerConnection.make_dummy(('foobar', 1234)) +        assert c.address == ('foobar', 1234) + +    def test_simple(self): +        d = test.Daemon() +        c = connections.ServerConnection((d.IFACE, d.port)) +        c.connect() +        f = tflow.tflow() +        f.server_conn = c +        f.request.path = "/p/200:da" + +        # use this protocol just to assemble - not for actual sending +        c.wfile.write(http1.assemble_request(f.request)) +        c.wfile.flush() + +        assert http1.read_response(c.rfile, f.request, 1000) +        assert d.last_log() + +        c.finish() +        d.shutdown() + +    def test_terminate_error(self): +        d = test.Daemon() +        c = connections.ServerConnection((d.IFACE, d.port)) +        c.connect() +        c.connection = mock.Mock() +        c.connection.recv = mock.Mock(return_value=False) +        c.connection.flush = mock.Mock(side_effect=exceptions.TcpDisconnect) +        c.finish() +        d.shutdown() + +    def test_sni(self): +        c = connections.ServerConnection(('', 1234)) +        with pytest.raises(ValueError, matches='sni must be str, not '): +            c.establish_ssl(None, b'foobar') + + +class TestClientConnectionTLS: + +    @pytest.mark.parametrize("sni", [ +        None, +        "example.com" +    ]) +    def test_tls_with_sni(self, sni): +        address = ('127.0.0.1', 0) +        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) +        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) +        sock.bind(address) +        sock.listen() +        address = sock.getsockname() + +        def client_run(): +            ctx = ssl.create_default_context() +            ctx.check_hostname = False +            ctx.verify_mode = ssl.CERT_NONE +            s = socket.create_connection(address) +            s = ctx.wrap_socket(s, server_hostname=sni) +            s.send(b'foobar') +            s.shutdown(socket.SHUT_RDWR) +        threading.Thread(target=client_run).start() + +        connection, client_address = sock.accept() +        c = connections.ClientConnection(connection, client_address, None) + +        cert = tutils.test_data.path("mitmproxy/net/data/server.crt") +        key = OpenSSL.crypto.load_privatekey( +            OpenSSL.crypto.FILETYPE_PEM, +            open(tutils.test_data.path("mitmproxy/net/data/server.key"), "rb").read()) +        c.convert_to_ssl(cert, key) +        assert c.connected() +        assert c.sni == sni +        assert c.tls_established +        assert c.rfile.read(6) == b'foobar' +        c.finish() + + +class TestServerConnectionTLS(tservers.ServerTestBase): +    ssl = True + +    class handler(tcp.BaseHandler): +        def handle(self): +            self.finish() + +    @pytest.mark.parametrize("clientcert", [ +        None, +        tutils.test_data.path("mitmproxy/data/clientcert"), +        os.path.join(tutils.test_data.path("mitmproxy/data/clientcert"), "client.pem"), +    ]) +    def test_tls(self, clientcert): +        c = connections.ServerConnection(("127.0.0.1", self.port)) +        c.connect() +        c.establish_ssl(clientcert, "foo.com") +        assert c.connected() +        assert c.sni == "foo.com" +        assert c.tls_established +        c.close() +        c.finish() diff --git a/test/mitmproxy/test_eventsequence.py b/test/mitmproxy/test_eventsequence.py index fe0f92b3..871d4b9d 100644 --- a/test/mitmproxy/test_eventsequence.py +++ b/test/mitmproxy/test_eventsequence.py @@ -32,6 +32,8 @@ def test_websocket_flow(err):      assert len(f.messages) == 1      assert next(i) == ("websocket_message", f)      assert len(f.messages) == 2 +    assert next(i) == ("websocket_message", f) +    assert len(f.messages) == 3      if err:          assert next(i) == ("websocket_error", f)      assert next(i) == ("websocket_end", f) diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index a78e5f80..0ac3bfd6 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -2,160 +2,18 @@ import io  import pytest  from mitmproxy.test import tflow -from mitmproxy.net.http import Headers  import mitmproxy.io  from mitmproxy import flowfilter, options  from mitmproxy.contrib import tnetstring -from mitmproxy.exceptions import FlowReadException, Kill +from mitmproxy.exceptions import FlowReadException  from mitmproxy import flow  from mitmproxy import http -from mitmproxy import connections  from mitmproxy.proxy import ProxyConfig  from mitmproxy.proxy.server import DummyServer  from mitmproxy import master  from . import tservers -class TestHTTPFlow: - -    def test_copy(self): -        f = tflow.tflow(resp=True) -        f.get_state() -        f2 = f.copy() -        a = f.get_state() -        b = f2.get_state() -        del a["id"] -        del b["id"] -        assert a == b -        assert not f == f2 -        assert f is not f2 -        assert f.request.get_state() == f2.request.get_state() -        assert f.request is not f2.request -        assert f.request.headers == f2.request.headers -        assert f.request.headers is not f2.request.headers -        assert f.response.get_state() == f2.response.get_state() -        assert f.response is not f2.response - -        f = tflow.tflow(err=True) -        f2 = f.copy() -        assert f is not f2 -        assert f.request is not f2.request -        assert f.request.headers == f2.request.headers -        assert f.request.headers is not f2.request.headers -        assert f.error.get_state() == f2.error.get_state() -        assert f.error is not f2.error - -    def test_match(self): -        f = tflow.tflow(resp=True) -        assert not flowfilter.match("~b test", f) -        assert flowfilter.match(None, f) -        assert not flowfilter.match("~b test", f) - -        f = tflow.tflow(err=True) -        assert flowfilter.match("~e", f) - -        with pytest.raises(ValueError): -            flowfilter.match("~", f) - -    def test_backup(self): -        f = tflow.tflow() -        f.response = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp()) -        f.request.content = b"foo" -        assert not f.modified() -        f.backup() -        f.request.content = b"bar" -        assert f.modified() -        f.revert() -        assert f.request.content == b"foo" - -    def test_backup_idempotence(self): -        f = tflow.tflow(resp=True) -        f.backup() -        f.revert() -        f.backup() -        f.revert() - -    def test_getset_state(self): -        f = tflow.tflow(resp=True) -        state = f.get_state() -        assert f.get_state() == http.HTTPFlow.from_state( -            state).get_state() - -        f.response = None -        f.error = flow.Error("error") -        state = f.get_state() -        assert f.get_state() == http.HTTPFlow.from_state( -            state).get_state() - -        f2 = f.copy() -        f2.id = f.id  # copy creates a different uuid -        assert f.get_state() == f2.get_state() -        assert not f == f2 -        f2.error = flow.Error("e2") -        assert not f == f2 -        f.set_state(f2.get_state()) -        assert f.get_state() == f2.get_state() - -    def test_kill(self): -        f = tflow.tflow() -        f.reply.handle() -        f.intercept() -        assert f.killable -        f.kill() -        assert not f.killable -        assert f.reply.value == Kill - -    def test_resume(self): -        f = tflow.tflow() -        f.reply.handle() -        f.intercept() -        assert f.reply.state == "taken" -        f.resume() -        assert f.reply.state == "committed" - -    def test_replace_unicode(self): -        f = tflow.tflow(resp=True) -        f.response.content = b"\xc2foo" -        f.replace(b"foo", u"bar") - -    def test_replace_no_content(self): -        f = tflow.tflow() -        f.request.content = None -        assert f.replace("foo", "bar") == 0 - -    def test_replace(self): -        f = tflow.tflow(resp=True) -        f.request.headers["foo"] = "foo" -        f.request.content = b"afoob" - -        f.response.headers["foo"] = "foo" -        f.response.content = b"afoob" - -        assert f.replace("foo", "bar") == 6 - -        assert f.request.headers["bar"] == "bar" -        assert f.request.content == b"abarb" -        assert f.response.headers["bar"] == "bar" -        assert f.response.content == b"abarb" - -    def test_replace_encoded(self): -        f = tflow.tflow(resp=True) -        f.request.content = b"afoob" -        f.request.encode("gzip") -        f.response.content = b"afoob" -        f.response.encode("gzip") - -        f.replace("foo", "bar") - -        assert f.request.raw_content != b"abarb" -        f.request.decode() -        assert f.request.raw_content == b"abarb" - -        assert f.response.raw_content != b"abarb" -        f.response.decode() -        assert f.response.raw_content == b"abarb" - -  class TestSerialize:      def _treader(self): @@ -307,88 +165,6 @@ class TestFlowMaster:          fm.shutdown() -class TestRequest: - -    def test_simple(self): -        f = tflow.tflow() -        r = f.request -        u = r.url -        r.url = u -        with pytest.raises(ValueError): -            setattr(r, "url", "") -        assert r.url == u -        r2 = r.copy() -        assert r.get_state() == r2.get_state() - -    def test_get_url(self): -        r = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq()) - -        assert r.url == "http://address:22/path" - -        r.scheme = "https" -        assert r.url == "https://address:22/path" - -        r.host = "host" -        r.port = 42 -        assert r.url == "https://host:42/path" - -        r.host = "address" -        r.port = 22 -        assert r.url == "https://address:22/path" - -        assert r.pretty_url == "https://address:22/path" -        r.headers["Host"] = "foo.com:22" -        assert r.url == "https://address:22/path" -        assert r.pretty_url == "https://foo.com:22/path" - -    def test_replace(self): -        r = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq()) -        r.path = "path/foo" -        r.headers["Foo"] = "fOo" -        r.content = b"afoob" -        assert r.replace("foo(?i)", "boo") == 4 -        assert r.path == "path/boo" -        assert b"foo" not in r.content -        assert r.headers["boo"] == "boo" - -    def test_constrain_encoding(self): -        r = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq()) -        r.headers["accept-encoding"] = "gzip, oink" -        r.constrain_encoding() -        assert "oink" not in r.headers["accept-encoding"] - -        r.headers.set_all("accept-encoding", ["gzip", "oink"]) -        r.constrain_encoding() -        assert "oink" not in r.headers["accept-encoding"] - -    def test_get_content_type(self): -        resp = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp()) -        resp.headers = Headers(content_type="text/plain") -        assert resp.headers["content-type"] == "text/plain" - - -class TestResponse: - -    def test_simple(self): -        f = tflow.tflow(resp=True) -        resp = f.response -        resp2 = resp.copy() -        assert resp2.get_state() == resp.get_state() - -    def test_replace(self): -        r = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp()) -        r.headers["Foo"] = "fOo" -        r.content = b"afoob" -        assert r.replace("foo(?i)", "boo") == 3 -        assert b"foo" not in r.content -        assert r.headers["boo"] == "boo" - -    def test_get_content_type(self): -        resp = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp()) -        resp.headers = Headers(content_type="text/plain") -        assert resp.headers["content-type"] == "text/plain" - -  class TestError:      def test_getset_state(self): @@ -409,23 +185,4 @@ class TestError:      def test_repr(self):          e = flow.Error("yay")          assert repr(e) - - -class TestClientConnection: -    def test_state(self): -        c = tflow.tclient_conn() -        assert connections.ClientConnection.from_state(c.get_state()).get_state() == \ -            c.get_state() - -        c2 = tflow.tclient_conn() -        c2.address.address = (c2.address.host, 4242) -        assert not c == c2 - -        c2.timestamp_start = 42 -        c.set_state(c2.get_state()) -        assert c.timestamp_start == 42 - -        c3 = c.copy() -        assert c3.get_state() == c.get_state() - -        assert str(c) +        assert str(e) diff --git a/test/mitmproxy/test_flowfilter.py b/test/mitmproxy/test_flowfilter.py index bfce265e..46fff477 100644 --- a/test/mitmproxy/test_flowfilter.py +++ b/test/mitmproxy/test_flowfilter.py @@ -1,4 +1,5 @@  import io +import pytest  from unittest.mock import patch  from mitmproxy.test import tflow @@ -134,6 +135,12 @@ class TestMatchingHTTPFlow:          e = self.err()          assert self.q("~e", e) +    def test_fmarked(self): +        q = self.req() +        assert not self.q("~marked", q) +        q.marked = True +        assert self.q("~marked", q) +      def test_head(self):          q = self.req()          s = self.resp() @@ -221,6 +228,11 @@ class TestMatchingHTTPFlow:          assert not self.q("~src :99", q)          assert self.q("~src address:22", q) +        q.client_conn.address = None +        assert not self.q('~src address:22', q) +        q.client_conn = None +        assert not self.q('~src address:22', q) +      def test_dst(self):          q = self.req()          q.server_conn = tflow.tserver_conn() @@ -230,6 +242,11 @@ class TestMatchingHTTPFlow:          assert not self.q("~dst :99", q)          assert self.q("~dst address:22", q) +        q.server_conn.address = None +        assert not self.q('~dst address:22', q) +        q.server_conn = None +        assert not self.q('~dst address:22', q) +      def test_and(self):          s = self.resp()          assert self.q("~c 200 & ~h head", s) @@ -269,6 +286,7 @@ class TestMatchingTCPFlow:          f = self.flow()          assert self.q("~tcp", f)          assert not self.q("~http", f) +        assert not self.q("~websocket", f)      def test_ferr(self):          e = self.err() @@ -378,6 +396,87 @@ class TestMatchingTCPFlow:          assert not self.q("~u whatever", f) +class TestMatchingWebSocketFlow: + +    def flow(self): +        return tflow.twebsocketflow() + +    def err(self): +        return tflow.twebsocketflow(err=True) + +    def q(self, q, o): +        return flowfilter.parse(q)(o) + +    def test_websocket(self): +        f = self.flow() +        assert self.q("~websocket", f) +        assert not self.q("~tcp", f) +        assert not self.q("~http", f) + +    def test_ferr(self): +        e = self.err() +        assert self.q("~e", e) + +    def test_body(self): +        f = self.flow() + +        # Messages sent by client or server +        assert self.q("~b hello", f) +        assert self.q("~b me", f) +        assert not self.q("~b nonexistent", f) + +        # Messages sent by client +        assert self.q("~bq hello", f) +        assert not self.q("~bq me", f) +        assert not self.q("~bq nonexistent", f) + +        # Messages sent by server +        assert self.q("~bs me", f) +        assert not self.q("~bs hello", f) +        assert not self.q("~bs nonexistent", f) + +    def test_src(self): +        f = self.flow() +        assert self.q("~src address", f) +        assert not self.q("~src foobar", f) +        assert self.q("~src :22", f) +        assert not self.q("~src :99", f) +        assert self.q("~src address:22", f) + +    def test_dst(self): +        f = self.flow() +        f.server_conn = tflow.tserver_conn() +        assert self.q("~dst address", f) +        assert not self.q("~dst foobar", f) +        assert self.q("~dst :22", f) +        assert not self.q("~dst :99", f) +        assert self.q("~dst address:22", f) + +    def test_and(self): +        f = self.flow() +        f.server_conn = tflow.tserver_conn() +        assert self.q("~b hello & ~b me", f) +        assert not self.q("~src wrongaddress & ~b hello", f) +        assert self.q("(~src :22 & ~dst :22) & ~b hello", f) +        assert not self.q("(~src address:22 & ~dst :22) & ~b nonexistent", f) +        assert not self.q("(~src address:22 & ~dst :99) & ~b hello", f) + +    def test_or(self): +        f = self.flow() +        f.server_conn = tflow.tserver_conn() +        assert self.q("~b hello | ~b me", f) +        assert self.q("~src :22 | ~b me", f) +        assert not self.q("~src :99 | ~dst :99", f) +        assert self.q("(~src :22 | ~dst :22) | ~b me", f) + +    def test_not(self): +        f = self.flow() +        assert not self.q("! ~src :22", f) +        assert self.q("! ~src :99", f) +        assert self.q("!~src :99 !~src :99", f) +        assert not self.q("!~src :99 !~src :22", f) + +  class TestMatchingDummyFlow:      def flow(self): @@ -411,6 +510,8 @@ class TestMatchingDummyFlow:          assert not self.q("~e", f)          assert not self.q("~http", f) +        assert not self.q("~tcp", f) +        assert not self.q("~websocket", f)          assert not self.q("~h whatever", f)          assert not self.q("~hq whatever", f) @@ -440,3 +541,11 @@ def test_pyparsing_bug(extract_tb):      # The text is a string with leading and trailing whitespace stripped; if the source is not available it is None.      extract_tb.return_value = [("", 1, "test", None)]      assert flowfilter.parse("test") + + +def test_match(): +    with pytest.raises(ValueError): +        flowfilter.match('[foobar', None) + +    assert flowfilter.match(None, None) +    assert not flowfilter.match('foobar', None) diff --git a/test/mitmproxy/test_http.py b/test/mitmproxy/test_http.py index 777ab4dd..889eb0a7 100644 --- a/test/mitmproxy/test_http.py +++ b/test/mitmproxy/test_http.py @@ -1 +1,256 @@ -# TODO: write tests +import pytest + +from mitmproxy.test import tflow +from mitmproxy.net.http import Headers +import mitmproxy.io +from mitmproxy import flowfilter +from mitmproxy.exceptions import Kill +from mitmproxy import flow +from mitmproxy import http + + +class TestHTTPRequest: + +    def test_simple(self): +        f = tflow.tflow() +        r = f.request +        u = r.url +        r.url = u +        with pytest.raises(ValueError): +            setattr(r, "url", "") +        assert r.url == u +        r2 = r.copy() +        assert r.get_state() == r2.get_state() +        assert hash(r) + +    def test_get_url(self): +        r = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq()) + +        assert r.url == "http://address:22/path" + +        r.scheme = "https" +        assert r.url == "https://address:22/path" + +        r.host = "host" +        r.port = 42 +        assert r.url == "https://host:42/path" + +        r.host = "address" +        r.port = 22 +        assert r.url == "https://address:22/path" + +        assert r.pretty_url == "https://address:22/path" +        r.headers["Host"] = "foo.com:22" +        assert r.url == "https://address:22/path" +        assert r.pretty_url == "https://foo.com:22/path" + +    def test_replace(self): +        r = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq()) +        r.path = "path/foo" +        r.headers["Foo"] = "fOo" +        r.content = b"afoob" +        assert r.replace("foo(?i)", "boo") == 4 +        assert r.path == "path/boo" +        assert b"foo" not in r.content +        assert r.headers["boo"] == "boo" + +    def test_constrain_encoding(self): +        r = http.HTTPRequest.wrap(mitmproxy.test.tutils.treq()) +        r.headers["accept-encoding"] = "gzip, oink" +        r.constrain_encoding() +        assert "oink" not in r.headers["accept-encoding"] + +        r.headers.set_all("accept-encoding", ["gzip", "oink"]) +        r.constrain_encoding() +        assert "oink" not in r.headers["accept-encoding"] + +    def test_get_content_type(self): +        resp = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp()) +        resp.headers = Headers(content_type="text/plain") +        assert resp.headers["content-type"] == "text/plain" + + +class TestHTTPResponse: + +    def test_simple(self): +        f = tflow.tflow(resp=True) +        resp = f.response +        resp2 = resp.copy() +        assert resp2.get_state() == resp.get_state() + +    def test_replace(self): +        r = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp()) +        r.headers["Foo"] = "fOo" +        r.content = b"afoob" +        assert r.replace("foo(?i)", "boo") == 3 +        assert b"foo" not in r.content +        assert r.headers["boo"] == "boo" + +    def test_get_content_type(self): +        resp = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp()) +        resp.headers = Headers(content_type="text/plain") +        assert resp.headers["content-type"] == "text/plain" + + +class TestHTTPFlow: + +    def test_copy(self): +        f = tflow.tflow(resp=True) +        assert repr(f) +        f.get_state() +        f2 = f.copy() +        a = f.get_state() +        b = f2.get_state() +        del a["id"] +        del b["id"] +        assert a == b +        assert not f == f2 +        assert f is not f2 +        assert f.request.get_state() == f2.request.get_state() +        assert f.request is not f2.request +        assert f.request.headers == f2.request.headers +        assert f.request.headers is not f2.request.headers +        assert f.response.get_state() == f2.response.get_state() +        assert f.response is not f2.response + +        f = tflow.tflow(err=True) +        f2 = f.copy() +        assert f is not f2 +        assert f.request is not f2.request +        assert f.request.headers == f2.request.headers +        assert f.request.headers is not f2.request.headers +        assert f.error.get_state() == f2.error.get_state() +        assert f.error is not f2.error + +    def test_match(self): +        f = tflow.tflow(resp=True) +        assert not flowfilter.match("~b test", f) +        assert flowfilter.match(None, f) +        assert not flowfilter.match("~b test", f) + +        f = tflow.tflow(err=True) +        assert flowfilter.match("~e", f) + +        with pytest.raises(ValueError): +            flowfilter.match("~", f) + +    def test_backup(self): +        f = tflow.tflow() +        f.response = http.HTTPResponse.wrap(mitmproxy.test.tutils.tresp()) +        f.request.content = b"foo" +        assert not f.modified() +        f.backup() +        f.request.content = b"bar" +        assert f.modified() +        f.revert() +        assert f.request.content == b"foo" + +    def test_backup_idempotence(self): +        f = tflow.tflow(resp=True) +        f.backup() +        f.revert() +        f.backup() +        f.revert() + +    def test_getset_state(self): +        f = tflow.tflow(resp=True) +        state = f.get_state() +        assert f.get_state() == http.HTTPFlow.from_state( +            state).get_state() + +        f.response = None +        f.error = flow.Error("error") +        state = f.get_state() +        assert f.get_state() == http.HTTPFlow.from_state( +            state).get_state() + +        f2 = f.copy() +        f2.id = f.id  # copy creates a different uuid +        assert f.get_state() == f2.get_state() +        assert not f == f2 +        f2.error = flow.Error("e2") +        assert not f == f2 +        f.set_state(f2.get_state()) +        assert f.get_state() == f2.get_state() + +    def test_kill(self): +        f = tflow.tflow() +        f.reply.handle() +        f.intercept() +        assert f.killable +        f.kill() +        assert not f.killable +        assert f.reply.value == Kill + +    def test_resume(self): +        f = tflow.tflow() +        f.reply.handle() +        f.intercept() +        assert f.reply.state == "taken" +        f.resume() +        assert f.reply.state == "committed" + +    def test_replace_unicode(self): +        f = tflow.tflow(resp=True) +        f.response.content = b"\xc2foo" +        f.replace(b"foo", u"bar") + +    def test_replace_no_content(self): +        f = tflow.tflow() +        f.request.content = None +        assert f.replace("foo", "bar") == 0 + +    def test_replace(self): +        f = tflow.tflow(resp=True) +        f.request.headers["foo"] = "foo" +        f.request.content = b"afoob" + +        f.response.headers["foo"] = "foo" +        f.response.content = b"afoob" + +        assert f.replace("foo", "bar") == 6 + +        assert f.request.headers["bar"] == "bar" +        assert f.request.content == b"abarb" +        assert f.response.headers["bar"] == "bar" +        assert f.response.content == b"abarb" + +    def test_replace_encoded(self): +        f = tflow.tflow(resp=True) +        f.request.content = b"afoob" +        f.request.encode("gzip") +        f.response.content = b"afoob" +        f.response.encode("gzip") + +        f.replace("foo", "bar") + +        assert f.request.raw_content != b"abarb" +        f.request.decode() +        assert f.request.raw_content == b"abarb" + +        assert f.response.raw_content != b"abarb" +        f.response.decode() +        assert f.response.raw_content == b"abarb" + + +def test_make_error_response(): +    resp = http.make_error_response(543, 'foobar', Headers()) +    assert resp + + +def test_make_connect_request(): +    req = http.make_connect_request(('invalidhost', 1234)) +    assert req.first_line_format == 'authority' +    assert req.method == 'CONNECT' +    assert req.http_version == 'HTTP/1.1' + + +def test_make_connect_response(): +    resp = http.make_connect_response('foobar') +    assert resp.http_version == 'foobar' +    assert resp.status_code == 200 + + +def test_expect_continue_response(): +    assert http.expect_continue_response.http_version == 'HTTP/1.1' +    assert http.expect_continue_response.status_code == 100 diff --git a/test/mitmproxy/test_optmanager.py b/test/mitmproxy/test_optmanager.py index 65691fdf..161b0dcf 100644 --- a/test/mitmproxy/test_optmanager.py +++ b/test/mitmproxy/test_optmanager.py @@ -30,6 +30,14 @@ class TD2(TD):          super().__init__(three=three, **kwargs) +class TM(optmanager.OptManager): +    def __init__(self, one="one", two=["foo"], three=None): +        self.one = one +        self.two = two +        self.three = three +        super().__init__() + +  def test_defaults():      assert TD2.default("one") == "done"      assert TD2.default("two") == "dtwo" @@ -203,6 +211,9 @@ def test_serialize():      t = ""      o2.load(t) +    with pytest.raises(exceptions.OptionsError, matches='No such option: foobar'): +        o2.load("foobar: '123'") +  def test_serialize_defaults():      o = options.Options() @@ -224,13 +235,10 @@ def test_saving():          o.load_paths(dst)          assert o.three == "foo" - -class TM(optmanager.OptManager): -    def __init__(self, one="one", two=["foo"], three=None): -        self.one = one -        self.two = two -        self.three = three -        super().__init__() +        with open(dst, 'a') as f: +            f.write("foobar: '123'") +        with pytest.raises(exceptions.OptionsError, matches=''): +            o.load_paths(dst)  def test_merge(): diff --git a/test/mitmproxy/test_proxy.py b/test/mitmproxy/test_proxy.py index a14c851e..37cec57a 100644 --- a/test/mitmproxy/test_proxy.py +++ b/test/mitmproxy/test_proxy.py @@ -4,62 +4,17 @@ from unittest import mock  from OpenSSL import SSL  import pytest -from mitmproxy.test import tflow  from mitmproxy.tools import cmdline  from mitmproxy import options  from mitmproxy.proxy import ProxyConfig -from mitmproxy import connections  from mitmproxy.proxy.server import DummyServer, ProxyServer, ConnectionHandler  from mitmproxy.proxy import config -from mitmproxy import exceptions -from pathod import test -from mitmproxy.net.http import http1  from mitmproxy.test import tutils  from ..conftest import skip_windows -class TestServerConnection: - -    def test_simple(self): -        self.d = test.Daemon() -        sc = connections.ServerConnection((self.d.IFACE, self.d.port)) -        sc.connect() -        f = tflow.tflow() -        f.server_conn = sc -        f.request.path = "/p/200:da" - -        # use this protocol just to assemble - not for actual sending -        sc.wfile.write(http1.assemble_request(f.request)) -        sc.wfile.flush() - -        assert http1.read_response(sc.rfile, f.request, 1000) -        assert self.d.last_log() - -        sc.finish() -        self.d.shutdown() - -    def test_terminate_error(self): -        self.d = test.Daemon() -        sc = connections.ServerConnection((self.d.IFACE, self.d.port)) -        sc.connect() -        sc.connection = mock.Mock() -        sc.connection.recv = mock.Mock(return_value=False) -        sc.connection.flush = mock.Mock(side_effect=exceptions.TcpDisconnect) -        sc.finish() -        self.d.shutdown() - -    def test_repr(self): -        sc = tflow.tserver_conn() -        assert "address:22" in repr(sc) -        assert "ssl" not in repr(sc) -        sc.ssl_established = True -        assert "ssl" in repr(sc) -        sc.sni = "foo" -        assert "foo" in repr(sc) - -  class MockParser(argparse.ArgumentParser):      """ @@ -160,7 +115,7 @@ class TestProxyServer:              ProxyServer(conf)      def test_err_2(self): -        conf = ProxyConfig(options.Options(listen_host="invalidhost")) +        conf = ProxyConfig(options.Options(listen_host="256.256.256.256"))          with pytest.raises(Exception, match="Error starting proxy server"):              ProxyServer(conf) diff --git a/test/mitmproxy/tservers.py b/test/mitmproxy/tservers.py index 298fddcb..9a289ae5 100644 --- a/test/mitmproxy/tservers.py +++ b/test/mitmproxy/tservers.py @@ -98,13 +98,14 @@ class ProxyThread(threading.Thread):          threading.Thread.__init__(self)          self.tmaster = tmaster          self.name = "ProxyThread (%s:%s)" % ( -            tmaster.server.address.host, tmaster.server.address.port +            tmaster.server.address[0], +            tmaster.server.address[1],          )          controller.should_exit = False      @property      def port(self): -        return self.tmaster.server.address.port +        return self.tmaster.server.address[1]      @property      def tlog(self): diff --git a/web/src/js/components/FlowView/Details.jsx b/web/src/js/components/FlowView/Details.jsx index 10ec6553..a73abf37 100644 --- a/web/src/js/components/FlowView/Details.jsx +++ b/web/src/js/components/FlowView/Details.jsx @@ -26,7 +26,7 @@ export function ConnectionInfo({ conn }) {              <tbody>                  <tr key="address">                      <td>Address:</td> -                    <td>{conn.address.address.join(':')}</td> +                    <td>{conn.address.join(':')}</td>                  </tr>                  {conn.sni && (                      <tr key="sni"> diff --git a/web/src/js/filt/filt.peg b/web/src/js/filt/filt.peg index 989bfdd3..7122a1a5 100644 --- a/web/src/js/filt/filt.peg +++ b/web/src/js/filt/filt.peg @@ -106,7 +106,7 @@ function destination(regex){      function destinationFilter(flow){      return (!!flow.server_conn.address)             && -           regex.test(flow.server_conn.address.address[0] + ":" + flow.server_conn.address.address[1]); +           regex.test(flow.server_conn.address[0] + ":" + flow.server_conn.address[1]);      }      destinationFilter.desc = "destination address matches " + regex;      return destinationFilter; @@ -172,7 +172,7 @@ function source(regex){      function sourceFilter(flow){          return (!!flow.client_conn.address)                 && -               regex.test(flow.client_conn.address.address[0] + ":" + flow.client_conn.address.address[1]); +               regex.test(flow.client_conn.address[0] + ":" + flow.client_conn.address[1]);      }      sourceFilter.desc = "source address matches " + regex;      return sourceFilter;  | 
