aboutsummaryrefslogtreecommitdiffstats
path: root/cloud_mdir_sync/messages.py
diff options
context:
space:
mode:
Diffstat (limited to 'cloud_mdir_sync/messages.py')
-rw-r--r--cloud_mdir_sync/messages.py326
1 files changed, 326 insertions, 0 deletions
diff --git a/cloud_mdir_sync/messages.py b/cloud_mdir_sync/messages.py
new file mode 100644
index 0000000..d87e547
--- /dev/null
+++ b/cloud_mdir_sync/messages.py
@@ -0,0 +1,326 @@
+# SPDX-License-Identifier: GPL-2.0+
+import collections
+import datetime
+import email
+import email.parser
+import hashlib
+import logging
+import os
+import pickle
+import re
+import stat
+import subprocess
+import sys
+import tempfile
+from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
+
+import cryptography
+import cryptography.exceptions
+from cryptography.fernet import Fernet
+
+from . import config, util
+
+if TYPE_CHECKING:
+ from . import mailbox
+
+ContentHash_Type = str
+CID_Type = tuple
+MBoxDict_Type = Dict["mailbox.Mailbox", Dict[ContentHash_Type,
+ "Message"]]
+CHMsgDict_Type = Dict[ContentHash_Type, "Message"]
+CHMsgMappingDict_Type = Dict[ContentHash_Type, Tuple[Optional["Message"],
+ "Message"]]
+
+
+class Message(object):
+ """A single message in the system"""
+ content_hash: Optional[ContentHash_Type] = None
+ received_time: Optional[datetime.datetime] = None
+ flags = 0
+ FLAG_REPLIED = 1 << 0
+ FLAG_READ = 1 << 1
+ FLAG_FLAGGED = 1 << 2
+ ALL_FLAGS = FLAG_REPLIED | FLAG_READ | FLAG_FLAGGED
+ fn: str
+
+ def __init__(self, mailbox, storage_id, email_id=None):
+ assert storage_id
+ self.mailbox = mailbox
+ self.storage_id = storage_id
+ self.email_id = email_id
+
+ def cid(self):
+ """The unique content ID of the message. This is scoped within the
+ mailbox and is used to search for the content_hash"""
+ return (self.mailbox.storage_kind, self.storage_id, self.email_id)
+
+ def __getstate__(self):
+ return {
+ "content_hash": self.content_hash,
+ "received_time": self.received_time,
+ "flags": self.flags,
+ "storage_id": self.storage_id,
+ "email_id": self.email_id
+ }
+
+
+class MessageDB(object):
+ """The persistent state associated with the message database. This holds:
+ - A directory of content_hash files for mailbox content
+ - A set of files of pickles storing the mapping of CID to content_hash
+ """
+ content_hashes: Dict[CID_Type, ContentHash_Type]
+ content_msgid: Dict[ContentHash_Type, str]
+ alt_file_hashes: Dict[ContentHash_Type, set]
+ inode_hashes: Dict[tuple, ContentHash_Type]
+ file_hashes: Set[str]
+ authenticators_to_save: Set[str]
+ authenticators: Dict[str, Tuple[int, bytes]]
+
+ @util.log_progress(
+ "Loading cached state",
+ lambda self:
+ f", {len(self.file_hashes)} msgs, {len(self.content_hashes)} cached ids",
+ level=logging.DEBUG)
+ def __init__(self, cfg: config.Config):
+ self.cfg = cfg
+ self.content_hashes = {} # [cid] = content_hash
+ self.content_msgid = {} # [hash] = message_id
+ self.file_hashes = set()
+ self.alt_file_hashes = collections.defaultdict(
+ set) # [hash] = set(fns)
+ self.inode_hashes = {} # [inode] = content_hash
+ self.authenticators_to_save = set()
+ self.authenticators = {} # [did] = (serial, blob)
+
+ self.state_dir = os.path.expanduser(cfg.message_db_dir)
+ self.hashes_dir = os.path.join(self.state_dir, "hashes")
+ os.makedirs(self.hashes_dir, exist_ok=True)
+ self._load_file_hashes(self.hashes_dir)
+ self._load_content_hashes()
+
+ def close(self):
+ try:
+ self._save_content_hashes()
+ except IOError:
+ pass
+
+ def _save_content_hashes(self):
+ """Store the current content_hash dictionary in a file named after its
+ content. This allows us to be safe against FS problems on loading"""
+ data = pickle.dumps({
+ "content_hashes":
+ self.content_hashes,
+ "authenticators_enc":
+ self._encrypt_authenticators(),
+ })
+ m = hashlib.sha1()
+ m.update(data)
+ with open(os.path.join(self.state_dir, "ch-" + m.hexdigest()),
+ "xb") as F:
+ F.write(data)
+
+ def _load_content_hash_fn(self, fn, dfn):
+ with open(dfn, "rb") as F:
+ data = F.read()
+ st = os.fstat(F.fileno())
+
+ m = hashlib.sha1()
+ m.update(data)
+ if fn != "ch-" + m.hexdigest():
+ os.unlink(dfn)
+ return ({}, None)
+ return (pickle.loads(data), st[stat.ST_CTIME])
+
+ def _load_content_hashes(self):
+ """Load every available content hash file and union their content."""
+ states = []
+ res: Dict[CID_Type, ContentHash_Type] = {}
+ blacklist = set()
+ for fn in os.listdir(self.state_dir):
+ if not fn.startswith("ch-"):
+ continue
+
+ dfn = os.path.join(self.state_dir, fn)
+ try:
+ state, ctime = self._load_content_hash_fn(fn, dfn)
+ except (IOError, pickle.PickleError):
+ os.unlink(dfn)
+
+ if ctime is not None:
+ states.append((ctime, dfn))
+ for k, v in state.get("content_hashes", state).items():
+ if res.get(k, v) != v:
+ blacklist.add(k)
+ res[k] = v
+ self._load_authenticators(state.get("authenticators_enc"))
+
+ # Keep the 5 latest state files
+ states.sort(reverse=True)
+ for I in states[5:]:
+ os.unlink(I[1])
+
+ for k in blacklist:
+ del res[k]
+ for cid, ch in res.items():
+ self.content_msgid[ch] = cid[2]
+ self.content_hashes = res
+
+ def _sha1_fn(self, fn):
+ return subprocess.check_output(["sha1sum",
+ fn]).partition(b' ')[0].decode()
+
+ def _load_file_hashes(self, hashes_dir):
+ """All files in a directory into the content_hash cache. This figures out what
+ stuff we have already downloaded and is crash safe as we rehash every
+ file. Accidental duplicates are pruned along the way."""
+ hashes = set()
+ for fn in os.listdir(hashes_dir):
+ if fn.startswith("."):
+ continue
+
+ # Since we don't use sync the files can be corrupted, check them.
+ ffn = os.path.join(hashes_dir, fn)
+ ch = self._sha1_fn(ffn)
+ if fn == ch:
+ hashes.add(ch)
+ st = os.stat(ffn)
+ inode = (st.st_ino, st.st_size, st.st_mtime, st.st_ctime)
+ self.inode_hashes[inode] = ch
+ else:
+ sys.exit()
+ os.unlink(ffn)
+ self.file_hashes.update(hashes)
+
+ def have_content(self, msg: Message):
+ """True if we have the message contents for msg locally, based on the
+ storage_id and email_id"""
+ if msg.content_hash is None:
+ msg.content_hash = self.content_hashes.get(msg.cid())
+
+ # If we have this in some other file, link it back to the hashes dir
+ if (msg.content_hash is not None
+ and msg.content_hash not in self.file_hashes):
+ for fn in self.alt_file_hashes.get(msg.content_hash, []):
+ hfn = os.path.join(self.hashes_dir, msg.content_hash)
+ try:
+ os.link(fn, hfn)
+ self.file_hashes.add(msg.content_hash)
+ except FileNotFoundError:
+ continue
+
+ return (msg.content_hash is not None
+ and msg.content_hash in self.file_hashes)
+
+ def msg_from_file(self, msg, fn):
+ """Setup msg from a local file, ie in a Maildir. This also records that we
+ have this message in the DB"""
+ st = os.stat(fn)
+ inode = (st.st_ino, st.st_size, st.st_mtime, st.st_ctime)
+ msg.content_hash = self.inode_hashes.get(inode)
+ if msg.content_hash is None:
+ msg.content_hash = self._sha1_fn(fn)
+ self.inode_hashes[inode] = msg.content_hash
+
+ if msg.email_id is None:
+ msg.email_id = self.content_msgid.get(msg.content_hash)
+ if msg.email_id is None:
+ with open(fn, "rb") as F:
+ emsg = email.parser.BytesParser().parsebytes(F.read())
+ # Hrm, I wonder if this is the right way to normalize a header?
+ msg.email_id = re.sub(r"\n[ \t]+", " ",
+ emsg["message-id"]).strip()
+ self.alt_file_hashes[msg.content_hash].add(fn)
+ assert self.content_msgid.get(msg.content_hash,
+ msg.email_id) == msg.email_id
+ self.content_msgid[msg.content_hash] = msg.email_id
+ msg.fn = fn
+
+ def write_content(self, content_hash, dest_fn):
+ """Make the filename dest_fn contain content_hash's content"""
+ assert content_hash in self.file_hashes
+ os.link(os.path.join(self.hashes_dir, content_hash), dest_fn)
+
+ def get_temp(self):
+ """Return a file for later use by store_hashed_file"""
+ return tempfile.NamedTemporaryFile(dir=self.hashes_dir)
+
+ def store_hashed_msg(self, msg, tmpf):
+ """Retain the content tmpf in the hashed file database"""
+ tmpf.flush()
+ ch = self._sha1_fn(tmpf.name)
+ if ch not in self.file_hashes:
+ # Adopt the tmpfile into the hashes storage
+ fn = os.path.join(self.hashes_dir, ch)
+ os.link(tmpf.name, fn)
+ self.file_hashes.add(ch)
+ st = os.stat(fn)
+ inode = (st.st_ino, st.st_size, st.st_mtime, st.st_ctime)
+ self.inode_hashes[inode] = ch
+
+ msg.content_hash = ch
+ if msg.email_id is not None:
+ assert self.content_msgid.get(ch, msg.email_id) == msg.email_id
+ self.content_msgid[ch] = msg.email_id
+ self.content_hashes[msg.cid()] = ch
+ assert self.have_content(msg)
+ return ch
+
+ def cleanup_msgs(self, msgs_by_local: MBoxDict_Type):
+ """Clean our various caches to only have current messages"""
+ all_chs: Set[ContentHash_Type] = set()
+ for msgdict in msgs_by_local.values():
+ all_chs.update(msgdict.keys())
+ for ch in self.file_hashes - all_chs:
+ try:
+ os.unlink(os.path.join(self.hashes_dir, ch))
+ except FileNotFoundError:
+ pass
+ self.file_hashes.remove(ch)
+
+ # Remove obsolete items in the inode cache
+ to_del = []
+ for ino, ch in self.inode_hashes.items():
+ if ch not in all_chs:
+ to_del.append(ino)
+ for ino in to_del:
+ del self.inode_hashes[ino]
+
+ def _encrypt_authenticators(self):
+ crypto = Fernet(self.cfg.storage_key)
+ return crypto.encrypt(
+ pickle.dumps({
+ k: v
+ for k, v in self.authenticators.items()
+ if k in self.authenticators_to_save
+ }))
+
+ def _load_authenticators(self, data):
+ if data is None:
+ return
+ crypto = Fernet(self.cfg.storage_key)
+ try:
+ plain_data = crypto.decrypt(data)
+ except (cryptography.exceptions.InvalidSignature,
+ cryptography.fernet.InvalidToken):
+ return
+ for k, v in pickle.loads(plain_data).items():
+ if v[0] > self.authenticators.get(k, (0, ))[0]:
+ self.authenticators[k] = v
+
+ def get_authenticator(self, domain_id):
+ """Return the stored authenticator data for the domain_id"""
+ auth = self.authenticators.get(domain_id)
+ if auth is None:
+ return None
+ return auth[1]
+
+ def set_authenticator(self, domain_id, value):
+ """Store authenticator data for the domain_id. The data will persist
+ across reloads of the message db. Usually this will be the OAUTH
+ refresh token."""
+ self.authenticators_to_save.add(domain_id)
+ serial, cur = self.authenticators.get(domain_id, (0, None))
+ if cur == value:
+ return
+ self.authenticators[domain_id] = (serial + 1, value)