From 53b85d23605409a46754f217d061f7d8a1d3cb6d Mon Sep 17 00:00:00 2001 From: Pietro Francesco Tirenna Date: Sat, 28 Jul 2018 17:44:02 +0200 Subject: session: adding methods to capture and store flows --- mitmproxy/addons/session.py | 107 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/mitmproxy/addons/session.py b/mitmproxy/addons/session.py index c49b95c4..6176bd5b 100644 --- a/mitmproxy/addons/session.py +++ b/mitmproxy/addons/session.py @@ -1,12 +1,30 @@ import tempfile +import asyncio +import typing +import bisect import shutil import sqlite3 import os +from mitmproxy import types +from mitmproxy import http +from mitmproxy import ctx from mitmproxy.exceptions import SessionLoadException from mitmproxy.utils.data import pkg_data +class KeyifyList(object): + def __init__(self, inner, key): + self.inner = inner + self.key = key + + def __len__(self): + return len(self.inner) + + def __getitem__(self, k): + return self.key(self.inner[k]) + + # Could be implemented using async libraries class SessionDB: """ @@ -71,3 +89,92 @@ class SessionDB: if c: c.close() return False + + +orders = [ + ("t", "time"), + ("m", "method"), + ("u", "url"), + ("z", "size") +] + + +class Session: + def __init__(self): + self.sdb = SessionDB(ctx.options.session_path) + self._hot_store = [] + self._view = [] + self.order = orders[0] + self._flush_period = 3.0 + self._flush_rate = 150 + + def load(self, loader): + loader.add_option( + "session_path", typing.Optional[types.Path], None, + "Path of session to load or to create." + ) + loader.add_option( + "view_order", str, "time", + "Flow sort order.", + choices=list(map(lambda c: c[1], orders)) + ) + + def _generate_order(self, f: http.HTTPFlow) -> typing.Union[str, int, float]: + o = self.order + if o == "time": + return f.request.timestamp_start or 0 + if o == "method": + return f.request.method + if o == "url": + return f.request.url + if o == "size": + s = 0 + if f.request.raw_content: + s += len(f.request.raw_content) + if f.response and f.response.raw_content: + s += len(f.response.raw_content) + return s + + async def _writer(self): + while True: + await asyncio.sleep(self._flush_period) + tof = [] + to_dump = min(self._flush_rate, len(self._hot_store)) + for _ in range(to_dump): + tof.append(self._hot_store.pop()) + self.store(tof) + + def store(self, flows: typing.Sequence[http.HTTPFlow]): + pass + + def running(self): + pass + + def add(self, flows: typing.Sequence[http.HTTPFlow]) -> None: + for f in flows: + if f.id not in [f.id for f in self._hot_store] and f.id not in self.sdb: + # Flow has to be filtered here before adding to view. Later + o = self._generate_order(f) + self._view.insert(bisect.bisect_left(KeyifyList(self._view, lambda x: x[0]), o), (o, f.id)) + self._hot_store.append(f) + + def update(self, flow): + pass + + def request(self, f): + self.add([f]) + + def error(self, f): + self.update([f]) + + def response(self, f): + self.update([f]) + + def intercept(self, f): + self.update([f]) + + def resume(self, f): + self.update([f]) + + def kill(self, f): + self.update([f]) -- cgit v1.2.3 From ccb5fd7c9981b65e8bb543076e89e381481340f7 Mon Sep 17 00:00:00 2001 From: madt1m Date: Wed, 1 Aug 2018 01:55:20 +0200 Subject: session: basic flow capture implemented --- mitmproxy/addons/session.py | 110 +++++++++++++++++++++++++++++++++++++++----- 1 file changed, 99 insertions(+), 11 deletions(-) diff --git a/mitmproxy/addons/session.py b/mitmproxy/addons/session.py index 6176bd5b..010d3616 100644 --- a/mitmproxy/addons/session.py +++ b/mitmproxy/addons/session.py @@ -9,7 +9,8 @@ import os from mitmproxy import types from mitmproxy import http from mitmproxy import ctx -from mitmproxy.exceptions import SessionLoadException +from mitmproxy.io import protobuf +from mitmproxy.exceptions import SessionLoadException, CommandError from mitmproxy.utils.data import pkg_data @@ -32,6 +33,13 @@ class SessionDB: for Sessions and handles creation, retrieving and insertion in tables. """ + content_threshold = 1000 + type_mappings = { + "body": { + "request" : 1, + "response" : 2 + } + } def __init__(self, db_path=None): """ @@ -58,6 +66,13 @@ class SessionDB: if self.tempdir: shutil.rmtree(self.tempdir) + def __contains__(self, fid): + return fid in self._get_ids() + + def _get_ids(self): + with self.con as con: + return [t[0] for t in con.execute("SELECT id FROM flow;").fetchall()] + def _load_session(self, path): if not self.is_session_db(path): raise SessionLoadException('Given path does not point to a valid Session') @@ -90,6 +105,23 @@ class SessionDB: c.close() return False + def store_flows(self, flows): + body_buf = [] + flow_buf = [] + for flow in flows: + if len(flow.request.content) > self.content_threshold: + body_buf.append((flow.id, self.type_mappings["body"]["request"], flow.request.content)) + flow.request.content = b"" + if flow.response: + if len(flow.response.content) > self.content_threshold: + body_buf.append((flow.id, self.type_mappings["body"]["response"], flow.response.content)) + flow.response.content = b"" + flow_buf.append((flow.id, protobuf.dumps(flow))) + with self.con as con: + con.executemany("INSERT OR REPLACE INTO flow VALUES(?, ?)", flow_buf) + con.executemany("INSERT INTO body VALUES(?, ?, ?)", body_buf) + + orders = [ ("t", "time"), @@ -101,12 +133,15 @@ orders = [ class Session: def __init__(self): - self.sdb = SessionDB(ctx.options.session_path) + self.dbstore = SessionDB(ctx.options.session_path) self._hot_store = [] self._view = [] + self._live_components = {} self.order = orders[0] self._flush_period = 3.0 + self._tweak_period = 0.5 self._flush_rate = 150 + self.started = False def load(self, loader): loader.add_option( @@ -118,6 +153,23 @@ class Session: "Flow sort order.", choices=list(map(lambda c: c[1], orders)) ) + loader.add_option( + "view_filter", typing.Optional[str], None, + "Limit the view to matching flows." + ) + + def running(self): + if not self.started: + self.started = True + loop = asyncio.get_event_loop() + tasks = (self._writer, self._tweaker) + loop.create_task(asyncio.gather(*(t() for t in tasks))) + + def configure(self, updated): + if "view_order" in updated: + self.set_order(ctx.options.view_order) + if "view_filter" in updated: + self.set_filter(ctx.options.view_filter) def _generate_order(self, f: http.HTTPFlow) -> typing.Union[str, int, float]: o = self.order @@ -135,6 +187,12 @@ class Session: s += len(f.response.raw_content) return s + def set_order(self, order: str) -> None: + pass + + def set_filter(self, filt: str) -> None: + pass + async def _writer(self): while True: await asyncio.sleep(self._flush_period) @@ -144,22 +202,52 @@ class Session: tof.append(self._hot_store.pop()) self.store(tof) - def store(self, flows: typing.Sequence[http.HTTPFlow]): - pass + async def _tweaker(self): + while True: + await asyncio.sleep(self._tweak_period) + if len(self._hot_store) >= self._flush_rate: + self._flush_period *= 0.9 + self._flush_rate *= 0.9 + elif len(self._hot_store) < self._flush_rate: + self._flush_period *= 1.1 + self._flush_rate *= 1.1 - def running(self): - pass + def store(self, flows: typing.Sequence[http.HTTPFlow]) -> None: + # Some live components of flows cannot be serialized, but they are needed to ensure correct functionality. + # We solve this by keeping a list of tuples which "save" those components for each flow id, eventually + # adding them back when needed. + for f in flows: + self._live_components[f.id] = ( + f.client_conn.wfile or None, + f.client_conn.rfile or None, + f.server_conn.wfile or None, + f.server_conn.rfile or None, + f.reply or None + ) + self.dbstore.store_flows(flows) + + def _base_add(self, f): + if f.id not in self._view: + o = self._generate_order(f) + self._view.insert(bisect.bisect_left(KeyifyList(self._view, lambda x: x[0]), o), (o, f.id)) + else: + o = self._generate_order(f) + self._view = [flow for flow in self._view if flow.id != f.id] + self._view.insert(bisect.bisect_left(KeyifyList(self._view, lambda x: x[0]), o), (o, f.id)) def add(self, flows: typing.Sequence[http.HTTPFlow]) -> None: for f in flows: - if f.id not in [f.id for f in self._hot_store] and f.id not in self.sdb: + if f.id not in [f.id for f in self._hot_store] and f.id not in self.dbstore: # Flow has to be filtered here before adding to view. Later - o = self._generate_order(f) - self._view.insert(bisect.bisect_left(KeyifyList(self._view, lambda x: x[0]), o), (o, f.id)) self._hot_store.append(f) + self._base_add(f) - def update(self, flow): - pass + def update(self, flows: typing.Sequence[http.HTTPFlow]) -> None: + for f in flows: + if f.id in [f.id for f in self._hot_store]: + self._hot_store = [flow for flow in self._hot_store if flow.id != f.id] + self._hot_store.append(f) + self._base_add(f) def request(self, f): self.add([f]) -- cgit v1.2.3 From afe41eb75cb7ec5d5edf54fe490bdde49294683b Mon Sep 17 00:00:00 2001 From: madt1m Date: Wed, 1 Aug 2018 12:00:07 +0200 Subject: protobuf: changed return type annotation in loads to enhance granularity --- mitmproxy/io/protobuf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mitmproxy/io/protobuf.py b/mitmproxy/io/protobuf.py index 9a00eacf..c8ca3acc 100644 --- a/mitmproxy/io/protobuf.py +++ b/mitmproxy/io/protobuf.py @@ -189,7 +189,7 @@ def load_http(hf: http_pb2.HTTPFlow) -> HTTPFlow: return f -def loads(b: bytes, typ="http") -> flow.Flow: +def loads(b: bytes, typ="http") -> typing.Union[HTTPFlow]: if typ != 'http': raise exceptions.TypeError("Flow types different than HTTP not supported yet!") else: -- cgit v1.2.3 From a839d2ee2a5c668be1d5b2198f89bf44c6c7c78b Mon Sep 17 00:00:00 2001 From: madt1m Date: Wed, 1 Aug 2018 12:00:28 +0200 Subject: session: implemented filter and refilter. Ready for testing implementation --- mitmproxy/addons/session.py | 88 ++++++++++++++++++++++++++++--------- mitmproxy/io/sql/session_create.sql | 2 + 2 files changed, 69 insertions(+), 21 deletions(-) diff --git a/mitmproxy/addons/session.py b/mitmproxy/addons/session.py index 010d3616..c08097ee 100644 --- a/mitmproxy/addons/session.py +++ b/mitmproxy/addons/session.py @@ -6,6 +6,7 @@ import shutil import sqlite3 import os +from mitmproxy import flowfilter from mitmproxy import types from mitmproxy import http from mitmproxy import ctx @@ -36,8 +37,8 @@ class SessionDB: content_threshold = 1000 type_mappings = { "body": { - "request" : 1, - "response" : 2 + 1: "request", + 2: "response" } } @@ -49,6 +50,9 @@ class SessionDB: """ self.tempdir = None self.con = None + # This is used for fast look-ups over bodies already dumped to database. + # This permits to enforce one-to-one relationship between flow and body table. + self.body_ledger = set() if db_path is not None and os.path.isfile(db_path): self._load_session(db_path) else: @@ -91,6 +95,7 @@ class SessionDB: is a valid Session SQLite DB. :return: True if valid, False if invalid. """ + c = None try: c = sqlite3.connect(f'file:{path}?mode=rw', uri=True) cursor = c.cursor() @@ -100,7 +105,7 @@ class SessionDB: if all(elem in rows for elem in tables): c.close() return True - except: + except sqlite3.Error: if c: c.close() return False @@ -110,18 +115,42 @@ class SessionDB: flow_buf = [] for flow in flows: if len(flow.request.content) > self.content_threshold: - body_buf.append((flow.id, self.type_mappings["body"]["request"], flow.request.content)) + body_buf.append((flow.id, self.type_mappings["body"][1], flow.request.content)) flow.request.content = b"" - if flow.response: + self.body_ledger.add(flow.id) + if flow.response and flow.id not in self.body_ledger: if len(flow.response.content) > self.content_threshold: - body_buf.append((flow.id, self.type_mappings["body"]["response"], flow.response.content)) + body_buf.append((flow.id, self.type_mappings["body"][2], flow.response.content)) flow.response.content = b"" flow_buf.append((flow.id, protobuf.dumps(flow))) with self.con as con: con.executemany("INSERT OR REPLACE INTO flow VALUES(?, ?)", flow_buf) con.executemany("INSERT INTO body VALUES(?, ?, ?)", body_buf) - + def retrieve_flows(self, ids=None): + flows = [] + with self.con as con: + if not ids: + sql = "SELECT f.content, b.type_id, b.content " \ + "FROM flow f, body b " \ + "WHERE f.id = b.flow_id;" + rows = con.execute(sql).fetchall() + else: + sql = "SELECT f.content, b.type_id, b.content " \ + "FROM flow f, body b " \ + "WHERE f.id = b.flow_id" \ + f"AND f.id IN ({','.join(['?' for _ in range(len(ids))])})" + rows = con.execute(sql, ids).fetchall() + for row in rows: + flow = protobuf.loads(row[0]) + typ = self.type_mappings["body"][row[1]] + if typ and row[2]: + setattr(getattr(flow, typ), "content", row[2]) + flows.append(flow) + return flows + + +matchall = flowfilter.parse(".") orders = [ ("t", "time"), @@ -138,6 +167,7 @@ class Session: self._view = [] self._live_components = {} self.order = orders[0] + self.filter = matchall self._flush_period = 3.0 self._tweak_period = 0.5 self._flush_rate = 150 @@ -188,10 +218,32 @@ class Session: return s def set_order(self, order: str) -> None: - pass + if order not in orders: + raise CommandError( + "Unknown flow order: %s" % order + ) + if order != self.order: + self.order = order + newview = [ + (self._generate_order(f), f.id) for f in self.dbstore.retrieve_flows([t[0] for t in self._view]) + ] + self._view = sorted(newview) + + def _refilter(self): + self._view = [] + flows = self.dbstore.retrieve_flows() + for f in flows: + if self.filter(f): + self._base_add(f) - def set_filter(self, filt: str) -> None: - pass + def set_filter(self, input_filter: str) -> None: + filt = flowfilter.parse(input_filter) + if not filt: + raise CommandError( + "Invalid interception filter: %s" % filt + ) + self.filter = filt + self._refilter() async def _writer(self): while True: @@ -235,22 +287,16 @@ class Session: self._view = [flow for flow in self._view if flow.id != f.id] self._view.insert(bisect.bisect_left(KeyifyList(self._view, lambda x: x[0]), o), (o, f.id)) - def add(self, flows: typing.Sequence[http.HTTPFlow]) -> None: + def update(self, flows: typing.Sequence[http.HTTPFlow]) -> None: for f in flows: - if f.id not in [f.id for f in self._hot_store] and f.id not in self.dbstore: - # Flow has to be filtered here before adding to view. Later + if self.filter(f): + if f.id in [f.id for f in self._hot_store]: + self._hot_store = [flow for flow in self._hot_store if flow.id != f.id] self._hot_store.append(f) self._base_add(f) - def update(self, flows: typing.Sequence[http.HTTPFlow]) -> None: - for f in flows: - if f.id in [f.id for f in self._hot_store]: - self._hot_store = [flow for flow in self._hot_store if flow.id != f.id] - self._hot_store.append(f) - self._base_add(f) - def request(self, f): - self.add([f]) + self.update([f]) def error(self, f): self.update([f]) diff --git a/mitmproxy/io/sql/session_create.sql b/mitmproxy/io/sql/session_create.sql index bfc98b94..b9c28c03 100644 --- a/mitmproxy/io/sql/session_create.sql +++ b/mitmproxy/io/sql/session_create.sql @@ -1,3 +1,5 @@ +PRAGMA foreign_keys = ON; + CREATE TABLE flow ( id VARCHAR(36) PRIMARY KEY, content BLOB -- cgit v1.2.3 From 4e0c10b88bc580712d45181aaf641af918457ff3 Mon Sep 17 00:00:00 2001 From: madt1m Date: Thu, 2 Aug 2018 05:55:35 +0200 Subject: tests: 97% coverage reached. Session opportunely patched after emerged defects. --- mitmproxy/addons/session.py | 218 ++++++++++++++++++++++------------ test/mitmproxy/addons/test_session.py | 153 +++++++++++++++++++++++- 2 files changed, 296 insertions(+), 75 deletions(-) diff --git a/mitmproxy/addons/session.py b/mitmproxy/addons/session.py index c08097ee..f9d3af3f 100644 --- a/mitmproxy/addons/session.py +++ b/mitmproxy/addons/session.py @@ -1,9 +1,11 @@ +import collections import tempfile import asyncio import typing import bisect import shutil import sqlite3 +import copy import os from mitmproxy import flowfilter @@ -48,6 +50,7 @@ class SessionDB: or create a new one with optional path. :param db_path: """ + self.live_components = {} self.tempdir = None self.con = None # This is used for fast look-ups over bodies already dumped to database. @@ -71,11 +74,14 @@ class SessionDB: shutil.rmtree(self.tempdir) def __contains__(self, fid): - return fid in self._get_ids() + return any([fid == i for i in self._get_ids()]) + + def __len__(self): + ln = self.con.execute("SELECT COUNT(*) FROM flow;").fetchall()[0] + return ln[0] if ln else 0 def _get_ids(self): - with self.con as con: - return [t[0] for t in con.execute("SELECT id FROM flow;").fetchall()] + return [t[0] for t in self.con.execute("SELECT id FROM flow;").fetchall()] def _load_session(self, path): if not self.is_session_db(path): @@ -85,8 +91,8 @@ class SessionDB: def _create_session(self): script_path = pkg_data.path("io/sql/session_create.sql") qry = open(script_path, 'r').read() - with self.con: - self.con.executescript(qry) + self.con.executescript(qry) + self.con.commit() @staticmethod def is_session_db(path): @@ -110,62 +116,104 @@ class SessionDB: c.close() return False + def _disassemble(self, flow): + # Some live components of flows cannot be serialized, but they are needed to ensure correct functionality. + # We solve this by keeping a list of tuples which "save" those components for each flow id, eventually + # adding them back when needed. + self.live_components[flow.id] = ( + flow.client_conn.wfile, + flow.client_conn.rfile, + flow.client_conn.reply, + flow.server_conn.wfile, + flow.server_conn.rfile, + flow.server_conn.reply, + (flow.server_conn.via.wfile, flow.server_conn.via.rfile, + flow.server_conn.via.reply) if flow.server_conn.via else None, + flow.reply + ) + + def _reassemble(self, flow): + if flow.id in self.live_components: + cwf, crf, crp, swf, srf, srp, via, rep = self.live_components[flow.id] + flow.client_conn.wfile = cwf + flow.client_conn.rfile = crf + flow.client_conn.reply = crp + flow.server_conn.wfile = swf + flow.server_conn.rfile = srf + flow.server_conn.reply = srp + flow.reply = rep + if via: + flow.server_conn.via.rfile, flow.server_conn.via.wfile, flow.server_conn.via.reply = via + return flow + def store_flows(self, flows): body_buf = [] flow_buf = [] for flow in flows: - if len(flow.request.content) > self.content_threshold: - body_buf.append((flow.id, self.type_mappings["body"][1], flow.request.content)) - flow.request.content = b"" - self.body_ledger.add(flow.id) - if flow.response and flow.id not in self.body_ledger: - if len(flow.response.content) > self.content_threshold: - body_buf.append((flow.id, self.type_mappings["body"][2], flow.response.content)) - flow.response.content = b"" - flow_buf.append((flow.id, protobuf.dumps(flow))) - with self.con as con: - con.executemany("INSERT OR REPLACE INTO flow VALUES(?, ?)", flow_buf) - con.executemany("INSERT INTO body VALUES(?, ?, ?)", body_buf) + self._disassemble(flow) + f = copy.copy(flow) + f.request = copy.deepcopy(flow.request) + if flow.response: + f.response = copy.deepcopy(flow.response) + f.id = flow.id + if len(f.request.content) > self.content_threshold and f.id not in self.body_ledger: + body_buf.append((f.id, 1, f.request.content)) + f.request.content = b"" + self.body_ledger.add(f.id) + if f.response and f.id not in self.body_ledger: + if len(f.response.content) > self.content_threshold: + body_buf.append((f.id, 2, f.response.content)) + f.response.content = b"" + flow_buf.append((f.id, protobuf.dumps(f))) + self.con.executemany("INSERT OR REPLACE INTO flow VALUES(?, ?);", flow_buf) + if body_buf: + self.con.executemany("INSERT INTO body (flow_id, type_id, content) VALUES(?, ?, ?);", body_buf) + self.con.commit() def retrieve_flows(self, ids=None): flows = [] with self.con as con: if not ids: sql = "SELECT f.content, b.type_id, b.content " \ - "FROM flow f, body b " \ - "WHERE f.id = b.flow_id;" + "FROM flow f " \ + "LEFT OUTER JOIN body b ON f.id = b.flow_id;" rows = con.execute(sql).fetchall() else: sql = "SELECT f.content, b.type_id, b.content " \ - "FROM flow f, body b " \ - "WHERE f.id = b.flow_id" \ - f"AND f.id IN ({','.join(['?' for _ in range(len(ids))])})" + "FROM flow f " \ + "LEFT OUTER JOIN body b ON f.id = b.flow_id " \ + f"AND f.id IN ({','.join(['?' for _ in range(len(ids))])});" rows = con.execute(sql, ids).fetchall() for row in rows: flow = protobuf.loads(row[0]) - typ = self.type_mappings["body"][row[1]] - if typ and row[2]: - setattr(getattr(flow, typ), "content", row[2]) + if row[1]: + typ = self.type_mappings["body"][row[1]] + if typ and row[2]: + setattr(getattr(flow, typ), "content", row[2]) + flow = self._reassemble(flow) flows.append(flow) return flows + def clear(self): + self.con.executescript("DELETE FROM body; DELETE FROM annotation; DELETE FROM flow;") + matchall = flowfilter.parse(".") orders = [ - ("t", "time"), - ("m", "method"), - ("u", "url"), - ("z", "size") + "time", + "method", + "url", + "size" ] class Session: def __init__(self): - self.dbstore = SessionDB(ctx.options.session_path) - self._hot_store = [] - self._view = [] + self.db_store = None + self._hot_store = collections.OrderedDict() self._live_components = {} + self._view = [] self.order = orders[0] self.filter = matchall self._flush_period = 3.0 @@ -191,6 +239,7 @@ class Session: def running(self): if not self.started: self.started = True + self.db_store = SessionDB(ctx.options.session_path) loop = asyncio.get_event_loop() tasks = (self._writer, self._tweaker) loop.create_task(asyncio.gather(*(t() for t in tasks))) @@ -201,6 +250,60 @@ class Session: if "view_filter" in updated: self.set_filter(ctx.options.view_filter) + async def _writer(self): + while True: + await asyncio.sleep(self._flush_period) + tof = [] + to_dump = min(self._flush_rate, len(self._hot_store)) + for _ in range(to_dump): + tof.append(self._hot_store.popitem(last=False)[1]) + self.db_store.store_flows(tof) + + async def _tweaker(self): + while True: + await asyncio.sleep(self._tweak_period) + if len(self._hot_store) >= 3 * self._flush_rate: + self._flush_period *= 0.9 + self._flush_rate *= 1.1 + elif len(self._hot_store) < self._flush_rate: + self._flush_period *= 1.1 + self._flush_rate *= 0.9 + + def load_view(self, ids=None): + flows = [] + ids_from_store = [] + if ids is None: + ids = [fid for _, fid in self._view] + for fid in ids: + # Flow could be at the same time in database and in hot storage. We want the most updated version. + if fid in self._hot_store: + flows.append(self._hot_store[fid]) + elif fid in self.db_store: + ids_from_store.append(fid) + else: + flows.append(None) + flows += self.db_store.retrieve_flows(ids_from_store) + return flows + + def load_storage(self): + flows = [] + flows += self.db_store.retrieve_flows() + for flow in self._hot_store.values(): + flows.append(flow) + return flows + + def clear_storage(self): + self.db_store.clear() + self._hot_store.clear() + self._view = [] + + def store_count(self): + ln = 0 + for fid in self._hot_store.keys(): + if fid not in self.db_store: + ln += 1 + return ln + len(self.db_store) + def _generate_order(self, f: http.HTTPFlow) -> typing.Union[str, int, float]: o = self.order if o == "time": @@ -225,19 +328,19 @@ class Session: if order != self.order: self.order = order newview = [ - (self._generate_order(f), f.id) for f in self.dbstore.retrieve_flows([t[0] for t in self._view]) + (self._generate_order(f), f.id) for f in self.load_view() ] self._view = sorted(newview) def _refilter(self): self._view = [] - flows = self.dbstore.retrieve_flows() + flows = self.load_storage() for f in flows: if self.filter(f): self._base_add(f) - def set_filter(self, input_filter: str) -> None: - filt = flowfilter.parse(input_filter) + def set_filter(self, input_filter: typing.Optional[str]) -> None: + filt = matchall if not input_filter else flowfilter.parse(input_filter) if not filt: raise CommandError( "Invalid interception filter: %s" % filt @@ -245,54 +348,21 @@ class Session: self.filter = filt self._refilter() - async def _writer(self): - while True: - await asyncio.sleep(self._flush_period) - tof = [] - to_dump = min(self._flush_rate, len(self._hot_store)) - for _ in range(to_dump): - tof.append(self._hot_store.pop()) - self.store(tof) - - async def _tweaker(self): - while True: - await asyncio.sleep(self._tweak_period) - if len(self._hot_store) >= self._flush_rate: - self._flush_period *= 0.9 - self._flush_rate *= 0.9 - elif len(self._hot_store) < self._flush_rate: - self._flush_period *= 1.1 - self._flush_rate *= 1.1 - - def store(self, flows: typing.Sequence[http.HTTPFlow]) -> None: - # Some live components of flows cannot be serialized, but they are needed to ensure correct functionality. - # We solve this by keeping a list of tuples which "save" those components for each flow id, eventually - # adding them back when needed. - for f in flows: - self._live_components[f.id] = ( - f.client_conn.wfile or None, - f.client_conn.rfile or None, - f.server_conn.wfile or None, - f.server_conn.rfile or None, - f.reply or None - ) - self.dbstore.store_flows(flows) - def _base_add(self, f): - if f.id not in self._view: + if not any([f.id == t[1] for t in self._view]): o = self._generate_order(f) self._view.insert(bisect.bisect_left(KeyifyList(self._view, lambda x: x[0]), o), (o, f.id)) else: o = self._generate_order(f) - self._view = [flow for flow in self._view if flow.id != f.id] + self._view = [(order, fid) for order, fid in self._view if fid != f.id] self._view.insert(bisect.bisect_left(KeyifyList(self._view, lambda x: x[0]), o), (o, f.id)) def update(self, flows: typing.Sequence[http.HTTPFlow]) -> None: for f in flows: + if f.id in self._hot_store: + self._hot_store.pop(f.id) + self._hot_store[f.id] = f if self.filter(f): - if f.id in [f.id for f in self._hot_store]: - self._hot_store = [flow for flow in self._hot_store if flow.id != f.id] - self._hot_store.append(f) self._base_add(f) def request(self, f): diff --git a/test/mitmproxy/addons/test_session.py b/test/mitmproxy/addons/test_session.py index d4b1109b..41e8a401 100644 --- a/test/mitmproxy/addons/test_session.py +++ b/test/mitmproxy/addons/test_session.py @@ -1,13 +1,37 @@ import sqlite3 +import asyncio import pytest import os +from mitmproxy import ctx +from mitmproxy import http +from mitmproxy.test import tflow, tutils +from mitmproxy.test import taddons from mitmproxy.addons import session -from mitmproxy.exceptions import SessionLoadException +from mitmproxy.exceptions import SessionLoadException, CommandError from mitmproxy.utils.data import pkg_data class TestSession: + + @staticmethod + def tft(*, method="GET", start=0): + f = tflow.tflow() + f.request.method = method + f.request.timestamp_start = start + return f + + @staticmethod + def start_session(fp=None): + s = session.Session() + tctx = taddons.context() + tctx.master.addons.add(s) + tctx.options.session_path = None + if fp: + s._flush_period = fp + s.running() + return s + def test_session_temporary(self): s = session.SessionDB() td = s.tempdir @@ -56,3 +80,130 @@ class TestSession: assert len(rows) == 1 con.close() os.remove(path) + + def test_session_order_generators(self): + s = session.Session() + tf = tflow.tflow(resp=True) + + s.order = "time" + assert s._generate_order(tf) == 946681200 + + s.order = "method" + assert s._generate_order(tf) == tf.request.method + + s.order = "url" + assert s._generate_order(tf) == tf.request.url + + s.order = "size" + assert s._generate_order(tf) == len(tf.request.raw_content) + len(tf.response.raw_content) + + def test_simple(self): + s = session.Session() + ctx.options = taddons.context() + ctx.options.session_path = None + s.running() + f = self.tft(start=1) + assert s.store_count() == 0 + s.request(f) + assert s._view == [(1, f.id)] + assert s.load_view([f.id]) == [f] + assert s.load_view(['nonexistent']) == [None] + + s.error(f) + s.response(f) + s.intercept(f) + s.resume(f) + s.kill(f) + + # Verify that flow has been updated, not duplicated + assert s._view == [(1, f.id)] + assert s.store_count() == 1 + + f2 = self.tft(start=3) + s.request(f2) + assert s._view == [(1, f.id), (3, f2.id)] + s.request(f2) + assert s._view == [(1, f.id), (3, f2.id)] + + f3 = self.tft(start=2) + s.request(f3) + assert s._view == [(1, f.id), (2, f3.id), (3, f2.id)] + s.request(f3) + assert s._view == [(1, f.id), (2, f3.id), (3, f2.id)] + assert s.store_count() == 3 + + s.clear_storage() + assert len(s._view) == 0 + assert s.store_count() == 0 + + def test_filter(self): + s = self.start_session() + s.request(self.tft(method="get")) + s.request(self.tft(method="put")) + s.request(self.tft(method="get")) + s.request(self.tft(method="put")) + assert len(s._view) == 4 + s.set_filter("~m get") + assert [f.request.method for f in s.load_view()] == ["GET", "GET"] + assert s.store_count() == 4 + with pytest.raises(CommandError): + s.set_filter("~notafilter") + s.set_filter(None) + assert len(s._view) == 4 + + @pytest.mark.asyncio + async def test_flush_withspecials(self): + s = self.start_session(fp=0.5) + f = self.tft() + s.request(f) + await asyncio.sleep(2) + assert len(s._hot_store) == 0 + assert all([lflow.__dict__ == flow.__dict__ for lflow, flow in list(zip(s.load_view(), [f]))]) + + f.server_conn.via = tflow.tserver_conn() + s.request(f) + await asyncio.sleep(1) + assert len(s._hot_store) == 0 + assert all([lflow.__dict__ == flow.__dict__ for lflow, flow in list(zip(s.load_view(), [f]))]) + + flows = [self.tft() for _ in range(500)] + s.update(flows) + fp = s._flush_period + fr = s._flush_rate + await asyncio.sleep(0.6) + assert s._flush_period < fp and s._flush_rate > fr + + @pytest.mark.asyncio + async def test_bodies(self): + # Need to test for configure + # Need to test for set_order + s = self.start_session(fp=0.5) + f = self.tft() + f2 = self.tft(start=1) + f.request.content = b"A"*1001 + s.request(f) + s.request(f2) + await asyncio.sleep(1.0) + content = s.db_store.con.execute( + "SELECT type_id, content FROM body WHERE body.flow_id == (?);", [f.id] + ).fetchall()[0] + assert content == (1, b"A"*1001) + assert s.db_store.body_ledger == {f.id} + f.response = http.HTTPResponse.wrap(tutils.tresp(content=b"A"*1001)) + f2.response = http.HTTPResponse.wrap(tutils.tresp(content=b"A"*1001)) + # Content length is wrong for some reason -- quick fix + f.response.headers['content-length'] = b"1001" + f2.response.headers['content-length'] = b"1001" + s.response(f) + s.response(f2) + await asyncio.sleep(1.0) + rows = s.db_store.con.execute( + "SELECT type_id, content FROM body WHERE body.flow_id == (?);", [f.id] + ).fetchall() + assert len(rows) == 1 + rows = s.db_store.con.execute( + "SELECT type_id, content FROM body WHERE body.flow_id == (?);", [f2.id] + ).fetchall() + assert len(rows) == 1 + assert s.db_store.body_ledger == {f.id} + assert all([lf.__dict__ == rf.__dict__ for lf, rf in list(zip(s.load_view(), [f, f2]))]) -- cgit v1.2.3 From e9c2b12dabddd8d5b26db7f877eb982859274263 Mon Sep 17 00:00:00 2001 From: madt1m Date: Thu, 2 Aug 2018 14:20:43 +0200 Subject: tests: Full coverage. Everything working, ready for review --- mitmproxy/addons/session.py | 48 +++++++++++++------------ test/mitmproxy/addons/test_session.py | 68 ++++++++++++++++++++++++++--------- 2 files changed, 77 insertions(+), 39 deletions(-) diff --git a/mitmproxy/addons/session.py b/mitmproxy/addons/session.py index f9d3af3f..2e4d2147 100644 --- a/mitmproxy/addons/session.py +++ b/mitmproxy/addons/session.py @@ -241,8 +241,8 @@ class Session: self.started = True self.db_store = SessionDB(ctx.options.session_path) loop = asyncio.get_event_loop() - tasks = (self._writer, self._tweaker) - loop.create_task(asyncio.gather(*(t() for t in tasks))) + loop.create_task(self._writer()) + loop.create_task(self._tweaker()) def configure(self, updated): if "view_order" in updated: @@ -269,27 +269,30 @@ class Session: self._flush_period *= 1.1 self._flush_rate *= 0.9 - def load_view(self, ids=None): - flows = [] - ids_from_store = [] - if ids is None: - ids = [fid for _, fid in self._view] - for fid in ids: - # Flow could be at the same time in database and in hot storage. We want the most updated version. - if fid in self._hot_store: - flows.append(self._hot_store[fid]) - elif fid in self.db_store: - ids_from_store.append(fid) - else: - flows.append(None) - flows += self.db_store.retrieve_flows(ids_from_store) - return flows + def load_view(self): + ids = [fid for _, fid in self._view] + flows = self.load_storage(ids) + return sorted(flows, key=lambda f: self._generate_order(f)) - def load_storage(self): + def load_storage(self, ids=None): flows = [] - flows += self.db_store.retrieve_flows() - for flow in self._hot_store.values(): - flows.append(flow) + ids_from_store = [] + if ids is not None: + for fid in ids: + # A same flow could be at the same time in hot and db storage. We want the most updated version. + if fid in self._hot_store: + flows.append(self._hot_store[fid]) + elif fid in self.db_store: + ids_from_store.append(fid) + else: + flows.append(None) + flows += self.db_store.retrieve_flows(ids_from_store) + else: + for flow in self._hot_store.values(): + flows.append(flow) + for flow in self.db_store.retrieve_flows(): + if flow.id not in self._hot_store: + flows.append(flow) return flows def clear_storage(self): @@ -304,7 +307,7 @@ class Session: ln += 1 return ln + len(self.db_store) - def _generate_order(self, f: http.HTTPFlow) -> typing.Union[str, int, float]: + def _generate_order(self, f: http.HTTPFlow) -> typing.Optional[typing.Union[str, int, float]]: o = self.order if o == "time": return f.request.timestamp_start or 0 @@ -319,6 +322,7 @@ class Session: if f.response and f.response.raw_content: s += len(f.response.raw_content) return s + return None def set_order(self, order: str) -> None: if order not in orders: diff --git a/test/mitmproxy/addons/test_session.py b/test/mitmproxy/addons/test_session.py index 41e8a401..11a41a6a 100644 --- a/test/mitmproxy/addons/test_session.py +++ b/test/mitmproxy/addons/test_session.py @@ -24,9 +24,10 @@ class TestSession: @staticmethod def start_session(fp=None): s = session.Session() - tctx = taddons.context() - tctx.master.addons.add(s) - tctx.options.session_path = None + with taddons.context() as tctx: + tctx.master.addons.add(s) + tctx.options.session_path = None + tctx.options.view_filter = None if fp: s._flush_period = fp s.running() @@ -97,7 +98,10 @@ class TestSession: s.order = "size" assert s._generate_order(tf) == len(tf.request.raw_content) + len(tf.response.raw_content) - def test_simple(self): + s.order = "invalid" + assert not s._generate_order(tf) + + def test_storage_simple(self): s = session.Session() ctx.options = taddons.context() ctx.options.session_path = None @@ -106,8 +110,8 @@ class TestSession: assert s.store_count() == 0 s.request(f) assert s._view == [(1, f.id)] - assert s.load_view([f.id]) == [f] - assert s.load_view(['nonexistent']) == [None] + assert s.load_view() == [f] + assert s.load_storage(['nonexistent']) == [None] s.error(f) s.response(f) @@ -136,14 +140,17 @@ class TestSession: assert len(s._view) == 0 assert s.store_count() == 0 - def test_filter(self): + def test_storage_filter(self): s = self.start_session() s.request(self.tft(method="get")) s.request(self.tft(method="put")) s.request(self.tft(method="get")) s.request(self.tft(method="put")) assert len(s._view) == 4 - s.set_filter("~m get") + with taddons.context() as tctx: + tctx.master.addons.add(s) + tctx.options.view_filter = '~m get' + s.configure({"view_filter"}) assert [f.request.method for f in s.load_view()] == ["GET", "GET"] assert s.store_count() == 4 with pytest.raises(CommandError): @@ -152,19 +159,24 @@ class TestSession: assert len(s._view) == 4 @pytest.mark.asyncio - async def test_flush_withspecials(self): + async def test_storage_flush_with_specials(self): s = self.start_session(fp=0.5) f = self.tft() s.request(f) - await asyncio.sleep(2) + await asyncio.sleep(1) assert len(s._hot_store) == 0 - assert all([lflow.__dict__ == flow.__dict__ for lflow, flow in list(zip(s.load_view(), [f]))]) + f.response = http.HTTPResponse.wrap(tutils.tresp()) + s.response(f) + assert len(s._hot_store) == 1 + assert s.load_storage() == [f] + await asyncio.sleep(1) + assert all([lflow.__dict__ == flow.__dict__ for lflow, flow in list(zip(s.load_storage(), [f]))]) f.server_conn.via = tflow.tserver_conn() s.request(f) await asyncio.sleep(1) assert len(s._hot_store) == 0 - assert all([lflow.__dict__ == flow.__dict__ for lflow, flow in list(zip(s.load_view(), [f]))]) + assert all([lflow.__dict__ == flow.__dict__ for lflow, flow in list(zip(s.load_storage(), [f]))]) flows = [self.tft() for _ in range(500)] s.update(flows) @@ -174,23 +186,23 @@ class TestSession: assert s._flush_period < fp and s._flush_rate > fr @pytest.mark.asyncio - async def test_bodies(self): + async def test_storage_bodies(self): # Need to test for configure # Need to test for set_order s = self.start_session(fp=0.5) f = self.tft() f2 = self.tft(start=1) - f.request.content = b"A"*1001 + f.request.content = b"A" * 1001 s.request(f) s.request(f2) await asyncio.sleep(1.0) content = s.db_store.con.execute( "SELECT type_id, content FROM body WHERE body.flow_id == (?);", [f.id] ).fetchall()[0] - assert content == (1, b"A"*1001) + assert content == (1, b"A" * 1001) assert s.db_store.body_ledger == {f.id} - f.response = http.HTTPResponse.wrap(tutils.tresp(content=b"A"*1001)) - f2.response = http.HTTPResponse.wrap(tutils.tresp(content=b"A"*1001)) + f.response = http.HTTPResponse.wrap(tutils.tresp(content=b"A" * 1001)) + f2.response = http.HTTPResponse.wrap(tutils.tresp(content=b"A" * 1001)) # Content length is wrong for some reason -- quick fix f.response.headers['content-length'] = b"1001" f2.response.headers['content-length'] = b"1001" @@ -207,3 +219,25 @@ class TestSession: assert len(rows) == 1 assert s.db_store.body_ledger == {f.id} assert all([lf.__dict__ == rf.__dict__ for lf, rf in list(zip(s.load_view(), [f, f2]))]) + + @pytest.mark.asyncio + async def test_storage_order(self): + s = self.start_session(fp=0.5) + s.request(self.tft(method="GET", start=4)) + s.request(self.tft(method="PUT", start=2)) + s.request(self.tft(method="GET", start=3)) + s.request(self.tft(method="PUT", start=1)) + assert [i.request.timestamp_start for i in s.load_view()] == [1, 2, 3, 4] + await asyncio.sleep(1.0) + assert [i.request.timestamp_start for i in s.load_view()] == [1, 2, 3, 4] + with taddons.context() as tctx: + tctx.master.addons.add(s) + tctx.options.view_order = "method" + s.configure({"view_order"}) + assert [i.request.method for i in s.load_view()] == ["GET", "GET", "PUT", "PUT"] + + s.set_order("time") + assert [i.request.timestamp_start for i in s.load_view()] == [1, 2, 3, 4] + + with pytest.raises(CommandError): + s.set_order("not_an_order") -- cgit v1.2.3 From a52451900c71f48bb51d777522424d8ba6944f0b Mon Sep 17 00:00:00 2001 From: madt1m Date: Sun, 5 Aug 2018 21:49:54 +0200 Subject: session: implemented changes requested after PR review. --- mitmproxy/addons/session.py | 105 +++++++++++++++++----------------- test/mitmproxy/addons/test_session.py | 40 ++++++------- 2 files changed, 71 insertions(+), 74 deletions(-) diff --git a/mitmproxy/addons/session.py b/mitmproxy/addons/session.py index 2e4d2147..63e382ec 100644 --- a/mitmproxy/addons/session.py +++ b/mitmproxy/addons/session.py @@ -50,12 +50,13 @@ class SessionDB: or create a new one with optional path. :param db_path: """ - self.live_components = {} - self.tempdir = None - self.con = None + self.live_components: typing.Dict[str, tuple] = {} + self.tempdir: tempfile.TemporaryDirectory = None + self.con: sqlite3.Connection = None # This is used for fast look-ups over bodies already dumped to database. # This permits to enforce one-to-one relationship between flow and body table. - self.body_ledger = set() + self.body_ledger: typing.Set[str] = set() + self.id_ledger: typing.Set[str] = set() if db_path is not None and os.path.isfile(db_path): self._load_session(db_path) else: @@ -74,14 +75,10 @@ class SessionDB: shutil.rmtree(self.tempdir) def __contains__(self, fid): - return any([fid == i for i in self._get_ids()]) + return fid in self.id_ledger def __len__(self): - ln = self.con.execute("SELECT COUNT(*) FROM flow;").fetchall()[0] - return ln[0] if ln else 0 - - def _get_ids(self): - return [t[0] for t in self.con.execute("SELECT id FROM flow;").fetchall()] + return len(self.id_ledger) def _load_session(self, path): if not self.is_session_db(path): @@ -150,6 +147,7 @@ class SessionDB: body_buf = [] flow_buf = [] for flow in flows: + self.id_ledger.add(flow.id) self._disassemble(flow) f = copy.copy(flow) f.request = copy.deepcopy(flow.request) @@ -209,17 +207,21 @@ orders = [ class Session: + + _FP_RATE = 150 + _FP_DECREMENT = 0.9 + _FP_DEFAULT = 3.0 + def __init__(self): - self.db_store = None - self._hot_store = collections.OrderedDict() - self._live_components = {} - self._view = [] - self.order = orders[0] + self.db_store: SessionDB = None + self._hot_store: collections.OrderedDict = collections.OrderedDict() + self._order_store: typing.Dict[str, typing.Dict[str, typing.Union[int, float, str]]] = {} + self._view: typing.List[typing.Tuple[typing.Union[int, float, str], str]] = [] + self.order: str = orders[0] self.filter = matchall - self._flush_period = 3.0 - self._tweak_period = 0.5 - self._flush_rate = 150 - self.started = False + self._flush_period: float = self._FP_DEFAULT + self._flush_rate: int = self._FP_RATE + self.started: bool = False def load(self, loader): loader.add_option( @@ -242,7 +244,6 @@ class Session: self.db_store = SessionDB(ctx.options.session_path) loop = asyncio.get_event_loop() loop.create_task(self._writer()) - loop.create_task(self._tweaker()) def configure(self, updated): if "view_order" in updated: @@ -253,28 +254,23 @@ class Session: async def _writer(self): while True: await asyncio.sleep(self._flush_period) - tof = [] - to_dump = min(self._flush_rate, len(self._hot_store)) - for _ in range(to_dump): - tof.append(self._hot_store.popitem(last=False)[1]) - self.db_store.store_flows(tof) - - async def _tweaker(self): - while True: - await asyncio.sleep(self._tweak_period) - if len(self._hot_store) >= 3 * self._flush_rate: - self._flush_period *= 0.9 - self._flush_rate *= 1.1 - elif len(self._hot_store) < self._flush_rate: - self._flush_period *= 1.1 - self._flush_rate *= 0.9 - - def load_view(self): + batches = -(-len(self._hot_store) // self._flush_rate) + self._flush_period = self._flush_period * self._FP_DECREMENT if batches > 1 else self._FP_DEFAULT + while batches: + tof = [] + to_dump = min(len(self._hot_store), self._flush_rate) + for _ in range(to_dump): + tof.append(self._hot_store.popitem(last=False)[1]) + self.db_store.store_flows(tof) + batches -= 1 + await asyncio.sleep(0.01) + + def load_view(self) -> typing.Sequence[http.HTTPFlow]: ids = [fid for _, fid in self._view] flows = self.load_storage(ids) - return sorted(flows, key=lambda f: self._generate_order(f)) + return sorted(flows, key=lambda f: self._generate_order(self.order, f)) - def load_storage(self, ids=None): + def load_storage(self, ids=None) -> typing.Sequence[http.HTTPFlow]: flows = [] ids_from_store = [] if ids is not None: @@ -284,8 +280,6 @@ class Session: flows.append(self._hot_store[fid]) elif fid in self.db_store: ids_from_store.append(fid) - else: - flows.append(None) flows += self.db_store.retrieve_flows(ids_from_store) else: for flow in self._hot_store.values(): @@ -300,15 +294,15 @@ class Session: self._hot_store.clear() self._view = [] - def store_count(self): + def store_count(self) -> int: ln = 0 for fid in self._hot_store.keys(): if fid not in self.db_store: ln += 1 return ln + len(self.db_store) - def _generate_order(self, f: http.HTTPFlow) -> typing.Optional[typing.Union[str, int, float]]: - o = self.order + @staticmethod + def _generate_order(o: str, f: http.HTTPFlow) -> typing.Optional[typing.Union[str, int, float]]: if o == "time": return f.request.timestamp_start or 0 if o == "method": @@ -324,6 +318,11 @@ class Session: return s return None + def _store_order(self, f: http.HTTPFlow): + self._order_store[f.id] = {} + for order in orders: + self._order_store[f.id][order] = self._generate_order(order, f) + def set_order(self, order: str) -> None: if order not in orders: raise CommandError( @@ -332,7 +331,7 @@ class Session: if order != self.order: self.order = order newview = [ - (self._generate_order(f), f.id) for f in self.load_view() + (self._order_store[t[1]][order], t[1]) for t in self._view ] self._view = sorted(newview) @@ -341,7 +340,7 @@ class Session: flows = self.load_storage() for f in flows: if self.filter(f): - self._base_add(f) + self.update_view(f) def set_filter(self, input_filter: typing.Optional[str]) -> None: filt = matchall if not input_filter else flowfilter.parse(input_filter) @@ -352,22 +351,20 @@ class Session: self.filter = filt self._refilter() - def _base_add(self, f): - if not any([f.id == t[1] for t in self._view]): - o = self._generate_order(f) - self._view.insert(bisect.bisect_left(KeyifyList(self._view, lambda x: x[0]), o), (o, f.id)) - else: - o = self._generate_order(f) + def update_view(self, f): + if any([f.id == t[1] for t in self._view]): self._view = [(order, fid) for order, fid in self._view if fid != f.id] - self._view.insert(bisect.bisect_left(KeyifyList(self._view, lambda x: x[0]), o), (o, f.id)) + o = self._order_store[f.id][self.order] + self._view.insert(bisect.bisect_left(KeyifyList(self._view, lambda x: x[0]), o), (o, f.id)) def update(self, flows: typing.Sequence[http.HTTPFlow]) -> None: for f in flows: + self._store_order(f) if f.id in self._hot_store: self._hot_store.pop(f.id) self._hot_store[f.id] = f if self.filter(f): - self._base_add(f) + self.update_view(f) def request(self, f): self.update([f]) diff --git a/test/mitmproxy/addons/test_session.py b/test/mitmproxy/addons/test_session.py index 11a41a6a..20feb69d 100644 --- a/test/mitmproxy/addons/test_session.py +++ b/test/mitmproxy/addons/test_session.py @@ -28,8 +28,10 @@ class TestSession: tctx.master.addons.add(s) tctx.options.session_path = None tctx.options.view_filter = None + # To make tests quicker if fp: s._flush_period = fp + s._FP_DEFAULT = fp s.running() return s @@ -85,21 +87,11 @@ class TestSession: def test_session_order_generators(self): s = session.Session() tf = tflow.tflow(resp=True) - - s.order = "time" - assert s._generate_order(tf) == 946681200 - - s.order = "method" - assert s._generate_order(tf) == tf.request.method - - s.order = "url" - assert s._generate_order(tf) == tf.request.url - - s.order = "size" - assert s._generate_order(tf) == len(tf.request.raw_content) + len(tf.response.raw_content) - - s.order = "invalid" - assert not s._generate_order(tf) + assert s._generate_order('time', tf) == 946681200 + assert s._generate_order('method', tf) == tf.request.method + assert s._generate_order('url', tf) == tf.request.url + assert s._generate_order('size', tf) == len(tf.request.raw_content) + len(tf.response.raw_content) + assert not s._generate_order('invalid', tf) def test_storage_simple(self): s = session.Session() @@ -110,8 +102,12 @@ class TestSession: assert s.store_count() == 0 s.request(f) assert s._view == [(1, f.id)] + assert s._order_store[f.id]['time'] == 1 + assert s._order_store[f.id]['method'] == f.request.method + assert s._order_store[f.id]['url'] == f.request.url + assert s._order_store[f.id]['size'] == len(f.request.raw_content) assert s.load_view() == [f] - assert s.load_storage(['nonexistent']) == [None] + assert s.load_storage(['nonexistent']) == [] s.error(f) s.response(f) @@ -121,6 +117,10 @@ class TestSession: # Verify that flow has been updated, not duplicated assert s._view == [(1, f.id)] + assert s._order_store[f.id]['time'] == 1 + assert s._order_store[f.id]['method'] == f.request.method + assert s._order_store[f.id]['url'] == f.request.url + assert s._order_store[f.id]['size'] == len(f.request.raw_content) assert s.store_count() == 1 f2 = self.tft(start=3) @@ -174,16 +174,16 @@ class TestSession: f.server_conn.via = tflow.tserver_conn() s.request(f) - await asyncio.sleep(1) + await asyncio.sleep(0.6) assert len(s._hot_store) == 0 assert all([lflow.__dict__ == flow.__dict__ for lflow, flow in list(zip(s.load_storage(), [f]))]) flows = [self.tft() for _ in range(500)] s.update(flows) - fp = s._flush_period - fr = s._flush_rate await asyncio.sleep(0.6) - assert s._flush_period < fp and s._flush_rate > fr + assert s._flush_period == s._FP_DEFAULT * s._FP_DECREMENT + await asyncio.sleep(3) + assert s._flush_period == s._FP_DEFAULT @pytest.mark.asyncio async def test_storage_bodies(self): -- cgit v1.2.3