aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mitmproxy/addons/session.py218
-rw-r--r--test/mitmproxy/addons/test_session.py153
2 files changed, 296 insertions, 75 deletions
diff --git a/mitmproxy/addons/session.py b/mitmproxy/addons/session.py
index c08097ee..f9d3af3f 100644
--- a/mitmproxy/addons/session.py
+++ b/mitmproxy/addons/session.py
@@ -1,9 +1,11 @@
+import collections
import tempfile
import asyncio
import typing
import bisect
import shutil
import sqlite3
+import copy
import os
from mitmproxy import flowfilter
@@ -48,6 +50,7 @@ class SessionDB:
or create a new one with optional path.
:param db_path:
"""
+ self.live_components = {}
self.tempdir = None
self.con = None
# This is used for fast look-ups over bodies already dumped to database.
@@ -71,11 +74,14 @@ class SessionDB:
shutil.rmtree(self.tempdir)
def __contains__(self, fid):
- return fid in self._get_ids()
+ return any([fid == i for i in self._get_ids()])
+
+ def __len__(self):
+ ln = self.con.execute("SELECT COUNT(*) FROM flow;").fetchall()[0]
+ return ln[0] if ln else 0
def _get_ids(self):
- with self.con as con:
- return [t[0] for t in con.execute("SELECT id FROM flow;").fetchall()]
+ return [t[0] for t in self.con.execute("SELECT id FROM flow;").fetchall()]
def _load_session(self, path):
if not self.is_session_db(path):
@@ -85,8 +91,8 @@ class SessionDB:
def _create_session(self):
script_path = pkg_data.path("io/sql/session_create.sql")
qry = open(script_path, 'r').read()
- with self.con:
- self.con.executescript(qry)
+ self.con.executescript(qry)
+ self.con.commit()
@staticmethod
def is_session_db(path):
@@ -110,62 +116,104 @@ class SessionDB:
c.close()
return False
+ def _disassemble(self, flow):
+ # Some live components of flows cannot be serialized, but they are needed to ensure correct functionality.
+ # We solve this by keeping a list of tuples which "save" those components for each flow id, eventually
+ # adding them back when needed.
+ self.live_components[flow.id] = (
+ flow.client_conn.wfile,
+ flow.client_conn.rfile,
+ flow.client_conn.reply,
+ flow.server_conn.wfile,
+ flow.server_conn.rfile,
+ flow.server_conn.reply,
+ (flow.server_conn.via.wfile, flow.server_conn.via.rfile,
+ flow.server_conn.via.reply) if flow.server_conn.via else None,
+ flow.reply
+ )
+
+ def _reassemble(self, flow):
+ if flow.id in self.live_components:
+ cwf, crf, crp, swf, srf, srp, via, rep = self.live_components[flow.id]
+ flow.client_conn.wfile = cwf
+ flow.client_conn.rfile = crf
+ flow.client_conn.reply = crp
+ flow.server_conn.wfile = swf
+ flow.server_conn.rfile = srf
+ flow.server_conn.reply = srp
+ flow.reply = rep
+ if via:
+ flow.server_conn.via.rfile, flow.server_conn.via.wfile, flow.server_conn.via.reply = via
+ return flow
+
def store_flows(self, flows):
body_buf = []
flow_buf = []
for flow in flows:
- if len(flow.request.content) > self.content_threshold:
- body_buf.append((flow.id, self.type_mappings["body"][1], flow.request.content))
- flow.request.content = b""
- self.body_ledger.add(flow.id)
- if flow.response and flow.id not in self.body_ledger:
- if len(flow.response.content) > self.content_threshold:
- body_buf.append((flow.id, self.type_mappings["body"][2], flow.response.content))
- flow.response.content = b""
- flow_buf.append((flow.id, protobuf.dumps(flow)))
- with self.con as con:
- con.executemany("INSERT OR REPLACE INTO flow VALUES(?, ?)", flow_buf)
- con.executemany("INSERT INTO body VALUES(?, ?, ?)", body_buf)
+ self._disassemble(flow)
+ f = copy.copy(flow)
+ f.request = copy.deepcopy(flow.request)
+ if flow.response:
+ f.response = copy.deepcopy(flow.response)
+ f.id = flow.id
+ if len(f.request.content) > self.content_threshold and f.id not in self.body_ledger:
+ body_buf.append((f.id, 1, f.request.content))
+ f.request.content = b""
+ self.body_ledger.add(f.id)
+ if f.response and f.id not in self.body_ledger:
+ if len(f.response.content) > self.content_threshold:
+ body_buf.append((f.id, 2, f.response.content))
+ f.response.content = b""
+ flow_buf.append((f.id, protobuf.dumps(f)))
+ self.con.executemany("INSERT OR REPLACE INTO flow VALUES(?, ?);", flow_buf)
+ if body_buf:
+ self.con.executemany("INSERT INTO body (flow_id, type_id, content) VALUES(?, ?, ?);", body_buf)
+ self.con.commit()
def retrieve_flows(self, ids=None):
flows = []
with self.con as con:
if not ids:
sql = "SELECT f.content, b.type_id, b.content " \
- "FROM flow f, body b " \
- "WHERE f.id = b.flow_id;"
+ "FROM flow f " \
+ "LEFT OUTER JOIN body b ON f.id = b.flow_id;"
rows = con.execute(sql).fetchall()
else:
sql = "SELECT f.content, b.type_id, b.content " \
- "FROM flow f, body b " \
- "WHERE f.id = b.flow_id" \
- f"AND f.id IN ({','.join(['?' for _ in range(len(ids))])})"
+ "FROM flow f " \
+ "LEFT OUTER JOIN body b ON f.id = b.flow_id " \
+ f"AND f.id IN ({','.join(['?' for _ in range(len(ids))])});"
rows = con.execute(sql, ids).fetchall()
for row in rows:
flow = protobuf.loads(row[0])
- typ = self.type_mappings["body"][row[1]]
- if typ and row[2]:
- setattr(getattr(flow, typ), "content", row[2])
+ if row[1]:
+ typ = self.type_mappings["body"][row[1]]
+ if typ and row[2]:
+ setattr(getattr(flow, typ), "content", row[2])
+ flow = self._reassemble(flow)
flows.append(flow)
return flows
+ def clear(self):
+ self.con.executescript("DELETE FROM body; DELETE FROM annotation; DELETE FROM flow;")
+
matchall = flowfilter.parse(".")
orders = [
- ("t", "time"),
- ("m", "method"),
- ("u", "url"),
- ("z", "size")
+ "time",
+ "method",
+ "url",
+ "size"
]
class Session:
def __init__(self):
- self.dbstore = SessionDB(ctx.options.session_path)
- self._hot_store = []
- self._view = []
+ self.db_store = None
+ self._hot_store = collections.OrderedDict()
self._live_components = {}
+ self._view = []
self.order = orders[0]
self.filter = matchall
self._flush_period = 3.0
@@ -191,6 +239,7 @@ class Session:
def running(self):
if not self.started:
self.started = True
+ self.db_store = SessionDB(ctx.options.session_path)
loop = asyncio.get_event_loop()
tasks = (self._writer, self._tweaker)
loop.create_task(asyncio.gather(*(t() for t in tasks)))
@@ -201,6 +250,60 @@ class Session:
if "view_filter" in updated:
self.set_filter(ctx.options.view_filter)
+ async def _writer(self):
+ while True:
+ await asyncio.sleep(self._flush_period)
+ tof = []
+ to_dump = min(self._flush_rate, len(self._hot_store))
+ for _ in range(to_dump):
+ tof.append(self._hot_store.popitem(last=False)[1])
+ self.db_store.store_flows(tof)
+
+ async def _tweaker(self):
+ while True:
+ await asyncio.sleep(self._tweak_period)
+ if len(self._hot_store) >= 3 * self._flush_rate:
+ self._flush_period *= 0.9
+ self._flush_rate *= 1.1
+ elif len(self._hot_store) < self._flush_rate:
+ self._flush_period *= 1.1
+ self._flush_rate *= 0.9
+
+ def load_view(self, ids=None):
+ flows = []
+ ids_from_store = []
+ if ids is None:
+ ids = [fid for _, fid in self._view]
+ for fid in ids:
+ # Flow could be at the same time in database and in hot storage. We want the most updated version.
+ if fid in self._hot_store:
+ flows.append(self._hot_store[fid])
+ elif fid in self.db_store:
+ ids_from_store.append(fid)
+ else:
+ flows.append(None)
+ flows += self.db_store.retrieve_flows(ids_from_store)
+ return flows
+
+ def load_storage(self):
+ flows = []
+ flows += self.db_store.retrieve_flows()
+ for flow in self._hot_store.values():
+ flows.append(flow)
+ return flows
+
+ def clear_storage(self):
+ self.db_store.clear()
+ self._hot_store.clear()
+ self._view = []
+
+ def store_count(self):
+ ln = 0
+ for fid in self._hot_store.keys():
+ if fid not in self.db_store:
+ ln += 1
+ return ln + len(self.db_store)
+
def _generate_order(self, f: http.HTTPFlow) -> typing.Union[str, int, float]:
o = self.order
if o == "time":
@@ -225,19 +328,19 @@ class Session:
if order != self.order:
self.order = order
newview = [
- (self._generate_order(f), f.id) for f in self.dbstore.retrieve_flows([t[0] for t in self._view])
+ (self._generate_order(f), f.id) for f in self.load_view()
]
self._view = sorted(newview)
def _refilter(self):
self._view = []
- flows = self.dbstore.retrieve_flows()
+ flows = self.load_storage()
for f in flows:
if self.filter(f):
self._base_add(f)
- def set_filter(self, input_filter: str) -> None:
- filt = flowfilter.parse(input_filter)
+ def set_filter(self, input_filter: typing.Optional[str]) -> None:
+ filt = matchall if not input_filter else flowfilter.parse(input_filter)
if not filt:
raise CommandError(
"Invalid interception filter: %s" % filt
@@ -245,54 +348,21 @@ class Session:
self.filter = filt
self._refilter()
- async def _writer(self):
- while True:
- await asyncio.sleep(self._flush_period)
- tof = []
- to_dump = min(self._flush_rate, len(self._hot_store))
- for _ in range(to_dump):
- tof.append(self._hot_store.pop())
- self.store(tof)
-
- async def _tweaker(self):
- while True:
- await asyncio.sleep(self._tweak_period)
- if len(self._hot_store) >= self._flush_rate:
- self._flush_period *= 0.9
- self._flush_rate *= 0.9
- elif len(self._hot_store) < self._flush_rate:
- self._flush_period *= 1.1
- self._flush_rate *= 1.1
-
- def store(self, flows: typing.Sequence[http.HTTPFlow]) -> None:
- # Some live components of flows cannot be serialized, but they are needed to ensure correct functionality.
- # We solve this by keeping a list of tuples which "save" those components for each flow id, eventually
- # adding them back when needed.
- for f in flows:
- self._live_components[f.id] = (
- f.client_conn.wfile or None,
- f.client_conn.rfile or None,
- f.server_conn.wfile or None,
- f.server_conn.rfile or None,
- f.reply or None
- )
- self.dbstore.store_flows(flows)
-
def _base_add(self, f):
- if f.id not in self._view:
+ if not any([f.id == t[1] for t in self._view]):
o = self._generate_order(f)
self._view.insert(bisect.bisect_left(KeyifyList(self._view, lambda x: x[0]), o), (o, f.id))
else:
o = self._generate_order(f)
- self._view = [flow for flow in self._view if flow.id != f.id]
+ self._view = [(order, fid) for order, fid in self._view if fid != f.id]
self._view.insert(bisect.bisect_left(KeyifyList(self._view, lambda x: x[0]), o), (o, f.id))
def update(self, flows: typing.Sequence[http.HTTPFlow]) -> None:
for f in flows:
+ if f.id in self._hot_store:
+ self._hot_store.pop(f.id)
+ self._hot_store[f.id] = f
if self.filter(f):
- if f.id in [f.id for f in self._hot_store]:
- self._hot_store = [flow for flow in self._hot_store if flow.id != f.id]
- self._hot_store.append(f)
self._base_add(f)
def request(self, f):
diff --git a/test/mitmproxy/addons/test_session.py b/test/mitmproxy/addons/test_session.py
index d4b1109b..41e8a401 100644
--- a/test/mitmproxy/addons/test_session.py
+++ b/test/mitmproxy/addons/test_session.py
@@ -1,13 +1,37 @@
import sqlite3
+import asyncio
import pytest
import os
+from mitmproxy import ctx
+from mitmproxy import http
+from mitmproxy.test import tflow, tutils
+from mitmproxy.test import taddons
from mitmproxy.addons import session
-from mitmproxy.exceptions import SessionLoadException
+from mitmproxy.exceptions import SessionLoadException, CommandError
from mitmproxy.utils.data import pkg_data
class TestSession:
+
+ @staticmethod
+ def tft(*, method="GET", start=0):
+ f = tflow.tflow()
+ f.request.method = method
+ f.request.timestamp_start = start
+ return f
+
+ @staticmethod
+ def start_session(fp=None):
+ s = session.Session()
+ tctx = taddons.context()
+ tctx.master.addons.add(s)
+ tctx.options.session_path = None
+ if fp:
+ s._flush_period = fp
+ s.running()
+ return s
+
def test_session_temporary(self):
s = session.SessionDB()
td = s.tempdir
@@ -56,3 +80,130 @@ class TestSession:
assert len(rows) == 1
con.close()
os.remove(path)
+
+ def test_session_order_generators(self):
+ s = session.Session()
+ tf = tflow.tflow(resp=True)
+
+ s.order = "time"
+ assert s._generate_order(tf) == 946681200
+
+ s.order = "method"
+ assert s._generate_order(tf) == tf.request.method
+
+ s.order = "url"
+ assert s._generate_order(tf) == tf.request.url
+
+ s.order = "size"
+ assert s._generate_order(tf) == len(tf.request.raw_content) + len(tf.response.raw_content)
+
+ def test_simple(self):
+ s = session.Session()
+ ctx.options = taddons.context()
+ ctx.options.session_path = None
+ s.running()
+ f = self.tft(start=1)
+ assert s.store_count() == 0
+ s.request(f)
+ assert s._view == [(1, f.id)]
+ assert s.load_view([f.id]) == [f]
+ assert s.load_view(['nonexistent']) == [None]
+
+ s.error(f)
+ s.response(f)
+ s.intercept(f)
+ s.resume(f)
+ s.kill(f)
+
+ # Verify that flow has been updated, not duplicated
+ assert s._view == [(1, f.id)]
+ assert s.store_count() == 1
+
+ f2 = self.tft(start=3)
+ s.request(f2)
+ assert s._view == [(1, f.id), (3, f2.id)]
+ s.request(f2)
+ assert s._view == [(1, f.id), (3, f2.id)]
+
+ f3 = self.tft(start=2)
+ s.request(f3)
+ assert s._view == [(1, f.id), (2, f3.id), (3, f2.id)]
+ s.request(f3)
+ assert s._view == [(1, f.id), (2, f3.id), (3, f2.id)]
+ assert s.store_count() == 3
+
+ s.clear_storage()
+ assert len(s._view) == 0
+ assert s.store_count() == 0
+
+ def test_filter(self):
+ s = self.start_session()
+ s.request(self.tft(method="get"))
+ s.request(self.tft(method="put"))
+ s.request(self.tft(method="get"))
+ s.request(self.tft(method="put"))
+ assert len(s._view) == 4
+ s.set_filter("~m get")
+ assert [f.request.method for f in s.load_view()] == ["GET", "GET"]
+ assert s.store_count() == 4
+ with pytest.raises(CommandError):
+ s.set_filter("~notafilter")
+ s.set_filter(None)
+ assert len(s._view) == 4
+
+ @pytest.mark.asyncio
+ async def test_flush_withspecials(self):
+ s = self.start_session(fp=0.5)
+ f = self.tft()
+ s.request(f)
+ await asyncio.sleep(2)
+ assert len(s._hot_store) == 0
+ assert all([lflow.__dict__ == flow.__dict__ for lflow, flow in list(zip(s.load_view(), [f]))])
+
+ f.server_conn.via = tflow.tserver_conn()
+ s.request(f)
+ await asyncio.sleep(1)
+ assert len(s._hot_store) == 0
+ assert all([lflow.__dict__ == flow.__dict__ for lflow, flow in list(zip(s.load_view(), [f]))])
+
+ flows = [self.tft() for _ in range(500)]
+ s.update(flows)
+ fp = s._flush_period
+ fr = s._flush_rate
+ await asyncio.sleep(0.6)
+ assert s._flush_period < fp and s._flush_rate > fr
+
+ @pytest.mark.asyncio
+ async def test_bodies(self):
+ # Need to test for configure
+ # Need to test for set_order
+ s = self.start_session(fp=0.5)
+ f = self.tft()
+ f2 = self.tft(start=1)
+ f.request.content = b"A"*1001
+ s.request(f)
+ s.request(f2)
+ await asyncio.sleep(1.0)
+ content = s.db_store.con.execute(
+ "SELECT type_id, content FROM body WHERE body.flow_id == (?);", [f.id]
+ ).fetchall()[0]
+ assert content == (1, b"A"*1001)
+ assert s.db_store.body_ledger == {f.id}
+ f.response = http.HTTPResponse.wrap(tutils.tresp(content=b"A"*1001))
+ f2.response = http.HTTPResponse.wrap(tutils.tresp(content=b"A"*1001))
+ # Content length is wrong for some reason -- quick fix
+ f.response.headers['content-length'] = b"1001"
+ f2.response.headers['content-length'] = b"1001"
+ s.response(f)
+ s.response(f2)
+ await asyncio.sleep(1.0)
+ rows = s.db_store.con.execute(
+ "SELECT type_id, content FROM body WHERE body.flow_id == (?);", [f.id]
+ ).fetchall()
+ assert len(rows) == 1
+ rows = s.db_store.con.execute(
+ "SELECT type_id, content FROM body WHERE body.flow_id == (?);", [f2.id]
+ ).fetchall()
+ assert len(rows) == 1
+ assert s.db_store.body_ledger == {f.id}
+ assert all([lf.__dict__ == rf.__dict__ for lf, rf in list(zip(s.load_view(), [f, f2]))])