aboutsummaryrefslogtreecommitdiffstats
path: root/mitmproxy
diff options
context:
space:
mode:
authormadt1m <pietrotirenna.pt@gmail.com>2018-08-02 05:55:35 +0200
committermadt1m <pietrotirenna.pt@gmail.com>2018-08-02 05:55:35 +0200
commit4e0c10b88bc580712d45181aaf641af918457ff3 (patch)
treebe453bfced89f212ba2577e3614cc9434e69f607 /mitmproxy
parenta839d2ee2a5c668be1d5b2198f89bf44c6c7c78b (diff)
downloadmitmproxy-4e0c10b88bc580712d45181aaf641af918457ff3.tar.gz
mitmproxy-4e0c10b88bc580712d45181aaf641af918457ff3.tar.bz2
mitmproxy-4e0c10b88bc580712d45181aaf641af918457ff3.zip
tests: 97% coverage reached. Session opportunely patched after emerged defects.
Diffstat (limited to 'mitmproxy')
-rw-r--r--mitmproxy/addons/session.py218
1 files changed, 144 insertions, 74 deletions
diff --git a/mitmproxy/addons/session.py b/mitmproxy/addons/session.py
index c08097ee..f9d3af3f 100644
--- a/mitmproxy/addons/session.py
+++ b/mitmproxy/addons/session.py
@@ -1,9 +1,11 @@
+import collections
import tempfile
import asyncio
import typing
import bisect
import shutil
import sqlite3
+import copy
import os
from mitmproxy import flowfilter
@@ -48,6 +50,7 @@ class SessionDB:
or create a new one with optional path.
:param db_path:
"""
+ self.live_components = {}
self.tempdir = None
self.con = None
# This is used for fast look-ups over bodies already dumped to database.
@@ -71,11 +74,14 @@ class SessionDB:
shutil.rmtree(self.tempdir)
def __contains__(self, fid):
- return fid in self._get_ids()
+ return any([fid == i for i in self._get_ids()])
+
+ def __len__(self):
+ ln = self.con.execute("SELECT COUNT(*) FROM flow;").fetchall()[0]
+ return ln[0] if ln else 0
def _get_ids(self):
- with self.con as con:
- return [t[0] for t in con.execute("SELECT id FROM flow;").fetchall()]
+ return [t[0] for t in self.con.execute("SELECT id FROM flow;").fetchall()]
def _load_session(self, path):
if not self.is_session_db(path):
@@ -85,8 +91,8 @@ class SessionDB:
def _create_session(self):
script_path = pkg_data.path("io/sql/session_create.sql")
qry = open(script_path, 'r').read()
- with self.con:
- self.con.executescript(qry)
+ self.con.executescript(qry)
+ self.con.commit()
@staticmethod
def is_session_db(path):
@@ -110,62 +116,104 @@ class SessionDB:
c.close()
return False
+ def _disassemble(self, flow):
+ # Some live components of flows cannot be serialized, but they are needed to ensure correct functionality.
+ # We solve this by keeping a list of tuples which "save" those components for each flow id, eventually
+ # adding them back when needed.
+ self.live_components[flow.id] = (
+ flow.client_conn.wfile,
+ flow.client_conn.rfile,
+ flow.client_conn.reply,
+ flow.server_conn.wfile,
+ flow.server_conn.rfile,
+ flow.server_conn.reply,
+ (flow.server_conn.via.wfile, flow.server_conn.via.rfile,
+ flow.server_conn.via.reply) if flow.server_conn.via else None,
+ flow.reply
+ )
+
+ def _reassemble(self, flow):
+ if flow.id in self.live_components:
+ cwf, crf, crp, swf, srf, srp, via, rep = self.live_components[flow.id]
+ flow.client_conn.wfile = cwf
+ flow.client_conn.rfile = crf
+ flow.client_conn.reply = crp
+ flow.server_conn.wfile = swf
+ flow.server_conn.rfile = srf
+ flow.server_conn.reply = srp
+ flow.reply = rep
+ if via:
+ flow.server_conn.via.rfile, flow.server_conn.via.wfile, flow.server_conn.via.reply = via
+ return flow
+
def store_flows(self, flows):
body_buf = []
flow_buf = []
for flow in flows:
- if len(flow.request.content) > self.content_threshold:
- body_buf.append((flow.id, self.type_mappings["body"][1], flow.request.content))
- flow.request.content = b""
- self.body_ledger.add(flow.id)
- if flow.response and flow.id not in self.body_ledger:
- if len(flow.response.content) > self.content_threshold:
- body_buf.append((flow.id, self.type_mappings["body"][2], flow.response.content))
- flow.response.content = b""
- flow_buf.append((flow.id, protobuf.dumps(flow)))
- with self.con as con:
- con.executemany("INSERT OR REPLACE INTO flow VALUES(?, ?)", flow_buf)
- con.executemany("INSERT INTO body VALUES(?, ?, ?)", body_buf)
+ self._disassemble(flow)
+ f = copy.copy(flow)
+ f.request = copy.deepcopy(flow.request)
+ if flow.response:
+ f.response = copy.deepcopy(flow.response)
+ f.id = flow.id
+ if len(f.request.content) > self.content_threshold and f.id not in self.body_ledger:
+ body_buf.append((f.id, 1, f.request.content))
+ f.request.content = b""
+ self.body_ledger.add(f.id)
+ if f.response and f.id not in self.body_ledger:
+ if len(f.response.content) > self.content_threshold:
+ body_buf.append((f.id, 2, f.response.content))
+ f.response.content = b""
+ flow_buf.append((f.id, protobuf.dumps(f)))
+ self.con.executemany("INSERT OR REPLACE INTO flow VALUES(?, ?);", flow_buf)
+ if body_buf:
+ self.con.executemany("INSERT INTO body (flow_id, type_id, content) VALUES(?, ?, ?);", body_buf)
+ self.con.commit()
def retrieve_flows(self, ids=None):
flows = []
with self.con as con:
if not ids:
sql = "SELECT f.content, b.type_id, b.content " \
- "FROM flow f, body b " \
- "WHERE f.id = b.flow_id;"
+ "FROM flow f " \
+ "LEFT OUTER JOIN body b ON f.id = b.flow_id;"
rows = con.execute(sql).fetchall()
else:
sql = "SELECT f.content, b.type_id, b.content " \
- "FROM flow f, body b " \
- "WHERE f.id = b.flow_id" \
- f"AND f.id IN ({','.join(['?' for _ in range(len(ids))])})"
+ "FROM flow f " \
+ "LEFT OUTER JOIN body b ON f.id = b.flow_id " \
+ f"AND f.id IN ({','.join(['?' for _ in range(len(ids))])});"
rows = con.execute(sql, ids).fetchall()
for row in rows:
flow = protobuf.loads(row[0])
- typ = self.type_mappings["body"][row[1]]
- if typ and row[2]:
- setattr(getattr(flow, typ), "content", row[2])
+ if row[1]:
+ typ = self.type_mappings["body"][row[1]]
+ if typ and row[2]:
+ setattr(getattr(flow, typ), "content", row[2])
+ flow = self._reassemble(flow)
flows.append(flow)
return flows
+ def clear(self):
+ self.con.executescript("DELETE FROM body; DELETE FROM annotation; DELETE FROM flow;")
+
matchall = flowfilter.parse(".")
orders = [
- ("t", "time"),
- ("m", "method"),
- ("u", "url"),
- ("z", "size")
+ "time",
+ "method",
+ "url",
+ "size"
]
class Session:
def __init__(self):
- self.dbstore = SessionDB(ctx.options.session_path)
- self._hot_store = []
- self._view = []
+ self.db_store = None
+ self._hot_store = collections.OrderedDict()
self._live_components = {}
+ self._view = []
self.order = orders[0]
self.filter = matchall
self._flush_period = 3.0
@@ -191,6 +239,7 @@ class Session:
def running(self):
if not self.started:
self.started = True
+ self.db_store = SessionDB(ctx.options.session_path)
loop = asyncio.get_event_loop()
tasks = (self._writer, self._tweaker)
loop.create_task(asyncio.gather(*(t() for t in tasks)))
@@ -201,6 +250,60 @@ class Session:
if "view_filter" in updated:
self.set_filter(ctx.options.view_filter)
+ async def _writer(self):
+ while True:
+ await asyncio.sleep(self._flush_period)
+ tof = []
+ to_dump = min(self._flush_rate, len(self._hot_store))
+ for _ in range(to_dump):
+ tof.append(self._hot_store.popitem(last=False)[1])
+ self.db_store.store_flows(tof)
+
+ async def _tweaker(self):
+ while True:
+ await asyncio.sleep(self._tweak_period)
+ if len(self._hot_store) >= 3 * self._flush_rate:
+ self._flush_period *= 0.9
+ self._flush_rate *= 1.1
+ elif len(self._hot_store) < self._flush_rate:
+ self._flush_period *= 1.1
+ self._flush_rate *= 0.9
+
+ def load_view(self, ids=None):
+ flows = []
+ ids_from_store = []
+ if ids is None:
+ ids = [fid for _, fid in self._view]
+ for fid in ids:
+ # Flow could be at the same time in database and in hot storage. We want the most updated version.
+ if fid in self._hot_store:
+ flows.append(self._hot_store[fid])
+ elif fid in self.db_store:
+ ids_from_store.append(fid)
+ else:
+ flows.append(None)
+ flows += self.db_store.retrieve_flows(ids_from_store)
+ return flows
+
+ def load_storage(self):
+ flows = []
+ flows += self.db_store.retrieve_flows()
+ for flow in self._hot_store.values():
+ flows.append(flow)
+ return flows
+
+ def clear_storage(self):
+ self.db_store.clear()
+ self._hot_store.clear()
+ self._view = []
+
+ def store_count(self):
+ ln = 0
+ for fid in self._hot_store.keys():
+ if fid not in self.db_store:
+ ln += 1
+ return ln + len(self.db_store)
+
def _generate_order(self, f: http.HTTPFlow) -> typing.Union[str, int, float]:
o = self.order
if o == "time":
@@ -225,19 +328,19 @@ class Session:
if order != self.order:
self.order = order
newview = [
- (self._generate_order(f), f.id) for f in self.dbstore.retrieve_flows([t[0] for t in self._view])
+ (self._generate_order(f), f.id) for f in self.load_view()
]
self._view = sorted(newview)
def _refilter(self):
self._view = []
- flows = self.dbstore.retrieve_flows()
+ flows = self.load_storage()
for f in flows:
if self.filter(f):
self._base_add(f)
- def set_filter(self, input_filter: str) -> None:
- filt = flowfilter.parse(input_filter)
+ def set_filter(self, input_filter: typing.Optional[str]) -> None:
+ filt = matchall if not input_filter else flowfilter.parse(input_filter)
if not filt:
raise CommandError(
"Invalid interception filter: %s" % filt
@@ -245,54 +348,21 @@ class Session:
self.filter = filt
self._refilter()
- async def _writer(self):
- while True:
- await asyncio.sleep(self._flush_period)
- tof = []
- to_dump = min(self._flush_rate, len(self._hot_store))
- for _ in range(to_dump):
- tof.append(self._hot_store.pop())
- self.store(tof)
-
- async def _tweaker(self):
- while True:
- await asyncio.sleep(self._tweak_period)
- if len(self._hot_store) >= self._flush_rate:
- self._flush_period *= 0.9
- self._flush_rate *= 0.9
- elif len(self._hot_store) < self._flush_rate:
- self._flush_period *= 1.1
- self._flush_rate *= 1.1
-
- def store(self, flows: typing.Sequence[http.HTTPFlow]) -> None:
- # Some live components of flows cannot be serialized, but they are needed to ensure correct functionality.
- # We solve this by keeping a list of tuples which "save" those components for each flow id, eventually
- # adding them back when needed.
- for f in flows:
- self._live_components[f.id] = (
- f.client_conn.wfile or None,
- f.client_conn.rfile or None,
- f.server_conn.wfile or None,
- f.server_conn.rfile or None,
- f.reply or None
- )
- self.dbstore.store_flows(flows)
-
def _base_add(self, f):
- if f.id not in self._view:
+ if not any([f.id == t[1] for t in self._view]):
o = self._generate_order(f)
self._view.insert(bisect.bisect_left(KeyifyList(self._view, lambda x: x[0]), o), (o, f.id))
else:
o = self._generate_order(f)
- self._view = [flow for flow in self._view if flow.id != f.id]
+ self._view = [(order, fid) for order, fid in self._view if fid != f.id]
self._view.insert(bisect.bisect_left(KeyifyList(self._view, lambda x: x[0]), o), (o, f.id))
def update(self, flows: typing.Sequence[http.HTTPFlow]) -> None:
for f in flows:
+ if f.id in self._hot_store:
+ self._hot_store.pop(f.id)
+ self._hot_store[f.id] = f
if self.filter(f):
- if f.id in [f.id for f in self._hot_store]:
- self._hot_store = [flow for flow in self._hot_store if flow.id != f.id]
- self._hot_store.append(f)
self._base_add(f)
def request(self, f):