diff options
184 files changed, 5104 insertions, 3941 deletions
diff --git a/.appveyor.yml b/.appveyor.yml index 339342ae..dae12978 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -30,7 +30,7 @@ test_script: deploy_script: ps: | if( - ($Env:TOXENV -match "py27") -and + ($Env:TOXENV -match "py35") -and (($Env:APPVEYOR_REPO_BRANCH -match "master") -or ($Env:APPVEYOR_REPO_TAG -match "true")) ) { pip install -U virtualenv @@ -41,7 +41,6 @@ deploy_script: cache: - C:\Users\appveyor\AppData\Local\pip\cache - - C:\projects\mitmproxy\.tox notifications: - provider: Slack diff --git a/.dockerignore b/.dockerignore deleted file mode 100644 index 6b8710a7..00000000 --- a/.dockerignore +++ /dev/null @@ -1 +0,0 @@ -.git diff --git a/.python-version b/.python-version deleted file mode 100644 index 2339c8bf..00000000 --- a/.python-version +++ /dev/null @@ -1,2 +0,0 @@ -2.7.11 -3.5.1 diff --git a/.sources/bootswatch.less b/.sources/bootswatch.less deleted file mode 100644 index f9e4b827..00000000 --- a/.sources/bootswatch.less +++ /dev/null @@ -1,171 +0,0 @@ -// Bootswatch.less -// Swatch: Journal -// Version: 2.0.4 -// ----------------------------------------------------- - -// TYPOGRAPHY -// ----------------------------------------------------- - -@import url('https://fonts.googleapis.com/css?family=Open+Sans:400,700'); - -h1, h2, h3, h4, h5, h6, .navbar .brand { - font-weight: 700; -} - -// SCAFFOLDING -// ----------------------------------------------------- - -a { - text-decoration: none; -} - -.nav a, .navbar .brand, .subnav a, a.btn, .dropdown-menu a { - text-decoration: none; -} - -// NAVBAR -// ----------------------------------------------------- - -.navbar { - - .navbar-inner { - @shadow: 0 2px 4px rgba(0,0,0,.25), inset 0 -1px 0 rgba(0,0,0,.1); - .box-shadow(@shadow); - border-top: 1px solid #E5E5E5; - .border-radius(0); - } - - .brand { - text-shadow: none; - - &:hover { - background-color: #EEEEEE; - } - } - - .navbar-text { - line-height: 68px; - } - - .nav > li > a { - text-shadow: none; - } - - .dropdown-menu { - .border-radius(0); - } - - .nav li.dropdown.active > .dropdown-toggle, - .nav li.dropdown.active > .dropdown-toggle:hover, - .nav li.dropdown.open > .dropdown-toggle, - .nav li.dropdown.active.open > .dropdown-toggle, - .nav li.dropdown.active.open > .dropdown-toggle:hover { - background-color: @grayLighter; - color: @linkColor; - } - - .nav li.dropdown .dropdown-toggle .caret, - .nav .open .caret, - .nav .open .dropdown-toggle:hover .caret { - border-top-color: @black; - opacity: 1; - } - - .nav-collapse.in .nav li > a:hover { - background-color: @grayLighter; - } - - .nav-collapse .nav li > a { - color: @textColor; - text-decoration: none; - font-weight: normal; - } - - .nav-collapse .navbar-form, - .nav-collapse .navbar-search { - border-color: transparent; - } - - .navbar-search .search-query, - .navbar-search .search-query:hover { - border: 1px solid @grayLighter; - color: @textColor; - .placeholder(@gray); - } -} - -div.subnav { - background-color: @bodyBackground; - background-image: none; - @shadow: 0 1px 2px rgba(0,0,0,.25); - .box-shadow(@shadow); - .border-radius(0); - - &.subnav-fixed { - top: @navbarHeight; - } - - .nav > li > a:hover, - .nav > .active > a, - .nav > .active > a:hover { - color: @textColor; - text-decoration: none; - font-weight: normal; - } - - .nav > li:first-child > a, - .nav > li:first-child > a:hover { - .border-radius(0); - } -} - -// BUTTONS -// ----------------------------------------------------- - -.btn-primary { - .buttonBackground(lighten(@linkColor, 5%), @linkColor); -} - -[class^="icon-"], [class*=" icon-"] { - vertical-align: -2px; -} - -// MODALS -// ----------------------------------------------------- - -.modal { - .border-radius(0px); - background: @bodyBackground; -} - -.modal-header { - border-bottom: none; -} - -.modal-header .close { - text-decoration: none; -} - -.modal-footer { - background: transparent; - .box-shadow(none); - border-top: none; -} - - -// MISC -// ----------------------------------------------------- - -code, pre, pre.prettyprint, .well { - background-color: @grayLighter; -} - -.hero-unit { - .box-shadow(inset 0 1px 1px rgba(0,0,0,.05)); - border: 1px solid rgba(0,0,0,.05); - .border-radius(0); -} - -.table-bordered, .well, .prettyprint { - .border-radius(0); -} diff --git a/.sources/make b/.sources/make deleted file mode 100644 index 94648859..00000000 --- a/.sources/make +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/sh -pygmentize -f html ../examples/test_context.py > ../pathod/templates/examples_context.html -pygmentize -f html ../examples/test_setup.py > ../pathod/templates/examples_setup.html -pygmentize -f html ../examples/test_setupall.py > ../pathod/templates/examples_setupall.html -pygmentize -f html ../examples/pathod_pathoc.py > ../pathod/templates/pathod_pathoc.html diff --git a/.sources/variables.less b/.sources/variables.less deleted file mode 100644 index 75ff5be6..00000000 --- a/.sources/variables.less +++ /dev/null @@ -1,208 +0,0 @@ -// Variables.less -// Variables to customize the look and feel of Bootstrap -// Swatch: Journal -// Version: 2.0.4 -// ----------------------------------------------------- - -// GLOBAL VALUES -// -------------------------------------------------- - - -// Grays -// ------------------------- -@black: #000; -@grayDarker: #222; -@grayDark: #333; -@gray: #888; -@grayLight: #999; -@grayLighter: #eee; -@white: #fff; - - -// Accent colors -// ------------------------- -@blue: #4380D3; -@blueDark: darken(@blue, 15%); -@green: #22B24C; -@red: #C00; -@yellow: #FCFADB; -@orange: #FF7F00; -@pink: #CC99CC; -@purple: #7a43b6; -@tan: #FFCA73; - - - -// Scaffolding -// ------------------------- -@bodyBackground: #FCFBFD; -@textColor: @grayDarker; - - -// Links -// ------------------------- -@linkColor: @blue; -@linkColorHover: @red; - - -// Typography -// ------------------------- -@sansFontFamily: 'Open Sans', "Helvetica Neue", Helvetica, Arial, sans-serif; -@serifFontFamily: Georgia, "Times New Roman", Times, serif; -@monoFontFamily: Menlo, Monaco, Consolas, "Courier New", monospace; - -@baseFontSize: 14px; -@baseFontFamily: @sansFontFamily; -@baseLineHeight: 18px; -@altFontFamily: @serifFontFamily; - -@headingsFontFamily: inherit; // empty to use BS default, @baseFontFamily -@headingsFontWeight: bold; // instead of browser default, bold -@headingsColor: inherit; // empty to use BS default, @textColor - - -// Tables -// ------------------------- -@tableBackground: transparent; // overall background-color -@tableBackgroundAccent: @grayLighter; // for striping -@tableBackgroundHover: #f5f5f5; // for hover -@tableBorder: #ddd; // table and cell border - - -// Buttons -// ------------------------- -@btnBackground: @white; -@btnBackgroundHighlight: darken(@white, 10%); -@btnBorder: darken(@white, 20%); - -@btnPrimaryBackground: @linkColor; -@btnPrimaryBackgroundHighlight: spin(@btnPrimaryBackground, 15%); - -@btnInfoBackground: #5bc0de; -@btnInfoBackgroundHighlight: #2f96b4; - -@btnSuccessBackground: #62c462; -@btnSuccessBackgroundHighlight: #51a351; - -@btnWarningBackground: lighten(@orange, 10%); -@btnWarningBackgroundHighlight: @orange; - -@btnDangerBackground: #ee5f5b; -@btnDangerBackgroundHighlight: #bd362f; - -@btnInverseBackground: @linkColor; -@btnInverseBackgroundHighlight: darken(@linkColor, 5%); - - -// Forms -// ------------------------- -@inputBackground: @white; -@inputBorder: #ccc; -@inputBorderRadius: 3px; -@inputDisabledBackground: @grayLighter; -@formActionsBackground: @grayLighter; - -// Dropdowns -// ------------------------- -@dropdownBackground: @bodyBackground; -@dropdownBorder: rgba(0,0,0,.2); -@dropdownLinkColor: @textColor; -@dropdownLinkColorHover: @textColor; -@dropdownLinkBackgroundHover: #eee; -@dropdownDividerTop: #e5e5e5; -@dropdownDividerBottom: @white; - - - -// COMPONENT VARIABLES -// -------------------------------------------------- - -// Z-index master list -// ------------------------- -// Used for a bird's eye view of components dependent on the z-axis -// Try to avoid customizing these :) -@zindexDropdown: 1000; -@zindexPopover: 1010; -@zindexTooltip: 1020; -@zindexFixedNavbar: 1030; -@zindexModalBackdrop: 1040; -@zindexModal: 1050; - - -// Sprite icons path -// ------------------------- -@iconSpritePath: "../img/glyphicons-halflings.png"; -@iconWhiteSpritePath: "../img/glyphicons-halflings-white.png"; - - -// Input placeholder text color -// ------------------------- -@placeholderText: @grayLight; - - -// Hr border color -// ------------------------- -@hrBorder: @grayLighter; - - -// Navbar -// ------------------------- -@navbarHeight: 50px; -@navbarBackground: @bodyBackground; -@navbarBackgroundHighlight: @bodyBackground; - -@navbarText: @textColor; -@navbarLinkColor: @linkColor; -@navbarLinkColorHover: @linkColor; -@navbarLinkColorActive: @navbarLinkColorHover; -@navbarLinkBackgroundHover: @grayLighter; -@navbarLinkBackgroundActive: @grayLighter; - -@navbarSearchBackground: lighten(@navbarBackground, 25%); -@navbarSearchBackgroundFocus: @white; -@navbarSearchBorder: darken(@navbarSearchBackground, 30%); -@navbarSearchPlaceholderColor: #ccc; -@navbarBrandColor: @blue; - - -// Hero unit -// ------------------------- -@heroUnitBackground: @grayLighter; -@heroUnitHeadingColor: inherit; -@heroUnitLeadColor: inherit; - - -// Form states and alerts -// ------------------------- -@warningText: #c09853; -@warningBackground: #fcf8e3; -@warningBorder: darken(spin(@warningBackground, -10), 3%); - -@errorText: #b94a48; -@errorBackground: #f2dede; -@errorBorder: darken(spin(@errorBackground, -10), 3%); - -@successText: #468847; -@successBackground: #dff0d8; -@successBorder: darken(spin(@successBackground, -10), 5%); - -@infoText: #3a87ad; -@infoBackground: #d9edf7; -@infoBorder: darken(spin(@infoBackground, -10), 7%); - - - -// GRID -// -------------------------------------------------- - -// Default 940px grid -// ------------------------- -@gridColumns: 12; -@gridColumnWidth: 60px; -@gridGutterWidth: 20px; -@gridRowWidth: (@gridColumns * @gridColumnWidth) + (@gridGutterWidth * (@gridColumns - 1)); - -// Fluid grid -// ------------------------- -@fluidGridColumnWidth: 6.382978723%; -@fluidGridGutterWidth: 2.127659574%; diff --git a/.travis.yml b/.travis.yml index 29d0897c..e832d058 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,7 +10,7 @@ addons: packages: - libssl-dev -env: +env: global: - CI_DEPS=codecov>=2.0.5 - CI_COMMANDS=codecov @@ -18,22 +18,20 @@ env: matrix: fast_finish: true include: - - python: 2.7 - env: TOXENV=py27 - - python: 2.7 - env: TOXENV=py27 NO_ALPN=1 + - python: 3.5 + env: TOXENV=lint +# - os: osx +# osx_image: xcode7.3 +# language: generic +# env: TOXENV=py35 - python: 3.5 env: TOXENV=py35 - python: 3.5 env: TOXENV=py35 NO_ALPN=1 - - language: generic - os: osx - osx_image: xcode7.1 - git: - depth: 9999999 + - python: 2.7 env: TOXENV=py27 - - python: 3.5 - env: TOXENV=lint + - python: 2.7 + env: TOXENV=py27 NO_ALPN=1 - python: 3.5 env: TOXENV=docs allow_failures: @@ -44,20 +42,26 @@ install: if [[ $TRAVIS_OS_NAME == "osx" ]] then brew update || brew update # try again if it fails - brew outdated openssl || brew upgrade openssl - brew install python + brew upgrade + brew reinstall openssl + brew reinstall pyenv + eval "$(pyenv init -)" + env PYTHON_CONFIGURE_OPTS="--enable-framework" pyenv install --skip-existing 3.5.2 + pyenv global 3.5.2 + pyenv shell 3.5.2 + pip install -U pip setuptools wheel virtualenv fi - pip install tox -script: set -o pipefail; tox -- --cov netlib --cov mitmproxy --cov pathod 2>&1 | grep -v Cryptography_locking_cb +script: set -o pipefail; python -m tox -- --cov netlib --cov mitmproxy --cov pathod 2>&1 | grep -v Cryptography_locking_cb after_success: - | if [[ $TRAVIS_OS_NAME == "osx" && $TRAVIS_PULL_REQUEST == "false" && ($TRAVIS_BRANCH == "master" || -n $TRAVIS_TAG) ]] then - pip install -U virtualenv - ./dev.sh - source venv/bin/activate + git fetch --unshallow + ./dev.sh 3.5 + source venv3.5/bin/activate pip install -e ./release python ./release/rtool.py bdist python ./release/rtool.py upload-snapshot --bdist --wheel @@ -73,6 +77,5 @@ notifications: cache: directories: - - $HOME/build/mitmproxy/mitmproxy/.tox - - $HOME/.cache/pip - $HOME/.pyenv + - $HOME/.cache/pip diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index a689ed5e..00000000 --- a/Dockerfile +++ /dev/null @@ -1,4 +0,0 @@ -FROM mitmproxy/base:latest-onbuild -EXPOSE 8080 -EXPOSE 8081 -VOLUME /certs @@ -52,9 +52,9 @@ If you want to contribute changes, keep on reading. Hacking ------- -To get started hacking on mitmproxy, make sure you have Python_ 2.7.x. with +To get started hacking on mitmproxy, make sure you have Python_ 3.5.x or above with virtualenv_ installed (you can find installation instructions for virtualenv -here_). Then do the following: +`here <http://virtualenv.readthedocs.org/en/latest/>`_). Then do the following: .. code-block:: text @@ -194,7 +194,6 @@ PR checks will fail and block merging. We are using this command to check for st .. _Python: https://www.python.org/ .. _virtualenv: http://virtualenv.readthedocs.org/en/latest/ -.. _here: http://virtualenv.readthedocs.org/en/latest/installation.html .. _autoenv: https://github.com/kennethreitz/autoenv .. _.env: https://github.com/mitmproxy/mitmproxy/blob/master/.env .. _Sphinx: http://sphinx-doc.org/ diff --git a/docs/dev/models.rst b/docs/dev/models.rst index 02f36f58..7260f1f7 100644 --- a/docs/dev/models.rst +++ b/docs/dev/models.rst @@ -56,8 +56,6 @@ Datastructures :special-members: :no-undoc-members: - .. autoclass:: decoded - .. automodule:: netlib.multidict .. autoclass:: MultiDictView diff --git a/docs/scripting/inlinescripts.rst b/docs/scripting/inlinescripts.rst index 1ee44972..bc9d5ff5 100644 --- a/docs/scripting/inlinescripts.rst +++ b/docs/scripting/inlinescripts.rst @@ -15,9 +15,7 @@ client: :caption: examples/add_header.py :language: python -The first argument to each event method is an instance of -:py:class:`~mitmproxy.script.ScriptContext` that lets the script interact with the global mitmproxy -state. The **response** event also gets an instance of :py:class:`~mitmproxy.models.HTTPFlow`, +All events that deal with an HTTP request get an instance of :py:class:`~mitmproxy.models.HTTPFlow`, which we can use to manipulate the response itself. We can now run this script using mitmdump or mitmproxy as follows: @@ -36,11 +34,6 @@ We encourage you to either browse them locally or on `GitHub`_. Events ------ -The ``context`` argument passed to each event method is always a -:py:class:`~mitmproxy.script.ScriptContext` instance. It is guaranteed to be the same object -for the scripts lifetime and is not shared between multiple inline scripts. You can safely use it -to store any form of state you require. - Script Lifecycle Events ^^^^^^^^^^^^^^^^^^^^^^^ @@ -155,8 +148,9 @@ The canonical API documentation is the code, which you can browse here, locally The main classes you will deal with in writing mitmproxy scripts are: -:py:class:`~mitmproxy.script.ScriptContext` - - A handle for interacting with mitmproxy's Flow Master from within scripts. +:py:class:`mitmproxy.flow.FlowMaster` + - The "heart" of mitmproxy, usually subclassed as :py:class:`mitmproxy.dump.DumpMaster` or + :py:class:`mitmproxy.console.ConsoleMaster`. :py:class:`~mitmproxy.models.ClientConnection` - Describes a client connection. :py:class:`~mitmproxy.models.ServerConnection` @@ -173,16 +167,7 @@ The main classes you will deal with in writing mitmproxy scripts are: - A dictionary-like object for managing HTTP headers. :py:class:`netlib.certutils.SSLCert` - Exposes information SSL certificates. -:py:class:`mitmproxy.flow.FlowMaster` - - The "heart" of mitmproxy, usually subclassed as :py:class:`mitmproxy.dump.DumpMaster` or - :py:class:`mitmproxy.console.ConsoleMaster`. - -Script Context --------------- -.. autoclass:: mitmproxy.script.ScriptContext - :members: - :undoc-members: Running scripts in parallel --------------------------- diff --git a/docs/tutorials/gamecenter.rst b/docs/tutorials/gamecenter.rst index 9dce5df8..d0d73b73 100644 --- a/docs/tutorials/gamecenter.rst +++ b/docs/tutorials/gamecenter.rst @@ -51,7 +51,7 @@ The contents of the submission are particularly interesting: <key>context</key> <integer>0</integer> <key>score-value</key> - <integer>0</integer> + <integer>55</integer> <key>timestamp</key> <integer>1363515361321</integer> </dict> diff --git a/examples/add_header.py b/examples/add_header.py index cf1b53cc..3e0b5f1e 100644 --- a/examples/add_header.py +++ b/examples/add_header.py @@ -1,2 +1,2 @@ -def response(context, flow): +def response(flow): flow.response.headers["newheader"] = "foo" diff --git a/examples/change_upstream_proxy.py b/examples/change_upstream_proxy.py index 34a6eece..49d5379f 100644 --- a/examples/change_upstream_proxy.py +++ b/examples/change_upstream_proxy.py @@ -14,7 +14,7 @@ def proxy_address(flow): return ("localhost", 8081) -def request(context, flow): +def request(flow): if flow.request.method == "CONNECT": # If the decision is done by domain, one could also modify the server address here. # We do it after CONNECT here to have the request data available as well. diff --git a/examples/custom_contentviews.py b/examples/custom_contentviews.py index 05ebeb69..5a63e2a0 100644 --- a/examples/custom_contentviews.py +++ b/examples/custom_contentviews.py @@ -11,7 +11,7 @@ class ViewPigLatin(contentviews.View): content_types = ["text/html"] def __call__(self, data, **metadata): - if strutils.isXML(data): + if strutils.is_xml(data): parser = lxml.etree.HTMLParser( strip_cdata=True, remove_blank_text=True @@ -20,7 +20,7 @@ class ViewPigLatin(contentviews.View): docinfo = d.getroottree().docinfo def piglify(src): - words = string.split(src) + words = src.split() ret = '' for word in words: idx = -1 @@ -62,9 +62,9 @@ class ViewPigLatin(contentviews.View): pig_view = ViewPigLatin() -def start(context): - context.add_contentview(pig_view) +def start(): + contentviews.add(pig_view) -def done(context): - context.remove_contentview(pig_view) +def done(): + contentviews.remove(pig_view) diff --git a/examples/dns_spoofing.py b/examples/dns_spoofing.py index 8d715f33..c020047f 100644 --- a/examples/dns_spoofing.py +++ b/examples/dns_spoofing.py @@ -28,7 +28,7 @@ import re parse_host_header = re.compile(r"^(?P<host>[^:]+|\[.+\])(?::(?P<port>\d+))?$") -def request(context, flow): +def request(flow): if flow.client_conn.ssl_established: flow.request.scheme = "https" sni = flow.client_conn.connection.get_servername() diff --git a/examples/dup_and_replay.py b/examples/dup_and_replay.py index 9ba91d3b..b47bf951 100644 --- a/examples/dup_and_replay.py +++ b/examples/dup_and_replay.py @@ -1,4 +1,7 @@ -def request(context, flow): - f = context.duplicate_flow(flow) +from mitmproxy import master + + +def request(flow): + f = master.duplicate_flow(flow) f.request.path = "/changed" - context.replay_request(f) + master.replay_request(f, block=True, run_scripthooks=False) diff --git a/examples/fail_with_500.py b/examples/fail_with_500.py index aec85b50..9710f74a 100644 --- a/examples/fail_with_500.py +++ b/examples/fail_with_500.py @@ -1,3 +1,3 @@ -def response(context, flow): +def response(flow): flow.response.status_code = 500 flow.response.content = b"" diff --git a/examples/filt.py b/examples/filt.py index 1a423845..21744edd 100644 --- a/examples/filt.py +++ b/examples/filt.py @@ -3,14 +3,16 @@ import sys from mitmproxy import filt +state = {} -def start(context): + +def start(): if len(sys.argv) != 2: raise ValueError("Usage: -s 'filt.py FILTER'") - context.filter = filt.parse(sys.argv[1]) + state["filter"] = filt.parse(sys.argv[1]) -def response(context, flow): - if flow.match(context.filter): +def response(flow): + if flow.match(state["filter"]): print("Flow matches filter:") print(flow) diff --git a/examples/flowwriter.py b/examples/flowwriter.py index cb5ccb0d..07c7ca20 100644 --- a/examples/flowwriter.py +++ b/examples/flowwriter.py @@ -3,8 +3,10 @@ import sys from mitmproxy.flow import FlowWriter +state = {} -def start(context): + +def start(): if len(sys.argv) != 2: raise ValueError('Usage: -s "flowriter.py filename"') @@ -12,9 +14,9 @@ def start(context): f = sys.stdout else: f = open(sys.argv[1], "wb") - context.flow_writer = FlowWriter(f) + state["flow_writer"] = FlowWriter(f) -def response(context, flow): +def response(flow): if random.choice([True, False]): - context.flow_writer.add(flow) + state["flow_writer"].add(flow) diff --git a/examples/har_extractor.py b/examples/har_extractor.py index d6b50c21..76059d8e 100644 --- a/examples/har_extractor.py +++ b/examples/har_extractor.py @@ -2,6 +2,7 @@ This inline script utilizes harparser.HAR from https://github.com/JustusW/harparser to generate a HAR log object. """ +import mitmproxy.ctx import six import sys import pytz @@ -54,12 +55,24 @@ class _HARLog(HAR.log): return self.__page_list__ -def start(context): +class Context(object): + pass + +context = Context() + + +def start(): """ On start we create a HARLog instance. You will have to adapt this to suit your actual needs of HAR generation. As it will probably be necessary to cluster logs by IPs or reset them from time to time. """ + if sys.version_info >= (3, 0): + raise RuntimeError( + "har_extractor.py does not work on Python 3. " + "Please check out https://github.com/mitmproxy/mitmproxy/issues/1320 " + "if you want to help making this work again." + ) context.dump_file = None if len(sys.argv) > 1: context.dump_file = sys.argv[1] @@ -73,7 +86,7 @@ def start(context): context.seen_server = set() -def response(context, flow): +def response(flow): """ Called when a server response has been received. At the time of this message both a request and a response are present and completely done. @@ -127,7 +140,7 @@ def response(context, flow): for k, v in flow.request.query or {}] response_body_size = len(flow.response.content) - response_body_decoded_size = len(flow.response.get_decoded_content()) + response_body_decoded_size = len(flow.response.content) response_body_compression = response_body_decoded_size - response_body_size entry = HAR.entries({ @@ -195,7 +208,7 @@ def response(context, flow): context.HARLog.add(entry) -def done(context): +def done(): """ Called once on script shutdown, after any other events. """ @@ -206,17 +219,19 @@ def done(context): compressed_json_dump = context.HARLog.compress() if context.dump_file == '-': - context.log(pprint.pformat(json.loads(json_dump))) + mitmproxy.ctx.log(pprint.pformat(json.loads(json_dump))) elif context.dump_file.endswith('.zhar'): - file(context.dump_file, "w").write(compressed_json_dump) + with open(context.dump_file, "wb") as f: + f.write(compressed_json_dump) else: - file(context.dump_file, "w").write(json_dump) - context.log( + with open(context.dump_file, "wb") as f: + f.write(json_dump) + mitmproxy.ctx.log( "HAR log finished with %s bytes (%s bytes compressed)" % ( len(json_dump), len(compressed_json_dump) ) ) - context.log( + mitmproxy.ctx.log( "Compression rate is %s%%" % str( 100. * len(compressed_json_dump) / len(json_dump) ) diff --git a/examples/iframe_injector.py b/examples/iframe_injector.py index 9495da93..352c3c24 100644 --- a/examples/iframe_injector.py +++ b/examples/iframe_injector.py @@ -2,27 +2,27 @@ # (this script works best with --anticache) import sys from bs4 import BeautifulSoup -from mitmproxy.models import decoded +iframe_url = None -def start(context): + +def start(): if len(sys.argv) != 2: raise ValueError('Usage: -s "iframe_injector.py url"') - context.iframe_url = sys.argv[1] + global iframe_url + iframe_url = sys.argv[1] -def response(context, flow): - if flow.request.host in context.iframe_url: +def response(flow): + if flow.request.host in iframe_url: return - with decoded(flow.response): # Remove content encoding (gzip, ...) - html = BeautifulSoup(flow.response.content, "lxml") - if html.body: - iframe = html.new_tag( - "iframe", - src=context.iframe_url, - frameborder=0, - height=0, - width=0) - html.body.insert(0, iframe) - flow.response.content = str(html) - context.log("Iframe inserted.") + html = BeautifulSoup(flow.response.content, "lxml") + if html.body: + iframe = html.new_tag( + "iframe", + src=iframe_url, + frameborder=0, + height=0, + width=0) + html.body.insert(0, iframe) + flow.response.content = str(html).encode("utf8") diff --git a/examples/modify_form.py b/examples/modify_form.py index 3fe0cf96..b63a1586 100644 --- a/examples/modify_form.py +++ b/examples/modify_form.py @@ -1,4 +1,4 @@ -def request(context, flow): +def request(flow): if flow.request.urlencoded_form: flow.request.urlencoded_form["mitmproxy"] = "rocks" else: diff --git a/examples/modify_querystring.py b/examples/modify_querystring.py index b89e5c8d..ee8a89ad 100644 --- a/examples/modify_querystring.py +++ b/examples/modify_querystring.py @@ -1,2 +1,2 @@ -def request(context, flow): +def request(flow): flow.request.query["mitmproxy"] = "rocks" diff --git a/examples/modify_response_body.py b/examples/modify_response_body.py index 3034892e..b4632248 100644 --- a/examples/modify_response_body.py +++ b/examples/modify_response_body.py @@ -2,19 +2,20 @@ # (this script works best with --anticache) import sys -from mitmproxy.models import decoded +state = {} -def start(context): + +def start(): if len(sys.argv) != 3: raise ValueError('Usage: -s "modify_response_body.py old new"') # You may want to use Python's argparse for more sophisticated argument # parsing. - context.old, context.new = sys.argv[1], sys.argv[2] + state["old"], state["new"] = sys.argv[1].encode(), sys.argv[2].encode() -def response(context, flow): - with decoded(flow.response): # automatically decode gzipped responses. - flow.response.content = flow.response.content.replace( - context.old, - context.new) +def response(flow): + flow.response.content = flow.response.content.replace( + state["old"], + state["new"] + ) diff --git a/examples/nonblocking.py b/examples/nonblocking.py index 4609f389..b81478df 100644 --- a/examples/nonblocking.py +++ b/examples/nonblocking.py @@ -1,9 +1,10 @@ import time +import mitmproxy from mitmproxy.script import concurrent @concurrent # Remove this and see what happens -def request(context, flow): - context.log("handle request: %s%s" % (flow.request.host, flow.request.path)) +def request(flow): + mitmproxy.ctx.log("handle request: %s%s" % (flow.request.host, flow.request.path)) time.sleep(5) - context.log("start request: %s%s" % (flow.request.host, flow.request.path)) + mitmproxy.ctx.log("start request: %s%s" % (flow.request.host, flow.request.path)) diff --git a/examples/proxapp.py b/examples/proxapp.py index 613d3f8b..2935b587 100644 --- a/examples/proxapp.py +++ b/examples/proxapp.py @@ -4,6 +4,7 @@ instance, we're using the Flask framework (http://flask.pocoo.org/) to expose a single simplest-possible page. """ from flask import Flask +import mitmproxy app = Flask("proxapp") @@ -15,10 +16,10 @@ def hello_world(): # Register the app using the magic domain "proxapp" on port 80. Requests to # this domain and port combination will now be routed to the WSGI app instance. -def start(context): - context.app_registry.add(app, "proxapp", 80) +def start(): + mitmproxy.ctx.master.apps.add(app, "proxapp", 80) # SSL works too, but the magic domain needs to be resolvable from the mitmproxy machine due to mitmproxy's design. # mitmproxy will connect to said domain and use serve its certificate (unless --no-upstream-cert is set) # but won't send any data. - context.app_registry.add(app, "example.com", 443) + mitmproxy.ctx.master.apps.add(app, "example.com", 443) diff --git a/examples/redirect_requests.py b/examples/redirect_requests.py index d7db3f1c..36594bcd 100644 --- a/examples/redirect_requests.py +++ b/examples/redirect_requests.py @@ -5,7 +5,7 @@ from mitmproxy.models import HTTPResponse from netlib.http import Headers -def request(context, flow): +def request(flow): # pretty_host takes the "Host" header of the request into account, # which is useful in transparent mode where we usually only have the IP # otherwise. @@ -13,9 +13,10 @@ def request(context, flow): # Method 1: Answer with a locally generated response if flow.request.pretty_host.endswith("example.com"): resp = HTTPResponse( - "HTTP/1.1", 200, "OK", + b"HTTP/1.1", 200, b"OK", Headers(Content_Type="text/html"), - "helloworld") + b"helloworld" + ) flow.reply.send(resp) # Method 2: Redirect the request to a different server diff --git a/examples/sslstrip.py b/examples/sslstrip.py index 8dde8e3e..0be1f020 100644 --- a/examples/sslstrip.py +++ b/examples/sslstrip.py @@ -1,40 +1,36 @@ -from netlib.http import decoded import re from six.moves import urllib +# set of SSL/TLS capable hosts +secure_hosts = set() -def start(context): - # set of SSL/TLS capable hosts - context.secure_hosts = set() - -def request(context, flow): +def request(flow): flow.request.headers.pop('If-Modified-Since', None) flow.request.headers.pop('Cache-Control', None) # proxy connections to SSL-enabled hosts - if flow.request.pretty_host in context.secure_hosts: + if flow.request.pretty_host in secure_hosts: flow.request.scheme = 'https' flow.request.port = 443 -def response(context, flow): - with decoded(flow.response): - flow.request.headers.pop('Strict-Transport-Security', None) - flow.request.headers.pop('Public-Key-Pins', None) +def response(flow): + flow.request.headers.pop('Strict-Transport-Security', None) + flow.request.headers.pop('Public-Key-Pins', None) - # strip links in response body - flow.response.content = flow.response.content.replace('https://', 'http://') + # strip links in response body + flow.response.content = flow.response.content.replace('https://', 'http://') - # strip links in 'Location' header - if flow.response.headers.get('Location', '').startswith('https://'): - location = flow.response.headers['Location'] - hostname = urllib.parse.urlparse(location).hostname - if hostname: - context.secure_hosts.add(hostname) - flow.response.headers['Location'] = location.replace('https://', 'http://', 1) + # strip links in 'Location' header + if flow.response.headers.get('Location', '').startswith('https://'): + location = flow.response.headers['Location'] + hostname = urllib.parse.urlparse(location).hostname + if hostname: + secure_hosts.add(hostname) + flow.response.headers['Location'] = location.replace('https://', 'http://', 1) - # strip secure flag from 'Set-Cookie' headers - cookies = flow.response.headers.get_all('Set-Cookie') - cookies = [re.sub(r';\s*secure\s*', '', s) for s in cookies] - flow.response.headers.set_all('Set-Cookie', cookies) + # strip secure flag from 'Set-Cookie' headers + cookies = flow.response.headers.get_all('Set-Cookie') + cookies = [re.sub(r';\s*secure\s*', '', s) for s in cookies] + flow.response.headers.set_all('Set-Cookie', cookies) diff --git a/examples/stream.py b/examples/stream.py index 3adbe437..8598f329 100644 --- a/examples/stream.py +++ b/examples/stream.py @@ -1,4 +1,4 @@ -def responseheaders(context, flow): +def responseheaders(flow): """ Enables streaming for all responses. """ diff --git a/examples/stream_modify.py b/examples/stream_modify.py index aa395c03..5e5da95b 100644 --- a/examples/stream_modify.py +++ b/examples/stream_modify.py @@ -16,5 +16,5 @@ def modify(chunks): yield chunk.replace("foo", "bar") -def responseheaders(context, flow): +def responseheaders(flow): flow.response.stream = modify diff --git a/examples/stub.py b/examples/stub.py index a0f73538..e5b4a39a 100644 --- a/examples/stub.py +++ b/examples/stub.py @@ -1,79 +1,87 @@ +import mitmproxy """ This is a script stub, with definitions for all events. """ -def start(context): +def start(): """ - Called once on script startup, before any other events. + Called once on script startup before any other events """ - context.log("start") + mitmproxy.ctx.log("start") -def clientconnect(context, root_layer): +def configure(options): + """ + Called once on script startup before any other events, and whenever options changes. + """ + mitmproxy.ctx.log("configure") + + +def clientconnect(root_layer): """ Called when a client initiates a connection to the proxy. Note that a connection can correspond to multiple HTTP requests """ - context.log("clientconnect") + mitmproxy.ctx.log("clientconnect") -def request(context, flow): +def request(flow): """ Called when a client request has been received. """ - context.log("request") + mitmproxy.ctx.log("request") -def serverconnect(context, server_conn): +def serverconnect(server_conn): """ Called when the proxy initiates a connection to the target server. Note that a connection can correspond to multiple HTTP requests """ - context.log("serverconnect") + mitmproxy.ctx.log("serverconnect") -def responseheaders(context, flow): +def responseheaders(flow): """ Called when the response headers for a server response have been received, but the response body has not been processed yet. Can be used to tell mitmproxy to stream the response. """ - context.log("responseheaders") + mitmproxy.ctx.log("responseheaders") -def response(context, flow): +def response(flow): """ Called when a server response has been received. """ - context.log("response") + mitmproxy.ctx.log("response") -def error(context, flow): +def error(flow): """ Called when a flow error has occured, e.g. invalid server responses, or interrupted connections. This is distinct from a valid server HTTP error response, which is simply a response with an HTTP error code. """ - context.log("error") + mitmproxy.ctx.log("error") -def serverdisconnect(context, server_conn): +def serverdisconnect(server_conn): """ Called when the proxy closes the connection to the target server. """ - context.log("serverdisconnect") + mitmproxy.ctx.log("serverdisconnect") -def clientdisconnect(context, root_layer): +def clientdisconnect(root_layer): """ Called when a client disconnects from the proxy. """ - context.log("clientdisconnect") + mitmproxy.ctx.log("clientdisconnect") -def done(context): +def done(): """ Called once on script shutdown, after any other events. """ - context.log("done") + mitmproxy.ctx.log("done") diff --git a/examples/tcp_message.py b/examples/tcp_message.py index 78500c19..b431c23f 100644 --- a/examples/tcp_message.py +++ b/examples/tcp_message.py @@ -11,15 +11,17 @@ mitmdump -T --host --tcp ".*" -q -s examples/tcp_message.py from netlib import strutils -def tcp_message(ctx, tcp_msg): +def tcp_message(tcp_msg): modified_msg = tcp_msg.message.replace("foo", "bar") is_modified = False if modified_msg == tcp_msg.message else True tcp_msg.message = modified_msg - print("[tcp_message{}] from {} {} to {} {}:\r\n{}".format( - " (modified)" if is_modified else "", - "client" if tcp_msg.sender == tcp_msg.client_conn else "server", - tcp_msg.sender.address, - "server" if tcp_msg.receiver == tcp_msg.server_conn else "client", - tcp_msg.receiver.address, strutils.clean_bin(tcp_msg.message))) + print( + "[tcp_message{}] from {} {} to {} {}:\r\n{}".format( + " (modified)" if is_modified else "", + "client" if tcp_msg.sender == tcp_msg.client_conn else "server", + tcp_msg.sender.address, + "server" if tcp_msg.receiver == tcp_msg.server_conn else "client", + tcp_msg.receiver.address, strutils.bytes_to_escaped_str(tcp_msg.message)) + ) diff --git a/examples/tls_passthrough.py b/examples/tls_passthrough.py index 50aab65b..20e8f9be 100644 --- a/examples/tls_passthrough.py +++ b/examples/tls_passthrough.py @@ -20,13 +20,14 @@ Example: Authors: Maximilian Hils, Matthew Tuusberg """ -from __future__ import (absolute_import, print_function, division) +from __future__ import absolute_import, print_function, division import collections import random import sys from enum import Enum +import mitmproxy from mitmproxy.exceptions import TlsProtocolException from mitmproxy.protocol import TlsLayer, RawTCPLayer @@ -97,7 +98,6 @@ class TlsFeedback(TlsLayer): def _establish_tls_with_client(self): server_address = self.server_conn.address - tls_strategy = self.script_context.tls_strategy try: super(TlsFeedback, self)._establish_tls_with_client() @@ -110,15 +110,18 @@ class TlsFeedback(TlsLayer): # inline script hooks below. +tls_strategy = None -def start(context): + +def start(): + global tls_strategy if len(sys.argv) == 2: - context.tls_strategy = ProbabilisticStrategy(float(sys.argv[1])) + tls_strategy = ProbabilisticStrategy(float(sys.argv[1])) else: - context.tls_strategy = ConservativeStrategy() + tls_strategy = ConservativeStrategy() -def next_layer(context, next_layer): +def next_layer(next_layer): """ This hook does the actual magic - if the next layer is planned to be a TLS layer, we check if we want to enter pass-through mode instead. @@ -126,14 +129,13 @@ def next_layer(context, next_layer): if isinstance(next_layer, TlsLayer) and next_layer._client_tls: server_address = next_layer.server_conn.address - if context.tls_strategy.should_intercept(server_address): + if tls_strategy.should_intercept(server_address): # We try to intercept. # Monkey-Patch the layer to get feedback from the TLSLayer if interception worked. next_layer.__class__ = TlsFeedback - next_layer.script_context = context else: # We don't intercept - reply with a pass-through layer and add a "skipped" entry. - context.log("TLS passthrough for %s" % repr(next_layer.server_conn.address), "info") - next_layer_replacement = RawTCPLayer(next_layer.ctx, logging=False) + mitmproxy.ctx.log("TLS passthrough for %s" % repr(next_layer.server_conn.address), "info") + next_layer_replacement = RawTCPLayer(next_layer.ctx, ignore=True) next_layer.reply.send(next_layer_replacement) - context.tls_strategy.record_skipped(server_address) + tls_strategy.record_skipped(server_address) diff --git a/examples/upsidedownternet.py b/examples/upsidedownternet.py index 9aac9f05..d5059092 100644 --- a/examples/upsidedownternet.py +++ b/examples/upsidedownternet.py @@ -1,17 +1,15 @@ from six.moves import cStringIO as StringIO from PIL import Image -from mitmproxy.models import decoded -def response(context, flow): +def response(flow): if flow.response.headers.get("content-type", "").startswith("image"): - with decoded(flow.response): # automatically decode gzipped responses. - try: - s = StringIO(flow.response.content) - img = Image.open(s).rotate(180) - s2 = StringIO() - img.save(s2, "png") - flow.response.content = s2.getvalue() - flow.response.headers["content-type"] = "image/png" - except: # Unknown image types etc. - pass + try: + s = StringIO(flow.response.content) + img = Image.open(s).rotate(180) + s2 = StringIO() + img.save(s2, "png") + flow.response.content = s2.getvalue() + flow.response.headers["content-type"] = "image/png" + except: # Unknown image types etc. + pass diff --git a/mitmproxy/addons.py b/mitmproxy/addons.py new file mode 100644 index 00000000..c779aaf8 --- /dev/null +++ b/mitmproxy/addons.py @@ -0,0 +1,64 @@ +from __future__ import absolute_import, print_function, division +from mitmproxy import exceptions +import pprint + + +def _get_name(itm): + return getattr(itm, "name", itm.__class__.__name__) + + +class Addons(object): + def __init__(self, master): + self.chain = [] + self.master = master + master.options.changed.connect(self.options_update) + + def options_update(self, options): + for i in self.chain: + with self.master.handlecontext(): + i.configure(options) + + def add(self, *addons): + self.chain.extend(addons) + for i in addons: + self.invoke_with_context(i, "start") + self.invoke_with_context(i, "configure", self.master.options) + + def remove(self, addon): + self.chain = [i for i in self.chain if i is not addon] + self.invoke_with_context(addon, "done") + + def done(self): + for i in self.chain: + self.invoke_with_context(i, "done") + + def has_addon(self, name): + """ + Is an addon with this name registered? + """ + for i in self.chain: + if _get_name(i) == name: + return True + + def __len__(self): + return len(self.chain) + + def __str__(self): + return pprint.pformat([str(i) for i in self.chain]) + + def invoke_with_context(self, addon, name, *args, **kwargs): + with self.master.handlecontext(): + self.invoke(addon, name, *args, **kwargs) + + def invoke(self, addon, name, *args, **kwargs): + func = getattr(addon, name, None) + if func: + if not callable(func): + raise exceptions.AddonError( + "Addon handler %s not callable" % name + ) + func(*args, **kwargs) + + def __call__(self, name, *args, **kwargs): + for i in self.chain: + self.invoke(i, name, *args, **kwargs) diff --git a/mitmproxy/builtins/__init__.py b/mitmproxy/builtins/__init__.py new file mode 100644 index 00000000..3974d736 --- /dev/null +++ b/mitmproxy/builtins/__init__.py @@ -0,0 +1,23 @@ +from __future__ import absolute_import, print_function, division + +from mitmproxy.builtins import anticache +from mitmproxy.builtins import anticomp +from mitmproxy.builtins import filestreamer +from mitmproxy.builtins import stickyauth +from mitmproxy.builtins import stickycookie +from mitmproxy.builtins import script +from mitmproxy.builtins import replace +from mitmproxy.builtins import setheaders + + +def default_addons(): + return [ + anticache.AntiCache(), + anticomp.AntiComp(), + stickyauth.StickyAuth(), + stickycookie.StickyCookie(), + script.ScriptLoader(), + filestreamer.FileStreamer(), + replace.Replace(), + setheaders.SetHeaders(), + ] diff --git a/mitmproxy/builtins/anticache.py b/mitmproxy/builtins/anticache.py new file mode 100644 index 00000000..f208e2fb --- /dev/null +++ b/mitmproxy/builtins/anticache.py @@ -0,0 +1,13 @@ +from __future__ import absolute_import, print_function, division + + +class AntiCache: + def __init__(self): + self.enabled = False + + def configure(self, options): + self.enabled = options.anticache + + def request(self, flow): + if self.enabled: + flow.request.anticache() diff --git a/mitmproxy/builtins/anticomp.py b/mitmproxy/builtins/anticomp.py new file mode 100644 index 00000000..50bd1b73 --- /dev/null +++ b/mitmproxy/builtins/anticomp.py @@ -0,0 +1,13 @@ +from __future__ import absolute_import, print_function, division + + +class AntiComp: + def __init__(self): + self.enabled = False + + def configure(self, options): + self.enabled = options.anticomp + + def request(self, flow): + if self.enabled: + flow.request.anticomp() diff --git a/mitmproxy/builtins/dumper.py b/mitmproxy/builtins/dumper.py new file mode 100644 index 00000000..239630fb --- /dev/null +++ b/mitmproxy/builtins/dumper.py @@ -0,0 +1,252 @@ +from __future__ import absolute_import, print_function, division + +import itertools +import traceback + +import click + +from mitmproxy import contentviews +from mitmproxy import ctx +from mitmproxy import exceptions +from mitmproxy import filt +from netlib import human +from netlib import strutils + + +def indent(n, text): + l = str(text).strip().splitlines() + pad = " " * n + return "\n".join(pad + i for i in l) + + +class Dumper(): + def __init__(self): + self.filter = None + self.flow_detail = None + self.outfp = None + self.showhost = None + + def echo(self, text, ident=None, **style): + if ident: + text = indent(ident, text) + click.secho(text, file=self.outfp, **style) + if self.outfp: + self.outfp.flush() + + def _echo_message(self, message): + if self.flow_detail >= 2 and hasattr(message, "headers"): + headers = "\r\n".join( + "{}: {}".format( + click.style( + strutils.bytes_to_escaped_str(k), fg="blue", bold=True + ), + click.style( + strutils.bytes_to_escaped_str(v), fg="blue" + ) + ) + for k, v in message.headers.fields + ) + self.echo(headers, ident=4) + if self.flow_detail >= 3: + try: + content = message.content + except ValueError: + content = message.get_content(strict=False) + + if content is None: + self.echo("(content missing)", ident=4) + elif content: + self.echo("") + + try: + type, lines = contentviews.get_content_view( + contentviews.get("Auto"), + content, + headers=getattr(message, "headers", None) + ) + except exceptions.ContentViewException: + s = "Content viewer failed: \n" + traceback.format_exc() + ctx.log.debug(s) + type, lines = contentviews.get_content_view( + contentviews.get("Raw"), + content, + headers=getattr(message, "headers", None) + ) + + styles = dict( + highlight=dict(bold=True), + offset=dict(fg="blue"), + header=dict(fg="green", bold=True), + text=dict(fg="green") + ) + + def colorful(line): + yield u" " # we can already indent here + for (style, text) in line: + yield click.style(text, **styles.get(style, {})) + + if self.flow_detail == 3: + lines_to_echo = itertools.islice(lines, 70) + else: + lines_to_echo = lines + + lines_to_echo = list(lines_to_echo) + + content = u"\r\n".join( + u"".join(colorful(line)) for line in lines_to_echo + ) + + self.echo(content) + if next(lines, None): + self.echo("(cut off)", ident=4, dim=True) + + if self.flow_detail >= 2: + self.echo("") + + def _echo_request_line(self, flow): + if flow.request.stickycookie: + stickycookie = click.style( + "[stickycookie] ", fg="yellow", bold=True + ) + else: + stickycookie = "" + + if flow.client_conn: + client = click.style( + strutils.escape_control_characters( + flow.client_conn.address.host + ), + bold=True + ) + elif flow.request.is_replay: + client = click.style("[replay]", fg="yellow", bold=True) + else: + client = "" + + method = flow.request.method + method_color = dict( + GET="green", + DELETE="red" + ).get(method.upper(), "magenta") + method = click.style( + strutils.escape_control_characters(method), + fg=method_color, + bold=True + ) + if self.showhost: + url = flow.request.pretty_url + else: + url = flow.request.url + url = click.style(strutils.escape_control_characters(url), bold=True) + + httpversion = "" + if flow.request.http_version not in ("HTTP/1.1", "HTTP/1.0"): + # We hide "normal" HTTP 1. + httpversion = " " + flow.request.http_version + + line = "{stickycookie}{client} {method} {url}{httpversion}".format( + stickycookie=stickycookie, + client=client, + method=method, + url=url, + httpversion=httpversion + ) + self.echo(line) + + def _echo_response_line(self, flow): + if flow.response.is_replay: + replay = click.style("[replay] ", fg="yellow", bold=True) + else: + replay = "" + + code = flow.response.status_code + code_color = None + if 200 <= code < 300: + code_color = "green" + elif 300 <= code < 400: + code_color = "magenta" + elif 400 <= code < 600: + code_color = "red" + code = click.style( + str(code), + fg=code_color, + bold=True, + blink=(code == 418) + ) + reason = click.style( + strutils.escape_control_characters(flow.response.reason), + fg=code_color, + bold=True + ) + + if flow.response.raw_content is None: + size = "(content missing)" + else: + size = human.pretty_size(len(flow.response.raw_content)) + size = click.style(size, bold=True) + + arrows = click.style(" <<", bold=True) + + line = "{replay} {arrows} {code} {reason} {size}".format( + replay=replay, + arrows=arrows, + code=code, + reason=reason, + size=size + ) + self.echo(line) + + def echo_flow(self, f): + if f.request: + self._echo_request_line(f) + self._echo_message(f.request) + + if f.response: + self._echo_response_line(f) + self._echo_message(f.response) + + if f.error: + self.echo(" << {}".format(f.error.msg), bold=True, fg="red") + + def match(self, f): + if self.flow_detail == 0: + return False + if not self.filt: + return True + elif f.match(self.filt): + return True + return False + + def configure(self, options): + if options.filtstr: + self.filt = filt.parse(options.filtstr) + if not self.filt: + raise exceptions.OptionsError( + "Invalid filter expression: %s" % options.filtstr + ) + else: + self.filt = None + self.flow_detail = options.flow_detail + self.outfp = options.tfile + self.showhost = options.showhost + + def response(self, f): + if self.match(f): + self.echo_flow(f) + + def error(self, f): + if self.match(f): + self.echo_flow(f) + + def tcp_message(self, f): + # FIXME: Filter should be applied here + if self.options.flow_detail == 0: + return + message = f.messages[-1] + direction = "->" if message.from_client else "<-" + self.echo("{client} {direction} tcp {direction} {server}".format( + client=repr(f.client_conn.address), + server=repr(f.server_conn.address), + direction=direction, + )) + self._echo_message(message) diff --git a/mitmproxy/builtins/filestreamer.py b/mitmproxy/builtins/filestreamer.py new file mode 100644 index 00000000..97ddc7c4 --- /dev/null +++ b/mitmproxy/builtins/filestreamer.py @@ -0,0 +1,66 @@ +from __future__ import absolute_import, print_function, division +import os.path + +from mitmproxy import exceptions +from mitmproxy.flow import io + + +class FileStreamer: + def __init__(self): + self.stream = None + self.active_flows = set() # type: Set[models.Flow] + + def start_stream_to_path(self, path, mode, filt): + path = os.path.expanduser(path) + try: + f = open(path, mode) + except IOError as v: + return str(v) + self.stream = io.FilteredFlowWriter(f, filt) + self.active_flows = set() + + def configure(self, options): + # We're already streaming - stop the previous stream and restart + if self.stream: + self.done() + + if options.outfile: + filt = None + if options.get("filtstr"): + filt = filt.parse(options.filtstr) + if not filt: + raise exceptions.OptionsError( + "Invalid filter specification: %s" % options.filtstr + ) + path, mode = options.outfile + if mode not in ("wb", "ab"): + raise exceptions.OptionsError("Invalid mode.") + err = self.start_stream_to_path(path, mode, filt) + if err: + raise exceptions.OptionsError(err) + + def tcp_open(self, flow): + if self.stream: + self.active_flows.add(flow) + + def tcp_close(self, flow): + if self.stream: + self.stream.add(flow) + self.active_flows.discard(flow) + + def response(self, flow): + if self.stream: + self.stream.add(flow) + self.active_flows.discard(flow) + + def request(self, flow): + if self.stream: + self.active_flows.add(flow) + + def done(self): + if self.stream: + for flow in self.active_flows: + self.stream.add(flow) + self.active_flows = set([]) + self.stream.fo.close() + self.stream = None diff --git a/mitmproxy/builtins/replace.py b/mitmproxy/builtins/replace.py new file mode 100644 index 00000000..83b96cee --- /dev/null +++ b/mitmproxy/builtins/replace.py @@ -0,0 +1,49 @@ +import re + +from mitmproxy import exceptions +from mitmproxy import filt + + +class Replace: + def __init__(self): + self.lst = [] + + def configure(self, options): + """ + .replacements is a list of tuples (fpat, rex, s): + + fpatt: a string specifying a filter pattern. + rex: a regular expression. + s: the replacement string + """ + lst = [] + for fpatt, rex, s in options.replacements: + cpatt = filt.parse(fpatt) + if not cpatt: + raise exceptions.OptionsError( + "Invalid filter pattern: %s" % fpatt + ) + try: + re.compile(rex) + except re.error as e: + raise exceptions.OptionsError( + "Invalid regular expression: %s - %s" % (rex, str(e)) + ) + lst.append((rex, s, cpatt)) + self.lst = lst + + def execute(self, f): + for rex, s, cpatt in self.lst: + if cpatt(f): + if f.response: + f.response.replace(rex, s) + else: + f.request.replace(rex, s) + + def request(self, flow): + if not flow.reply.acked: + self.execute(flow) + + def response(self, flow): + if not flow.reply.acked: + self.execute(flow) diff --git a/mitmproxy/builtins/script.py b/mitmproxy/builtins/script.py new file mode 100644 index 00000000..ab068e47 --- /dev/null +++ b/mitmproxy/builtins/script.py @@ -0,0 +1,186 @@ +from __future__ import absolute_import, print_function, division + +import contextlib +import os +import shlex +import sys +import threading +import traceback + +from mitmproxy import exceptions +from mitmproxy import controller +from mitmproxy import ctx + + +import watchdog.events +from watchdog.observers import polling + + +def parse_command(command): + """ + Returns a (path, args) tuple. + """ + if not command or not command.strip(): + raise exceptions.AddonError("Empty script command.") + # Windows: escape all backslashes in the path. + if os.name == "nt": # pragma: no cover + backslashes = shlex.split(command, posix=False)[0].count("\\") + command = command.replace("\\", "\\\\", backslashes) + args = shlex.split(command) # pragma: no cover + args[0] = os.path.expanduser(args[0]) + if not os.path.exists(args[0]): + raise exceptions.AddonError( + ("Script file not found: %s.\r\n" + "If your script path contains spaces, " + "make sure to wrap it in additional quotes, e.g. -s \"'./foo bar/baz.py' --args\".") % + args[0]) + elif os.path.isdir(args[0]): + raise exceptions.AddonError("Not a file: %s" % args[0]) + return args[0], args[1:] + + +@contextlib.contextmanager +def scriptenv(path, args): + oldargs = sys.argv + sys.argv = [path] + args + script_dir = os.path.dirname(os.path.abspath(path)) + sys.path.append(script_dir) + try: + yield + except Exception: + _, _, tb = sys.exc_info() + scriptdir = os.path.dirname(os.path.abspath(path)) + for i, s in enumerate(reversed(traceback.extract_tb(tb))): + tb = tb.tb_next + if not os.path.abspath(s[0]).startswith(scriptdir): + break + ctx.log.error("Script error: %s" % "".join(traceback.format_tb(tb))) + finally: + sys.argv = oldargs + sys.path.pop() + + +def load_script(path, args): + with open(path, "rb") as f: + try: + code = compile(f.read(), path, 'exec') + except SyntaxError as e: + ctx.log.error( + "Script error: %s line %s: %s" % ( + e.filename, e.lineno, e.msg + ) + ) + return + ns = {'__file__': os.path.abspath(path)} + with scriptenv(path, args): + exec(code, ns, ns) + return ns + + +class ReloadHandler(watchdog.events.FileSystemEventHandler): + def __init__(self, callback): + self.callback = callback + + def on_modified(self, event): + self.callback() + + def on_created(self, event): + self.callback() + + +class Script: + """ + An addon that manages a single script. + """ + def __init__(self, command): + self.name = command + + self.command = command + self.path, self.args = parse_command(command) + self.ns = None + self.observer = None + self.dead = False + + self.last_options = None + self.should_reload = threading.Event() + + for i in controller.Events: + if not hasattr(self, i): + def mkprox(): + evt = i + + def prox(*args, **kwargs): + self.run(evt, *args, **kwargs) + return prox + setattr(self, i, mkprox()) + + def run(self, name, *args, **kwargs): + # It's possible for ns to be un-initialised if we failed during + # configure + if self.ns is not None and not self.dead: + func = self.ns.get(name) + if func: + with scriptenv(self.path, self.args): + func(*args, **kwargs) + + def reload(self): + self.should_reload.set() + + def tick(self): + if self.should_reload.is_set(): + self.should_reload.clear() + ctx.log.info("Reloading script: %s" % self.name) + self.ns = load_script(self.path, self.args) + self.start() + self.configure(self.last_options) + else: + self.run("tick") + + def start(self): + self.ns = load_script(self.path, self.args) + self.run("start") + + def configure(self, options): + self.last_options = options + if not self.observer: + self.observer = polling.PollingObserver() + # Bind the handler to the real underlying master object + self.observer.schedule( + ReloadHandler(self.reload), + os.path.dirname(self.path) or "." + ) + self.observer.start() + self.run("configure", options) + + def done(self): + self.run("done") + self.dead = True + + +class ScriptLoader(): + """ + An addon that manages loading scripts from options. + """ + def configure(self, options): + for s in options.scripts: + if options.scripts.count(s) > 1: + raise exceptions.OptionsError("Duplicate script: %s" % s) + + for a in ctx.master.addons.chain[:]: + if isinstance(a, Script) and a.name not in options.scripts: + ctx.log.info("Un-loading script: %s" % a.name) + ctx.master.addons.remove(a) + + current = {} + for a in ctx.master.addons.chain[:]: + if isinstance(a, Script): + current[a.name] = a + ctx.master.addons.chain.remove(a) + + for s in options.scripts: + if s in current: + ctx.master.addons.chain.append(current[s]) + else: + ctx.log.info("Loading script: %s" % s) + sc = Script(s) + ctx.master.addons.add(sc) diff --git a/mitmproxy/builtins/setheaders.py b/mitmproxy/builtins/setheaders.py new file mode 100644 index 00000000..6bda3f55 --- /dev/null +++ b/mitmproxy/builtins/setheaders.py @@ -0,0 +1,39 @@ +from mitmproxy import exceptions +from mitmproxy import filt + + +class SetHeaders: + def __init__(self): + self.lst = [] + + def configure(self, options): + """ + options.setheaders is a tuple of (fpatt, header, value) + + fpatt: String specifying a filter pattern. + header: Header name. + value: Header value string + """ + for fpatt, header, value in options.setheaders: + cpatt = filt.parse(fpatt) + if not cpatt: + raise exceptions.OptionsError( + "Invalid setheader filter pattern %s" % fpatt + ) + self.lst.append((fpatt, header, value, cpatt)) + + def run(self, f, hdrs): + for _, header, value, cpatt in self.lst: + if cpatt(f): + hdrs.pop(header, None) + for _, header, value, cpatt in self.lst: + if cpatt(f): + hdrs.add(header, value) + + def request(self, flow): + if not flow.reply.acked: + self.run(flow, flow.request.headers) + + def response(self, flow): + if not flow.reply.acked: + self.run(flow, flow.response.headers) diff --git a/mitmproxy/builtins/stickyauth.py b/mitmproxy/builtins/stickyauth.py new file mode 100644 index 00000000..1309911c --- /dev/null +++ b/mitmproxy/builtins/stickyauth.py @@ -0,0 +1,28 @@ +from __future__ import absolute_import, print_function, division + +from mitmproxy import filt +from mitmproxy import exceptions + + +class StickyAuth: + def __init__(self): + # Compiled filter + self.flt = None + self.hosts = {} + + def configure(self, options): + if options.stickyauth: + flt = filt.parse(options.stickyauth) + if not flt: + raise exceptions.OptionsError( + "stickyauth: invalid filter expression: %s" % options.stickyauth + ) + self.flt = flt + + def request(self, flow): + host = flow.request.host + if "authorization" in flow.request.headers: + self.hosts[host] = flow.request.headers["authorization"] + elif flow.match(self.flt): + if host in self.hosts: + flow.request.headers["authorization"] = self.hosts[host] diff --git a/mitmproxy/builtins/stickycookie.py b/mitmproxy/builtins/stickycookie.py new file mode 100644 index 00000000..dc699bb4 --- /dev/null +++ b/mitmproxy/builtins/stickycookie.py @@ -0,0 +1,80 @@ +import collections +from six.moves import http_cookiejar +from netlib.http import cookies + +from mitmproxy import exceptions +from mitmproxy import filt + + +def ckey(attrs, f): + """ + Returns a (domain, port, path) tuple. + """ + domain = f.request.host + path = "/" + if "domain" in attrs: + domain = attrs["domain"] + if "path" in attrs: + path = attrs["path"] + return (domain, f.request.port, path) + + +def domain_match(a, b): + if http_cookiejar.domain_match(a, b): + return True + elif http_cookiejar.domain_match(a, b.strip(".")): + return True + return False + + +class StickyCookie: + def __init__(self): + self.jar = collections.defaultdict(dict) + self.flt = None + + def configure(self, options): + if options.stickycookie: + flt = filt.parse(options.stickycookie) + if not flt: + raise exceptions.OptionsError( + "stickycookie: invalid filter expression: %s" % options.stickycookie + ) + self.flt = flt + + def response(self, flow): + if self.flt: + for name, (value, attrs) in flow.response.cookies.items(multi=True): + # FIXME: We now know that Cookie.py screws up some cookies with + # valid RFC 822/1123 datetime specifications for expiry. Sigh. + dom_port_path = ckey(attrs, flow) + + if domain_match(flow.request.host, dom_port_path[0]): + if cookies.is_expired(attrs): + # Remove the cookie from jar + self.jar[dom_port_path].pop(name, None) + + # If all cookies of a dom_port_path have been removed + # then remove it from the jar itself + if not self.jar[dom_port_path]: + self.jar.pop(dom_port_path, None) + else: + b = attrs.with_insert(0, name, value) + self.jar[dom_port_path][name] = b + + def request(self, flow): + if self.flt: + l = [] + if flow.match(self.flt): + for domain, port, path in self.jar.keys(): + match = [ + domain_match(flow.request.host, domain), + flow.request.port == port, + flow.request.path.startswith(path) + ] + if all(match): + c = self.jar[(domain, port, path)] + l.extend([cookies.format_cookie_header(c[name].items(multi=True)) for name in c.keys()]) + if l: + # FIXME: we need to formalise this... + flow.request.stickycookie = True + flow.request.headers["cookie"] = "; ".join(l) diff --git a/mitmproxy/cmdline.py b/mitmproxy/cmdline.py index 551fffa0..507ddfc7 100644 --- a/mitmproxy/cmdline.py +++ b/mitmproxy/cmdline.py @@ -284,8 +284,8 @@ def basic_options(parser): ) parser.add_argument( "-v", "--verbose", - action="store_const", dest="verbose", default=1, const=2, - help="Increase event log verbosity." + action="store_const", dest="verbose", default=2, const=3, + help="Increase log verbosity." ) outfile = parser.add_mutually_exclusive_group() outfile.add_argument( @@ -384,7 +384,7 @@ def proxy_options(parser): help=""" Generic TCP SSL proxy mode for all hosts that match the pattern. Similar to --ignore, but SSL connections are intercepted. The - communication contents are printed to the event log in verbose mode. + communication contents are printed to the log in verbose mode. """ ) group.add_argument( diff --git a/mitmproxy/console/common.py b/mitmproxy/console/common.py index b450c19d..5d15e0cd 100644 --- a/mitmproxy/console/common.py +++ b/mitmproxy/console/common.py @@ -4,10 +4,10 @@ import os import urwid import urwid.util +import six import netlib from mitmproxy import flow -from mitmproxy import models from mitmproxy import utils from mitmproxy.console import signals from netlib import human @@ -37,7 +37,7 @@ def is_keypress(k): """ Is this input event a keypress? """ - if isinstance(k, basestring): + if isinstance(k, six.string_types): return True @@ -108,7 +108,7 @@ def shortcuts(k): def fcol(s, attr): - s = unicode(s) + s = six.text_type(s) return ( "fixed", len(s), @@ -216,7 +216,7 @@ def save_data(path, data): if not path: return try: - with file(path, "wb") as f: + with open(path, "wb") as f: f.write(data) except IOError as v: signals.status_message.send(message=v.strerror) @@ -257,28 +257,30 @@ def copy_flow_format_data(part, scope, flow): else: data = "" if scope in ("q", "a"): - if flow.request.content is None: + request = flow.request.copy() + request.decode(strict=False) + if request.content is None: return None, "Request content is missing" - with models.decoded(flow.request): - if part == "h": - data += netlib.http.http1.assemble_request(flow.request) - elif part == "c": - data += flow.request.content - else: - raise ValueError("Unknown part: {}".format(part)) - if scope == "a" and flow.request.content and flow.response: + if part == "h": + data += netlib.http.http1.assemble_request(request) + elif part == "c": + data += request.content + else: + raise ValueError("Unknown part: {}".format(part)) + if scope == "a" and flow.request.raw_content and flow.response: # Add padding between request and response data += "\r\n" * 2 if scope in ("s", "a") and flow.response: - if flow.response.content is None: + response = flow.response.copy() + response.decode(strict=False) + if response.content is None: return None, "Response content is missing" - with models.decoded(flow.response): - if part == "h": - data += netlib.http.http1.assemble_response(flow.response) - elif part == "c": - data += flow.response.content - else: - raise ValueError("Unknown part: {}".format(part)) + if part == "h": + data += netlib.http.http1.assemble_response(response) + elif part == "c": + data += response.content + else: + raise ValueError("Unknown part: {}".format(part)) return data, False @@ -364,8 +366,8 @@ def ask_save_body(part, master, state, flow): "q" (request), "s" (response) or None (ask user if necessary). """ - request_has_content = flow.request and flow.request.content - response_has_content = flow.response and flow.response.content + request_has_content = flow.request and flow.request.raw_content + response_has_content = flow.response and flow.response.raw_content if part is None: # We first need to determine whether we want to save the request or the @@ -388,12 +390,12 @@ def ask_save_body(part, master, state, flow): elif part == "q" and request_has_content: ask_save_path( "Save request content", - flow.request.get_decoded_content() + flow.request.get_content(strict=False), ) elif part == "s" and response_has_content: ask_save_path( "Save response content", - flow.response.get_decoded_content() + flow.response.get_content(strict=False), ) else: signals.status_message.send(message="No content to save.") @@ -418,9 +420,9 @@ def format_flow(f, focus, extended=False, hostheader=False, marked=False): marked = marked, ) if f.response: - if f.response.content: - contentdesc = human.pretty_size(len(f.response.content)) - elif f.response.content is None: + if f.response.raw_content: + contentdesc = human.pretty_size(len(f.response.raw_content)) + elif f.response.raw_content is None: contentdesc = "[content missing]" else: contentdesc = "[no content]" diff --git a/mitmproxy/console/flowdetailview.py b/mitmproxy/console/flowdetailview.py index 2a493b90..0a03e1c4 100644 --- a/mitmproxy/console/flowdetailview.py +++ b/mitmproxy/console/flowdetailview.py @@ -71,7 +71,7 @@ def flowdetails(state, flow): parts.append( [ "Alt names", - ", ".join(c.altnames) + ", ".join(str(x) for x in c.altnames) ] ) text.extend( diff --git a/mitmproxy/console/flowlist.py b/mitmproxy/console/flowlist.py index 8c20c4b6..bc523874 100644 --- a/mitmproxy/console/flowlist.py +++ b/mitmproxy/console/flowlist.py @@ -44,11 +44,11 @@ footer = [ ] -class EventListBox(urwid.ListBox): +class LogBufferBox(urwid.ListBox): def __init__(self, master): self.master = master - urwid.ListBox.__init__(self, master.eventlist) + urwid.ListBox.__init__(self, master.logbuffer) def keypress(self, size, key): key = common.shortcuts(key) @@ -56,7 +56,7 @@ class EventListBox(urwid.ListBox): self.master.clear_events() key = None elif key == "G": - self.set_focus(len(self.master.eventlist) - 1) + self.set_focus(len(self.master.logbuffer) - 1) elif key == "g": self.set_focus(0) return urwid.ListBox.keypress(self, size, key) @@ -76,7 +76,7 @@ class BodyPile(urwid.Pile): [ FlowListBox(master), urwid.Frame( - EventListBox(master), + LogBufferBox(master), header = self.inactive_header ) ] @@ -118,7 +118,7 @@ class ConnectionItem(urwid.WidgetWrap): return common.format_flow( self.flow, self.f, - hostheader = self.master.showhost, + hostheader = self.master.options.showhost, marked=self.state.flow_marked(self.flow) ) @@ -151,7 +151,7 @@ class ConnectionItem(urwid.WidgetWrap): if k == "a": self.master.start_server_playback( [i.copy() for i in self.master.state.view], - self.master.killextra, self.master.rheaders, + self.master.options.kill, self.master.rheaders, False, self.master.nopop, self.master.options.replay_ignore_params, self.master.options.replay_ignore_content, @@ -161,7 +161,7 @@ class ConnectionItem(urwid.WidgetWrap): elif k == "t": self.master.start_server_playback( [self.flow.copy()], - self.master.killextra, self.master.rheaders, + self.master.options.kill, self.master.rheaders, False, self.master.nopop, self.master.options.replay_ignore_params, self.master.options.replay_ignore_content, @@ -317,11 +317,9 @@ class FlowListWalker(urwid.ListWalker): class FlowListBox(urwid.ListBox): def __init__(self, master): + # type: (mitmproxy.console.master.ConsoleMaster) -> None self.master = master - urwid.ListBox.__init__( - self, - FlowListWalker(master, master.state) - ) + super(FlowListBox, self).__init__(FlowListWalker(master, master.state)) def get_method_raw(self, k): if k: @@ -395,13 +393,13 @@ class FlowListBox(urwid.ListBox): elif key == "F": self.master.toggle_follow_flows() elif key == "W": - if self.master.stream: - self.master.stop_stream() + if self.master.options.outfile: + self.master.options.outfile = None else: signals.status_prompt_path.send( self, - prompt = "Stream flows to", - callback = self.master.start_stream_to_path + prompt="Stream flows to", + callback= lambda path: self.master.options.update(outfile=(path, "ab")) ) else: return urwid.ListBox.keypress(self, size, key) diff --git a/mitmproxy/console/flowview.py b/mitmproxy/console/flowview.py index e9b23176..c85a9f73 100644 --- a/mitmproxy/console/flowview.py +++ b/mitmproxy/console/flowview.py @@ -110,7 +110,7 @@ class FlowViewHeader(urwid.WidgetWrap): f, False, extended=True, - hostheader=self.master.showhost + hostheader=self.master.options.showhost ) signals.flow_change.connect(self.sig_flow_change) @@ -120,7 +120,7 @@ class FlowViewHeader(urwid.WidgetWrap): flow, False, extended=True, - hostheader=self.master.showhost + hostheader=self.master.options.showhost ) @@ -176,7 +176,7 @@ class FlowView(tabs.Tabs): self.show() def content_view(self, viewmode, message): - if message.content is None: + if message.raw_content is None: msg, body = "", [urwid.Text([("error", "[content missing]")])] return msg, body else: @@ -200,20 +200,34 @@ class FlowView(tabs.Tabs): def _get_content_view(self, viewmode, message, max_lines, _): try: + content = message.content + if content != message.raw_content: + enc = "[decoded {}]".format( + message.headers.get("content-encoding") + ) + else: + enc = None + except ValueError: + content = message.raw_content + enc = "[cannot decode]" + try: query = None if isinstance(message, models.HTTPRequest): query = message.query description, lines = contentviews.get_content_view( - viewmode, message.content, headers=message.headers, query=query + viewmode, content, headers=message.headers, query=query ) except exceptions.ContentViewException: s = "Content viewer failed: \n" + traceback.format_exc() - signals.add_event(s, "error") + signals.add_log(s, "error") description, lines = contentviews.get_content_view( - contentviews.get("Raw"), message.content, headers=message.headers + contentviews.get("Raw"), content, headers=message.headers ) description = description.replace("Raw", "Couldn't parse: falling back to Raw") + if enc: + description = " ".join([enc, description]) + # Give hint that you have to tab for the response. if description == "No content" and isinstance(message, models.HTTPRequest): description = "No request content (press tab to view response)" @@ -257,7 +271,7 @@ class FlowView(tabs.Tabs): def conn_text(self, conn): if conn: txt = common.format_keyvals( - [(h + ":", v) for (h, v) in conn.headers.fields], + [(h + ":", v) for (h, v) in conn.headers.items(multi=True)], key = "header", val = "text" ) @@ -407,17 +421,16 @@ class FlowView(tabs.Tabs): ) ) if part == "r": - with models.decoded(message): - # Fix an issue caused by some editors when editing a - # request/response body. Many editors make it hard to save a - # file without a terminating newline on the last line. When - # editing message bodies, this can cause problems. For now, I - # just strip the newlines off the end of the body when we return - # from an editor. - c = self.master.spawn_editor(message.content or "") - message.content = c.rstrip("\n") + # Fix an issue caused by some editors when editing a + # request/response body. Many editors make it hard to save a + # file without a terminating newline on the last line. When + # editing message bodies, this can cause problems. For now, I + # just strip the newlines off the end of the body when we return + # from an editor. + c = self.master.spawn_editor(message.get_content(strict=False) or b"") + message.content = c.rstrip(b"\n") elif part == "f": - if not message.urlencoded_form and message.content: + if not message.urlencoded_form and message.raw_content: signals.status_prompt_onekey.send( prompt = "Existing body is not a URL-encoded form. Clear and edit?", keys = [ @@ -512,14 +525,10 @@ class FlowView(tabs.Tabs): signals.flow_change.send(self, flow = self.flow) def delete_body(self, t): - if t == "m": - val = None - else: - val = None if self.tab_offset == TAB_REQ: - self.flow.request.content = val + self.flow.request.content = None else: - self.flow.response.content = val + self.flow.response.content = None signals.flow_change.send(self, flow = self.flow) def keypress(self, size, key): @@ -681,10 +690,10 @@ class FlowView(tabs.Tabs): ) key = None elif key == "v": - if conn.content: + if conn.raw_content: t = conn.headers.get("content-type") if "EDITOR" in os.environ or "PAGER" in os.environ: - self.master.spawn_external_viewer(conn.content, t) + self.master.spawn_external_viewer(conn.get_content(strict=False), t) else: signals.status_message.send( message = "Error! Set $EDITOR or $PAGER." diff --git a/mitmproxy/console/grideditor.py b/mitmproxy/console/grideditor.py index 9fa51ccb..87700fd7 100644 --- a/mitmproxy/console/grideditor.py +++ b/mitmproxy/console/grideditor.py @@ -6,11 +6,12 @@ import re import urwid +from mitmproxy import exceptions from mitmproxy import filt -from mitmproxy import script -from mitmproxy import utils +from mitmproxy.builtins import script from mitmproxy.console import common from mitmproxy.console import signals +from netlib import strutils from netlib.http import cookies from netlib.http import user_agents @@ -55,7 +56,7 @@ class TextColumn: o = editor.walker.get_current_value() if o is not None: n = editor.master.spawn_editor(o.encode("string-escape")) - n = utils.clean_hanging_newline(n) + n = strutils.clean_hanging_newline(n) editor.walker.set_current_value(n, False) editor.walker._modified() elif key in ["enter"]: @@ -395,7 +396,7 @@ class GridEditor(urwid.WidgetWrap): if p: try: p = os.path.expanduser(p) - d = file(p, "rb").read() + d = open(p, "rb").read() self.walker.set_current_value(d, unescaped) self.walker._modified() except IOError as v: @@ -643,8 +644,8 @@ class ScriptEditor(GridEditor): def is_error(self, col, val): try: - script.Script.parse_command(val) - except script.ScriptException as e: + script.parse_command(val) + except exceptions.AddonError as e: return str(e) diff --git a/mitmproxy/console/master.py b/mitmproxy/console/master.py index 5fd51f4b..25a0b83f 100644 --- a/mitmproxy/console/master.py +++ b/mitmproxy/console/master.py @@ -14,12 +14,15 @@ import traceback import weakref import urwid +from typing import Optional # noqa +from mitmproxy import builtins from mitmproxy import contentviews from mitmproxy import controller from mitmproxy import exceptions from mitmproxy import flow from mitmproxy import script +from mitmproxy import utils from mitmproxy.console import flowlist from mitmproxy.console import flowview from mitmproxy.console import grideditor @@ -30,7 +33,7 @@ from mitmproxy.console import palettes from mitmproxy.console import signals from mitmproxy.console import statusbar from mitmproxy.console import window -from netlib import tcp +from netlib import tcp, strutils EVENTLOG_SIZE = 500 @@ -175,64 +178,37 @@ class ConsoleState(flow.State): self.add_flow_setting(flow, "marked", marked) -class Options(object): - attributes = [ - "app", - "app_domain", - "app_ip", - "anticache", - "anticomp", - "client_replay", - "eventlog", - "follow", - "keepserving", - "kill", - "intercept", - "limit", - "no_server", - "refresh_server_playback", - "rfile", - "scripts", - "showhost", - "replacements", - "rheaders", - "setheaders", - "server_replay", - "stickycookie", - "stickyauth", - "stream_large_bodies", - "verbosity", - "wfile", - "nopop", - "palette", - "palette_transparent", - "no_mouse", - "outfile", - ] - - def __init__(self, **kwargs): - for k, v in kwargs.items(): - setattr(self, k, v) - for i in self.attributes: - if not hasattr(self, i): - setattr(self, i, None) +class Options(flow.options.Options): + def __init__( + self, + eventlog=False, # type: bool + follow=False, # type: bool + intercept=False, # type: bool + limit=None, # type: Optional[str] + palette=None, # type: Optional[str] + palette_transparent=False, # type: bool + no_mouse=False, # type: bool + **kwargs + ): + self.eventlog = eventlog + self.follow = follow + self.intercept = intercept + self.limit = limit + self.palette = palette + self.palette_transparent = palette_transparent + self.no_mouse = no_mouse + super(Options, self).__init__(**kwargs) class ConsoleMaster(flow.FlowMaster): palette = [] def __init__(self, server, options): - flow.FlowMaster.__init__(self, server, ConsoleState()) + flow.FlowMaster.__init__(self, options, server, ConsoleState()) self.stream_path = None - self.options = options - - if options.replacements: - for i in options.replacements: - self.replacehooks.add(*i) - - if options.setheaders: - for i in options.setheaders: - self.setheaders.add(*i) + # This line is just for type hinting + self.options = self.options # type: Options + self.options.errored.connect(self.options_error) r = self.set_intercept(options.intercept) if r: @@ -242,30 +218,14 @@ class ConsoleMaster(flow.FlowMaster): if options.limit: self.set_limit(options.limit) - r = self.set_stickycookie(options.stickycookie) - if r: - print("Sticky cookies error: {}".format(r), file=sys.stderr) - sys.exit(1) - - r = self.set_stickyauth(options.stickyauth) - if r: - print("Sticky auth error: {}".format(r), file=sys.stderr) - sys.exit(1) - self.set_stream_large_bodies(options.stream_large_bodies) - self.refresh_server_playback = options.refresh_server_playback - self.anticache = options.anticache - self.anticomp = options.anticomp - self.killextra = options.kill self.rheaders = options.rheaders self.nopop = options.nopop - self.showhost = options.showhost self.palette = options.palette self.palette_transparent = options.palette_transparent - self.eventlog = options.eventlog - self.eventlist = urwid.SimpleListWalker([]) + self.logbuffer = urwid.SimpleListWalker([]) self.follow = options.follow if options.client_replay: @@ -274,56 +234,49 @@ class ConsoleMaster(flow.FlowMaster): if options.server_replay: self.server_playback_path(options.server_replay) - if options.scripts: - for i in options.scripts: - try: - self.load_script(i) - except exceptions.ScriptException as e: - print("Script load error: {}".format(e), file=sys.stderr) - sys.exit(1) - - if options.outfile: - err = self.start_stream_to_path( - options.outfile[0], - options.outfile[1] - ) - if err: - print("Stream file error: {}".format(err), file=sys.stderr) - sys.exit(1) - self.view_stack = [] if options.app: self.start_app(self.options.app_host, self.options.app_port) + signals.call_in.connect(self.sig_call_in) signals.pop_view_state.connect(self.sig_pop_view_state) signals.push_view_state.connect(self.sig_push_view_state) - signals.sig_add_event.connect(self.sig_add_event) + signals.sig_add_log.connect(self.sig_add_log) + self.addons.add(*builtins.default_addons()) def __setattr__(self, name, value): self.__dict__[name] = value signals.update_settings.send(self) + def options_error(self, opts, exc): + signals.status_message.send( + message=str(exc), + expire=1 + ) + def load_script(self, command, use_reloader=True): # We default to using the reloader in the console ui. return super(ConsoleMaster, self).load_script(command, use_reloader) - def sig_add_event(self, sender, e, level): - needed = dict(error=0, info=1, debug=2).get(level, 1) - if self.options.verbosity < needed: + def sig_add_log(self, sender, e, level): + if self.options.verbosity < utils.log_tier(level): return if level == "error": + signals.status_message.send( + message = "Error: %s" % str(e) + ) e = urwid.Text(("error", str(e))) else: e = urwid.Text(str(e)) - self.eventlist.append(e) - if len(self.eventlist) > EVENTLOG_SIZE: - self.eventlist.pop(0) - self.eventlist.set_focus(len(self.eventlist) - 1) + self.logbuffer.append(e) + if len(self.logbuffer) > EVENTLOG_SIZE: + self.logbuffer.pop(0) + self.logbuffer.set_focus(len(self.logbuffer) - 1) - def add_event(self, e, level): - signals.add_event(e, level) + def add_log(self, e, level): + signals.add_log(e, level) def sig_call_in(self, sender, seconds, callback, args=()): def cb(*_): @@ -354,25 +307,25 @@ class ConsoleMaster(flow.FlowMaster): status, val = s.run(method, f) if val: if status: - signals.add_event("Method %s return: %s" % (method, val), "debug") + signals.add_log("Method %s return: %s" % (method, val), "debug") else: - signals.add_event( + signals.add_log( "Method %s error: %s" % (method, val[1]), "error") def run_script_once(self, command, f): if not command: return - signals.add_event("Running script on flow: %s" % command, "debug") + signals.add_log("Running script on flow: %s" % command, "debug") try: - s = script.Script(command, script.ScriptContext(self)) + s = script.Script(command) s.load() except script.ScriptException as e: signals.status_message.send( message='Error loading "{}".'.format(command) ) - signals.add_event('Error loading "{}":\n{}'.format(command, e), "error") + signals.add_log('Error loading "{}":\n{}'.format(command, e), "error") return if f.request: @@ -385,7 +338,7 @@ class ConsoleMaster(flow.FlowMaster): signals.flow_change.send(self, flow = f) def toggle_eventlog(self): - self.eventlog = not self.eventlog + self.options.eventlog = not self.options.eventlog signals.pop_view_state.send(self) self.view_flowlist() @@ -416,7 +369,7 @@ class ConsoleMaster(flow.FlowMaster): if flows: self.start_server_playback( flows, - self.killextra, self.rheaders, + self.options.kill, self.rheaders, False, self.nopop, self.options.replay_ignore_params, self.options.replay_ignore_content, @@ -512,7 +465,7 @@ class ConsoleMaster(flow.FlowMaster): if self.options.rfile: ret = self.load_flows_path(self.options.rfile) if ret and self.state.flow_count(): - signals.add_event( + signals.add_log( "File truncated or corrupted. " "Loaded as many flows as possible.", "error" @@ -615,7 +568,7 @@ class ConsoleMaster(flow.FlowMaster): if self.state.follow_focus: self.state.set_focus(self.state.flow_count()) - if self.eventlog: + if self.options.eventlog: body = flowlist.BodyPile(self) else: body = flowlist.FlowListBox(self) @@ -652,7 +605,7 @@ class ConsoleMaster(flow.FlowMaster): return path = os.path.expanduser(path) try: - f = file(path, "wb") + f = open(path, "wb") fw = flow.FlowWriter(f) for i in flows: fw.add(i) @@ -705,20 +658,7 @@ class ConsoleMaster(flow.FlowMaster): self.refresh_focus() def edit_scripts(self, scripts): - commands = [x[0] for x in scripts] # remove outer array - if commands == [s.command for s in self.scripts]: - return - - self.unload_scripts() - for command in commands: - try: - self.load_script(command) - except exceptions.ScriptException as e: - signals.status_message.send( - message='Error loading "{}".'.format(command) - ) - signals.add_event('Error loading "{}":\n{}'.format(command, e), "error") - signals.update_settings.send(self) + self.options.scripts = [x[0] for x in scripts] def stop_client_playback_prompt(self, a): if a != "n": @@ -773,7 +713,7 @@ class ConsoleMaster(flow.FlowMaster): signals.flow_change.send(self, flow = f) def clear_events(self): - self.eventlist[:] = [] + self.logbuffer[:] = [] # Handlers @controller.handler @@ -798,8 +738,20 @@ class ConsoleMaster(flow.FlowMaster): return f @controller.handler + def tcp_message(self, f): + super(ConsoleMaster, self).tcp_message(f) + message = f.messages[-1] + direction = "->" if message.from_client else "<-" + self.add_log("{client} {direction} tcp {direction} {server}".format( + client=repr(f.client_conn.address), + server=repr(f.server_conn.address), + direction=direction, + ), "info") + self.add_log(strutils.bytes_to_escaped_str(message.content), "debug") + + @controller.handler def script_change(self, script): if super(ConsoleMaster, self).script_change(script): - signals.status_message.send(message='"{}" reloaded.'.format(script.filename)) + signals.status_message.send(message='"{}" reloaded.'.format(script.path)) else: - signals.status_message.send(message='Error reloading "{}".'.format(script.filename)) + signals.status_message.send(message='Error reloading "{}".'.format(script.path)) diff --git a/mitmproxy/console/options.py b/mitmproxy/console/options.py index 5a01c9d5..e1dd29ee 100644 --- a/mitmproxy/console/options.py +++ b/mitmproxy/console/options.py @@ -36,7 +36,7 @@ class Options(urwid.WidgetWrap): select.Option( "Header Set Patterns", "H", - lambda: master.setheaders.count(), + lambda: len(master.options.setheaders), self.setheaders ), select.Option( @@ -48,13 +48,13 @@ class Options(urwid.WidgetWrap): select.Option( "Replacement Patterns", "R", - lambda: master.replacehooks.count(), + lambda: len(master.options.replacements), self.replacepatterns ), select.Option( "Scripts", "S", - lambda: master.scripts, + lambda: master.options.scripts, self.scripts ), @@ -74,8 +74,8 @@ class Options(urwid.WidgetWrap): select.Option( "Show Host", "w", - lambda: master.showhost, - self.toggle_showhost + lambda: master.options.showhost, + master.options.toggler("showhost") ), select.Heading("Network"), @@ -96,37 +96,37 @@ class Options(urwid.WidgetWrap): select.Option( "Anti-Cache", "a", - lambda: master.anticache, - self.toggle_anticache + lambda: master.options.anticache, + master.options.toggler("anticache") ), select.Option( "Anti-Compression", "o", - lambda: master.anticomp, - self.toggle_anticomp + lambda: master.options.anticomp, + master.options.toggler("anticomp") ), select.Option( "Kill Extra", "x", - lambda: master.killextra, - self.toggle_killextra + lambda: master.options.kill, + master.options.toggler("kill") ), select.Option( "No Refresh", "f", - lambda: not master.refresh_server_playback, - self.toggle_refresh_server_playback + lambda: not master.options.refresh_server_playback, + master.options.toggler("refresh_server_playback") ), select.Option( "Sticky Auth", "A", - lambda: master.stickyauth_txt, + lambda: master.options.stickyauth, self.sticky_auth ), select.Option( "Sticky Cookies", "t", - lambda: master.stickycookie_txt, + lambda: master.options.stickycookie, self.sticky_cookie ), ] @@ -140,6 +140,7 @@ class Options(urwid.WidgetWrap): ) self.master.loop.widget.footer.update("") signals.update_settings.connect(self.sig_update_settings) + master.options.changed.connect(self.sig_update_settings) def sig_update_settings(self, sender): self.lb.walker._modified() @@ -151,19 +152,23 @@ class Options(urwid.WidgetWrap): return super(self.__class__, self).keypress(size, key) def clearall(self): - self.master.anticache = False - self.master.anticomp = False - self.master.killextra = False - self.master.showhost = False - self.master.refresh_server_playback = True self.master.server.config.no_upstream_cert = False - self.master.setheaders.clear() - self.master.replacehooks.clear() self.master.set_ignore_filter([]) self.master.set_tcp_filter([]) - self.master.scripts = [] - self.master.set_stickyauth(None) - self.master.set_stickycookie(None) + + self.master.options.update( + anticache = False, + anticomp = False, + kill = False, + refresh_server_playback = True, + replacements = [], + scripts = [], + setheaders = [], + showhost = False, + stickyauth = None, + stickycookie = None + ) + self.master.state.default_body_view = contentviews.get("Auto") signals.update_settings.send(self) @@ -172,41 +177,22 @@ class Options(urwid.WidgetWrap): expire = 1 ) - def toggle_anticache(self): - self.master.anticache = not self.master.anticache - - def toggle_anticomp(self): - self.master.anticomp = not self.master.anticomp - - def toggle_killextra(self): - self.master.killextra = not self.master.killextra - - def toggle_showhost(self): - self.master.showhost = not self.master.showhost - - def toggle_refresh_server_playback(self): - self.master.refresh_server_playback = not self.master.refresh_server_playback - def toggle_upstream_cert(self): self.master.server.config.no_upstream_cert = not self.master.server.config.no_upstream_cert signals.update_settings.send(self) def setheaders(self): - def _set(*args, **kwargs): - self.master.setheaders.set(*args, **kwargs) - signals.update_settings.send(self) self.master.view_grideditor( grideditor.SetHeadersEditor( self.master, - self.master.setheaders.get_specs(), - _set + self.master.options.setheaders, + self.master.options.setter("setheaders") ) ) def ignorepatterns(self): def _set(ignore): self.master.set_ignore_filter(ignore) - signals.update_settings.send(self) self.master.view_grideditor( grideditor.HostPatternEditor( self.master, @@ -216,14 +202,11 @@ class Options(urwid.WidgetWrap): ) def replacepatterns(self): - def _set(*args, **kwargs): - self.master.replacehooks.set(*args, **kwargs) - signals.update_settings.send(self) self.master.view_grideditor( grideditor.ReplaceEditor( self.master, - self.master.replacehooks.get_specs(), - _set + self.master.options.replacements, + self.master.options.setter("replacements") ) ) @@ -231,7 +214,7 @@ class Options(urwid.WidgetWrap): self.master.view_grideditor( grideditor.ScriptEditor( self.master, - [[i.command] for i in self.master.scripts], + [[i] for i in self.master.options.scripts], self.master.edit_scripts ) ) @@ -261,15 +244,15 @@ class Options(urwid.WidgetWrap): def sticky_auth(self): signals.status_prompt.send( prompt = "Sticky auth filter", - text = self.master.stickyauth_txt, - callback = self.master.set_stickyauth + text = self.master.options.stickyauth, + callback = self.master.options.setter("stickyauth") ) def sticky_cookie(self): signals.status_prompt.send( prompt = "Sticky cookie filter", - text = self.master.stickycookie_txt, - callback = self.master.set_stickycookie + text = self.master.options.stickycookie, + callback = self.master.options.setter("stickycookie") ) def palette(self): diff --git a/mitmproxy/console/palettes.py b/mitmproxy/console/palettes.py index 36cc3ac0..2e12338f 100644 --- a/mitmproxy/console/palettes.py +++ b/mitmproxy/console/palettes.py @@ -24,7 +24,7 @@ class Palette: # List and Connections 'method', 'focus', 'code_200', 'code_300', 'code_400', 'code_500', 'code_other', - 'error', + 'error', "warn", 'header', 'highlight', 'intercept', 'replay', 'mark', # Hex view @@ -100,6 +100,7 @@ class LowDark(Palette): code_500 = ('light red', 'default'), code_other = ('dark red', 'default'), + warn = ('brown', 'default'), error = ('light red', 'default'), header = ('dark cyan', 'default'), @@ -166,6 +167,7 @@ class LowLight(Palette): code_other = ('light red', 'default'), error = ('light red', 'default'), + warn = ('brown', 'default'), header = ('dark blue', 'default'), highlight = ('black,bold', 'default'), @@ -250,6 +252,7 @@ class SolarizedLight(LowLight): code_other = (sol_magenta, 'default'), error = (sol_red, 'default'), + warn = (sol_orange, 'default'), header = (sol_blue, 'default'), highlight = (sol_base01, 'default'), @@ -299,6 +302,7 @@ class SolarizedDark(LowDark): code_other = (sol_magenta, 'default'), error = (sol_red, 'default'), + warn = (sol_orange, 'default'), header = (sol_blue, 'default'), highlight = (sol_base01, 'default'), diff --git a/mitmproxy/console/signals.py b/mitmproxy/console/signals.py index b57ebf0c..97507834 100644 --- a/mitmproxy/console/signals.py +++ b/mitmproxy/console/signals.py @@ -3,11 +3,11 @@ from __future__ import absolute_import, print_function, division import blinker # Show a status message in the action bar -sig_add_event = blinker.Signal() +sig_add_log = blinker.Signal() -def add_event(e, level): - sig_add_event.send( +def add_log(e, level): + sig_add_log.send( None, e=e, level=level diff --git a/mitmproxy/console/statusbar.py b/mitmproxy/console/statusbar.py index e576b565..8f039e48 100644 --- a/mitmproxy/console/statusbar.py +++ b/mitmproxy/console/statusbar.py @@ -28,9 +28,10 @@ class ActionBar(urwid.WidgetWrap): self.pathprompt = False def sig_message(self, sender, message, expire=None): + if self.prompting: + return w = urwid.Text(message) self._w = w - self.prompting = False if expire: def cb(*args): if w == self._w: @@ -116,12 +117,15 @@ class ActionBar(urwid.WidgetWrap): class StatusBar(urwid.WidgetWrap): def __init__(self, master, helptext): - self.master, self.helptext = master, helptext + # type: (mitmproxy.console.master.ConsoleMaster, object) -> None + self.master = master + self.helptext = helptext self.ab = ActionBar() self.ib = urwid.WidgetWrap(urwid.Text("")) - self._w = urwid.Pile([self.ib, self.ab]) + super(StatusBar, self).__init__(urwid.Pile([self.ib, self.ab])) signals.update_settings.connect(self.sig_update_settings) signals.flowlist_change.connect(self.sig_update_settings) + master.options.changed.connect(self.sig_update_settings) self.redraw() def sig_update_settings(self, sender): @@ -133,11 +137,11 @@ class StatusBar(urwid.WidgetWrap): def get_status(self): r = [] - if self.master.setheaders.count(): + if len(self.master.options.setheaders): r.append("[") r.append(("heading_key", "H")) r.append("eaders]") - if self.master.replacehooks.count(): + if len(self.master.options.replacements): r.append("[") r.append(("heading_key", "R")) r.append("eplacing]") @@ -172,29 +176,29 @@ class StatusBar(urwid.WidgetWrap): r.append("[") r.append(("heading_key", "Marked Flows")) r.append("]") - if self.master.stickycookie_txt: + if self.master.options.stickycookie: r.append("[") r.append(("heading_key", "t")) - r.append(":%s]" % self.master.stickycookie_txt) - if self.master.stickyauth_txt: + r.append(":%s]" % self.master.options.stickycookie) + if self.master.options.stickyauth: r.append("[") r.append(("heading_key", "u")) - r.append(":%s]" % self.master.stickyauth_txt) + r.append(":%s]" % self.master.options.stickyauth) if self.master.state.default_body_view.name != "Auto": r.append("[") r.append(("heading_key", "M")) r.append(":%s]" % self.master.state.default_body_view.name) opts = [] - if self.master.anticache: + if self.master.options.anticache: opts.append("anticache") - if self.master.anticomp: + if self.master.options.anticomp: opts.append("anticomp") - if self.master.showhost: + if self.master.options.showhost: opts.append("showhost") - if not self.master.refresh_server_playback: + if not self.master.options.refresh_server_playback: opts.append("norefresh") - if self.master.killextra: + if self.master.options.kill: opts.append("killextra") if self.master.server.config.no_upstream_cert: opts.append("no-upstream-cert") @@ -217,14 +221,13 @@ class StatusBar(urwid.WidgetWrap): dst.address.host, dst.address.port )) - if self.master.scripts: + if self.master.options.scripts: r.append("[") r.append(("heading_key", "s")) - r.append("cripts:%s]" % len(self.master.scripts)) - # r.append("[lt:%0.3f]"%self.master.looptime) + r.append("cripts:%s]" % len(self.master.options.scripts)) - if self.master.stream: - r.append("[W:%s]" % self.master.stream_path) + if self.master.options.outfile: + r.append("[W:%s]" % self.master.options.outfile[0]) return r diff --git a/mitmproxy/contentviews.py b/mitmproxy/contentviews.py index 7c9e4ba1..afdaad7f 100644 --- a/mitmproxy/contentviews.py +++ b/mitmproxy/contentviews.py @@ -31,7 +31,6 @@ from six import BytesIO from mitmproxy import exceptions from mitmproxy.contrib import jsbeautifier from mitmproxy.contrib.wbxml import ASCommandResponse -from netlib import encoding from netlib import http from netlib import multidict from netlib.http import url @@ -143,11 +142,11 @@ class ViewAuto(View): ct = "%s/%s" % (ct[0], ct[1]) if ct in content_types_map: return content_types_map[ct][0](data, **metadata) - elif strutils.isXML(data.decode()): + elif strutils.is_xml(data): return get("XML")(data, **metadata) if metadata.get("query"): return get("Query")(data, **metadata) - if data and strutils.isMostlyBin(data.decode()): + if data and strutils.is_mostly_bin(data): return get("Hex")(data) if not data: return "No content", [] @@ -160,7 +159,7 @@ class ViewRaw(View): content_types = [] def __call__(self, data, **metadata): - return "Raw", format_text(strutils.bytes_to_escaped_str(data)) + return "Raw", format_text(strutils.bytes_to_escaped_str(data, True)) class ViewHex(View): @@ -226,7 +225,10 @@ class ViewXML(View): class ViewJSON(View): name = "JSON" prompt = ("json", "s") - content_types = ["application/json"] + content_types = [ + "application/json", + "application/vnd.api+json" + ] def __call__(self, data, **metadata): pj = pretty_json(data) @@ -240,7 +242,7 @@ class ViewHTML(View): content_types = ["text/html"] def __call__(self, data, **metadata): - if strutils.isXML(data.decode()): + if strutils.is_xml(data): parser = lxml.etree.HTMLParser( strip_cdata=True, remove_blank_text=True @@ -597,10 +599,9 @@ def safe_to_print(lines, encoding="utf8"): for line in lines: clean_line = [] for (style, text) in line: - try: - text = strutils.clean_bin(text.decode(encoding, "strict")) - except UnicodeDecodeError: - text = strutils.clean_bin(text).decode(encoding, "strict") + if isinstance(text, bytes): + text = text.decode(encoding, "replace") + text = strutils.escape_control_characters(text) clean_line.append((style, text)) yield clean_line @@ -618,15 +619,6 @@ def get_content_view(viewmode, data, **metadata): Raises: ContentViewException, if the content view threw an error. """ - msg = [] - - headers = metadata.get("headers", {}) - enc = headers.get("content-encoding") - if enc and enc != "identity": - decoded = encoding.decode(enc, data) - if decoded: - data = decoded - msg.append("[decoded %s]" % enc) try: ret = viewmode(data, **metadata) # Third-party viewers can fail in unexpected ways... @@ -637,8 +629,8 @@ def get_content_view(viewmode, data, **metadata): sys.exc_info()[2] ) if not ret: - ret = get("Raw")(data, **metadata) - msg.append("Couldn't parse: falling back to Raw") + desc = "Couldn't parse: falling back to Raw" + _, content = get("Raw")(data, **metadata) else: - msg.append(ret[0]) - return " ".join(msg), safe_to_print(ret[1]) + desc, content = ret + return desc, safe_to_print(content) diff --git a/mitmproxy/contrib/tnetstring.py b/mitmproxy/contrib/tnetstring.py index 9bf20b09..d99a83f9 100644 --- a/mitmproxy/contrib/tnetstring.py +++ b/mitmproxy/contrib/tnetstring.py @@ -1,100 +1,67 @@ -# imported from the tnetstring project: https://github.com/rfk/tnetstring -# -# Copyright (c) 2011 Ryan Kelly -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -# THE SOFTWARE. """ tnetstring: data serialization using typed netstrings ====================================================== +This is a custom Python 3 implementation of tnetstrings. +Compared to other implementations, the main difference +is that this implementation supports a custom unicode datatype. -This is a data serialization library. It's a lot like JSON but it uses a -new syntax called "typed netstrings" that Zed has proposed for use in the -Mongrel2 webserver. It's designed to be simpler and easier to implement -than JSON, with a happy consequence of also being faster in many cases. - -An ordinary netstring is a blob of data prefixed with its length and postfixed -with a sanity-checking comma. The string "hello world" encodes like this:: +An ordinary tnetstring is a blob of data prefixed with its length and postfixed +with its type. Here are some examples: + >>> tnetstring.dumps("hello world") 11:hello world, - -Typed netstrings add other datatypes by replacing the comma with a type tag. -Here's the integer 12345 encoded as a tnetstring:: - + >>> tnetstring.dumps(12345) 5:12345# - -And here's the list [12345,True,0] which mixes integers and bools:: - + >>> tnetstring.dumps([12345, True, 0]) 19:5:12345#4:true!1:0#] -Simple enough? This module gives you the following functions: +This module gives you the following functions: :dump: dump an object as a tnetstring to a file :dumps: dump an object as a tnetstring to a string :load: load a tnetstring-encoded object from a file :loads: load a tnetstring-encoded object from a string - :pop: pop a tnetstring-encoded object from the front of a string Note that since parsing a tnetstring requires reading all the data into memory at once, there's no efficiency gain from using the file-based versions of these functions. They're only here so you can use load() to read precisely one item from a file or socket without consuming any extra data. -By default tnetstrings work only with byte strings, not unicode. If you want -unicode strings then pass an optional encoding to the various functions, -like so:: +The tnetstrings specification explicitly states that strings are binary blobs +and forbids the use of unicode at the protocol level. +**This implementation decodes dictionary keys as surrogate-escaped ASCII**, +all other strings are returned as plain bytes. - >>> print(repr(tnetstring.loads("2:\\xce\\xb1,"))) - '\\xce\\xb1' - >>> - >>> print(repr(tnetstring.loads("2:\\xce\\xb1,","utf8"))) - u'\u03b1' +:Copyright: (c) 2012-2013 by Ryan Kelly <ryan@rfk.id.au>. +:Copyright: (c) 2014 by Carlo Pires <carlopires@gmail.com>. +:Copyright: (c) 2016 by Maximilian Hils <tnetstring3@maximilianhils.com>. +:License: MIT """ -from collections import deque +import collections import six +from typing import io, Union, Tuple # noqa -__ver_major__ = 0 -__ver_minor__ = 2 -__ver_patch__ = 0 -__ver_sub__ = "" -__version__ = "%d.%d.%d%s" % ( - __ver_major__, __ver_minor__, __ver_patch__, __ver_sub__) +TSerializable = Union[None, bool, int, float, bytes, list, tuple, dict] def dumps(value): + # type: (TSerializable) -> bytes """ This function dumps a python object as a tnetstring. """ # This uses a deque to collect output fragments in reverse order, # then joins them together at the end. It's measurably faster # than creating all the intermediate strings. - # If you're reading this to get a handle on the tnetstring format, - # consider the _gdumps() function instead; it's a standard top-down - # generator that's simpler to understand but much less efficient. - q = deque() + q = collections.deque() _rdumpq(q, 0, value) return b''.join(q) def dump(value, file_handle): + # type: (TSerializable, io.BinaryIO) -> None """ This function dumps a python object as a tnetstring and writes it to the given file. @@ -103,6 +70,7 @@ def dump(value, file_handle): def _rdumpq(q, size, value): + # type: (collections.deque, int, TSerializable) -> int """ Dump value as a tnetstring, to a deque instance, last chunks first. @@ -132,10 +100,7 @@ def _rdumpq(q, size, value): data = str(value).encode() ldata = len(data) span = str(ldata).encode() - write(b'#') - write(data) - write(b':') - write(span) + write(b'%s:%s#' % (span, data)) return size + 2 + len(span) + ldata elif isinstance(value, float): # Use repr() for float rather than str(). @@ -145,19 +110,26 @@ def _rdumpq(q, size, value): data = repr(value).encode() ldata = len(data) span = str(ldata).encode() - write(b'^') + write(b'%s:%s^' % (span, data)) + return size + 2 + len(span) + ldata + elif isinstance(value, bytes): + data = value + ldata = len(data) + span = str(ldata).encode() + write(b',') write(data) write(b':') write(span) return size + 2 + len(span) + ldata - elif isinstance(value, bytes): - lvalue = len(value) - span = str(lvalue).encode() - write(b',') - write(value) + elif isinstance(value, six.text_type): + data = value.encode("utf8") + ldata = len(data) + span = str(ldata).encode() + write(b';') + write(data) write(b':') write(span) - return size + 2 + len(span) + lvalue + return size + 2 + len(span) + ldata elif isinstance(value, (list, tuple)): write(b']') init_size = size = size + 1 @@ -181,73 +153,16 @@ def _rdumpq(q, size, value): raise ValueError("unserializable object: {} ({})".format(value, type(value))) -def _gdumps(value): - """ - Generate fragments of value dumped as a tnetstring. - - This is the naive dumping algorithm, implemented as a generator so that - it's easy to pass to "".join() without building a new list. - - This is mainly here for comparison purposes; the _rdumpq version is - measurably faster as it doesn't have to build intermediate strins. - """ - if value is None: - yield b'0:~' - elif value is True: - yield b'4:true!' - elif value is False: - yield b'5:false!' - elif isinstance(value, six.integer_types): - data = str(value).encode() - yield str(len(data)).encode() - yield b':' - yield data - yield b'#' - elif isinstance(value, float): - data = repr(value).encode() - yield str(len(data)).encode() - yield b':' - yield data - yield b'^' - elif isinstance(value, bytes): - yield str(len(value)).encode() - yield b':' - yield value - yield b',' - elif isinstance(value, (list, tuple)): - sub = [] - for item in value: - sub.extend(_gdumps(item)) - sub = b''.join(sub) - yield str(len(sub)).encode() - yield b':' - yield sub - yield b']' - elif isinstance(value, (dict,)): - sub = [] - for (k, v) in value.items(): - sub.extend(_gdumps(k)) - sub.extend(_gdumps(v)) - sub = b''.join(sub) - yield str(len(sub)).encode() - yield b':' - yield sub - yield b'}' - else: - raise ValueError("unserializable object") - - def loads(string): + # type: (bytes) -> TSerializable """ This function parses a tnetstring into a python object. """ - # No point duplicating effort here. In the C-extension version, - # loads() is measurably faster then pop() since it can avoid - # the overhead of building a second string. return pop(string)[0] def load(file_handle): + # type: (io.BinaryIO) -> TSerializable """load(file) -> object This function reads a tnetstring from a file and parses it into a @@ -257,119 +172,89 @@ def load(file_handle): # Read the length prefix one char at a time. # Note that the netstring spec explicitly forbids padding zeros. c = file_handle.read(1) - if not c.isdigit(): - raise ValueError("not a tnetstring: missing or invalid length prefix") - datalen = ord(c) - ord('0') - c = file_handle.read(1) - if datalen != 0: - while c.isdigit(): - datalen = (10 * datalen) + (ord(c) - ord('0')) - if datalen > 999999999: - errmsg = "not a tnetstring: absurdly large length prefix" - raise ValueError(errmsg) - c = file_handle.read(1) - if c != b':': + data_length = b"" + while c.isdigit(): + data_length += c + if len(data_length) > 9: + raise ValueError("not a tnetstring: absurdly large length prefix") + c = file_handle.read(1) + if c != b":": raise ValueError("not a tnetstring: missing or invalid length prefix") - # Now we can read and parse the payload. - # This repeats the dispatch logic of pop() so we can avoid - # re-constructing the outermost tnetstring. - data = file_handle.read(datalen) - if len(data) != datalen: - raise ValueError("not a tnetstring: length prefix too big") - tns_type = file_handle.read(1) - if tns_type == b',': + + data = file_handle.read(int(data_length)) + data_type = file_handle.read(1)[0] + + return parse(data_type, data) + + +def parse(data_type, data): + if six.PY2: + data_type = ord(data_type) + # type: (int, bytes) -> TSerializable + if data_type == ord(b','): return data - if tns_type == b'#': + if data_type == ord(b';'): + return data.decode("utf8") + if data_type == ord(b'#'): try: + if six.PY2: + return long(data) return int(data) except ValueError: - raise ValueError("not a tnetstring: invalid integer literal") - if tns_type == b'^': + raise ValueError("not a tnetstring: invalid integer literal: {}".format(data)) + if data_type == ord(b'^'): try: return float(data) except ValueError: - raise ValueError("not a tnetstring: invalid float literal") - if tns_type == b'!': + raise ValueError("not a tnetstring: invalid float literal: {}".format(data)) + if data_type == ord(b'!'): if data == b'true': return True elif data == b'false': return False else: - raise ValueError("not a tnetstring: invalid boolean literal") - if tns_type == b'~': + raise ValueError("not a tnetstring: invalid boolean literal: {}".format(data)) + if data_type == ord(b'~'): if data: raise ValueError("not a tnetstring: invalid null literal") return None - if tns_type == b']': + if data_type == ord(b']'): l = [] while data: item, data = pop(data) l.append(item) return l - if tns_type == b'}': + if data_type == ord(b'}'): d = {} while data: key, data = pop(data) val, data = pop(data) d[key] = val return d - raise ValueError("unknown type tag") - + raise ValueError("unknown type tag: {}".format(data_type)) -def pop(string): - """pop(string,encoding='utf_8') -> (object, remain) +def pop(data): + # type: (bytes) -> Tuple[TSerializable, bytes] + """ This function parses a tnetstring into a python object. It returns a tuple giving the parsed object and a string containing any unparsed data from the end of the string. """ # Parse out data length, type and remaining string. try: - dlen, rest = string.split(b':', 1) - dlen = int(dlen) + length, data = data.split(b':', 1) + length = int(length) except ValueError: - raise ValueError("not a tnetstring: missing or invalid length prefix: {}".format(string)) + raise ValueError("not a tnetstring: missing or invalid length prefix: {}".format(data)) try: - data, tns_type, remain = rest[:dlen], rest[dlen:dlen + 1], rest[dlen + 1:] + data, data_type, remain = data[:length], data[length], data[length + 1:] except IndexError: - # This fires if len(rest) < dlen, meaning we don't need + # This fires if len(data) < dlen, meaning we don't need # to further validate that data is the right length. - raise ValueError("not a tnetstring: invalid length prefix: {}".format(dlen)) - # Parse the data based on the type tag. - if tns_type == b',': - return data, remain - if tns_type == b'#': - try: - return int(data), remain - except ValueError: - raise ValueError("not a tnetstring: invalid integer literal: {}".format(data)) - if tns_type == b'^': - try: - return float(data), remain - except ValueError: - raise ValueError("not a tnetstring: invalid float literal: {}".format(data)) - if tns_type == b'!': - if data == b'true': - return True, remain - elif data == b'false': - return False, remain - else: - raise ValueError("not a tnetstring: invalid boolean literal: {}".format(data)) - if tns_type == b'~': - if data: - raise ValueError("not a tnetstring: invalid null literal") - return None, remain - if tns_type == b']': - l = [] - while data: - item, data = pop(data) - l.append(item) - return (l, remain) - if tns_type == b'}': - d = {} - while data: - key, data = pop(data) - val, data = pop(data) - d[key] = val - return d, remain - raise ValueError("unknown type tag: {}".format(tns_type)) + raise ValueError("not a tnetstring: invalid length prefix: {}".format(length)) + # Parse the data based on the type tag. + return parse(data_type, data), remain + + +__all__ = ["dump", "dumps", "load", "loads", "pop"] diff --git a/mitmproxy/controller.py b/mitmproxy/controller.py index a170d868..070ec862 100644 --- a/mitmproxy/controller.py +++ b/mitmproxy/controller.py @@ -2,11 +2,14 @@ from __future__ import absolute_import, print_function, division import functools import threading +import contextlib from six.moves import queue +from mitmproxy import addons +from mitmproxy import options +from . import ctx as mitmproxy_ctx from netlib import basethread - from . import exceptions @@ -30,21 +33,67 @@ Events = frozenset([ "error", "log", + "start", + "configure", + "done", + "tick", + "script_change", ]) +class Log(object): + def __init__(self, master): + self.master = master + + def __call__(self, text, level="info"): + self.master.add_log(text, level) + + def debug(self, txt): + self(txt, "debug") + + def info(self, txt): + self(txt, "info") + + def warn(self, txt): + self(txt, "warn") + + def error(self, txt): + self(txt, "error") + + class Master(object): """ The master handles mitmproxy's main event loop. """ - def __init__(self, *servers): + def __init__(self, opts, *servers): + self.options = opts or options.Options() + self.addons = addons.Addons(self) self.event_queue = queue.Queue() self.should_exit = threading.Event() self.servers = [] for i in servers: self.add_server(i) + @contextlib.contextmanager + def handlecontext(self): + # Handlecontexts also have to nest - leave cleanup to the outermost + if mitmproxy_ctx.master: + yield + return + mitmproxy_ctx.master = self + mitmproxy_ctx.log = Log(self) + try: + yield + finally: + mitmproxy_ctx.master = None + mitmproxy_ctx.log = None + + def add_log(self, e, level="info"): + """ + level: debug, info, warn, error + """ + def add_server(self, server): # We give a Channel to the server which can be used to communicate with the master channel = Channel(self.event_queue, self.should_exit) @@ -68,26 +117,25 @@ class Master(object): self.shutdown() def tick(self, timeout): + with self.handlecontext(): + self.addons("tick") changed = False try: - # This endless loop runs until the 'Queue.Empty' - # exception is thrown. - while True: - mtype, obj = self.event_queue.get(timeout=timeout) - if mtype not in Events: - raise exceptions.ControlException("Unknown event %s" % repr(mtype)) - handle_func = getattr(self, mtype) - if not hasattr(handle_func, "__dict__"): - raise exceptions.ControlException("Handler %s not a function" % mtype) - if not handle_func.__dict__.get("__handler"): - raise exceptions.ControlException( - "Handler function %s is not decorated with controller.handler" % ( - handle_func - ) + mtype, obj = self.event_queue.get(timeout=timeout) + if mtype not in Events: + raise exceptions.ControlException("Unknown event %s" % repr(mtype)) + handle_func = getattr(self, mtype) + if not callable(handle_func): + raise exceptions.ControlException("Handler %s not callable" % mtype) + if not handle_func.__dict__.get("__handler"): + raise exceptions.ControlException( + "Handler function %s is not decorated with controller.handler" % ( + handle_func ) - handle_func(obj) - self.event_queue.task_done() - changed = True + ) + handle_func(obj) + self.event_queue.task_done() + changed = True except queue.Empty: pass return changed @@ -96,6 +144,7 @@ class Master(object): for server in self.servers: server.shutdown() self.should_exit.set() + self.addons.done() class ServerThread(basethread.BaseThread): @@ -151,15 +200,7 @@ class Channel(object): def handler(f): @functools.wraps(f) - def wrapper(*args, **kwargs): - # We can either be called as a method, or as a wrapped solo function - if len(args) == 1: - message = args[0] - elif len(args) == 2: - message = args[1] - else: - raise exceptions.ControlException("Handler takes one argument: a message") - + def wrapper(master, message): if not hasattr(message, "reply"): raise exceptions.ControlException("Message %s has no reply attribute" % message) @@ -172,10 +213,19 @@ def handler(f): handling = True message.reply.handled = True - ret = f(*args, **kwargs) + with master.handlecontext(): + ret = f(master, message) + if handling: + master.addons(f.__name__, message) if handling and not message.reply.acked and not message.reply.taken: message.reply.ack() + + # Reset the handled flag - it's common for us to feed the same object + # through handlers repeatedly, so we don't want this to persist across + # calls. + if message.reply.handled: + message.reply.handled = False return ret # Mark this function as a handler wrapper wrapper.__dict__["__handler"] = True @@ -216,7 +266,7 @@ class Reply(object): def __del__(self): if not self.acked: # This will be ignored by the interpreter, but emit a warning - raise exceptions.ControlException("Un-acked message") + raise exceptions.ControlException("Un-acked message: %s" % self.obj) class DummyReply(object): diff --git a/mitmproxy/ctx.py b/mitmproxy/ctx.py new file mode 100644 index 00000000..fcfdfd0b --- /dev/null +++ b/mitmproxy/ctx.py @@ -0,0 +1,4 @@ +from typing import Callable # noqa + +master = None # type: "mitmproxy.flow.FlowMaster" +log = None # type: Callable[[str], None] diff --git a/mitmproxy/dump.py b/mitmproxy/dump.py index cc6896ed..eaa368a0 100644 --- a/mitmproxy/dump.py +++ b/mitmproxy/dump.py @@ -1,75 +1,50 @@ from __future__ import absolute_import, print_function, division -import itertools import sys -import traceback + +from typing import Optional # noqa +import typing # noqa import click -from mitmproxy import contentviews from mitmproxy import controller from mitmproxy import exceptions -from mitmproxy import filt from mitmproxy import flow -from netlib import human +from mitmproxy import builtins +from mitmproxy import utils +from mitmproxy.builtins import dumper from netlib import tcp -from netlib import strutils class DumpError(Exception): pass -class Options(object): - attributes = [ - "app", - "app_host", - "app_port", - "anticache", - "anticomp", - "client_replay", - "filtstr", - "flow_detail", - "keepserving", - "kill", - "no_server", - "nopop", - "refresh_server_playback", - "replacements", - "rfile", - "rheaders", - "setheaders", - "server_replay", - "scripts", - "showhost", - "stickycookie", - "stickyauth", - "stream_large_bodies", - "verbosity", - "outfile", - "replay_ignore_content", - "replay_ignore_params", - "replay_ignore_payload_params", - "replay_ignore_host" - ] - - def __init__(self, **kwargs): - for k, v in kwargs.items(): - setattr(self, k, v) - for i in self.attributes: - if not hasattr(self, i): - setattr(self, i, None) +class Options(flow.options.Options): + def __init__( + self, + keepserving=False, # type: bool + filtstr=None, # type: Optional[str] + flow_detail=1, # type: int + tfile=None, # type: Optional[typing.io.TextIO] + **kwargs + ): + self.filtstr = filtstr + self.flow_detail = flow_detail + self.keepserving = keepserving + self.tfile = tfile + super(Options, self).__init__(**kwargs) class DumpMaster(flow.FlowMaster): - def __init__(self, server, options, outfile=None): - flow.FlowMaster.__init__(self, server, flow.State()) - self.outfile = outfile - self.o = options - self.anticache = options.anticache - self.anticomp = options.anticomp - self.showhost = options.showhost + def __init__(self, server, options): + flow.FlowMaster.__init__(self, options, server, flow.State()) + self.has_errored = False + self.addons.add(*builtins.default_addons()) + self.addons.add(dumper.Dumper()) + # This line is just for type hinting + self.options = self.options # type: Options self.replay_ignore_params = options.replay_ignore_params self.replay_ignore_content = options.replay_ignore_content self.replay_ignore_host = options.replay_ignore_host @@ -83,34 +58,6 @@ class DumpMaster(flow.FlowMaster): "HTTP/2 is disabled. Use --no-http2 to silence this warning.", file=sys.stderr) - if options.filtstr: - self.filt = filt.parse(options.filtstr) - else: - self.filt = None - - if options.stickycookie: - self.set_stickycookie(options.stickycookie) - - if options.stickyauth: - self.set_stickyauth(options.stickyauth) - - if options.outfile: - err = self.start_stream_to_path( - options.outfile[0], - options.outfile[1], - self.filt - ) - if err: - raise DumpError(err) - - if options.replacements: - for i in options.replacements: - self.replacehooks.add(*i) - - if options.setheaders: - for i in options.setheaders: - self.setheaders.add(*i) - if options.server_replay: self.start_server_playback( self._readflow(options.server_replay), @@ -129,22 +76,15 @@ class DumpMaster(flow.FlowMaster): not options.keepserving ) - scripts = options.scripts or [] - for command in scripts: - try: - self.load_script(command, use_reloader=True) - except exceptions.ScriptException as e: - raise DumpError(str(e)) - if options.rfile: try: self.load_flows_file(options.rfile) except exceptions.FlowReadException as v: - self.add_event("Flow file corrupted.", "error") + self.add_log("Flow file corrupted.", "error") raise DumpError(v) - if self.o.app: - self.start_app(self.o.app_host, self.o.app_port) + if self.options.app: + self.start_app(self.options.app_host, self.options.app_port) def _readflow(self, paths): """ @@ -156,204 +96,26 @@ class DumpMaster(flow.FlowMaster): except exceptions.FlowReadException as e: raise DumpError(str(e)) - def add_event(self, e, level="info"): - needed = dict(error=0, info=1, debug=2).get(level, 1) - if self.o.verbosity >= needed: - self.echo( + def add_log(self, e, level="info"): + if level == "error": + self.has_errored = True + if self.options.verbosity >= utils.log_tier(level): + click.secho( e, + file=self.options.tfile, fg="red" if level == "error" else None, dim=(level == "debug"), err=(level == "error") ) - @staticmethod - def indent(n, text): - l = str(text).strip().splitlines() - pad = " " * n - return "\n".join(pad + i for i in l) - - def echo(self, text, indent=None, **style): - if indent: - text = self.indent(indent, text) - click.secho(text, file=self.outfile, **style) - - def _echo_message(self, message): - if self.o.flow_detail >= 2: - headers = "\r\n".join( - "{}: {}".format( - click.style(strutils.bytes_to_escaped_str(k), fg="blue", bold=True), - click.style(strutils.bytes_to_escaped_str(v), fg="blue")) - for k, v in message.headers.fields - ) - self.echo(headers, indent=4) - if self.o.flow_detail >= 3: - if message.content is None: - self.echo("(content missing)", indent=4) - elif message.content: - self.echo("") - - try: - type, lines = contentviews.get_content_view( - contentviews.get("Auto"), - message.content, - headers=message.headers - ) - except exceptions.ContentViewException: - s = "Content viewer failed: \n" + traceback.format_exc() - self.add_event(s, "debug") - type, lines = contentviews.get_content_view( - contentviews.get("Raw"), - message.content, - headers=message.headers - ) - - styles = dict( - highlight=dict(bold=True), - offset=dict(fg="blue"), - header=dict(fg="green", bold=True), - text=dict(fg="green") - ) - - def colorful(line): - yield u" " # we can already indent here - for (style, text) in line: - yield click.style(text, **styles.get(style, {})) - - if self.o.flow_detail == 3: - lines_to_echo = itertools.islice(lines, 70) - else: - lines_to_echo = lines - - lines_to_echo = list(lines_to_echo) - - content = u"\r\n".join( - u"".join(colorful(line)) for line in lines_to_echo - ) - - self.echo(content) - if next(lines, None): - self.echo("(cut off)", indent=4, dim=True) - - if self.o.flow_detail >= 2: - self.echo("") - - def _echo_request_line(self, flow): - if flow.request.stickycookie: - stickycookie = click.style("[stickycookie] ", fg="yellow", bold=True) - else: - stickycookie = "" - - if flow.client_conn: - client = click.style(strutils.bytes_to_escaped_str(flow.client_conn.address.host), bold=True) - else: - client = click.style("[replay]", fg="yellow", bold=True) - - method = flow.request.method - method_color = dict( - GET="green", - DELETE="red" - ).get(method.upper(), "magenta") - method = click.style(strutils.bytes_to_escaped_str(method), fg=method_color, bold=True) - if self.showhost: - url = flow.request.pretty_url - else: - url = flow.request.url - url = click.style(strutils.bytes_to_escaped_str(url), bold=True) - - httpversion = "" - if flow.request.http_version not in ("HTTP/1.1", "HTTP/1.0"): - httpversion = " " + flow.request.http_version # We hide "normal" HTTP 1. - - line = "{stickycookie}{client} {method} {url}{httpversion}".format( - stickycookie=stickycookie, - client=client, - method=method, - url=url, - httpversion=httpversion - ) - self.echo(line) - - def _echo_response_line(self, flow): - if flow.response.is_replay: - replay = click.style("[replay] ", fg="yellow", bold=True) - else: - replay = "" - - code = flow.response.status_code - code_color = None - if 200 <= code < 300: - code_color = "green" - elif 300 <= code < 400: - code_color = "magenta" - elif 400 <= code < 600: - code_color = "red" - code = click.style(str(code), fg=code_color, bold=True, blink=(code == 418)) - reason = click.style(strutils.bytes_to_escaped_str(flow.response.reason), fg=code_color, bold=True) - - if flow.response.content is None: - size = "(content missing)" - else: - size = human.pretty_size(len(flow.response.content)) - size = click.style(size, bold=True) - - arrows = click.style("<<", bold=True) - - line = "{replay} {arrows} {code} {reason} {size}".format( - replay=replay, - arrows=arrows, - code=code, - reason=reason, - size=size - ) - self.echo(line) - - def echo_flow(self, f): - if self.o.flow_detail == 0: - return - - if f.request: - self._echo_request_line(f) - self._echo_message(f.request) - - if f.response: - self._echo_response_line(f) - self._echo_message(f.response) - - if f.error: - self.echo(" << {}".format(f.error.msg), bold=True, fg="red") - - if self.outfile: - self.outfile.flush() - - def _process_flow(self, f): - if self.filt and not f.match(self.filt): - return - - self.echo_flow(f) - @controller.handler def request(self, f): - f = flow.FlowMaster.request(self, f) + f = super(DumpMaster, self).request(f) if f: self.state.delete_flow(f) return f - @controller.handler - def response(self, f): - f = flow.FlowMaster.response(self, f) - if f: - self._process_flow(f) - return f - - @controller.handler - def error(self, f): - flow.FlowMaster.error(self, f) - if f: - self._process_flow(f) - return f - def run(self): # pragma: no cover - if self.o.rfile and not self.o.keepserving: - self.unload_scripts() # make sure to trigger script unload events. + if self.options.rfile and not self.options.keepserving: return super(DumpMaster, self).run() diff --git a/mitmproxy/exceptions.py b/mitmproxy/exceptions.py index 63bd8d3d..3b41fe1c 100644 --- a/mitmproxy/exceptions.py +++ b/mitmproxy/exceptions.py @@ -95,3 +95,11 @@ class FlowReadException(ProxyException): class ControlException(ProxyException): pass + + +class OptionsError(Exception): + pass + + +class AddonError(Exception): + pass diff --git a/mitmproxy/filt.py b/mitmproxy/filt.py index d98e3749..8b647b22 100644 --- a/mitmproxy/filt.py +++ b/mitmproxy/filt.py @@ -35,10 +35,26 @@ from __future__ import absolute_import, print_function, division import re import sys +import functools + +from mitmproxy.models.http import HTTPFlow +from mitmproxy.models.tcp import TCPFlow +from netlib import strutils import pyparsing as pp +def only(*types): + def decorator(fn): + @functools.wraps(fn) + def filter_types(self, flow): + if isinstance(flow, types): + return fn(self, flow) + return False + return filter_types + return decorator + + class _Token(object): def dump(self, indent=0, fp=sys.stdout): @@ -64,10 +80,29 @@ class FErr(_Action): return True if f.error else False +class FHTTP(_Action): + code = "http" + help = "Match HTTP flows" + + @only(HTTPFlow) + def __call__(self, f): + return True + + +class FTCP(_Action): + code = "tcp" + help = "Match TCP flows" + + @only(TCPFlow) + def __call__(self, f): + return True + + class FReq(_Action): code = "q" help = "Match request with no response" + @only(HTTPFlow) def __call__(self, f): if not f.response: return True @@ -77,40 +112,47 @@ class FResp(_Action): code = "s" help = "Match response" + @only(HTTPFlow) def __call__(self, f): - return True if f.response else False + return bool(f.response) class _Rex(_Action): flags = 0 + is_binary = True def __init__(self, expr): self.expr = expr + if self.is_binary: + expr = strutils.escaped_str_to_bytes(expr) try: - self.re = re.compile(self.expr, self.flags) + self.re = re.compile(expr, self.flags) except: raise ValueError("Cannot compile expression.") -def _check_content_type(expr, o): - val = o.headers.get("content-type") - if val and re.search(expr, val): - return True - return False +def _check_content_type(rex, message): + return any( + name.lower() == b"content-type" and + rex.search(value) + for name, value in message.headers.fields + ) class FAsset(_Action): code = "a" help = "Match asset in response: CSS, Javascript, Flash, images." ASSET_TYPES = [ - "text/javascript", - "application/x-javascript", - "application/javascript", - "text/css", - "image/.*", - "application/x-shockwave-flash" + b"text/javascript", + b"application/x-javascript", + b"application/javascript", + b"text/css", + b"image/.*", + b"application/x-shockwave-flash" ] + ASSET_TYPES = [re.compile(x) for x in ASSET_TYPES] + @only(HTTPFlow) def __call__(self, f): if f.response: for i in self.ASSET_TYPES: @@ -123,29 +165,32 @@ class FContentType(_Rex): code = "t" help = "Content-type header" + @only(HTTPFlow) def __call__(self, f): - if _check_content_type(self.expr, f.request): + if _check_content_type(self.re, f.request): return True - elif f.response and _check_content_type(self.expr, f.response): + elif f.response and _check_content_type(self.re, f.response): return True return False -class FRequestContentType(_Rex): +class FContentTypeRequest(_Rex): code = "tq" help = "Request Content-Type header" + @only(HTTPFlow) def __call__(self, f): - return _check_content_type(self.expr, f.request) + return _check_content_type(self.re, f.request) -class FResponseContentType(_Rex): +class FContentTypeResponse(_Rex): code = "ts" help = "Response Content-Type header" + @only(HTTPFlow) def __call__(self, f): if f.response: - return _check_content_type(self.expr, f.response) + return _check_content_type(self.re, f.response) return False @@ -154,6 +199,7 @@ class FHead(_Rex): help = "Header" flags = re.MULTILINE + @only(HTTPFlow) def __call__(self, f): if f.request and self.re.search(bytes(f.request.headers)): return True @@ -167,6 +213,7 @@ class FHeadRequest(_Rex): help = "Request header" flags = re.MULTILINE + @only(HTTPFlow) def __call__(self, f): if f.request and self.re.search(bytes(f.request.headers)): return True @@ -177,6 +224,7 @@ class FHeadResponse(_Rex): help = "Response header" flags = re.MULTILINE + @only(HTTPFlow) def __call__(self, f): if f.response and self.re.search(bytes(f.response.headers)): return True @@ -186,13 +234,19 @@ class FBod(_Rex): code = "b" help = "Body" + @only(HTTPFlow, TCPFlow) def __call__(self, f): - if f.request and f.request.content: - if self.re.search(f.request.get_decoded_content()): - return True - if f.response and f.response.content: - if self.re.search(f.response.get_decoded_content()): - return True + if isinstance(f, HTTPFlow): + if f.request and f.request.raw_content: + if self.re.search(f.request.get_content(strict=False)): + return True + if f.response and f.response.raw_content: + if self.re.search(f.response.get_content(strict=False)): + return True + elif isinstance(f, TCPFlow): + for msg in f.messages: + if self.re.search(msg.content): + return True return False @@ -200,20 +254,32 @@ class FBodRequest(_Rex): code = "bq" help = "Request body" + @only(HTTPFlow, TCPFlow) def __call__(self, f): - if f.request and f.request.content: - if self.re.search(f.request.get_decoded_content()): - return True + if isinstance(f, HTTPFlow): + if f.request and f.request.raw_content: + if self.re.search(f.request.get_content(strict=False)): + return True + elif isinstance(f, TCPFlow): + for msg in f.messages: + if msg.from_client and self.re.search(msg.content): + return True class FBodResponse(_Rex): code = "bs" help = "Response body" + @only(HTTPFlow, TCPFlow) def __call__(self, f): - if f.response and f.response.content: - if self.re.search(f.response.get_decoded_content()): - return True + if isinstance(f, HTTPFlow): + if f.response and f.response.raw_content: + if self.re.search(f.response.get_content(strict=False)): + return True + elif isinstance(f, TCPFlow): + for msg in f.messages: + if not msg.from_client and self.re.search(msg.content): + return True class FMethod(_Rex): @@ -221,8 +287,9 @@ class FMethod(_Rex): help = "Method" flags = re.IGNORECASE + @only(HTTPFlow) def __call__(self, f): - return bool(self.re.search(f.request.method)) + return bool(self.re.search(f.request.data.method)) class FDomain(_Rex): @@ -230,13 +297,15 @@ class FDomain(_Rex): help = "Domain" flags = re.IGNORECASE + @only(HTTPFlow) def __call__(self, f): - return bool(self.re.search(f.request.host)) + return bool(self.re.search(f.request.data.host)) class FUrl(_Rex): code = "u" help = "URL" + is_binary = False # FUrl is special, because it can be "naked". @classmethod @@ -245,6 +314,7 @@ class FUrl(_Rex): toks = toks[1:] return klass(*toks) + @only(HTTPFlow) def __call__(self, f): return self.re.search(f.request.url) @@ -252,6 +322,7 @@ class FUrl(_Rex): class FSrc(_Rex): code = "src" help = "Match source address" + is_binary = False def __call__(self, f): return f.client_conn.address and self.re.search(repr(f.client_conn.address)) @@ -260,6 +331,7 @@ class FSrc(_Rex): class FDst(_Rex): code = "dst" help = "Match destination address" + is_binary = False def __call__(self, f): return f.server_conn.address and self.re.search(repr(f.server_conn.address)) @@ -275,6 +347,7 @@ class FCode(_Int): code = "c" help = "HTTP response code" + @only(HTTPFlow) def __call__(self, f): if f.response and f.response.status_code == self.num: return True @@ -322,26 +395,28 @@ class FNot(_Token): filt_unary = [ + FAsset, + FErr, + FHTTP, FReq, FResp, - FAsset, - FErr + FTCP, ] filt_rex = [ - FHeadRequest, - FHeadResponse, - FHead, + FBod, FBodRequest, FBodResponse, - FBod, - FMethod, - FDomain, - FUrl, - FRequestContentType, - FResponseContentType, FContentType, - FSrc, + FContentTypeRequest, + FContentTypeResponse, + FDomain, FDst, + FHead, + FHeadRequest, + FHeadResponse, + FMethod, + FSrc, + FUrl, ] filt_int = [ FCode diff --git a/mitmproxy/flow/__init__.py b/mitmproxy/flow/__init__.py index c14a0fec..b2ab74c6 100644 --- a/mitmproxy/flow/__init__.py +++ b/mitmproxy/flow/__init__.py @@ -4,10 +4,10 @@ from mitmproxy.flow import export, modules from mitmproxy.flow.io import FlowWriter, FilteredFlowWriter, FlowReader, read_flows_from_paths from mitmproxy.flow.master import FlowMaster from mitmproxy.flow.modules import ( - AppRegistry, ReplaceHooks, SetHeaders, StreamLargeBodies, ClientPlaybackState, - ServerPlaybackState, StickyCookieState, StickyAuthState + AppRegistry, StreamLargeBodies, ClientPlaybackState, ServerPlaybackState ) from mitmproxy.flow.state import State, FlowView +from mitmproxy.flow import options # TODO: We may want to remove the imports from .modules and just expose "modules" @@ -15,7 +15,6 @@ __all__ = [ "export", "modules", "FlowWriter", "FilteredFlowWriter", "FlowReader", "read_flows_from_paths", "FlowMaster", - "AppRegistry", "ReplaceHooks", "SetHeaders", "StreamLargeBodies", "ClientPlaybackState", - "ServerPlaybackState", "StickyCookieState", "StickyAuthState", - "State", "FlowView", + "AppRegistry", "StreamLargeBodies", "ClientPlaybackState", + "ServerPlaybackState", "State", "FlowView", "options", ] diff --git a/mitmproxy/flow/export.py b/mitmproxy/flow/export.py index f0ac02ab..deeeb998 100644 --- a/mitmproxy/flow/export.py +++ b/mitmproxy/flow/export.py @@ -4,32 +4,46 @@ import json import re from textwrap import dedent -from six.moves.urllib.parse import quote, quote_plus +import six +from six.moves import urllib import netlib.http +def _native(s): + if six.PY2: + if isinstance(s, six.text_type): + return s.encode() + else: + if isinstance(s, six.binary_type): + return s.decode() + return s + + def dictstr(items, indent): lines = [] for k, v in items: - lines.append(indent + "%s: %s,\n" % (repr(k), repr(v))) + lines.append(indent + "%s: %s,\n" % (repr(_native(k)), repr(_native(v)))) return "{\n%s}\n" % "".join(lines) def curl_command(flow): data = "curl " - for k, v in flow.request.headers.fields: + request = flow.request.copy() + request.decode(strict=False) + + for k, v in request.headers.items(multi=True): data += "-H '%s:%s' " % (k, v) - if flow.request.method != "GET": - data += "-X %s " % flow.request.method + if request.method != "GET": + data += "-X %s " % request.method - full_url = flow.request.scheme + "://" + flow.request.host + flow.request.path + full_url = request.scheme + "://" + request.host + request.path data += "'%s'" % full_url - if flow.request.content: - data += " --data-binary '%s'" % flow.request.content + if request.content: + data += " --data-binary '%s'" % _native(request.content) return data @@ -48,7 +62,7 @@ def python_code(flow): print(response.text) """).strip() - components = map(lambda x: quote(x, safe=""), flow.request.path_components) + components = [urllib.parse.quote(c, safe="") for c in flow.request.path_components] url = flow.request.scheme + "://" + flow.request.host + "/" + "/".join(components) args = "" @@ -64,12 +78,12 @@ def python_code(flow): data = "" if flow.request.body: - json_obj = is_json(flow.request.headers, flow.request.body) + json_obj = is_json(flow.request.headers, flow.request.content) if json_obj: data = "\njson = %s\n" % dictstr(sorted(json_obj.items()), " ") args += "\n json=json," else: - data = "\ndata = '''%s'''\n" % flow.request.body + data = "\ndata = '''%s'''\n" % _native(flow.request.content) args += "\n data=data," code = code.format( @@ -85,15 +99,16 @@ def python_code(flow): def raw_request(flow): data = netlib.http.http1.assemble_request(flow.request) - return data + return _native(data) def is_json(headers, content): + # type: (netlib.http.Headers, bytes) -> bool if headers: ct = netlib.http.parse_content_type(headers.get("content-type", "")) if ct and "%s/%s" % (ct[0], ct[1]) == "application/json": try: - return json.loads(content) + return json.loads(content.decode("utf8", "surrogateescape")) except ValueError: return False return False @@ -126,17 +141,21 @@ def locust_code(flow): max_wait = 3000 """).strip() - components = map(lambda x: quote(x, safe=""), flow.request.path_components) - file_name = "_".join(components) - name = re.sub('\W|^(?=\d)', '_', file_name) - url = flow.request.scheme + "://" + flow.request.host + "/" + "/".join(components) + components = [urllib.parse.quote(c, safe="") for c in flow.request.path_components] + name = re.sub('\W|^(?=\d)', '_', "_".join(components)) if name == "" or name is None: new_name = "_".join([str(flow.request.host), str(flow.request.timestamp_start)]) name = re.sub('\W|^(?=\d)', '_', new_name) + + url = flow.request.scheme + "://" + flow.request.host + "/" + "/".join(components) + args = "" headers = "" if flow.request.headers: - lines = [(k, v) for k, v in flow.request.headers.fields if k.lower() not in ["host", "cookie"]] + lines = [ + (_native(k), _native(v)) for k, v in flow.request.headers.fields + if _native(k).lower() not in ["host", "cookie"] + ] lines = [" '%s': '%s',\n" % (k, v) for k, v in lines] headers += "\n headers = {\n%s }\n" % "".join(lines) args += "\n headers=headers," @@ -148,8 +167,8 @@ def locust_code(flow): args += "\n params=params," data = "" - if flow.request.body: - data = "\n data = '''%s'''\n" % flow.request.body + if flow.request.content: + data = "\n data = '''%s'''\n" % _native(flow.request.content) args += "\n data=data," code = code.format( @@ -164,8 +183,8 @@ def locust_code(flow): host = flow.request.scheme + "://" + flow.request.host code = code.replace(host, "' + self.locust.host + '") - code = code.replace(quote_plus(host), "' + quote_plus(self.locust.host) + '") - code = code.replace(quote(host), "' + quote(self.locust.host) + '") + code = code.replace(urllib.parse.quote_plus(host), "' + quote_plus(self.locust.host) + '") + code = code.replace(urllib.parse.quote(host), "' + quote(self.locust.host) + '") code = code.replace("'' + ", "") return code diff --git a/mitmproxy/flow/io.py b/mitmproxy/flow/io.py index 671ddf43..276d7a5b 100644 --- a/mitmproxy/flow/io.py +++ b/mitmproxy/flow/io.py @@ -49,7 +49,7 @@ class FlowReader: yield models.FLOW_TYPES[data["type"]].from_state(data) except ValueError: # Error is due to EOF - if can_tell and self.fo.tell() == off and self.fo.read() == '': + if can_tell and self.fo.tell() == off and self.fo.read() == b'': return raise exceptions.FlowReadException("Invalid data format.") diff --git a/mitmproxy/flow/io_compat.py b/mitmproxy/flow/io_compat.py index 1023e87f..bcfbd375 100644 --- a/mitmproxy/flow/io_compat.py +++ b/mitmproxy/flow/io_compat.py @@ -3,51 +3,112 @@ This module handles the import of mitmproxy flows generated by old versions. """ from __future__ import absolute_import, print_function, division -from netlib import version +import six + +from netlib import version, strutils + + +def convert_011_012(data): + data[b"version"] = (0, 12) + return data + + +def convert_012_013(data): + data[b"version"] = (0, 13) + return data def convert_013_014(data): - data["request"]["first_line_format"] = data["request"].pop("form_in") - data["request"]["http_version"] = "HTTP/" + ".".join(str(x) for x in data["request"].pop("httpversion")) - data["response"]["status_code"] = data["response"].pop("code") - data["response"]["body"] = data["response"].pop("content") - data["server_conn"].pop("state") - data["server_conn"]["via"] = None - data["version"] = (0, 14) + data[b"request"][b"first_line_format"] = data[b"request"].pop(b"form_in") + data[b"request"][b"http_version"] = b"HTTP/" + ".".join( + str(x) for x in data[b"request"].pop(b"httpversion")).encode() + data[b"response"][b"http_version"] = b"HTTP/" + ".".join( + str(x) for x in data[b"response"].pop(b"httpversion")).encode() + data[b"response"][b"status_code"] = data[b"response"].pop(b"code") + data[b"response"][b"body"] = data[b"response"].pop(b"content") + data[b"server_conn"].pop(b"state") + data[b"server_conn"][b"via"] = None + data[b"version"] = (0, 14) return data def convert_014_015(data): - data["version"] = (0, 15) + data[b"version"] = (0, 15) return data def convert_015_016(data): - for m in ("request", "response"): - if "body" in data[m]: - data[m]["content"] = data[m].pop("body") - if "httpversion" in data[m]: - data[m]["http_version"] = data[m].pop("httpversion") - if "msg" in data["response"]: - data["response"]["reason"] = data["response"].pop("msg") - data["request"].pop("form_out", None) - data["version"] = (0, 16) + for m in (b"request", b"response"): + if b"body" in data[m]: + data[m][b"content"] = data[m].pop(b"body") + if b"msg" in data[b"response"]: + data[b"response"][b"reason"] = data[b"response"].pop(b"msg") + data[b"request"].pop(b"form_out", None) + data[b"version"] = (0, 16) return data def convert_016_017(data): - data["server_conn"]["peer_address"] = None - data["version"] = (0, 17) + data[b"server_conn"][b"peer_address"] = None + data[b"version"] = (0, 17) return data def convert_017_018(data): + # convert_unicode needs to be called for every dual release and the first py3-only release + data = convert_unicode(data) + data["server_conn"]["ip_address"] = data["server_conn"].pop("peer_address") data["version"] = (0, 18) return data +def _convert_dict_keys(o): + # type: (Any) -> Any + if isinstance(o, dict): + return {strutils.native(k): _convert_dict_keys(v) for k, v in o.items()} + else: + return o + + +def _convert_dict_vals(o, values_to_convert): + # type: (dict, dict) -> dict + for k, v in values_to_convert.items(): + if not o or k not in o: + continue + if v is True: + o[k] = strutils.native(o[k]) + else: + _convert_dict_vals(o[k], v) + return o + + +def convert_unicode(data): + # type: (dict) -> dict + """ + The Python 2 version of mitmproxy serializes everything as bytes. + This method converts between Python 3 and Python 2 dumpfiles. + """ + if not six.PY2: + data = _convert_dict_keys(data) + data = _convert_dict_vals( + data, { + "type": True, + "id": True, + "request": { + "first_line_format": True + }, + "error": { + "msg": True + } + } + ) + return data + + converters = { + (0, 11): convert_011_012, + (0, 12): convert_012_013, (0, 13): convert_013_014, (0, 14): convert_014_015, (0, 15): convert_015_016, @@ -58,14 +119,17 @@ converters = { def migrate_flow(flow_data): while True: - flow_version = tuple(flow_data["version"][:2]) - if flow_version == version.IVERSION[:2]: + flow_version = tuple(flow_data.get(b"version", flow_data.get("version"))) + if flow_version[:2] == version.IVERSION[:2]: break - elif flow_version in converters: - flow_data = converters[flow_version](flow_data) + elif flow_version[:2] in converters: + flow_data = converters[flow_version[:2]](flow_data) else: - v = ".".join(str(i) for i in flow_data["version"]) + v = ".".join(str(i) for i in flow_version) raise ValueError( "{} cannot read files serialized with version {}.".format(version.MITMPROXY, v) ) + # TODO: This should finally be moved in the converter for the first py3-only release. + # It's here so that a py2 0.18 dump can be read by py3 0.18 and vice versa. + flow_data = convert_unicode(flow_data) return flow_data diff --git a/mitmproxy/flow/master.py b/mitmproxy/flow/master.py index 289102a1..64a242ba 100644 --- a/mitmproxy/flow/master.py +++ b/mitmproxy/flow/master.py @@ -8,15 +8,12 @@ from typing import List, Optional, Set # noqa import netlib.exceptions from mitmproxy import controller from mitmproxy import exceptions -from mitmproxy import filt from mitmproxy import models -from mitmproxy import script from mitmproxy.flow import io from mitmproxy.flow import modules from mitmproxy.onboarding import app from mitmproxy.protocol import http_replay from mitmproxy.proxy.config import HostMatcher -from netlib import strutils class FlowMaster(controller.Master): @@ -28,35 +25,20 @@ class FlowMaster(controller.Master): if len(self.servers) > 0: return self.servers[0] - def __init__(self, server, state): - super(FlowMaster, self).__init__() + def __init__(self, options, server, state): + super(FlowMaster, self).__init__(options) if server: self.add_server(server) self.state = state - self.active_flows = set() # type: Set[models.Flow] self.server_playback = None # type: Optional[modules.ServerPlaybackState] self.client_playback = None # type: Optional[modules.ClientPlaybackState] self.kill_nonreplay = False - self.scripts = [] # type: List[script.Script] - self.pause_scripts = False - self.stickycookie_state = None # type: Optional[modules.StickyCookieState] - self.stickycookie_txt = None - - self.stickyauth_state = False # type: Optional[modules.StickyAuthState] - self.stickyauth_txt = None - - self.anticache = False - self.anticomp = False self.stream_large_bodies = None # type: Optional[modules.StreamLargeBodies] - self.refresh_server_playback = False - self.replacehooks = modules.ReplaceHooks() - self.setheaders = modules.SetHeaders() self.replay_ignore_params = False self.replay_ignore_content = None self.replay_ignore_host = False - self.stream = None self.apps = modules.AppRegistry() def start_app(self, host, port): @@ -66,48 +48,6 @@ class FlowMaster(controller.Master): port ) - def add_event(self, e, level="info"): - """ - level: debug, info, error - """ - - def unload_scripts(self): - for s in self.scripts[:]: - self.unload_script(s) - - def unload_script(self, script_obj): - try: - script_obj.unload() - except script.ScriptException as e: - self.add_event("Script error:\n" + str(e), "error") - script.reloader.unwatch(script_obj) - self.scripts.remove(script_obj) - - def load_script(self, command, use_reloader=False): - """ - Loads a script. - - Raises: - ScriptException - """ - s = script.Script(command, script.ScriptContext(self)) - s.load() - if use_reloader: - script.reloader.watch(s, lambda: self.event_queue.put(("script_change", s))) - self.scripts.append(s) - - def _run_single_script_hook(self, script_obj, name, *args, **kwargs): - if script_obj and not self.pause_scripts: - try: - script_obj.run(name, *args, **kwargs) - except script.ScriptException as e: - self.add_event("Script error:\n{}".format(e), "error") - - def run_scripts(self, name, msg): - for script_obj in self.scripts: - if not msg.reply.acked: - self._run_single_script_hook(script_obj, name, msg) - def get_ignore_filter(self): return self.server.config.check_ignore.patterns @@ -120,34 +60,12 @@ class FlowMaster(controller.Master): def set_tcp_filter(self, host_patterns): self.server.config.check_tcp = HostMatcher(host_patterns) - def set_stickycookie(self, txt): - if txt: - flt = filt.parse(txt) - if not flt: - return "Invalid filter expression." - self.stickycookie_state = modules.StickyCookieState(flt) - self.stickycookie_txt = txt - else: - self.stickycookie_state = None - self.stickycookie_txt = None - def set_stream_large_bodies(self, max_size): if max_size is not None: self.stream_large_bodies = modules.StreamLargeBodies(max_size) else: self.stream_large_bodies = False - def set_stickyauth(self, txt): - if txt: - flt = filt.parse(txt) - if not flt: - return "Invalid filter expression." - self.stickyauth_state = modules.StickyAuthState(flt) - self.stickyauth_txt = txt - else: - self.stickyauth_state = None - self.stickyauth_txt = None - def start_client_playback(self, flows, exit): """ flows: List of flows. @@ -201,7 +119,7 @@ class FlowMaster(controller.Master): return None response = rflow.response.copy() response.is_replay = True - if self.refresh_server_playback: + if self.options.refresh_server_playback: response.refresh() flow.response = response return True @@ -235,8 +153,12 @@ class FlowMaster(controller.Master): return super(FlowMaster, self).tick(timeout) def duplicate_flow(self, f): + """ + Duplicate flow, and insert it into state without triggering any of + the normal flow events. + """ f2 = f.copy() - self.load_flow(f2) + self.state.add_flow(f2) return f2 def create_request(self, method, scheme, host, port, path): @@ -321,47 +243,36 @@ class FlowMaster(controller.Master): raise exceptions.FlowReadException(v.strerror) def process_new_request(self, f): - if self.stickycookie_state: - self.stickycookie_state.handle_request(f) - if self.stickyauth_state: - self.stickyauth_state.handle_request(f) - - if self.anticache: - f.request.anticache() - if self.anticomp: - f.request.anticomp() - if self.server_playback: pb = self.do_server_playback(f) if not pb and self.kill_nonreplay: f.kill(self) - def process_new_response(self, f): - if self.stickycookie_state: - self.stickycookie_state.handle_response(f) - - def replay_request(self, f, block=False, run_scripthooks=True): + def replay_request(self, f, block=False): """ Returns None if successful, or error message if not. """ - if f.live and run_scripthooks: + if f.live: return "Can't replay live request." if f.intercepted: return "Can't replay while intercepting..." - if f.request.content is None: + if f.request.raw_content is None: return "Can't replay request with missing content..." if f.request: f.backup() f.request.is_replay = True + + # TODO: We should be able to remove this. if "Content-Length" in f.request.headers: - f.request.headers["Content-Length"] = str(len(f.request.content)) + f.request.headers["Content-Length"] = str(len(f.request.raw_content)) + f.response = None f.error = None self.process_new_request(f) rt = http_replay.RequestReplayThread( self.server.config, f, - self.event_queue if run_scripthooks else False, + self.event_queue, self.should_exit ) rt.start() # pragma: no cover @@ -370,32 +281,31 @@ class FlowMaster(controller.Master): @controller.handler def log(self, l): - self.add_event(l.msg, l.level) + self.add_log(l.msg, l.level) @controller.handler def clientconnect(self, root_layer): - self.run_scripts("clientconnect", root_layer) + pass @controller.handler def clientdisconnect(self, root_layer): - self.run_scripts("clientdisconnect", root_layer) + pass @controller.handler def serverconnect(self, server_conn): - self.run_scripts("serverconnect", server_conn) + pass @controller.handler def serverdisconnect(self, server_conn): - self.run_scripts("serverdisconnect", server_conn) + pass @controller.handler def next_layer(self, top_layer): - self.run_scripts("next_layer", top_layer) + pass @controller.handler def error(self, f): self.state.update_flow(f) - self.run_scripts("error", f) if self.client_playback: self.client_playback.clear(f) return f @@ -411,20 +321,13 @@ class FlowMaster(controller.Master): **{"mitmproxy.master": self} ) if err: - self.add_event("Error in wsgi app. %s" % err, "error") + self.add_log("Error in wsgi app. %s" % err, "error") f.reply.kill() return if f not in self.state.flows: # don't add again on replay self.state.add_flow(f) - self.active_flows.add(f) - if not f.reply.acked: - self.replacehooks.run(f) - if not f.reply.acked: - self.setheaders.run(f) if not f.reply.acked: self.process_new_request(f) - if not f.reply.acked: - self.run_scripts("request", f) return f @controller.handler @@ -435,24 +338,14 @@ class FlowMaster(controller.Master): except netlib.exceptions.HttpException: f.reply.kill() return - self.run_scripts("responseheaders", f) return f @controller.handler def response(self, f): - self.active_flows.discard(f) self.state.update_flow(f) if not f.reply.acked: - self.replacehooks.run(f) - if not f.reply.acked: - self.setheaders.run(f) - self.run_scripts("response", f) - if not f.reply.acked: if self.client_playback: self.client_playback.clear(f) - self.process_new_response(f) - if self.stream: - self.stream.add(f) return f def handle_intercept(self, f): @@ -462,91 +355,22 @@ class FlowMaster(controller.Master): self.state.update_flow(f) @controller.handler - def script_change(self, s): - """ - Handle a script whose contents have been changed on the file system. - - Args: - s (script.Script): the changed script - - Returns: - True, if reloading was successful. - False, otherwise. - """ - ok = True - # We deliberately do not want to fail here. - # In the worst case, we have an "empty" script object. - try: - s.unload() - except script.ScriptException as e: - ok = False - self.add_event('Error reloading "{}":\n{}'.format(s.filename, e), 'error') - try: - s.load() - except script.ScriptException as e: - ok = False - self.add_event('Error reloading "{}":\n{}'.format(s.filename, e), 'error') - else: - self.add_event('"{}" reloaded.'.format(s.filename), 'info') - return ok - - @controller.handler def tcp_open(self, flow): # TODO: This would break mitmproxy currently. # self.state.add_flow(flow) - self.active_flows.add(flow) - self.run_scripts("tcp_open", flow) + pass @controller.handler def tcp_message(self, flow): - self.run_scripts("tcp_message", flow) - message = flow.messages[-1] - direction = "->" if message.from_client else "<-" - self.add_event("{client} {direction} tcp {direction} {server}".format( - client=repr(flow.client_conn.address), - server=repr(flow.server_conn.address), - direction=direction, - ), "info") - self.add_event(strutils.clean_bin(message.content), "debug") + pass @controller.handler def tcp_error(self, flow): - self.add_event("Error in TCP connection to {}: {}".format( + self.add_log("Error in TCP connection to {}: {}".format( repr(flow.server_conn.address), flow.error ), "info") - self.run_scripts("tcp_error", flow) @controller.handler def tcp_close(self, flow): - self.active_flows.discard(flow) - if self.stream: - self.stream.add(flow) - self.run_scripts("tcp_close", flow) - - def shutdown(self): - super(FlowMaster, self).shutdown() - - # Add all flows that are still active - if self.stream: - for flow in self.active_flows: - self.stream.add(flow) - self.stop_stream() - - self.unload_scripts() - - def start_stream(self, fp, filt): - self.stream = io.FilteredFlowWriter(fp, filt) - - def stop_stream(self): - self.stream.fo.close() - self.stream = None - - def start_stream_to_path(self, path, mode="wb", filt=None): - path = os.path.expanduser(path) - try: - f = open(path, mode) - self.start_stream(f, filt) - except IOError as v: - return str(v) - self.stream_path = path + pass diff --git a/mitmproxy/flow/modules.py b/mitmproxy/flow/modules.py index 2998d259..fb3c52da 100644 --- a/mitmproxy/flow/modules.py +++ b/mitmproxy/flow/modules.py @@ -1,17 +1,13 @@ from __future__ import absolute_import, print_function, division -import collections import hashlib -import re -from six.moves import http_cookiejar from six.moves import urllib from mitmproxy import controller -from mitmproxy import filt from netlib import wsgi from netlib import version -from netlib.http import cookies +from netlib import strutils from netlib.http import http1 @@ -42,112 +38,6 @@ class AppRegistry: return self.apps.get((host, request.port), None) -class ReplaceHooks: - def __init__(self): - self.lst = [] - - def set(self, r): - self.clear() - for i in r: - self.add(*i) - - def add(self, fpatt, rex, s): - """ - add a replacement hook. - - fpatt: a string specifying a filter pattern. - rex: a regular expression. - s: the replacement string - - returns true if hook was added, false if the pattern could not be - parsed. - """ - cpatt = filt.parse(fpatt) - if not cpatt: - return False - try: - re.compile(rex) - except re.error: - return False - self.lst.append((fpatt, rex, s, cpatt)) - return True - - def get_specs(self): - """ - Retrieve the hook specifcations. Returns a list of (fpatt, rex, s) - tuples. - """ - return [i[:3] for i in self.lst] - - def count(self): - return len(self.lst) - - def run(self, f): - for _, rex, s, cpatt in self.lst: - if cpatt(f): - if f.response: - f.response.replace(rex, s) - else: - f.request.replace(rex, s) - - def clear(self): - self.lst = [] - - -class SetHeaders: - def __init__(self): - self.lst = [] - - def set(self, r): - self.clear() - for i in r: - self.add(*i) - - def add(self, fpatt, header, value): - """ - Add a set header hook. - - fpatt: String specifying a filter pattern. - header: Header name. - value: Header value string - - Returns True if hook was added, False if the pattern could not be - parsed. - """ - cpatt = filt.parse(fpatt) - if not cpatt: - return False - self.lst.append((fpatt, header, value, cpatt)) - return True - - def get_specs(self): - """ - Retrieve the hook specifcations. Returns a list of (fpatt, rex, s) - tuples. - """ - return [i[:3] for i in self.lst] - - def count(self): - return len(self.lst) - - def clear(self): - self.lst = [] - - def run(self, f): - for _, header, value, cpatt in self.lst: - if cpatt(f): - if f.response: - f.response.headers.pop(header, None) - else: - f.request.headers.pop(header, None) - for _, header, value, cpatt in self.lst: - if cpatt(f): - if f.response: - f.response.headers.add(header, value) - else: - f.request.headers.add(header, value) - - class StreamLargeBodies(object): def __init__(self, max_size): self.max_size = max_size @@ -157,7 +47,7 @@ class StreamLargeBodies(object): expected_size = http1.expected_http_body_size( flow.request, flow.response if not is_request else None ) - if not r.content and not (0 <= expected_size <= self.max_size): + if not r.raw_content and not (0 <= expected_size <= self.max_size): # r.stream may already be a callable, which we want to preserve. r.stream = r.stream or True @@ -216,7 +106,7 @@ class ServerPlaybackState: self.nopop = nopop self.ignore_params = ignore_params self.ignore_content = ignore_content - self.ignore_payload_params = ignore_payload_params + self.ignore_payload_params = [strutils.always_bytes(x) for x in (ignore_payload_params or ())] self.ignore_host = ignore_host self.fmap = {} for i in flows: @@ -251,7 +141,7 @@ class ServerPlaybackState: if p[0] not in self.ignore_payload_params ) else: - key.append(str(r.content)) + key.append(str(r.raw_content)) if not self.ignore_host: key.append(r.host) @@ -271,7 +161,7 @@ class ServerPlaybackState: v = r.headers.get(i) headers.append((i, v)) key.append(headers) - return hashlib.sha256(repr(key)).digest() + return hashlib.sha256(repr(key).encode("utf8", "surrogateescape")).digest() def next_flow(self, request): """ @@ -286,73 +176,3 @@ class ServerPlaybackState: return l[0] else: return l.pop(0) - - -class StickyCookieState: - def __init__(self, flt): - """ - flt: Compiled filter. - """ - self.jar = collections.defaultdict(dict) - self.flt = flt - - def ckey(self, attrs, f): - """ - Returns a (domain, port, path) tuple. - """ - domain = f.request.host - path = "/" - if "domain" in attrs: - domain = attrs["domain"] - if "path" in attrs: - path = attrs["path"] - return (domain, f.request.port, path) - - def domain_match(self, a, b): - if http_cookiejar.domain_match(a, b): - return True - elif http_cookiejar.domain_match(a, b.strip(".")): - return True - return False - - def handle_response(self, f): - for name, (value, attrs) in f.response.cookies.items(multi=True): - # FIXME: We now know that Cookie.py screws up some cookies with - # valid RFC 822/1123 datetime specifications for expiry. Sigh. - a = self.ckey(attrs, f) - if self.domain_match(f.request.host, a[0]): - b = attrs.with_insert(0, name, value) - self.jar[a][name] = b - - def handle_request(self, f): - l = [] - if f.match(self.flt): - for domain, port, path in self.jar.keys(): - match = [ - self.domain_match(f.request.host, domain), - f.request.port == port, - f.request.path.startswith(path) - ] - if all(match): - c = self.jar[(domain, port, path)] - l.extend([cookies.format_cookie_header(c[name].items(multi=True)) for name in c.keys()]) - if l: - f.request.stickycookie = True - f.request.headers["cookie"] = "; ".join(l) - - -class StickyAuthState: - def __init__(self, flt): - """ - flt: Compiled filter. - """ - self.flt = flt - self.hosts = {} - - def handle_request(self, f): - host = f.request.host - if "authorization" in f.request.headers: - self.hosts[host] = f.request.headers["authorization"] - elif f.match(self.flt): - if host in self.hosts: - f.request.headers["authorization"] = self.hosts[host] diff --git a/mitmproxy/flow/options.py b/mitmproxy/flow/options.py new file mode 100644 index 00000000..6c2e3933 --- /dev/null +++ b/mitmproxy/flow/options.py @@ -0,0 +1,69 @@ +from __future__ import absolute_import, print_function, division +from mitmproxy import options +from typing import Tuple, Optional, Sequence # noqa + +APP_HOST = "mitm.it" +APP_PORT = 80 + + +class Options(options.Options): + def __init__( + self, + # TODO: rename to onboarding_app_* + app=True, # type: bool + app_host=APP_HOST, # type: str + app_port=APP_PORT, # type: int + anticache=False, # type: bool + anticomp=False, # type: bool + client_replay=None, # type: Optional[str] + kill=False, # type: bool + no_server=False, # type: bool + nopop=False, # type: bool + refresh_server_playback=False, # type: bool + rfile=None, # type: Optional[str] + scripts=(), # type: Sequence[str] + showhost=False, # type: bool + replacements=(), # type: Sequence[Tuple[str, str, str]] + rheaders=(), # type: Sequence[str] + setheaders=(), # type: Sequence[Tuple[str, str, str]] + server_replay=None, # type: Optional[str] + stickycookie=None, # type: Optional[str] + stickyauth=None, # type: Optional[str] + stream_large_bodies=None, # type: Optional[str] + verbosity=2, # type: int + outfile=None, # type: Tuple[str, str] + replay_ignore_content=False, # type: bool + replay_ignore_params=(), # type: Sequence[str] + replay_ignore_payload_params=(), # type: Sequence[str] + replay_ignore_host=False, # type: bool + ): + # We could replace all assignments with clever metaprogramming, + # but type hints are a much more valueable asset. + + self.app = app + self.app_host = app_host + self.app_port = app_port + self.anticache = anticache + self.anticomp = anticomp + self.client_replay = client_replay + self.kill = kill + self.no_server = no_server + self.nopop = nopop + self.refresh_server_playback = refresh_server_playback + self.rfile = rfile + self.scripts = scripts + self.showhost = showhost + self.replacements = replacements + self.rheaders = rheaders + self.setheaders = setheaders + self.server_replay = server_replay + self.stickycookie = stickycookie + self.stickyauth = stickyauth + self.stream_large_bodies = stream_large_bodies + self.verbosity = verbosity + self.outfile = outfile + self.replay_ignore_content = replay_ignore_content + self.replay_ignore_params = replay_ignore_params + self.replay_ignore_payload_params = replay_ignore_payload_params + self.replay_ignore_host = replay_ignore_host + super(Options, self).__init__() diff --git a/mitmproxy/main.py b/mitmproxy/main.py index bf01a3cb..316db91a 100644 --- a/mitmproxy/main.py +++ b/mitmproxy/main.py @@ -76,7 +76,11 @@ def mitmproxy(args=None): # pragma: no cover server = get_server(console_options.no_server, proxy_config) - m = console.master.ConsoleMaster(server, console_options) + try: + m = console.master.ConsoleMaster(server, console_options) + except exceptions.OptionsError as e: + print("mitmproxy: %s" % e, file=sys.stderr) + sys.exit(1) try: m.run() except (KeyboardInterrupt, _thread.error): @@ -109,11 +113,14 @@ def mitmdump(args=None): # pragma: no cover signal.signal(signal.SIGTERM, cleankill) master.run() - except dump.DumpError as e: + except (dump.DumpError, exceptions.OptionsError) as e: print("mitmdump: %s" % e, file=sys.stderr) sys.exit(1) except (KeyboardInterrupt, _thread.error): pass + if master.has_errored: + print("mitmdump: errors occurred during run", file=sys.stderr) + sys.exit(1) def mitmweb(args=None): # pragma: no cover @@ -137,7 +144,11 @@ def mitmweb(args=None): # pragma: no cover server = get_server(web_options.no_server, proxy_config) - m = web.master.WebMaster(server, web_options) + try: + m = web.master.WebMaster(server, web_options) + except exceptions.OptionsError as e: + print("mitmweb: %s" % e, file=sys.stderr) + sys.exit(1) try: m.run() except (KeyboardInterrupt, _thread.error): diff --git a/mitmproxy/models/connections.py b/mitmproxy/models/connections.py index b8e0567a..570e89a9 100644 --- a/mitmproxy/models/connections.py +++ b/mitmproxy/models/connections.py @@ -205,6 +205,8 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): self.wfile.flush() def establish_ssl(self, clientcerts, sni, **kwargs): + if sni and not isinstance(sni, six.string_types): + raise ValueError("sni must be str, not " + type(sni).__name__) clientcert = None if clientcerts: if os.path.isfile(clientcerts): @@ -212,7 +214,7 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): else: path = os.path.join( clientcerts, - self.address.host.encode("idna")) + ".pem" + self.address.host.encode("idna").decode()) + ".pem" if os.path.exists(path): clientcert = path diff --git a/mitmproxy/models/flow.py b/mitmproxy/models/flow.py index 0e4f80cb..f4993b7a 100644 --- a/mitmproxy/models/flow.py +++ b/mitmproxy/models/flow.py @@ -9,6 +9,7 @@ from mitmproxy.models.connections import ClientConnection from mitmproxy.models.connections import ServerConnection from netlib import version +from typing import Optional # noqa class Error(stateobject.StateObject): @@ -70,18 +71,13 @@ class Flow(stateobject.StateObject): def __init__(self, type, client_conn, server_conn, live=None): self.type = type self.id = str(uuid.uuid4()) - self.client_conn = client_conn - """@type: ClientConnection""" - self.server_conn = server_conn - """@type: ServerConnection""" + self.client_conn = client_conn # type: ClientConnection + self.server_conn = server_conn # type: ServerConnection self.live = live - """@type: LiveConnection""" - self.error = None - """@type: Error""" - self.intercepted = False - """@type: bool""" - self._backup = None + self.error = None # type: Error + self.intercepted = False # type: bool + self._backup = None # type: Optional[Flow] self.reply = None _stateobject_attributes = dict( diff --git a/mitmproxy/models/http.py b/mitmproxy/models/http.py index 01f5f1ee..1fd28f00 100644 --- a/mitmproxy/models/http.py +++ b/mitmproxy/models/http.py @@ -1,9 +1,10 @@ from __future__ import absolute_import, print_function, division import cgi +import warnings +import six from mitmproxy.models.flow import Flow -from netlib import encoding from netlib import version from netlib.http import Headers from netlib.http import Request @@ -20,10 +21,8 @@ class MessageMixin(object): header. Doesn't change the message iteself or its headers. """ - ce = self.headers.get("content-encoding") - if not self.content or ce not in encoding.ENCODINGS: - return self.content - return encoding.decode(ce, self.content) + warnings.warn(".get_decoded_content() is deprecated, please use .content directly instead.", DeprecationWarning) + return self.content class HTTPRequest(MessageMixin, Request): @@ -220,7 +219,7 @@ class HTTPFlow(Flow): If f is a string, it will be compiled as a filter expression. If the expression is invalid, ValueError is raised. """ - if isinstance(f, basestring): + if isinstance(f, six.string_types): from .. import filt f = filt.parse(f) diff --git a/mitmproxy/models/tcp.py b/mitmproxy/models/tcp.py index e33475c2..6650141d 100644 --- a/mitmproxy/models/tcp.py +++ b/mitmproxy/models/tcp.py @@ -7,6 +7,8 @@ from typing import List import netlib.basetypes from mitmproxy.models.flow import Flow +import six + class TCPMessage(netlib.basetypes.Serializable): @@ -53,3 +55,22 @@ class TCPFlow(Flow): def __repr__(self): return "<TCPFlow ({} messages)>".format(len(self.messages)) + + def match(self, f): + """ + Match this flow against a compiled filter expression. Returns True + if matched, False if not. + + If f is a string, it will be compiled as a filter expression. If + the expression is invalid, ValueError is raised. + """ + if isinstance(f, six.string_types): + from .. import filt + + f = filt.parse(f) + if not f: + raise ValueError("Invalid filter expression.") + if f: + return f(self) + + return True diff --git a/mitmproxy/options.py b/mitmproxy/options.py new file mode 100644 index 00000000..04353dca --- /dev/null +++ b/mitmproxy/options.py @@ -0,0 +1,104 @@ +from __future__ import absolute_import, print_function, division + +import contextlib +import blinker +import pprint + +from mitmproxy import exceptions + + +class Options(object): + """ + .changed is a blinker Signal that triggers whenever options are + updated. If any handler in the chain raises an exceptions.OptionsError + exception, all changes are rolled back, the exception is suppressed, + and the .errored signal is notified. + """ + _initialized = False + attributes = [] + + def __new__(cls, *args, **kwargs): + # Initialize instance._opts before __init__ is called. + # This allows us to call super().__init__() last, which then sets + # ._initialized = True as the final operation. + instance = super(Options, cls).__new__(cls) + instance.__dict__["_opts"] = {} + return instance + + def __init__(self): + self.__dict__["changed"] = blinker.Signal() + self.__dict__["errored"] = blinker.Signal() + self.__dict__["_initialized"] = True + + @contextlib.contextmanager + def rollback(self): + old = self._opts.copy() + try: + yield + except exceptions.OptionsError as e: + # Notify error handlers + self.errored.send(self, exc=e) + # Rollback + self.__dict__["_opts"] = old + self.changed.send(self) + + def __eq__(self, other): + return self._opts == other._opts + + def __copy__(self): + return self.__class__(**self._opts) + + def __getattr__(self, attr): + if attr in self._opts: + return self._opts[attr] + else: + raise AttributeError() + + def __setattr__(self, attr, value): + if not self._initialized: + self._opts[attr] = value + return + if attr not in self._opts: + raise KeyError("No such option: %s" % attr) + with self.rollback(): + self._opts[attr] = value + self.changed.send(self) + + def get(self, k, d=None): + return self._opts.get(k, d) + + def update(self, **kwargs): + for k in kwargs: + if k not in self._opts: + raise KeyError("No such option: %s" % k) + with self.rollback(): + self._opts.update(kwargs) + self.changed.send(self) + + def setter(self, attr): + """ + Generate a setter for a given attribute. This returns a callable + taking a single argument. + """ + if attr not in self._opts: + raise KeyError("No such option: %s" % attr) + return lambda x: self.__setattr__(attr, x) + + def toggler(self, attr): + """ + Generate a toggler for a boolean attribute. This returns a callable + that takes no arguments. + """ + if attr not in self._opts: + raise KeyError("No such option: %s" % attr) + return lambda: self.__setattr__(attr, not getattr(self, attr)) + + def __repr__(self): + options = pprint.pformat(self._opts, indent=4).strip(" {}") + if "\n" in options: + options = "\n " + options + "\n" + return "{mod}.{cls}({{{options}}})".format( + mod=type(self).__module__, + cls=type(self).__name__, + options=options + ) diff --git a/mitmproxy/platform/osx.py b/mitmproxy/platform/osx.py index b5dce793..6a555f32 100644 --- a/mitmproxy/platform/osx.py +++ b/mitmproxy/platform/osx.py @@ -23,12 +23,12 @@ class Resolver(object): try: stxt = subprocess.check_output(self.STATECMD, stderr=subprocess.STDOUT) except subprocess.CalledProcessError as e: - if "sudo: a password is required" in e.output: + if "sudo: a password is required" in e.output.decode(errors="replace"): insufficient_priv = True else: raise RuntimeError("Error getting pfctl state: " + repr(e)) else: - insufficient_priv = "sudo: a password is required" in stxt + insufficient_priv = "sudo: a password is required" in stxt.decode(errors="replace") if insufficient_priv: raise RuntimeError( diff --git a/mitmproxy/protocol/http.py b/mitmproxy/protocol/http.py index 187c17f6..2c70f288 100644 --- a/mitmproxy/protocol/http.py +++ b/mitmproxy/protocol/http.py @@ -41,10 +41,10 @@ class _HttpTransmissionLayer(base.Layer): yield "this is a generator" # pragma: no cover def send_response(self, response): - if response.content is None: + if response.data.content is None: raise netlib.exceptions.HttpException("Cannot assemble flow with missing content") self.send_response_headers(response) - self.send_response_body(response, [response.content]) + self.send_response_body(response, [response.data.content]) def send_response_headers(self, response): raise NotImplementedError() diff --git a/mitmproxy/protocol/http2.py b/mitmproxy/protocol/http2.py index b9a30c7e..ee66393f 100644 --- a/mitmproxy/protocol/http2.py +++ b/mitmproxy/protocol/http2.py @@ -5,7 +5,6 @@ import time import traceback import h2.exceptions -import hyperframe import six from h2 import connection from h2 import events @@ -55,12 +54,12 @@ class SafeH2Connection(connection.H2Connection): self.update_settings(new_settings) self.conn.send(self.data_to_send()) - def safe_send_headers(self, is_zombie, stream_id, headers): - # make sure to have a lock - if is_zombie(): # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") - self.send_headers(stream_id, headers.fields) - self.conn.send(self.data_to_send()) + def safe_send_headers(self, is_zombie, stream_id, headers, **kwargs): + with self.lock: + if is_zombie(): # pragma: no cover + raise exceptions.Http2ProtocolException("Zombie Stream") + self.send_headers(stream_id, headers.fields, **kwargs) + self.conn.send(self.data_to_send()) def safe_send_body(self, is_zombie, stream_id, chunks): for chunk in chunks: @@ -79,7 +78,7 @@ class SafeH2Connection(connection.H2Connection): self.send_data(stream_id, frame_chunk) try: self.conn.send(self.data_to_send()) - except Exception as e: + except Exception as e: # pragma: no cover raise e finally: self.lock.release() @@ -131,6 +130,7 @@ class Http2Layer(base.Layer): [repr(event)] ) + eid = None if hasattr(event, 'stream_id'): if is_server and event.stream_id % 2 == 1: eid = self.server_to_client_stream_ids[event.stream_id] @@ -138,92 +138,164 @@ class Http2Layer(base.Layer): eid = event.stream_id if isinstance(event, events.RequestReceived): - headers = netlib.http.Headers([[k, v] for k, v in event.headers]) - self.streams[eid] = Http2SingleStreamLayer(self, eid, headers) - self.streams[eid].timestamp_start = time.time() - self.streams[eid].start() + return self._handle_request_received(eid, event) elif isinstance(event, events.ResponseReceived): - headers = netlib.http.Headers([[k, v] for k, v in event.headers]) - self.streams[eid].queued_data_length = 0 - self.streams[eid].timestamp_start = time.time() - self.streams[eid].response_headers = headers - self.streams[eid].response_arrived.set() + return self._handle_response_received(eid, event) elif isinstance(event, events.DataReceived): - if self.config.body_size_limit and self.streams[eid].queued_data_length > self.config.body_size_limit: - raise netlib.exceptions.HttpException("HTTP body too large. Limit is {}.".format(self.config.body_size_limit)) - self.streams[eid].data_queue.put(event.data) - self.streams[eid].queued_data_length += len(event.data) - source_conn.h2.safe_increment_flow_control(event.stream_id, event.flow_controlled_length) + return self._handle_data_received(eid, event, source_conn) elif isinstance(event, events.StreamEnded): - self.streams[eid].timestamp_end = time.time() - self.streams[eid].data_finished.set() + return self._handle_stream_ended(eid) elif isinstance(event, events.StreamReset): - self.streams[eid].zombie = time.time() - if eid in self.streams and event.error_code == 0x8: - if is_server: - other_stream_id = self.streams[eid].client_stream_id - else: - other_stream_id = self.streams[eid].server_stream_id - if other_stream_id is not None: - other_conn.h2.safe_reset_stream(other_stream_id, event.error_code) + return self._handle_stream_reset(eid, event, is_server, other_conn) elif isinstance(event, events.RemoteSettingsChanged): - new_settings = dict([(id, cs.new_value) for (id, cs) in six.iteritems(event.changed_settings)]) - other_conn.h2.safe_update_settings(new_settings) + return self._handle_remote_settings_changed(event, other_conn) elif isinstance(event, events.ConnectionTerminated): - if event.error_code == h2.errors.NO_ERROR: - # Do not immediately terminate the other connection. - # Some streams might be still sending data to the client. - return False - else: - # Something terrible has happened - kill everything! - self.client_conn.h2.close_connection( - error_code=event.error_code, - last_stream_id=event.last_stream_id, - additional_data=event.additional_data - ) - self.client_conn.send(self.client_conn.h2.data_to_send()) - self._kill_all_streams() - return False - + return self._handle_connection_terminated(event) elif isinstance(event, events.PushedStreamReceived): - # pushed stream ids should be unique and not dependent on race conditions - # only the parent stream id must be looked up first - parent_eid = self.server_to_client_stream_ids[event.parent_stream_id] - with self.client_conn.h2.lock: - self.client_conn.h2.push_stream(parent_eid, event.pushed_stream_id, event.headers) - self.client_conn.send(self.client_conn.h2.data_to_send()) - - headers = netlib.http.Headers([[str(k), str(v)] for k, v in event.headers]) - self.streams[event.pushed_stream_id] = Http2SingleStreamLayer(self, event.pushed_stream_id, headers) - self.streams[event.pushed_stream_id].timestamp_start = time.time() - self.streams[event.pushed_stream_id].pushed = True - self.streams[event.pushed_stream_id].parent_stream_id = parent_eid - self.streams[event.pushed_stream_id].timestamp_end = time.time() - self.streams[event.pushed_stream_id].request_data_finished.set() - self.streams[event.pushed_stream_id].start() + return self._handle_pushed_stream_received(event) elif isinstance(event, events.PriorityUpdated): - stream_id = event.stream_id - if stream_id in self.streams.keys() and self.streams[stream_id].server_stream_id: - stream_id = self.streams[stream_id].server_stream_id + return self._handle_priority_updated(eid, event) + elif isinstance(event, events.TrailersReceived): + raise NotImplementedError('TrailersReceived not implemented') - depends_on = event.depends_on - if depends_on in self.streams.keys() and self.streams[depends_on].server_stream_id: - depends_on = self.streams[depends_on].server_stream_id + # fail-safe for unhandled events + return True - # weight is between 1 and 256 (inclusive), but represented as uint8 (0 to 255) - frame = hyperframe.frame.PriorityFrame(stream_id, depends_on, event.weight - 1, event.exclusive) - self.server_conn.send(frame.serialize()) - elif isinstance(event, events.TrailersReceived): - raise NotImplementedError() + def _handle_request_received(self, eid, event): + headers = netlib.http.Headers([[k, v] for k, v in event.headers]) + self.streams[eid] = Http2SingleStreamLayer(self, eid, headers) + self.streams[eid].timestamp_start = time.time() + self.streams[eid].no_body = (event.stream_ended is not None) + if event.priority_updated is not None: + self.streams[eid].priority_exclusive = event.priority_updated.exclusive + self.streams[eid].priority_depends_on = event.priority_updated.depends_on + self.streams[eid].priority_weight = event.priority_updated.weight + self.streams[eid].handled_priority_event = event.priority_updated + self.streams[eid].start() + return True + + def _handle_response_received(self, eid, event): + headers = netlib.http.Headers([[k, v] for k, v in event.headers]) + self.streams[eid].queued_data_length = 0 + self.streams[eid].timestamp_start = time.time() + self.streams[eid].response_headers = headers + self.streams[eid].response_arrived.set() + return True + + def _handle_data_received(self, eid, event, source_conn): + if self.config.body_size_limit and self.streams[eid].queued_data_length > self.config.body_size_limit: + self.streams[eid].zombie = time.time() + source_conn.h2.safe_reset_stream(event.stream_id, h2.errors.REFUSED_STREAM) + self.log("HTTP body too large. Limit is {}.".format(self.config.body_size_limit), "info") + else: + self.streams[eid].data_queue.put(event.data) + self.streams[eid].queued_data_length += len(event.data) + source_conn.h2.safe_increment_flow_control(event.stream_id, event.flow_controlled_length) + return True + + def _handle_stream_ended(self, eid): + self.streams[eid].timestamp_end = time.time() + self.streams[eid].data_finished.set() + return True + + def _handle_stream_reset(self, eid, event, is_server, other_conn): + self.streams[eid].zombie = time.time() + if eid in self.streams and event.error_code == h2.errors.CANCEL: + if is_server: + other_stream_id = self.streams[eid].client_stream_id + else: + other_stream_id = self.streams[eid].server_stream_id + if other_stream_id is not None: + other_conn.h2.safe_reset_stream(other_stream_id, event.error_code) + return True + + def _handle_remote_settings_changed(self, event, other_conn): + new_settings = dict([(key, cs.new_value) for (key, cs) in six.iteritems(event.changed_settings)]) + other_conn.h2.safe_update_settings(new_settings) + return True + + def _handle_connection_terminated(self, event): + if event.error_code != h2.errors.NO_ERROR: + # Something terrible has happened - kill everything! + self.client_conn.h2.close_connection( + error_code=event.error_code, + last_stream_id=event.last_stream_id, + additional_data=event.additional_data + ) + self.client_conn.send(self.client_conn.h2.data_to_send()) + self._kill_all_streams() + else: + """ + Do not immediately terminate the other connection. + Some streams might be still sending data to the client. + """ + return False + + def _handle_pushed_stream_received(self, event): + # pushed stream ids should be unique and not dependent on race conditions + # only the parent stream id must be looked up first + parent_eid = self.server_to_client_stream_ids[event.parent_stream_id] + with self.client_conn.h2.lock: + self.client_conn.h2.push_stream(parent_eid, event.pushed_stream_id, event.headers) + self.client_conn.send(self.client_conn.h2.data_to_send()) + + headers = netlib.http.Headers([[k, v] for k, v in event.headers]) + self.streams[event.pushed_stream_id] = Http2SingleStreamLayer(self, event.pushed_stream_id, headers) + self.streams[event.pushed_stream_id].timestamp_start = time.time() + self.streams[event.pushed_stream_id].pushed = True + self.streams[event.pushed_stream_id].parent_stream_id = parent_eid + self.streams[event.pushed_stream_id].timestamp_end = time.time() + self.streams[event.pushed_stream_id].request_data_finished.set() + self.streams[event.pushed_stream_id].start() + return True + def _handle_priority_updated(self, eid, event): + if eid in self.streams and self.streams[eid].handled_priority_event is event: + # this event was already handled during stream creation + # HeadersFrame + Priority information as RequestReceived + return True + + with self.server_conn.h2.lock: + mapped_stream_id = event.stream_id + if mapped_stream_id in self.streams and self.streams[mapped_stream_id].server_stream_id: + # if the stream is already up and running and was sent to the server + # use the mapped server stream id to update priority information + mapped_stream_id = self.streams[mapped_stream_id].server_stream_id + + if eid in self.streams: + self.streams[eid].priority_exclusive = event.exclusive + self.streams[eid].priority_depends_on = event.depends_on + self.streams[eid].priority_weight = event.weight + + self.server_conn.h2.prioritize( + mapped_stream_id, + weight=event.weight, + depends_on=self._map_depends_on_stream_id(mapped_stream_id, event.depends_on), + exclusive=event.exclusive + ) + self.server_conn.send(self.server_conn.h2.data_to_send()) return True + def _map_depends_on_stream_id(self, stream_id, depends_on): + mapped_depends_on = depends_on + if mapped_depends_on in self.streams and self.streams[mapped_depends_on].server_stream_id: + # if the depends-on-stream is already up and running and was sent to the server + # use the mapped server stream id to update priority information + mapped_depends_on = self.streams[mapped_depends_on].server_stream_id + if stream_id == mapped_depends_on: + # looks like one of the streams wasn't opened yet + # prevent self-dependent streams which result in ProtocolError + mapped_depends_on += 2 + return mapped_depends_on + def _cleanup_streams(self): death_time = time.time() - 10 - for stream_id in self.streams.keys(): - zombie = self.streams[stream_id].zombie - if zombie and zombie <= death_time: - self.streams.pop(stream_id, None) + + zombie_streams = [(stream_id, stream) for stream_id, stream in list(self.streams.items()) if stream.zombie] + outdated_streams = [stream_id for stream_id, stream in zombie_streams if stream.zombie <= death_time] + + for stream_id in outdated_streams: # pragma: no cover + self.streams.pop(stream_id, None) def _kill_all_streams(self): for stream in self.streams.values(): @@ -267,8 +339,8 @@ class Http2Layer(base.Layer): self._kill_all_streams() return - self._cleanup_streams() - except Exception as e: + self._cleanup_streams() + except Exception as e: # pragma: no cover self.log(repr(e), "info") self.log(traceback.format_exc(), "debug") self._kill_all_streams() @@ -296,6 +368,22 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) self.response_queued_data_length = 0 self.response_data_finished = threading.Event() + self.no_body = False + + self.priority_exclusive = None + self.priority_depends_on = None + self.priority_weight = None + self.handled_priority_event = None + + def check_close_connection(self, flow): + # This layer only handles a single stream. + # RFC 7540 8.1: An HTTP request/response exchange fully consumes a single stream. + return True + + def set_server(self, *args, **kwargs): # pragma: no cover + # do not mess with the server connection - all streams share it. + pass + @property def data_queue(self): if self.response_arrived.is_set(): @@ -330,39 +418,13 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) if self.zombie: # pragma: no cover raise exceptions.Http2ProtocolException("Zombie Stream") - authority = self.request_headers.get(':authority', '') - method = self.request_headers.get(':method', 'GET') - scheme = self.request_headers.get(':scheme', 'https') - path = self.request_headers.get(':path', '/') - self.request_headers.clear(":method") - self.request_headers.clear(":scheme") - self.request_headers.clear(":path") - host = None - port = None - - if path == '*' or path.startswith("/"): - first_line_format = "relative" - elif method == 'CONNECT': # pragma: no cover - raise NotImplementedError("CONNECT over HTTP/2 is not implemented.") - else: # pragma: no cover - first_line_format = "absolute" - # FIXME: verify if path or :host contains what we need - scheme, host, port, _ = netlib.http.url.parse(path) - - if authority: - host, _, port = authority.partition(':') - - if not host: - host = 'localhost' - if not port: - port = 443 if scheme == 'https' else 80 - port = int(port) - data = [] while self.request_data_queue.qsize() > 0: data.append(self.request_data_queue.get()) data = b"".join(data) + first_line_format, method, scheme, host, port, path = http2.parse_headers(self.request_headers) + return models.HTTPRequest( first_line_format, method, @@ -412,25 +474,39 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) headers.insert(0, ":path", message.path) headers.insert(0, ":method", message.method) headers.insert(0, ":scheme", message.scheme) - self.server_stream_id = self.server_conn.h2.get_next_available_stream_id() - self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id + + priority_exclusive = None + priority_depends_on = None + priority_weight = None + if self.handled_priority_event: + # only send priority information if they actually came with the original HeadersFrame + # and not if they got updated before/after with a PriorityFrame + priority_exclusive = self.priority_exclusive + priority_depends_on = self._map_depends_on_stream_id(self.server_stream_id, self.priority_depends_on) + priority_weight = self.priority_weight try: self.server_conn.h2.safe_send_headers( self.is_zombie, self.server_stream_id, headers, + end_stream=self.no_body, + priority_exclusive=priority_exclusive, + priority_depends_on=priority_depends_on, + priority_weight=priority_weight, ) - except Exception as e: + except Exception as e: # pragma: no cover raise e finally: self.server_conn.h2.lock.release() - self.server_conn.h2.safe_send_body( - self.is_zombie, - self.server_stream_id, - message.body - ) + if not self.no_body: + self.server_conn.h2.safe_send_body( + self.is_zombie, + self.server_stream_id, + [message.body] + ) + if self.zombie: # pragma: no cover raise exceptions.Http2ProtocolException("Zombie Stream") @@ -442,12 +518,12 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) status_code = int(self.response_headers.get(':status', 502)) headers = self.response_headers.copy() - headers.clear(":status") + headers.pop(":status", None) return models.HTTPResponse( http_version=b"HTTP/2.0", status_code=status_code, - reason='', + reason=b'', headers=headers, content=None, timestamp_start=self.timestamp_start, @@ -458,7 +534,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) while True: try: yield self.response_data_queue.get(timeout=1) - except queue.Empty: + except queue.Empty: # pragma: no cover pass if self.response_data_finished.is_set(): if self.zombie: # pragma: no cover @@ -472,6 +548,9 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) def send_response_headers(self, response): headers = response.headers.copy() headers.insert(0, ":status", str(response.status_code)) + for forbidden_header in h2.utilities.CONNECTION_HEADERS: + if forbidden_header in headers: + del headers[forbidden_header] with self.client_conn.h2.lock: self.client_conn.h2.safe_send_headers( self.is_zombie, @@ -490,24 +569,12 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) if self.zombie: # pragma: no cover raise exceptions.Http2ProtocolException("Zombie Stream") - def check_close_connection(self, flow): - # This layer only handles a single stream. - # RFC 7540 8.1: An HTTP request/response exchange fully consumes a single stream. - return True - - def set_server(self, *args, **kwargs): # pragma: no cover - # do not mess with the server connection - all streams share it. - pass - def run(self): - self() - - def __call__(self): layer = http.HttpLayer(self, self.mode) try: layer() - except exceptions.ProtocolException as e: + except exceptions.ProtocolException as e: # pragma: no cover self.log(repr(e), "info") self.log(traceback.format_exc(), "debug") diff --git a/mitmproxy/protocol/http_replay.py b/mitmproxy/protocol/http_replay.py index e804eba9..986de845 100644 --- a/mitmproxy/protocol/http_replay.py +++ b/mitmproxy/protocol/http_replay.py @@ -49,7 +49,7 @@ class RequestReplayThread(basethread.BaseThread): server = models.ServerConnection(server_address, (self.config.host, 0)) server.connect() if r.scheme == "https": - connect_request = models.make_connect_request((r.host, r.port)) + connect_request = models.make_connect_request((r.data.host, r.port)) server.wfile.write(http1.assemble_request(connect_request)) server.wfile.flush() resp = http1.read_response( diff --git a/mitmproxy/protocol/tls.py b/mitmproxy/protocol/tls.py index 9f883b2b..8ef34493 100644 --- a/mitmproxy/protocol/tls.py +++ b/mitmproxy/protocol/tls.py @@ -10,6 +10,7 @@ import netlib.exceptions from mitmproxy import exceptions from mitmproxy.contrib.tls import _constructs from mitmproxy.protocol import base +from netlib import utils # taken from https://testssl.sh/openssl-rfc.mappping.html @@ -274,10 +275,11 @@ class TlsClientHello(object): is_valid_sni_extension = ( extension.type == 0x00 and len(extension.server_names) == 1 and - extension.server_names[0].type == 0 + extension.server_names[0].type == 0 and + utils.is_valid_host(extension.server_names[0].name) ) if is_valid_sni_extension: - return extension.server_names[0].name + return extension.server_names[0].name.decode("idna") @property def alpn_protocols(self): @@ -403,13 +405,14 @@ class TlsLayer(base.Layer): self._establish_tls_with_server() def set_server_tls(self, server_tls, sni=None): + # type: (bool, Union[six.text_type, None, False]) -> None """ Set the TLS settings for the next server connection that will be established. This function will not alter an existing connection. Args: server_tls: Shall we establish TLS with the server? - sni: ``bytes`` for a custom SNI value, + sni: ``str`` for a custom SNI value, ``None`` for the client SNI value, ``False`` if no SNI value should be sent. """ @@ -602,9 +605,9 @@ class TlsLayer(base.Layer): host = upstream_cert.cn.decode("utf8").encode("idna") # Also add SNI values. if self._client_hello.sni: - sans.add(self._client_hello.sni) + sans.add(self._client_hello.sni.encode("idna")) if self._custom_server_sni: - sans.add(self._custom_server_sni) + sans.add(self._custom_server_sni.encode("idna")) # RFC 2818: If a subjectAltName extension of type dNSName is present, that MUST be used as the identity. # In other words, the Common Name is irrelevant then. diff --git a/mitmproxy/proxy/root_context.py b/mitmproxy/proxy/root_context.py index 57183c7e..4d6509d4 100644 --- a/mitmproxy/proxy/root_context.py +++ b/mitmproxy/proxy/root_context.py @@ -100,7 +100,7 @@ class RootContext(object): is_ascii = ( len(d) == 3 and # expect A-Za-z - all(65 <= x <= 90 and 97 <= x <= 122 for x in six.iterbytes(d)) + all(65 <= x <= 90 or 97 <= x <= 122 for x in six.iterbytes(d)) ) if self.config.rawtcp and not is_ascii: return protocol.RawTCPLayer(top_layer) diff --git a/mitmproxy/script/__init__.py b/mitmproxy/script/__init__.py index d6bff4c7..e75f282a 100644 --- a/mitmproxy/script/__init__.py +++ b/mitmproxy/script/__init__.py @@ -1,13 +1,5 @@ -from . import reloader from .concurrent import concurrent -from .script import Script -from .script_context import ScriptContext -from ..exceptions import ScriptException __all__ = [ - "Script", - "ScriptContext", "concurrent", - "ScriptException", - "reloader" ] diff --git a/mitmproxy/script/concurrent.py b/mitmproxy/script/concurrent.py index 56d39d0b..0cc0514e 100644 --- a/mitmproxy/script/concurrent.py +++ b/mitmproxy/script/concurrent.py @@ -13,14 +13,14 @@ class ScriptThread(basethread.BaseThread): def concurrent(fn): - if fn.__name__ not in controller.Events: + if fn.__name__ not in controller.Events - set(["start", "configure", "tick"]): raise NotImplementedError( "Concurrent decorator not supported for '%s' method." % fn.__name__ ) - def _concurrent(ctx, obj): + def _concurrent(obj): def run(): - fn(ctx, obj) + fn(obj) if not obj.reply.acked: obj.reply.ack() obj.reply.take() diff --git a/mitmproxy/script/reloader.py b/mitmproxy/script/reloader.py deleted file mode 100644 index 50401034..00000000 --- a/mitmproxy/script/reloader.py +++ /dev/null @@ -1,47 +0,0 @@ -from __future__ import absolute_import, print_function, division - -import os - -from watchdog.events import RegexMatchingEventHandler - -from watchdog.observers.polling import PollingObserver as Observer -# We occasionally have watchdog errors on Windows, Linux and Mac when using the native observers. -# After reading through the watchdog source code and issue tracker, -# we may want to replace this with a very simple implementation of our own. - -_observers = {} - - -def watch(script, callback): - if script in _observers: - raise RuntimeError("Script already observed") - script_dir = os.path.dirname(os.path.abspath(script.filename)) - script_name = os.path.basename(script.filename) - event_handler = _ScriptModificationHandler(callback, filename=script_name) - observer = Observer() - observer.schedule(event_handler, script_dir) - observer.start() - _observers[script] = observer - - -def unwatch(script): - observer = _observers.pop(script, None) - if observer: - observer.stop() - observer.join() - - -class _ScriptModificationHandler(RegexMatchingEventHandler): - - def __init__(self, callback, filename='.*'): - - super(_ScriptModificationHandler, self).__init__( - ignore_directories=True, - regexes=['.*' + filename] - ) - self.callback = callback - - def on_modified(self, event): - self.callback() - -__all__ = ["watch", "unwatch"] diff --git a/mitmproxy/script/script.py b/mitmproxy/script/script.py deleted file mode 100644 index 9ff79f52..00000000 --- a/mitmproxy/script/script.py +++ /dev/null @@ -1,146 +0,0 @@ -""" -The script object representing mitmproxy inline scripts. -Script objects know nothing about mitmproxy or mitmproxy's API - this knowledge is provided -by the mitmproxy-specific ScriptContext. -""" -# Do not import __future__ here, this would apply transitively to the inline scripts. -from __future__ import absolute_import, print_function, division - -import inspect -import os -import shlex -import sys -import contextlib -import warnings - -import six - -from mitmproxy import exceptions - - -@contextlib.contextmanager -def setargs(args): - oldargs = sys.argv - sys.argv = args - try: - yield - finally: - sys.argv = oldargs - - -class Script(object): - - """ - Script object representing an inline script. - """ - - def __init__(self, command, context): - self.command = command - self.args = self.parse_command(command) - self.ctx = context - self.ns = None - - def __enter__(self): - self.load() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if exc_val: - return False # reraise the exception - self.unload() - - @property - def filename(self): - return self.args[0] - - @staticmethod - def parse_command(command): - if not command or not command.strip(): - raise exceptions.ScriptException("Empty script command.") - # Windows: escape all backslashes in the path. - if os.name == "nt": # pragma: no cover - backslashes = shlex.split(command, posix=False)[0].count("\\") - command = command.replace("\\", "\\\\", backslashes) - args = shlex.split(command) # pragma: no cover - args[0] = os.path.expanduser(args[0]) - if not os.path.exists(args[0]): - raise exceptions.ScriptException( - ("Script file not found: %s.\r\n" - "If your script path contains spaces, " - "make sure to wrap it in additional quotes, e.g. -s \"'./foo bar/baz.py' --args\".") % - args[0]) - elif os.path.isdir(args[0]): - raise exceptions.ScriptException("Not a file: %s" % args[0]) - return args - - def load(self): - """ - Loads an inline script. - - Returns: - The return value of self.run("start", ...) - - Raises: - ScriptException on failure - """ - if self.ns is not None: - raise exceptions.ScriptException("Script is already loaded") - script_dir = os.path.dirname(os.path.abspath(self.args[0])) - self.ns = {'__file__': os.path.abspath(self.args[0])} - sys.path.append(script_dir) - sys.path.append(os.path.join(script_dir, "..")) - try: - with open(self.filename) as f: - code = compile(f.read(), self.filename, 'exec') - exec(code, self.ns, self.ns) - except Exception: - six.reraise( - exceptions.ScriptException, - exceptions.ScriptException.from_exception_context(), - sys.exc_info()[2] - ) - finally: - sys.path.pop() - sys.path.pop() - - start_fn = self.ns.get("start") - if start_fn and len(inspect.getargspec(start_fn).args) == 2: - warnings.warn( - "The 'args' argument of the start() script hook is deprecated. " - "Please use sys.argv instead." - ) - return self.run("start", self.args) - return self.run("start") - - def unload(self): - try: - return self.run("done") - finally: - self.ns = None - - def run(self, name, *args, **kwargs): - """ - Runs an inline script hook. - - Returns: - The return value of the method. - None, if the script does not provide the method. - - Raises: - ScriptException if there was an exception. - """ - if self.ns is None: - raise exceptions.ScriptException("Script not loaded.") - f = self.ns.get(name) - if f: - try: - with setargs(self.args): - return f(self.ctx, *args, **kwargs) - except Exception: - six.reraise( - exceptions.ScriptException, - exceptions.ScriptException.from_exception_context(), - sys.exc_info()[2] - ) - else: - return None diff --git a/mitmproxy/script/script_context.py b/mitmproxy/script/script_context.py deleted file mode 100644 index 44e2736b..00000000 --- a/mitmproxy/script/script_context.py +++ /dev/null @@ -1,61 +0,0 @@ -""" -The mitmproxy script context provides an API to inline scripts. -""" -from __future__ import absolute_import, print_function, division - -from mitmproxy import contentviews - - -class ScriptContext(object): - - """ - The script context should be used to interact with the global mitmproxy state from within a - script. - """ - - def __init__(self, master): - self._master = master - - def log(self, message, level="info"): - """ - Logs an event. - - By default, only events with level "error" get displayed. This can be controlled with the "-v" switch. - How log messages are handled depends on the front-end. mitmdump will print them to stdout, - mitmproxy sends output to the eventlog for display ("e" keyboard shortcut). - """ - self._master.add_event(message, level) - - def kill_flow(self, f): - """ - Kills a flow immediately. No further data will be sent to the client or the server. - """ - f.kill(self._master) - - def duplicate_flow(self, f): - """ - Returns a duplicate of the specified flow. The flow is also - injected into the current state, and is ready for editing, replay, - etc. - """ - self._master.pause_scripts = True - f = self._master.duplicate_flow(f) - self._master.pause_scripts = False - return f - - def replay_request(self, f): - """ - Replay the request on the current flow. The response will be added - to the flow object. - """ - return self._master.replay_request(f, block=True, run_scripthooks=False) - - @property - def app_registry(self): - return self._master.apps - - def add_contentview(self, view_obj): - contentviews.add(view_obj) - - def remove_contentview(self, view_obj): - contentviews.remove(view_obj) diff --git a/mitmproxy/stateobject.py b/mitmproxy/stateobject.py index 8db6cda3..5e4ae6e3 100644 --- a/mitmproxy/stateobject.py +++ b/mitmproxy/stateobject.py @@ -52,19 +52,20 @@ class StateObject(netlib.basetypes.Serializable): """ state = state.copy() for attr, cls in six.iteritems(self._stateobject_attributes): - if state.get(attr) is None: - setattr(self, attr, state.pop(attr)) + val = state.pop(attr) + if val is None: + setattr(self, attr, val) else: curr = getattr(self, attr) if hasattr(curr, "set_state"): - curr.set_state(state.pop(attr)) + curr.set_state(val) elif hasattr(cls, "from_state"): - obj = cls.from_state(state.pop(attr)) + obj = cls.from_state(val) setattr(self, attr, obj) elif _is_list(cls): cls = cls.__parameters__[0] if cls.__parameters__ else cls.__args__[0] - setattr(self, attr, [cls.from_state(x) for x in state.pop(attr)]) + setattr(self, attr, [cls.from_state(x) for x in val]) else: # primitive types such as int, str, ... - setattr(self, attr, cls(state.pop(attr))) + setattr(self, attr, cls(val)) if state: raise RuntimeWarning("Unexpected State in __setstate__: {}".format(state)) diff --git a/mitmproxy/utils.py b/mitmproxy/utils.py index 15785c72..1c75dd83 100644 --- a/mitmproxy/utils.py +++ b/mitmproxy/utils.py @@ -36,3 +36,7 @@ class LRUCache: d = self.cacheList.pop() self.cache.pop(d) return ret + + +def log_tier(level): + return dict(error=0, warn=1, info=2, debug=3).get(level) diff --git a/mitmproxy/web/app.py b/mitmproxy/web/app.py index a2798472..8c080e98 100644 --- a/mitmproxy/web/app.py +++ b/mitmproxy/web/app.py @@ -12,34 +12,57 @@ from io import BytesIO from mitmproxy.flow import FlowWriter, FlowReader from mitmproxy import filt +from mitmproxy import models from netlib import version -def _strip_content(flow_state): +def convert_flow_to_json_dict(flow): + # type: (models.Flow) -> dict """ Remove flow message content and cert to save transmission space. Args: - flow_state: The original flow state. Will be left unmodified + flow: The original flow. """ - for attr in ("request", "response"): - if attr in flow_state: - message = flow_state[attr] - if message is None: - continue - if message["content"]: - message["contentLength"] = len(message["content"]) - else: - message["contentLength"] = None - del message["content"] - - if "backup" in flow_state: - del flow_state["backup"] - flow_state["modified"] = True - - flow_state.get("server_conn", {}).pop("cert", None) - - return flow_state + f = { + "id": flow.id, + "intercepted": flow.intercepted, + "client_conn": flow.client_conn.get_state(), + "server_conn": flow.server_conn.get_state(), + "type": flow.type + } + if flow.error: + f["error"] = flow.error.get_state() + + if isinstance(flow, models.HTTPFlow): + if flow.request: + f["request"] = { + "method": flow.request.method, + "scheme": flow.request.scheme, + "host": flow.request.host, + "port": flow.request.port, + "path": flow.request.path, + "http_version": flow.request.http_version, + "headers": tuple(flow.request.headers.items(True)), + "contentLength": len(flow.request.content) if flow.request.content is not None else None, + "timestamp_start": flow.request.timestamp_start, + "timestamp_end": flow.request.timestamp_end, + "is_replay": flow.request.is_replay, + } + if flow.response: + f["response"] = { + "http_version": flow.response.http_version, + "status_code": flow.response.status_code, + "reason": flow.response.reason, + "headers": tuple(flow.response.headers.items(True)), + "contentLength": len(flow.response.content) if flow.response.content is not None else None, + "timestamp_start": flow.response.timestamp_start, + "timestamp_end": flow.response.timestamp_end, + "is_replay": flow.response.is_replay, + } + f.get("server_conn", {}).pop("cert", None) + + return f class APIError(tornado.web.HTTPError): @@ -158,7 +181,7 @@ class Flows(RequestHandler): def get(self): self.write(dict( - data=[_strip_content(f.get_state()) for f in self.state.flows] + data=[convert_flow_to_json_dict(f) for f in self.state.flows] )) @@ -272,7 +295,7 @@ class FlowContent(RequestHandler): def get(self, flow_id, message): message = getattr(self.flow, message) - if not message.content: + if not message.raw_content: raise APIError(400, "No content.") content_encoding = message.headers.get("Content-Encoding", None) @@ -295,7 +318,7 @@ class FlowContent(RequestHandler): self.set_header("Content-Type", "application/text") self.set_header("X-Content-Type-Options", "nosniff") self.set_header("X-Frame-Options", "DENY") - self.write(message.content) + self.write(message.raw_content) class Events(RequestHandler): @@ -321,7 +344,7 @@ class Settings(RequestHandler): http2=self.master.server.config.http2, anticache=self.master.options.anticache, anticomp=self.master.options.anticomp, - stickyauth=self.master.stickyauth_txt, + stickyauth=self.master.options.stickyauth, stickycookie=self.master.stickycookie_txt, stream= self.master.stream_large_bodies.max_size if self.master.stream_large_bodies else False ) @@ -355,7 +378,7 @@ class Settings(RequestHandler): self.master.set_stickycookie(v) update[k] = v elif k == "stickyauth": - self.master.set_stickyauth(v) + self.master.options.stickyauth = v update[k] = v elif k == "stream": self.master.set_stream_large_bodies(v) diff --git a/mitmproxy/web/master.py b/mitmproxy/web/master.py index d034a24b..83f18539 100644 --- a/mitmproxy/web/master.py +++ b/mitmproxy/web/master.py @@ -6,6 +6,9 @@ import collections import tornado.httpserver import tornado.ioloop +from typing import Optional # noqa + +from mitmproxy import builtins from mitmproxy import controller from mitmproxy import exceptions from mitmproxy import flow @@ -27,7 +30,7 @@ class WebFlowView(flow.FlowView): app.ClientConnection.broadcast( type="UPDATE_FLOWS", cmd="add", - data=app._strip_content(f.get_state()) + data=app.convert_flow_to_json_dict(f) ) def _update(self, f): @@ -35,7 +38,7 @@ class WebFlowView(flow.FlowView): app.ClientConnection.broadcast( type="UPDATE_FLOWS", cmd="update", - data=app._strip_content(f.get_state()) + data=app.convert_flow_to_json_dict(f) ) def _remove(self, f): @@ -64,7 +67,7 @@ class WebState(flow.State): self._last_event_id = 0 self.events = collections.deque(maxlen=1000) - def add_event(self, e, level): + def add_log(self, e, level): self._last_event_id += 1 entry = { "id": self._last_event_id, @@ -88,50 +91,28 @@ class WebState(flow.State): ) -class Options(object): - attributes = [ - "app", - "app_domain", - "app_ip", - "anticache", - "anticomp", - "client_replay", - "eventlog", - "keepserving", - "kill", - "intercept", - "no_server", - "outfile", - "refresh_server_playback", - "rfile", - "scripts", - "showhost", - "replacements", - "rheaders", - "setheaders", - "server_replay", - "stickycookie", - "stickyauth", - "stream_large_bodies", - "verbosity", - "wfile", - "nopop", - - "wdebug", - "wport", - "wiface", - "wauthenticator", - "wsingleuser", - "whtpasswd", - ] - - def __init__(self, **kwargs): - for k, v in kwargs.items(): - setattr(self, k, v) - for i in self.attributes: - if not hasattr(self, i): - setattr(self, i, None) - +class Options(flow.options.Options): + def __init__( + self, + intercept=False, # type: bool + wdebug=bool, # type: bool + wport=8081, # type: int + wiface="127.0.0.1", # type: str + wauthenticator=None, # type: Optional[authentication.PassMan] + wsingleuser=None, # type: Optional[str] + whtpasswd=None, # type: Optional[str] + **kwargs + ): + self.wdebug = wdebug + self.wport = wport + self.wiface = wiface + self.wauthenticator = wauthenticator + self.wsingleuser = wsingleuser + self.whtpasswd = whtpasswd + self.intercept = intercept + super(Options, self).__init__(**kwargs) + + # TODO: This doesn't belong here. def process_web_options(self, parser): if self.wsingleuser or self.whtpasswd: if self.wsingleuser: @@ -153,14 +134,18 @@ class Options(object): class WebMaster(flow.FlowMaster): def __init__(self, server, options): - self.options = options - super(WebMaster, self).__init__(server, WebState()) - self.app = app.Application(self, self.options.wdebug, self.options.wauthenticator) + super(WebMaster, self).__init__(options, server, WebState()) + self.addons.add(*builtins.default_addons()) + self.app = app.Application( + self, self.options.wdebug, self.options.wauthenticator + ) + # This line is just for type hinting + self.options = self.options # type: Options if options.rfile: try: self.load_flows_file(options.rfile) except exceptions.FlowReadException as v: - self.add_event( + self.add_log( "Could not read flow file: %s" % v, "error" ) @@ -215,6 +200,6 @@ class WebMaster(flow.FlowMaster): super(WebMaster, self).error(f) return self._process_flow(f) - def add_event(self, e, level="info"): - super(WebMaster, self).add_event(e, level) - return self.state.add_event(e, level) + def add_log(self, e, level="info"): + super(WebMaster, self).add_log(e, level) + return self.state.add_log(e, level) diff --git a/netlib/debug.py b/netlib/debug.py index a395afcb..29c7f655 100644 --- a/netlib/debug.py +++ b/netlib/debug.py @@ -7,8 +7,6 @@ import signal import platform import traceback -import psutil - from netlib import version from OpenSSL import SSL @@ -19,7 +17,7 @@ def sysinfo(): "Mitmproxy version: %s" % version.VERSION, "Python version: %s" % platform.python_version(), "Platform: %s" % platform.platform(), - "SSL version: %s" % SSL.SSLeay_version(SSL.SSLEAY_VERSION), + "SSL version: %s" % SSL.SSLeay_version(SSL.SSLEAY_VERSION).decode(), ] d = platform.linux_distribution() t = "Linux distro: %s %s %s" % d @@ -40,15 +38,32 @@ def sysinfo(): def dump_info(sig, frm, file=sys.stdout): # pragma: no cover - p = psutil.Process() - print("****************************************************", file=file) print("Summary", file=file) print("=======", file=file) - print("num threads: ", p.num_threads(), file=file) - if hasattr(p, "num_fds"): - print("num fds: ", p.num_fds(), file=file) - print("memory: ", p.memory_info(), file=file) + + try: + import psutil + except: + print("(psutil not installed, skipping some debug info)", file=file) + else: + p = psutil.Process() + print("num threads: ", p.num_threads(), file=file) + if hasattr(p, "num_fds"): + print("num fds: ", p.num_fds(), file=file) + print("memory: ", p.memory_info(), file=file) + + print(file=file) + print("Files", file=file) + print("=====", file=file) + for i in p.open_files(): + print(i, file=file) + + print(file=file) + print("Connections", file=file) + print("===========", file=file) + for i in p.connections(): + print(i, file=file) print(file=file) print("Threads", file=file) @@ -63,18 +78,6 @@ def dump_info(sig, frm, file=sys.stdout): # pragma: no cover for i in bthreads: print(i._threadinfo(), file=file) - print(file=file) - print("Files", file=file) - print("=====", file=file) - for i in p.open_files(): - print(i, file=file) - - print(file=file) - print("Connections", file=file) - print("===========", file=file) - for i in p.connections(): - print(i, file=file) - print("****************************************************", file=file) diff --git a/netlib/encoding.py b/netlib/encoding.py index 98502451..8b67b543 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -1,39 +1,62 @@ """ - Utility functions for decoding response bodies. +Utility functions for decoding response bodies. """ from __future__ import absolute_import + +import codecs from io import BytesIO import gzip import zlib +from typing import Union # noqa + -ENCODINGS = {"identity", "gzip", "deflate"} +def decode(obj, encoding, errors='strict'): + # type: (Union[str, bytes], str) -> Union[str, bytes] + """ + Decode the given input object + Returns: + The decoded value -def decode(e, content): - if not isinstance(content, bytes): - return None - encoding_map = { - "identity": identity, - "gzip": decode_gzip, - "deflate": decode_deflate, - } - if e not in encoding_map: - return None - return encoding_map[e](content) + Raises: + ValueError, if decoding fails. + """ + try: + try: + return custom_decode[encoding](obj) + except KeyError: + return codecs.decode(obj, encoding, errors) + except Exception as e: + raise ValueError("{} when decoding {} with {}".format( + type(e).__name__, + repr(obj)[:10], + repr(encoding), + )) + + +def encode(obj, encoding, errors='strict'): + # type: (Union[str, bytes], str) -> Union[str, bytes] + """ + Encode the given input object + Returns: + The encoded value -def encode(e, content): - if not isinstance(content, bytes): - return None - encoding_map = { - "identity": identity, - "gzip": encode_gzip, - "deflate": encode_deflate, - } - if e not in encoding_map: - return None - return encoding_map[e](content) + Raises: + ValueError, if encoding fails. + """ + try: + try: + return custom_encode[encoding](obj) + except KeyError: + return codecs.encode(obj, encoding, errors) + except Exception as e: + raise ValueError("{} when encoding {} with {}".format( + type(e).__name__, + repr(obj)[:10], + repr(encoding), + )) def identity(content): @@ -46,10 +69,7 @@ def identity(content): def decode_gzip(content): gfile = gzip.GzipFile(fileobj=BytesIO(content)) - try: - return gfile.read() - except (IOError, EOFError): - return None + return gfile.read() def encode_gzip(content): @@ -70,12 +90,9 @@ def decode_deflate(content): http://bugs.python.org/issue5784 """ try: - try: - return zlib.decompress(content) - except zlib.error: - return zlib.decompress(content, -15) + return zlib.decompress(content) except zlib.error: - return None + return zlib.decompress(content, -15) def encode_deflate(content): @@ -84,4 +101,16 @@ def encode_deflate(content): """ return zlib.compress(content) -__all__ = ["ENCODINGS", "encode", "decode"] + +custom_decode = { + "identity": identity, + "gzip": decode_gzip, + "deflate": decode_deflate, +} +custom_encode = { + "identity": identity, + "gzip": encode_gzip, + "deflate": encode_deflate, +} + +__all__ = ["encode", "decode"] diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py index 768a85df..dd0af99c 100644 --- a/netlib/http/cookies.py +++ b/netlib/http/cookies.py @@ -1,7 +1,8 @@ import collections +import email.utils import re +import time -import email.utils from netlib import multidict """ @@ -260,3 +261,29 @@ def refresh_set_cookie_header(c, delta): if not ret: raise ValueError("Invalid Cookie") return ret + + +def is_expired(cookie_attrs): + """ + Determines whether a cookie has expired. + + Returns: boolean + """ + + # See if 'expires' time is in the past + expires = False + if 'expires' in cookie_attrs: + e = email.utils.parsedate_tz(cookie_attrs["expires"]) + if e: + exp_ts = email.utils.mktime_tz(e) + now_ts = time.time() + expires = exp_ts < now_ts + + # or if Max-Age is 0 + max_age = False + try: + max_age = int(cookie_attrs.get('Max-Age', 1)) == 0 + except ValueError: + pass + + return expires or max_age diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 14888ea9..36e5060c 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -2,6 +2,7 @@ from __future__ import absolute_import, print_function, division import re +import collections import six from netlib import multidict from netlib import strutils @@ -148,6 +149,15 @@ class Headers(multidict.MultiDict): value = _always_bytes(value) super(Headers, self).insert(index, key, value) + def items(self, multi=False): + if multi: + return ( + (_native(k), _native(v)) + for k, v in self.fields + ) + else: + return super(Headers, self).items() + def replace(self, pattern, repl, flags=0): """ Replaces a regular expression pattern with repl in each "name: value" @@ -156,8 +166,10 @@ class Headers(multidict.MultiDict): Returns: The number of replacements made. """ - pattern = _always_bytes(pattern) - repl = _always_bytes(repl) + if isinstance(pattern, six.text_type): + pattern = strutils.escaped_str_to_bytes(pattern) + if isinstance(repl, six.text_type): + repl = strutils.escaped_str_to_bytes(repl) pattern = re.compile(pattern, flags) replacements = 0 @@ -172,8 +184,8 @@ class Headers(multidict.MultiDict): pass else: replacements += n - fields.append([name, value]) - self.fields = fields + fields.append((name, value)) + self.fields = tuple(fields) return replacements @@ -195,10 +207,22 @@ def parse_content_type(c): ts = parts[0].split("/", 1) if len(ts) != 2: return None - d = {} + d = collections.OrderedDict() if len(parts) == 2: for i in parts[1].split(";"): clause = i.split("=", 1) if len(clause) == 2: d[clause[0].strip()] = clause[1].strip() return ts[0].lower(), ts[1].lower(), d + + +def assemble_content_type(type, subtype, parameters): + if not parameters: + return "{}/{}".format(type, subtype) + params = "; ".join( + "{}={}".format(k, v) + for k, v in parameters.items() + ) + return "{}/{}; {}".format( + type, subtype, params + ) diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index 511328f1..e74732d2 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -5,7 +5,7 @@ from netlib import exceptions def assemble_request(request): - if request.content is None: + if request.data.content is None: raise exceptions.HttpException("Cannot assemble flow with missing content") head = assemble_request_head(request) body = b"".join(assemble_body(request.data.headers, [request.data.content])) @@ -19,7 +19,7 @@ def assemble_request_head(request): def assemble_response(response): - if response.content is None: + if response.data.content is None: raise exceptions.HttpException("Cannot assemble flow with missing content") head = assemble_response_head(response) body = b"".join(assemble_body(response.data.headers, [response.data.content])) diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index a4c341fd..70fffbd4 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -244,7 +244,7 @@ def _read_request_line(rfile): raise exceptions.HttpReadDisconnect("Client disconnected") try: - method, path, http_version = line.split(b" ") + method, path, http_version = line.split() if path == b"*" or path.startswith(b"/"): form = "relative" @@ -291,8 +291,7 @@ def _read_response_line(rfile): raise exceptions.HttpReadDisconnect("Server disconnected") try: - - parts = line.split(b" ", 2) + parts = line.split(None, 2) if len(parts) == 2: # handle missing message gracefully parts.append(b"") diff --git a/netlib/http/http2/__init__.py b/netlib/http/http2/__init__.py index 6a979a0d..60064190 100644 --- a/netlib/http/http2/__init__.py +++ b/netlib/http/http2/__init__.py @@ -1,6 +1,8 @@ from __future__ import absolute_import, print_function, division from netlib.http.http2 import framereader +from netlib.http.http2.utils import parse_headers __all__ = [ "framereader", + "parse_headers", ] diff --git a/netlib/http/http2/utils.py b/netlib/http/http2/utils.py new file mode 100644 index 00000000..164bacc8 --- /dev/null +++ b/netlib/http/http2/utils.py @@ -0,0 +1,37 @@ +from netlib.http import url + + +def parse_headers(headers): + authority = headers.get(':authority', '').encode() + method = headers.get(':method', 'GET').encode() + scheme = headers.get(':scheme', 'https').encode() + path = headers.get(':path', '/').encode() + + headers.pop(":method", None) + headers.pop(":scheme", None) + headers.pop(":path", None) + + host = None + port = None + + if path == b'*' or path.startswith(b"/"): + first_line_format = "relative" + elif method == b'CONNECT': # pragma: no cover + raise NotImplementedError("CONNECT over HTTP/2 is not implemented.") + else: # pragma: no cover + first_line_format = "absolute" + # FIXME: verify if path or :host contains what we need + scheme, host, port, _ = url.parse(path) + + if authority: + host, _, port = authority.partition(b':') + + if not host: + host = b'localhost' + + if not port: + port = 443 if scheme == b'https' else 80 + + port = int(port) + + return first_line_format, method, scheme, host, port, path diff --git a/netlib/http/message.py b/netlib/http/message.py index b633b671..34709f0a 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -1,5 +1,6 @@ from __future__ import absolute_import, print_function, division +import re import warnings import six @@ -51,7 +52,23 @@ class MessageData(basetypes.Serializable): return cls(**state) +class CachedDecode(object): + __slots__ = ["encoded", "encoding", "strict", "decoded"] + + def __init__(self, object, encoding, strict, decoded): + self.encoded = object + self.encoding = encoding + self.strict = strict + self.decoded = decoded + +no_cached_decode = CachedDecode(None, None, None, None) + + class Message(basetypes.Serializable): + def __init__(self): + self._content_cache = no_cached_decode # type: CachedDecode + self._text_cache = no_cached_decode # type: CachedDecode + def __eq__(self, other): if isinstance(other, Message): return self.data == other.data @@ -89,19 +106,82 @@ class Message(basetypes.Serializable): self.data.headers = h @property - def content(self): + def raw_content(self): + # type: () -> bytes """ The raw (encoded) HTTP message body - See also: :py:attr:`text` + See also: :py:attr:`content`, :py:class:`text` """ return self.data.content - @content.setter - def content(self, content): + @raw_content.setter + def raw_content(self, content): self.data.content = content - if isinstance(content, bytes): - self.headers["content-length"] = str(len(content)) + + def get_content(self, strict=True): + # type: (bool) -> bytes + """ + The HTTP message body decoded with the content-encoding header (e.g. gzip) + + Raises: + ValueError, when the content-encoding is invalid and strict is True. + + See also: :py:class:`raw_content`, :py:attr:`text` + """ + if self.raw_content is None: + return None + ce = self.headers.get("content-encoding") + cached = ( + self._content_cache.encoded == self.raw_content and + (self._content_cache.strict or not strict) and + self._content_cache.encoding == ce + ) + if not cached: + is_strict = True + if ce: + try: + decoded = encoding.decode(self.raw_content, ce) + except ValueError: + if strict: + raise + is_strict = False + decoded = self.raw_content + else: + decoded = self.raw_content + self._content_cache = CachedDecode(self.raw_content, ce, is_strict, decoded) + return self._content_cache.decoded + + def set_content(self, value): + if value is None: + self.raw_content = None + return + if not isinstance(value, bytes): + raise TypeError( + "Message content must be bytes, not {}. " + "Please use .text if you want to assign a str." + .format(type(value).__name__) + ) + ce = self.headers.get("content-encoding") + cached = ( + self._content_cache.decoded == value and + self._content_cache.encoding == ce and + self._content_cache.strict + ) + if not cached: + try: + encoded = encoding.encode(value, ce or "identity") + except ValueError: + # So we have an invalid content-encoding? + # Let's remove it! + del self.headers["content-encoding"] + ce = None + encoded = value + self._content_cache = CachedDecode(encoded, ce, True, value) + self.raw_content = self._content_cache.encoded + self.headers["content-length"] = str(len(self.raw_content)) + + content = property(get_content, set_content) @property def http_version(self): @@ -136,56 +216,108 @@ class Message(basetypes.Serializable): def timestamp_end(self, timestamp_end): self.data.timestamp_end = timestamp_end - @property - def text(self): - """ - The decoded HTTP message body. - Decoded contents are not cached, so accessing this attribute repeatedly is relatively expensive. + def _get_content_type_charset(self): + # type: () -> Optional[str] + ct = headers.parse_content_type(self.headers.get("content-type", "")) + if ct: + return ct[2].get("charset") - .. note:: - This is not implemented yet. + def _guess_encoding(self): + # type: () -> str + enc = self._get_content_type_charset() + if enc: + return enc - See also: :py:attr:`content`, :py:class:`decoded` + if "json" in self.headers.get("content-type", ""): + return "utf8" + else: + # We may also want to check for HTML meta tags here at some point. + return "latin-1" + + def get_text(self, strict=True): + # type: (bool) -> six.text_type """ - # This attribute should be called text, because that's what requests does. - raise NotImplementedError() + The HTTP message body decoded with both content-encoding header (e.g. gzip) + and content-type header charset. - @text.setter - def text(self, text): - raise NotImplementedError() + Raises: + ValueError, when either content-encoding or charset is invalid and strict is True. - def decode(self): + See also: :py:attr:`content`, :py:class:`raw_content` + """ + if self.raw_content is None: + return None + enc = self._guess_encoding() + + content = self.get_content(strict) + cached = ( + self._text_cache.encoded == content and + (self._text_cache.strict or not strict) and + self._text_cache.encoding == enc + ) + if not cached: + is_strict = self._content_cache.strict + try: + decoded = encoding.decode(content, enc) + except ValueError: + if strict: + raise + is_strict = False + decoded = self.content.decode("utf8", "replace" if six.PY2 else "surrogateescape") + self._text_cache = CachedDecode(content, enc, is_strict, decoded) + return self._text_cache.decoded + + def set_text(self, text): + if text is None: + self.content = None + return + enc = self._guess_encoding() + + cached = ( + self._text_cache.decoded == text and + self._text_cache.encoding == enc and + self._text_cache.strict + ) + if not cached: + try: + encoded = encoding.encode(text, enc) + except ValueError: + # Fall back to UTF-8 and update the content-type header. + ct = headers.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {}) + ct[2]["charset"] = "utf-8" + self.headers["content-type"] = headers.assemble_content_type(*ct) + enc = "utf8" + encoded = text.encode(enc, "replace" if six.PY2 else "surrogateescape") + self._text_cache = CachedDecode(encoded, enc, True, text) + self.content = self._text_cache.encoded + + text = property(get_text, set_text) + + def decode(self, strict=True): """ - Decodes body based on the current Content-Encoding header, then - removes the header. If there is no Content-Encoding header, no - action is taken. + Decodes body based on the current Content-Encoding header, then + removes the header. If there is no Content-Encoding header, no + action is taken. - Returns: - True, if decoding succeeded. - False, otherwise. + Raises: + ValueError, when the content-encoding is invalid and strict is True. """ - ce = self.headers.get("content-encoding") - data = encoding.decode(ce, self.content) - if data is None: - return False - self.content = data + self.raw_content = self.get_content(strict) self.headers.pop("content-encoding", None) - return True def encode(self, e): """ - Encodes body with the encoding e, where e is "gzip", "deflate" or "identity". + Encodes body with the encoding e, where e is "gzip", "deflate" or "identity". + Any existing content-encodings are overwritten, + the content is not decoded beforehand. - Returns: - True, if decoding succeeded. - False, otherwise. + Raises: + ValueError, when the specified content-encoding is invalid. """ - data = encoding.encode(e, self.content) - if data is None: - return False - self.content = data self.headers["content-encoding"] = e - return True + self.content = self.raw_content + if "content-encoding" not in self.headers: + raise ValueError("Invalid content encoding {}".format(repr(e))) def replace(self, pattern, repl, flags=0): """ @@ -196,13 +328,15 @@ class Message(basetypes.Serializable): Returns: The number of replacements made. """ - # TODO: Proper distinction between text and bytes. + if isinstance(pattern, six.text_type): + pattern = strutils.escaped_str_to_bytes(pattern) + if isinstance(repl, six.text_type): + repl = strutils.escaped_str_to_bytes(repl) replacements = 0 if self.content: - with decoded(self): - self.content, replacements = strutils.safe_subn( - pattern, repl, self.content, flags=flags - ) + self.content, replacements = re.subn( + pattern, repl, self.content, flags=flags + ) replacements += self.headers.replace(pattern, repl, flags) return replacements @@ -221,29 +355,16 @@ class Message(basetypes.Serializable): class decoded(object): """ - A context manager that decodes a request or response, and then - re-encodes it with the same encoding after execution of the block. - - Example: - - .. code-block:: python - - with decoded(request): - request.content = request.content.replace("foo", "bar") + Deprecated: You can now directly use :py:attr:`content`. + :py:attr:`raw_content` has the encoded content. """ - def __init__(self, message): - self.message = message - ce = message.headers.get("content-encoding") - if ce in encoding.ENCODINGS: - self.ce = ce - else: - self.ce = None + def __init__(self, message): # pragma no cover + warnings.warn("decoded() is deprecated, you can now directly use .content instead. " + ".raw_content has the encoded content.", DeprecationWarning) - def __enter__(self): - if self.ce: - self.message.decode() + def __enter__(self): # pragma no cover + pass - def __exit__(self, type, value, tb): - if self.ce: - self.message.encode(self.ce) + def __exit__(self, type, value, tb): # pragma no cover + pass diff --git a/netlib/http/request.py b/netlib/http/request.py index 01801d42..ecaa9b79 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -5,7 +5,6 @@ import re import six from six.moves import urllib -from netlib import encoding from netlib import multidict from netlib import strutils from netlib.http import multipart @@ -23,8 +22,20 @@ host_header_re = re.compile(r"^(?P<host>[^:]+|\[.+\])(?::(?P<port>\d+))?$") class RequestData(message.MessageData): def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=(), content=None, timestamp_start=None, timestamp_end=None): + if isinstance(method, six.text_type): + method = method.encode("ascii", "strict") + if isinstance(scheme, six.text_type): + scheme = scheme.encode("ascii", "strict") + if isinstance(host, six.text_type): + host = host.encode("idna", "strict") + if isinstance(path, six.text_type): + path = path.encode("ascii", "strict") + if isinstance(http_version, six.text_type): + http_version = http_version.encode("ascii", "strict") if not isinstance(headers, nheaders.Headers): headers = nheaders.Headers(headers) + if isinstance(content, six.text_type): + raise ValueError("Content must be bytes, not {}".format(type(content).__name__)) self.first_line_format = first_line_format self.method = method @@ -44,6 +55,7 @@ class Request(message.Message): An HTTP request. """ def __init__(self, *args, **kwargs): + super(Request, self).__init__() self.data = RequestData(*args, **kwargs) def __repr__(self): @@ -65,10 +77,14 @@ class Request(message.Message): Returns: The number of replacements made. """ - # TODO: Proper distinction between text and bytes. + if isinstance(pattern, six.text_type): + pattern = strutils.escaped_str_to_bytes(pattern) + if isinstance(repl, six.text_type): + repl = strutils.escaped_str_to_bytes(repl) + c = super(Request, self).replace(pattern, repl, flags) - self.path, pc = strutils.safe_subn( - pattern, repl, self.path, flags=flags + self.path, pc = re.subn( + pattern, repl, self.data.path, flags=flags ) c += pc return c @@ -102,6 +118,8 @@ class Request(message.Message): """ HTTP request scheme, which should be "http" or "https". """ + if not self.data.scheme: + return self.data.scheme return message._native(self.data.scheme) @scheme.setter @@ -321,7 +339,7 @@ class Request(message.Message): self.headers["accept-encoding"] = ( ', '.join( e - for e in encoding.ENCODINGS + for e in {"gzip", "identity", "deflate"} if e in accept_encoding ) ) @@ -341,7 +359,10 @@ class Request(message.Message): def _get_urlencoded_form(self): is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() if is_valid_content_type: - return tuple(netlib.http.url.decode(self.content)) + try: + return tuple(netlib.http.url.decode(self.content)) + except ValueError: + pass return () def _set_urlencoded_form(self, value): @@ -350,7 +371,7 @@ class Request(message.Message): This will overwrite the existing content if there is one. """ self.headers["content-type"] = "application/x-www-form-urlencoded" - self.content = netlib.http.url.encode(value) + self.content = netlib.http.url.encode(value).encode() @urlencoded_form.setter def urlencoded_form(self, value): @@ -370,7 +391,10 @@ class Request(message.Message): def _get_multipart_form(self): is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() if is_valid_content_type: - return multipart.decode(self.headers, self.content) + try: + return multipart.decode(self.headers, self.content) + except ValueError: + pass return () def _set_multipart_form(self, value): diff --git a/netlib/http/response.py b/netlib/http/response.py index 17d69418..85f54940 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -2,6 +2,7 @@ from __future__ import absolute_import, print_function, division from email.utils import parsedate_tz, formatdate, mktime_tz import time +import six from netlib.http import cookies from netlib.http import headers as nheaders @@ -13,8 +14,14 @@ from netlib import human class ResponseData(message.MessageData): def __init__(self, http_version, status_code, reason=None, headers=(), content=None, timestamp_start=None, timestamp_end=None): + if isinstance(http_version, six.text_type): + http_version = http_version.encode("ascii", "strict") + if isinstance(reason, six.text_type): + reason = reason.encode("ascii", "strict") if not isinstance(headers, nheaders.Headers): headers = nheaders.Headers(headers) + if isinstance(content, six.text_type): + raise ValueError("Content must be bytes, not {}".format(type(content).__name__)) self.http_version = http_version self.status_code = status_code @@ -30,13 +37,14 @@ class Response(message.Message): An HTTP response. """ def __init__(self, *args, **kwargs): + super(Response, self).__init__() self.data = ResponseData(*args, **kwargs) def __repr__(self): - if self.content: + if self.raw_content: details = "{}, {}".format( self.headers.get("content-type", "unknown content type"), - human.pretty_size(len(self.content)) + human.pretty_size(len(self.raw_content)) ) else: details = "no content" diff --git a/netlib/multidict.py b/netlib/multidict.py index 50c879d9..51053ff6 100644 --- a/netlib/multidict.py +++ b/netlib/multidict.py @@ -170,18 +170,10 @@ class _MultiDict(MutableMapping, basetypes.Serializable): else: return super(_MultiDict, self).items() - def clear(self, key): - """ - Removes all items with the specified key, and does not raise an - exception if the key does not exist. - """ - if key in self: - del self[key] - def collect(self): """ Returns a list of (key, value) tuples, where values are either - singular if threre is only one matching item for a key, or a list + singular if there is only one matching item for a key, or a list if there are more than one. The order of the keys matches the order in the underlying fields list. """ @@ -204,18 +196,16 @@ class _MultiDict(MutableMapping, basetypes.Serializable): .. code-block:: python # Simple dict with duplicate values. - >>> d - MultiDictView[("name", "value"), ("a", "false"), ("a", "42")] + >>> d = MultiDict([("name", "value"), ("a", False), ("a", 42)]) >>> d.to_dict() { "name": "value", - "a": ["false", "42"] + "a": [False, 42] } """ - d = {} - for k, v in self.collect(): - d[k] = v - return d + return { + k: v for k, v in self.collect() + } def get_state(self): return self.fields @@ -307,4 +297,4 @@ class MultiDictView(_MultiDict): @fields.setter def fields(self, value): - return self._setter(value) + self._setter(value) diff --git a/netlib/strutils.py b/netlib/strutils.py index 5ad41c7e..32e77927 100644 --- a/netlib/strutils.py +++ b/netlib/strutils.py @@ -1,5 +1,5 @@ +from __future__ import absolute_import, print_function, division import re -import unicodedata import codecs import six @@ -20,68 +20,80 @@ def native(s, *encoding_opts): """ if not isinstance(s, (six.binary_type, six.text_type)): raise TypeError("%r is neither bytes nor unicode" % s) - if six.PY3: - if isinstance(s, six.binary_type): - return s.decode(*encoding_opts) - else: + if six.PY2: if isinstance(s, six.text_type): return s.encode(*encoding_opts) + else: + if isinstance(s, six.binary_type): + return s.decode(*encoding_opts) return s -def clean_bin(s, keep_spacing=True): - """ - Cleans binary data to make it safe to display. +# Translate control characters to "safe" characters. This implementation initially +# replaced them with the matching control pictures (http://unicode.org/charts/PDF/U2400.pdf), +# but that turned out to render badly with monospace fonts. We are back to "." therefore. +_control_char_trans = { + x: ord(".") # x + 0x2400 for unicode control group pictures + for x in range(32) +} +_control_char_trans[127] = ord(".") # 0x2421 +_control_char_trans_newline = _control_char_trans.copy() +for x in ("\r", "\n", "\t"): + del _control_char_trans_newline[ord(x)] - Args: - keep_spacing: If False, tabs and newlines will also be replaced. - """ - if isinstance(s, six.text_type): - if keep_spacing: - keep = u" \n\r\t" - else: - keep = u" " - return u"".join( - ch if (unicodedata.category(ch)[0] not in "CZ" or ch in keep) else u"." - for ch in s - ) - else: - if keep_spacing: - keep = (9, 10, 13) # \t, \n, \r, - else: - keep = () - return b"".join( - six.int2byte(ch) if (31 < ch < 127 or ch in keep) else b"." - for ch in six.iterbytes(s) - ) +if six.PY2: + pass +else: + _control_char_trans = str.maketrans(_control_char_trans) + _control_char_trans_newline = str.maketrans(_control_char_trans_newline) -def safe_subn(pattern, repl, target, *args, **kwargs): + +def escape_control_characters(text, keep_spacing=True): """ - There are Unicode conversion problems with re.subn. We try to smooth - that over by casting the pattern and replacement to strings. We really - need a better solution that is aware of the actual content ecoding. + Replace all unicode C1 control characters from the given text with their respective control pictures. + For example, a null byte is replaced with the unicode character "\u2400". + + Args: + keep_spacing: If True, tabs and newlines will not be replaced. """ - return re.subn(str(pattern), str(repl), target, *args, **kwargs) + # type: (six.string_types) -> six.text_type + if not isinstance(text, six.string_types): + raise ValueError("text type must be unicode but is {}".format(type(text).__name__)) + + trans = _control_char_trans_newline if keep_spacing else _control_char_trans + if six.PY2: + return u"".join( + six.unichr(trans.get(ord(ch), ord(ch))) + for ch in text + ) + return text.translate(trans) -def bytes_to_escaped_str(data): +def bytes_to_escaped_str(data, keep_spacing=False): """ Take bytes and return a safe string that can be displayed to the user. Single quotes are always escaped, double quotes are never escaped: "'" + bytes_to_escaped_str(...) + "'" gives a valid Python string. + + Args: + keep_spacing: If True, tabs and newlines will not be escaped. """ - # TODO: We may want to support multi-byte characters without escaping them. - # One way to do would be calling .decode("utf8", "backslashreplace") first - # and then escaping UTF8 control chars (see clean_bin). if not isinstance(data, bytes): raise ValueError("data must be bytes, but is {}".format(data.__class__.__name__)) # We always insert a double-quote here so that we get a single-quoted string back # https://stackoverflow.com/questions/29019340/why-does-python-use-different-quotes-for-representing-strings-depending-on-their - return repr(b'"' + data).lstrip("b")[2:-1] + ret = repr(b'"' + data).lstrip("b")[2:-1] + if keep_spacing: + ret = re.sub( + r"(?<!\\)(\\\\)*\\([nrt])", + lambda m: (m.group(1) or "") + dict(n="\n", r="\r", t="\t")[m.group(2)], + ret + ) + return ret def escaped_str_to_bytes(data): @@ -103,24 +115,17 @@ def escaped_str_to_bytes(data): return codecs.escape_decode(data)[0] -def isBin(s): - """ - Does this string have any non-ASCII characters? - """ - for i in s: - i = ord(i) - if i < 9 or 13 < i < 32 or 126 < i: - return True - return False - - -def isMostlyBin(s): - s = s[:100] - return sum(isBin(ch) for ch in s) / len(s) > 0.3 +def is_mostly_bin(s): + # type: (bytes) -> bool + return sum( + i < 9 or 13 < i < 32 or 126 < i + for i in six.iterbytes(s[:100]) + ) / len(s[:100]) > 0.3 -def isXML(s): - return s.strip().startswith("<") +def is_xml(s): + # type: (bytes) -> bool + return s.strip().startswith(b"<") def clean_hanging_newline(t): @@ -141,8 +146,12 @@ def hexdump(s): A generator of (offset, hex, str) tuples """ for i in range(0, len(s), 16): - offset = "{:0=10x}".format(i).encode() + offset = "{:0=10x}".format(i) part = s[i:i + 16] - x = b" ".join("{:0=2x}".format(i).encode() for i in six.iterbytes(part)) + x = " ".join("{:0=2x}".format(i) for i in six.iterbytes(part)) x = x.ljust(47) # 16*2 + 15 - yield (offset, x, clean_bin(part, False)) + part_repr = native(escape_control_characters( + part.decode("ascii", "replace").replace(u"\ufffd", u"."), + False + )) + yield (offset, x, part_repr) diff --git a/netlib/tcp.py b/netlib/tcp.py index 69dafc1f..cf099edd 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -676,7 +676,7 @@ class TCPClient(_Connection): self.connection = SSL.Connection(context, self.connection) if sni: self.sni = sni - self.connection.set_tlsext_host_name(sni) + self.connection.set_tlsext_host_name(sni.encode("idna")) self.connection.set_connect_state() try: self.connection.do_handshake() @@ -705,7 +705,7 @@ class TCPClient(_Connection): if self.cert.cn: crt["subject"] = [[["commonName", self.cert.cn.decode("ascii", "strict")]]] if sni: - hostname = sni.decode("ascii", "strict") + hostname = sni else: hostname = "no-hostname" ssl_match_hostname.match_hostname(crt, hostname) diff --git a/netlib/utils.py b/netlib/utils.py index 79340cbd..9eebf22c 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -56,6 +56,13 @@ class Data(object): dirname = os.path.dirname(inspect.getsourcefile(m)) self.dirname = os.path.abspath(dirname) + def push(self, subpath): + """ + Change the data object to a path relative to the module. + """ + self.dirname = os.path.join(self.dirname, subpath) + return self + def path(self, path): """ Returns a path to the package data housed at 'path' under this @@ -73,11 +80,9 @@ _label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE) def is_valid_host(host): + # type: (bytes) -> bool """ Checks if a hostname is valid. - - Args: - host (bytes): The hostname """ try: host.decode("idna") diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py index 42196ffb..7d355699 100644 --- a/netlib/websockets/frame.py +++ b/netlib/websockets/frame.py @@ -255,7 +255,7 @@ class Frame(object): def __repr__(self): ret = repr(self.header) if self.payload: - ret = ret + "\nPayload:\n" + strutils.clean_bin(self.payload).decode("ascii") + ret = ret + "\nPayload:\n" + strutils.bytes_to_escaped_str(self.payload) return ret def human_readable(self): diff --git a/netlib/wsgi.py b/netlib/wsgi.py index c66fddc2..0def75b5 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -54,6 +54,10 @@ class WSGIAdaptor(object): self.app, self.domain, self.port, self.sversion = app, domain, port, sversion def make_environ(self, flow, errsoc, **extra): + """ + Raises: + ValueError, if the content-encoding is invalid. + """ path = strutils.native(flow.request.path, "latin-1") if '?' in path: path_info, query = strutils.native(path, "latin-1").split('?', 1) diff --git a/pathod/language/http.py b/pathod/language/http.py index 5bd6e385..fdc5bba6 100644 --- a/pathod/language/http.py +++ b/pathod/language/http.py @@ -181,7 +181,7 @@ class Response(_HTTPMessage): l.append( status_codes.RESPONSES.get( status_code, - b"Unknown code" + "Unknown code" ).encode() ) return l diff --git a/pathod/language/http2.py b/pathod/language/http2.py index 2693446e..c0313baa 100644 --- a/pathod/language/http2.py +++ b/pathod/language/http2.py @@ -273,7 +273,7 @@ class Request(_HTTP2Message): req = http.Request( b'', self.method.string(), - b'', + b'http', b'', b'', path, diff --git a/pathod/log.py b/pathod/log.py index 23e9a2ce..47837101 100644 --- a/pathod/log.py +++ b/pathod/log.py @@ -62,7 +62,14 @@ class LogCtx(object): for line in strutils.hexdump(data): self("\t%s %s %s" % line) else: - for i in strutils.clean_bin(data).split(b"\n"): + data = strutils.native( + strutils.escape_control_characters( + data + .decode("ascii", "replace") + .replace(u"\ufffd", u".") + ) + ) + for i in data.split("\n"): self("\t%s" % i) def __call__(self, line): diff --git a/pathod/pathod.py b/pathod/pathod.py index 3df86aae..7087cba6 100644 --- a/pathod/pathod.py +++ b/pathod/pathod.py @@ -89,7 +89,10 @@ class PathodHandler(tcp.BaseHandler): self.http2_framedump = http2_framedump def handle_sni(self, connection): - self.sni = connection.get_servername() + sni = connection.get_servername() + if sni: + sni = sni.decode("idna") + self.sni = sni def http_serve_crafted(self, crafted, logctx): error, crafted = self.server.check_policy( diff --git a/pathod/protocols/http2.py b/pathod/protocols/http2.py index c8728940..5ad120de 100644 --- a/pathod/protocols/http2.py +++ b/pathod/protocols/http2.py @@ -7,8 +7,7 @@ import hyperframe.frame from hpack.hpack import Encoder, Decoder from netlib import utils, strutils -from netlib.http import url -from netlib.http.http2 import framereader +from netlib.http import http2 import netlib.http.headers import netlib.http.response import netlib.http.request @@ -101,46 +100,15 @@ class HTTP2StateProtocol(object): timestamp_end = time.time() - authority = headers.get(':authority', b'') - method = headers.get(':method', 'GET') - scheme = headers.get(':scheme', 'https') - path = headers.get(':path', '/') - - headers.clear(":method") - headers.clear(":scheme") - headers.clear(":path") - - host = None - port = None - - if path == '*' or path.startswith("/"): - first_line_format = "relative" - elif method == 'CONNECT': - first_line_format = "authority" - if ":" in authority: - host, port = authority.split(":", 1) - else: - host = authority - else: - first_line_format = "absolute" - # FIXME: verify if path or :host contains what we need - scheme, host, port, _ = url.parse(path) - scheme = scheme.decode('ascii') - host = host.decode('ascii') - - if host is None: - host = 'localhost' - if port is None: - port = 80 if scheme == 'http' else 443 - port = int(port) + first_line_format, method, scheme, host, port, path = http2.parse_headers(headers) request = netlib.http.request.Request( first_line_format, - method.encode('ascii'), - scheme.encode('ascii'), - host.encode('ascii'), + method, + scheme, + host, port, - path.encode('ascii'), + path, b"HTTP/2.0", headers, body, @@ -213,10 +181,10 @@ class HTTP2StateProtocol(object): headers = request.headers.copy() if ':authority' not in headers: - headers.insert(0, b':authority', authority.encode('ascii')) - headers.insert(0, b':scheme', request.scheme.encode('ascii')) - headers.insert(0, b':path', request.path.encode('ascii')) - headers.insert(0, b':method', request.method.encode('ascii')) + headers.insert(0, ':authority', authority) + headers.insert(0, ':scheme', request.scheme) + headers.insert(0, ':path', request.path) + headers.insert(0, ':method', request.method) if hasattr(request, 'stream_id'): stream_id = request.stream_id @@ -286,7 +254,7 @@ class HTTP2StateProtocol(object): def read_frame(self, hide=False): while True: - frm = framereader.http2_read_frame(self.tcp_handler.rfile) + frm = http2.framereader.http2_read_frame(self.tcp_handler.rfile) if not hide and self.dump_frames: # pragma no cover print(frm.human_readable("<<")) @@ -429,7 +397,7 @@ class HTTP2StateProtocol(object): self._handle_unexpected_frame(frm) headers = netlib.http.headers.Headers( - (k.encode('ascii'), v.encode('ascii')) for k, v in self.decoder.decode(header_blocks) + [[k, v] for k, v in self.decoder.decode(header_blocks, raw=True)] ) return stream_id, headers, body diff --git a/release/rtool.py b/release/rtool.py index 04e1249d..4e43eaef 100755 --- a/release/rtool.py +++ b/release/rtool.py @@ -76,7 +76,7 @@ def get_snapshot_version(): return "{version}dev{tag_dist:04}-0x{commit}".format( version=get_version(), # this should already be the next version tag_dist=tag_dist, - commit=commit + commit=commit.decode() ) diff --git a/release/setup.py b/release/setup.py index 9876af0a..78155140 100644 --- a/release/setup.py +++ b/release/setup.py @@ -10,7 +10,7 @@ setup( "virtualenv>=14.0.5, <14.1", "wheel>=0.29.0, <0.30", "six>=1.10.0, <1.11", - "pysftp>=0.2.8, <0.3", + "pysftp>=0.2.8, !=0.2.9, <0.3", ], entry_points={ "console_scripts": [ @@ -2,13 +2,13 @@ from setuptools import setup, find_packages from codecs import open import os +from netlib import version + # Based on https://github.com/pypa/sampleproject/blob/master/setup.py # and https://python-packaging-user-guide.readthedocs.org/ here = os.path.abspath(os.path.dirname(__file__)) -from netlib import version - with open(os.path.join(here, 'README.rst'), encoding='utf-8') as f: long_description = f.read() @@ -32,6 +32,8 @@ setup( "Programming Language :: Python", "Programming Language :: Python :: 2", "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.5", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", "Topic :: Security", @@ -66,13 +68,12 @@ setup( "construct>=2.5.2, <2.6", "cryptography>=1.3, <1.5", "Flask>=0.10.1, <0.12", - "h2>=2.3.1, <3", + "h2>=2.4.0, <3", "html2text>=2016.1.8, <=2016.5.29", "hyperframe>=4.0.1, <5", "lxml>=3.5.0, <3.7", - "Pillow>=3.2, <3.3", + "Pillow>=3.2, <3.4", "passlib>=1.6.5, <1.7", - "psutil>=4.2, <4.4", "pyasn1>=0.1.9, <0.2", "pyOpenSSL>=16.0, <17.0", "pyparsing>=2.1.3, <2.2", @@ -99,10 +100,10 @@ setup( 'dev': [ "tox>=2.3, <3", "mock>=2.0, <2.1", - "pytest>=2.8.7, <2.10", - "pytest-cov>=2.2.1, <2.3", - "pytest-timeout>=1.0.0, <1.1", - "pytest-xdist>=1.14, <1.15", + "pytest>=2.8.7, <3", + "pytest-cov>=2.2.1, <3", + "pytest-timeout>=1.0.0, <2", + "pytest-xdist>=1.14, <2", "sphinx>=1.3.5, <1.5", "sphinx-autobuild>=0.5.2, <0.7", "sphinxcontrib-documentedlist>=0.4.0, <0.5", diff --git a/test/mitmproxy/builtins/__init__.py b/test/mitmproxy/builtins/__init__.py new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/test/mitmproxy/builtins/__init__.py diff --git a/test/mitmproxy/builtins/test_anticache.py b/test/mitmproxy/builtins/test_anticache.py new file mode 100644 index 00000000..127e1c1a --- /dev/null +++ b/test/mitmproxy/builtins/test_anticache.py @@ -0,0 +1,23 @@ +from .. import tutils, mastertest +from mitmproxy.builtins import anticache +from mitmproxy.flow import master +from mitmproxy.flow import state +from mitmproxy.flow import options + + +class TestAntiCache(mastertest.MasterTest): + def test_simple(self): + s = state.State() + m = master.FlowMaster(options.Options(anticache = True), None, s) + sa = anticache.AntiCache() + m.addons.add(sa) + + f = tutils.tflow(resp=True) + self.invoke(m, "request", f) + + f = tutils.tflow(resp=True) + f.request.headers["if-modified-since"] = "test" + f.request.headers["if-none-match"] = "test" + self.invoke(m, "request", f) + assert "if-modified-since" not in f.request.headers + assert "if-none-match" not in f.request.headers diff --git a/test/mitmproxy/builtins/test_anticomp.py b/test/mitmproxy/builtins/test_anticomp.py new file mode 100644 index 00000000..601e56c8 --- /dev/null +++ b/test/mitmproxy/builtins/test_anticomp.py @@ -0,0 +1,22 @@ +from .. import tutils, mastertest +from mitmproxy.builtins import anticomp +from mitmproxy.flow import master +from mitmproxy.flow import state +from mitmproxy.flow import options + + +class TestAntiComp(mastertest.MasterTest): + def test_simple(self): + s = state.State() + m = master.FlowMaster(options.Options(anticomp = True), None, s) + sa = anticomp.AntiComp() + m.addons.add(sa) + + f = tutils.tflow(resp=True) + self.invoke(m, "request", f) + + f = tutils.tflow(resp=True) + + f.request.headers["Accept-Encoding"] = "foobar" + self.invoke(m, "request", f) + assert f.request.headers["Accept-Encoding"] == "identity" diff --git a/test/mitmproxy/builtins/test_dumper.py b/test/mitmproxy/builtins/test_dumper.py new file mode 100644 index 00000000..57e3d036 --- /dev/null +++ b/test/mitmproxy/builtins/test_dumper.py @@ -0,0 +1,86 @@ +from .. import tutils, mastertest +from six.moves import cStringIO as StringIO + +from mitmproxy.builtins import dumper +from mitmproxy.flow import state +from mitmproxy import exceptions +from mitmproxy import dump +from mitmproxy import models +import netlib.tutils +import mock + + +class TestDumper(mastertest.MasterTest): + def test_simple(self): + d = dumper.Dumper() + sio = StringIO() + + d.configure(dump.Options(tfile = sio, flow_detail = 0)) + d.response(tutils.tflow()) + assert not sio.getvalue() + + d.configure(dump.Options(tfile = sio, flow_detail = 4)) + d.response(tutils.tflow()) + assert sio.getvalue() + + sio = StringIO() + d.configure(dump.Options(tfile = sio, flow_detail = 4)) + d.response(tutils.tflow(resp=True)) + assert "<<" in sio.getvalue() + + sio = StringIO() + d.configure(dump.Options(tfile = sio, flow_detail = 4)) + d.response(tutils.tflow(err=True)) + assert "<<" in sio.getvalue() + + sio = StringIO() + d.configure(dump.Options(tfile = sio, flow_detail = 4)) + flow = tutils.tflow() + flow.request = netlib.tutils.treq() + flow.request.stickycookie = True + flow.client_conn = mock.MagicMock() + flow.client_conn.address.host = "foo" + flow.response = netlib.tutils.tresp(content=None) + flow.response.is_replay = True + flow.response.status_code = 300 + d.response(flow) + assert sio.getvalue() + + sio = StringIO() + d.configure(dump.Options(tfile = sio, flow_detail = 4)) + flow = tutils.tflow(resp=netlib.tutils.tresp(content=b"{")) + flow.response.headers["content-type"] = "application/json" + flow.response.status_code = 400 + d.response(flow) + assert sio.getvalue() + + sio = StringIO() + d.configure(dump.Options(tfile = sio)) + flow = tutils.tflow() + flow.request.content = None + flow.response = models.HTTPResponse.wrap(netlib.tutils.tresp()) + flow.response.content = None + d.response(flow) + assert "content missing" in sio.getvalue() + + +class TestContentView(mastertest.MasterTest): + @mock.patch("mitmproxy.contentviews.get_content_view") + def test_contentview(self, get_content_view): + se = exceptions.ContentViewException(""), ("x", iter([])) + get_content_view.side_effect = se + + s = state.State() + sio = StringIO() + m = mastertest.RecordingMaster( + dump.Options( + flow_detail=4, + verbosity=3, + tfile=sio, + ), + None, s + ) + d = dumper.Dumper() + m.addons.add(d) + self.invoke(m, "response", tutils.tflow()) + assert "Content viewer failed" in m.event_log[0][1] diff --git a/test/mitmproxy/builtins/test_filestreamer.py b/test/mitmproxy/builtins/test_filestreamer.py new file mode 100644 index 00000000..002006b7 --- /dev/null +++ b/test/mitmproxy/builtins/test_filestreamer.py @@ -0,0 +1,46 @@ +from __future__ import absolute_import, print_function, division + +from .. import tutils, mastertest + +import os.path + +from mitmproxy.builtins import filestreamer +from mitmproxy.flow import master, FlowReader +from mitmproxy.flow import state +from mitmproxy.flow import options + + +class TestStream(mastertest.MasterTest): + def test_stream(self): + with tutils.tmpdir() as tdir: + p = os.path.join(tdir, "foo") + + def r(): + r = FlowReader(open(p, "rb")) + return list(r.stream()) + + s = state.State() + m = master.FlowMaster( + options.Options( + outfile = (p, "wb") + ), + None, + s + ) + sa = filestreamer.FileStreamer() + + m.addons.add(sa) + f = tutils.tflow(resp=True) + self.invoke(m, "request", f) + self.invoke(m, "response", f) + m.addons.remove(sa) + + assert r()[0].response + + m.options.outfile = (p, "ab") + + m.addons.add(sa) + f = tutils.tflow() + self.invoke(m, "request", f) + m.addons.remove(sa) + assert not r()[1].response diff --git a/test/mitmproxy/builtins/test_replace.py b/test/mitmproxy/builtins/test_replace.py new file mode 100644 index 00000000..f8010bec --- /dev/null +++ b/test/mitmproxy/builtins/test_replace.py @@ -0,0 +1,52 @@ +from .. import tutils, mastertest +from mitmproxy.builtins import replace +from mitmproxy.flow import master +from mitmproxy.flow import state +from mitmproxy.flow import options + + +class TestReplace(mastertest.MasterTest): + def test_configure(self): + r = replace.Replace() + r.configure(options.Options( + replacements=[("one", "two", "three")] + )) + tutils.raises( + "invalid filter pattern", + r.configure, + options.Options( + replacements=[("~b", "two", "three")] + ) + ) + tutils.raises( + "invalid regular expression", + r.configure, + options.Options( + replacements=[("foo", "+", "three")] + ) + ) + + def test_simple(self): + s = state.State() + m = master.FlowMaster( + options.Options( + replacements = [ + ("~q", "foo", "bar"), + ("~s", "foo", "bar"), + ] + ), + None, + s + ) + sa = replace.Replace() + m.addons.add(sa) + + f = tutils.tflow() + f.request.content = b"foo" + self.invoke(m, "request", f) + assert f.request.content == b"bar" + + f = tutils.tflow(resp=True) + f.response.content = b"foo" + self.invoke(m, "response", f) + assert f.response.content == b"bar" diff --git a/test/mitmproxy/builtins/test_script.py b/test/mitmproxy/builtins/test_script.py new file mode 100644 index 00000000..c9616249 --- /dev/null +++ b/test/mitmproxy/builtins/test_script.py @@ -0,0 +1,191 @@ +import time + +from mitmproxy.builtins import script +from mitmproxy import exceptions +from mitmproxy.flow import master +from mitmproxy.flow import state +from mitmproxy.flow import options + +from .. import tutils, mastertest + + +class TestParseCommand: + def test_empty_command(self): + with tutils.raises(exceptions.AddonError): + script.parse_command("") + + with tutils.raises(exceptions.AddonError): + script.parse_command(" ") + + def test_no_script_file(self): + with tutils.raises("not found"): + script.parse_command("notfound") + + with tutils.tmpdir() as dir: + with tutils.raises("not a file"): + script.parse_command(dir) + + def test_parse_args(self): + with tutils.chdir(tutils.test_data.dirname): + assert script.parse_command("data/addonscripts/recorder.py") == ("data/addonscripts/recorder.py", []) + assert script.parse_command("data/addonscripts/recorder.py foo bar") == ("data/addonscripts/recorder.py", ["foo", "bar"]) + assert script.parse_command("data/addonscripts/recorder.py 'foo bar'") == ("data/addonscripts/recorder.py", ["foo bar"]) + + @tutils.skip_not_windows + def test_parse_windows(self): + with tutils.chdir(tutils.test_data.dirname): + assert script.parse_command( + "data\\addonscripts\\recorder.py" + ) == ("data\\addonscripts\\recorder.py", []) + assert script.parse_command( + "data\\addonscripts\\recorder.py 'foo \\ bar'" + ) == ("data\\addonscripts\\recorder.py", ['foo \\ bar']) + + +def test_load_script(): + ns = script.load_script( + tutils.test_data.path( + "data/addonscripts/recorder.py" + ), [] + ) + assert ns["configure"] + + +class TestScript(mastertest.MasterTest): + def test_simple(self): + s = state.State() + m = master.FlowMaster(options.Options(), None, s) + sc = script.Script( + tutils.test_data.path( + "data/addonscripts/recorder.py" + ) + ) + m.addons.add(sc) + assert sc.ns["call_log"] == [ + ("solo", "start", (), {}), + ("solo", "configure", (options.Options(),), {}) + ] + + sc.ns["call_log"] = [] + f = tutils.tflow(resp=True) + self.invoke(m, "request", f) + + recf = sc.ns["call_log"][0] + assert recf[1] == "request" + + def test_reload(self): + s = state.State() + m = mastertest.RecordingMaster(options.Options(), None, s) + with tutils.tmpdir(): + with open("foo.py", "w"): + pass + sc = script.Script("foo.py") + m.addons.add(sc) + + for _ in range(100): + with open("foo.py", "a") as f: + f.write(".") + m.addons.invoke_with_context(sc, "tick") + time.sleep(0.1) + if m.event_log: + return + raise AssertionError("Change event not detected.") + + def test_exception(self): + s = state.State() + m = mastertest.RecordingMaster(options.Options(), None, s) + sc = script.Script( + tutils.test_data.path("data/addonscripts/error.py") + ) + m.addons.add(sc) + f = tutils.tflow(resp=True) + self.invoke(m, "request", f) + assert m.event_log[0][0] == "error" + + def test_duplicate_flow(self): + s = state.State() + fm = master.FlowMaster(None, None, s) + fm.addons.add( + script.Script( + tutils.test_data.path("data/addonscripts/duplicate_flow.py") + ) + ) + f = tutils.tflow() + fm.request(f) + assert fm.state.flow_count() == 2 + assert not fm.state.view[0].request.is_replay + assert fm.state.view[1].request.is_replay + + +class TestScriptLoader(mastertest.MasterTest): + def test_simple(self): + s = state.State() + o = options.Options(scripts=[]) + m = master.FlowMaster(o, None, s) + sc = script.ScriptLoader() + m.addons.add(sc) + assert len(m.addons) == 1 + o.update( + scripts = [ + tutils.test_data.path("data/addonscripts/recorder.py") + ] + ) + assert len(m.addons) == 2 + o.update(scripts = []) + assert len(m.addons) == 1 + + def test_dupes(self): + s = state.State() + o = options.Options(scripts=["one", "one"]) + m = master.FlowMaster(o, None, s) + sc = script.ScriptLoader() + tutils.raises(exceptions.OptionsError, m.addons.add, sc) + + def test_order(self): + rec = tutils.test_data.path("data/addonscripts/recorder.py") + + s = state.State() + o = options.Options( + scripts = [ + "%s %s" % (rec, "a"), + "%s %s" % (rec, "b"), + "%s %s" % (rec, "c"), + ] + ) + m = mastertest.RecordingMaster(o, None, s) + sc = script.ScriptLoader() + m.addons.add(sc) + + debug = [(i[0], i[1]) for i in m.event_log if i[0] == "debug"] + assert debug == [ + ('debug', 'a start'), ('debug', 'a configure'), + ('debug', 'b start'), ('debug', 'b configure'), + ('debug', 'c start'), ('debug', 'c configure') + ] + m.event_log[:] = [] + + o.scripts = [ + "%s %s" % (rec, "c"), + "%s %s" % (rec, "a"), + "%s %s" % (rec, "b"), + ] + debug = [(i[0], i[1]) for i in m.event_log if i[0] == "debug"] + assert debug == [ + ('debug', 'c configure'), + ('debug', 'a configure'), + ('debug', 'b configure'), + ] + m.event_log[:] = [] + + o.scripts = [ + "%s %s" % (rec, "x"), + "%s %s" % (rec, "a"), + ] + debug = [(i[0], i[1]) for i in m.event_log if i[0] == "debug"] + assert debug == [ + ('debug', 'c done'), + ('debug', 'b done'), + ('debug', 'x start'), + ('debug', 'x configure'), + ('debug', 'a configure'), + ] diff --git a/test/mitmproxy/builtins/test_setheaders.py b/test/mitmproxy/builtins/test_setheaders.py new file mode 100644 index 00000000..1a8d048c --- /dev/null +++ b/test/mitmproxy/builtins/test_setheaders.py @@ -0,0 +1,64 @@ +from .. import tutils, mastertest + +from mitmproxy.builtins import setheaders +from mitmproxy.flow import state +from mitmproxy.flow import options + + +class TestSetHeaders(mastertest.MasterTest): + def mkmaster(self, **opts): + s = state.State() + m = mastertest.RecordingMaster(options.Options(**opts), None, s) + sh = setheaders.SetHeaders() + m.addons.add(sh) + return m, sh + + def test_configure(self): + sh = setheaders.SetHeaders() + tutils.raises( + "invalid setheader filter pattern", + sh.configure, + options.Options( + setheaders = [("~b", "one", "two")] + ) + ) + + def test_setheaders(self): + m, sh = self.mkmaster( + setheaders = [ + ("~q", "one", "two"), + ("~s", "one", "three") + ] + ) + f = tutils.tflow() + f.request.headers["one"] = "xxx" + self.invoke(m, "request", f) + assert f.request.headers["one"] == "two" + + f = tutils.tflow(resp=True) + f.response.headers["one"] = "xxx" + self.invoke(m, "response", f) + assert f.response.headers["one"] == "three" + + m, sh = self.mkmaster( + setheaders = [ + ("~s", "one", "two"), + ("~s", "one", "three") + ] + ) + f = tutils.tflow(resp=True) + f.request.headers["one"] = "xxx" + f.response.headers["one"] = "xxx" + self.invoke(m, "response", f) + assert f.response.headers.get_all("one") == ["two", "three"] + + m, sh = self.mkmaster( + setheaders = [ + ("~q", "one", "two"), + ("~q", "one", "three") + ] + ) + f = tutils.tflow() + f.request.headers["one"] = "xxx" + self.invoke(m, "request", f) + assert f.request.headers.get_all("one") == ["two", "three"] diff --git a/test/mitmproxy/builtins/test_stickyauth.py b/test/mitmproxy/builtins/test_stickyauth.py new file mode 100644 index 00000000..1e617402 --- /dev/null +++ b/test/mitmproxy/builtins/test_stickyauth.py @@ -0,0 +1,23 @@ +from .. import tutils, mastertest +from mitmproxy.builtins import stickyauth +from mitmproxy.flow import master +from mitmproxy.flow import state +from mitmproxy.flow import options + + +class TestStickyAuth(mastertest.MasterTest): + def test_simple(self): + s = state.State() + m = master.FlowMaster(options.Options(stickyauth = ".*"), None, s) + sa = stickyauth.StickyAuth() + m.addons.add(sa) + + f = tutils.tflow(resp=True) + f.request.headers["authorization"] = "foo" + self.invoke(m, "request", f) + + assert "address" in sa.hosts + + f = tutils.tflow(resp=True) + self.invoke(m, "request", f) + assert f.request.headers["authorization"] == "foo" diff --git a/test/mitmproxy/builtins/test_stickycookie.py b/test/mitmproxy/builtins/test_stickycookie.py new file mode 100644 index 00000000..b8d703bd --- /dev/null +++ b/test/mitmproxy/builtins/test_stickycookie.py @@ -0,0 +1,131 @@ +from .. import tutils, mastertest +from mitmproxy.builtins import stickycookie +from mitmproxy.flow import master +from mitmproxy.flow import state +from mitmproxy.flow import options +from netlib import tutils as ntutils + + +def test_domain_match(): + assert stickycookie.domain_match("www.google.com", ".google.com") + assert stickycookie.domain_match("google.com", ".google.com") + + +class TestStickyCookie(mastertest.MasterTest): + def mk(self): + s = state.State() + m = master.FlowMaster(options.Options(stickycookie = ".*"), None, s) + sc = stickycookie.StickyCookie() + m.addons.add(sc) + return s, m, sc + + def test_config(self): + sc = stickycookie.StickyCookie() + tutils.raises( + "invalid filter", + sc.configure, + options.Options(stickycookie = "~b") + ) + + def test_simple(self): + s, m, sc = self.mk() + m.addons.add(sc) + + f = tutils.tflow(resp=True) + f.response.headers["set-cookie"] = "foo=bar" + self.invoke(m, "request", f) + + f.reply.acked = False + self.invoke(m, "response", f) + + assert sc.jar + assert "cookie" not in f.request.headers + + f = f.copy() + f.reply.acked = False + self.invoke(m, "request", f) + assert f.request.headers["cookie"] == "foo=bar" + + def _response(self, s, m, sc, cookie, host): + f = tutils.tflow(req=ntutils.treq(host=host, port=80), resp=True) + f.response.headers["Set-Cookie"] = cookie + self.invoke(m, "response", f) + return f + + def test_response(self): + s, m, sc = self.mk() + + c = "SSID=mooo; domain=.google.com, FOO=bar; Domain=.google.com; Path=/; " \ + "Expires=Wed, 13-Jan-2021 22:23:01 GMT; Secure; " + + self._response(s, m, sc, c, "host") + assert not sc.jar.keys() + + self._response(s, m, sc, c, "www.google.com") + assert sc.jar.keys() + + sc.jar.clear() + self._response( + s, m, sc, "SSID=mooo", "www.google.com" + ) + assert list(sc.jar.keys())[0] == ('www.google.com', 80, '/') + + def test_response_multiple(self): + s, m, sc = self.mk() + + # Test setting of multiple cookies + c1 = "somecookie=test; Path=/" + c2 = "othercookie=helloworld; Path=/" + f = self._response(s, m, sc, c1, "www.google.com") + f.response.headers["Set-Cookie"] = c2 + self.invoke(m, "response", f) + googlekey = list(sc.jar.keys())[0] + assert len(sc.jar[googlekey].keys()) == 2 + + def test_response_weird(self): + s, m, sc = self.mk() + + # Test setting of weird cookie keys + f = tutils.tflow(req=ntutils.treq(host="www.google.com", port=80), resp=True) + cs = [ + "foo/bar=hello", + "foo:bar=world", + "foo@bar=fizz", + "foo,bar=buzz", + ] + for c in cs: + f.response.headers["Set-Cookie"] = c + self.invoke(m, "response", f) + googlekey = list(sc.jar.keys())[0] + assert len(sc.jar[googlekey].keys()) == len(cs) + + def test_response_overwrite(self): + s, m, sc = self.mk() + + # Test overwriting of a cookie value + c1 = "somecookie=helloworld; Path=/" + c2 = "somecookie=newvalue; Path=/" + f = self._response(s, m, sc, c1, "www.google.com") + f.response.headers["Set-Cookie"] = c2 + self.invoke(m, "response", f) + googlekey = list(sc.jar.keys())[0] + assert len(sc.jar[googlekey].keys()) == 1 + assert list(sc.jar[googlekey]["somecookie"].items())[0][1] == "newvalue" + + def test_response_delete(self): + s, m, sc = self.mk() + + # Test that a cookie is be deleted + # by setting the expire time in the past + f = self._response(s, m, sc, "duffer=zafar; Path=/", "www.google.com") + f.response.headers["Set-Cookie"] = "duffer=; Expires=Thu, 01-Jan-1970 00:00:00 GMT" + self.invoke(m, "response", f) + assert not sc.jar.keys() + + def test_request(self): + s, m, sc = self.mk() + + f = self._response(s, m, sc, "SSID=mooo", "www.google.com") + assert "cookie" not in f.request.headers + self.invoke(m, "request", f) + assert "cookie" in f.request.headers diff --git a/test/mitmproxy/console/test_master.py b/test/mitmproxy/console/test_master.py index 33261c28..b84e4c1c 100644 --- a/test/mitmproxy/console/test_master.py +++ b/test/mitmproxy/console/test_master.py @@ -111,12 +111,14 @@ def test_options(): class TestMaster(mastertest.MasterTest): - def mkmaster(self, filt, **options): - o = console.master.Options(filtstr=filt, **options) + def mkmaster(self, **options): + if "verbosity" not in options: + options["verbosity"] = 0 + o = console.master.Options(**options) return console.master.ConsoleMaster(None, o) def test_basic(self): - m = self.mkmaster(None) + m = self.mkmaster() for i in (1, 2, 3): - self.dummy_cycle(m, 1, "") + self.dummy_cycle(m, 1, b"") assert len(m.state.flows) == i diff --git a/test/mitmproxy/data/scripts/concurrent_decorator.py b/test/mitmproxy/data/addonscripts/concurrent_decorator.py index e017f605..a56c2af1 100644 --- a/test/mitmproxy/data/scripts/concurrent_decorator.py +++ b/test/mitmproxy/data/addonscripts/concurrent_decorator.py @@ -1,7 +1,6 @@ import time from mitmproxy.script import concurrent - @concurrent -def request(context, flow): +def request(flow): time.sleep(0.1) diff --git a/test/mitmproxy/data/scripts/concurrent_decorator_err.py b/test/mitmproxy/data/addonscripts/concurrent_decorator_err.py index 349e5dd6..756869c8 100644 --- a/test/mitmproxy/data/scripts/concurrent_decorator_err.py +++ b/test/mitmproxy/data/addonscripts/concurrent_decorator_err.py @@ -2,5 +2,5 @@ from mitmproxy.script import concurrent @concurrent -def start(context): +def start(): pass diff --git a/test/mitmproxy/data/addonscripts/duplicate_flow.py b/test/mitmproxy/data/addonscripts/duplicate_flow.py new file mode 100644 index 00000000..b466423c --- /dev/null +++ b/test/mitmproxy/data/addonscripts/duplicate_flow.py @@ -0,0 +1,6 @@ +from mitmproxy import ctx + + +def request(flow): + f = ctx.master.duplicate_flow(flow) + ctx.master.replay_request(f, block=True) diff --git a/test/mitmproxy/data/addonscripts/error.py b/test/mitmproxy/data/addonscripts/error.py new file mode 100644 index 00000000..8ece9fce --- /dev/null +++ b/test/mitmproxy/data/addonscripts/error.py @@ -0,0 +1,7 @@ + +def mkerr(): + raise ValueError("Error!") + + +def request(flow): + mkerr() diff --git a/test/mitmproxy/data/addonscripts/recorder.py b/test/mitmproxy/data/addonscripts/recorder.py new file mode 100644 index 00000000..b6ac8d89 --- /dev/null +++ b/test/mitmproxy/data/addonscripts/recorder.py @@ -0,0 +1,25 @@ +from mitmproxy import controller +from mitmproxy import ctx +import sys + +call_log = [] + +if len(sys.argv) > 1: + name = sys.argv[1] +else: + name = "solo" + +# Keep a log of all possible event calls +evts = list(controller.Events) + ["configure"] +for i in evts: + def mkprox(): + evt = i + + def prox(*args, **kwargs): + lg = (name, evt, args, kwargs) + if evt != "log": + ctx.log.info(str(lg)) + call_log.append(lg) + ctx.log.debug("%s %s" % (name, evt)) + return prox + globals()[i] = mkprox() diff --git a/test/mitmproxy/data/scripts/stream_modify.py b/test/mitmproxy/data/addonscripts/stream_modify.py index e26d83f1..bc616342 100644 --- a/test/mitmproxy/data/scripts/stream_modify.py +++ b/test/mitmproxy/data/addonscripts/stream_modify.py @@ -1,7 +1,8 @@ + def modify(chunks): for chunk in chunks: - yield chunk.replace("foo", "bar") + yield chunk.replace(b"foo", b"bar") -def responseheaders(context, flow): +def responseheaders(flow): flow.response.stream = modify diff --git a/test/mitmproxy/data/scripts/tcp_stream_modify.py b/test/mitmproxy/data/addonscripts/tcp_stream_modify.py index 0965beba..af4ccf7e 100644 --- a/test/mitmproxy/data/scripts/tcp_stream_modify.py +++ b/test/mitmproxy/data/addonscripts/tcp_stream_modify.py @@ -1,4 +1,5 @@ -def tcp_message(ctx, flow): + +def tcp_message(flow): message = flow.messages[-1] if not message.from_client: message.content = message.content.replace(b"foo", b"bar") diff --git a/test/mitmproxy/data/dumpfile-010 b/test/mitmproxy/data/dumpfile-010 Binary files differnew file mode 100644 index 00000000..435795bf --- /dev/null +++ b/test/mitmproxy/data/dumpfile-010 diff --git a/test/mitmproxy/data/dumpfile-011 b/test/mitmproxy/data/dumpfile-011 Binary files differnew file mode 100644 index 00000000..2534ad89 --- /dev/null +++ b/test/mitmproxy/data/dumpfile-011 diff --git a/test/mitmproxy/data/dumpfile-012 b/test/mitmproxy/data/dumpfile-012 deleted file mode 100644 index 49c2350d..00000000 --- a/test/mitmproxy/data/dumpfile-012 +++ /dev/null @@ -1,35 +0,0 @@ -4092:8:response,491:11:httpversion,8:1:1#1:1#]13:timestamp_end,14:1449080668.874^3:msg,12:Not Modified,15:timestamp_start,14:1449080668.863^7:headers,330:35:13:Cache-Control,14:max-age=604800,]40:4:Date,29:Wed, 02 Dec 2015 18:24:32 GMT,]32:4:Etag,21:"359670651+gzip+gzip",]43:7:Expires,29:Wed, 09 Dec 2015 18:24:32 GMT,]50:13:Last-Modified,29:Fri, 09 Aug 2013 23:54:35 GMT,]27:6:Server,14:ECS (lga/1312),]26:4:Vary,15:Accept-Encoding,]16:7:X-Cache,3:HIT,]25:17:x-ec-custom-error,1:1,]]7:content,0:,4:code,3:304#}4:type,4:http,2:id,36:d209a4fc-8e12-43cb-9250-b0b052d2caf8,5:error,0:~7:version,9:1:0#2:12#]11:client_conn,208:15:ssl_established,4:true!10:clientcert,0:~13:timestamp_end,0:~19:timestamp_ssl_setup,14:1449080668.754^7:address,53:7:address,20:9:127.0.0.1,5:58199#]8:use_ipv6,5:false!}15:timestamp_start,14:1449080666.523^}11:server_conn,2479:15:ssl_established,4:true!14:source_address,57:7:address,24:12:10.67.56.236,5:58201#]8:use_ipv6,5:false!}13:timestamp_end,0:~7:address,54:7:address,21:11:example.com,3:443#]8:use_ipv6,5:false!}15:timestamp_start,14:1449080668.046^3:sni,11:example.com,4:cert,2122:-----BEGIN CERTIFICATE----- -MIIF8jCCBNqgAwIBAgIQDmTF+8I2reFLFyrrQceMsDANBgkqhkiG9w0BAQsFADBw -MQswCQYDVQQGEwJVUzEVMBMGA1UEChMMRGlnaUNlcnQgSW5jMRkwFwYDVQQLExB3 -d3cuZGlnaWNlcnQuY29tMS8wLQYDVQQDEyZEaWdpQ2VydCBTSEEyIEhpZ2ggQXNz -dXJhbmNlIFNlcnZlciBDQTAeFw0xNTExMDMwMDAwMDBaFw0xODExMjgxMjAwMDBa -MIGlMQswCQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTEUMBIGA1UEBxML -TG9zIEFuZ2VsZXMxPDA6BgNVBAoTM0ludGVybmV0IENvcnBvcmF0aW9uIGZvciBB -c3NpZ25lZCBOYW1lcyBhbmQgTnVtYmVyczETMBEGA1UECxMKVGVjaG5vbG9neTEY -MBYGA1UEAxMPd3d3LmV4YW1wbGUub3JnMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A -MIIBCgKCAQEAs0CWL2FjPiXBl61lRfvvE0KzLJmG9LWAC3bcBjgsH6NiVVo2dt6u -Xfzi5bTm7F3K7srfUBYkLO78mraM9qizrHoIeyofrV/n+pZZJauQsPjCPxMEJnRo -D8Z4KpWKX0LyDu1SputoI4nlQ/htEhtiQnuoBfNZxF7WxcxGwEsZuS1KcXIkHl5V -RJOreKFHTaXcB1qcZ/QRaBIv0yhxvK1yBTwWddT4cli6GfHcCe3xGMaSL328Fgs3 -jYrvG29PueB6VJi/tbbPu6qTfwp/H1brqdjh29U52Bhb0fJkM9DWxCP/Cattcc7a -z8EXnCO+LK8vkhw/kAiJWPKx4RBvgy73nwIDAQABo4ICUDCCAkwwHwYDVR0jBBgw -FoAUUWj/kK8CB3U8zNllZGKiErhZcjswHQYDVR0OBBYEFKZPYB4fLdHn8SOgKpUW -5Oia6m5IMIGBBgNVHREEejB4gg93d3cuZXhhbXBsZS5vcmeCC2V4YW1wbGUuY29t -ggtleGFtcGxlLmVkdYILZXhhbXBsZS5uZXSCC2V4YW1wbGUub3Jngg93d3cuZXhh -bXBsZS5jb22CD3d3dy5leGFtcGxlLmVkdYIPd3d3LmV4YW1wbGUubmV0MA4GA1Ud -DwEB/wQEAwIFoDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwdQYDVR0f -BG4wbDA0oDKgMIYuaHR0cDovL2NybDMuZGlnaWNlcnQuY29tL3NoYTItaGEtc2Vy -dmVyLWc0LmNybDA0oDKgMIYuaHR0cDovL2NybDQuZGlnaWNlcnQuY29tL3NoYTIt -aGEtc2VydmVyLWc0LmNybDBMBgNVHSAERTBDMDcGCWCGSAGG/WwBATAqMCgGCCsG -AQUFBwIBFhxodHRwczovL3d3dy5kaWdpY2VydC5jb20vQ1BTMAgGBmeBDAECAjCB -gwYIKwYBBQUHAQEEdzB1MCQGCCsGAQUFBzABhhhodHRwOi8vb2NzcC5kaWdpY2Vy -dC5jb20wTQYIKwYBBQUHMAKGQWh0dHA6Ly9jYWNlcnRzLmRpZ2ljZXJ0LmNvbS9E -aWdpQ2VydFNIQTJIaWdoQXNzdXJhbmNlU2VydmVyQ0EuY3J0MAwGA1UdEwEB/wQC -MAAwDQYJKoZIhvcNAQELBQADggEBAISomhGn2L0LJn5SJHuyVZ3qMIlRCIdvqe0Q -6ls+C8ctRwRO3UU3x8q8OH+2ahxlQmpzdC5al4XQzJLiLjiJ2Q1p+hub8MFiMmVP -PZjb2tZm2ipWVuMRM+zgpRVM6nVJ9F3vFfUSHOb4/JsEIUvPY+d8/Krc+kPQwLvy -ieqRbcuFjmqfyPmUv1U9QoI4TQikpw7TZU0zYZANP4C/gj4Ry48/znmUaRvy2kvI -l7gRQ21qJTK5suoiYoYNo3J9T+pXPGU7Lydz/HwW+w0DpArtAaukI8aNX4ohFUKS -wDSiIIWIWJiJGbEeIO0TIFwEVWTOnbNl/faPXpk5IRXicapqiII= ------END CERTIFICATE----- -,19:timestamp_ssl_setup,14:1449080668.358^5:state,0:]19:timestamp_tcp_setup,14:1449080668.177^}11:intercepted,5:false!7:request,727:9:is_replay,5:false!4:port,3:443#6:scheme,5:https,6:method,3:GET,4:path,1:/,8:form_out,8:relative,11:httpversion,8:1:1#1:1#]4:host,11:example.com,7:headers,460:22:4:Host,11:example.com,]91:10:User-Agent,73:Mozilla/5.0 (Windows NT 10.0; WOW64; rv:41.0) Gecko/20100101 Firefox/41.0,]76:6:Accept,63:text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8,]46:15:Accept-Language,23:de,en-US;q=0.7,en;q=0.3,]36:15:Accept-Encoding,13:gzip, deflate,]28:10:Connection,10:keep-alive,]54:17:If-Modified-Since,29:Fri, 09 Aug 2013 23:54:35 GMT,]42:13:If-None-Match,21:"359670651+gzip+gzip",]29:13:Cache-Control,9:max-age=0,]]7:content,0:,7:form_in,8:relative,15:timestamp_start,14:1449080668.754^13:timestamp_end,14:1449080668.757^}}
\ No newline at end of file diff --git a/test/mitmproxy/data/dumpfile-013 b/test/mitmproxy/data/dumpfile-013 deleted file mode 100644 index ede06f23..00000000 --- a/test/mitmproxy/data/dumpfile-013 +++ /dev/null @@ -1,35 +0,0 @@ -4092:8:response,491:11:httpversion,8:1:1#1:1#]13:timestamp_end,14:1449080668.874^3:msg,12:Not Modified,15:timestamp_start,14:1449080668.863^7:headers,330:35:13:Cache-Control,14:max-age=604800,]40:4:Date,29:Wed, 02 Dec 2015 18:24:32 GMT,]32:4:Etag,21:"359670651+gzip+gzip",]43:7:Expires,29:Wed, 09 Dec 2015 18:24:32 GMT,]50:13:Last-Modified,29:Fri, 09 Aug 2013 23:54:35 GMT,]27:6:Server,14:ECS (lga/1312),]26:4:Vary,15:Accept-Encoding,]16:7:X-Cache,3:HIT,]25:17:x-ec-custom-error,1:1,]]7:content,0:,4:code,3:304#}4:type,4:http,2:id,36:d209a4fc-8e12-43cb-9250-b0b052d2caf8,5:error,0:~7:version,9:1:0#2:13#]11:client_conn,208:15:ssl_established,4:true!10:clientcert,0:~13:timestamp_end,0:~19:timestamp_ssl_setup,14:1449080668.754^7:address,53:7:address,20:9:127.0.0.1,5:58199#]8:use_ipv6,5:false!}15:timestamp_start,14:1449080666.523^}11:server_conn,2479:15:ssl_established,4:true!14:source_address,57:7:address,24:12:10.67.56.236,5:58201#]8:use_ipv6,5:false!}13:timestamp_end,0:~7:address,54:7:address,21:11:example.com,3:443#]8:use_ipv6,5:false!}15:timestamp_start,14:1449080668.046^3:sni,11:example.com,4:cert,2122:-----BEGIN CERTIFICATE----- -MIIF8jCCBNqgAwIBAgIQDmTF+8I2reFLFyrrQceMsDANBgkqhkiG9w0BAQsFADBw -MQswCQYDVQQGEwJVUzEVMBMGA1UEChMMRGlnaUNlcnQgSW5jMRkwFwYDVQQLExB3 -d3cuZGlnaWNlcnQuY29tMS8wLQYDVQQDEyZEaWdpQ2VydCBTSEEyIEhpZ2ggQXNz -dXJhbmNlIFNlcnZlciBDQTAeFw0xNTExMDMwMDAwMDBaFw0xODExMjgxMjAwMDBa -MIGlMQswCQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTEUMBIGA1UEBxML -TG9zIEFuZ2VsZXMxPDA6BgNVBAoTM0ludGVybmV0IENvcnBvcmF0aW9uIGZvciBB -c3NpZ25lZCBOYW1lcyBhbmQgTnVtYmVyczETMBEGA1UECxMKVGVjaG5vbG9neTEY -MBYGA1UEAxMPd3d3LmV4YW1wbGUub3JnMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A -MIIBCgKCAQEAs0CWL2FjPiXBl61lRfvvE0KzLJmG9LWAC3bcBjgsH6NiVVo2dt6u -Xfzi5bTm7F3K7srfUBYkLO78mraM9qizrHoIeyofrV/n+pZZJauQsPjCPxMEJnRo -D8Z4KpWKX0LyDu1SputoI4nlQ/htEhtiQnuoBfNZxF7WxcxGwEsZuS1KcXIkHl5V -RJOreKFHTaXcB1qcZ/QRaBIv0yhxvK1yBTwWddT4cli6GfHcCe3xGMaSL328Fgs3 -jYrvG29PueB6VJi/tbbPu6qTfwp/H1brqdjh29U52Bhb0fJkM9DWxCP/Cattcc7a -z8EXnCO+LK8vkhw/kAiJWPKx4RBvgy73nwIDAQABo4ICUDCCAkwwHwYDVR0jBBgw -FoAUUWj/kK8CB3U8zNllZGKiErhZcjswHQYDVR0OBBYEFKZPYB4fLdHn8SOgKpUW -5Oia6m5IMIGBBgNVHREEejB4gg93d3cuZXhhbXBsZS5vcmeCC2V4YW1wbGUuY29t -ggtleGFtcGxlLmVkdYILZXhhbXBsZS5uZXSCC2V4YW1wbGUub3Jngg93d3cuZXhh -bXBsZS5jb22CD3d3dy5leGFtcGxlLmVkdYIPd3d3LmV4YW1wbGUubmV0MA4GA1Ud -DwEB/wQEAwIFoDAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwdQYDVR0f -BG4wbDA0oDKgMIYuaHR0cDovL2NybDMuZGlnaWNlcnQuY29tL3NoYTItaGEtc2Vy -dmVyLWc0LmNybDA0oDKgMIYuaHR0cDovL2NybDQuZGlnaWNlcnQuY29tL3NoYTIt -aGEtc2VydmVyLWc0LmNybDBMBgNVHSAERTBDMDcGCWCGSAGG/WwBATAqMCgGCCsG -AQUFBwIBFhxodHRwczovL3d3dy5kaWdpY2VydC5jb20vQ1BTMAgGBmeBDAECAjCB -gwYIKwYBBQUHAQEEdzB1MCQGCCsGAQUFBzABhhhodHRwOi8vb2NzcC5kaWdpY2Vy -dC5jb20wTQYIKwYBBQUHMAKGQWh0dHA6Ly9jYWNlcnRzLmRpZ2ljZXJ0LmNvbS9E -aWdpQ2VydFNIQTJIaWdoQXNzdXJhbmNlU2VydmVyQ0EuY3J0MAwGA1UdEwEB/wQC -MAAwDQYJKoZIhvcNAQELBQADggEBAISomhGn2L0LJn5SJHuyVZ3qMIlRCIdvqe0Q -6ls+C8ctRwRO3UU3x8q8OH+2ahxlQmpzdC5al4XQzJLiLjiJ2Q1p+hub8MFiMmVP -PZjb2tZm2ipWVuMRM+zgpRVM6nVJ9F3vFfUSHOb4/JsEIUvPY+d8/Krc+kPQwLvy -ieqRbcuFjmqfyPmUv1U9QoI4TQikpw7TZU0zYZANP4C/gj4Ry48/znmUaRvy2kvI -l7gRQ21qJTK5suoiYoYNo3J9T+pXPGU7Lydz/HwW+w0DpArtAaukI8aNX4ohFUKS -wDSiIIWIWJiJGbEeIO0TIFwEVWTOnbNl/faPXpk5IRXicapqiII= ------END CERTIFICATE----- -,19:timestamp_ssl_setup,14:1449080668.358^5:state,0:]19:timestamp_tcp_setup,14:1449080668.177^}11:intercepted,5:false!7:request,727:9:is_replay,5:false!4:port,3:443#6:scheme,5:https,6:method,3:GET,4:path,1:/,8:form_out,8:relative,11:httpversion,8:1:1#1:1#]4:host,11:example.com,7:headers,460:22:4:Host,11:example.com,]91:10:User-Agent,73:Mozilla/5.0 (Windows NT 10.0; WOW64; rv:41.0) Gecko/20100101 Firefox/41.0,]76:6:Accept,63:text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8,]46:15:Accept-Language,23:de,en-US;q=0.7,en;q=0.3,]36:15:Accept-Encoding,13:gzip, deflate,]28:10:Connection,10:keep-alive,]54:17:If-Modified-Since,29:Fri, 09 Aug 2013 23:54:35 GMT,]42:13:If-None-Match,21:"359670651+gzip+gzip",]29:13:Cache-Control,9:max-age=0,]]7:content,0:,7:form_in,8:relative,15:timestamp_start,14:1449080668.754^13:timestamp_end,14:1449080668.757^}}
\ No newline at end of file diff --git a/test/mitmproxy/data/scripts/a.py b/test/mitmproxy/data/scripts/a.py deleted file mode 100644 index 33dbaa64..00000000 --- a/test/mitmproxy/data/scripts/a.py +++ /dev/null @@ -1,20 +0,0 @@ -import sys - -from a_helper import parser - -var = 0 - - -def start(ctx): - global var - var = parser.parse_args(sys.argv[1:]).var - - -def here(ctx): - global var - var += 1 - return var - - -def errargs(): - pass diff --git a/test/mitmproxy/data/scripts/a_helper.py b/test/mitmproxy/data/scripts/a_helper.py deleted file mode 100644 index e1f1c649..00000000 --- a/test/mitmproxy/data/scripts/a_helper.py +++ /dev/null @@ -1,4 +0,0 @@ -import argparse - -parser = argparse.ArgumentParser() -parser.add_argument('--var', type=int) diff --git a/test/mitmproxy/data/scripts/all.py b/test/mitmproxy/data/scripts/all.py index dad2aade..bf8e93ec 100644 --- a/test/mitmproxy/data/scripts/all.py +++ b/test/mitmproxy/data/scripts/all.py @@ -1,36 +1,37 @@ +import mitmproxy log = [] -def clientconnect(ctx, cc): - ctx.log("XCLIENTCONNECT") +def clientconnect(cc): + mitmproxy.ctx.log("XCLIENTCONNECT") log.append("clientconnect") -def serverconnect(ctx, cc): - ctx.log("XSERVERCONNECT") +def serverconnect(cc): + mitmproxy.ctx.log("XSERVERCONNECT") log.append("serverconnect") -def request(ctx, f): - ctx.log("XREQUEST") +def request(f): + mitmproxy.ctx.log("XREQUEST") log.append("request") -def response(ctx, f): - ctx.log("XRESPONSE") +def response(f): + mitmproxy.ctx.log("XRESPONSE") log.append("response") -def responseheaders(ctx, f): - ctx.log("XRESPONSEHEADERS") +def responseheaders(f): + mitmproxy.ctx.log("XRESPONSEHEADERS") log.append("responseheaders") -def clientdisconnect(ctx, cc): - ctx.log("XCLIENTDISCONNECT") +def clientdisconnect(cc): + mitmproxy.ctx.log("XCLIENTDISCONNECT") log.append("clientdisconnect") -def error(ctx, cc): - ctx.log("XERROR") +def error(cc): + mitmproxy.ctx.log("XERROR") log.append("error") diff --git a/test/mitmproxy/data/scripts/duplicate_flow.py b/test/mitmproxy/data/scripts/duplicate_flow.py deleted file mode 100644 index e13af786..00000000 --- a/test/mitmproxy/data/scripts/duplicate_flow.py +++ /dev/null @@ -1,4 +0,0 @@ - -def request(ctx, f): - f = ctx.duplicate_flow(f) - ctx.replay_request(f) diff --git a/test/mitmproxy/data/scripts/loaderr.py b/test/mitmproxy/data/scripts/loaderr.py deleted file mode 100644 index 8dc4d56d..00000000 --- a/test/mitmproxy/data/scripts/loaderr.py +++ /dev/null @@ -1,3 +0,0 @@ - - -a = x diff --git a/test/mitmproxy/data/scripts/reqerr.py b/test/mitmproxy/data/scripts/reqerr.py deleted file mode 100644 index e7c503a8..00000000 --- a/test/mitmproxy/data/scripts/reqerr.py +++ /dev/null @@ -1,2 +0,0 @@ -def request(ctx, r): - raise ValueError diff --git a/test/mitmproxy/data/scripts/starterr.py b/test/mitmproxy/data/scripts/starterr.py deleted file mode 100644 index 82d773bd..00000000 --- a/test/mitmproxy/data/scripts/starterr.py +++ /dev/null @@ -1,3 +0,0 @@ - -def start(ctx): - raise ValueError() diff --git a/test/mitmproxy/data/scripts/syntaxerr.py b/test/mitmproxy/data/scripts/syntaxerr.py deleted file mode 100644 index 219d6b84..00000000 --- a/test/mitmproxy/data/scripts/syntaxerr.py +++ /dev/null @@ -1,3 +0,0 @@ - - -a + diff --git a/test/mitmproxy/data/scripts/unloaderr.py b/test/mitmproxy/data/scripts/unloaderr.py deleted file mode 100644 index fba02734..00000000 --- a/test/mitmproxy/data/scripts/unloaderr.py +++ /dev/null @@ -1,2 +0,0 @@ -def done(ctx): - raise RuntimeError() diff --git a/test/mitmproxy/data/test_flow_export/python_post_json.py b/test/mitmproxy/data/test_flow_export/python_post_json.py index 6c1b9740..5ef110f3 100644 --- a/test/mitmproxy/data/test_flow_export/python_post_json.py +++ b/test/mitmproxy/data/test_flow_export/python_post_json.py @@ -8,8 +8,8 @@ headers = { json = { - u'email': u'example@example.com', - u'name': u'example', + 'email': 'example@example.com', + 'name': 'example', } diff --git a/test/mitmproxy/mastertest.py b/test/mitmproxy/mastertest.py index 9e726a32..dcc0dc48 100644 --- a/test/mitmproxy/mastertest.py +++ b/test/mitmproxy/mastertest.py @@ -3,24 +3,31 @@ import mock from . import tutils import netlib.tutils -from mitmproxy import flow, proxy, models +from mitmproxy.flow import master +from mitmproxy import flow, proxy, models, controller class MasterTest: + def invoke(self, master, handler, *message): + with master.handlecontext(): + func = getattr(master, handler) + func(*message) + if message: + message[0].reply = controller.DummyReply() + def cycle(self, master, content): f = tutils.tflow(req=netlib.tutils.treq(content=content)) l = proxy.Log("connect") l.reply = mock.MagicMock() master.log(l) - master.clientconnect(f.client_conn) - master.serverconnect(f.server_conn) - master.request(f) + self.invoke(master, "clientconnect", f.client_conn) + self.invoke(master, "clientconnect", f.client_conn) + self.invoke(master, "serverconnect", f.server_conn) + self.invoke(master, "request", f) if not f.error: f.response = models.HTTPResponse.wrap(netlib.tutils.tresp(content=content)) - f.reply.acked = False - f = master.response(f) - f.client_conn.reply.acked = False - master.clientdisconnect(f.client_conn) + self.invoke(master, "response", f) + self.invoke(master, "clientdisconnect", f) return f def dummy_cycle(self, master, n, content): @@ -34,3 +41,12 @@ class MasterTest: t = tutils.tflow(resp=True) fw.add(t) f.close() + + +class RecordingMaster(master.FlowMaster): + def __init__(self, *args, **kwargs): + master.FlowMaster.__init__(self, *args, **kwargs) + self.event_log = [] + + def add_log(self, e, level): + self.event_log.append((level, e)) diff --git a/test/mitmproxy/script/test_concurrent.py b/test/mitmproxy/script/test_concurrent.py index 62541f3f..080746e8 100644 --- a/test/mitmproxy/script/test_concurrent.py +++ b/test/mitmproxy/script/test_concurrent.py @@ -1,28 +1,46 @@ -from mitmproxy.script import Script -from test.mitmproxy import tutils +from test.mitmproxy import tutils, mastertest from mitmproxy import controller +from mitmproxy.builtins import script +from mitmproxy import options +from mitmproxy.flow import master +from mitmproxy.flow import state import time class Thing: def __init__(self): self.reply = controller.DummyReply() + self.live = True -@tutils.skip_appveyor -def test_concurrent(): - with Script(tutils.test_data.path("data/scripts/concurrent_decorator.py"), None) as s: - f1, f2 = Thing(), Thing() - s.run("request", f1) - s.run("request", f2) +class TestConcurrent(mastertest.MasterTest): + @tutils.skip_appveyor + def test_concurrent(self): + s = state.State() + m = master.FlowMaster(options.Options(), None, s) + sc = script.Script( + tutils.test_data.path( + "data/addonscripts/concurrent_decorator.py" + ) + ) + m.addons.add(sc) + f1, f2 = tutils.tflow(), tutils.tflow() + self.invoke(m, "request", f1) + self.invoke(m, "request", f2) start = time.time() while time.time() - start < 5: if f1.reply.acked and f2.reply.acked: return raise ValueError("Script never acked") - -def test_concurrent_err(): - s = Script(tutils.test_data.path("data/scripts/concurrent_decorator_err.py"), None) - with tutils.raises("Concurrent decorator not supported for 'start' method"): - s.load() + def test_concurrent_err(self): + s = state.State() + m = mastertest.RecordingMaster(options.Options(), None, s) + sc = script.Script( + tutils.test_data.path( + "data/addonscripts/concurrent_decorator_err.py" + ) + ) + with m.handlecontext(): + sc.start() + assert "decorator not supported" in m.event_log[0][1] diff --git a/test/mitmproxy/script/test_reloader.py b/test/mitmproxy/script/test_reloader.py deleted file mode 100644 index 0345f6ed..00000000 --- a/test/mitmproxy/script/test_reloader.py +++ /dev/null @@ -1,34 +0,0 @@ -import mock -from mitmproxy.script.reloader import watch, unwatch -from test.mitmproxy import tutils -from threading import Event - - -def test_simple(): - with tutils.tmpdir(): - with open("foo.py", "w"): - pass - - script = mock.Mock() - script.filename = "foo.py" - - e = Event() - - def _onchange(): - e.set() - - watch(script, _onchange) - with tutils.raises("already observed"): - watch(script, _onchange) - - # Some reloaders don't register a change directly after watching, because they first need to initialize. - # To test if watching works at all, we do repeated writes every 100ms. - for _ in range(100): - with open("foo.py", "a") as f: - f.write(".") - if e.wait(0.1): - break - else: - raise AssertionError("No change detected.") - - unwatch(script) diff --git a/test/mitmproxy/script/test_script.py b/test/mitmproxy/script/test_script.py deleted file mode 100644 index fe98fab5..00000000 --- a/test/mitmproxy/script/test_script.py +++ /dev/null @@ -1,83 +0,0 @@ -from mitmproxy.script import Script -from mitmproxy.exceptions import ScriptException -from test.mitmproxy import tutils - - -class TestParseCommand: - def test_empty_command(self): - with tutils.raises(ScriptException): - Script.parse_command("") - - with tutils.raises(ScriptException): - Script.parse_command(" ") - - def test_no_script_file(self): - with tutils.raises("not found"): - Script.parse_command("notfound") - - with tutils.tmpdir() as dir: - with tutils.raises("not a file"): - Script.parse_command(dir) - - def test_parse_args(self): - with tutils.chdir(tutils.test_data.dirname): - assert Script.parse_command("data/scripts/a.py") == ["data/scripts/a.py"] - assert Script.parse_command("data/scripts/a.py foo bar") == ["data/scripts/a.py", "foo", "bar"] - assert Script.parse_command("data/scripts/a.py 'foo bar'") == ["data/scripts/a.py", "foo bar"] - - @tutils.skip_not_windows - def test_parse_windows(self): - with tutils.chdir(tutils.test_data.dirname): - assert Script.parse_command("data\\scripts\\a.py") == ["data\\scripts\\a.py"] - assert Script.parse_command("data\\scripts\\a.py 'foo \\ bar'") == ["data\\scripts\\a.py", 'foo \\ bar'] - - -def test_simple(): - with tutils.chdir(tutils.test_data.path("data/scripts")): - s = Script("a.py --var 42", None) - assert s.filename == "a.py" - assert s.ns is None - - s.load() - assert s.ns["var"] == 42 - - s.run("here") - assert s.ns["var"] == 43 - - s.unload() - assert s.ns is None - - with tutils.raises(ScriptException): - s.run("here") - - with Script("a.py --var 42", None) as s: - s.run("here") - - -def test_script_exception(): - with tutils.chdir(tutils.test_data.path("data/scripts")): - s = Script("syntaxerr.py", None) - with tutils.raises(ScriptException): - s.load() - - s = Script("starterr.py", None) - with tutils.raises(ScriptException): - s.load() - - s = Script("a.py", None) - s.load() - with tutils.raises(ScriptException): - s.load() - - s = Script("a.py", None) - with tutils.raises(ScriptException): - s.run("here") - - with tutils.raises(ScriptException): - with Script("reqerr.py", None) as s: - s.run("request", None) - - s = Script("unloaderr.py", None) - s.load() - with tutils.raises(ScriptException): - s.unload() diff --git a/test/mitmproxy/test_addons.py b/test/mitmproxy/test_addons.py new file mode 100644 index 00000000..1861d4ac --- /dev/null +++ b/test/mitmproxy/test_addons.py @@ -0,0 +1,20 @@ +from __future__ import absolute_import, print_function, division +from mitmproxy import addons +from mitmproxy import controller +from mitmproxy import options + + +class TAddon: + def __init__(self, name): + self.name = name + + def __repr__(self): + return "Addon(%s)" % self.name + + +def test_simple(): + m = controller.Master(options.Options()) + a = addons.Addons(m) + a.add(TAddon("one")) + assert a.has_addon("one") + assert not a.has_addon("two") diff --git a/test/mitmproxy/test_contentview.py b/test/mitmproxy/test_contentview.py index 52fceeac..2db9ab40 100644 --- a/test/mitmproxy/test_contentview.py +++ b/test/mitmproxy/test_contentview.py @@ -1,6 +1,5 @@ from mitmproxy.exceptions import ContentViewException from netlib.http import Headers -from netlib import encoding from netlib.http import url from netlib import multidict @@ -201,6 +200,13 @@ Larry ) assert "Raw" in r[0] + r = cv.get_content_view( + cv.get("Auto"), + b"[1, 2, 3]", + headers=Headers(content_type="application/vnd.api+json") + ) + assert r[0] == "JSON" + tutils.raises( ContentViewException, cv.get_content_view, @@ -209,28 +215,6 @@ Larry headers=Headers() ) - r = cv.get_content_view( - cv.get("Auto"), - encoding.encode('gzip', b"[1, 2, 3]"), - headers=Headers( - content_type="application/json", - content_encoding="gzip" - ) - ) - assert "decoded gzip" in r[0] - assert "JSON" in r[0] - - r = cv.get_content_view( - cv.get("XML"), - encoding.encode('gzip', b"[1, 2, 3]"), - headers=Headers( - content_type="application/json", - content_encoding="gzip" - ) - ) - assert "decoded gzip" in r[0] - assert "Raw" in r[0] - def test_add_cv(self): class TestContentView(cv.View): name = "test" diff --git a/test/mitmproxy/test_contrib_tnetstring.py b/test/mitmproxy/test_contrib_tnetstring.py index 17654ad9..05c4a7c9 100644 --- a/test/mitmproxy/test_contrib_tnetstring.py +++ b/test/mitmproxy/test_contrib_tnetstring.py @@ -15,7 +15,9 @@ FORMAT_EXAMPLES = { {b'hello': [12345678901, b'this', True, None, b'\x00\x00\x00\x00']}, b'5:12345#': 12345, b'12:this is cool,': b'this is cool', + b'19:this is unicode \xe2\x98\x85;': u'this is unicode \u2605', b'0:,': b'', + b'0:;': u'', b'0:~': None, b'4:true!': True, b'5:false!': False, @@ -43,7 +45,7 @@ def get_random_object(random=random, depth=0): d = {} for _ in range(n): n = random.randint(0, 100) - k = bytes([random.randint(32, 126) for _ in range(n)]) + k = str([random.randint(32, 126) for _ in range(n)]) d[k] = get_random_object(random, depth + 1) return d else: @@ -78,12 +80,6 @@ class Test_Format(unittest.TestCase): self.assertEqual(v, tnetstring.loads(tnetstring.dumps(v))) self.assertEqual((v, b""), tnetstring.pop(tnetstring.dumps(v))) - def test_unicode_handling(self): - with self.assertRaises(ValueError): - tnetstring.dumps(u"hello") - self.assertEqual(tnetstring.dumps(u"hello".encode()), b"5:hello,") - self.assertEqual(type(tnetstring.loads(b"5:hello,")), bytes) - def test_roundtrip_format_unicode(self): for _ in range(500): v = get_random_object() diff --git a/test/mitmproxy/test_controller.py b/test/mitmproxy/test_controller.py index 5a68e15b..6d4b8fe6 100644 --- a/test/mitmproxy/test_controller.py +++ b/test/mitmproxy/test_controller.py @@ -25,7 +25,7 @@ class TestMaster(object): # Speed up test super(DummyMaster, self).tick(0) - m = DummyMaster() + m = DummyMaster(None) assert not m.should_exit.is_set() msg = TMsg() msg.reply = controller.DummyReply() @@ -34,7 +34,7 @@ class TestMaster(object): assert m.should_exit.is_set() def test_server_simple(self): - m = controller.Master() + m = controller.Master(None) s = DummyServer(None) m.add_server(s) m.start() diff --git a/test/mitmproxy/test_dump.py b/test/mitmproxy/test_dump.py index 234490f8..90f33264 100644 --- a/test/mitmproxy/test_dump.py +++ b/test/mitmproxy/test_dump.py @@ -1,127 +1,79 @@ import os from six.moves import cStringIO as StringIO -from mitmproxy.exceptions import ContentViewException -import netlib.tutils - -from mitmproxy import dump, flow, models +from mitmproxy import dump, flow, exceptions from . import tutils, mastertest import mock -def test_strfuncs(): - o = dump.Options() - m = dump.DumpMaster(None, o) - - m.outfile = StringIO() - m.o.flow_detail = 0 - m.echo_flow(tutils.tflow()) - assert not m.outfile.getvalue() - - m.o.flow_detail = 4 - m.echo_flow(tutils.tflow()) - assert m.outfile.getvalue() - - m.outfile = StringIO() - m.echo_flow(tutils.tflow(resp=True)) - assert "<<" in m.outfile.getvalue() - - m.outfile = StringIO() - m.echo_flow(tutils.tflow(err=True)) - assert "<<" in m.outfile.getvalue() - - flow = tutils.tflow() - flow.request = netlib.tutils.treq() - flow.request.stickycookie = True - flow.client_conn = mock.MagicMock() - flow.client_conn.address.host = "foo" - flow.response = netlib.tutils.tresp(content=None) - flow.response.is_replay = True - flow.response.status_code = 300 - m.echo_flow(flow) - - flow = tutils.tflow(resp=netlib.tutils.tresp(content="{")) - flow.response.headers["content-type"] = "application/json" - flow.response.status_code = 400 - m.echo_flow(flow) - - -@mock.patch("mitmproxy.contentviews.get_content_view") -def test_contentview(get_content_view): - get_content_view.side_effect = ContentViewException(""), ("x", iter([])) - - o = dump.Options(flow_detail=4, verbosity=3) - m = dump.DumpMaster(None, o, StringIO()) - m.echo_flow(tutils.tflow()) - assert "Content viewer failed" in m.outfile.getvalue() - - class TestDumpMaster(mastertest.MasterTest): def dummy_cycle(self, master, n, content): mastertest.MasterTest.dummy_cycle(self, master, n, content) - return master.outfile.getvalue() + return master.options.tfile.getvalue() def mkmaster(self, filt, **options): - cs = StringIO() - o = dump.Options(filtstr=filt, **options) - return dump.DumpMaster(None, o, outfile=cs) + if "verbosity" not in options: + options["verbosity"] = 0 + if "flow_detail" not in options: + options["flow_detail"] = 0 + o = dump.Options(filtstr=filt, tfile=StringIO(), **options) + return dump.DumpMaster(None, o) def test_basic(self): for i in (1, 2, 3): - assert "GET" in self.dummy_cycle(self.mkmaster("~s", flow_detail=i), 1, "") assert "GET" in self.dummy_cycle( self.mkmaster("~s", flow_detail=i), 1, - "\x00\x00\x00" + b"" + ) + assert "GET" in self.dummy_cycle( + self.mkmaster("~s", flow_detail=i), + 1, + b"\x00\x00\x00" ) assert "GET" in self.dummy_cycle( self.mkmaster("~s", flow_detail=i), - 1, "ascii" + 1, + b"ascii" ) def test_error(self): - cs = StringIO() - o = dump.Options(flow_detail=1) - m = dump.DumpMaster(None, o, outfile=cs) + o = dump.Options( + tfile=StringIO(), + flow_detail=1 + ) + m = dump.DumpMaster(None, o) f = tutils.tflow(err=True) - m.request(f) + m.error(f) assert m.error(f) - assert "error" in cs.getvalue() - - def test_missing_content(self): - cs = StringIO() - o = dump.Options(flow_detail=3) - m = dump.DumpMaster(None, o, outfile=cs) - f = tutils.tflow() - f.request.content = None - m.request(f) - f.response = models.HTTPResponse.wrap(netlib.tutils.tresp()) - f.response.content = None - m.response(f) - assert "content missing" in cs.getvalue() + assert "error" in o.tfile.getvalue() def test_replay(self): - cs = StringIO() - o = dump.Options(server_replay=["nonexistent"], kill=True) - tutils.raises(dump.DumpError, dump.DumpMaster, None, o, outfile=cs) + tutils.raises(dump.DumpError, dump.DumpMaster, None, o) with tutils.tmpdir() as t: p = os.path.join(t, "rep") self.flowfile(p) o = dump.Options(server_replay=[p], kill=True) - m = dump.DumpMaster(None, o, outfile=cs) + o.verbosity = 0 + o.flow_detail = 0 + m = dump.DumpMaster(None, o) - self.cycle(m, "content") - self.cycle(m, "content") + self.cycle(m, b"content") + self.cycle(m, b"content") o = dump.Options(server_replay=[p], kill=False) - m = dump.DumpMaster(None, o, outfile=cs) - self.cycle(m, "nonexistent") + o.verbosity = 0 + o.flow_detail = 0 + m = dump.DumpMaster(None, o) + self.cycle(m, b"nonexistent") o = dump.Options(client_replay=[p], kill=False) - m = dump.DumpMaster(None, o, outfile=cs) + o.verbosity = 0 + o.flow_detail = 0 + m = dump.DumpMaster(None, o) def test_read(self): with tutils.tmpdir() as t: @@ -129,9 +81,8 @@ class TestDumpMaster(mastertest.MasterTest): self.flowfile(p) assert "GET" in self.dummy_cycle( self.mkmaster(None, flow_detail=1, rfile=p), - 0, "", + 1, b"", ) - tutils.raises( dump.DumpError, self.mkmaster, None, verbosity=1, rfile="/nonexistent" @@ -147,7 +98,7 @@ class TestDumpMaster(mastertest.MasterTest): def test_filter(self): assert "GET" not in self.dummy_cycle( - self.mkmaster("~u foo", verbosity=1), 1, "" + self.mkmaster("~u foo", verbosity=1), 1, b"" ) def test_app(self): @@ -157,24 +108,32 @@ class TestDumpMaster(mastertest.MasterTest): assert len(m.apps.apps) == 1 def test_replacements(self): - cs = StringIO() - o = dump.Options(replacements=[(".*", "content", "foo")]) - m = dump.DumpMaster(None, o, outfile=cs) - f = self.cycle(m, "content") - assert f.request.content == "foo" + o = dump.Options( + replacements=[(".*", "content", "foo")], + tfile = StringIO(), + ) + o.verbosity = 0 + o.flow_detail = 0 + m = dump.DumpMaster(None, o) + f = self.cycle(m, b"content") + assert f.request.content == b"foo" def test_setheader(self): - cs = StringIO() - o = dump.Options(setheaders=[(".*", "one", "two")]) - m = dump.DumpMaster(None, o, outfile=cs) - f = self.cycle(m, "content") + o = dump.Options( + setheaders=[(".*", "one", "two")], + tfile=StringIO() + ) + o.verbosity = 0 + o.flow_detail = 0 + m = dump.DumpMaster(None, o) + f = self.cycle(m, b"content") assert f.request.headers["one"] == "two" def test_write(self): with tutils.tmpdir() as d: p = os.path.join(d, "a") self.dummy_cycle( - self.mkmaster(None, outfile=(p, "wb"), verbosity=0), 1, "" + self.mkmaster(None, outfile=(p, "wb"), verbosity=0), 1, b"" ) assert len(list(flow.FlowReader(open(p, "rb")).stream())) == 1 @@ -183,17 +142,17 @@ class TestDumpMaster(mastertest.MasterTest): p = os.path.join(d, "a.append") self.dummy_cycle( self.mkmaster(None, outfile=(p, "wb"), verbosity=0), - 1, "" + 1, b"" ) self.dummy_cycle( self.mkmaster(None, outfile=(p, "ab"), verbosity=0), - 1, "" + 1, b"" ) assert len(list(flow.FlowReader(open(p, "rb")).stream())) == 2 def test_write_err(self): tutils.raises( - dump.DumpError, + exceptions.OptionsError, self.mkmaster, None, outfile = ("nonexistentdir/foo", "wb") ) @@ -201,9 +160,10 @@ class TestDumpMaster(mastertest.MasterTest): ret = self.dummy_cycle( self.mkmaster( None, - scripts=[tutils.test_data.path("data/scripts/all.py")], verbosity=1 + scripts=[tutils.test_data.path("data/scripts/all.py")], + verbosity=2 ), - 1, "", + 1, b"", ) assert "XCLIENTCONNECT" in ret assert "XSERVERCONNECT" in ret @@ -211,12 +171,12 @@ class TestDumpMaster(mastertest.MasterTest): assert "XRESPONSE" in ret assert "XCLIENTDISCONNECT" in ret tutils.raises( - dump.DumpError, + exceptions.AddonError, self.mkmaster, None, scripts=["nonexistent"] ) tutils.raises( - dump.DumpError, + exceptions.AddonError, self.mkmaster, None, scripts=["starterr.py"] ) @@ -224,11 +184,11 @@ class TestDumpMaster(mastertest.MasterTest): def test_stickycookie(self): self.dummy_cycle( self.mkmaster(None, stickycookie = ".*"), - 1, "" + 1, b"" ) def test_stickyauth(self): self.dummy_cycle( self.mkmaster(None, stickyauth = ".*"), - 1, "" + 1, b"" ) diff --git a/test/mitmproxy/test_examples.py b/test/mitmproxy/test_examples.py index 607d6faf..0ec85f52 100644 --- a/test/mitmproxy/test_examples.py +++ b/test/mitmproxy/test_examples.py @@ -1,155 +1,126 @@ -import glob import json -import os -from contextlib import contextmanager -from mitmproxy import script -from mitmproxy.proxy import config +import six +import sys +import os.path +from mitmproxy.flow import master +from mitmproxy.flow import state +from mitmproxy import options +from mitmproxy import contentviews +from mitmproxy.builtins import script import netlib.utils from netlib import tutils as netutils from netlib.http import Headers -from . import tservers, tutils +from . import tutils, mastertest -example_dir = netlib.utils.Data(__name__).path("../../examples") +example_dir = netlib.utils.Data(__name__).push("../../examples") -class DummyContext(object): - """Emulate script.ScriptContext() functionality.""" +class ScriptError(Exception): + pass - contentview = None - def log(self, *args, **kwargs): - pass +class RaiseMaster(master.FlowMaster): + def add_log(self, e, level): + if level in ("warn", "error"): + raise ScriptError(e) - def add_contentview(self, view_obj): - self.contentview = view_obj - def remove_contentview(self, view_obj): - self.contentview = None +def tscript(cmd, args=""): + cmd = example_dir.path(cmd) + " " + args + m = RaiseMaster(options.Options(), None, state.State()) + sc = script.Script(cmd) + m.addons.add(sc) + return m, sc -@contextmanager -def example(command): - command = os.path.join(example_dir, command) - ctx = DummyContext() - with script.Script(command, ctx) as s: - yield s +class TestScripts(mastertest.MasterTest): + def test_add_header(self): + m, _ = tscript("add_header.py") + f = tutils.tflow(resp=netutils.tresp()) + self.invoke(m, "response", f) + assert f.response.headers["newheader"] == "foo" + def test_custom_contentviews(self): + m, sc = tscript("custom_contentviews.py") + pig = contentviews.get("pig_latin_HTML") + _, fmt = pig(b"<html>test!</html>") + assert any(b'esttay!' in val[0][1] for val in fmt) + assert not pig(b"gobbledygook") -def test_load_scripts(): - scripts = glob.glob("%s/*.py" % example_dir) + def test_iframe_injector(self): + with tutils.raises(ScriptError): + tscript("iframe_injector.py") - tmaster = tservers.TestMaster(config.ProxyConfig()) - - for f in scripts: - if "har_extractor" in f: - continue - if "flowwriter" in f: - f += " -" - if "iframe_injector" in f: - f += " foo" # one argument required - if "filt" in f: - f += " ~a" - if "modify_response_body" in f: - f += " foo bar" # two arguments required - - s = script.Script(f, script.ScriptContext(tmaster)) - try: - s.load() - except Exception as v: - if "ImportError" not in str(v): - raise - else: - s.unload() - - -def test_add_header(): - flow = tutils.tflow(resp=netutils.tresp()) - with example("add_header.py") as ex: - ex.run("response", flow) - assert flow.response.headers["newheader"] == "foo" - - -def test_custom_contentviews(): - with example("custom_contentviews.py") as ex: - pig = ex.ctx.contentview - _, fmt = pig("<html>test!</html>") - assert any('esttay!' in val[0][1] for val in fmt) - assert not pig("gobbledygook") - - -def test_iframe_injector(): - with tutils.raises(script.ScriptException): - with example("iframe_injector.py") as ex: - pass - - flow = tutils.tflow(resp=netutils.tresp(content="<html>mitmproxy</html>")) - with example("iframe_injector.py http://example.org/evil_iframe") as ex: - ex.run("response", flow) + m, sc = tscript("iframe_injector.py", "http://example.org/evil_iframe") + flow = tutils.tflow(resp=netutils.tresp(content=b"<html>mitmproxy</html>")) + self.invoke(m, "response", flow) content = flow.response.content - assert 'iframe' in content and 'evil_iframe' in content - - -def test_modify_form(): - form_header = Headers(content_type="application/x-www-form-urlencoded") - flow = tutils.tflow(req=netutils.treq(headers=form_header)) - with example("modify_form.py") as ex: - ex.run("request", flow) - assert flow.request.urlencoded_form["mitmproxy"] == "rocks" - - flow.request.headers["content-type"] = "" - ex.run("request", flow) - assert list(flow.request.urlencoded_form.items()) == [("foo", "bar")] - - -def test_modify_querystring(): - flow = tutils.tflow(req=netutils.treq(path="/search?q=term")) - with example("modify_querystring.py") as ex: - ex.run("request", flow) - assert flow.request.query["mitmproxy"] == "rocks" - - flow.request.path = "/" - ex.run("request", flow) - assert flow.request.query["mitmproxy"] == "rocks" - - -def test_modify_response_body(): - with tutils.raises(script.ScriptException): - with example("modify_response_body.py"): - assert True - - flow = tutils.tflow(resp=netutils.tresp(content="I <3 mitmproxy")) - with example("modify_response_body.py mitmproxy rocks") as ex: - assert ex.ctx.old == "mitmproxy" and ex.ctx.new == "rocks" - ex.run("response", flow) - assert flow.response.content == "I <3 rocks" - - -def test_redirect_requests(): - flow = tutils.tflow(req=netutils.treq(host="example.org")) - with example("redirect_requests.py") as ex: - ex.run("request", flow) - assert flow.request.host == "mitmproxy.org" - - -def test_har_extractor(): - with tutils.raises(script.ScriptException): - with example("har_extractor.py"): - pass - - times = dict( - timestamp_start=746203272, - timestamp_end=746203272, - ) - - flow = tutils.tflow( - req=netutils.treq(**times), - resp=netutils.tresp(**times) - ) - - with example("har_extractor.py -") as ex: - ex.run("response", flow) - - with open(tutils.test_data.path("data/har_extractor.har")) as fp: - test_data = json.load(fp) - assert json.loads(ex.ctx.HARLog.json()) == test_data["test_response"] + assert b'iframe' in content and b'evil_iframe' in content + + def test_modify_form(self): + m, sc = tscript("modify_form.py") + + form_header = Headers(content_type="application/x-www-form-urlencoded") + f = tutils.tflow(req=netutils.treq(headers=form_header)) + self.invoke(m, "request", f) + + assert f.request.urlencoded_form[b"mitmproxy"] == b"rocks" + + f.request.headers["content-type"] = "" + self.invoke(m, "request", f) + assert list(f.request.urlencoded_form.items()) == [(b"foo", b"bar")] + + def test_modify_querystring(self): + m, sc = tscript("modify_querystring.py") + f = tutils.tflow(req=netutils.treq(path="/search?q=term")) + + self.invoke(m, "request", f) + assert f.request.query["mitmproxy"] == "rocks" + + f.request.path = "/" + self.invoke(m, "request", f) + assert f.request.query["mitmproxy"] == "rocks" + + def test_modify_response_body(self): + with tutils.raises(ScriptError): + tscript("modify_response_body.py") + + m, sc = tscript("modify_response_body.py", "mitmproxy rocks") + f = tutils.tflow(resp=netutils.tresp(content=b"I <3 mitmproxy")) + self.invoke(m, "response", f) + assert f.response.content == b"I <3 rocks" + + def test_redirect_requests(self): + m, sc = tscript("redirect_requests.py") + f = tutils.tflow(req=netutils.treq(host="example.org")) + self.invoke(m, "request", f) + assert f.request.host == "mitmproxy.org" + + def test_har_extractor(self): + if sys.version_info >= (3, 0): + with tutils.raises("does not work on Python 3"): + tscript("har_extractor.py") + return + + with tutils.raises(ScriptError): + tscript("har_extractor.py") + + with tutils.tmpdir() as tdir: + times = dict( + timestamp_start=746203272, + timestamp_end=746203272, + ) + + path = os.path.join(tdir, "file") + m, sc = tscript("har_extractor.py", six.moves.shlex_quote(path)) + f = tutils.tflow( + req=netutils.treq(**times), + resp=netutils.tresp(**times) + ) + self.invoke(m, "response", f) + m.addons.remove(sc) + + with open(path, "rb") as f: + test_data = json.load(f) + assert len(test_data["log"]["pages"]) == 1 diff --git a/test/mitmproxy/test_filt.py b/test/mitmproxy/test_filt.py index 9fe36b2a..69f042bb 100644 --- a/test/mitmproxy/test_filt.py +++ b/test/mitmproxy/test_filt.py @@ -1,6 +1,8 @@ from six.moves import cStringIO as StringIO -from mitmproxy import filt from mock import patch + +from mitmproxy import filt + from . import tutils @@ -73,7 +75,7 @@ class TestParsing: self._dump(a) -class TestMatching: +class TestMatchingHTTPFlow: def req(self): return tutils.tflow() @@ -87,6 +89,11 @@ class TestMatching: def q(self, q, o): return filt.parse(q)(o) + def test_http(self): + s = self.req() + assert self.q("~http", s) + assert not self.q("~tcp", s) + def test_asset(self): s = self.resp() assert not self.q("~a", s) @@ -247,6 +254,186 @@ class TestMatching: assert not self.q("!~c 201 !~c 200", s) +class TestMatchingTCPFlow: + + def flow(self): + return tutils.ttcpflow() + + def err(self): + return tutils.ttcpflow(err=True) + + def q(self, q, o): + return filt.parse(q)(o) + + def test_tcp(self): + f = self.flow() + assert self.q("~tcp", f) + assert not self.q("~http", f) + + def test_ferr(self): + e = self.err() + assert self.q("~e", e) + + def test_body(self): + f = self.flow() + + # Messages sent by client or server + assert self.q("~b hello", f) + assert self.q("~b me", f) + assert not self.q("~b nonexistent", f) + + # Messages sent by client + assert self.q("~bq hello", f) + assert not self.q("~bq me", f) + assert not self.q("~bq nonexistent", f) + + # Messages sent by server + assert self.q("~bs me", f) + assert not self.q("~bs hello", f) + assert not self.q("~bs nonexistent", f) + + def test_src(self): + f = self.flow() + assert self.q("~src address", f) + assert not self.q("~src foobar", f) + assert self.q("~src :22", f) + assert not self.q("~src :99", f) + assert self.q("~src address:22", f) + + def test_dst(self): + f = self.flow() + f.server_conn = tutils.tserver_conn() + assert self.q("~dst address", f) + assert not self.q("~dst foobar", f) + assert self.q("~dst :22", f) + assert not self.q("~dst :99", f) + assert self.q("~dst address:22", f) + + def test_and(self): + f = self.flow() + f.server_conn = tutils.tserver_conn() + assert self.q("~b hello & ~b me", f) + assert not self.q("~src wrongaddress & ~b hello", f) + assert self.q("(~src :22 & ~dst :22) & ~b hello", f) + assert not self.q("(~src address:22 & ~dst :22) & ~b nonexistent", f) + assert not self.q("(~src address:22 & ~dst :99) & ~b hello", f) + + def test_or(self): + f = self.flow() + f.server_conn = tutils.tserver_conn() + assert self.q("~b hello | ~b me", f) + assert self.q("~src :22 | ~b me", f) + assert not self.q("~src :99 | ~dst :99", f) + assert self.q("(~src :22 | ~dst :22) | ~b me", f) + + def test_not(self): + f = self.flow() + assert not self.q("! ~src :22", f) + assert self.q("! ~src :99", f) + assert self.q("!~src :99 !~src :99", f) + assert not self.q("!~src :99 !~src :22", f) + + def test_request(self): + f = self.flow() + assert not self.q("~q", f) + + def test_response(self): + f = self.flow() + assert not self.q("~s", f) + + def test_headers(self): + f = self.flow() + assert not self.q("~h whatever", f) + + # Request headers + assert not self.q("~hq whatever", f) + + # Response headers + assert not self.q("~hs whatever", f) + + def test_content_type(self): + f = self.flow() + assert not self.q("~t whatever", f) + + # Request content-type + assert not self.q("~tq whatever", f) + + # Response content-type + assert not self.q("~ts whatever", f) + + def test_code(self): + f = self.flow() + assert not self.q("~c 200", f) + + def test_domain(self): + f = self.flow() + assert not self.q("~d whatever", f) + + def test_method(self): + f = self.flow() + assert not self.q("~m whatever", f) + + def test_url(self): + f = self.flow() + assert not self.q("~u whatever", f) + + +class TestMatchingDummyFlow: + + def flow(self): + return tutils.tdummyflow() + + def err(self): + return tutils.tdummyflow(err=True) + + def q(self, q, o): + return filt.parse(q)(o) + + def test_filters(self): + e = self.err() + f = self.flow() + f.server_conn = tutils.tserver_conn() + + assert not self.q("~a", f) + + assert not self.q("~b whatever", f) + assert not self.q("~bq whatever", f) + assert not self.q("~bs whatever", f) + + assert not self.q("~c 0", f) + + assert not self.q("~d whatever", f) + + assert self.q("~dst address", f) + assert not self.q("~dst nonexistent", f) + + assert self.q("~e", e) + assert not self.q("~e", f) + + assert not self.q("~http", f) + + assert not self.q("~h whatever", f) + assert not self.q("~hq whatever", f) + assert not self.q("~hs whatever", f) + + assert not self.q("~m whatever", f) + + assert not self.q("~s", f) + + assert self.q("~src address", f) + assert not self.q("~src nonexistent", f) + + assert not self.q("~tcp", f) + + assert not self.q("~t whatever", f) + assert not self.q("~tq whatever", f) + assert not self.q("~ts whatever", f) + + assert not self.q("~u whatever", f) + + assert not self.q("~q", f) + + @patch('traceback.extract_tb') def test_pyparsing_bug(extract_tb): """https://github.com/mitmproxy/mitmproxy/issues/1087""" diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index 9eaab9aa..90f7f915 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -1,13 +1,11 @@ -import os.path -from six.moves import cStringIO as StringIO - import mock +import io import netlib.utils from netlib.http import Headers from mitmproxy import filt, controller, flow from mitmproxy.contrib import tnetstring -from mitmproxy.exceptions import FlowReadException, ScriptException +from mitmproxy.exceptions import FlowReadException from mitmproxy.models import Error from mitmproxy.models import Flow from mitmproxy.models import HTTPFlow @@ -40,94 +38,12 @@ def test_app_registry(): assert ar.get(r) -class TestStickyCookieState: - - def _response(self, cookie, host): - s = flow.StickyCookieState(filt.parse(".*")) - f = tutils.tflow(req=netlib.tutils.treq(host=host, port=80), resp=True) - f.response.headers["Set-Cookie"] = cookie - s.handle_response(f) - return s, f - - def test_domain_match(self): - s = flow.StickyCookieState(filt.parse(".*")) - assert s.domain_match("www.google.com", ".google.com") - assert s.domain_match("google.com", ".google.com") - - def test_response(self): - c = "SSID=mooo; domain=.google.com, FOO=bar; Domain=.google.com; Path=/; " \ - "Expires=Wed, 13-Jan-2021 22:23:01 GMT; Secure; " - - s, f = self._response(c, "host") - assert not s.jar.keys() - - s, f = self._response(c, "www.google.com") - assert s.jar.keys() - - s, f = self._response("SSID=mooo", "www.google.com") - assert s.jar.keys()[0] == ('www.google.com', 80, '/') - - # Test setting of multiple cookies - c1 = "somecookie=test; Path=/" - c2 = "othercookie=helloworld; Path=/" - s, f = self._response(c1, "www.google.com") - f.response.headers["Set-Cookie"] = c2 - s.handle_response(f) - googlekey = s.jar.keys()[0] - assert len(s.jar[googlekey].keys()) == 2 - - # Test setting of weird cookie keys - s = flow.StickyCookieState(filt.parse(".*")) - f = tutils.tflow(req=netlib.tutils.treq(host="www.google.com", port=80), resp=True) - cs = [ - "foo/bar=hello", - "foo:bar=world", - "foo@bar=fizz", - "foo,bar=buzz", - ] - for c in cs: - f.response.headers["Set-Cookie"] = c - s.handle_response(f) - googlekey = s.jar.keys()[0] - assert len(s.jar[googlekey].keys()) == len(cs) - - # Test overwriting of a cookie value - c1 = "somecookie=helloworld; Path=/" - c2 = "somecookie=newvalue; Path=/" - s, f = self._response(c1, "www.google.com") - f.response.headers["Set-Cookie"] = c2 - s.handle_response(f) - googlekey = s.jar.keys()[0] - assert len(s.jar[googlekey].keys()) == 1 - assert s.jar[googlekey]["somecookie"].items()[0][1] == "newvalue" - - def test_request(self): - s, f = self._response("SSID=mooo", "www.google.com") - assert "cookie" not in f.request.headers - s.handle_request(f) - assert "cookie" in f.request.headers - - -class TestStickyAuthState: - - def test_response(self): - s = flow.StickyAuthState(filt.parse(".*")) - f = tutils.tflow(resp=True) - f.request.headers["authorization"] = "foo" - s.handle_request(f) - assert "address" in s.hosts - - f = tutils.tflow(resp=True) - s.handle_request(f) - assert f.request.headers["authorization"] == "foo" - - class TestClientPlaybackState: def test_tick(self): first = tutils.tflow() s = flow.State() - fm = flow.FlowMaster(None, s) + fm = flow.FlowMaster(None, None, s) fm.start_client_playback([first, tutils.tflow()], True) c = fm.client_playback c.testing = True @@ -264,26 +180,26 @@ class TestServerPlaybackState: "param1", "param2"], False) r = tutils.tflow(resp=True) r.request.headers["Content-Type"] = "application/x-www-form-urlencoded" - r.request.content = "paramx=x¶m1=1" + r.request.content = b"paramx=x¶m1=1" r2 = tutils.tflow(resp=True) r2.request.headers["Content-Type"] = "application/x-www-form-urlencoded" - r2.request.content = "paramx=x¶m1=1" + r2.request.content = b"paramx=x¶m1=1" # same parameters assert s._hash(r) == s._hash(r2) # ignored parameters != - r2.request.content = "paramx=x¶m1=2" + r2.request.content = b"paramx=x¶m1=2" assert s._hash(r) == s._hash(r2) # missing parameter - r2.request.content = "paramx=x" + r2.request.content = b"paramx=x" assert s._hash(r) == s._hash(r2) # ignorable parameter added - r2.request.content = "paramx=x¶m1=2" + r2.request.content = b"paramx=x¶m1=2" assert s._hash(r) == s._hash(r2) # not ignorable parameter changed - r2.request.content = "paramx=y¶m1=1" + r2.request.content = b"paramx=y¶m1=1" assert not s._hash(r) == s._hash(r2) # not ignorable parameter missing - r2.request.content = "param1=1" + r2.request.content = b"param1=1" assert not s._hash(r) == s._hash(r2) def test_ignore_payload_params_other_content_type(self): @@ -292,14 +208,14 @@ class TestServerPlaybackState: "param1", "param2"], False) r = tutils.tflow(resp=True) r.request.headers["Content-Type"] = "application/json" - r.request.content = '{"param1":"1"}' + r.request.content = b'{"param1":"1"}' r2 = tutils.tflow(resp=True) r2.request.headers["Content-Type"] = "application/json" - r2.request.content = '{"param1":"1"}' + r2.request.content = b'{"param1":"1"}' # same content assert s._hash(r) == s._hash(r2) # distint content (note only x-www-form-urlencoded payload is analysed) - r2.request.content = '{"param1":"2"}' + r2.request.content = b'{"param1":"2"}' assert not s._hash(r) == s._hash(r2) def test_ignore_payload_wins_over_params(self): @@ -309,10 +225,10 @@ class TestServerPlaybackState: "param1", "param2"], False) r = tutils.tflow(resp=True) r.request.headers["Content-Type"] = "application/x-www-form-urlencoded" - r.request.content = "paramx=y" + r.request.content = b"paramx=y" r2 = tutils.tflow(resp=True) r2.request.headers["Content-Type"] = "application/x-www-form-urlencoded" - r2.request.content = "paramx=x" + r2.request.content = b"paramx=x" # same parameters assert s._hash(r) == s._hash(r2) @@ -329,10 +245,10 @@ class TestServerPlaybackState: r = tutils.tflow(resp=True) r2 = tutils.tflow(resp=True) - r.request.content = "foo" - r2.request.content = "foo" + r.request.content = b"foo" + r2.request.content = b"foo" assert s._hash(r) == s._hash(r2) - r2.request.content = "bar" + r2.request.content = b"bar" assert not s._hash(r) == s._hash(r2) # now ignoring content @@ -347,12 +263,12 @@ class TestServerPlaybackState: False) r = tutils.tflow(resp=True) r2 = tutils.tflow(resp=True) - r.request.content = "foo" - r2.request.content = "foo" + r.request.content = b"foo" + r2.request.content = b"foo" assert s._hash(r) == s._hash(r2) - r2.request.content = "bar" + r2.request.content = b"bar" assert s._hash(r) == s._hash(r2) - r2.request.content = "" + r2.request.content = b"" assert s._hash(r) == s._hash(r2) r2.request.content = None assert s._hash(r) == s._hash(r2) @@ -377,7 +293,7 @@ class TestServerPlaybackState: assert s._hash(r) == s._hash(r2) -class TestFlow(object): +class TestHTTPFlow(object): def test_copy(self): f = tutils.tflow(resp=True) @@ -420,13 +336,13 @@ class TestFlow(object): def test_backup(self): f = tutils.tflow() f.response = HTTPResponse.wrap(netlib.tutils.tresp()) - f.request.content = "foo" + f.request.content = b"foo" assert not f.modified() f.backup() - f.request.content = "bar" + f.request.content = b"bar" assert f.modified() f.revert() - assert f.request.content == "foo" + assert f.request.content == b"foo" def test_backup_idempotence(self): f = tutils.tflow(resp=True) @@ -458,7 +374,7 @@ class TestFlow(object): def test_kill(self): s = flow.State() - fm = flow.FlowMaster(None, s) + fm = flow.FlowMaster(None, None, s) f = tutils.tflow() f.intercept(mock.Mock()) f.kill(fm) @@ -467,7 +383,7 @@ class TestFlow(object): def test_killall(self): s = flow.State() - fm = flow.FlowMaster(None, s) + fm = flow.FlowMaster(None, None, s) f = tutils.tflow() f.intercept(fm) @@ -486,8 +402,8 @@ class TestFlow(object): def test_replace_unicode(self): f = tutils.tflow(resp=True) - f.response.content = "\xc2foo" - f.replace("foo", u"bar") + f.response.content = b"\xc2foo" + f.replace(b"foo", u"bar") def test_replace_no_content(self): f = tutils.tflow() @@ -497,34 +413,48 @@ class TestFlow(object): def test_replace(self): f = tutils.tflow(resp=True) f.request.headers["foo"] = "foo" - f.request.content = "afoob" + f.request.content = b"afoob" f.response.headers["foo"] = "foo" - f.response.content = "afoob" + f.response.content = b"afoob" assert f.replace("foo", "bar") == 6 assert f.request.headers["bar"] == "bar" - assert f.request.content == "abarb" + assert f.request.content == b"abarb" assert f.response.headers["bar"] == "bar" - assert f.response.content == "abarb" + assert f.response.content == b"abarb" def test_replace_encoded(self): f = tutils.tflow(resp=True) - f.request.content = "afoob" + f.request.content = b"afoob" f.request.encode("gzip") - f.response.content = "afoob" + f.response.content = b"afoob" f.response.encode("gzip") f.replace("foo", "bar") - assert f.request.content != "abarb" + assert f.request.raw_content != b"abarb" f.request.decode() - assert f.request.content == "abarb" + assert f.request.raw_content == b"abarb" - assert f.response.content != "abarb" + assert f.response.raw_content != b"abarb" f.response.decode() - assert f.response.content == "abarb" + assert f.response.raw_content == b"abarb" + + +class TestTCPFlow: + + def test_match(self): + f = tutils.ttcpflow() + assert not f.match("~b nonexistent") + assert f.match(None) + assert not f.match("~b nonexistent") + + f = tutils.ttcpflow(err=True) + assert f.match("~e") + + tutils.raises(ValueError, f.match, "~") class TestState: @@ -667,7 +597,7 @@ class TestState: class TestSerialize: def _treader(self): - sio = StringIO() + sio = io.BytesIO() w = flow.FlowWriter(sio) for i in range(3): f = tutils.tflow(resp=True) @@ -684,9 +614,9 @@ class TestSerialize: return flow.FlowReader(sio) def test_roundtrip(self): - sio = StringIO() + sio = io.BytesIO() f = tutils.tflow() - f.request.content = "".join(chr(i) for i in range(255)) + f.request.content = bytes(bytearray(range(256))) w = flow.FlowWriter(sio) w.add(f) @@ -702,7 +632,7 @@ class TestSerialize: def test_load_flows(self): r = self._treader() s = flow.State() - fm = flow.FlowMaster(None, s) + fm = flow.FlowMaster(None, None, s) fm.load_flows(r) assert len(s.flows) == 6 @@ -713,12 +643,12 @@ class TestSerialize: mode="reverse", upstream_server=("https", ("use-this-domain", 80)) ) - fm = flow.FlowMaster(DummyServer(conf), s) + fm = flow.FlowMaster(None, DummyServer(conf), s) fm.load_flows(r) assert s.flows[0].request.host == "use-this-domain" def test_filter(self): - sio = StringIO() + sio = io.BytesIO() fl = filt.parse("~c 200") w = flow.FilteredFlowWriter(sio, fl) @@ -735,8 +665,8 @@ class TestSerialize: assert len(list(r.stream())) def test_error(self): - sio = StringIO() - sio.write("bogus") + sio = io.BytesIO() + sio.write(b"bogus") sio.seek(0) r = flow.FlowReader(sio) tutils.raises(FlowReadException, list, r.stream()) @@ -748,7 +678,7 @@ class TestSerialize: f = tutils.tflow() d = f.get_state() d["version"] = (0, 0) - sio = StringIO() + sio = io.BytesIO() tnetstring.dump(d, sio) sio.seek(0) @@ -758,32 +688,17 @@ class TestSerialize: class TestFlowMaster: - def test_load_script(self): - s = flow.State() - fm = flow.FlowMaster(None, s) - - fm.load_script(tutils.test_data.path("data/scripts/a.py")) - fm.load_script(tutils.test_data.path("data/scripts/a.py")) - fm.unload_scripts() - with tutils.raises(ScriptException): - fm.load_script("nonexistent") - try: - fm.load_script(tutils.test_data.path("data/scripts/starterr.py")) - except ScriptException as e: - assert "ValueError" in str(e) - assert len(fm.scripts) == 0 - def test_getset_ignore(self): p = mock.Mock() p.config.check_ignore = HostMatcher() - fm = flow.FlowMaster(p, flow.State()) + fm = flow.FlowMaster(None, p, flow.State()) assert not fm.get_ignore_filter() fm.set_ignore_filter(["^apple\.com:", ":443$"]) assert fm.get_ignore_filter() def test_replay(self): s = flow.State() - fm = flow.FlowMaster(None, s) + fm = flow.FlowMaster(None, None, s) f = tutils.tflow(resp=True) f.request.content = None assert "missing" in fm.replay_request(f) @@ -792,55 +707,11 @@ class TestFlowMaster: assert "intercepting" in fm.replay_request(f) f.live = True - assert "live" in fm.replay_request(f, run_scripthooks=True) - - def test_script_reqerr(self): - s = flow.State() - fm = flow.FlowMaster(None, s) - fm.load_script(tutils.test_data.path("data/scripts/reqerr.py")) - f = tutils.tflow() - fm.clientconnect(f.client_conn) - assert fm.request(f) - - def test_script(self): - s = flow.State() - fm = flow.FlowMaster(None, s) - fm.load_script(tutils.test_data.path("data/scripts/all.py")) - f = tutils.tflow(resp=True) - - f.client_conn.acked = False - fm.clientconnect(f.client_conn) - assert fm.scripts[0].ns["log"][-1] == "clientconnect" - f.server_conn.acked = False - fm.serverconnect(f.server_conn) - assert fm.scripts[0].ns["log"][-1] == "serverconnect" - f.reply.acked = False - fm.request(f) - assert fm.scripts[0].ns["log"][-1] == "request" - f.reply.acked = False - fm.response(f) - assert fm.scripts[0].ns["log"][-1] == "response" - # load second script - fm.load_script(tutils.test_data.path("data/scripts/all.py")) - assert len(fm.scripts) == 2 - f.server_conn.reply.acked = False - fm.clientdisconnect(f.server_conn) - assert fm.scripts[0].ns["log"][-1] == "clientdisconnect" - assert fm.scripts[1].ns["log"][-1] == "clientdisconnect" - - # unload first script - fm.unload_scripts() - assert len(fm.scripts) == 0 - fm.load_script(tutils.test_data.path("data/scripts/all.py")) - - f.error = tutils.terr() - f.reply.acked = False - fm.error(f) - assert fm.scripts[0].ns["log"][-1] == "error" + assert "live" in fm.replay_request(f) def test_duplicate_flow(self): s = flow.State() - fm = flow.FlowMaster(None, s) + fm = flow.FlowMaster(None, None, s) f = tutils.tflow(resp=True) fm.load_flow(f) assert s.flow_count() == 1 @@ -851,14 +722,12 @@ class TestFlowMaster: def test_create_flow(self): s = flow.State() - fm = flow.FlowMaster(None, s) + fm = flow.FlowMaster(None, None, s) assert fm.create_request("GET", "http", "example.com", 80, "/") def test_all(self): s = flow.State() - fm = flow.FlowMaster(None, s) - fm.anticache = True - fm.anticomp = True + fm = flow.FlowMaster(None, None, s) f = tutils.tflow(req=None) fm.clientconnect(f.client_conn) f.request = HTTPRequest.wrap(netlib.tutils.treq()) @@ -875,7 +744,6 @@ class TestFlowMaster: f.error.reply = controller.DummyReply() fm.error(f) - fm.load_script(tutils.test_data.path("data/scripts/a.py")) fm.shutdown() def test_client_playback(self): @@ -883,7 +751,11 @@ class TestFlowMaster: f = tutils.tflow(resp=True) pb = [tutils.tflow(resp=True), f] - fm = flow.FlowMaster(DummyServer(ProxyConfig()), s) + fm = flow.FlowMaster( + flow.options.Options(), + DummyServer(ProxyConfig()), + s + ) assert not fm.start_server_playback( pb, False, @@ -911,7 +783,7 @@ class TestFlowMaster: f.response = HTTPResponse.wrap(netlib.tutils.tresp(content=f.request)) pb = [f] - fm = flow.FlowMaster(None, s) + fm = flow.FlowMaster(flow.options.Options(), None, s) fm.refresh_server_playback = True assert not fm.do_server_playback(tutils.tflow()) @@ -938,7 +810,7 @@ class TestFlowMaster: None, False) r = tutils.tflow() - r.request.content = "gibble" + r.request.content = b"gibble" assert not fm.do_server_playback(r) assert fm.do_server_playback(tutils.tflow()) @@ -953,7 +825,7 @@ class TestFlowMaster: f = tutils.tflow() f.response = HTTPResponse.wrap(netlib.tutils.tresp(content=f.request)) pb = [f] - fm = flow.FlowMaster(None, s) + fm = flow.FlowMaster(None, None, s) fm.refresh_server_playback = True fm.start_server_playback( pb, @@ -971,74 +843,6 @@ class TestFlowMaster: fm.process_new_request(f) assert "killed" in f.error.msg - def test_stickycookie(self): - s = flow.State() - fm = flow.FlowMaster(None, s) - assert "Invalid" in fm.set_stickycookie("~h") - fm.set_stickycookie(".*") - assert fm.stickycookie_state - fm.set_stickycookie(None) - assert not fm.stickycookie_state - - fm.set_stickycookie(".*") - f = tutils.tflow(resp=True) - f.response.headers["set-cookie"] = "foo=bar" - fm.request(f) - f.reply.acked = False - fm.response(f) - assert fm.stickycookie_state.jar - assert "cookie" not in f.request.headers - f = f.copy() - f.reply.acked = False - fm.request(f) - assert f.request.headers["cookie"] == "foo=bar" - - def test_stickyauth(self): - s = flow.State() - fm = flow.FlowMaster(None, s) - assert "Invalid" in fm.set_stickyauth("~h") - fm.set_stickyauth(".*") - assert fm.stickyauth_state - fm.set_stickyauth(None) - assert not fm.stickyauth_state - - fm.set_stickyauth(".*") - f = tutils.tflow(resp=True) - f.request.headers["authorization"] = "foo" - fm.request(f) - - f = tutils.tflow(resp=True) - assert fm.stickyauth_state.hosts - assert "authorization" not in f.request.headers - fm.request(f) - assert f.request.headers["authorization"] == "foo" - - def test_stream(self): - with tutils.tmpdir() as tdir: - p = os.path.join(tdir, "foo") - - def r(): - r = flow.FlowReader(open(p, "rb")) - return list(r.stream()) - - s = flow.State() - fm = flow.FlowMaster(None, s) - f = tutils.tflow(resp=True) - - fm.start_stream(file(p, "ab"), None) - fm.request(f) - fm.response(f) - fm.stop_stream() - - assert r()[0].response - - f = tutils.tflow() - fm.start_stream(file(p, "ab"), None) - fm.request(f) - fm.shutdown() - - assert not r()[1].response - class TestRequest: @@ -1073,23 +877,14 @@ class TestRequest: assert r.url == "https://address:22/path" assert r.pretty_url == "https://foo.com:22/path" - def test_anticache(self): - r = HTTPRequest.wrap(netlib.tutils.treq()) - r.headers = Headers() - r.headers["if-modified-since"] = "test" - r.headers["if-none-match"] = "test" - r.anticache() - assert "if-modified-since" not in r.headers - assert "if-none-match" not in r.headers - def test_replace(self): r = HTTPRequest.wrap(netlib.tutils.treq()) r.path = "path/foo" r.headers["Foo"] = "fOo" - r.content = "afoob" + r.content = b"afoob" assert r.replace("foo(?i)", "boo") == 4 assert r.path == "path/boo" - assert "foo" not in r.content + assert b"foo" not in r.content assert r.headers["boo"] == "boo" def test_constrain_encoding(self): @@ -1102,16 +897,6 @@ class TestRequest: r.constrain_encoding() assert "oink" not in r.headers["accept-encoding"] - def test_get_decoded_content(self): - r = HTTPRequest.wrap(netlib.tutils.treq()) - r.content = None - r.headers["content-encoding"] = "identity" - assert r.get_decoded_content() is None - - r.content = "falafel" - r.encode("gzip") - assert r.get_decoded_content() == "falafel" - def test_get_content_type(self): resp = HTTPResponse.wrap(netlib.tutils.tresp()) resp.headers = Headers(content_type="text/plain") @@ -1129,9 +914,9 @@ class TestResponse: def test_replace(self): r = HTTPResponse.wrap(netlib.tutils.tresp()) r.headers["Foo"] = "fOo" - r.content = "afoob" + r.content = b"afoob" assert r.replace("foo(?i)", "boo") == 3 - assert "foo" not in r.content + assert b"foo" not in r.content assert r.headers["boo"] == "boo" def test_get_content_type(self): @@ -1180,104 +965,3 @@ class TestClientConnection: assert c3.get_state() == c.get_state() assert str(c) - - -def test_replacehooks(): - h = flow.ReplaceHooks() - h.add("~q", "foo", "bar") - assert h.lst - - h.set( - [ - (".*", "one", "two"), - (".*", "three", "four"), - ] - ) - assert h.count() == 2 - - h.clear() - assert not h.lst - - h.add("~q", "foo", "bar") - h.add("~s", "foo", "bar") - - v = h.get_specs() - assert v == [('~q', 'foo', 'bar'), ('~s', 'foo', 'bar')] - assert h.count() == 2 - h.clear() - assert h.count() == 0 - - f = tutils.tflow() - f.request.content = "foo" - h.add("~s", "foo", "bar") - h.run(f) - assert f.request.content == "foo" - - f = tutils.tflow(resp=True) - f.request.content = "foo" - f.response.content = "foo" - h.run(f) - assert f.response.content == "bar" - assert f.request.content == "foo" - - f = tutils.tflow() - h.clear() - h.add("~q", "foo", "bar") - f.request.content = "foo" - h.run(f) - assert f.request.content == "bar" - - assert not h.add("~", "foo", "bar") - assert not h.add("foo", "*", "bar") - - -def test_setheaders(): - h = flow.SetHeaders() - h.add("~q", "foo", "bar") - assert h.lst - - h.set( - [ - (".*", "one", "two"), - (".*", "three", "four"), - ] - ) - assert h.count() == 2 - - h.clear() - assert not h.lst - - h.add("~q", "foo", "bar") - h.add("~s", "foo", "bar") - - v = h.get_specs() - assert v == [('~q', 'foo', 'bar'), ('~s', 'foo', 'bar')] - assert h.count() == 2 - h.clear() - assert h.count() == 0 - - f = tutils.tflow() - f.request.content = "foo" - h.add("~s", "foo", "bar") - h.run(f) - assert f.request.content == "foo" - - h.clear() - h.add("~s", "one", "two") - h.add("~s", "one", "three") - f = tutils.tflow(resp=True) - f.request.headers["one"] = "xxx" - f.response.headers["one"] = "xxx" - h.run(f) - assert f.request.headers["one"] == "xxx" - assert f.response.headers.get_all("one") == ["two", "three"] - - h.clear() - h.add("~q", "one", "two") - h.add("~q", "one", "three") - f = tutils.tflow() - f.request.headers["one"] = "xxx" - h.run(f) - assert f.request.headers.get_all("one") == ["two", "three"] - - assert not h.add("~", "foo", "bar") diff --git a/test/mitmproxy/test_flow_export.py b/test/mitmproxy/test_flow_export.py index 9a263b1b..e6d65e40 100644 --- a/test/mitmproxy/test_flow_export.py +++ b/test/mitmproxy/test_flow_export.py @@ -21,15 +21,15 @@ def python_equals(testdata, text): def req_get(): - return netlib.tutils.treq(method='GET', content='', path=b"/path?a=foo&a=bar&b=baz") + return netlib.tutils.treq(method=b'GET', content=b'', path=b"/path?a=foo&a=bar&b=baz") def req_post(): - return netlib.tutils.treq(method='POST', headers=()) + return netlib.tutils.treq(method=b'POST', headers=()) def req_patch(): - return netlib.tutils.treq(method='PATCH', path=b"/path?query=param") + return netlib.tutils.treq(method=b'PATCH', path=b"/path?query=param") class TestExportCurlCommand(): @@ -60,7 +60,7 @@ class TestExportPythonCode(): def test_post_json(self): p = req_post() - p.content = '{"name": "example", "email": "example@example.com"}' + p.content = b'{"name": "example", "email": "example@example.com"}' p.headers = Headers(content_type="application/json") flow = tutils.tflow(req=p) python_equals("data/test_flow_export/python_post_json.py", export.python_code(flow)) @@ -112,7 +112,7 @@ class TestExportLocustCode(): def test_post(self): p = req_post() - p.content = '''content''' + p.content = b'content' p.headers = '' flow = tutils.tflow(req=p) python_equals("data/test_flow_export/locust_post.py", export.locust_code(flow)) @@ -142,14 +142,14 @@ class TestIsJson(): def test_json_type(self): headers = Headers(content_type="application/json") - assert export.is_json(headers, "foobar") is False + assert export.is_json(headers, b"foobar") is False def test_valid(self): headers = Headers(content_type="application/foobar") - j = export.is_json(headers, '{"name": "example", "email": "example@example.com"}') + j = export.is_json(headers, b'{"name": "example", "email": "example@example.com"}') assert j is False def test_valid2(self): headers = Headers(content_type="application/json") - j = export.is_json(headers, '{"name": "example", "email": "example@example.com"}') + j = export.is_json(headers, b'{"name": "example", "email": "example@example.com"}') assert isinstance(j, dict) diff --git a/test/mitmproxy/test_flow_format_compat.py b/test/mitmproxy/test_flow_format_compat.py index b2cef88d..cc80db81 100644 --- a/test/mitmproxy/test_flow_format_compat.py +++ b/test/mitmproxy/test_flow_format_compat.py @@ -4,7 +4,7 @@ from . import tutils def test_load(): - with open(tutils.test_data.path("data/dumpfile-013"), "rb") as f: + with open(tutils.test_data.path("data/dumpfile-011"), "rb") as f: flow_reader = FlowReader(f) flows = list(flow_reader.stream()) assert len(flows) == 1 @@ -12,7 +12,7 @@ def test_load(): def test_cannot_convert(): - with open(tutils.test_data.path("data/dumpfile-012"), "rb") as f: + with open(tutils.test_data.path("data/dumpfile-010"), "rb") as f: flow_reader = FlowReader(f) with tutils.raises(FlowReadException): list(flow_reader.stream()) diff --git a/test/mitmproxy/test_options.py b/test/mitmproxy/test_options.py new file mode 100644 index 00000000..af619b27 --- /dev/null +++ b/test/mitmproxy/test_options.py @@ -0,0 +1,100 @@ +from __future__ import absolute_import, print_function, division +import copy + +from mitmproxy import options +from mitmproxy import exceptions +from netlib import tutils + + +class TO(options.Options): + def __init__(self, one=None, two=None): + self.one = one + self.two = two + super(TO, self).__init__() + + +def test_options(): + o = TO(two="three") + assert o.one is None + assert o.two == "three" + o.one = "one" + assert o.one == "one" + + with tutils.raises(TypeError): + TO(nonexistent = "value") + with tutils.raises("no such option"): + o.nonexistent = "value" + with tutils.raises("no such option"): + o.update(nonexistent = "value") + + rec = [] + + def sub(opts): + rec.append(copy.copy(opts)) + + o.changed.connect(sub) + + o.one = "ninety" + assert len(rec) == 1 + assert rec[-1].one == "ninety" + + o.update(one="oink") + assert len(rec) == 2 + assert rec[-1].one == "oink" + + +def test_setter(): + o = TO(two="three") + f = o.setter("two") + f("xxx") + assert o.two == "xxx" + with tutils.raises("no such option"): + o.setter("nonexistent") + + +def test_toggler(): + o = TO(two=True) + f = o.toggler("two") + f() + assert o.two is False + f() + assert o.two is True + with tutils.raises("no such option"): + o.toggler("nonexistent") + + +def test_rollback(): + o = TO(one="two") + + rec = [] + + def sub(opts): + rec.append(copy.copy(opts)) + + recerr = [] + + def errsub(opts, **kwargs): + recerr.append(kwargs) + + def err(opts): + if opts.one == "ten": + raise exceptions.OptionsError() + + o.changed.connect(sub) + o.changed.connect(err) + o.errored.connect(errsub) + + o.one = "ten" + assert isinstance(recerr[0]["exc"], exceptions.OptionsError) + assert o.one == "two" + assert len(rec) == 2 + assert rec[0].one == "ten" + assert rec[1].one == "two" + + +def test_repr(): + assert repr(TO()) == "test.mitmproxy.test_options.TO({'one': None, 'two': None})" + assert repr(TO(one='x' * 60)) == """test.mitmproxy.test_options.TO({ + 'one': 'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx', + 'two': None +})""" diff --git a/test/mitmproxy/test_protocol_http2.py b/test/mitmproxy/test_protocol_http2.py index 932c8df2..b8f724bd 100644 --- a/test/mitmproxy/test_protocol_http2.py +++ b/test/mitmproxy/test_protocol_http2.py @@ -3,9 +3,10 @@ from __future__ import (absolute_import, print_function, division) import pytest -import traceback import os import tempfile +import traceback + import h2 from mitmproxy.proxy.config import ProxyConfig @@ -46,6 +47,11 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase): self.wfile.write(h2_conn.data_to_send()) self.wfile.flush() + if 'h2_server_settings' in self.kwargs: + h2_conn.update_settings(self.kwargs['h2_server_settings']) + self.wfile.write(h2_conn.data_to_send()) + self.wfile.flush() + done = False while not done: try: @@ -54,7 +60,10 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase): except HttpException: print(traceback.format_exc()) assert False + except netlib.exceptions.TcpDisconnect: + break except: + print(traceback.format_exc()) break self.wfile.write(h2_conn.data_to_send()) self.wfile.flush() @@ -64,8 +73,11 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase): if not self.server.handle_server_event(event, h2_conn, self.rfile, self.wfile): done = True break + except netlib.exceptions.TcpDisconnect: + done = True except: done = True + print(traceback.format_exc()) break def handle_server_event(self, h2_conn, rfile, wfile): @@ -75,31 +87,31 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase): class _Http2TestBase(object): @classmethod - def setup_class(self): - self.config = ProxyConfig(**self.get_proxy_config()) + def setup_class(cls): + cls.config = ProxyConfig(**cls.get_proxy_config()) - tmaster = tservers.TestMaster(self.config) + tmaster = tservers.TestMaster(cls.config) tmaster.start_app(APP_HOST, APP_PORT) - self.proxy = tservers.ProxyThread(tmaster) - self.proxy.start() + cls.proxy = tservers.ProxyThread(tmaster) + cls.proxy.start() @classmethod def teardown_class(cls): cls.proxy.shutdown() - @property - def master(self): - return self.proxy.tmaster - @classmethod def get_proxy_config(cls): cls.cadir = os.path.join(tempfile.gettempdir(), "mitmproxy") return dict( - no_upstream_cert = False, - cadir = cls.cadir, - authenticator = None, + no_upstream_cert=False, + cadir=cls.cadir, + authenticator=None, ) + @property + def master(self): + return self.proxy.tmaster + def setup(self): self.master.clear_log() self.master.state.clear() @@ -120,7 +132,7 @@ class _Http2TestBase(object): client.wfile.flush() # read CONNECT response - while client.rfile.readline() != "\r\n": + while client.rfile.readline() != b"\r\n": pass client.convert_to_ssl(alpn_protos=[b'h2']) @@ -132,11 +144,26 @@ class _Http2TestBase(object): return client, h2_conn - def _send_request(self, wfile, h2_conn, stream_id=1, headers=[], body=b''): + def _send_request(self, + wfile, + h2_conn, + stream_id=1, + headers=[], + body=b'', + end_stream=None, + priority_exclusive=None, + priority_depends_on=None, + priority_weight=None): + if end_stream is None: + end_stream = (len(body) == 0) + h2_conn.send_headers( stream_id=stream_id, headers=headers, - end_stream=(len(body) == 0), + end_stream=end_stream, + priority_exclusive=priority_exclusive, + priority_depends_on=priority_depends_on, + priority_weight=priority_weight, ) if body: h2_conn.send_data(stream_id, body) @@ -145,8 +172,7 @@ class _Http2TestBase(object): wfile.flush() -@requires_alpn -class TestSimple(_Http2TestBase, _Http2ServerBase): +class _Http2Test(_Http2TestBase, _Http2ServerBase): @classmethod def setup_class(self): @@ -158,14 +184,19 @@ class TestSimple(_Http2TestBase, _Http2ServerBase): _Http2TestBase.teardown_class() _Http2ServerBase.teardown_class() + +@requires_alpn +class TestSimple(_Http2Test): + request_body_buffer = b'' + @classmethod def handle_server_event(self, event, h2_conn, rfile, wfile): if isinstance(event, h2.events.ConnectionTerminated): return False elif isinstance(event, h2.events.RequestReceived): - assert ('client-foo', 'client-bar-1') in event.headers - assert ('client-foo', 'client-bar-2') in event.headers - + assert (b'client-foo', b'client-bar-1') in event.headers + assert (b'client-foo', b'client-bar-2') in event.headers + elif isinstance(event, h2.events.StreamEnded): import warnings with warnings.catch_warnings(): # Ignore UnicodeWarning: @@ -181,23 +212,30 @@ class TestSimple(_Http2TestBase, _Http2ServerBase): ('föo', 'bär'), ('X-Stream-ID', str(event.stream_id)), ]) - h2_conn.send_data(event.stream_id, b'foobar') + h2_conn.send_data(event.stream_id, b'response body') h2_conn.end_stream(event.stream_id) wfile.write(h2_conn.data_to_send()) wfile.flush() + elif isinstance(event, h2.events.DataReceived): + self.request_body_buffer += event.data return True def test_simple(self): + response_body_buffer = b'' client, h2_conn = self._setup_connection() - self._send_request(client.wfile, h2_conn, headers=[ - (':authority', "127.0.0.1:%s" % self.server.server.address.port), - (':method', 'GET'), - (':scheme', 'https'), - (':path', '/'), - ('ClIeNt-FoO', 'client-bar-1'), - ('ClIeNt-FoO', 'client-bar-2'), - ], body='my request body echoed back to me') + self._send_request( + client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ('ClIeNt-FoO', 'client-bar-1'), + ('ClIeNt-FoO', 'client-bar-2'), + ], + body=b'request body') done = False while not done: @@ -212,7 +250,9 @@ class TestSimple(_Http2TestBase, _Http2ServerBase): client.wfile.flush() for event in events: - if isinstance(event, h2.events.StreamEnded): + if isinstance(event, h2.events.DataReceived): + response_body_buffer += event.data + elif isinstance(event, h2.events.StreamEnded): done = True h2_conn.close_connection() @@ -223,41 +263,226 @@ class TestSimple(_Http2TestBase, _Http2ServerBase): assert self.master.state.flows[0].response.status_code == 200 assert self.master.state.flows[0].response.headers['server-foo'] == 'server-bar' assert self.master.state.flows[0].response.headers['föo'] == 'bär' - assert self.master.state.flows[0].response.body == b'foobar' + assert self.master.state.flows[0].response.body == b'response body' + assert self.request_body_buffer == b'request body' + assert response_body_buffer == b'response body' @requires_alpn -class TestWithBodies(_Http2TestBase, _Http2ServerBase): - tmp_data_buffer_foobar = b'' +class TestRequestWithPriority(_Http2Test): @classmethod - def setup_class(self): - _Http2TestBase.setup_class() - _Http2ServerBase.setup_class() + def handle_server_event(self, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.RequestReceived): + import warnings + with warnings.catch_warnings(): + # Ignore UnicodeWarning: + # h2/utilities.py:64: UnicodeWarning: Unicode equal comparison + # failed to convert both arguments to Unicode - interpreting + # them as being unequal. + # elif header[0] in (b'cookie', u'cookie') and len(header[1]) < 20: + + warnings.simplefilter("ignore") + + headers = [(':status', '200')] + if event.priority_updated: + headers.append(('priority_exclusive', event.priority_updated.exclusive)) + headers.append(('priority_depends_on', event.priority_updated.depends_on)) + headers.append(('priority_weight', event.priority_updated.weight)) + h2_conn.send_headers(event.stream_id, headers) + h2_conn.end_stream(event.stream_id) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return True + + def test_request_with_priority(self): + client, h2_conn = self._setup_connection() + + self._send_request( + client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ], + priority_exclusive = True, + priority_depends_on = 42424242, + priority_weight = 42, + ) + + done = False + while not done: + try: + raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except HttpException: + print(traceback.format_exc()) + assert False + + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == 1 + assert self.master.state.flows[0].response.headers['priority_exclusive'] == 'True' + assert self.master.state.flows[0].response.headers['priority_depends_on'] == '42424242' + assert self.master.state.flows[0].response.headers['priority_weight'] == '42' + + def test_request_without_priority(self): + client, h2_conn = self._setup_connection() + + self._send_request( + client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ], + ) + + done = False + while not done: + try: + raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except HttpException: + print(traceback.format_exc()) + assert False + + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == 1 + assert 'priority_exclusive' not in self.master.state.flows[0].response.headers + assert 'priority_depends_on' not in self.master.state.flows[0].response.headers + assert 'priority_weight' not in self.master.state.flows[0].response.headers + + +@requires_alpn +class TestPriority(_Http2Test): + priority_data = None @classmethod - def teardown_class(self): - _Http2TestBase.teardown_class() - _Http2ServerBase.teardown_class() + def handle_server_event(self, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.PriorityUpdated): + self.priority_data = (event.exclusive, event.depends_on, event.weight) + elif isinstance(event, h2.events.RequestReceived): + import warnings + with warnings.catch_warnings(): + # Ignore UnicodeWarning: + # h2/utilities.py:64: UnicodeWarning: Unicode equal comparison + # failed to convert both arguments to Unicode - interpreting + # them as being unequal. + # elif header[0] in (b'cookie', u'cookie') and len(header[1]) < 20: + + warnings.simplefilter("ignore") + + headers = [(':status', '200')] + h2_conn.send_headers(event.stream_id, headers) + h2_conn.end_stream(event.stream_id) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return True + + def test_priority(self): + client, h2_conn = self._setup_connection() + + h2_conn.prioritize(1, exclusive=True, depends_on=0, weight=42) + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + self._send_request( + client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ], + ) + + done = False + while not done: + try: + raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except HttpException: + print(traceback.format_exc()) + assert False + + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == 1 + assert self.priority_data == (True, 0, 42) + + +@requires_alpn +class TestPriorityWithExistingStream(_Http2Test): + priority_data = [] @classmethod def handle_server_event(self, event, h2_conn, rfile, wfile): if isinstance(event, h2.events.ConnectionTerminated): return False - if isinstance(event, h2.events.DataReceived): - self.tmp_data_buffer_foobar += event.data + elif isinstance(event, h2.events.PriorityUpdated): + self.priority_data.append((event.exclusive, event.depends_on, event.weight)) + elif isinstance(event, h2.events.RequestReceived): + assert not event.priority_updated + + import warnings + with warnings.catch_warnings(): + # Ignore UnicodeWarning: + # h2/utilities.py:64: UnicodeWarning: Unicode equal comparison + # failed to convert both arguments to Unicode - interpreting + # them as being unequal. + # elif header[0] in (b'cookie', u'cookie') and len(header[1]) < 20: + + warnings.simplefilter("ignore") + + headers = [(':status', '200')] + h2_conn.send_headers(event.stream_id, headers) + wfile.write(h2_conn.data_to_send()) + wfile.flush() elif isinstance(event, h2.events.StreamEnded): - h2_conn.send_headers(1, [ - (':status', '200'), - ]) - h2_conn.send_data(1, self.tmp_data_buffer_foobar) - h2_conn.end_stream(1) + h2_conn.end_stream(event.stream_id) wfile.write(h2_conn.data_to_send()) wfile.flush() - return True - def test_with_bodies(self): + def test_priority_with_existing_stream(self): client, h2_conn = self._setup_connection() self._send_request( @@ -269,9 +494,14 @@ class TestWithBodies(_Http2TestBase, _Http2ServerBase): (':scheme', 'https'), (':path', '/'), ], - body='foobar with request body', + end_stream=False, ) + h2_conn.prioritize(1, exclusive=True, depends_on=0, weight=42) + h2_conn.end_stream(1) + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + done = False while not done: try: @@ -292,21 +522,112 @@ class TestWithBodies(_Http2TestBase, _Http2ServerBase): client.wfile.write(h2_conn.data_to_send()) client.wfile.flush() - assert self.master.state.flows[0].response.body == b'foobar with request body' + assert len(self.master.state.flows) == 1 + assert self.priority_data == [(True, 0, 42)] @requires_alpn -class TestPushPromise(_Http2TestBase, _Http2ServerBase): +class TestStreamResetFromServer(_Http2Test): @classmethod - def setup_class(self): - _Http2TestBase.setup_class() - _Http2ServerBase.setup_class() + def handle_server_event(self, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.RequestReceived): + h2_conn.reset_stream(event.stream_id, 0x8) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return True + + def test_request_with_priority(self): + client, h2_conn = self._setup_connection() + + self._send_request( + client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ], + ) + + done = False + while not done: + try: + raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except HttpException: + print(traceback.format_exc()) + assert False + + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamReset): + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == 1 + assert self.master.state.flows[0].response is None + + +@requires_alpn +class TestBodySizeLimit(_Http2Test): @classmethod - def teardown_class(self): - _Http2TestBase.teardown_class() - _Http2ServerBase.teardown_class() + def handle_server_event(self, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + return True + + def test_body_size_limit(self): + self.config.body_size_limit = 20 + + client, h2_conn = self._setup_connection() + + self._send_request( + client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ], + body=b'very long body over 20 characters long', + ) + + done = False + while not done: + try: + raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + except HttpException: + print(traceback.format_exc()) + assert False + + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamReset): + done = True + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == 0 + + +@requires_alpn +class TestPushPromise(_Http2Test): @classmethod def handle_server_event(self, event, h2_conn, rfile, wfile): @@ -459,17 +780,7 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase): @requires_alpn -class TestConnectionLost(_Http2TestBase, _Http2ServerBase): - - @classmethod - def setup_class(self): - _Http2TestBase.setup_class() - _Http2ServerBase.setup_class() - - @classmethod - def teardown_class(self): - _Http2TestBase.teardown_class() - _Http2ServerBase.teardown_class() +class TestConnectionLost(_Http2Test): @classmethod def handle_server_event(self, event, h2_conn, rfile, wfile): @@ -508,3 +819,105 @@ class TestConnectionLost(_Http2TestBase, _Http2ServerBase): if len(self.master.state.flows) == 1: assert self.master.state.flows[0].response is None + + +@requires_alpn +class TestMaxConcurrentStreams(_Http2Test): + + @classmethod + def setup_class(self): + _Http2TestBase.setup_class() + _Http2ServerBase.setup_class(h2_server_settings={h2.settings.MAX_CONCURRENT_STREAMS: 2}) + + @classmethod + def handle_server_event(self, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.RequestReceived): + h2_conn.send_headers(event.stream_id, [ + (':status', '200'), + ('X-Stream-ID', str(event.stream_id)), + ]) + h2_conn.send_data(event.stream_id, 'Stream-ID {}'.format(event.stream_id).encode()) + h2_conn.end_stream(event.stream_id) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return True + + def test_max_concurrent_streams(self): + client, h2_conn = self._setup_connection() + new_streams = [1, 3, 5, 7, 9, 11] + for id in new_streams: + # this will exceed MAX_CONCURRENT_STREAMS on the server connection + # and cause mitmproxy to throttle stream creation to the server + self._send_request(client.wfile, h2_conn, stream_id=id, headers=[ + (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ('X-Stream-ID', str(id)), + ]) + + ended_streams = 0 + while ended_streams != len(new_streams): + try: + header, body = framereader.http2_read_raw_frame(client.rfile) + events = h2_conn.receive_data(b''.join([header, body])) + except: + break + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + ended_streams += 1 + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == len(new_streams) + for flow in self.master.state.flows: + assert flow.response.status_code == 200 + assert b"Stream-ID " in flow.response.body + + +@requires_alpn +class TestConnectionTerminated(_Http2Test): + + @classmethod + def handle_server_event(self, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.RequestReceived): + h2_conn.close_connection(error_code=5, last_stream_id=42, additional_data=b'foobar') + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return True + + def test_connection_terminated(self): + client, h2_conn = self._setup_connection() + + self._send_request(client.wfile, h2_conn, headers=[ + (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ]) + + done = False + connection_terminated_event = None + while not done: + try: + raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + for event in events: + if isinstance(event, h2.events.ConnectionTerminated): + connection_terminated_event = event + done = True + except: + break + + assert len(self.master.state.flows) == 1 + assert connection_terminated_event is not None + assert connection_terminated_event.error_code == 5 + assert connection_terminated_event.last_stream_id == 42 + assert connection_terminated_event.additional_data == b'foobar' diff --git a/test/mitmproxy/test_script.py b/test/mitmproxy/test_script.py deleted file mode 100644 index 81994780..00000000 --- a/test/mitmproxy/test_script.py +++ /dev/null @@ -1,13 +0,0 @@ -from mitmproxy import flow -from . import tutils - - -def test_duplicate_flow(): - s = flow.State() - fm = flow.FlowMaster(None, s) - fm.load_script(tutils.test_data.path("data/scripts/duplicate_flow.py")) - f = tutils.tflow() - fm.request(f) - assert fm.state.flow_count() == 2 - assert not fm.state.view[0].request.is_replay - assert fm.state.view[1].request.is_replay diff --git a/test/mitmproxy/test_server.py b/test/mitmproxy/test_server.py index 432340c0..2e580d47 100644 --- a/test/mitmproxy/test_server.py +++ b/test/mitmproxy/test_server.py @@ -1,6 +1,7 @@ import os import socket import time +import types from OpenSSL import SSL from netlib.exceptions import HttpReadDisconnect, HttpException from netlib.tcp import Address @@ -12,6 +13,7 @@ from netlib.http import authentication, http1 from netlib.tutils import raises from pathod import pathoc, pathod +from mitmproxy.builtins import script from mitmproxy import controller from mitmproxy.proxy.config import HostMatcher from mitmproxy.models import Error, HTTPResponse, HTTPFlow @@ -91,10 +93,10 @@ class CommonMixin: def test_invalid_http(self): t = tcp.TCPClient(("127.0.0.1", self.proxy.port)) t.connect() - t.wfile.write("invalid\r\n\r\n") + t.wfile.write(b"invalid\r\n\r\n") t.wfile.flush() line = t.rfile.readline() - assert ("Bad Request" in line) or ("Bad Gateway" in line) + assert (b"Bad Request" in line) or (b"Bad Gateway" in line) def test_sni(self): if not self.ssl: @@ -190,8 +192,9 @@ class TcpMixin: assert i_cert == i2_cert == n_cert # Make sure that TCP messages are in the event log. - assert any("305" in m for m in self.master.tlog) - assert any("306" in m for m in self.master.tlog) + # Re-enable and fix this when we start keeping TCPFlows in the state. + # assert any("305" in m for m in self.master.tlog) + # assert any("306" in m for m in self.master.tlog) class AppMixin: @@ -199,7 +202,7 @@ class AppMixin: def test_app(self): ret = self.app("/") assert ret.status_code == 200 - assert "mitmproxy" in ret.content + assert b"mitmproxy" in ret.content class TestHTTP(tservers.HTTPProxyTest, CommonMixin, AppMixin): @@ -208,14 +211,14 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin, AppMixin): p = self.pathoc() ret = p.request("get:'http://errapp/'") assert ret.status_code == 500 - assert "ValueError" in ret.content + assert b"ValueError" in ret.content def test_invalid_connect(self): t = tcp.TCPClient(("127.0.0.1", self.proxy.port)) t.connect() - t.wfile.write("CONNECT invalid\n\n") + t.wfile.write(b"CONNECT invalid\n\n") t.wfile.flush() - assert "Bad Request" in t.rfile.readline() + assert b"Bad Request" in t.rfile.readline() def test_upstream_ssl_error(self): p = self.pathoc() @@ -285,10 +288,13 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin, AppMixin): self.master.set_stream_large_bodies(None) def test_stream_modify(self): - self.master.load_script(tutils.test_data.path("data/scripts/stream_modify.py")) + s = script.Script( + tutils.test_data.path("data/addonscripts/stream_modify.py") + ) + self.master.addons.add(s) d = self.pathod('200:b"foo"') - assert d.content == "bar" - self.master.unload_scripts() + assert d.content == b"bar" + self.master.addons.remove(s) class TestHTTPAuth(tservers.HTTPProxyTest): @@ -356,7 +362,7 @@ class TestHTTPSUpstreamServerVerificationWTrustedCert(tservers.HTTPProxyTest): """ ssl = True ssloptions = pathod.SSLOptions( - cn="trusted-cert", + cn=b"trusted-cert", certs=[ ("trusted-cert", tutils.test_data.path("data/trusted-server.crt")) ]) @@ -383,7 +389,7 @@ class TestHTTPSUpstreamServerVerificationWBadCert(tservers.HTTPProxyTest): """ ssl = True ssloptions = pathod.SSLOptions( - cn="untrusted-cert", + cn=b"untrusted-cert", certs=[ ("untrusted-cert", tutils.test_data.path("data/untrusted-server.crt")) ]) @@ -423,7 +429,7 @@ class TestHTTPSNoCommonName(tservers.HTTPProxyTest): ssl = True ssloptions = pathod.SSLOptions( certs=[ - ("*", tutils.test_data.path("data/no_common_name.pem")) + (b"*", tutils.test_data.path("data/no_common_name.pem")) ] ) @@ -448,7 +454,7 @@ class TestSocks5(tservers.SocksModeTest): p = self.pathoc() f = p.request("get:/p/200") assert f.status_code == 502 - assert "SOCKS5 mode failure" in f.content + assert b"SOCKS5 mode failure" in f.content def test_no_connect(self): """ @@ -471,7 +477,7 @@ class TestSocks5(tservers.SocksModeTest): p.rfile.read(2) # read server greeting f = p.request("get:/p/200") # the request doesn't matter, error response from handshake will be read anyway. assert f.status_code == 502 - assert "SOCKS5 mode failure" in f.content + assert b"SOCKS5 mode failure" in f.content class TestHttps2Http(tservers.ReverseProxyTest): @@ -510,15 +516,15 @@ class TestTransparent(tservers.TransparentProxyTest, CommonMixin, TcpMixin): ssl = False def test_tcp_stream_modify(self): - self.master.load_script(tutils.test_data.path("data/scripts/tcp_stream_modify.py")) - + s = script.Script( + tutils.test_data.path("data/addonscripts/tcp_stream_modify.py") + ) + self.master.addons.add(s) self._tcpproxy_on() d = self.pathod('200:b"foo"') self._tcpproxy_off() - - assert d.content == "bar" - - self.master.unload_scripts() + assert d.content == b"bar" + self.master.addons.remove(s) class TestTransparentSSL(tservers.TransparentProxyTest, CommonMixin, TcpMixin): @@ -561,10 +567,10 @@ class TestProxy(tservers.HTTPProxyTest): # call pathod server, wait a second to complete the request connection.send( - "GET http://localhost:%d/p/304:b@1k HTTP/1.1\r\n" % + b"GET http://localhost:%d/p/304:b@1k HTTP/1.1\r\n" % self.server.port) time.sleep(1) - connection.send("\r\n") + connection.send(b"\r\n") connection.recv(50000) connection.close() @@ -579,17 +585,17 @@ class TestProxy(tservers.HTTPProxyTest): connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) connection.connect(("localhost", self.proxy.port)) connection.send( - "GET http://localhost:%d/p/200:b@1k HTTP/1.1\r\n" % + b"GET http://localhost:%d/p/200:b@1k HTTP/1.1\r\n" % self.server.port) - connection.send("\r\n") + connection.send(b"\r\n") # a bit hacky: make sure that we don't just read the headers only. recvd = 0 while recvd < 1024: recvd += len(connection.recv(5000)) connection.send( - "GET http://localhost:%d/p/200:b@1k HTTP/1.1\r\n" % + b"GET http://localhost:%d/p/200:b@1k HTTP/1.1\r\n" % self.server.port) - connection.send("\r\n") + connection.send(b"\r\nb") recvd = 0 while recvd < 1024: recvd += len(connection.recv(5000)) @@ -639,7 +645,7 @@ class MasterRedirectRequest(tservers.TestMaster): @controller.handler def response(self, f): - f.response.content = str(f.client_conn.address.port) + f.response.content = bytes(f.client_conn.address.port) f.response.headers["server-conn-id"] = str(f.server_conn.source_address.port) super(MasterRedirectRequest, self).response(f) @@ -721,12 +727,12 @@ class TestStreamRequest(tservers.HTTPProxyTest): def test_stream_chunked(self): connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM) connection.connect(("127.0.0.1", self.proxy.port)) - fconn = connection.makefile() + fconn = connection.makefile("rb") spec = '200:h"Transfer-Encoding"="chunked":r:b"4\\r\\nthis\\r\\n11\\r\\nisatest__reachhex\\r\\n0\\r\\n\\r\\n"' connection.send( - "GET %s/p/%s HTTP/1.1\r\n" % - (self.server.urlbase, spec)) - connection.send("\r\n") + b"GET %s/p/%s HTTP/1.1\r\n" % + (self.server.urlbase.encode(), spec.encode())) + connection.send(b"\r\n") resp = http1.read_response_head(fconn) @@ -734,7 +740,7 @@ class TestStreamRequest(tservers.HTTPProxyTest): assert resp.status_code == 200 chunks = list(http1.read_body(fconn, None)) - assert chunks == ["this", "isatest__reachhex"] + assert chunks == [b"this", b"isatest__reachhex"] connection.close() @@ -833,20 +839,15 @@ class TestUpstreamProxy(tservers.HTTPUpstreamProxyTest, CommonMixin, AppMixin): ssl = False def test_order(self): - self.proxy.tmaster.replacehooks.add( - "~q", - "foo", - "bar") # replace in request - self.chain[0].tmaster.replacehooks.add("~q", "bar", "baz") - self.chain[1].tmaster.replacehooks.add("~q", "foo", "oh noes!") - self.chain[0].tmaster.replacehooks.add( - "~s", - "baz", - "ORLY") # replace in response - + self.proxy.tmaster.options.replacements = [ + ("~q", "foo", "bar"), + ("~q", "bar", "baz"), + ("~q", "foo", "oh noes!"), + ("~s", "baz", "ORLY") + ] p = self.pathoc() req = p.request("get:'%s/p/418:b\"foo\"'" % self.server.urlbase) - assert req.content == "ORLY" + assert req.content == b"ORLY" assert req.status_code == 418 @@ -907,7 +908,7 @@ class TestUpstreamProxySSL( def test_simple(self): p = self.pathoc() req = p.request("get:'/p/418:b\"content\"'") - assert req.content == "content" + assert req.content == b"content" assert req.status_code == 418 # CONNECT from pathoc to chain[0], @@ -944,7 +945,7 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxyTest): f.reply.kill() return _func(f) - setattr(master, attr, handler) + setattr(master, attr, types.MethodType(handler, master)) kill_requests( self.chain[1].tmaster, @@ -965,7 +966,7 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxyTest): p = self.pathoc() req = p.request("get:'/p/418:b\"content\"'") - assert req.content == "content" + assert req.content == b"content" assert req.status_code == 418 assert self.proxy.tmaster.state.flow_count() == 2 # CONNECT and request @@ -1013,9 +1014,9 @@ class AddUpstreamCertsToClientChainMixin: ssl = True servercert = tutils.test_data.path("data/trusted-server.crt") ssloptions = pathod.SSLOptions( - cn="trusted-cert", + cn=b"trusted-cert", certs=[ - ("trusted-cert", servercert) + (b"trusted-cert", servercert) ] ) diff --git a/test/mitmproxy/test_web_master.py b/test/mitmproxy/test_web_master.py index 98f53c93..2ab440ce 100644 --- a/test/mitmproxy/test_web_master.py +++ b/test/mitmproxy/test_web_master.py @@ -3,15 +3,12 @@ from . import mastertest class TestWebMaster(mastertest.MasterTest): - def mkmaster(self, filt, **options): - o = master.Options( - filtstr=filt, - **options - ) + def mkmaster(self, **options): + o = master.Options(**options) return master.WebMaster(None, o) def test_basic(self): - m = self.mkmaster(None) + m = self.mkmaster() for i in (1, 2, 3): - self.dummy_cycle(m, 1, "") + self.dummy_cycle(m, 1, b"") assert len(m.state.flows) == i diff --git a/test/mitmproxy/tservers.py b/test/mitmproxy/tservers.py index 24ebb476..9b830b2d 100644 --- a/test/mitmproxy/tservers.py +++ b/test/mitmproxy/tservers.py @@ -9,7 +9,9 @@ from mitmproxy.proxy.server import ProxyServer import pathod.test import pathod.pathoc from mitmproxy import flow, controller +from mitmproxy.flow import options from mitmproxy.cmdline import APP_HOST, APP_PORT +from mitmproxy import builtins testapp = flask.Flask(__name__) @@ -34,7 +36,8 @@ class TestMaster(flow.FlowMaster): config.port = 0 s = ProxyServer(config) state = flow.State() - flow.FlowMaster.__init__(self, s, state) + flow.FlowMaster.__init__(self, options.Options(), s, state) + self.addons.add(*builtins.default_addons()) self.apps.add(testapp, "testapp", 80) self.apps.add(errapp, "errapp", 80) self.clear_log() @@ -42,7 +45,7 @@ class TestMaster(flow.FlowMaster): def clear_log(self): self.tlog = [] - def add_event(self, message, level=None): + def add_log(self, message, level=None): self.tlog.append(message) @@ -148,7 +151,6 @@ class HTTPProxyTest(ProxyTestBase): Constructs a pathod GET request, with the appropriate base and proxy. """ p = self.pathoc(sni=sni) - spec = spec.encode("string_escape") if self.ssl: q = "get:'/p/%s'" % spec else: diff --git a/test/mitmproxy/tutils.py b/test/mitmproxy/tutils.py index d0a09035..d743a9e6 100644 --- a/test/mitmproxy/tutils.py +++ b/test/mitmproxy/tutils.py @@ -4,18 +4,19 @@ import tempfile import argparse import sys -from mitmproxy.models.tcp import TCPMessage -from six.moves import cStringIO as StringIO from contextlib import contextmanager - from unittest.case import SkipTest +from six.moves import cStringIO as StringIO + import netlib.utils import netlib.tutils from mitmproxy import controller from mitmproxy.models import ( ClientConnection, ServerConnection, Error, HTTPRequest, HTTPResponse, HTTPFlow, TCPFlow ) +from mitmproxy.models.tcp import TCPMessage +from mitmproxy.models.flow import Flow def _skip_windows(*args): @@ -47,6 +48,27 @@ def skip_appveyor(fn): return fn +class DummyFlow(Flow): + """A flow that is neither HTTP nor TCP.""" + + def __init__(self, client_conn, server_conn, live=None): + super(DummyFlow, self).__init__("dummy", client_conn, server_conn, live) + + +def tdummyflow(client_conn=True, server_conn=True, err=None): + if client_conn is True: + client_conn = tclient_conn() + if server_conn is True: + server_conn = tserver_conn() + if err is True: + err = terr() + + f = DummyFlow(client_conn, server_conn) + f.error = err + f.reply = controller.DummyReply() + return f + + def ttcpflow(client_conn=True, server_conn=True, messages=True, err=None): if client_conn is True: client_conn = tclient_conn() diff --git a/test/netlib/http/http1/test_assemble.py b/test/netlib/http/http1/test_assemble.py index 50d29384..841ea58a 100644 --- a/test/netlib/http/http1/test_assemble.py +++ b/test/netlib/http/http1/test_assemble.py @@ -24,7 +24,7 @@ def test_assemble_request(): def test_assemble_request_head(): - c = assemble_request_head(treq(content="foo")) + c = assemble_request_head(treq(content=b"foo")) assert b"GET" in c assert b"qvalue" in c assert b"content-length" in c diff --git a/test/netlib/http/http1/test_read.py b/test/netlib/http/http1/test_read.py index 5285ac1d..c8a40ecb 100644 --- a/test/netlib/http/http1/test_read.py +++ b/test/netlib/http/http1/test_read.py @@ -1,6 +1,9 @@ from __future__ import absolute_import, print_function, division + from io import BytesIO from mock import Mock +import pytest + from netlib.exceptions import HttpException, HttpSyntaxException, HttpReadDisconnect, TcpDisconnect from netlib.http import Headers from netlib.http.http1.read import ( @@ -23,11 +26,18 @@ def test_get_header_tokens(): assert get_header_tokens(headers, "foo") == ["bar", "voing", "oink"] -def test_read_request(): - rfile = BytesIO(b"GET / HTTP/1.1\r\n\r\nskip") +@pytest.mark.parametrize("input", [ + b"GET / HTTP/1.1\r\n\r\nskip", + b"GET / HTTP/1.1\r\n\r\nskip", + b"GET / HTTP/1.1\r\n\r\nskip", + b"GET / HTTP/1.1 \r\n\r\nskip", +]) +def test_read_request(input): + rfile = BytesIO(input) r = read_request(rfile) assert r.method == "GET" assert r.content == b"" + assert r.http_version == "HTTP/1.1" assert r.timestamp_end assert rfile.read() == b"skip" @@ -50,11 +60,19 @@ def test_read_request_head(): assert rfile.read() == b"skip" -def test_read_response(): +@pytest.mark.parametrize("input", [ + b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody", + b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody", + b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody", + b"HTTP/1.1 418 I'm a teapot \r\n\r\nbody", +]) +def test_read_response(input): req = treq() - rfile = BytesIO(b"HTTP/1.1 418 I'm a teapot\r\n\r\nbody") + rfile = BytesIO(input) r = read_response(rfile, req) + assert r.http_version == "HTTP/1.1" assert r.status_code == 418 + assert r.reason == "I'm a teapot" assert r.content == b"body" assert r.timestamp_end diff --git a/test/netlib/http/test_cookies.py b/test/netlib/http/test_cookies.py index 83b85656..17e21b94 100644 --- a/test/netlib/http/test_cookies.py +++ b/test/netlib/http/test_cookies.py @@ -245,3 +245,24 @@ def test_refresh_cookie(): assert cookies.refresh_set_cookie_header(c, 0) c = "foo/bar=bla" assert cookies.refresh_set_cookie_header(c, 0) + + +def test_is_expired(): + CA = cookies.CookieAttrs + + # A cookie can be expired + # by setting the expire time in the past + assert cookies.is_expired(CA([("Expires", "Thu, 01-Jan-1970 00:00:00 GMT")])) + + # or by setting Max-Age to 0 + assert cookies.is_expired(CA([("Max-Age", "0")])) + + # or both + assert cookies.is_expired(CA([("Expires", "Thu, 01-Jan-1970 00:00:00 GMT"), ("Max-Age", "0")])) + + assert not cookies.is_expired(CA([("Expires", "Thu, 24-Aug-2063 00:00:00 GMT")])) + assert not cookies.is_expired(CA([("Max-Age", "1")])) + assert not cookies.is_expired(CA([("Expires", "Thu, 15-Jul-2068 00:00:00 GMT"), ("Max-Age", "1")])) + + assert not cookies.is_expired(CA([("Max-Age", "nan")])) + assert not cookies.is_expired(CA([("Expires", "false")])) diff --git a/test/netlib/http/test_headers.py b/test/netlib/http/test_headers.py index 51819b86..51537310 100644 --- a/test/netlib/http/test_headers.py +++ b/test/netlib/http/test_headers.py @@ -1,4 +1,6 @@ -from netlib.http import Headers, parse_content_type +import collections + +from netlib.http.headers import Headers, parse_content_type, assemble_content_type from netlib.tutils import raises @@ -81,3 +83,10 @@ def test_parse_content_type(): v = p("text/html; charset=UTF-8") assert v == ('text', 'html', {'charset': 'UTF-8'}) + + +def test_assemble_content_type(): + p = assemble_content_type + assert p("text", "html", {}) == "text/html" + assert p("text", "html", {"charset": "utf8"}) == "text/html; charset=utf8" + assert p("text", "html", collections.OrderedDict([("charset", "utf8"), ("foo", "bar")])) == "text/html; charset=utf8; foo=bar" diff --git a/test/netlib/http/test_message.py b/test/netlib/http/test_message.py index f5bf7f0c..deebd6f2 100644 --- a/test/netlib/http/test_message.py +++ b/test/netlib/http/test_message.py @@ -1,14 +1,17 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, print_function, division -from netlib.http import decoded +import mock +import six + from netlib.tutils import tresp +from netlib import http, tutils def _test_passthrough_attr(message, attr): assert getattr(message, attr) == getattr(message.data, attr) - setattr(message, attr, "foo") - assert getattr(message.data, attr) == "foo" + setattr(message, attr, b"foo") + assert getattr(message.data, attr) == b"foo" def _test_decoded_attr(message, attr): @@ -68,6 +71,15 @@ class TestMessage(object): assert resp != 0 + def test_hash(self): + resp = tresp() + assert hash(resp) + + def test_serializable(self): + resp = tresp() + resp2 = http.Response.from_state(resp.get_state()) + assert resp == resp2 + def test_content_length_update(self): resp = tresp() resp.content = b"foo" @@ -76,9 +88,9 @@ class TestMessage(object): resp.content = b"" assert resp.data.content == b"" assert resp.headers["content-length"] == "0" - - def test_content_basic(self): - _test_passthrough_attr(tresp(), "content") + resp.raw_content = b"bar" + assert resp.data.content == b"bar" + assert resp.headers["content-length"] == "0" def test_headers(self): _test_passthrough_attr(tresp(), "headers") @@ -89,65 +101,201 @@ class TestMessage(object): def test_timestamp_end(self): _test_passthrough_attr(tresp(), "timestamp_end") - def teste_http_version(self): + def test_http_version(self): _test_decoded_attr(tresp(), "http_version") -class TestDecodedDecorator(object): - +class TestMessageContentEncoding(object): def test_simple(self): r = tresp() - assert r.content == b"message" + assert r.raw_content == b"message" assert "content-encoding" not in r.headers - assert r.encode("gzip") + r.encode("gzip") assert r.headers["content-encoding"] - assert r.content != b"message" - with decoded(r): - assert "content-encoding" not in r.headers - assert r.content == b"message" - assert r.headers["content-encoding"] - assert r.content != b"message" + assert r.raw_content != b"message" + assert r.content == b"message" + assert r.raw_content != b"message" + + r.raw_content = b"foo" + with mock.patch("netlib.encoding.decode") as e: + assert r.content + assert e.call_count == 1 + e.reset_mock() + assert r.content + assert e.call_count == 0 def test_modify(self): r = tresp() assert "content-encoding" not in r.headers - assert r.encode("gzip") + r.encode("gzip") - with decoded(r): + r.content = b"foo" + assert r.raw_content != b"foo" + r.decode() + assert r.raw_content == b"foo" + + r.encode("identity") + with mock.patch("netlib.encoding.encode") as e: r.content = b"foo" + assert e.call_count == 0 + r.content = b"bar" + assert e.call_count == 1 - assert r.content != b"foo" - r.decode() - assert r.content == b"foo" + with tutils.raises(TypeError): + r.content = u"foo" def test_unknown_ce(self): r = tresp() r.headers["content-encoding"] = "zopfli" - r.content = b"foo" - with decoded(r): - assert r.headers["content-encoding"] - assert r.content == b"foo" + r.raw_content = b"foo" + with tutils.raises(ValueError): + assert r.content assert r.headers["content-encoding"] - assert r.content == b"foo" + assert r.get_content(strict=False) == b"foo" def test_cannot_decode(self): r = tresp() - assert r.encode("gzip") - r.content = b"foo" - with decoded(r): - assert r.headers["content-encoding"] - assert r.content == b"foo" + r.encode("gzip") + r.raw_content = b"foo" + with tutils.raises(ValueError): + assert r.content assert r.headers["content-encoding"] - assert r.content != b"foo" - r.decode() + assert r.get_content(strict=False) == b"foo" + + with tutils.raises(ValueError): + r.decode() + assert r.raw_content == b"foo" + assert "content-encoding" in r.headers + + r.decode(strict=False) assert r.content == b"foo" + assert "content-encoding" not in r.headers + + def test_none(self): + r = tresp(content=None) + assert r.content is None + r.content = b"foo" + assert r.content is not None + r.content = None + assert r.content is None def test_cannot_encode(self): r = tresp() - assert r.encode("gzip") - with decoded(r): - r.content = None + r.encode("gzip") + r.content = None + assert r.headers["content-encoding"] + assert r.raw_content is None + r.headers["content-encoding"] = "zopfli" + r.content = b"foo" assert "content-encoding" not in r.headers - assert r.content is None + assert r.raw_content == b"foo" + + with tutils.raises(ValueError): + r.encode("zopfli") + assert r.raw_content == b"foo" + assert "content-encoding" not in r.headers + + +class TestMessageText(object): + def test_simple(self): + r = tresp(content=b'\xfc') + assert r.raw_content == b"\xfc" + assert r.content == b"\xfc" + assert r.text == u"ü" + + r.encode("gzip") + assert r.text == u"ü" + r.decode() + assert r.text == u"ü" + + r.headers["content-type"] = "text/html; charset=latin1" + r.content = b"\xc3\xbc" + assert r.text == u"ü" + r.headers["content-type"] = "text/html; charset=utf8" + assert r.text == u"ü" + + r.encode("identity") + r.raw_content = b"foo" + with mock.patch("netlib.encoding.decode") as e: + assert r.text + assert e.call_count == 2 + e.reset_mock() + assert r.text + assert e.call_count == 0 + + def test_guess_json(self): + r = tresp(content=b'"\xc3\xbc"') + r.headers["content-type"] = "application/json" + assert r.text == u'"ü"' + + def test_none(self): + r = tresp(content=None) + assert r.text is None + r.text = u"foo" + assert r.text is not None + r.text = None + assert r.text is None + + def test_modify(self): + r = tresp() + + r.text = u"ü" + assert r.raw_content == b"\xfc" + + r.headers["content-type"] = "text/html; charset=utf8" + r.text = u"ü" + assert r.raw_content == b"\xc3\xbc" + assert r.headers["content-length"] == "2" + + r.encode("identity") + with mock.patch("netlib.encoding.encode") as e: + e.return_value = b"" + r.text = u"ü" + assert e.call_count == 0 + r.text = u"ä" + assert e.call_count == 2 + + def test_unknown_ce(self): + r = tresp() + r.headers["content-type"] = "text/html; charset=wtf" + r.raw_content = b"foo" + with tutils.raises(ValueError): + assert r.text == u"foo" + assert r.get_text(strict=False) == u"foo" + + def test_cannot_decode(self): + r = tresp() + r.headers["content-type"] = "text/html; charset=utf8" + r.raw_content = b"\xFF" + with tutils.raises(ValueError): + assert r.text + + assert r.get_text(strict=False) == u'\ufffd' if six.PY2 else '\udcff' + + def test_cannot_encode(self): + r = tresp() + r.content = None + assert "content-type" not in r.headers + assert r.raw_content is None + + r.headers["content-type"] = "text/html; charset=latin1; foo=bar" + r.text = u"☃" + assert r.headers["content-type"] == "text/html; charset=utf-8; foo=bar" + assert r.raw_content == b'\xe2\x98\x83' + + r.headers["content-type"] = "gibberish" + r.text = u"☃" + assert r.headers["content-type"] == "text/plain; charset=utf-8" + assert r.raw_content == b'\xe2\x98\x83' + + del r.headers["content-type"] + r.text = u"☃" + assert r.headers["content-type"] == "text/plain; charset=utf-8" + assert r.raw_content == b'\xe2\x98\x83' + + r.headers["content-type"] = "text/html; charset=latin1" + r.text = u'\udcff' + assert r.headers["content-type"] == "text/html; charset=utf-8" + assert r.raw_content == b'\xed\xb3\xbf' if six.PY2 else b"\xFF" diff --git a/test/netlib/http/test_request.py b/test/netlib/http/test_request.py index c03db339..f3cd8b71 100644 --- a/test/netlib/http/test_request.py +++ b/test/netlib/http/test_request.py @@ -248,20 +248,20 @@ class TestRequestUtils(object): assert "gzip" in request.headers["Accept-Encoding"] def test_get_urlencoded_form(self): - request = treq(content="foobar=baz") + request = treq(content=b"foobar=baz") assert not request.urlencoded_form request.headers["Content-Type"] = "application/x-www-form-urlencoded" - assert list(request.urlencoded_form.items()) == [("foobar", "baz")] + assert list(request.urlencoded_form.items()) == [(b"foobar", b"baz")] def test_set_urlencoded_form(self): request = treq() - request.urlencoded_form = [('foo', 'bar'), ('rab', 'oof')] + request.urlencoded_form = [(b'foo', b'bar'), (b'rab', b'oof')] assert request.headers["Content-Type"] == "application/x-www-form-urlencoded" assert request.content def test_get_multipart_form(self): - request = treq(content="foobar") + request = treq(content=b"foobar") assert not request.multipart_form request.headers["Content-Type"] = "multipart/form-data" diff --git a/test/netlib/test_encoding.py b/test/netlib/test_encoding.py index 0ff1aad1..de10fc48 100644 --- a/test/netlib/test_encoding.py +++ b/test/netlib/test_encoding.py @@ -1,37 +1,39 @@ -from netlib import encoding +from netlib import encoding, tutils def test_identity(): - assert b"string" == encoding.decode("identity", b"string") - assert b"string" == encoding.encode("identity", b"string") - assert not encoding.encode("nonexistent", b"string") - assert not encoding.decode("nonexistent encoding", b"string") + assert b"string" == encoding.decode(b"string", "identity") + assert b"string" == encoding.encode(b"string", "identity") + with tutils.raises(ValueError): + encoding.encode(b"string", "nonexistent encoding") def test_gzip(): assert b"string" == encoding.decode( - "gzip", encoding.encode( - "gzip", - b"string" - ) + b"string", + "gzip" + ), + "gzip" ) - assert encoding.decode("gzip", b"bogus") is None + with tutils.raises(ValueError): + encoding.decode(b"bogus", "gzip") def test_deflate(): assert b"string" == encoding.decode( - "deflate", encoding.encode( - "deflate", - b"string" - ) + b"string", + "deflate" + ), + "deflate" ) assert b"string" == encoding.decode( - "deflate", encoding.encode( - "deflate", - b"string" - )[2:-4] + b"string", + "deflate" + )[2:-4], + "deflate" ) - assert encoding.decode("deflate", b"bogus") is None + with tutils.raises(ValueError): + encoding.decode(b"bogus", "deflate") diff --git a/test/netlib/test_strutils.py b/test/netlib/test_strutils.py index 84a0dded..7c3eacc6 100644 --- a/test/netlib/test_strutils.py +++ b/test/netlib/test_strutils.py @@ -1,9 +1,15 @@ -# coding=utf-8 import six from netlib import strutils, tutils +def test_always_bytes(): + assert strutils.always_bytes(bytes(bytearray(range(256)))) == bytes(bytearray(range(256))) + assert strutils.always_bytes("foo") == b"foo" + with tutils.raises(ValueError): + strutils.always_bytes(u"\u2605", "ascii") + + def test_native(): with tutils.raises(TypeError): strutils.native(42) @@ -15,22 +21,26 @@ def test_native(): assert strutils.native(b"foo") == u"foo" -def test_clean_bin(): - assert strutils.clean_bin(b"one") == b"one" - assert strutils.clean_bin(b"\00ne") == b".ne" - assert strutils.clean_bin(b"\nne") == b"\nne" - assert strutils.clean_bin(b"\nne", False) == b".ne" - assert strutils.clean_bin(u"\u2605".encode("utf8")) == b"..." - - assert strutils.clean_bin(u"one") == u"one" - assert strutils.clean_bin(u"\00ne") == u".ne" - assert strutils.clean_bin(u"\nne") == u"\nne" - assert strutils.clean_bin(u"\nne", False) == u".ne" - assert strutils.clean_bin(u"\u2605") == u"\u2605" - - -def test_safe_subn(): - assert strutils.safe_subn("foo", u"bar", "\xc2foo") +def test_escape_control_characters(): + assert strutils.escape_control_characters(u"one") == u"one" + assert strutils.escape_control_characters(u"\00ne") == u".ne" + assert strutils.escape_control_characters(u"\nne") == u"\nne" + assert strutils.escape_control_characters(u"\nne", False) == u".ne" + assert strutils.escape_control_characters(u"\u2605") == u"\u2605" + assert ( + strutils.escape_control_characters(bytes(bytearray(range(128))).decode()) == + u'.........\t\n..\r.................. !"#$%&\'()*+,-./0123456789:;<' + u'=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~.' + ) + assert ( + strutils.escape_control_characters(bytes(bytearray(range(128))).decode(), False) == + u'................................ !"#$%&\'()*+,-./0123456789:;<' + u'=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~.' + ) + + if not six.PY2: + with tutils.raises(ValueError): + strutils.escape_control_characters(b"foo") def test_bytes_to_escaped_str(): @@ -41,6 +51,14 @@ def test_bytes_to_escaped_str(): assert strutils.bytes_to_escaped_str(b"'") == r"\'" assert strutils.bytes_to_escaped_str(b'"') == r'"' + assert strutils.bytes_to_escaped_str(b"\r\n\t") == "\\r\\n\\t" + assert strutils.bytes_to_escaped_str(b"\r\n\t", True) == "\r\n\t" + + assert strutils.bytes_to_escaped_str(b"\n", True) == "\n" + assert strutils.bytes_to_escaped_str(b"\\n", True) == "\\ \\ n".replace(" ", "") + assert strutils.bytes_to_escaped_str(b"\\\n", True) == "\\ \\ \n".replace(" ", "") + assert strutils.bytes_to_escaped_str(b"\\\\n", True) == "\\ \\ \\ \\ n".replace(" ", "") + with tutils.raises(ValueError): strutils.bytes_to_escaped_str(u"such unicode") @@ -49,10 +67,9 @@ def test_escaped_str_to_bytes(): assert strutils.escaped_str_to_bytes("foo") == b"foo" assert strutils.escaped_str_to_bytes("\x08") == b"\b" assert strutils.escaped_str_to_bytes("&!?=\\\\)") == br"&!?=\)" - assert strutils.escaped_str_to_bytes("ü") == b'\xc3\xbc' assert strutils.escaped_str_to_bytes(u"\\x08") == b"\b" assert strutils.escaped_str_to_bytes(u"&!?=\\\\)") == br"&!?=\)" - assert strutils.escaped_str_to_bytes(u"ü") == b'\xc3\xbc' + assert strutils.escaped_str_to_bytes(u"\u00fc") == b'\xc3\xbc' if six.PY2: with tutils.raises(ValueError): @@ -62,17 +79,15 @@ def test_escaped_str_to_bytes(): strutils.escaped_str_to_bytes(b"very byte") -def test_isBin(): - assert not strutils.isBin("testing\n\r") - assert strutils.isBin("testing\x01") - assert strutils.isBin("testing\x0e") - assert strutils.isBin("testing\x7f") +def test_is_mostly_bin(): + assert not strutils.is_mostly_bin(b"foo\xFF") + assert strutils.is_mostly_bin(b"foo" + b"\xFF" * 10) -def test_isXml(): - assert not strutils.isXML("foo") - assert strutils.isXML("<foo") - assert strutils.isXML(" \n<foo") +def test_is_xml(): + assert not strutils.is_xml(b"foo") + assert strutils.is_xml(b"<foo") + assert strutils.is_xml(b" \n<foo") def test_clean_hanging_newline(): diff --git a/test/netlib/test_tcp.py b/test/netlib/test_tcp.py index 590bcc01..273427d5 100644 --- a/test/netlib/test_tcp.py +++ b/test/netlib/test_tcp.py @@ -169,7 +169,7 @@ class TestServerSSL(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni=b"foo.com", options=SSL.OP_ALL) + c.convert_to_ssl(sni="foo.com", options=SSL.OP_ALL) testval = b"echo!\n" c.wfile.write(testval) c.wfile.flush() @@ -179,7 +179,7 @@ class TestServerSSL(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): assert not c.get_current_cipher() - c.convert_to_ssl(sni=b"foo.com") + c.convert_to_ssl(sni="foo.com") ret = c.get_current_cipher() assert ret assert "AES" in ret[0] @@ -195,7 +195,7 @@ class TestSSLv3Only(tservers.ServerTestBase): def test_failure(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - tutils.raises(TlsException, c.convert_to_ssl, sni=b"foo.com") + tutils.raises(TlsException, c.convert_to_ssl, sni="foo.com") class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): @@ -238,7 +238,7 @@ class TestSSLUpstreamCertVerificationWBadServerCert(tservers.ServerTestBase): with c.connect(): with tutils.raises(InvalidCertificateException): c.convert_to_ssl( - sni=b"example.mitmproxy.org", + sni="example.mitmproxy.org", verify_options=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt") ) @@ -272,7 +272,7 @@ class TestSSLUpstreamCertVerificationWBadHostname(tservers.ServerTestBase): with c.connect(): with tutils.raises(InvalidCertificateException): c.convert_to_ssl( - sni=b"mitmproxy.org", + sni="mitmproxy.org", verify_options=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt") ) @@ -291,7 +291,7 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): c.convert_to_ssl( - sni=b"example.mitmproxy.org", + sni="example.mitmproxy.org", verify_options=SSL.VERIFY_PEER, ca_pemfile=tutils.test_data.path("data/verificationcerts/trusted-root.crt") ) @@ -307,7 +307,7 @@ class TestSSLUpstreamCertVerificationWValidCertChain(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): c.convert_to_ssl( - sni=b"example.mitmproxy.org", + sni="example.mitmproxy.org", verify_options=SSL.VERIFY_PEER, ca_path=tutils.test_data.path("data/verificationcerts/") ) @@ -371,8 +371,8 @@ class TestSNI(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni=b"foo.com") - assert c.sni == b"foo.com" + c.convert_to_ssl(sni="foo.com") + assert c.sni == "foo.com" assert c.rfile.readline() == b"foo.com" @@ -385,7 +385,7 @@ class TestServerCipherList(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni=b"foo.com") + c.convert_to_ssl(sni="foo.com") assert c.rfile.readline() == b"['RC4-SHA']" @@ -405,7 +405,7 @@ class TestServerCurrentCipher(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - c.convert_to_ssl(sni=b"foo.com") + c.convert_to_ssl(sni="foo.com") assert b"RC4-SHA" in c.rfile.readline() @@ -418,7 +418,7 @@ class TestServerCipherListError(tservers.ServerTestBase): def test_echo(self): c = tcp.TCPClient(("127.0.0.1", self.port)) with c.connect(): - tutils.raises("handshake error", c.convert_to_ssl, sni=b"foo.com") + tutils.raises("handshake error", c.convert_to_ssl, sni="foo.com") class TestClientCipherListError(tservers.ServerTestBase): @@ -433,7 +433,7 @@ class TestClientCipherListError(tservers.ServerTestBase): tutils.raises( "cipher specification", c.convert_to_ssl, - sni=b"foo.com", + sni="foo.com", cipher_list="bogus" ) diff --git a/test/netlib/tservers.py b/test/netlib/tservers.py index 803aaa72..666f97ac 100644 --- a/test/netlib/tservers.py +++ b/test/netlib/tservers.py @@ -24,7 +24,7 @@ class _ServerThread(threading.Thread): class _TServer(tcp.TCPServer): - def __init__(self, ssl, q, handler_klass, addr): + def __init__(self, ssl, q, handler_klass, addr, **kwargs): """ ssl: A dictionary of SSL parameters: @@ -42,6 +42,8 @@ class _TServer(tcp.TCPServer): self.q = q self.handler_klass = handler_klass + if self.handler_klass is not None: + self.handler_klass.kwargs = kwargs self.last_handler = None def handle_client_connection(self, request, client_address): @@ -89,16 +91,16 @@ class ServerTestBase(object): addr = ("localhost", 0) @classmethod - def setup_class(cls): + def setup_class(cls, **kwargs): cls.q = queue.Queue() - s = cls.makeserver() + s = cls.makeserver(**kwargs) cls.port = s.address.port cls.server = _ServerThread(s) cls.server.start() @classmethod - def makeserver(cls): - return _TServer(cls.ssl, cls.q, cls.handler, cls.addr) + def makeserver(cls, **kwargs): + return _TServer(cls.ssl, cls.q, cls.handler, cls.addr, **kwargs) @classmethod def teardown_class(cls): diff --git a/test/pathod/test_pathoc.py b/test/pathod/test_pathoc.py index 28f9f0f8..361a863b 100644 --- a/test/pathod/test_pathoc.py +++ b/test/pathod/test_pathoc.py @@ -54,10 +54,10 @@ class TestDaemonSSL(PathocTestDaemon): def test_sni(self): self.tval( ["get:/p/200"], - sni=b"foobar.com" + sni="foobar.com" ) log = self.d.log() - assert log[0]["request"]["sni"] == b"foobar.com" + assert log[0]["request"]["sni"] == "foobar.com" def test_showssl(self): assert "certificate chain" in self.tval(["get:/p/200"], showssl=True) diff --git a/test/pathod/test_protocols_http2.py b/test/pathod/test_protocols_http2.py index e42c2858..8d7efc82 100644 --- a/test/pathod/test_protocols_http2.py +++ b/test/pathod/test_protocols_http2.py @@ -367,37 +367,6 @@ class TestReadRequestAbsolute(netlib_tservers.ServerTestBase): assert req.port == 22 -class TestReadRequestConnect(netlib_tservers.ServerTestBase): - class handler(tcp.BaseHandler): - def handle(self): - self.wfile.write( - codecs.decode('00001b0105000000014287bdab4e9c17b7ff44871c92585422e08541871c92585422e085', 'hex_codec')) - self.wfile.write( - codecs.decode('00001d0105000000014287bdab4e9c17b7ff44882f91d35d055c87a741882f91d35d055c87a7', 'hex_codec')) - self.wfile.flush() - - ssl = True - - def test_connect(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_ssl() - protocol = HTTP2StateProtocol(c, is_server=True) - protocol.connection_preface_performed = True - - req = protocol.read_request(NotImplemented) - assert req.first_line_format == "authority" - assert req.method == "CONNECT" - assert req.host == "address" - assert req.port == 22 - - req = protocol.read_request(NotImplemented) - assert req.first_line_format == "authority" - assert req.method == "CONNECT" - assert req.host == "example.com" - assert req.port == 443 - - class TestReadResponse(netlib_tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): @@ -1,27 +1,20 @@ [tox] envlist = py27, py35, docs, lint +skipsdist = True [testenv] -usedevelop=True deps = {env:CI_DEPS:} -rrequirements.txt passenv = CODECOV_TOKEN CI CI_* TRAVIS TRAVIS_* APPVEYOR APPVEYOR_* -setenv = - TESTS = test/ - HOME = {envtmpdir} +setenv = HOME = {envtmpdir} commands = - py.test --timeout 60 {posargs} {env:TESTS} + py.test --timeout 60 {posargs} {env:CI_COMMANDS:python -c ""} -[testenv:py35] -setenv = - TESTS = test/netlib test/pathod/ test/mitmproxy/script test/mitmproxy/test_contentview.py test/mitmproxy/test_custom_contentview.py test/mitmproxy/test_app.py test/mitmproxy/test_controller.py test/mitmproxy/test_fuzzing.py test/mitmproxy/test_script.py test/mitmproxy/test_web_app.py test/mitmproxy/test_utils.py test/mitmproxy/test_stateobject.py test/mitmproxy/test_cmdline.py test/mitmproxy/test_contrib_tnetstring.py test/mitmproxy/test_proxy.py test/mitmproxy/test_protocol_http1.py test/mitmproxy/test_platform_pf.py - HOME = {envtmpdir} - [testenv:docs] changedir = docs -commands = sphinx-build -W -b html -d {envtmpdir}/doctrees . {envtmpdir}/html +commands = sphinx-build -W -b html -d {envtmpdir}/doctrees . {envtmpdir}/html [testenv:lint] deps = flake8>=2.6.2, <3 diff --git a/web/src/js/components/Header/FlowMenu.jsx b/web/src/js/components/Header/FlowMenu.jsx index 8d13dd6a..9855cde3 100644 --- a/web/src/js/components/Header/FlowMenu.jsx +++ b/web/src/js/components/Header/FlowMenu.jsx @@ -15,11 +15,11 @@ function FlowMenu({ flow, acceptFlow, replayFlow, duplicateFlow, removeFlow, rev return ( <div> <div className="menu-row"> - <Button disabled={!flow.intercepted} title="[a]ccept intercepted flow" text="Accept" icon="fa-play" onClick={() => acceptFlow(flow)} /> + <Button disabled={!flow || !flow.intercepted} title="[a]ccept intercepted flow" text="Accept" icon="fa-play" onClick={() => acceptFlow(flow)} /> <Button title="[r]eplay flow" text="Replay" icon="fa-repeat" onClick={() => replayFlow(flow)} /> <Button title="[D]uplicate flow" text="Duplicate" icon="fa-copy" onClick={() => duplicateFlow(flow)} /> <Button title="[d]elete flow" text="Delete" icon="fa-trash" onClick={() => removeFlow(flow)}/> - <Button disabled={!flow.modified} title="revert changes to flow [V]" text="Revert" icon="fa-history" onClick={() => revertFlow(flow)} /> + <Button disabled={!flow || !flow.modified} title="revert changes to flow [V]" text="Revert" icon="fa-history" onClick={() => revertFlow(flow)} /> <Button title="download" text="Download" icon="fa-download" onClick={() => window.location = MessageUtils.getContentURL(flow, flow.response)}/> </div> <div className="clearfix"/> |