diff options
| -rw-r--r-- | mitmproxy/addons/session.py | 73 | ||||
| -rw-r--r-- | mitmproxy/exceptions.py | 4 | ||||
| -rw-r--r-- | mitmproxy/io/sql/session_create.sql | 20 | ||||
| -rw-r--r-- | test/bench/serialization-bm.py | 116 | ||||
| -rw-r--r-- | test/mitmproxy/addons/test_session.py | 58 |
5 files changed, 271 insertions, 0 deletions
diff --git a/mitmproxy/addons/session.py b/mitmproxy/addons/session.py new file mode 100644 index 00000000..c49b95c4 --- /dev/null +++ b/mitmproxy/addons/session.py @@ -0,0 +1,73 @@ +import tempfile +import shutil +import sqlite3 +import os + +from mitmproxy.exceptions import SessionLoadException +from mitmproxy.utils.data import pkg_data + + +# Could be implemented using async libraries +class SessionDB: + """ + This class wraps connection to DB + for Sessions and handles creation, + retrieving and insertion in tables. + """ + + def __init__(self, db_path=None): + """ + Connect to an already existing database, + or create a new one with optional path. + :param db_path: + """ + self.tempdir = None + self.con = None + if db_path is not None and os.path.isfile(db_path): + self._load_session(db_path) + else: + if db_path: + path = db_path + else: + self.tempdir = tempfile.mkdtemp() + path = os.path.join(self.tempdir, 'tmp.sqlite') + self.con = sqlite3.connect(path) + self._create_session() + + def __del__(self): + if self.con: + self.con.close() + if self.tempdir: + shutil.rmtree(self.tempdir) + + def _load_session(self, path): + if not self.is_session_db(path): + raise SessionLoadException('Given path does not point to a valid Session') + self.con = sqlite3.connect(path) + + def _create_session(self): + script_path = pkg_data.path("io/sql/session_create.sql") + qry = open(script_path, 'r').read() + with self.con: + self.con.executescript(qry) + + @staticmethod + def is_session_db(path): + """ + Check if database entered from user + is a valid Session SQLite DB. + :return: True if valid, False if invalid. + """ + try: + c = sqlite3.connect(f'file:{path}?mode=rw', uri=True) + cursor = c.cursor() + cursor.execute("SELECT NAME FROM sqlite_master WHERE type='table';") + rows = cursor.fetchall() + tables = [('flow',), ('body',), ('annotation',)] + if all(elem in rows for elem in tables): + c.close() + return True + except: + if c: + c.close() + return False diff --git a/mitmproxy/exceptions.py b/mitmproxy/exceptions.py index d568898b..9f0a8c30 100644 --- a/mitmproxy/exceptions.py +++ b/mitmproxy/exceptions.py @@ -129,6 +129,10 @@ class NetlibException(MitmproxyException): super().__init__(message) +class SessionLoadException(MitmproxyException): + pass + + class Disconnect: """Immediate EOF""" diff --git a/mitmproxy/io/sql/session_create.sql b/mitmproxy/io/sql/session_create.sql new file mode 100644 index 00000000..bfc98b94 --- /dev/null +++ b/mitmproxy/io/sql/session_create.sql @@ -0,0 +1,20 @@ +CREATE TABLE flow ( +id VARCHAR(36) PRIMARY KEY, +content BLOB +); + +CREATE TABLE body ( +id INTEGER PRIMARY KEY, +flow_id VARCHAR(36), +type_id INTEGER, +content BLOB, +FOREIGN KEY(flow_id) REFERENCES flow(id) +); + +CREATE TABLE annotation ( +id INTEGER PRIMARY KEY, +flow_id VARCHAR(36), +type VARCHAR(16), +content BLOB, +FOREIGN KEY(flow_id) REFERENCES flow(id) +); diff --git a/test/bench/serialization-bm.py b/test/bench/serialization-bm.py new file mode 100644 index 00000000..665b72cb --- /dev/null +++ b/test/bench/serialization-bm.py @@ -0,0 +1,116 @@ +import tempfile
+import asyncio
+import typing
+import time
+
+from statistics import mean
+
+from mitmproxy import ctx
+from mitmproxy.io import db
+from mitmproxy.test import tflow
+
+
+class StreamTester:
+
+ """
+ Generates a constant stream of flows and
+ measure protobuf dumping throughput.
+ """
+
+ def __init__(self):
+ self.dbh = None
+ self.streaming = False
+ self.tf = None
+ self.out = None
+ self.hot_flows = []
+ self.results = []
+ self._flushes = 0
+ self._stream_period = 0.001
+ self._flush_period = 3.0
+ self._flush_rate = 150
+ self._target = 2000
+ self.loop = asyncio.get_event_loop()
+ self.queue = asyncio.Queue(maxsize=self._flush_rate * 3, loop=self.loop)
+ self.temp = tempfile.NamedTemporaryFile()
+
+ def load(self, loader):
+ loader.add_option(
+ "testflow_size",
+ int,
+ 1000,
+ "Length in bytes of test flow content"
+ )
+ loader.add_option(
+ "benchmark_save_path",
+ typing.Optional[str],
+ None,
+ "Destination for the stats result file"
+ )
+
+ def _log(self, msg):
+ if self.out:
+ self.out.write(msg + '\n')
+ else:
+ ctx.log(msg)
+
+ def running(self):
+ if not self.streaming:
+ ctx.log("<== Serialization Benchmark Enabled ==>")
+ self.tf = tflow.tflow()
+ self.tf.request.content = b'A' * ctx.options.testflow_size
+ ctx.log(f"With content size: {len(self.tf.request.content)} B")
+ if ctx.options.benchmark_save_path:
+ ctx.log(f"Storing results to {ctx.options.benchmark_save_path}")
+ self.out = open(ctx.options.benchmark_save_path, "w")
+ self.dbh = db.DBHandler(self.temp.name, mode='write')
+ self.streaming = True
+ tasks = (self.stream, self.writer, self.stats)
+ self.loop.create_task(asyncio.gather(*(t() for t in tasks)))
+
+ async def stream(self):
+ while True:
+ await self.queue.put(self.tf)
+ await asyncio.sleep(self._stream_period)
+
+ async def writer(self):
+ while True:
+ await asyncio.sleep(self._flush_period)
+ count = 1
+ f = await self.queue.get()
+ self.hot_flows.append(f)
+ while count < self._flush_rate:
+ try:
+ self.hot_flows.append(self.queue.get_nowait())
+ count += 1
+ except asyncio.QueueEmpty:
+ pass
+ start = time.perf_counter()
+ n = self._fflush()
+ end = time.perf_counter()
+ self._log(f"dumps/time ratio: {n} / {end-start} -> {n/(end-start)}")
+ self.results.append(n / (end - start))
+ self._flushes += n
+ self._log(f"Flows dumped: {self._flushes}")
+ ctx.log(f"Progress: {min(100.0, 100.0 * (self._flushes / self._target))}%")
+
+ async def stats(self):
+ while True:
+ await asyncio.sleep(1.0)
+ if self._flushes >= self._target:
+ self._log(f"AVG : {mean(self.results)}")
+ ctx.log(f"<== Benchmark Ended. Shutting down... ==>")
+ if self.out:
+ self.out.close()
+ self.temp.close()
+ ctx.master.shutdown()
+
+ def _fflush(self):
+ self.dbh.store(self.hot_flows)
+ n = len(self.hot_flows)
+ self.hot_flows = []
+ return n
+
+
+addons = [
+ StreamTester()
+]
diff --git a/test/mitmproxy/addons/test_session.py b/test/mitmproxy/addons/test_session.py new file mode 100644 index 00000000..d4b1109b --- /dev/null +++ b/test/mitmproxy/addons/test_session.py @@ -0,0 +1,58 @@ +import sqlite3 +import pytest +import os + +from mitmproxy.addons import session +from mitmproxy.exceptions import SessionLoadException +from mitmproxy.utils.data import pkg_data + + +class TestSession: + def test_session_temporary(self): + s = session.SessionDB() + td = s.tempdir + filename = os.path.join(td, 'tmp.sqlite') + assert session.SessionDB.is_session_db(filename) + assert os.path.isdir(td) + del s + assert not os.path.isdir(td) + + def test_session_not_valid(self, tdata): + path = tdata.path('mitmproxy/data/') + '/test_snv.sqlite' + if os.path.isfile(path): + os.remove(path) + with open(path, 'w') as handle: + handle.write("Not valid data") + with pytest.raises(SessionLoadException): + session.SessionDB(path) + os.remove(path) + + def test_session_new_persistent(self, tdata): + path = tdata.path('mitmproxy/data/') + '/test_np.sqlite' + if os.path.isfile(path): + os.remove(path) + session.SessionDB(path) + assert session.SessionDB.is_session_db(path) + os.remove(path) + + def test_session_load_existing(self, tdata): + path = tdata.path('mitmproxy/data/') + '/test_le.sqlite' + if os.path.isfile(path): + os.remove(path) + con = sqlite3.connect(path) + script_path = pkg_data.path("io/sql/session_create.sql") + qry = open(script_path, 'r').read() + with con: + con.executescript(qry) + blob = b'blob_of_data' + con.execute(f'INSERT INTO FLOW VALUES(1, "{blob}");') + con.close() + session.SessionDB(path) + con = sqlite3.connect(path) + with con: + cur = con.cursor() + cur.execute('SELECT * FROM FLOW;') + rows = cur.fetchall() + assert len(rows) == 1 + con.close() + os.remove(path) |
