diff options
author | madt1m <pietrotirenna.pt@gmail.com> | 2018-08-02 05:55:35 +0200 |
---|---|---|
committer | madt1m <pietrotirenna.pt@gmail.com> | 2018-08-02 05:55:35 +0200 |
commit | 4e0c10b88bc580712d45181aaf641af918457ff3 (patch) | |
tree | be453bfced89f212ba2577e3614cc9434e69f607 /mitmproxy | |
parent | a839d2ee2a5c668be1d5b2198f89bf44c6c7c78b (diff) | |
download | mitmproxy-4e0c10b88bc580712d45181aaf641af918457ff3.tar.gz mitmproxy-4e0c10b88bc580712d45181aaf641af918457ff3.tar.bz2 mitmproxy-4e0c10b88bc580712d45181aaf641af918457ff3.zip |
tests: 97% coverage reached. Session opportunely patched after emerged defects.
Diffstat (limited to 'mitmproxy')
-rw-r--r-- | mitmproxy/addons/session.py | 218 |
1 files changed, 144 insertions, 74 deletions
diff --git a/mitmproxy/addons/session.py b/mitmproxy/addons/session.py index c08097ee..f9d3af3f 100644 --- a/mitmproxy/addons/session.py +++ b/mitmproxy/addons/session.py @@ -1,9 +1,11 @@ +import collections import tempfile import asyncio import typing import bisect import shutil import sqlite3 +import copy import os from mitmproxy import flowfilter @@ -48,6 +50,7 @@ class SessionDB: or create a new one with optional path. :param db_path: """ + self.live_components = {} self.tempdir = None self.con = None # This is used for fast look-ups over bodies already dumped to database. @@ -71,11 +74,14 @@ class SessionDB: shutil.rmtree(self.tempdir) def __contains__(self, fid): - return fid in self._get_ids() + return any([fid == i for i in self._get_ids()]) + + def __len__(self): + ln = self.con.execute("SELECT COUNT(*) FROM flow;").fetchall()[0] + return ln[0] if ln else 0 def _get_ids(self): - with self.con as con: - return [t[0] for t in con.execute("SELECT id FROM flow;").fetchall()] + return [t[0] for t in self.con.execute("SELECT id FROM flow;").fetchall()] def _load_session(self, path): if not self.is_session_db(path): @@ -85,8 +91,8 @@ class SessionDB: def _create_session(self): script_path = pkg_data.path("io/sql/session_create.sql") qry = open(script_path, 'r').read() - with self.con: - self.con.executescript(qry) + self.con.executescript(qry) + self.con.commit() @staticmethod def is_session_db(path): @@ -110,62 +116,104 @@ class SessionDB: c.close() return False + def _disassemble(self, flow): + # Some live components of flows cannot be serialized, but they are needed to ensure correct functionality. + # We solve this by keeping a list of tuples which "save" those components for each flow id, eventually + # adding them back when needed. + self.live_components[flow.id] = ( + flow.client_conn.wfile, + flow.client_conn.rfile, + flow.client_conn.reply, + flow.server_conn.wfile, + flow.server_conn.rfile, + flow.server_conn.reply, + (flow.server_conn.via.wfile, flow.server_conn.via.rfile, + flow.server_conn.via.reply) if flow.server_conn.via else None, + flow.reply + ) + + def _reassemble(self, flow): + if flow.id in self.live_components: + cwf, crf, crp, swf, srf, srp, via, rep = self.live_components[flow.id] + flow.client_conn.wfile = cwf + flow.client_conn.rfile = crf + flow.client_conn.reply = crp + flow.server_conn.wfile = swf + flow.server_conn.rfile = srf + flow.server_conn.reply = srp + flow.reply = rep + if via: + flow.server_conn.via.rfile, flow.server_conn.via.wfile, flow.server_conn.via.reply = via + return flow + def store_flows(self, flows): body_buf = [] flow_buf = [] for flow in flows: - if len(flow.request.content) > self.content_threshold: - body_buf.append((flow.id, self.type_mappings["body"][1], flow.request.content)) - flow.request.content = b"" - self.body_ledger.add(flow.id) - if flow.response and flow.id not in self.body_ledger: - if len(flow.response.content) > self.content_threshold: - body_buf.append((flow.id, self.type_mappings["body"][2], flow.response.content)) - flow.response.content = b"" - flow_buf.append((flow.id, protobuf.dumps(flow))) - with self.con as con: - con.executemany("INSERT OR REPLACE INTO flow VALUES(?, ?)", flow_buf) - con.executemany("INSERT INTO body VALUES(?, ?, ?)", body_buf) + self._disassemble(flow) + f = copy.copy(flow) + f.request = copy.deepcopy(flow.request) + if flow.response: + f.response = copy.deepcopy(flow.response) + f.id = flow.id + if len(f.request.content) > self.content_threshold and f.id not in self.body_ledger: + body_buf.append((f.id, 1, f.request.content)) + f.request.content = b"" + self.body_ledger.add(f.id) + if f.response and f.id not in self.body_ledger: + if len(f.response.content) > self.content_threshold: + body_buf.append((f.id, 2, f.response.content)) + f.response.content = b"" + flow_buf.append((f.id, protobuf.dumps(f))) + self.con.executemany("INSERT OR REPLACE INTO flow VALUES(?, ?);", flow_buf) + if body_buf: + self.con.executemany("INSERT INTO body (flow_id, type_id, content) VALUES(?, ?, ?);", body_buf) + self.con.commit() def retrieve_flows(self, ids=None): flows = [] with self.con as con: if not ids: sql = "SELECT f.content, b.type_id, b.content " \ - "FROM flow f, body b " \ - "WHERE f.id = b.flow_id;" + "FROM flow f " \ + "LEFT OUTER JOIN body b ON f.id = b.flow_id;" rows = con.execute(sql).fetchall() else: sql = "SELECT f.content, b.type_id, b.content " \ - "FROM flow f, body b " \ - "WHERE f.id = b.flow_id" \ - f"AND f.id IN ({','.join(['?' for _ in range(len(ids))])})" + "FROM flow f " \ + "LEFT OUTER JOIN body b ON f.id = b.flow_id " \ + f"AND f.id IN ({','.join(['?' for _ in range(len(ids))])});" rows = con.execute(sql, ids).fetchall() for row in rows: flow = protobuf.loads(row[0]) - typ = self.type_mappings["body"][row[1]] - if typ and row[2]: - setattr(getattr(flow, typ), "content", row[2]) + if row[1]: + typ = self.type_mappings["body"][row[1]] + if typ and row[2]: + setattr(getattr(flow, typ), "content", row[2]) + flow = self._reassemble(flow) flows.append(flow) return flows + def clear(self): + self.con.executescript("DELETE FROM body; DELETE FROM annotation; DELETE FROM flow;") + matchall = flowfilter.parse(".") orders = [ - ("t", "time"), - ("m", "method"), - ("u", "url"), - ("z", "size") + "time", + "method", + "url", + "size" ] class Session: def __init__(self): - self.dbstore = SessionDB(ctx.options.session_path) - self._hot_store = [] - self._view = [] + self.db_store = None + self._hot_store = collections.OrderedDict() self._live_components = {} + self._view = [] self.order = orders[0] self.filter = matchall self._flush_period = 3.0 @@ -191,6 +239,7 @@ class Session: def running(self): if not self.started: self.started = True + self.db_store = SessionDB(ctx.options.session_path) loop = asyncio.get_event_loop() tasks = (self._writer, self._tweaker) loop.create_task(asyncio.gather(*(t() for t in tasks))) @@ -201,6 +250,60 @@ class Session: if "view_filter" in updated: self.set_filter(ctx.options.view_filter) + async def _writer(self): + while True: + await asyncio.sleep(self._flush_period) + tof = [] + to_dump = min(self._flush_rate, len(self._hot_store)) + for _ in range(to_dump): + tof.append(self._hot_store.popitem(last=False)[1]) + self.db_store.store_flows(tof) + + async def _tweaker(self): + while True: + await asyncio.sleep(self._tweak_period) + if len(self._hot_store) >= 3 * self._flush_rate: + self._flush_period *= 0.9 + self._flush_rate *= 1.1 + elif len(self._hot_store) < self._flush_rate: + self._flush_period *= 1.1 + self._flush_rate *= 0.9 + + def load_view(self, ids=None): + flows = [] + ids_from_store = [] + if ids is None: + ids = [fid for _, fid in self._view] + for fid in ids: + # Flow could be at the same time in database and in hot storage. We want the most updated version. + if fid in self._hot_store: + flows.append(self._hot_store[fid]) + elif fid in self.db_store: + ids_from_store.append(fid) + else: + flows.append(None) + flows += self.db_store.retrieve_flows(ids_from_store) + return flows + + def load_storage(self): + flows = [] + flows += self.db_store.retrieve_flows() + for flow in self._hot_store.values(): + flows.append(flow) + return flows + + def clear_storage(self): + self.db_store.clear() + self._hot_store.clear() + self._view = [] + + def store_count(self): + ln = 0 + for fid in self._hot_store.keys(): + if fid not in self.db_store: + ln += 1 + return ln + len(self.db_store) + def _generate_order(self, f: http.HTTPFlow) -> typing.Union[str, int, float]: o = self.order if o == "time": @@ -225,19 +328,19 @@ class Session: if order != self.order: self.order = order newview = [ - (self._generate_order(f), f.id) for f in self.dbstore.retrieve_flows([t[0] for t in self._view]) + (self._generate_order(f), f.id) for f in self.load_view() ] self._view = sorted(newview) def _refilter(self): self._view = [] - flows = self.dbstore.retrieve_flows() + flows = self.load_storage() for f in flows: if self.filter(f): self._base_add(f) - def set_filter(self, input_filter: str) -> None: - filt = flowfilter.parse(input_filter) + def set_filter(self, input_filter: typing.Optional[str]) -> None: + filt = matchall if not input_filter else flowfilter.parse(input_filter) if not filt: raise CommandError( "Invalid interception filter: %s" % filt @@ -245,54 +348,21 @@ class Session: self.filter = filt self._refilter() - async def _writer(self): - while True: - await asyncio.sleep(self._flush_period) - tof = [] - to_dump = min(self._flush_rate, len(self._hot_store)) - for _ in range(to_dump): - tof.append(self._hot_store.pop()) - self.store(tof) - - async def _tweaker(self): - while True: - await asyncio.sleep(self._tweak_period) - if len(self._hot_store) >= self._flush_rate: - self._flush_period *= 0.9 - self._flush_rate *= 0.9 - elif len(self._hot_store) < self._flush_rate: - self._flush_period *= 1.1 - self._flush_rate *= 1.1 - - def store(self, flows: typing.Sequence[http.HTTPFlow]) -> None: - # Some live components of flows cannot be serialized, but they are needed to ensure correct functionality. - # We solve this by keeping a list of tuples which "save" those components for each flow id, eventually - # adding them back when needed. - for f in flows: - self._live_components[f.id] = ( - f.client_conn.wfile or None, - f.client_conn.rfile or None, - f.server_conn.wfile or None, - f.server_conn.rfile or None, - f.reply or None - ) - self.dbstore.store_flows(flows) - def _base_add(self, f): - if f.id not in self._view: + if not any([f.id == t[1] for t in self._view]): o = self._generate_order(f) self._view.insert(bisect.bisect_left(KeyifyList(self._view, lambda x: x[0]), o), (o, f.id)) else: o = self._generate_order(f) - self._view = [flow for flow in self._view if flow.id != f.id] + self._view = [(order, fid) for order, fid in self._view if fid != f.id] self._view.insert(bisect.bisect_left(KeyifyList(self._view, lambda x: x[0]), o), (o, f.id)) def update(self, flows: typing.Sequence[http.HTTPFlow]) -> None: for f in flows: + if f.id in self._hot_store: + self._hot_store.pop(f.id) + self._hot_store[f.id] = f if self.filter(f): - if f.id in [f.id for f in self._hot_store]: - self._hot_store = [flow for flow in self._hot_store if flow.id != f.id] - self._hot_store.append(f) self._base_add(f) def request(self, f): |