aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.travis.yml8
-rw-r--r--examples/filt.py20
-rw-r--r--examples/flowwriter.py25
-rw-r--r--examples/iframe_injector.py37
-rw-r--r--examples/remote_debug.py19
-rw-r--r--examples/stub.py2
-rw-r--r--mitmproxy/addons.py15
-rw-r--r--mitmproxy/builtins/anticache.py2
-rw-r--r--mitmproxy/builtins/anticomp.py2
-rw-r--r--mitmproxy/builtins/dumper.py79
-rw-r--r--mitmproxy/builtins/filestreamer.py2
-rw-r--r--mitmproxy/builtins/replace.py2
-rw-r--r--mitmproxy/builtins/script.py95
-rw-r--r--mitmproxy/builtins/setheaders.py2
-rw-r--r--mitmproxy/builtins/stickyauth.py2
-rw-r--r--mitmproxy/builtins/stickycookie.py2
-rw-r--r--mitmproxy/console/common.py23
-rw-r--r--mitmproxy/console/flowlist.py32
-rw-r--r--mitmproxy/console/flowview.py303
-rw-r--r--mitmproxy/console/help.py4
-rw-r--r--mitmproxy/console/master.py119
-rw-r--r--mitmproxy/console/options.py2
-rw-r--r--mitmproxy/console/searchable.py4
-rw-r--r--mitmproxy/console/statusbar.py6
-rw-r--r--mitmproxy/console/tabs.py2
-rw-r--r--mitmproxy/contentviews.py17
-rw-r--r--mitmproxy/controller.py2
-rw-r--r--mitmproxy/ctx.py2
-rw-r--r--mitmproxy/dump.py4
-rw-r--r--mitmproxy/filt.py16
-rw-r--r--mitmproxy/flow/io_compat.py1
-rw-r--r--mitmproxy/models/flow.py24
-rw-r--r--mitmproxy/models/http.py19
-rw-r--r--mitmproxy/models/tcp.py21
-rw-r--r--mitmproxy/optmanager.py18
-rw-r--r--mitmproxy/protocol/http2.py2
-rw-r--r--mitmproxy/proxy/config.py4
-rw-r--r--mitmproxy/web/app.py3
-rw-r--r--mitmproxy/web/master.py2
-rw-r--r--netlib/encoding.py49
-rw-r--r--netlib/http/message.py121
-rw-r--r--netlib/http/request.py26
-rw-r--r--netlib/http/url.py41
-rw-r--r--netlib/multidict.py6
-rw-r--r--netlib/strutils.py6
-rw-r--r--test/mitmproxy/builtins/test_anticache.py5
-rw-r--r--test/mitmproxy/builtins/test_anticomp.py5
-rw-r--r--test/mitmproxy/builtins/test_dumper.py29
-rw-r--r--test/mitmproxy/builtins/test_filestreamer.py13
-rw-r--r--test/mitmproxy/builtins/test_replace.py26
-rw-r--r--test/mitmproxy/builtins/test_script.py49
-rw-r--r--test/mitmproxy/builtins/test_setheaders.py13
-rw-r--r--test/mitmproxy/builtins/test_stickyauth.py5
-rw-r--r--test/mitmproxy/builtins/test_stickycookie.py11
-rw-r--r--test/mitmproxy/data/addonscripts/addon.py22
-rw-r--r--test/mitmproxy/data/addonscripts/recorder.py36
-rw-r--r--test/mitmproxy/data/dumpfile-011bin5465 -> 5046 bytes
-rw-r--r--test/mitmproxy/script/test_concurrent.py2
-rw-r--r--test/mitmproxy/test_addons.py5
-rw-r--r--test/mitmproxy/test_contentview.py4
-rw-r--r--test/mitmproxy/test_examples.py5
-rw-r--r--test/mitmproxy/test_flow.py2
-rw-r--r--test/mitmproxy/test_optmanager.py8
-rw-r--r--test/mitmproxy/test_protocol_http2.py75
-rw-r--r--test/mitmproxy/test_server.py4
-rw-r--r--test/mitmproxy/tservers.py2
-rw-r--r--test/netlib/http/test_message.py37
-rw-r--r--test/netlib/http/test_url.py44
-rw-r--r--test/netlib/test_encoding.py30
-rw-r--r--test/netlib/test_multidict.py14
-rw-r--r--web/src/js/components/FlowTable/FlowRow.jsx7
-rw-r--r--web/src/js/components/MainView.jsx2
-rw-r--r--web/src/js/utils.js26
73 files changed, 933 insertions, 741 deletions
diff --git a/.travis.yml b/.travis.yml
index e832d058..e9566ebe 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -20,10 +20,10 @@ matrix:
include:
- python: 3.5
env: TOXENV=lint
-# - os: osx
-# osx_image: xcode7.3
-# language: generic
-# env: TOXENV=py35
+ - os: osx
+ osx_image: xcode7.3
+ language: generic
+ env: TOXENV=py35
- python: 3.5
env: TOXENV=py35
- python: 3.5
diff --git a/examples/filt.py b/examples/filt.py
index 21744edd..9ccf9fa1 100644
--- a/examples/filt.py
+++ b/examples/filt.py
@@ -1,18 +1,20 @@
-# This scripts demonstrates how to use mitmproxy's filter pattern in inline scripts.
+# This scripts demonstrates how to use mitmproxy's filter pattern in scripts.
# Usage: mitmdump -s "filt.py FILTER"
import sys
from mitmproxy import filt
-state = {}
+
+class Filter:
+ def __init__(self, spec):
+ self.filter = filt.parse(spec)
+
+ def response(self, flow):
+ if flow.match(self.filter):
+ print("Flow matches filter:")
+ print(flow)
def start():
if len(sys.argv) != 2:
raise ValueError("Usage: -s 'filt.py FILTER'")
- state["filter"] = filt.parse(sys.argv[1])
-
-
-def response(flow):
- if flow.match(state["filter"]):
- print("Flow matches filter:")
- print(flow)
+ return Filter(sys.argv[1])
diff --git a/examples/flowwriter.py b/examples/flowwriter.py
index 07c7ca20..df2e5a40 100644
--- a/examples/flowwriter.py
+++ b/examples/flowwriter.py
@@ -3,20 +3,21 @@ import sys
from mitmproxy.flow import FlowWriter
-state = {}
+
+class Writer:
+ def __init__(self, path):
+ if path == "-":
+ f = sys.stdout
+ else:
+ f = open(path, "wb")
+ self.w = FlowWriter(f)
+
+ def response(self, flow):
+ if random.choice([True, False]):
+ self.w.add(flow)
def start():
if len(sys.argv) != 2:
raise ValueError('Usage: -s "flowriter.py filename"')
-
- if sys.argv[1] == "-":
- f = sys.stdout
- else:
- f = open(sys.argv[1], "wb")
- state["flow_writer"] = FlowWriter(f)
-
-
-def response(flow):
- if random.choice([True, False]):
- state["flow_writer"].add(flow)
+ return Writer(sys.argv[1])
diff --git a/examples/iframe_injector.py b/examples/iframe_injector.py
index 352c3c24..33d18bbd 100644
--- a/examples/iframe_injector.py
+++ b/examples/iframe_injector.py
@@ -3,26 +3,27 @@
import sys
from bs4 import BeautifulSoup
-iframe_url = None
+
+class Injector:
+ def __init__(self, iframe_url):
+ self.iframe_url = iframe_url
+
+ def response(self, flow):
+ if flow.request.host in self.iframe_url:
+ return
+ html = BeautifulSoup(flow.response.content, "lxml")
+ if html.body:
+ iframe = html.new_tag(
+ "iframe",
+ src=self.iframe_url,
+ frameborder=0,
+ height=0,
+ width=0)
+ html.body.insert(0, iframe)
+ flow.response.content = str(html).encode("utf8")
def start():
if len(sys.argv) != 2:
raise ValueError('Usage: -s "iframe_injector.py url"')
- global iframe_url
- iframe_url = sys.argv[1]
-
-
-def response(flow):
- if flow.request.host in iframe_url:
- return
- html = BeautifulSoup(flow.response.content, "lxml")
- if html.body:
- iframe = html.new_tag(
- "iframe",
- src=iframe_url,
- frameborder=0,
- height=0,
- width=0)
- html.body.insert(0, iframe)
- flow.response.content = str(html).encode("utf8")
+ return Injector(sys.argv[1])
diff --git a/examples/remote_debug.py b/examples/remote_debug.py
new file mode 100644
index 00000000..fb864f78
--- /dev/null
+++ b/examples/remote_debug.py
@@ -0,0 +1,19 @@
+"""
+This script enables remote debugging of the mitmproxy *UI* with PyCharm.
+For general debugging purposes, it is easier to just debug mitmdump within PyCharm.
+
+Usage:
+ - pip install pydevd on the mitmproxy machine
+ - Open the Run/Debug Configuration dialog box in PyCharm, and select the Python Remote Debug configuration type.
+ - Debugging works in the way that mitmproxy connects to the debug server on startup.
+ Specify host and port that mitmproxy can use to reach your PyCharm instance on startup.
+ - Adjust this inline script accordingly.
+ - Start debug server in PyCharm
+ - Set breakpoints
+ - Start mitmproxy -s remote_debug.py
+"""
+
+
+def start():
+ import pydevd
+ pydevd.settrace("localhost", port=5678, stdoutToServer=True, stderrToServer=True)
diff --git a/examples/stub.py b/examples/stub.py
index e5b4a39a..4f5061e2 100644
--- a/examples/stub.py
+++ b/examples/stub.py
@@ -11,7 +11,7 @@ def start():
mitmproxy.ctx.log("start")
-def configure(options):
+def configure(options, updated):
"""
Called once on script startup before any other events, and whenever options changes.
"""
diff --git a/mitmproxy/addons.py b/mitmproxy/addons.py
index c779aaf8..a4bea9fa 100644
--- a/mitmproxy/addons.py
+++ b/mitmproxy/addons.py
@@ -13,16 +13,23 @@ class Addons(object):
self.master = master
master.options.changed.connect(self.options_update)
- def options_update(self, options):
+ def options_update(self, options, updated):
for i in self.chain:
with self.master.handlecontext():
- i.configure(options)
+ i.configure(options, updated)
- def add(self, *addons):
+ def add(self, options, *addons):
+ if not addons:
+ raise ValueError("No adons specified.")
self.chain.extend(addons)
for i in addons:
self.invoke_with_context(i, "start")
- self.invoke_with_context(i, "configure", self.master.options)
+ self.invoke_with_context(
+ i,
+ "configure",
+ self.master.options,
+ self.master.options.keys()
+ )
def remove(self, addon):
self.chain = [i for i in self.chain if i is not addon]
diff --git a/mitmproxy/builtins/anticache.py b/mitmproxy/builtins/anticache.py
index f208e2fb..41a5ed95 100644
--- a/mitmproxy/builtins/anticache.py
+++ b/mitmproxy/builtins/anticache.py
@@ -5,7 +5,7 @@ class AntiCache:
def __init__(self):
self.enabled = False
- def configure(self, options):
+ def configure(self, options, updated):
self.enabled = options.anticache
def request(self, flow):
diff --git a/mitmproxy/builtins/anticomp.py b/mitmproxy/builtins/anticomp.py
index 50bd1b73..823e960c 100644
--- a/mitmproxy/builtins/anticomp.py
+++ b/mitmproxy/builtins/anticomp.py
@@ -5,7 +5,7 @@ class AntiComp:
def __init__(self):
self.enabled = False
- def configure(self, options):
+ def configure(self, options, updated):
self.enabled = options.anticomp
def request(self, flow):
diff --git a/mitmproxy/builtins/dumper.py b/mitmproxy/builtins/dumper.py
index 239630fb..74c2e6b2 100644
--- a/mitmproxy/builtins/dumper.py
+++ b/mitmproxy/builtins/dumper.py
@@ -5,6 +5,8 @@ import traceback
import click
+import typing # noqa
+
from mitmproxy import contentviews
from mitmproxy import ctx
from mitmproxy import exceptions
@@ -19,12 +21,25 @@ def indent(n, text):
return "\n".join(pad + i for i in l)
-class Dumper():
+class Dumper(object):
def __init__(self):
- self.filter = None
- self.flow_detail = None
- self.outfp = None
- self.showhost = None
+ self.filter = None # type: filt.TFilter
+ self.flow_detail = None # type: int
+ self.outfp = None # type: typing.io.TextIO
+ self.showhost = None # type: bool
+
+ def configure(self, options, updated):
+ if options.filtstr:
+ self.filter = filt.parse(options.filtstr)
+ if not self.filter:
+ raise exceptions.OptionsError(
+ "Invalid filter expression: %s" % options.filtstr
+ )
+ else:
+ self.filter = None
+ self.flow_detail = options.flow_detail
+ self.outfp = options.tfile
+ self.showhost = options.showhost
def echo(self, text, ident=None, **style):
if ident:
@@ -59,7 +74,7 @@ class Dumper():
self.echo("")
try:
- type, lines = contentviews.get_content_view(
+ _, lines = contentviews.get_content_view(
contentviews.get("Auto"),
content,
headers=getattr(message, "headers", None)
@@ -67,7 +82,7 @@ class Dumper():
except exceptions.ContentViewException:
s = "Content viewer failed: \n" + traceback.format_exc()
ctx.log.debug(s)
- type, lines = contentviews.get_content_view(
+ _, lines = contentviews.get_content_view(
contentviews.get("Raw"),
content,
headers=getattr(message, "headers", None)
@@ -114,9 +129,8 @@ class Dumper():
if flow.client_conn:
client = click.style(
strutils.escape_control_characters(
- flow.client_conn.address.host
- ),
- bold=True
+ repr(flow.client_conn.address)
+ )
)
elif flow.request.is_replay:
client = click.style("[replay]", fg="yellow", bold=True)
@@ -139,17 +153,23 @@ class Dumper():
url = flow.request.url
url = click.style(strutils.escape_control_characters(url), bold=True)
- httpversion = ""
+ http_version = ""
if flow.request.http_version not in ("HTTP/1.1", "HTTP/1.0"):
# We hide "normal" HTTP 1.
- httpversion = " " + flow.request.http_version
+ http_version = " " + flow.request.http_version
- line = "{stickycookie}{client} {method} {url}{httpversion}".format(
- stickycookie=stickycookie,
+ if self.flow_detail >= 2:
+ linebreak = "\n "
+ else:
+ linebreak = ""
+
+ line = "{client}: {linebreak}{stickycookie}{method} {url}{http_version}".format(
client=client,
+ stickycookie=stickycookie,
+ linebreak=linebreak,
method=method,
url=url,
- httpversion=httpversion
+ http_version=http_version
)
self.echo(line)
@@ -185,9 +205,14 @@ class Dumper():
size = human.pretty_size(len(flow.response.raw_content))
size = click.style(size, bold=True)
- arrows = click.style(" <<", bold=True)
+ arrows = click.style(" <<", bold=True)
+ if self.flow_detail == 1:
+ # This aligns the HTTP response code with the HTTP request method:
+ # 127.0.0.1:59519: GET http://example.com/
+ # << 304 Not Modified 0b
+ arrows = " " * (len(repr(flow.client_conn.address)) - 2) + arrows
- line = "{replay} {arrows} {code} {reason} {size}".format(
+ line = "{replay}{arrows} {code} {reason} {size}".format(
replay=replay,
arrows=arrows,
code=code,
@@ -211,25 +236,12 @@ class Dumper():
def match(self, f):
if self.flow_detail == 0:
return False
- if not self.filt:
+ if not self.filter:
return True
- elif f.match(self.filt):
+ elif f.match(self.filter):
return True
return False
- def configure(self, options):
- if options.filtstr:
- self.filt = filt.parse(options.filtstr)
- if not self.filt:
- raise exceptions.OptionsError(
- "Invalid filter expression: %s" % options.filtstr
- )
- else:
- self.filt = None
- self.flow_detail = options.flow_detail
- self.outfp = options.tfile
- self.showhost = options.showhost
-
def response(self, f):
if self.match(f):
self.echo_flow(f)
@@ -239,8 +251,7 @@ class Dumper():
self.echo_flow(f)
def tcp_message(self, f):
- # FIXME: Filter should be applied here
- if self.options.flow_detail == 0:
+ if not self.match(f):
return
message = f.messages[-1]
direction = "->" if message.from_client else "<-"
diff --git a/mitmproxy/builtins/filestreamer.py b/mitmproxy/builtins/filestreamer.py
index 97ddc7c4..ffa565ac 100644
--- a/mitmproxy/builtins/filestreamer.py
+++ b/mitmproxy/builtins/filestreamer.py
@@ -19,7 +19,7 @@ class FileStreamer:
self.stream = io.FilteredFlowWriter(f, filt)
self.active_flows = set()
- def configure(self, options):
+ def configure(self, options, updated):
# We're already streaming - stop the previous stream and restart
if self.stream:
self.done()
diff --git a/mitmproxy/builtins/replace.py b/mitmproxy/builtins/replace.py
index 83b96cee..74d30c05 100644
--- a/mitmproxy/builtins/replace.py
+++ b/mitmproxy/builtins/replace.py
@@ -8,7 +8,7 @@ class Replace:
def __init__(self):
self.lst = []
- def configure(self, options):
+ def configure(self, options, updated):
"""
.replacements is a list of tuples (fpat, rex, s):
diff --git a/mitmproxy/builtins/script.py b/mitmproxy/builtins/script.py
index ab068e47..c960dd1c 100644
--- a/mitmproxy/builtins/script.py
+++ b/mitmproxy/builtins/script.py
@@ -16,6 +16,19 @@ import watchdog.events
from watchdog.observers import polling
+class NS:
+ def __init__(self, ns):
+ self.__dict__["ns"] = ns
+
+ def __getattr__(self, key):
+ if key not in self.ns:
+ raise AttributeError("No such element: %s", key)
+ return self.ns[key]
+
+ def __setattr__(self, key, value):
+ self.__dict__["ns"][key] = value
+
+
def parse_command(command):
"""
Returns a (path, args) tuple.
@@ -74,18 +87,27 @@ def load_script(path, args):
ns = {'__file__': os.path.abspath(path)}
with scriptenv(path, args):
exec(code, ns, ns)
- return ns
+ return NS(ns)
class ReloadHandler(watchdog.events.FileSystemEventHandler):
def __init__(self, callback):
self.callback = callback
+ def filter(self, event):
+ if event.is_directory:
+ return False
+ if os.path.basename(event.src_path).startswith("."):
+ return False
+ return True
+
def on_modified(self, event):
- self.callback()
+ if self.filter(event):
+ self.callback()
def on_created(self, event):
- self.callback()
+ if self.filter(event):
+ self.callback()
class Script:
@@ -118,29 +140,35 @@ class Script:
# It's possible for ns to be un-initialised if we failed during
# configure
if self.ns is not None and not self.dead:
- func = self.ns.get(name)
+ func = getattr(self.ns, name, None)
if func:
with scriptenv(self.path, self.args):
- func(*args, **kwargs)
+ return func(*args, **kwargs)
def reload(self):
self.should_reload.set()
+ def load_script(self):
+ self.ns = load_script(self.path, self.args)
+ ret = self.run("start")
+ if ret:
+ self.ns = ret
+ self.run("start")
+
def tick(self):
if self.should_reload.is_set():
self.should_reload.clear()
ctx.log.info("Reloading script: %s" % self.name)
self.ns = load_script(self.path, self.args)
self.start()
- self.configure(self.last_options)
+ self.configure(self.last_options, self.last_options.keys())
else:
self.run("tick")
def start(self):
- self.ns = load_script(self.path, self.args)
- self.run("start")
+ self.load_script()
- def configure(self, options):
+ def configure(self, options, updated):
self.last_options = options
if not self.observer:
self.observer = polling.PollingObserver()
@@ -150,7 +178,7 @@ class Script:
os.path.dirname(self.path) or "."
)
self.observer.start()
- self.run("configure", options)
+ self.run("configure", options, updated)
def done(self):
self.run("done")
@@ -161,26 +189,27 @@ class ScriptLoader():
"""
An addon that manages loading scripts from options.
"""
- def configure(self, options):
- for s in options.scripts:
- if options.scripts.count(s) > 1:
- raise exceptions.OptionsError("Duplicate script: %s" % s)
-
- for a in ctx.master.addons.chain[:]:
- if isinstance(a, Script) and a.name not in options.scripts:
- ctx.log.info("Un-loading script: %s" % a.name)
- ctx.master.addons.remove(a)
-
- current = {}
- for a in ctx.master.addons.chain[:]:
- if isinstance(a, Script):
- current[a.name] = a
- ctx.master.addons.chain.remove(a)
-
- for s in options.scripts:
- if s in current:
- ctx.master.addons.chain.append(current[s])
- else:
- ctx.log.info("Loading script: %s" % s)
- sc = Script(s)
- ctx.master.addons.add(sc)
+ def configure(self, options, updated):
+ if "scripts" in updated:
+ for s in options.scripts:
+ if options.scripts.count(s) > 1:
+ raise exceptions.OptionsError("Duplicate script: %s" % s)
+
+ for a in ctx.master.addons.chain[:]:
+ if isinstance(a, Script) and a.name not in options.scripts:
+ ctx.log.info("Un-loading script: %s" % a.name)
+ ctx.master.addons.remove(a)
+
+ current = {}
+ for a in ctx.master.addons.chain[:]:
+ if isinstance(a, Script):
+ current[a.name] = a
+ ctx.master.addons.chain.remove(a)
+
+ for s in options.scripts:
+ if s in current:
+ ctx.master.addons.chain.append(current[s])
+ else:
+ ctx.log.info("Loading script: %s" % s)
+ sc = Script(s)
+ ctx.master.addons.add(options, sc)
diff --git a/mitmproxy/builtins/setheaders.py b/mitmproxy/builtins/setheaders.py
index 6bda3f55..4a784a1d 100644
--- a/mitmproxy/builtins/setheaders.py
+++ b/mitmproxy/builtins/setheaders.py
@@ -6,7 +6,7 @@ class SetHeaders:
def __init__(self):
self.lst = []
- def configure(self, options):
+ def configure(self, options, updated):
"""
options.setheaders is a tuple of (fpatt, header, value)
diff --git a/mitmproxy/builtins/stickyauth.py b/mitmproxy/builtins/stickyauth.py
index 1309911c..98fb65ed 100644
--- a/mitmproxy/builtins/stickyauth.py
+++ b/mitmproxy/builtins/stickyauth.py
@@ -10,7 +10,7 @@ class StickyAuth:
self.flt = None
self.hosts = {}
- def configure(self, options):
+ def configure(self, options, updated):
if options.stickyauth:
flt = filt.parse(options.stickyauth)
if not flt:
diff --git a/mitmproxy/builtins/stickycookie.py b/mitmproxy/builtins/stickycookie.py
index dc699bb4..88333d5c 100644
--- a/mitmproxy/builtins/stickycookie.py
+++ b/mitmproxy/builtins/stickycookie.py
@@ -32,7 +32,7 @@ class StickyCookie:
self.jar = collections.defaultdict(dict)
self.flt = None
- def configure(self, options):
+ def configure(self, options, updated):
if options.stickycookie:
flt = filt.parse(options.stickycookie)
if not flt:
diff --git a/mitmproxy/console/common.py b/mitmproxy/console/common.py
index 281fd658..9fb8b5c9 100644
--- a/mitmproxy/console/common.py
+++ b/mitmproxy/console/common.py
@@ -134,7 +134,11 @@ def save_data(path, data):
if not path:
return
try:
- with open(path, "wb") as f:
+ if isinstance(data, bytes):
+ mode = "wb"
+ else:
+ mode = "w"
+ with open(path, mode) as f:
f.write(data)
except IOError as v:
signals.status_message.send(message=v.strerror)
@@ -193,10 +197,9 @@ def ask_scope_and_callback(flow, cb, *args):
def copy_to_clipboard_or_prompt(data):
# pyperclip calls encode('utf-8') on data to be copied without checking.
# if data are already encoded that way UnicodeDecodeError is thrown.
- toclip = ""
- try:
- toclip = data.decode('utf-8')
- except (UnicodeDecodeError):
+ if isinstance(data, bytes):
+ toclip = data.decode("utf8", "replace")
+ else:
toclip = data
try:
@@ -216,7 +219,7 @@ def copy_to_clipboard_or_prompt(data):
def format_flow_data(key, scope, flow):
- data = ""
+ data = b""
if scope in ("q", "b"):
request = flow.request.copy()
request.decode(strict=False)
@@ -230,7 +233,7 @@ def format_flow_data(key, scope, flow):
raise ValueError("Unknown key: {}".format(key))
if scope == "b" and flow.request.raw_content and flow.response:
# Add padding between request and response
- data += "\r\n" * 2
+ data += b"\r\n" * 2
if scope in ("s", "b") and flow.response:
response = flow.response.copy()
response.decode(strict=False)
@@ -293,7 +296,7 @@ def ask_save_body(scope, flow):
)
elif scope == "b" and request_has_content and response_has_content:
ask_save_path(
- (flow.request.get_content(strict=False) + "\n" +
+ (flow.request.get_content(strict=False) + b"\n" +
flow.response.get_content(strict=False)),
"Save request & response content to"
)
@@ -407,7 +410,7 @@ def raw_format_flow(f, focus, extended):
return urwid.Pile(pile)
-def format_flow(f, focus, extended=False, hostheader=False, marked=False):
+def format_flow(f, focus, extended=False, hostheader=False):
d = dict(
intercepted = f.intercepted,
acked = f.reply.acked,
@@ -420,7 +423,7 @@ def format_flow(f, focus, extended=False, hostheader=False, marked=False):
err_msg = f.error.msg if f.error else None,
- marked = marked,
+ marked = f.marked,
)
if f.response:
if f.response.raw_content:
diff --git a/mitmproxy/console/flowlist.py b/mitmproxy/console/flowlist.py
index 53e934f1..43742083 100644
--- a/mitmproxy/console/flowlist.py
+++ b/mitmproxy/console/flowlist.py
@@ -120,23 +120,17 @@ class ConnectionItem(urwid.WidgetWrap):
self.flow,
self.f,
hostheader = self.master.options.showhost,
- marked=self.state.flow_marked(self.flow)
)
def selectable(self):
return True
def save_flows_prompt(self, k):
- if k == "a":
+ if k == "l":
signals.status_prompt_path.send(
- prompt = "Save all flows to",
+ prompt = "Save listed flows to",
callback = self.master.save_flows
)
- elif k == "m":
- signals.status_prompt_path.send(
- prompt = "Save marked flows to",
- callback = self.master.save_marked_flows
- )
else:
signals.status_prompt_path.send(
prompt = "Save this flow to",
@@ -188,17 +182,16 @@ class ConnectionItem(urwid.WidgetWrap):
self.flow.accept_intercept(self.master)
signals.flowlist_change.send(self)
elif key == "d":
- self.flow.kill(self.master)
+ if not self.flow.reply.acked:
+ self.flow.kill(self.master)
self.state.delete_flow(self.flow)
signals.flowlist_change.send(self)
elif key == "D":
f = self.master.duplicate_flow(self.flow)
- self.master.view_flow(f)
+ self.master.state.set_focus_flow(f)
+ signals.flowlist_change.send(self)
elif key == "m":
- if self.state.flow_marked(self.flow):
- self.state.set_flow_marked(self.flow, False)
- else:
- self.state.set_flow_marked(self.flow, True)
+ self.flow.marked = not self.flow.marked
signals.flowlist_change.send(self)
elif key == "M":
if self.state.mark_filter:
@@ -233,7 +226,7 @@ class ConnectionItem(urwid.WidgetWrap):
)
elif key == "U":
for f in self.state.flows:
- self.state.set_flow_marked(f, False)
+ f.marked = False
signals.flowlist_change.send(self)
elif key == "V":
if not self.flow.modified():
@@ -247,14 +240,14 @@ class ConnectionItem(urwid.WidgetWrap):
self,
prompt = "Save",
keys = (
- ("all flows", "a"),
+ ("listed flows", "l"),
("this flow", "t"),
- ("marked flows", "m"),
),
callback = self.save_flows_prompt,
)
elif key == "X":
- self.flow.kill(self.master)
+ if not self.flow.reply.acked:
+ self.flow.kill(self.master)
elif key == "enter":
if self.flow.request:
self.master.view_flow(self.flow)
@@ -356,7 +349,8 @@ class FlowListBox(urwid.ListBox):
return
scheme, host, port, path = parts
f = self.master.create_request(method, scheme, host, port, path)
- self.master.view_flow(f)
+ self.master.state.set_focus_flow(f)
+ signals.flowlist_change.send(self)
def keypress(self, size, key):
key = common.shortcuts(key)
diff --git a/mitmproxy/console/flowview.py b/mitmproxy/console/flowview.py
index 938c8e86..c354563f 100644
--- a/mitmproxy/console/flowview.py
+++ b/mitmproxy/console/flowview.py
@@ -6,6 +6,7 @@ import sys
import traceback
import urwid
+from typing import Optional, Union # noqa
from mitmproxy import contentviews
from mitmproxy import controller
@@ -38,7 +39,7 @@ def _mkhelp():
("d", "delete flow"),
("e", "edit request/response"),
("f", "load full body data"),
- ("m", "change body display mode for this entity"),
+ ("m", "change body display mode for this entity\n(default mode can be changed in the options)"),
(None,
common.highlight_key("automatic", "a") +
[("text", ": automatic detection")]
@@ -75,7 +76,6 @@ def _mkhelp():
common.highlight_key("xml", "x") +
[("text", ": XML")]
),
- ("M", "change default body display mode"),
("E", "export flow to file"),
("r", "replay request"),
("V", "revert changes to request"),
@@ -105,7 +105,8 @@ footer = [
class FlowViewHeader(urwid.WidgetWrap):
def __init__(self, master, f):
- self.master, self.flow = master, f
+ self.master = master # type: "mitmproxy.console.master.ConsoleMaster"
+ self.flow = f # type: models.HTTPFlow
self._w = common.format_flow(
f,
False,
@@ -135,14 +136,15 @@ class FlowView(tabs.Tabs):
def __init__(self, master, state, flow, tab_offset):
self.master, self.state, self.flow = master, state, flow
- tabs.Tabs.__init__(self,
- [
- (self.tab_request, self.view_request),
- (self.tab_response, self.view_response),
- (self.tab_details, self.view_details),
- ],
- tab_offset
- )
+ super(FlowView, self).__init__(
+ [
+ (self.tab_request, self.view_request),
+ (self.tab_response, self.view_response),
+ (self.tab_details, self.view_details),
+ ],
+ tab_offset
+ )
+
self.show()
self.last_displayed_body = None
signals.flow_change.connect(self.sig_flow_change)
@@ -189,15 +191,21 @@ class FlowView(tabs.Tabs):
limit = sys.maxsize
else:
limit = contentviews.VIEW_CUTOFF
+
+ flow_modify_cache_invalidation = hash((
+ message.raw_content,
+ message.headers.fields,
+ getattr(message, "path", None),
+ ))
return cache.get(
- self._get_content_view,
+ # We move message into this partial function as it is not hashable.
+ lambda *args: self._get_content_view(message, *args),
viewmode,
- message,
limit,
- message # Cache invalidation
+ flow_modify_cache_invalidation
)
- def _get_content_view(self, viewmode, message, max_lines, _):
+ def _get_content_view(self, message, viewmode, max_lines, _):
try:
content = message.content
@@ -396,7 +404,7 @@ class FlowView(tabs.Tabs):
if not self.flow.response:
self.flow.response = models.HTTPResponse(
self.flow.request.http_version,
- 200, "OK", Headers(), ""
+ 200, b"OK", Headers(), b""
)
self.flow.response.reply = controller.DummyReply()
message = self.flow.response
@@ -524,30 +532,24 @@ class FlowView(tabs.Tabs):
)
signals.flow_change.send(self, flow = self.flow)
- def delete_body(self, t):
+ def keypress(self, size, key):
+ conn = None # type: Optional[Union[models.HTTPRequest, models.HTTPResponse]]
if self.tab_offset == TAB_REQ:
- self.flow.request.content = None
- else:
- self.flow.response.content = None
- signals.flow_change.send(self, flow = self.flow)
+ conn = self.flow.request
+ elif self.tab_offset == TAB_RESP:
+ conn = self.flow.response
- def keypress(self, size, key):
key = super(self.__class__, self).keypress(size, key)
+ # Special case: Space moves over to the next flow.
+ # We need to catch that before applying common.shortcuts()
if key == " ":
self.view_next_flow(self.flow)
return
key = common.shortcuts(key)
- if self.tab_offset == TAB_REQ:
- conn = self.flow.request
- elif self.tab_offset == TAB_RESP:
- conn = self.flow.response
- else:
- conn = None
-
if key in ("up", "down", "page up", "page down"):
- # Why doesn't this just work??
+ # Pass scroll events to the wrapped widget
self._w.keypress(size, key)
elif key == "a":
self.flow.accept_intercept(self.master)
@@ -563,10 +565,12 @@ class FlowView(tabs.Tabs):
else:
self.view_next_flow(self.flow)
f = self.flow
- f.kill(self.master)
+ if not f.reply.acked:
+ f.kill(self.master)
self.state.delete_flow(f)
elif key == "D":
f = self.master.duplicate_flow(self.flow)
+ signals.pop_view_state.send(self)
self.master.view_flow(f)
signals.status_message.send(message="Duplicated.")
elif key == "p":
@@ -577,12 +581,12 @@ class FlowView(tabs.Tabs):
signals.status_message.send(message=r)
signals.flow_change.send(self, flow = self.flow)
elif key == "V":
- if not self.flow.modified():
+ if self.flow.modified():
+ self.state.revert(self.flow)
+ signals.flow_change.send(self, flow = self.flow)
+ signals.status_message.send(message="Reverted.")
+ else:
signals.status_message.send(message="Flow not modified.")
- return
- self.state.revert(self.flow)
- signals.flow_change.send(self, flow = self.flow)
- signals.status_message.send(message="Reverted.")
elif key == "W":
signals.status_prompt_path.send(
prompt = "Save this flow",
@@ -595,133 +599,128 @@ class FlowView(tabs.Tabs):
callback = self.master.run_script_once,
args = (self.flow,)
)
-
- if not conn and key in set(list("befgmxvzEC")):
+ elif key == "e":
+ if self.tab_offset == TAB_REQ:
+ signals.status_prompt_onekey.send(
+ prompt="Edit request",
+ keys=(
+ ("cookies", "c"),
+ ("query", "q"),
+ ("path", "p"),
+ ("url", "u"),
+ ("header", "h"),
+ ("form", "f"),
+ ("raw body", "r"),
+ ("method", "m"),
+ ),
+ callback=self.edit
+ )
+ elif self.tab_offset == TAB_RESP:
+ signals.status_prompt_onekey.send(
+ prompt="Edit response",
+ keys=(
+ ("cookies", "c"),
+ ("code", "o"),
+ ("message", "m"),
+ ("header", "h"),
+ ("raw body", "r"),
+ ),
+ callback=self.edit
+ )
+ else:
+ signals.status_message.send(
+ message="Tab to the request or response",
+ expire=1
+ )
+ elif key in set("bfgmxvzEC") and not conn:
signals.status_message.send(
message = "Tab to the request or response",
expire = 1
)
- elif conn:
- if key == "b":
- if self.tab_offset == TAB_REQ:
- common.ask_save_body(
- "q", self.master, self.state, self.flow
- )
+ return
+ elif key == "b":
+ if self.tab_offset == TAB_REQ:
+ common.ask_save_body("q", self.flow)
+ else:
+ common.ask_save_body("s", self.flow)
+ elif key == "f":
+ signals.status_message.send(message="Loading all body data...")
+ self.state.add_flow_setting(
+ self.flow,
+ (self.tab_offset, "fullcontents"),
+ True
+ )
+ signals.flow_change.send(self, flow = self.flow)
+ signals.status_message.send(message="")
+ elif key == "m":
+ p = list(contentviews.view_prompts)
+ p.insert(0, ("Clear", "C"))
+ signals.status_prompt_onekey.send(
+ self,
+ prompt = "Display mode",
+ keys = p,
+ callback = self.change_this_display_mode
+ )
+ elif key == "E":
+ if self.tab_offset == TAB_REQ:
+ scope = "q"
+ else:
+ scope = "s"
+ signals.status_prompt_onekey.send(
+ self,
+ prompt = "Export to file",
+ keys = [(e[0], e[1]) for e in export.EXPORTERS],
+ callback = common.export_to_clip_or_file,
+ args = (scope, self.flow, common.ask_save_path)
+ )
+ elif key == "C":
+ if self.tab_offset == TAB_REQ:
+ scope = "q"
+ else:
+ scope = "s"
+ signals.status_prompt_onekey.send(
+ self,
+ prompt = "Export to clipboard",
+ keys = [(e[0], e[1]) for e in export.EXPORTERS],
+ callback = common.export_to_clip_or_file,
+ args = (scope, self.flow, common.copy_to_clipboard_or_prompt)
+ )
+ elif key == "x":
+ conn.content = None
+ signals.flow_change.send(self, flow=self.flow)
+ elif key == "v":
+ if conn.raw_content:
+ t = conn.headers.get("content-type")
+ if "EDITOR" in os.environ or "PAGER" in os.environ:
+ self.master.spawn_external_viewer(conn.get_content(strict=False), t)
else:
- common.ask_save_body(
- "s", self.master, self.state, self.flow
- )
- elif key == "e":
- if self.tab_offset == TAB_REQ:
- signals.status_prompt_onekey.send(
- prompt = "Edit request",
- keys = (
- ("cookies", "c"),
- ("query", "q"),
- ("path", "p"),
- ("url", "u"),
- ("header", "h"),
- ("form", "f"),
- ("raw body", "r"),
- ("method", "m"),
- ),
- callback = self.edit
+ signals.status_message.send(
+ message = "Error! Set $EDITOR or $PAGER."
)
- else:
- signals.status_prompt_onekey.send(
- prompt = "Edit response",
- keys = (
- ("cookies", "c"),
- ("code", "o"),
- ("message", "m"),
- ("header", "h"),
- ("raw body", "r"),
- ),
- callback = self.edit
+ elif key == "z":
+ self.flow.backup()
+ e = conn.headers.get("content-encoding", "identity")
+ if e != "identity":
+ try:
+ conn.decode()
+ except ValueError:
+ signals.status_message.send(
+ message = "Could not decode - invalid data?"
)
- key = None
- elif key == "f":
- signals.status_message.send(message="Loading all body data...")
- self.state.add_flow_setting(
- self.flow,
- (self.tab_offset, "fullcontents"),
- True
- )
- signals.flow_change.send(self, flow = self.flow)
- signals.status_message.send(message="")
- elif key == "m":
- p = list(contentviews.view_prompts)
- p.insert(0, ("Clear", "C"))
- signals.status_prompt_onekey.send(
- self,
- prompt = "Display mode",
- keys = p,
- callback = self.change_this_display_mode
- )
- key = None
- elif key == "E":
- if self.tab_offset == TAB_REQ:
- scope = "q"
- else:
- scope = "s"
- signals.status_prompt_onekey.send(
- self,
- prompt = "Export to file",
- keys = [(e[0], e[1]) for e in export.EXPORTERS],
- callback = common.export_to_clip_or_file,
- args = (scope, self.flow, common.ask_save_path)
- )
- elif key == "C":
- if self.tab_offset == TAB_REQ:
- scope = "q"
- else:
- scope = "s"
- signals.status_prompt_onekey.send(
- self,
- prompt = "Export to clipboard",
- keys = [(e[0], e[1]) for e in export.EXPORTERS],
- callback = common.export_to_clip_or_file,
- args = (scope, self.flow, common.copy_to_clipboard_or_prompt)
- )
- elif key == "x":
+ else:
signals.status_prompt_onekey.send(
- prompt = "Delete body",
+ prompt = "Select encoding: ",
keys = (
- ("completely", "c"),
- ("mark as missing", "m"),
+ ("gzip", "z"),
+ ("deflate", "d"),
),
- callback = self.delete_body
+ callback = self.encode_callback,
+ args = (conn,)
)
- key = None
- elif key == "v":
- if conn.raw_content:
- t = conn.headers.get("content-type")
- if "EDITOR" in os.environ or "PAGER" in os.environ:
- self.master.spawn_external_viewer(conn.get_content(strict=False), t)
- else:
- signals.status_message.send(
- message = "Error! Set $EDITOR or $PAGER."
- )
- elif key == "z":
- self.flow.backup()
- e = conn.headers.get("content-encoding", "identity")
- if e != "identity":
- if not conn.decode():
- signals.status_message.send(
- message = "Could not decode - invalid data?"
- )
- else:
- signals.status_prompt_onekey.send(
- prompt = "Select encoding: ",
- keys = (
- ("gzip", "z"),
- ("deflate", "d"),
- ),
- callback = self.encode_callback,
- args = (conn,)
- )
- signals.flow_change.send(self, flow = self.flow)
- return key
+ signals.flow_change.send(self, flow = self.flow)
+ else:
+ # Key is not handled here.
+ return key
def encode_callback(self, key, conn):
encoding_map = {
diff --git a/mitmproxy/console/help.py b/mitmproxy/console/help.py
index 064d3cb5..ff4a072f 100644
--- a/mitmproxy/console/help.py
+++ b/mitmproxy/console/help.py
@@ -1,5 +1,7 @@
from __future__ import absolute_import, print_function, division
+import platform
+
import urwid
from mitmproxy import filt
@@ -9,7 +11,7 @@ from mitmproxy.console import signals
from netlib import version
footer = [
- ("heading", 'mitmproxy v%s ' % version.VERSION),
+ ("heading", 'mitmproxy {} (Python {}) '.format(version.VERSION, platform.python_version())),
('heading_key', "q"), ":back ",
]
diff --git a/mitmproxy/console/master.py b/mitmproxy/console/master.py
index 4fd6cb78..db414147 100644
--- a/mitmproxy/console/master.py
+++ b/mitmproxy/console/master.py
@@ -34,6 +34,7 @@ from mitmproxy.console import palettes
from mitmproxy.console import signals
from mitmproxy.console import statusbar
from mitmproxy.console import window
+from mitmproxy.filt import FMarked
from netlib import tcp, strutils
EVENTLOG_SIZE = 500
@@ -48,7 +49,7 @@ class ConsoleState(flow.State):
self.default_body_view = contentviews.get("Auto")
self.flowsettings = weakref.WeakKeyDictionary()
self.last_search = None
- self.last_filter = None
+ self.last_filter = ""
self.mark_filter = False
def __setattr__(self, name, value):
@@ -66,7 +67,6 @@ class ConsoleState(flow.State):
def add_flow(self, f):
super(ConsoleState, self).add_flow(f)
self.update_focus()
- self.set_flow_marked(f, False)
return f
def update_flow(self, f):
@@ -86,10 +86,10 @@ class ConsoleState(flow.State):
def set_focus(self, idx):
if self.view:
- if idx >= len(self.view):
- idx = len(self.view) - 1
- elif idx < 0:
+ if idx is None or idx < 0:
idx = 0
+ elif idx >= len(self.view):
+ idx = len(self.view) - 1
self.focus = idx
else:
self.focus = None
@@ -123,48 +123,71 @@ class ConsoleState(flow.State):
self.set_focus(self.focus)
return ret
- def filter_marked(self, m):
- def actual_func(x):
- if x.id in m:
- return True
- return False
- return actual_func
+ def get_nearest_matching_flow(self, flow, filt):
+ fidx = self.view.index(flow)
+ dist = 1
+
+ fprev = fnext = True
+ while fprev or fnext:
+ fprev, _ = self.get_from_pos(fidx - dist)
+ fnext, _ = self.get_from_pos(fidx + dist)
+
+ if fprev and fprev.match(filt):
+ return fprev
+ elif fnext and fnext.match(filt):
+ return fnext
+
+ dist += 1
+
+ return None
def enable_marked_filter(self):
+ marked_flows = [f for f in self.flows if f.marked]
+ if not marked_flows:
+ return
+
+ marked_filter = "~%s" % FMarked.code
+
+ # Save Focus
+ last_focus, _ = self.get_focus()
+ nearest_marked = self.get_nearest_matching_flow(last_focus, marked_filter)
+
self.last_filter = self.limit_txt
- marked_flows = []
- for f in self.flows:
- if self.flow_marked(f):
- marked_flows.append(f.id)
- if len(marked_flows) > 0:
- f = self.filter_marked(marked_flows)
- self.view._close()
- self.view = flow.FlowView(self.flows, f)
- self.focus = 0
- self.set_focus(self.focus)
- self.mark_filter = True
+ self.set_limit(marked_filter)
+
+ # Restore Focus
+ if last_focus.marked:
+ self.set_focus_flow(last_focus)
+ else:
+ self.set_focus_flow(nearest_marked)
+
+ self.mark_filter = True
def disable_marked_filter(self):
- if self.last_filter is None:
- self.view = flow.FlowView(self.flows, None)
+ marked_filter = "~%s" % FMarked.code
+
+ # Save Focus
+ last_focus, _ = self.get_focus()
+ nearest_marked = self.get_nearest_matching_flow(last_focus, marked_filter)
+
+ self.set_limit(self.last_filter)
+ self.last_filter = ""
+
+ # Restore Focus
+ if last_focus.marked:
+ self.set_focus_flow(last_focus)
else:
- self.set_limit(self.last_filter)
- self.focus = 0
- self.set_focus(self.focus)
- self.last_filter = None
+ self.set_focus_flow(nearest_marked)
+
self.mark_filter = False
def clear(self):
- marked_flows = []
- for f in self.flows:
- if self.flow_marked(f):
- marked_flows.append(f)
-
+ marked_flows = [f for f in self.view if f.marked]
super(ConsoleState, self).clear()
for f in marked_flows:
self.add_flow(f)
- self.set_flow_marked(f, True)
+ f.marked = True
if len(self.flows.views) == 0:
self.focus = None
@@ -172,12 +195,6 @@ class ConsoleState(flow.State):
self.focus = 0
self.set_focus(self.focus)
- def flow_marked(self, flow):
- return self.get_flow_setting(flow, "marked", False)
-
- def set_flow_marked(self, flow, marked):
- self.add_flow_setting(flow, "marked", marked)
-
class Options(mitmproxy.options.Options):
def __init__(
@@ -242,7 +259,7 @@ class ConsoleMaster(flow.FlowMaster):
signals.pop_view_state.connect(self.sig_pop_view_state)
signals.push_view_state.connect(self.sig_push_view_state)
signals.sig_add_log.connect(self.sig_add_log)
- self.addons.add(*builtins.default_addons())
+ self.addons.add(options, *builtins.default_addons())
def __setattr__(self, name, value):
self.__dict__[name] = value
@@ -254,10 +271,6 @@ class ConsoleMaster(flow.FlowMaster):
expire=1
)
- def load_script(self, command, use_reloader=True):
- # We default to using the reloader in the console ui.
- return super(ConsoleMaster, self).load_script(command, use_reloader)
-
def sig_add_log(self, sender, e, level):
if self.options.verbosity < utils.log_tier(level):
return
@@ -352,7 +365,7 @@ class ConsoleMaster(flow.FlowMaster):
try:
return flow.read_flows_from_paths(path)
except exceptions.FlowReadException as e:
- signals.status_message.send(message=e.strerror)
+ signals.status_message.send(message=str(e))
def client_playback_path(self, path):
if not isinstance(path, list):
@@ -619,13 +632,6 @@ class ConsoleMaster(flow.FlowMaster):
def save_flows(self, path):
return self._write_flows(path, self.state.view)
- def save_marked_flows(self, path):
- marked_flows = []
- for f in self.state.view:
- if self.state.flow_marked(f):
- marked_flows.append(f)
- return self._write_flows(path, marked_flows)
-
def load_flows_callback(self, path):
if not path:
return
@@ -748,10 +754,3 @@ class ConsoleMaster(flow.FlowMaster):
direction=direction,
), "info")
self.add_log(strutils.bytes_to_escaped_str(message.content), "debug")
-
- @controller.handler
- def script_change(self, script):
- if super(ConsoleMaster, self).script_change(script):
- signals.status_message.send(message='"{}" reloaded.'.format(script.path))
- else:
- signals.status_message.send(message='Error reloading "{}".'.format(script.path))
diff --git a/mitmproxy/console/options.py b/mitmproxy/console/options.py
index 62564a60..f9fc3764 100644
--- a/mitmproxy/console/options.py
+++ b/mitmproxy/console/options.py
@@ -140,7 +140,7 @@ class Options(urwid.WidgetWrap):
)
self.master.loop.widget.footer.update("")
signals.update_settings.connect(self.sig_update_settings)
- master.options.changed.connect(self.sig_update_settings)
+ master.options.changed.connect(lambda sender, updated: self.sig_update_settings(sender))
def sig_update_settings(self, sender):
self.lb.walker._modified()
diff --git a/mitmproxy/console/searchable.py b/mitmproxy/console/searchable.py
index c60d1cd9..d58d3d13 100644
--- a/mitmproxy/console/searchable.py
+++ b/mitmproxy/console/searchable.py
@@ -78,9 +78,9 @@ class Searchable(urwid.ListBox):
return
# Start search at focus + 1
if backwards:
- rng = xrange(len(self.body) - 1, -1, -1)
+ rng = range(len(self.body) - 1, -1, -1)
else:
- rng = xrange(1, len(self.body) + 1)
+ rng = range(1, len(self.body) + 1)
for i in rng:
off = (self.focus_position + i) % len(self.body)
w = self.body[off]
diff --git a/mitmproxy/console/statusbar.py b/mitmproxy/console/statusbar.py
index 3120fa71..156d1176 100644
--- a/mitmproxy/console/statusbar.py
+++ b/mitmproxy/console/statusbar.py
@@ -124,7 +124,7 @@ class StatusBar(urwid.WidgetWrap):
super(StatusBar, self).__init__(urwid.Pile([self.ib, self.master.ab]))
signals.update_settings.connect(self.sig_update_settings)
signals.flowlist_change.connect(self.sig_update_settings)
- master.options.changed.connect(self.sig_update_settings)
+ master.options.changed.connect(lambda sender, updated: self.sig_update_settings(sender))
self.redraw()
def sig_update_settings(self, sender):
@@ -171,10 +171,6 @@ class StatusBar(urwid.WidgetWrap):
r.append("[")
r.append(("heading_key", "l"))
r.append(":%s]" % self.master.state.limit_txt)
- if self.master.state.mark_filter:
- r.append("[")
- r.append(("heading_key", "Marked Flows"))
- r.append("]")
if self.master.options.stickycookie:
r.append("[")
r.append(("heading_key", "t"))
diff --git a/mitmproxy/console/tabs.py b/mitmproxy/console/tabs.py
index bfcdeba3..a5e9c510 100644
--- a/mitmproxy/console/tabs.py
+++ b/mitmproxy/console/tabs.py
@@ -25,7 +25,7 @@ class Tab(urwid.WidgetWrap):
class Tabs(urwid.WidgetWrap):
def __init__(self, tabs, tab_offset=0):
- urwid.WidgetWrap.__init__(self, "")
+ super(Tabs, self).__init__("")
self.tab_offset = tab_offset
self.tabs = tabs
self.show()
diff --git a/mitmproxy/contentviews.py b/mitmproxy/contentviews.py
index afdaad7f..e155bc01 100644
--- a/mitmproxy/contentviews.py
+++ b/mitmproxy/contentviews.py
@@ -20,6 +20,8 @@ import logging
import subprocess
import sys
+from typing import Mapping # noqa
+
import html2text
import lxml.etree
import lxml.html
@@ -76,6 +78,7 @@ def pretty_json(s):
def format_dict(d):
+ # type: (Mapping[Union[str,bytes], Union[str,bytes]]) -> Generator[Tuple[Union[str,bytes], Union[str,bytes]]]
"""
Helper function that transforms the given dictionary into a list of
("key", key )
@@ -85,7 +88,7 @@ def format_dict(d):
max_key_len = max(len(k) for k in d.keys())
max_key_len = min(max_key_len, KEY_MAX)
for key, value in d.items():
- key += ":"
+ key += b":" if isinstance(key, bytes) else u":"
key = key.ljust(max_key_len + 2)
yield [
("header", key),
@@ -106,12 +109,16 @@ class View(object):
prompt = ()
content_types = []
- def __call__(self, data, **metadata):
+ def __call__(
+ self,
+ data, # type: bytes
+ **metadata
+ ):
"""
Transform raw data into human-readable output.
Args:
- data: the data to decode/format as bytes.
+ data: the data to decode/format.
metadata: optional keyword-only arguments for metadata. Implementations must not
rely on a given argument being present.
@@ -278,6 +285,10 @@ class ViewURLEncoded(View):
content_types = ["application/x-www-form-urlencoded"]
def __call__(self, data, **metadata):
+ try:
+ data = data.decode("ascii", "strict")
+ except ValueError:
+ return None
d = url.decode(data)
return "URLEncoded form", format_dict(multidict.MultiDict(d))
diff --git a/mitmproxy/controller.py b/mitmproxy/controller.py
index 070ec862..35817a85 100644
--- a/mitmproxy/controller.py
+++ b/mitmproxy/controller.py
@@ -37,8 +37,6 @@ Events = frozenset([
"configure",
"done",
"tick",
-
- "script_change",
])
diff --git a/mitmproxy/ctx.py b/mitmproxy/ctx.py
index fcfdfd0b..5d2905fa 100644
--- a/mitmproxy/ctx.py
+++ b/mitmproxy/ctx.py
@@ -1,4 +1,4 @@
from typing import Callable # noqa
master = None # type: "mitmproxy.flow.FlowMaster"
-log = None # type: Callable[[str], None]
+log = None # type: "mitmproxy.controller.Log"
diff --git a/mitmproxy/dump.py b/mitmproxy/dump.py
index 4f34ab95..83f44d87 100644
--- a/mitmproxy/dump.py
+++ b/mitmproxy/dump.py
@@ -42,8 +42,8 @@ class DumpMaster(flow.FlowMaster):
def __init__(self, server, options):
flow.FlowMaster.__init__(self, options, server, flow.State())
self.has_errored = False
- self.addons.add(*builtins.default_addons())
- self.addons.add(dumper.Dumper())
+ self.addons.add(options, *builtins.default_addons())
+ self.addons.add(options, dumper.Dumper())
# This line is just for type hinting
self.options = self.options # type: Options
self.replay_ignore_params = options.replay_ignore_params
diff --git a/mitmproxy/filt.py b/mitmproxy/filt.py
index 8b647b22..67915e5b 100644
--- a/mitmproxy/filt.py
+++ b/mitmproxy/filt.py
@@ -39,9 +39,12 @@ import functools
from mitmproxy.models.http import HTTPFlow
from mitmproxy.models.tcp import TCPFlow
+from mitmproxy.models.flow import Flow
+
from netlib import strutils
import pyparsing as pp
+from typing import Callable
def only(*types):
@@ -80,6 +83,14 @@ class FErr(_Action):
return True if f.error else False
+class FMarked(_Action):
+ code = "marked"
+ help = "Match marked flows"
+
+ def __call__(self, f):
+ return f.marked
+
+
class FHTTP(_Action):
code = "http"
help = "Match HTTP flows"
@@ -398,6 +409,7 @@ filt_unary = [
FAsset,
FErr,
FHTTP,
+ FMarked,
FReq,
FResp,
FTCP,
@@ -471,7 +483,11 @@ def _make():
bnf = _make()
+TFilter = Callable[[Flow], bool]
+
+
def parse(s):
+ # type: (str) -> TFilter
try:
filt = bnf.parseString(s, parseAll=True)[0]
filt.pattern = s
diff --git a/mitmproxy/flow/io_compat.py b/mitmproxy/flow/io_compat.py
index 8cd883c3..061bf16d 100644
--- a/mitmproxy/flow/io_compat.py
+++ b/mitmproxy/flow/io_compat.py
@@ -60,6 +60,7 @@ def convert_017_018(data):
data = convert_unicode(data)
data["server_conn"]["ip_address"] = data["server_conn"].pop("peer_address")
+ data["marked"] = False
data["version"] = (0, 18)
return data
diff --git a/mitmproxy/models/flow.py b/mitmproxy/models/flow.py
index f4993b7a..f4a2b54b 100644
--- a/mitmproxy/models/flow.py
+++ b/mitmproxy/models/flow.py
@@ -8,6 +8,8 @@ from mitmproxy import stateobject
from mitmproxy.models.connections import ClientConnection
from mitmproxy.models.connections import ServerConnection
+import six
+
from netlib import version
from typing import Optional # noqa
@@ -79,6 +81,7 @@ class Flow(stateobject.StateObject):
self.intercepted = False # type: bool
self._backup = None # type: Optional[Flow]
self.reply = None
+ self.marked = False # type: bool
_stateobject_attributes = dict(
id=str,
@@ -86,7 +89,8 @@ class Flow(stateobject.StateObject):
client_conn=ClientConnection,
server_conn=ServerConnection,
type=str,
- intercepted=bool
+ intercepted=bool,
+ marked=bool,
)
def get_state(self):
@@ -173,3 +177,21 @@ class Flow(stateobject.StateObject):
self.intercepted = False
self.reply.ack()
master.handle_accept_intercept(self)
+
+ def match(self, f):
+ """
+ Match this flow against a compiled filter expression. Returns True
+ if matched, False if not.
+
+ If f is a string, it will be compiled as a filter expression. If
+ the expression is invalid, ValueError is raised.
+ """
+ if isinstance(f, six.string_types):
+ from .. import filt
+
+ f = filt.parse(f)
+ if not f:
+ raise ValueError("Invalid filter expression.")
+ if f:
+ return f(self)
+ return True
diff --git a/mitmproxy/models/http.py b/mitmproxy/models/http.py
index 1fd28f00..7781e61f 100644
--- a/mitmproxy/models/http.py
+++ b/mitmproxy/models/http.py
@@ -2,7 +2,6 @@ from __future__ import absolute_import, print_function, division
import cgi
import warnings
-import six
from mitmproxy.models.flow import Flow
from netlib import version
@@ -211,24 +210,6 @@ class HTTPFlow(Flow):
f.response = self.response.copy()
return f
- def match(self, f):
- """
- Match this flow against a compiled filter expression. Returns True
- if matched, False if not.
-
- If f is a string, it will be compiled as a filter expression. If
- the expression is invalid, ValueError is raised.
- """
- if isinstance(f, six.string_types):
- from .. import filt
-
- f = filt.parse(f)
- if not f:
- raise ValueError("Invalid filter expression.")
- if f:
- return f(self)
- return True
-
def replace(self, pattern, repl, *args, **kwargs):
"""
Replaces a regular expression pattern with repl in both request and
diff --git a/mitmproxy/models/tcp.py b/mitmproxy/models/tcp.py
index 6650141d..e33475c2 100644
--- a/mitmproxy/models/tcp.py
+++ b/mitmproxy/models/tcp.py
@@ -7,8 +7,6 @@ from typing import List
import netlib.basetypes
from mitmproxy.models.flow import Flow
-import six
-
class TCPMessage(netlib.basetypes.Serializable):
@@ -55,22 +53,3 @@ class TCPFlow(Flow):
def __repr__(self):
return "<TCPFlow ({} messages)>".format(len(self.messages))
-
- def match(self, f):
- """
- Match this flow against a compiled filter expression. Returns True
- if matched, False if not.
-
- If f is a string, it will be compiled as a filter expression. If
- the expression is invalid, ValueError is raised.
- """
- if isinstance(f, six.string_types):
- from .. import filt
-
- f = filt.parse(f)
- if not f:
- raise ValueError("Invalid filter expression.")
- if f:
- return f(self)
-
- return True
diff --git a/mitmproxy/optmanager.py b/mitmproxy/optmanager.py
index e94ef51d..140c7ca8 100644
--- a/mitmproxy/optmanager.py
+++ b/mitmproxy/optmanager.py
@@ -35,7 +35,7 @@ class OptManager(object):
self.__dict__["_initialized"] = True
@contextlib.contextmanager
- def rollback(self):
+ def rollback(self, updated):
old = self._opts.copy()
try:
yield
@@ -44,7 +44,7 @@ class OptManager(object):
self.errored.send(self, exc=e)
# Rollback
self.__dict__["_opts"] = old
- self.changed.send(self)
+ self.changed.send(self, updated=updated)
def __eq__(self, other):
return self._opts == other._opts
@@ -62,22 +62,22 @@ class OptManager(object):
if not self._initialized:
self._opts[attr] = value
return
- if attr not in self._opts:
- raise KeyError("No such option: %s" % attr)
- with self.rollback():
- self._opts[attr] = value
- self.changed.send(self)
+ self.update(**{attr: value})
+
+ def keys(self):
+ return set(self._opts.keys())
def get(self, k, d=None):
return self._opts.get(k, d)
def update(self, **kwargs):
+ updated = set(kwargs.keys())
for k in kwargs:
if k not in self._opts:
raise KeyError("No such option: %s" % k)
- with self.rollback():
+ with self.rollback(updated):
self._opts.update(kwargs)
- self.changed.send(self)
+ self.changed.send(self, updated=updated)
def setter(self, attr):
"""
diff --git a/mitmproxy/protocol/http2.py b/mitmproxy/protocol/http2.py
index 1285e10e..8308f44d 100644
--- a/mitmproxy/protocol/http2.py
+++ b/mitmproxy/protocol/http2.py
@@ -584,6 +584,8 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
except exceptions.ProtocolException as e: # pragma: no cover
self.log(repr(e), "info")
self.log(traceback.format_exc(), "debug")
+ except exceptions.Kill:
+ self.log("Connection killed", "info")
if not self.zombie:
self.zombie = time.time()
diff --git a/mitmproxy/proxy/config.py b/mitmproxy/proxy/config.py
index 7aa4c736..a74ba7e2 100644
--- a/mitmproxy/proxy/config.py
+++ b/mitmproxy/proxy/config.py
@@ -79,10 +79,10 @@ class ProxyConfig:
self.certstore = None
self.clientcerts = None
self.openssl_verification_mode_server = None
- self.configure(options)
+ self.configure(options, set(options.keys()))
options.changed.connect(self.configure)
- def configure(self, options):
+ def configure(self, options, updated):
conflict = all(
[
options.add_upstream_certs_to_client_chain,
diff --git a/mitmproxy/web/app.py b/mitmproxy/web/app.py
index 8ccc21c5..f8f85f3d 100644
--- a/mitmproxy/web/app.py
+++ b/mitmproxy/web/app.py
@@ -234,7 +234,8 @@ class AcceptFlow(RequestHandler):
class FlowHandler(RequestHandler):
def delete(self, flow_id):
- self.flow.kill(self.master)
+ if not self.flow.reply.acked:
+ self.flow.kill(self.master)
self.state.delete_flow(self.flow)
def put(self, flow_id):
diff --git a/mitmproxy/web/master.py b/mitmproxy/web/master.py
index 3d384612..9ddb61d4 100644
--- a/mitmproxy/web/master.py
+++ b/mitmproxy/web/master.py
@@ -136,7 +136,7 @@ class WebMaster(flow.FlowMaster):
def __init__(self, server, options):
super(WebMaster, self).__init__(options, server, WebState())
- self.addons.add(*builtins.default_addons())
+ self.addons.add(options, *builtins.default_addons())
self.app = app.Application(
self, self.options.wdebug, self.options.wauthenticator
)
diff --git a/netlib/encoding.py b/netlib/encoding.py
index e3cf5f30..da282194 100644
--- a/netlib/encoding.py
+++ b/netlib/encoding.py
@@ -4,6 +4,7 @@ Utility functions for decoding response bodies.
from __future__ import absolute_import
import codecs
+import collections
from io import BytesIO
import gzip
import zlib
@@ -11,7 +12,15 @@ import zlib
from typing import Union # noqa
-def decode(obj, encoding, errors='strict'):
+# We have a shared single-element cache for encoding and decoding.
+# This is quite useful in practice, e.g.
+# flow.request.content = flow.request.content.replace(b"foo", b"bar")
+# does not require an .encode() call if content does not contain b"foo"
+CachedDecode = collections.namedtuple("CachedDecode", "encoded encoding errors decoded")
+_cache = CachedDecode(None, None, None, None)
+
+
+def decode(encoded, encoding, errors='strict'):
# type: (Union[str, bytes], str, str) -> Union[str, bytes]
"""
Decode the given input object
@@ -22,20 +31,32 @@ def decode(obj, encoding, errors='strict'):
Raises:
ValueError, if decoding fails.
"""
+ global _cache
+ cached = (
+ isinstance(encoded, bytes) and
+ _cache.encoded == encoded and
+ _cache.encoding == encoding and
+ _cache.errors == errors
+ )
+ if cached:
+ return _cache.decoded
try:
try:
- return custom_decode[encoding](obj)
+ decoded = custom_decode[encoding](encoded)
except KeyError:
- return codecs.decode(obj, encoding, errors)
+ decoded = codecs.decode(encoded, encoding, errors)
+ if encoding in ("gzip", "deflate"):
+ _cache = CachedDecode(encoded, encoding, errors, decoded)
+ return decoded
except Exception as e:
raise ValueError("{} when decoding {} with {}".format(
type(e).__name__,
- repr(obj)[:10],
+ repr(encoded)[:10],
repr(encoding),
))
-def encode(obj, encoding, errors='strict'):
+def encode(decoded, encoding, errors='strict'):
# type: (Union[str, bytes], str, str) -> Union[str, bytes]
"""
Encode the given input object
@@ -46,15 +67,27 @@ def encode(obj, encoding, errors='strict'):
Raises:
ValueError, if encoding fails.
"""
+ global _cache
+ cached = (
+ isinstance(decoded, bytes) and
+ _cache.decoded == decoded and
+ _cache.encoding == encoding and
+ _cache.errors == errors
+ )
+ if cached:
+ return _cache.encoded
try:
try:
- return custom_encode[encoding](obj)
+ encoded = custom_encode[encoding](decoded)
except KeyError:
- return codecs.encode(obj, encoding, errors)
+ encoded = codecs.encode(decoded, encoding, errors)
+ if encoding in ("gzip", "deflate"):
+ _cache = CachedDecode(encoded, encoding, errors, decoded)
+ return encoded
except Exception as e:
raise ValueError("{} when encoding {} with {}".format(
type(e).__name__,
- repr(obj)[:10],
+ repr(decoded)[:10],
repr(encoding),
))
diff --git a/netlib/http/message.py b/netlib/http/message.py
index 34709f0a..be35b8d1 100644
--- a/netlib/http/message.py
+++ b/netlib/http/message.py
@@ -32,9 +32,6 @@ class MessageData(basetypes.Serializable):
def __ne__(self, other):
return not self.__eq__(other)
- def __hash__(self):
- return hash(frozenset(self.__dict__.items()))
-
def set_state(self, state):
for k, v in state.items():
if k == "headers":
@@ -52,23 +49,7 @@ class MessageData(basetypes.Serializable):
return cls(**state)
-class CachedDecode(object):
- __slots__ = ["encoded", "encoding", "strict", "decoded"]
-
- def __init__(self, object, encoding, strict, decoded):
- self.encoded = object
- self.encoding = encoding
- self.strict = strict
- self.decoded = decoded
-
-no_cached_decode = CachedDecode(None, None, None, None)
-
-
class Message(basetypes.Serializable):
- def __init__(self):
- self._content_cache = no_cached_decode # type: CachedDecode
- self._text_cache = no_cached_decode # type: CachedDecode
-
def __eq__(self, other):
if isinstance(other, Message):
return self.data == other.data
@@ -77,9 +58,6 @@ class Message(basetypes.Serializable):
def __ne__(self, other):
return not self.__eq__(other)
- def __hash__(self):
- return hash(self.data) ^ 1
-
def get_state(self):
return self.data.get_state()
@@ -132,25 +110,15 @@ class Message(basetypes.Serializable):
if self.raw_content is None:
return None
ce = self.headers.get("content-encoding")
- cached = (
- self._content_cache.encoded == self.raw_content and
- (self._content_cache.strict or not strict) and
- self._content_cache.encoding == ce
- )
- if not cached:
- is_strict = True
- if ce:
- try:
- decoded = encoding.decode(self.raw_content, ce)
- except ValueError:
- if strict:
- raise
- is_strict = False
- decoded = self.raw_content
- else:
- decoded = self.raw_content
- self._content_cache = CachedDecode(self.raw_content, ce, is_strict, decoded)
- return self._content_cache.decoded
+ if ce:
+ try:
+ return encoding.decode(self.raw_content, ce)
+ except ValueError:
+ if strict:
+ raise
+ return self.raw_content
+ else:
+ return self.raw_content
def set_content(self, value):
if value is None:
@@ -163,22 +131,13 @@ class Message(basetypes.Serializable):
.format(type(value).__name__)
)
ce = self.headers.get("content-encoding")
- cached = (
- self._content_cache.decoded == value and
- self._content_cache.encoding == ce and
- self._content_cache.strict
- )
- if not cached:
- try:
- encoded = encoding.encode(value, ce or "identity")
- except ValueError:
- # So we have an invalid content-encoding?
- # Let's remove it!
- del self.headers["content-encoding"]
- ce = None
- encoded = value
- self._content_cache = CachedDecode(encoded, ce, True, value)
- self.raw_content = self._content_cache.encoded
+ try:
+ self.raw_content = encoding.encode(value, ce or "identity")
+ except ValueError:
+ # So we have an invalid content-encoding?
+ # Let's remove it!
+ del self.headers["content-encoding"]
+ self.raw_content = value
self.headers["content-length"] = str(len(self.raw_content))
content = property(get_content, set_content)
@@ -250,22 +209,12 @@ class Message(basetypes.Serializable):
enc = self._guess_encoding()
content = self.get_content(strict)
- cached = (
- self._text_cache.encoded == content and
- (self._text_cache.strict or not strict) and
- self._text_cache.encoding == enc
- )
- if not cached:
- is_strict = self._content_cache.strict
- try:
- decoded = encoding.decode(content, enc)
- except ValueError:
- if strict:
- raise
- is_strict = False
- decoded = self.content.decode("utf8", "replace" if six.PY2 else "surrogateescape")
- self._text_cache = CachedDecode(content, enc, is_strict, decoded)
- return self._text_cache.decoded
+ try:
+ return encoding.decode(content, enc)
+ except ValueError:
+ if strict:
+ raise
+ return content.decode("utf8", "replace" if six.PY2 else "surrogateescape")
def set_text(self, text):
if text is None:
@@ -273,23 +222,15 @@ class Message(basetypes.Serializable):
return
enc = self._guess_encoding()
- cached = (
- self._text_cache.decoded == text and
- self._text_cache.encoding == enc and
- self._text_cache.strict
- )
- if not cached:
- try:
- encoded = encoding.encode(text, enc)
- except ValueError:
- # Fall back to UTF-8 and update the content-type header.
- ct = headers.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {})
- ct[2]["charset"] = "utf-8"
- self.headers["content-type"] = headers.assemble_content_type(*ct)
- enc = "utf8"
- encoded = text.encode(enc, "replace" if six.PY2 else "surrogateescape")
- self._text_cache = CachedDecode(encoded, enc, True, text)
- self.content = self._text_cache.encoded
+ try:
+ self.content = encoding.encode(text, enc)
+ except ValueError:
+ # Fall back to UTF-8 and update the content-type header.
+ ct = headers.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {})
+ ct[2]["charset"] = "utf-8"
+ self.headers["content-type"] = headers.assemble_content_type(*ct)
+ enc = "utf8"
+ self.content = text.encode(enc, "replace" if six.PY2 else "surrogateescape")
text = property(get_text, set_text)
diff --git a/netlib/http/request.py b/netlib/http/request.py
index ecaa9b79..061217a3 100644
--- a/netlib/http/request.py
+++ b/netlib/http/request.py
@@ -253,14 +253,13 @@ class Request(message.Message):
)
def _get_query(self):
- _, _, _, _, query, _ = urllib.parse.urlparse(self.url)
+ query = urllib.parse.urlparse(self.url).query
return tuple(netlib.http.url.decode(query))
- def _set_query(self, value):
- query = netlib.http.url.encode(value)
- scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url)
- _, _, _, self.path = netlib.http.url.parse(
- urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]))
+ def _set_query(self, query_data):
+ query = netlib.http.url.encode(query_data)
+ _, _, path, params, _, fragment = urllib.parse.urlparse(self.url)
+ self.path = urllib.parse.urlunparse(["", "", path, params, query, fragment])
@query.setter
def query(self, value):
@@ -296,19 +295,18 @@ class Request(message.Message):
The URL's path components as a tuple of strings.
Components are unquoted.
"""
- _, _, path, _, _, _ = urllib.parse.urlparse(self.url)
+ path = urllib.parse.urlparse(self.url).path
# This needs to be a tuple so that it's immutable.
# Otherwise, this would fail silently:
# request.path_components.append("foo")
- return tuple(urllib.parse.unquote(i) for i in path.split("/") if i)
+ return tuple(netlib.http.url.unquote(i) for i in path.split("/") if i)
@path_components.setter
def path_components(self, components):
- components = map(lambda x: urllib.parse.quote(x, safe=""), components)
+ components = map(lambda x: netlib.http.url.quote(x, safe=""), components)
path = "/" + "/".join(components)
- scheme, netloc, _, params, query, fragment = urllib.parse.urlparse(self.url)
- _, _, _, self.path = netlib.http.url.parse(
- urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment]))
+ _, _, _, params, query, fragment = urllib.parse.urlparse(self.url)
+ self.path = urllib.parse.urlunparse(["", "", path, params, query, fragment])
def anticache(self):
"""
@@ -365,13 +363,13 @@ class Request(message.Message):
pass
return ()
- def _set_urlencoded_form(self, value):
+ def _set_urlencoded_form(self, form_data):
"""
Sets the body to the URL-encoded form data, and adds the appropriate content-type header.
This will overwrite the existing content if there is one.
"""
self.headers["content-type"] = "application/x-www-form-urlencoded"
- self.content = netlib.http.url.encode(value).encode()
+ self.content = netlib.http.url.encode(form_data).encode()
@urlencoded_form.setter
def urlencoded_form(self, value):
diff --git a/netlib/http/url.py b/netlib/http/url.py
index 2fc6e7ee..076854b9 100644
--- a/netlib/http/url.py
+++ b/netlib/http/url.py
@@ -82,18 +82,51 @@ def unparse(scheme, host, port, path=""):
def encode(s):
+ # type: Sequence[Tuple[str,str]] -> str
"""
Takes a list of (key, value) tuples and returns a urlencoded string.
"""
- s = [tuple(i) for i in s]
- return urllib.parse.urlencode(s, False)
+ if six.PY2:
+ return urllib.parse.urlencode(s, False)
+ else:
+ return urllib.parse.urlencode(s, False, errors="surrogateescape")
def decode(s):
"""
- Takes a urlencoded string and returns a list of (key, value) tuples.
+ Takes a urlencoded string and returns a list of surrogate-escaped (key, value) tuples.
+ """
+ if six.PY2:
+ return urllib.parse.parse_qsl(s, keep_blank_values=True)
+ else:
+ return urllib.parse.parse_qsl(s, keep_blank_values=True, errors='surrogateescape')
+
+
+def quote(b, safe="/"):
+ """
+ Returns:
+ An ascii-encodable str.
+ """
+ # type: (str) -> str
+ if six.PY2:
+ return urllib.parse.quote(b, safe=safe)
+ else:
+ return urllib.parse.quote(b, safe=safe, errors="surrogateescape")
+
+
+def unquote(s):
"""
- return urllib.parse.parse_qsl(s, keep_blank_values=True)
+ Args:
+ s: A surrogate-escaped str
+ Returns:
+ A surrogate-escaped str
+ """
+ # type: (str) -> str
+
+ if six.PY2:
+ return urllib.parse.unquote(s)
+ else:
+ return urllib.parse.unquote(s, errors="surrogateescape")
def hostport(scheme, host, port):
diff --git a/netlib/multidict.py b/netlib/multidict.py
index 51053ff6..e9fec155 100644
--- a/netlib/multidict.py
+++ b/netlib/multidict.py
@@ -79,9 +79,6 @@ class _MultiDict(MutableMapping, basetypes.Serializable):
def __ne__(self, other):
return not self.__eq__(other)
- def __hash__(self):
- return hash(self.fields)
-
def get_all(self, key):
"""
Return the list of all values for a given key.
@@ -241,6 +238,9 @@ class ImmutableMultiDict(MultiDict):
__delitem__ = set_all = insert = _immutable
+ def __hash__(self):
+ return hash(self.fields)
+
def with_delitem(self, key):
"""
Returns:
diff --git a/netlib/strutils.py b/netlib/strutils.py
index 32e77927..8f27ebb7 100644
--- a/netlib/strutils.py
+++ b/netlib/strutils.py
@@ -51,8 +51,7 @@ else:
def escape_control_characters(text, keep_spacing=True):
"""
- Replace all unicode C1 control characters from the given text with their respective control pictures.
- For example, a null byte is replaced with the unicode character "\u2400".
+ Replace all unicode C1 control characters from the given text with a single "."
Args:
keep_spacing: If True, tabs and newlines will not be replaced.
@@ -99,6 +98,9 @@ def bytes_to_escaped_str(data, keep_spacing=False):
def escaped_str_to_bytes(data):
"""
Take an escaped string and return the unescaped bytes equivalent.
+
+ Raises:
+ ValueError, if the escape sequence is invalid.
"""
if not isinstance(data, six.string_types):
if six.PY2:
diff --git a/test/mitmproxy/builtins/test_anticache.py b/test/mitmproxy/builtins/test_anticache.py
index 5a00af03..ac321e26 100644
--- a/test/mitmproxy/builtins/test_anticache.py
+++ b/test/mitmproxy/builtins/test_anticache.py
@@ -8,9 +8,10 @@ from mitmproxy import options
class TestAntiCache(mastertest.MasterTest):
def test_simple(self):
s = state.State()
- m = master.FlowMaster(options.Options(anticache = True), None, s)
+ o = options.Options(anticache = True)
+ m = master.FlowMaster(o, None, s)
sa = anticache.AntiCache()
- m.addons.add(sa)
+ m.addons.add(o, sa)
f = tutils.tflow(resp=True)
self.invoke(m, "request", f)
diff --git a/test/mitmproxy/builtins/test_anticomp.py b/test/mitmproxy/builtins/test_anticomp.py
index 6bfd54bb..a5f5a270 100644
--- a/test/mitmproxy/builtins/test_anticomp.py
+++ b/test/mitmproxy/builtins/test_anticomp.py
@@ -8,9 +8,10 @@ from mitmproxy import options
class TestAntiComp(mastertest.MasterTest):
def test_simple(self):
s = state.State()
- m = master.FlowMaster(options.Options(anticomp = True), None, s)
+ o = options.Options(anticomp = True)
+ m = master.FlowMaster(o, None, s)
sa = anticomp.AntiComp()
- m.addons.add(sa)
+ m.addons.add(o, sa)
f = tutils.tflow(resp=True)
self.invoke(m, "request", f)
diff --git a/test/mitmproxy/builtins/test_dumper.py b/test/mitmproxy/builtins/test_dumper.py
index 57e3d036..6287fe86 100644
--- a/test/mitmproxy/builtins/test_dumper.py
+++ b/test/mitmproxy/builtins/test_dumper.py
@@ -15,26 +15,27 @@ class TestDumper(mastertest.MasterTest):
d = dumper.Dumper()
sio = StringIO()
- d.configure(dump.Options(tfile = sio, flow_detail = 0))
+ updated = set(["tfile", "flow_detail"])
+ d.configure(dump.Options(tfile = sio, flow_detail = 0), updated)
d.response(tutils.tflow())
assert not sio.getvalue()
- d.configure(dump.Options(tfile = sio, flow_detail = 4))
+ d.configure(dump.Options(tfile = sio, flow_detail = 4), updated)
d.response(tutils.tflow())
assert sio.getvalue()
sio = StringIO()
- d.configure(dump.Options(tfile = sio, flow_detail = 4))
+ d.configure(dump.Options(tfile = sio, flow_detail = 4), updated)
d.response(tutils.tflow(resp=True))
assert "<<" in sio.getvalue()
sio = StringIO()
- d.configure(dump.Options(tfile = sio, flow_detail = 4))
+ d.configure(dump.Options(tfile = sio, flow_detail = 4), updated)
d.response(tutils.tflow(err=True))
assert "<<" in sio.getvalue()
sio = StringIO()
- d.configure(dump.Options(tfile = sio, flow_detail = 4))
+ d.configure(dump.Options(tfile = sio, flow_detail = 4), updated)
flow = tutils.tflow()
flow.request = netlib.tutils.treq()
flow.request.stickycookie = True
@@ -47,7 +48,7 @@ class TestDumper(mastertest.MasterTest):
assert sio.getvalue()
sio = StringIO()
- d.configure(dump.Options(tfile = sio, flow_detail = 4))
+ d.configure(dump.Options(tfile = sio, flow_detail = 4), updated)
flow = tutils.tflow(resp=netlib.tutils.tresp(content=b"{"))
flow.response.headers["content-type"] = "application/json"
flow.response.status_code = 400
@@ -55,7 +56,7 @@ class TestDumper(mastertest.MasterTest):
assert sio.getvalue()
sio = StringIO()
- d.configure(dump.Options(tfile = sio))
+ d.configure(dump.Options(tfile = sio), updated)
flow = tutils.tflow()
flow.request.content = None
flow.response = models.HTTPResponse.wrap(netlib.tutils.tresp())
@@ -72,15 +73,13 @@ class TestContentView(mastertest.MasterTest):
s = state.State()
sio = StringIO()
- m = mastertest.RecordingMaster(
- dump.Options(
- flow_detail=4,
- verbosity=3,
- tfile=sio,
- ),
- None, s
+ o = dump.Options(
+ flow_detail=4,
+ verbosity=3,
+ tfile=sio,
)
+ m = mastertest.RecordingMaster(o, None, s)
d = dumper.Dumper()
- m.addons.add(d)
+ m.addons.add(o, d)
self.invoke(m, "response", tutils.tflow())
assert "Content viewer failed" in m.event_log[0][1]
diff --git a/test/mitmproxy/builtins/test_filestreamer.py b/test/mitmproxy/builtins/test_filestreamer.py
index c1d5947f..0e69b340 100644
--- a/test/mitmproxy/builtins/test_filestreamer.py
+++ b/test/mitmproxy/builtins/test_filestreamer.py
@@ -20,16 +20,13 @@ class TestStream(mastertest.MasterTest):
return list(r.stream())
s = state.State()
- m = master.FlowMaster(
- options.Options(
- outfile = (p, "wb")
- ),
- None,
- s
+ o = options.Options(
+ outfile = (p, "wb")
)
+ m = master.FlowMaster(o, None, s)
sa = filestreamer.FileStreamer()
- m.addons.add(sa)
+ m.addons.add(o, sa)
f = tutils.tflow(resp=True)
self.invoke(m, "request", f)
self.invoke(m, "response", f)
@@ -39,7 +36,7 @@ class TestStream(mastertest.MasterTest):
m.options.outfile = (p, "ab")
- m.addons.add(sa)
+ m.addons.add(o, sa)
f = tutils.tflow()
self.invoke(m, "request", f)
m.addons.remove(sa)
diff --git a/test/mitmproxy/builtins/test_replace.py b/test/mitmproxy/builtins/test_replace.py
index a0b4b722..5e70ce56 100644
--- a/test/mitmproxy/builtins/test_replace.py
+++ b/test/mitmproxy/builtins/test_replace.py
@@ -8,38 +8,38 @@ from mitmproxy import options
class TestReplace(mastertest.MasterTest):
def test_configure(self):
r = replace.Replace()
+ updated = set(["replacements"])
r.configure(options.Options(
replacements=[("one", "two", "three")]
- ))
+ ), updated)
tutils.raises(
"invalid filter pattern",
r.configure,
options.Options(
replacements=[("~b", "two", "three")]
- )
+ ),
+ updated
)
tutils.raises(
"invalid regular expression",
r.configure,
options.Options(
replacements=[("foo", "+", "three")]
- )
+ ),
+ updated
)
def test_simple(self):
s = state.State()
- m = master.FlowMaster(
- options.Options(
- replacements = [
- ("~q", "foo", "bar"),
- ("~s", "foo", "bar"),
- ]
- ),
- None,
- s
+ o = options.Options(
+ replacements = [
+ ("~q", "foo", "bar"),
+ ("~s", "foo", "bar"),
+ ]
)
+ m = master.FlowMaster(o, None, s)
sa = replace.Replace()
- m.addons.add(sa)
+ m.addons.add(o, sa)
f = tutils.tflow()
f.request.content = b"foo"
diff --git a/test/mitmproxy/builtins/test_script.py b/test/mitmproxy/builtins/test_script.py
index f37c7f94..2870fd17 100644
--- a/test/mitmproxy/builtins/test_script.py
+++ b/test/mitmproxy/builtins/test_script.py
@@ -48,39 +48,41 @@ def test_load_script():
"data/addonscripts/recorder.py"
), []
)
- assert ns["configure"]
+ assert ns.start
class TestScript(mastertest.MasterTest):
def test_simple(self):
s = state.State()
- m = master.FlowMaster(options.Options(), None, s)
+ o = options.Options()
+ m = master.FlowMaster(o, None, s)
sc = script.Script(
tutils.test_data.path(
"data/addonscripts/recorder.py"
)
)
- m.addons.add(sc)
- assert sc.ns["call_log"] == [
+ m.addons.add(o, sc)
+ assert sc.ns.call_log == [
("solo", "start", (), {}),
- ("solo", "configure", (options.Options(),), {})
+ ("solo", "configure", (o, o.keys()), {})
]
- sc.ns["call_log"] = []
+ sc.ns.call_log = []
f = tutils.tflow(resp=True)
self.invoke(m, "request", f)
- recf = sc.ns["call_log"][0]
+ recf = sc.ns.call_log[0]
assert recf[1] == "request"
def test_reload(self):
s = state.State()
- m = mastertest.RecordingMaster(options.Options(), None, s)
+ o = options.Options()
+ m = mastertest.RecordingMaster(o, None, s)
with tutils.tmpdir():
with open("foo.py", "w"):
pass
sc = script.Script("foo.py")
- m.addons.add(sc)
+ m.addons.add(o, sc)
for _ in range(100):
with open("foo.py", "a") as f:
@@ -93,19 +95,22 @@ class TestScript(mastertest.MasterTest):
def test_exception(self):
s = state.State()
- m = mastertest.RecordingMaster(options.Options(), None, s)
+ o = options.Options()
+ m = mastertest.RecordingMaster(o, None, s)
sc = script.Script(
tutils.test_data.path("data/addonscripts/error.py")
)
- m.addons.add(sc)
+ m.addons.add(o, sc)
f = tutils.tflow(resp=True)
self.invoke(m, "request", f)
assert m.event_log[0][0] == "error"
def test_duplicate_flow(self):
s = state.State()
- fm = master.FlowMaster(None, None, s)
+ o = options.Options()
+ fm = master.FlowMaster(o, None, s)
fm.addons.add(
+ o,
script.Script(
tutils.test_data.path("data/addonscripts/duplicate_flow.py")
)
@@ -116,6 +121,20 @@ class TestScript(mastertest.MasterTest):
assert not fm.state.view[0].request.is_replay
assert fm.state.view[1].request.is_replay
+ def test_addon(self):
+ s = state.State()
+ o = options.Options()
+ m = master.FlowMaster(o, None, s)
+ sc = script.Script(
+ tutils.test_data.path(
+ "data/addonscripts/addon.py"
+ )
+ )
+ m.addons.add(o, sc)
+ assert sc.ns.event_log == [
+ 'scriptstart', 'addonstart', 'addonconfigure'
+ ]
+
class TestScriptLoader(mastertest.MasterTest):
def test_simple(self):
@@ -123,7 +142,7 @@ class TestScriptLoader(mastertest.MasterTest):
o = options.Options(scripts=[])
m = master.FlowMaster(o, None, s)
sc = script.ScriptLoader()
- m.addons.add(sc)
+ m.addons.add(o, sc)
assert len(m.addons) == 1
o.update(
scripts = [
@@ -139,7 +158,7 @@ class TestScriptLoader(mastertest.MasterTest):
o = options.Options(scripts=["one", "one"])
m = master.FlowMaster(o, None, s)
sc = script.ScriptLoader()
- tutils.raises(exceptions.OptionsError, m.addons.add, sc)
+ tutils.raises(exceptions.OptionsError, m.addons.add, o, sc)
def test_order(self):
rec = tutils.test_data.path("data/addonscripts/recorder.py")
@@ -154,7 +173,7 @@ class TestScriptLoader(mastertest.MasterTest):
)
m = mastertest.RecordingMaster(o, None, s)
sc = script.ScriptLoader()
- m.addons.add(sc)
+ m.addons.add(o, sc)
debug = [(i[0], i[1]) for i in m.event_log if i[0] == "debug"]
assert debug == [
diff --git a/test/mitmproxy/builtins/test_setheaders.py b/test/mitmproxy/builtins/test_setheaders.py
index 4465719d..41c18360 100644
--- a/test/mitmproxy/builtins/test_setheaders.py
+++ b/test/mitmproxy/builtins/test_setheaders.py
@@ -8,19 +8,20 @@ from mitmproxy import options
class TestSetHeaders(mastertest.MasterTest):
def mkmaster(self, **opts):
s = state.State()
- m = mastertest.RecordingMaster(options.Options(**opts), None, s)
+ o = options.Options(**opts)
+ m = mastertest.RecordingMaster(o, None, s)
sh = setheaders.SetHeaders()
- m.addons.add(sh)
+ m.addons.add(o, sh)
return m, sh
def test_configure(self):
sh = setheaders.SetHeaders()
+ o = options.Options(
+ setheaders = [("~b", "one", "two")]
+ )
tutils.raises(
"invalid setheader filter pattern",
- sh.configure,
- options.Options(
- setheaders = [("~b", "one", "two")]
- )
+ sh.configure, o, o.keys()
)
def test_setheaders(self):
diff --git a/test/mitmproxy/builtins/test_stickyauth.py b/test/mitmproxy/builtins/test_stickyauth.py
index 9233f435..5757fb2d 100644
--- a/test/mitmproxy/builtins/test_stickyauth.py
+++ b/test/mitmproxy/builtins/test_stickyauth.py
@@ -8,9 +8,10 @@ from mitmproxy import options
class TestStickyAuth(mastertest.MasterTest):
def test_simple(self):
s = state.State()
- m = master.FlowMaster(options.Options(stickyauth = ".*"), None, s)
+ o = options.Options(stickyauth = ".*")
+ m = master.FlowMaster(o, None, s)
sa = stickyauth.StickyAuth()
- m.addons.add(sa)
+ m.addons.add(o, sa)
f = tutils.tflow(resp=True)
f.request.headers["authorization"] = "foo"
diff --git a/test/mitmproxy/builtins/test_stickycookie.py b/test/mitmproxy/builtins/test_stickycookie.py
index 81b540db..e9d92c83 100644
--- a/test/mitmproxy/builtins/test_stickycookie.py
+++ b/test/mitmproxy/builtins/test_stickycookie.py
@@ -14,22 +14,23 @@ def test_domain_match():
class TestStickyCookie(mastertest.MasterTest):
def mk(self):
s = state.State()
- m = master.FlowMaster(options.Options(stickycookie = ".*"), None, s)
+ o = options.Options(stickycookie = ".*")
+ m = master.FlowMaster(o, None, s)
sc = stickycookie.StickyCookie()
- m.addons.add(sc)
+ m.addons.add(o, sc)
return s, m, sc
def test_config(self):
sc = stickycookie.StickyCookie()
+ o = options.Options(stickycookie = "~b")
tutils.raises(
"invalid filter",
- sc.configure,
- options.Options(stickycookie = "~b")
+ sc.configure, o, o.keys()
)
def test_simple(self):
s, m, sc = self.mk()
- m.addons.add(sc)
+ m.addons.add(m.options, sc)
f = tutils.tflow(resp=True)
f.response.headers["set-cookie"] = "foo=bar"
diff --git a/test/mitmproxy/data/addonscripts/addon.py b/test/mitmproxy/data/addonscripts/addon.py
new file mode 100644
index 00000000..84173cb6
--- /dev/null
+++ b/test/mitmproxy/data/addonscripts/addon.py
@@ -0,0 +1,22 @@
+event_log = []
+
+
+class Addon:
+ @property
+ def event_log(self):
+ return event_log
+
+ def start(self):
+ event_log.append("addonstart")
+
+ def configure(self, options, updated):
+ event_log.append("addonconfigure")
+
+
+def configure(options, updated):
+ event_log.append("addonconfigure")
+
+
+def start():
+ event_log.append("scriptstart")
+ return Addon()
diff --git a/test/mitmproxy/data/addonscripts/recorder.py b/test/mitmproxy/data/addonscripts/recorder.py
index b6ac8d89..890e6f4e 100644
--- a/test/mitmproxy/data/addonscripts/recorder.py
+++ b/test/mitmproxy/data/addonscripts/recorder.py
@@ -2,24 +2,24 @@ from mitmproxy import controller
from mitmproxy import ctx
import sys
-call_log = []
-if len(sys.argv) > 1:
- name = sys.argv[1]
-else:
- name = "solo"
+class CallLogger:
+ call_log = []
-# Keep a log of all possible event calls
-evts = list(controller.Events) + ["configure"]
-for i in evts:
- def mkprox():
- evt = i
+ def __init__(self, name = "solo"):
+ self.name = name
- def prox(*args, **kwargs):
- lg = (name, evt, args, kwargs)
- if evt != "log":
- ctx.log.info(str(lg))
- call_log.append(lg)
- ctx.log.debug("%s %s" % (name, evt))
- return prox
- globals()[i] = mkprox()
+ def __getattr__(self, attr):
+ if attr in controller.Events:
+ def prox(*args, **kwargs):
+ lg = (self.name, attr, args, kwargs)
+ if attr != "log":
+ ctx.log.info(str(lg))
+ self.call_log.append(lg)
+ ctx.log.debug("%s %s" % (self.name, attr))
+ return prox
+ raise AttributeError
+
+
+def start():
+ return CallLogger(*sys.argv[1:])
diff --git a/test/mitmproxy/data/dumpfile-011 b/test/mitmproxy/data/dumpfile-011
index 2534ad89..936ac0cc 100644
--- a/test/mitmproxy/data/dumpfile-011
+++ b/test/mitmproxy/data/dumpfile-011
Binary files differ
diff --git a/test/mitmproxy/script/test_concurrent.py b/test/mitmproxy/script/test_concurrent.py
index 080746e8..a5f76994 100644
--- a/test/mitmproxy/script/test_concurrent.py
+++ b/test/mitmproxy/script/test_concurrent.py
@@ -23,7 +23,7 @@ class TestConcurrent(mastertest.MasterTest):
"data/addonscripts/concurrent_decorator.py"
)
)
- m.addons.add(sc)
+ m.addons.add(m.options, sc)
f1, f2 = tutils.tflow(), tutils.tflow()
self.invoke(m, "request", f1)
self.invoke(m, "request", f2)
diff --git a/test/mitmproxy/test_addons.py b/test/mitmproxy/test_addons.py
index 1861d4ac..a5085ea0 100644
--- a/test/mitmproxy/test_addons.py
+++ b/test/mitmproxy/test_addons.py
@@ -13,8 +13,9 @@ class TAddon:
def test_simple():
- m = controller.Master(options.Options())
+ o = options.Options()
+ m = controller.Master(o)
a = addons.Addons(m)
- a.add(TAddon("one"))
+ a.add(o, TAddon("one"))
assert a.has_addon("one")
assert not a.has_addon("two")
diff --git a/test/mitmproxy/test_contentview.py b/test/mitmproxy/test_contentview.py
index 2db9ab40..aad53b37 100644
--- a/test/mitmproxy/test_contentview.py
+++ b/test/mitmproxy/test_contentview.py
@@ -59,10 +59,10 @@ class TestContentView:
assert f[0] == "Query"
def test_view_urlencoded(self):
- d = url.encode([("one", "two"), ("three", "four")])
+ d = url.encode([("one", "two"), ("three", "four")]).encode()
v = cv.ViewURLEncoded()
assert v(d)
- d = url.encode([("adsfa", "")])
+ d = url.encode([("adsfa", "")]).encode()
v = cv.ViewURLEncoded()
assert v(d)
diff --git a/test/mitmproxy/test_examples.py b/test/mitmproxy/test_examples.py
index 0ec85f52..34fcc261 100644
--- a/test/mitmproxy/test_examples.py
+++ b/test/mitmproxy/test_examples.py
@@ -27,10 +27,11 @@ class RaiseMaster(master.FlowMaster):
def tscript(cmd, args=""):
+ o = options.Options()
cmd = example_dir.path(cmd) + " " + args
- m = RaiseMaster(options.Options(), None, state.State())
+ m = RaiseMaster(o, None, state.State())
sc = script.Script(cmd)
- m.addons.add(sc)
+ m.addons.add(o, sc)
return m, sc
diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py
index 36b212a7..74992130 100644
--- a/test/mitmproxy/test_flow.py
+++ b/test/mitmproxy/test_flow.py
@@ -615,6 +615,7 @@ class TestSerialize:
def test_roundtrip(self):
sio = io.BytesIO()
f = tutils.tflow()
+ f.marked = True
f.request.content = bytes(bytearray(range(256)))
w = flow.FlowWriter(sio)
w.add(f)
@@ -627,6 +628,7 @@ class TestSerialize:
f2 = l[0]
assert f2.get_state() == f.get_state()
assert f2.request == f.request
+ assert f2.marked
def test_load_flows(self):
r = self._treader()
diff --git a/test/mitmproxy/test_optmanager.py b/test/mitmproxy/test_optmanager.py
index 67f76ecd..8414e6b5 100644
--- a/test/mitmproxy/test_optmanager.py
+++ b/test/mitmproxy/test_optmanager.py
@@ -15,6 +15,8 @@ class TO(optmanager.OptManager):
def test_options():
o = TO(two="three")
+ assert o.keys() == set(["one", "two"])
+
assert o.one is None
assert o.two == "three"
o.one = "one"
@@ -29,7 +31,7 @@ def test_options():
rec = []
- def sub(opts):
+ def sub(opts, updated):
rec.append(copy.copy(opts))
o.changed.connect(sub)
@@ -68,7 +70,7 @@ def test_rollback():
rec = []
- def sub(opts):
+ def sub(opts, updated):
rec.append(copy.copy(opts))
recerr = []
@@ -76,7 +78,7 @@ def test_rollback():
def errsub(opts, **kwargs):
recerr.append(kwargs)
- def err(opts):
+ def err(opts, updated):
if opts.one == "ten":
raise exceptions.OptionsError()
diff --git a/test/mitmproxy/test_protocol_http2.py b/test/mitmproxy/test_protocol_http2.py
index afbffb67..aa096a72 100644
--- a/test/mitmproxy/test_protocol_http2.py
+++ b/test/mitmproxy/test_protocol_http2.py
@@ -30,7 +30,7 @@ logging.getLogger("PIL.PngImagePlugin").setLevel(logging.WARNING)
requires_alpn = pytest.mark.skipif(
not netlib.tcp.HAS_ALPN,
- reason="requires OpenSSL with ALPN support")
+ reason='requires OpenSSL with ALPN support')
class _Http2ServerBase(netlib_tservers.ServerTestBase):
@@ -80,7 +80,7 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase):
print(traceback.format_exc())
break
- def handle_server_event(self, h2_conn, rfile, wfile):
+ def handle_server_event(self, event, h2_conn, rfile, wfile):
raise NotImplementedError()
@@ -88,7 +88,6 @@ class _Http2TestBase(object):
@classmethod
def setup_class(cls):
- cls.masteroptions = options.Options()
opts = cls.get_options()
cls.config = ProxyConfig(opts)
@@ -145,12 +144,14 @@ class _Http2TestBase(object):
wfile,
h2_conn,
stream_id=1,
- headers=[],
+ headers=None,
body=b'',
end_stream=None,
priority_exclusive=None,
priority_depends_on=None,
priority_weight=None):
+ if headers is None:
+ headers = []
if end_stream is None:
end_stream = (len(body) == 0)
@@ -172,12 +173,12 @@ class _Http2TestBase(object):
class _Http2Test(_Http2TestBase, _Http2ServerBase):
@classmethod
- def setup_class(self):
+ def setup_class(cls):
_Http2TestBase.setup_class()
_Http2ServerBase.setup_class()
@classmethod
- def teardown_class(self):
+ def teardown_class(cls):
_Http2TestBase.teardown_class()
_Http2ServerBase.teardown_class()
@@ -187,7 +188,7 @@ class TestSimple(_Http2Test):
request_body_buffer = b''
@classmethod
- def handle_server_event(self, event, h2_conn, rfile, wfile):
+ def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.RequestReceived):
@@ -214,7 +215,7 @@ class TestSimple(_Http2Test):
wfile.write(h2_conn.data_to_send())
wfile.flush()
elif isinstance(event, h2.events.DataReceived):
- self.request_body_buffer += event.data
+ cls.request_body_buffer += event.data
return True
def test_simple(self):
@@ -225,7 +226,7 @@ class TestSimple(_Http2Test):
client.wfile,
h2_conn,
headers=[
- (':authority', "127.0.0.1:%s" % self.server.server.address.port),
+ (':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
@@ -269,7 +270,7 @@ class TestSimple(_Http2Test):
class TestRequestWithPriority(_Http2Test):
@classmethod
- def handle_server_event(self, event, h2_conn, rfile, wfile):
+ def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.RequestReceived):
@@ -301,14 +302,14 @@ class TestRequestWithPriority(_Http2Test):
client.wfile,
h2_conn,
headers=[
- (':authority', "127.0.0.1:%s" % self.server.server.address.port),
+ (':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
],
- priority_exclusive = True,
- priority_depends_on = 42424242,
- priority_weight = 42,
+ priority_exclusive=True,
+ priority_depends_on=42424242,
+ priority_weight=42,
)
done = False
@@ -343,7 +344,7 @@ class TestRequestWithPriority(_Http2Test):
client.wfile,
h2_conn,
headers=[
- (':authority', "127.0.0.1:%s" % self.server.server.address.port),
+ (':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
@@ -381,11 +382,11 @@ class TestPriority(_Http2Test):
priority_data = None
@classmethod
- def handle_server_event(self, event, h2_conn, rfile, wfile):
+ def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.PriorityUpdated):
- self.priority_data = (event.exclusive, event.depends_on, event.weight)
+ cls.priority_data = (event.exclusive, event.depends_on, event.weight)
elif isinstance(event, h2.events.RequestReceived):
import warnings
with warnings.catch_warnings():
@@ -415,7 +416,7 @@ class TestPriority(_Http2Test):
client.wfile,
h2_conn,
headers=[
- (':authority', "127.0.0.1:%s" % self.server.server.address.port),
+ (':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
@@ -451,11 +452,11 @@ class TestPriorityWithExistingStream(_Http2Test):
priority_data = []
@classmethod
- def handle_server_event(self, event, h2_conn, rfile, wfile):
+ def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.PriorityUpdated):
- self.priority_data.append((event.exclusive, event.depends_on, event.weight))
+ cls.priority_data.append((event.exclusive, event.depends_on, event.weight))
elif isinstance(event, h2.events.RequestReceived):
assert not event.priority_updated
@@ -486,7 +487,7 @@ class TestPriorityWithExistingStream(_Http2Test):
client.wfile,
h2_conn,
headers=[
- (':authority', "127.0.0.1:%s" % self.server.server.address.port),
+ (':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
@@ -527,7 +528,7 @@ class TestPriorityWithExistingStream(_Http2Test):
class TestStreamResetFromServer(_Http2Test):
@classmethod
- def handle_server_event(self, event, h2_conn, rfile, wfile):
+ def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.RequestReceived):
@@ -543,7 +544,7 @@ class TestStreamResetFromServer(_Http2Test):
client.wfile,
h2_conn,
headers=[
- (':authority', "127.0.0.1:%s" % self.server.server.address.port),
+ (':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
@@ -578,7 +579,7 @@ class TestStreamResetFromServer(_Http2Test):
class TestBodySizeLimit(_Http2Test):
@classmethod
- def handle_server_event(self, event, h2_conn, rfile, wfile):
+ def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.ConnectionTerminated):
return False
return True
@@ -592,7 +593,7 @@ class TestBodySizeLimit(_Http2Test):
client.wfile,
h2_conn,
headers=[
- (':authority', "127.0.0.1:%s" % self.server.server.address.port),
+ (':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
@@ -627,7 +628,7 @@ class TestBodySizeLimit(_Http2Test):
class TestPushPromise(_Http2Test):
@classmethod
- def handle_server_event(self, event, h2_conn, rfile, wfile):
+ def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.RequestReceived):
@@ -637,14 +638,14 @@ class TestPushPromise(_Http2Test):
h2_conn.send_headers(1, [(':status', '200')])
h2_conn.push_stream(1, 2, [
- (':authority', "127.0.0.1:%s" % self.port),
+ (':authority', "127.0.0.1:{}".format(cls.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/pushed_stream_foo'),
('foo', 'bar')
])
h2_conn.push_stream(1, 4, [
- (':authority', "127.0.0.1:%s" % self.port),
+ (':authority', "127.0.0.1:{}".format(cls.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/pushed_stream_bar'),
@@ -675,7 +676,7 @@ class TestPushPromise(_Http2Test):
client, h2_conn = self._setup_connection()
self._send_request(client.wfile, h2_conn, stream_id=1, headers=[
- (':authority', "127.0.0.1:%s" % self.server.server.address.port),
+ (':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
@@ -728,7 +729,7 @@ class TestPushPromise(_Http2Test):
client, h2_conn = self._setup_connection()
self._send_request(client.wfile, h2_conn, stream_id=1, headers=[
- (':authority', "127.0.0.1:%s" % self.server.server.address.port),
+ (':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
@@ -780,7 +781,7 @@ class TestPushPromise(_Http2Test):
class TestConnectionLost(_Http2Test):
@classmethod
- def handle_server_event(self, event, h2_conn, rfile, wfile):
+ def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.RequestReceived):
h2_conn.send_headers(1, [(':status', '200')])
wfile.write(h2_conn.data_to_send())
@@ -791,7 +792,7 @@ class TestConnectionLost(_Http2Test):
client, h2_conn = self._setup_connection()
self._send_request(client.wfile, h2_conn, stream_id=1, headers=[
- (':authority', "127.0.0.1:%s" % self.server.server.address.port),
+ (':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
@@ -822,12 +823,12 @@ class TestConnectionLost(_Http2Test):
class TestMaxConcurrentStreams(_Http2Test):
@classmethod
- def setup_class(self):
+ def setup_class(cls):
_Http2TestBase.setup_class()
_Http2ServerBase.setup_class(h2_server_settings={h2.settings.MAX_CONCURRENT_STREAMS: 2})
@classmethod
- def handle_server_event(self, event, h2_conn, rfile, wfile):
+ def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.ConnectionTerminated):
return False
elif isinstance(event, h2.events.RequestReceived):
@@ -848,7 +849,7 @@ class TestMaxConcurrentStreams(_Http2Test):
# this will exceed MAX_CONCURRENT_STREAMS on the server connection
# and cause mitmproxy to throttle stream creation to the server
self._send_request(client.wfile, h2_conn, stream_id=id, headers=[
- (':authority', "127.0.0.1:%s" % self.server.server.address.port),
+ (':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
@@ -883,7 +884,7 @@ class TestMaxConcurrentStreams(_Http2Test):
class TestConnectionTerminated(_Http2Test):
@classmethod
- def handle_server_event(self, event, h2_conn, rfile, wfile):
+ def handle_server_event(cls, event, h2_conn, rfile, wfile):
if isinstance(event, h2.events.RequestReceived):
h2_conn.close_connection(error_code=5, last_stream_id=42, additional_data=b'foobar')
wfile.write(h2_conn.data_to_send())
@@ -894,7 +895,7 @@ class TestConnectionTerminated(_Http2Test):
client, h2_conn = self._setup_connection()
self._send_request(client.wfile, h2_conn, headers=[
- (':authority', "127.0.0.1:%s" % self.server.server.address.port),
+ (':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
diff --git a/test/mitmproxy/test_server.py b/test/mitmproxy/test_server.py
index 233af597..6230fc1f 100644
--- a/test/mitmproxy/test_server.py
+++ b/test/mitmproxy/test_server.py
@@ -291,7 +291,7 @@ class TestHTTP(tservers.HTTPProxyTest, CommonMixin, AppMixin):
s = script.Script(
tutils.test_data.path("data/addonscripts/stream_modify.py")
)
- self.master.addons.add(s)
+ self.master.addons.add(self.master.options, s)
d = self.pathod('200:b"foo"')
assert d.content == b"bar"
self.master.addons.remove(s)
@@ -523,7 +523,7 @@ class TestTransparent(tservers.TransparentProxyTest, CommonMixin, TcpMixin):
s = script.Script(
tutils.test_data.path("data/addonscripts/tcp_stream_modify.py")
)
- self.master.addons.add(s)
+ self.master.addons.add(self.master.options, s)
self._tcpproxy_on()
d = self.pathod('200:b"foo"')
self._tcpproxy_off()
diff --git a/test/mitmproxy/tservers.py b/test/mitmproxy/tservers.py
index f5119166..d364162c 100644
--- a/test/mitmproxy/tservers.py
+++ b/test/mitmproxy/tservers.py
@@ -34,7 +34,7 @@ class TestMaster(flow.FlowMaster):
s = ProxyServer(config)
state = flow.State()
flow.FlowMaster.__init__(self, opts, s, state)
- self.addons.add(*builtins.default_addons())
+ self.addons.add(opts, *builtins.default_addons())
self.apps.add(testapp, "testapp", 80)
self.apps.add(errapp, "errapp", 80)
self.clear_log()
diff --git a/test/netlib/http/test_message.py b/test/netlib/http/test_message.py
index deebd6f2..12e4706c 100644
--- a/test/netlib/http/test_message.py
+++ b/test/netlib/http/test_message.py
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function, division
-import mock
import six
from netlib.tutils import tresp
@@ -71,10 +70,6 @@ class TestMessage(object):
assert resp != 0
- def test_hash(self):
- resp = tresp()
- assert hash(resp)
-
def test_serializable(self):
resp = tresp()
resp2 = http.Response.from_state(resp.get_state())
@@ -117,14 +112,6 @@ class TestMessageContentEncoding(object):
assert r.content == b"message"
assert r.raw_content != b"message"
- r.raw_content = b"foo"
- with mock.patch("netlib.encoding.decode") as e:
- assert r.content
- assert e.call_count == 1
- e.reset_mock()
- assert r.content
- assert e.call_count == 0
-
def test_modify(self):
r = tresp()
assert "content-encoding" not in r.headers
@@ -135,13 +122,6 @@ class TestMessageContentEncoding(object):
r.decode()
assert r.raw_content == b"foo"
- r.encode("identity")
- with mock.patch("netlib.encoding.encode") as e:
- r.content = b"foo"
- assert e.call_count == 0
- r.content = b"bar"
- assert e.call_count == 1
-
with tutils.raises(TypeError):
r.content = u"foo"
@@ -216,15 +196,6 @@ class TestMessageText(object):
r.headers["content-type"] = "text/html; charset=utf8"
assert r.text == u"ü"
- r.encode("identity")
- r.raw_content = b"foo"
- with mock.patch("netlib.encoding.decode") as e:
- assert r.text
- assert e.call_count == 2
- e.reset_mock()
- assert r.text
- assert e.call_count == 0
-
def test_guess_json(self):
r = tresp(content=b'"\xc3\xbc"')
r.headers["content-type"] = "application/json"
@@ -249,14 +220,6 @@ class TestMessageText(object):
assert r.raw_content == b"\xc3\xbc"
assert r.headers["content-length"] == "2"
- r.encode("identity")
- with mock.patch("netlib.encoding.encode") as e:
- e.return_value = b""
- r.text = u"ü"
- assert e.call_count == 0
- r.text = u"ä"
- assert e.call_count == 2
-
def test_unknown_ce(self):
r = tresp()
r.headers["content-type"] = "text/html; charset=wtf"
diff --git a/test/netlib/http/test_url.py b/test/netlib/http/test_url.py
index 26b37230..768e5130 100644
--- a/test/netlib/http/test_url.py
+++ b/test/netlib/http/test_url.py
@@ -1,3 +1,4 @@
+import six
from netlib import tutils
from netlib.http import url
@@ -57,10 +58,49 @@ def test_unparse():
assert url.unparse("https", "foo.com", 443, "") == "https://foo.com"
-def test_urlencode():
+if six.PY2:
+ surrogates = bytes(bytearray(range(256)))
+else:
+ surrogates = bytes(range(256)).decode("utf8", "surrogateescape")
+
+surrogates_quoted = (
+ '%00%01%02%03%04%05%06%07%08%09%0A%0B%0C%0D%0E%0F'
+ '%10%11%12%13%14%15%16%17%18%19%1A%1B%1C%1D%1E%1F'
+ '%20%21%22%23%24%25%26%27%28%29%2A%2B%2C-./'
+ '0123456789%3A%3B%3C%3D%3E%3F'
+ '%40ABCDEFGHIJKLMNO'
+ 'PQRSTUVWXYZ%5B%5C%5D%5E_'
+ '%60abcdefghijklmno'
+ 'pqrstuvwxyz%7B%7C%7D%7E%7F'
+ '%80%81%82%83%84%85%86%87%88%89%8A%8B%8C%8D%8E%8F'
+ '%90%91%92%93%94%95%96%97%98%99%9A%9B%9C%9D%9E%9F'
+ '%A0%A1%A2%A3%A4%A5%A6%A7%A8%A9%AA%AB%AC%AD%AE%AF'
+ '%B0%B1%B2%B3%B4%B5%B6%B7%B8%B9%BA%BB%BC%BD%BE%BF'
+ '%C0%C1%C2%C3%C4%C5%C6%C7%C8%C9%CA%CB%CC%CD%CE%CF'
+ '%D0%D1%D2%D3%D4%D5%D6%D7%D8%D9%DA%DB%DC%DD%DE%DF'
+ '%E0%E1%E2%E3%E4%E5%E6%E7%E8%E9%EA%EB%EC%ED%EE%EF'
+ '%F0%F1%F2%F3%F4%F5%F6%F7%F8%F9%FA%FB%FC%FD%FE%FF'
+)
+
+
+def test_encode():
assert url.encode([('foo', 'bar')])
+ assert url.encode([('foo', surrogates)])
-def test_urldecode():
+def test_decode():
s = "one=two&three=four"
assert len(url.decode(s)) == 2
+ assert url.decode(surrogates)
+
+
+def test_quote():
+ assert url.quote("foo") == "foo"
+ assert url.quote("foo bar") == "foo%20bar"
+ assert url.quote(surrogates) == surrogates_quoted
+
+
+def test_unquote():
+ assert url.unquote("foo") == "foo"
+ assert url.unquote("foo%20bar") == "foo bar"
+ assert url.unquote(surrogates_quoted) == surrogates
diff --git a/test/netlib/test_encoding.py b/test/netlib/test_encoding.py
index de10fc48..a5e81379 100644
--- a/test/netlib/test_encoding.py
+++ b/test/netlib/test_encoding.py
@@ -1,3 +1,4 @@
+import mock
from netlib import encoding, tutils
@@ -37,3 +38,32 @@ def test_deflate():
)
with tutils.raises(ValueError):
encoding.decode(b"bogus", "deflate")
+
+
+def test_cache():
+ decode_gzip = mock.MagicMock()
+ decode_gzip.return_value = b"decoded"
+ encode_gzip = mock.MagicMock()
+ encode_gzip.return_value = b"encoded"
+
+ with mock.patch.dict(encoding.custom_decode, gzip=decode_gzip):
+ with mock.patch.dict(encoding.custom_encode, gzip=encode_gzip):
+ assert encoding.decode(b"encoded", "gzip") == b"decoded"
+ assert decode_gzip.call_count == 1
+
+ # should be cached
+ assert encoding.decode(b"encoded", "gzip") == b"decoded"
+ assert decode_gzip.call_count == 1
+
+ # the other way around as well
+ assert encoding.encode(b"decoded", "gzip") == b"encoded"
+ assert encode_gzip.call_count == 0
+
+ # different encoding
+ decode_gzip.return_value = b"bar"
+ assert encoding.encode(b"decoded", "deflate") != b"decoded"
+ assert encode_gzip.call_count == 0
+
+ # This is not in the cache anymore
+ assert encoding.encode(b"decoded", "gzip") == b"encoded"
+ assert encode_gzip.call_count == 1
diff --git a/test/netlib/test_multidict.py b/test/netlib/test_multidict.py
index 038441e7..58ae0f98 100644
--- a/test/netlib/test_multidict.py
+++ b/test/netlib/test_multidict.py
@@ -45,7 +45,7 @@ class TestMultiDict(object):
assert md["foo"] == "bar"
with tutils.raises(KeyError):
- md["bar"]
+ assert md["bar"]
md_multi = TMultiDict(
[("foo", "a"), ("foo", "b")]
@@ -101,6 +101,15 @@ class TestMultiDict(object):
assert TMultiDict() != self._multi()
assert TMultiDict() != 42
+ def test_hash(self):
+ """
+ If a class defines mutable objects and implements an __eq__() method,
+ it should not implement __hash__(), since the implementation of hashable
+ collections requires that a key's hash value is immutable.
+ """
+ with tutils.raises(TypeError):
+ assert hash(TMultiDict())
+
def test_get_all(self):
md = self._multi()
assert md.get_all("foo") == ["bar"]
@@ -197,6 +206,9 @@ class TestImmutableMultiDict(object):
with tutils.raises(TypeError):
md.add("foo", "bar")
+ def test_hash(self):
+ assert hash(TImmutableMultiDict())
+
def test_with_delitem(self):
md = TImmutableMultiDict([("foo", "bar")])
assert md.with_delitem("foo").fields == ()
diff --git a/web/src/js/components/FlowTable/FlowRow.jsx b/web/src/js/components/FlowTable/FlowRow.jsx
index 749bc0ce..7961d502 100644
--- a/web/src/js/components/FlowTable/FlowRow.jsx
+++ b/web/src/js/components/FlowTable/FlowRow.jsx
@@ -1,6 +1,7 @@
import React, { PropTypes } from 'react'
import classnames from 'classnames'
import columns from './FlowColumns'
+import { pure } from '../../utils'
FlowRow.propTypes = {
onSelect: PropTypes.func.isRequired,
@@ -9,7 +10,7 @@ FlowRow.propTypes = {
selected: PropTypes.bool,
}
-export default function FlowRow({ flow, selected, highlighted, onSelect }) {
+function FlowRow({ flow, selected, highlighted, onSelect }) {
const className = classnames({
'selected': selected,
'highlighted': highlighted,
@@ -19,10 +20,12 @@ export default function FlowRow({ flow, selected, highlighted, onSelect }) {
})
return (
- <tr className={className} onClick={() => onSelect(flow)}>
+ <tr className={className} onClick={() => onSelect(flow.id)}>
{columns.map(Column => (
<Column key={Column.name} flow={flow}/>
))}
</tr>
)
}
+
+export default pure(FlowRow)
diff --git a/web/src/js/components/MainView.jsx b/web/src/js/components/MainView.jsx
index d7d1ebeb..f45f9eef 100644
--- a/web/src/js/components/MainView.jsx
+++ b/web/src/js/components/MainView.jsx
@@ -22,7 +22,7 @@ class MainView extends Component {
flows={flows}
selected={selectedFlow}
highlight={highlight}
- onSelect={flow => this.props.selectFlow(flow.id)}
+ onSelect={this.props.selectFlow}
/>
{selectedFlow && [
<Splitter key="splitter"/>,
diff --git a/web/src/js/utils.js b/web/src/js/utils.js
index d3b99bd0..cc17c565 100644
--- a/web/src/js/utils.js
+++ b/web/src/js/utils.js
@@ -1,7 +1,9 @@
-import _ from "lodash";
+import _ from 'lodash'
+import React from 'react'
+import shallowEqual from 'shallowequal'
window._ = _;
-window.React = require("react");
+window.React = React;
export var Key = {
UP: 38,
@@ -106,15 +108,27 @@ fetchApi.put = (url, json, options) => fetchApi(
}
)
-
export function getDiff(obj1, obj2) {
let result = {...obj2};
for(let key in obj1) {
if(_.isEqual(obj2[key], obj1[key]))
- result[key] = undefined;
+ result[key] = undefined
else if(!(Array.isArray(obj2[key]) && Array.isArray(obj1[key])) &&
typeof obj2[key] == 'object' && typeof obj1[key] == 'object')
- result[key] = getDiff(obj1[key], obj2[key]);
+ result[key] = getDiff(obj1[key], obj2[key])
+ }
+ return result
+}
+
+export const pure = renderFn => class extends React.Component {
+ static displayName = renderFn.name
+
+ shouldComponentUpdate(nextProps) {
+ console.log(!shallowEqual(this.props, nextProps))
+ return !shallowEqual(this.props, nextProps)
+ }
+
+ render() {
+ return renderFn(this.props)
}
- return result;
}