diff options
| -rw-r--r-- | mitmproxy/addons/session.py | 14 | ||||
| -rw-r--r-- | test/mitmproxy/addons/test_session.py | 8 | 
2 files changed, 12 insertions, 10 deletions
| diff --git a/mitmproxy/addons/session.py b/mitmproxy/addons/session.py index 7f1c0025..c49b95c4 100644 --- a/mitmproxy/addons/session.py +++ b/mitmproxy/addons/session.py @@ -1,4 +1,5 @@  import tempfile +import shutil  import sqlite3  import os @@ -20,7 +21,7 @@ class SessionDB:          or create a new one with optional path.          :param db_path:          """ -        self.temp = None +        self.tempdir = None          self.con = None          if db_path is not None and os.path.isfile(db_path):              self._load_session(db_path) @@ -28,19 +29,16 @@ class SessionDB:              if db_path:                  path = db_path              else: -                # We use tempfile only to generate a path, since we demand file creation to sqlite, and removal to os. -                self.temp = tempfile.NamedTemporaryFile() -                path = self.temp.name -                self.temp.close() +                self.tempdir = tempfile.mkdtemp() +                path = os.path.join(self.tempdir, 'tmp.sqlite')              self.con = sqlite3.connect(path)              self._create_session()      def __del__(self):          if self.con:              self.con.close() -        if self.temp: -            # This is a workaround to ensure portability -            os.remove(self.temp.name) +        if self.tempdir: +            shutil.rmtree(self.tempdir)      def _load_session(self, path):          if not self.is_session_db(path): diff --git a/test/mitmproxy/addons/test_session.py b/test/mitmproxy/addons/test_session.py index cb36e283..d4b1109b 100644 --- a/test/mitmproxy/addons/test_session.py +++ b/test/mitmproxy/addons/test_session.py @@ -8,10 +8,14 @@ from mitmproxy.utils.data import pkg_data  class TestSession: -    def test_session_temporary(self, tdata): +    def test_session_temporary(self):          s = session.SessionDB() -        filename = s.temp.name +        td = s.tempdir +        filename = os.path.join(td, 'tmp.sqlite')          assert session.SessionDB.is_session_db(filename) +        assert os.path.isdir(td) +        del s +        assert not os.path.isdir(td)      def test_session_not_valid(self, tdata):          path = tdata.path('mitmproxy/data/') + '/test_snv.sqlite' | 
