aboutsummaryrefslogtreecommitdiffstats
path: root/cloud_mdir_sync
diff options
context:
space:
mode:
authorJason Gunthorpe <jgg@mellanox.com>2020-01-10 16:38:52 -0400
committerJason Gunthorpe <jgg@mellanox.com>2020-01-10 16:38:52 -0400
commit2e325c38031bc88568dc065821722dd3e22259cb (patch)
tree04cabe4f38118c483c3e477fc2d980d9b9d45cb4 /cloud_mdir_sync
downloadcloud_mdir_sync-2e325c38031bc88568dc065821722dd3e22259cb.tar.gz
cloud_mdir_sync-2e325c38031bc88568dc065821722dd3e22259cb.tar.bz2
cloud_mdir_sync-2e325c38031bc88568dc065821722dd3e22259cb.zip
Initial commit of cloud_mdir_sync
I have been using for a few months now with no ill effects. Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
Diffstat (limited to 'cloud_mdir_sync')
-rw-r--r--cloud_mdir_sync/__init__.py6
-rw-r--r--cloud_mdir_sync/config.py103
-rw-r--r--cloud_mdir_sync/mailbox.py80
-rw-r--r--cloud_mdir_sync/maildir.py206
-rw-r--r--cloud_mdir_sync/main.py119
-rw-r--r--cloud_mdir_sync/messages.py326
-rw-r--r--cloud_mdir_sync/oauth.py60
-rw-r--r--cloud_mdir_sync/office365.py637
-rw-r--r--cloud_mdir_sync/util.py70
9 files changed, 1607 insertions, 0 deletions
diff --git a/cloud_mdir_sync/__init__.py b/cloud_mdir_sync/__init__.py
new file mode 100644
index 0000000..8382999
--- /dev/null
+++ b/cloud_mdir_sync/__init__.py
@@ -0,0 +1,6 @@
+# SPDX-License-Identifier: GPL-2.0+
+import asyncio
+
+# Python 3.6 compatibility
+if "create_task" not in dir(asyncio):
+ asyncio.create_task = asyncio.ensure_future
diff --git a/cloud_mdir_sync/config.py b/cloud_mdir_sync/config.py
new file mode 100644
index 0000000..bdec863
--- /dev/null
+++ b/cloud_mdir_sync/config.py
@@ -0,0 +1,103 @@
+# SPDX-License-Identifier: GPL-2.0+
+import asyncio
+import itertools
+import logging
+import os
+from typing import TYPE_CHECKING, Any, Dict, List
+
+import pyinotify
+
+if TYPE_CHECKING:
+ from . import messages, mailbox, oauth
+
+logger: logging.Logger
+
+
+class Config(object):
+ """Program configuration and general global state"""
+ message_db_dir = "~/mail/.cms/"
+ domains: Dict[str, Any] = {}
+ trace_file: Any
+ web_app: "oauth.WebServer"
+ logger: logging.Logger
+ loop: asyncio.AbstractEventLoop
+ watch_manager: pyinotify.WatchManager
+ msgdb: "messages.MessageDB"
+ cloud_mboxes: "List[mailbox.Mailbox]"
+ local_mboxes: "List[mailbox.Mailbox]"
+
+ def _create_logger(self):
+ global logger
+ logger = logging.getLogger('cloud-mdir-sync')
+ logger.setLevel(logging.DEBUG)
+ ch = logging.StreamHandler()
+ ch.setFormatter(
+ logging.Formatter(fmt='%(asctime)s %(levelname)-8s %(message)s',
+ datefmt='%m-%d %H:%M:%S'))
+ ch.setLevel(logging.DEBUG)
+ logger.addHandler(ch)
+ self.logger = logger
+
+ def __init__(self):
+ self._create_logger()
+ self.cloud_mboxes = []
+ self.local_mboxes = []
+ self.message_db_dir = os.path.expanduser(self.message_db_dir)
+ self.direct_message = self._direct_message
+
+ def load_config(self, fn):
+ """The configuration file is a python script that we execute with
+ capitalized functions of this class injected into it"""
+ fn = os.path.expanduser(fn)
+ with open(fn, "r") as F:
+ pyc = compile(source=F.read(), filename=fn, mode="exec")
+
+ g = {"cfg": self}
+ for k in dir(self):
+ if k[0].isupper():
+ g[k] = getattr(self, k)
+ eval(pyc, g)
+
+ @property
+ def storage_key(self):
+ """The storage key is used with fernet to manage the authentication
+ data, which is stored to disk using symmetric encryption. The
+ decryption key is keld by the system keyring in some secure storage.
+ On Linux desktop systems this is likely to be something like
+ gnome-keyring."""
+ import keyring
+ from cryptography.fernet import Fernet
+
+ ring = keyring.get_keyring()
+ res = ring.get_password("cloud_mdir_sync", "storage")
+ if res is None:
+ res = Fernet.generate_key()
+ ring.set_password("cloud_mdir_sync", "storage", res)
+ return res
+
+ def all_mboxes(self):
+ return itertools.chain(self.local_mboxes, self.cloud_mboxes)
+
+ def Office365_Account(self, user=None, tenant="common"):
+ """Define an Office365 account credential. If user is left as None
+ then the browser will prompt for the user and the choice will be
+ cached. To lock the account to a single tenant specify the Azure
+ Directory name, ie 'contoso.onmicrosoft.com', or the GUID."""
+ return (user,tenant)
+
+ def Office365(self, mailbox, account):
+ """Create a cloud mailbox for Office365. Mailbox is the name of O365
+ mailbox to use, account should be the result of Office365_Account"""
+ from .office365 import O365Mailbox
+ self.cloud_mboxes.append(O365Mailbox(mailbox, user=account[0],
+ tenant=account[1]))
+ return self.cloud_mboxes[-1]
+
+ def MailDir(self, directory):
+ """Create a local maildir to hold messages"""
+ from .maildir import MailDirMailbox
+ self.local_mboxes.append(MailDirMailbox(directory))
+ return self.local_mboxes[-1]
+
+ def _direct_message(self, msg):
+ return self.local_mboxes[0]
diff --git a/cloud_mdir_sync/mailbox.py b/cloud_mdir_sync/mailbox.py
new file mode 100644
index 0000000..24c64cd
--- /dev/null
+++ b/cloud_mdir_sync/mailbox.py
@@ -0,0 +1,80 @@
+# SPDX-License-Identifier: GPL-2.0+
+import asyncio
+import functools
+import inspect
+from abc import abstractmethod
+from typing import TYPE_CHECKING, Dict
+
+if TYPE_CHECKING:
+ from . import config
+ from messages import MessageDB
+ from messages import CHMsgDict_Type
+ from messages import CHMsgMappingDict_Type
+
+
+def update_on_failure(func):
+ """Decorator for mailbox class methods that cause the mailbox to need a full
+ update if the method throws an exception."""
+ @functools.wraps(func)
+ def wrapper(self, *args, **kwargs):
+ try:
+ return func(self, *args, **kwargs)
+ except:
+ self.need_update = True
+ Mailbox.changed_event.set()
+ raise
+
+ @functools.wraps(func)
+ async def async_wrapper(self, *args, **kwargs):
+ try:
+ return await func(self, *args, **kwargs)
+ except:
+ self.need_update = True
+ Mailbox.changed_event.set()
+ raise
+
+ if inspect.iscoroutinefunction(func):
+ return async_wrapper
+ return wrapper
+
+
+class Mailbox(object):
+ messages: "CHMsgDict_Type" = {}
+ changed_event = asyncio.Event()
+ need_update = True
+
+ @abstractmethod
+ async def setup_mbox(self, cfg: "config.Config") -> None:
+ pass
+
+ @abstractmethod
+ def force_content(self, msgdb: "MessageDB",
+ msgs: "CHMsgDict_Type") -> None:
+ pass
+
+ @abstractmethod
+ async def merge_content(self, msgs: "CHMsgMappingDict_Type") -> None:
+ pass
+
+ def same_messages(self,
+ mdict: "CHMsgMappingDict_Type",
+ tuple_form=False) -> bool:
+ """Return true if mdict is the same as the local messages"""
+ if len(self.messages) != len(mdict):
+ return False
+
+ for ch, mmsg in self.messages.items():
+ omsg = mdict.get(ch)
+ if omsg is None:
+ return False
+
+ # update_cloud_from_local use a different dict format
+ if tuple_form:
+ omsg = omsg[0] # Check the local mbox
+ if omsg is None:
+ return False
+
+ if (mmsg.content_hash != omsg.content_hash
+ or mmsg.flags != omsg.flags):
+ return False
+ return True
diff --git a/cloud_mdir_sync/maildir.py b/cloud_mdir_sync/maildir.py
new file mode 100644
index 0000000..fdc8a90
--- /dev/null
+++ b/cloud_mdir_sync/maildir.py
@@ -0,0 +1,206 @@
+# SPDX-License-Identifier: GPL-2.0+
+import logging
+import os
+import pickle
+import re
+import time
+
+import pyinotify
+
+from . import config, mailbox, messages, util
+
+
+def unfold_header(s):
+ # Hrm, I wonder if this is the right way to normalize a header?
+ return re.sub(r"\n[ \t]+", " ", s)
+
+
+class MailDirMailbox(mailbox.Mailbox):
+ """Local MailDir mail directory"""
+ storage_kind = "maildir"
+ cfg: config.Config
+
+ def __init__(self, directory):
+ super().__init__()
+ self.dfn = os.path.expanduser(directory)
+ for sub in ["tmp", "cur", "new"]:
+ os.makedirs(os.path.join(self.dfn, sub), mode=0o700, exist_ok=True)
+
+ async def setup_mbox(self, cfg: config.Config):
+ self.cfg = cfg
+ cfg.watch_manager.add_watch(
+ path=[
+ os.path.join(self.dfn, "cur"),
+ os.path.join(self.dfn, "new")
+ ],
+ proc_fun=self._dir_changed,
+ mask=(pyinotify.IN_ATTRIB | pyinotify.IN_MOVED_FROM
+ | pyinotify.IN_MOVED_TO
+ | pyinotify.IN_CREATE | pyinotify.IN_DELETE
+ | pyinotify.IN_ONLYDIR),
+ quiet=False)
+
+ def _dir_changed(self, notifier):
+ self.need_update = True
+ self.changed_event.set()
+
+ def _msg_to_flags(self, msg: messages.Message):
+ """Return the desired maildir flags from a message"""
+ # See https://cr.yp.to/proto/maildir.html
+ res = set()
+ if msg.flags & messages.Message.FLAG_REPLIED:
+ res.add("R")
+ if msg.flags & messages.Message.FLAG_READ:
+ res.add("S")
+ if msg.flags & messages.Message.FLAG_FLAGGED:
+ res.add("F")
+ return res
+
+ def _decode_msg_filename(self, fn):
+ """Return the base maildir filename, message flags, and maildir flag
+ letters"""
+ fn = os.path.basename(fn)
+ if ":2," not in fn:
+ return (fn, set(), 0)
+ fn, _, flags = fn.partition(":2,")
+ flags = set(flags)
+ mflags = 0
+ if "R" in flags:
+ mflags |= messages.Message.FLAG_REPLIED
+ if "S" in flags:
+ mflags |= messages.Message.FLAG_READ
+ if "F" in flags:
+ mflags |= messages.Message.FLAG_FLAGGED
+ assert ":2," not in fn
+ return (fn, flags, mflags)
+
+ def _load_message(self, msgdb: messages.MessageDB, fn, ffn):
+ sid, _, mflags = self._decode_msg_filename(fn)
+ msg = messages.Message(mailbox=self, storage_id=sid)
+ msg.flags = mflags
+ msgdb.msg_from_file(msg, ffn)
+ return msg
+
+ def _update_message_dir(self, res, msgdb: messages.MessageDB, dfn):
+ for fn in os.listdir(dfn):
+ if fn.startswith("."):
+ continue
+ msg = self._load_message(msgdb, fn, os.path.join(dfn, fn))
+ res[msg.content_hash] = msg
+
+ @util.log_progress(lambda self: f"Updating Message List for {self.dfn}",
+ lambda self: f", {len(self.messages)} msgs",
+ level=logging.DEBUG)
+ @mailbox.update_on_failure
+ async def update_message_list(self, msgdb: messages.MessageDB):
+ """Read the message list from the maildir and compute the content hashes"""
+ res: messages.CHMsgDict_Type = {}
+ st = {}
+ for sd in ["cur", "new"]:
+ st[sd] = os.stat(os.path.join(self.dfn, sd))
+ for sd in ["cur", "new"]:
+ self._update_message_dir(res, msgdb, os.path.join(self.dfn, sd))
+ for sd in ["cur", "new"]:
+ fn = os.path.join(self.dfn, sd)
+ # Retry if the dirs changed while trying to read them
+ if os.stat(fn).st_mtime != st[sd].st_mtime:
+ raise IOError(f"Maildir {fn} changed during listing")
+
+ self.messages = res
+ self.need_update = False
+ if self.cfg.trace_file is not None:
+ pickle.dump(["update_message_list", self.dfn, self.messages],
+ self.cfg.trace_file)
+
+ def _new_maildir_id(self, msg: messages.Message):
+ """Return a unique maildir filename for the given message"""
+ tm = time.clock_gettime(time.CLOCK_REALTIME)
+ base = f"{int(tm)}.M{int((tm%1)*1000*1000)}-{msg.content_hash}"
+ flags = self._msg_to_flags(msg)
+ if flags:
+ fn = os.path.join(self.dfn, "cur",
+ base + ":2," + "".join(sorted(flags)))
+ else:
+ fn = os.path.join(self.dfn, "new", base)
+ return base, fn
+
+ def _store_msg(self, msgdb: messages.MessageDB,
+ cloudmsg: messages.Message):
+ """Apply a delta from the cloud: New message from cloud"""
+ sid, fn = self._new_maildir_id(cloudmsg)
+ msg = messages.Message(mailbox=self,
+ storage_id=sid,
+ email_id=cloudmsg.email_id)
+ msg.flags = cloudmsg.flags
+ msg.content_hash = cloudmsg.content_hash
+ assert msg.content_hash is not None
+ msg.fn = fn
+
+ msgdb.write_content(cloudmsg.content_hash, msg.fn)
+
+ # It isn't clear if we need to do this, but make the local timestamps
+ # match when the message would have been received if the local MTA
+ # delivered it.
+ if cloudmsg.received_time is not None:
+ os.utime(fn, (time.time(), cloudmsg.received_time.timestamp()))
+ self.messages[msg.content_hash] = msg
+
+ def _set_flags(self, mymsg: messages.Message, cloudmsg: messages.Message):
+ """Apply a delta from the cloud: Same message in cloud, synchronize flags"""
+ if mymsg.flags == cloudmsg.flags:
+ return
+
+ cloud_flags = self._msg_to_flags(cloudmsg)
+
+ base, mflags, _ = self._decode_msg_filename(mymsg.fn)
+ nflags = (mflags - set(("R", "S", "F"))) | cloud_flags
+ if mflags == nflags:
+ return
+ if nflags:
+ nfn = os.path.join(self.dfn, "cur",
+ base + ":2," + "".join(sorted(nflags)))
+ else:
+ nfn = os.path.join(self.dfn, "new", base)
+ os.rename(mymsg.fn, nfn)
+ mymsg.fn = nfn
+ mymsg.flags = cloudmsg.flags
+
+ def _remove_msg(self, mymsg: messages.Message):
+ """Apply a delta from the cloud: Message deleted in cloud"""
+ assert mymsg.content_hash is not None
+ os.unlink(mymsg.fn)
+ del self.messages[mymsg.content_hash]
+
+ @util.log_progress(
+ lambda self: f"Applying cloud changes for {self.dfn}", lambda self:
+ f", {self.last_force_new} added, {self.last_force_rm} removed, {self.last_force_kept} same"
+ )
+ @mailbox.update_on_failure
+ def force_content(self, msgdb: messages.MessageDB, msgs: messages.CHMsgDict_Type):
+ """Force this mailbox to contain the message list msgs (from cloud), including
+ all the flags and state"""
+ self.last_force_kept = 0
+ self.last_force_new = 0
+ self.last_force_rm = 0
+
+ have = set(self.messages.keys())
+ want = set(msgs.keys())
+
+ for content_hash in want.intersection(have):
+ self.last_force_kept += 1
+ self._set_flags(self.messages[content_hash], msgs[content_hash])
+
+ for content_hash in want - have:
+ self.last_force_new += 1
+ self._store_msg(msgdb, msgs[content_hash])
+
+ for content_hash in have - want:
+ self.last_force_rm += 1
+ self._remove_msg(self.messages[content_hash])
+
+ if self.cfg.trace_file is not None:
+ pickle.dump(["force_content", self.dfn, self.messages, msgs],
+ self.cfg.trace_file)
+
+ async def merge_content(self, msgs):
+ raise RuntimeError("Cannot merge local changes into a local mailbox")
diff --git a/cloud_mdir_sync/main.py b/cloud_mdir_sync/main.py
new file mode 100644
index 0000000..27fc711
--- /dev/null
+++ b/cloud_mdir_sync/main.py
@@ -0,0 +1,119 @@
+# SPDX-License-Identifier: GPL-2.0+
+import argparse
+import asyncio
+import contextlib
+import os
+from typing import Dict, Optional, Tuple
+
+import aiohttp
+import pyinotify
+
+from . import config, mailbox, messages, oauth, office365
+
+
+def force_local_to_cloud(cfg: config.Config) -> messages.MBoxDict_Type:
+ """Make all the local mailboxes match their cloud content, overwriting any
+ local changes."""
+
+ # For every cloud message figure out which local mailbox it belongs to
+ msgs: messages.MBoxDict_Type = {}
+ for mbox in cfg.local_mboxes:
+ msgs[mbox] = {}
+ for mbox in cfg.cloud_mboxes:
+ for ch,msg in mbox.messages.items():
+ dest = cfg.direct_message(msg)
+ msgs[dest][ch] = msg
+
+ for mbox, msgdict in msgs.items():
+ if not mbox.same_messages(msgdict):
+ mbox.force_content(cfg.msgdb, msgdict)
+ return msgs
+
+
+async def update_cloud_from_local(cfg: config.Config,
+ msgs_by_local: messages.MBoxDict_Type):
+ """Detect differences made by the local mailboxes and upload them to the
+ cloud."""
+ msgs_by_cloud: Dict[mailbox.Mailbox, messages.CHMsgMappingDict_Type] = {}
+ for mbox in cfg.cloud_mboxes:
+ msgs_by_cloud[mbox] = {}
+ for local_mbox, msgdict in msgs_by_local.items():
+ for ch, cloud_msg in msgdict.items():
+ msgs_by_cloud[cloud_msg.mailbox][ch] = (
+ local_mbox.messages.get(ch), cloud_msg)
+ await asyncio.gather(*(
+ mbox.merge_content(msgdict) for mbox, msgdict in msgs_by_cloud.items()
+ if not mbox.same_messages(msgdict, tuple_form=True)))
+
+
+async def synchronize_mail(cfg: config.Config):
+ """Main synchronizing loop"""
+ cfg.web_app = oauth.WebServer()
+ try:
+ await cfg.web_app.go()
+
+ await asyncio.gather(*(mbox.setup_mbox(cfg)
+ for mbox in cfg.all_mboxes()))
+
+ msgs = None
+ while True:
+ try:
+ await asyncio.gather(*(mbox.update_message_list(cfg.msgdb)
+ for mbox in cfg.all_mboxes()
+ if mbox.need_update))
+
+ if msgs is not None:
+ await update_cloud_from_local(cfg, msgs)
+
+ msgs = force_local_to_cloud(cfg)
+ except (FileNotFoundError, asyncio.TimeoutError,
+ aiohttp.client_exceptions.ClientError, IOError,
+ RuntimeError):
+ cfg.logger.exception(
+ "Failed update cycle, sleeping then retrying")
+ await asyncio.sleep(10)
+ continue
+
+ await mailbox.Mailbox.changed_event.wait()
+ mailbox.Mailbox.changed_event.clear()
+ cfg.msgdb.cleanup_msgs(msgs)
+ cfg.logger.debug("Changed event, looping")
+ finally:
+ await asyncio.gather(*(domain.close()
+ for domain in cfg.domains.values()))
+ cfg.domains = {}
+ await cfg.web_app.close()
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description=
+ """Cloud MailDir Sync is able to download email messages from a cloud
+ provider and store them in a local maildir. It uses the REST interface
+ from the cloud provider rather than IMAP and uses OAUTH to
+ authenticate. Once downloaded the tool tracks changes in the local
+ mail dir and uploads them back to the cloud.""")
+ parser.add_argument("-c",
+ dest="CFG",
+ default="cms.cfg",
+ help="Configuration file to use")
+ args = parser.parse_args()
+
+ cfg = config.Config()
+ cfg.load_config(args.CFG)
+ cfg.loop = asyncio.get_event_loop()
+ with contextlib.closing(pyinotify.WatchManager()) as wm, \
+ contextlib.closing(messages.MessageDB(cfg)) as msgdb, \
+ open("trace", "wb") as trace:
+ pyinotify.AsyncioNotifier(wm, cfg.loop)
+ cfg.watch_manager = wm
+ cfg.msgdb = msgdb
+ cfg.trace_file = trace
+ cfg.loop.run_until_complete(synchronize_mail(cfg))
+
+ cfg.loop.run_until_complete(cfg.loop.shutdown_asyncgens())
+ cfg.loop.close()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/cloud_mdir_sync/messages.py b/cloud_mdir_sync/messages.py
new file mode 100644
index 0000000..d87e547
--- /dev/null
+++ b/cloud_mdir_sync/messages.py
@@ -0,0 +1,326 @@
+# SPDX-License-Identifier: GPL-2.0+
+import collections
+import datetime
+import email
+import email.parser
+import hashlib
+import logging
+import os
+import pickle
+import re
+import stat
+import subprocess
+import sys
+import tempfile
+from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
+
+import cryptography
+import cryptography.exceptions
+from cryptography.fernet import Fernet
+
+from . import config, util
+
+if TYPE_CHECKING:
+ from . import mailbox
+
+ContentHash_Type = str
+CID_Type = tuple
+MBoxDict_Type = Dict["mailbox.Mailbox", Dict[ContentHash_Type,
+ "Message"]]
+CHMsgDict_Type = Dict[ContentHash_Type, "Message"]
+CHMsgMappingDict_Type = Dict[ContentHash_Type, Tuple[Optional["Message"],
+ "Message"]]
+
+
+class Message(object):
+ """A single message in the system"""
+ content_hash: Optional[ContentHash_Type] = None
+ received_time: Optional[datetime.datetime] = None
+ flags = 0
+ FLAG_REPLIED = 1 << 0
+ FLAG_READ = 1 << 1
+ FLAG_FLAGGED = 1 << 2
+ ALL_FLAGS = FLAG_REPLIED | FLAG_READ | FLAG_FLAGGED
+ fn: str
+
+ def __init__(self, mailbox, storage_id, email_id=None):
+ assert storage_id
+ self.mailbox = mailbox
+ self.storage_id = storage_id
+ self.email_id = email_id
+
+ def cid(self):
+ """The unique content ID of the message. This is scoped within the
+ mailbox and is used to search for the content_hash"""
+ return (self.mailbox.storage_kind, self.storage_id, self.email_id)
+
+ def __getstate__(self):
+ return {
+ "content_hash": self.content_hash,
+ "received_time": self.received_time,
+ "flags": self.flags,
+ "storage_id": self.storage_id,
+ "email_id": self.email_id
+ }
+
+
+class MessageDB(object):
+ """The persistent state associated with the message database. This holds:
+ - A directory of content_hash files for mailbox content
+ - A set of files of pickles storing the mapping of CID to content_hash
+ """
+ content_hashes: Dict[CID_Type, ContentHash_Type]
+ content_msgid: Dict[ContentHash_Type, str]
+ alt_file_hashes: Dict[ContentHash_Type, set]
+ inode_hashes: Dict[tuple, ContentHash_Type]
+ file_hashes: Set[str]
+ authenticators_to_save: Set[str]
+ authenticators: Dict[str, Tuple[int, bytes]]
+
+ @util.log_progress(
+ "Loading cached state",
+ lambda self:
+ f", {len(self.file_hashes)} msgs, {len(self.content_hashes)} cached ids",
+ level=logging.DEBUG)
+ def __init__(self, cfg: config.Config):
+ self.cfg = cfg
+ self.content_hashes = {} # [cid] = content_hash
+ self.content_msgid = {} # [hash] = message_id
+ self.file_hashes = set()
+ self.alt_file_hashes = collections.defaultdict(
+ set) # [hash] = set(fns)
+ self.inode_hashes = {} # [inode] = content_hash
+ self.authenticators_to_save = set()
+ self.authenticators = {} # [did] = (serial, blob)
+
+ self.state_dir = os.path.expanduser(cfg.message_db_dir)
+ self.hashes_dir = os.path.join(self.state_dir, "hashes")
+ os.makedirs(self.hashes_dir, exist_ok=True)
+ self._load_file_hashes(self.hashes_dir)
+ self._load_content_hashes()
+
+ def close(self):
+ try:
+ self._save_content_hashes()
+ except IOError:
+ pass
+
+ def _save_content_hashes(self):
+ """Store the current content_hash dictionary in a file named after its
+ content. This allows us to be safe against FS problems on loading"""
+ data = pickle.dumps({
+ "content_hashes":
+ self.content_hashes,
+ "authenticators_enc":
+ self._encrypt_authenticators(),
+ })
+ m = hashlib.sha1()
+ m.update(data)
+ with open(os.path.join(self.state_dir, "ch-" + m.hexdigest()),
+ "xb") as F:
+ F.write(data)
+
+ def _load_content_hash_fn(self, fn, dfn):
+ with open(dfn, "rb") as F:
+ data = F.read()
+ st = os.fstat(F.fileno())
+
+ m = hashlib.sha1()
+ m.update(data)
+ if fn != "ch-" + m.hexdigest():
+ os.unlink(dfn)
+ return ({}, None)
+ return (pickle.loads(data), st[stat.ST_CTIME])
+
+ def _load_content_hashes(self):
+ """Load every available content hash file and union their content."""
+ states = []
+ res: Dict[CID_Type, ContentHash_Type] = {}
+ blacklist = set()
+ for fn in os.listdir(self.state_dir):
+ if not fn.startswith("ch-"):
+ continue
+
+ dfn = os.path.join(self.state_dir, fn)
+ try:
+ state, ctime = self._load_content_hash_fn(fn, dfn)
+ except (IOError, pickle.PickleError):
+ os.unlink(dfn)
+
+ if ctime is not None:
+ states.append((ctime, dfn))
+ for k, v in state.get("content_hashes", state).items():
+ if res.get(k, v) != v:
+ blacklist.add(k)
+ res[k] = v
+ self._load_authenticators(state.get("authenticators_enc"))
+
+ # Keep the 5 latest state files
+ states.sort(reverse=True)
+ for I in states[5:]:
+ os.unlink(I[1])
+
+ for k in blacklist:
+ del res[k]
+ for cid, ch in res.items():
+ self.content_msgid[ch] = cid[2]
+ self.content_hashes = res
+
+ def _sha1_fn(self, fn):
+ return subprocess.check_output(["sha1sum",
+ fn]).partition(b' ')[0].decode()
+
+ def _load_file_hashes(self, hashes_dir):
+ """All files in a directory into the content_hash cache. This figures out what
+ stuff we have already downloaded and is crash safe as we rehash every
+ file. Accidental duplicates are pruned along the way."""
+ hashes = set()
+ for fn in os.listdir(hashes_dir):
+ if fn.startswith("."):
+ continue
+
+ # Since we don't use sync the files can be corrupted, check them.
+ ffn = os.path.join(hashes_dir, fn)
+ ch = self._sha1_fn(ffn)
+ if fn == ch:
+ hashes.add(ch)
+ st = os.stat(ffn)
+ inode = (st.st_ino, st.st_size, st.st_mtime, st.st_ctime)
+ self.inode_hashes[inode] = ch
+ else:
+ sys.exit()
+ os.unlink(ffn)
+ self.file_hashes.update(hashes)
+
+ def have_content(self, msg: Message):
+ """True if we have the message contents for msg locally, based on the
+ storage_id and email_id"""
+ if msg.content_hash is None:
+ msg.content_hash = self.content_hashes.get(msg.cid())
+
+ # If we have this in some other file, link it back to the hashes dir
+ if (msg.content_hash is not None
+ and msg.content_hash not in self.file_hashes):
+ for fn in self.alt_file_hashes.get(msg.content_hash, []):
+ hfn = os.path.join(self.hashes_dir, msg.content_hash)
+ try:
+ os.link(fn, hfn)
+ self.file_hashes.add(msg.content_hash)
+ except FileNotFoundError:
+ continue
+
+ return (msg.content_hash is not None
+ and msg.content_hash in self.file_hashes)
+
+ def msg_from_file(self, msg, fn):
+ """Setup msg from a local file, ie in a Maildir. This also records that we
+ have this message in the DB"""
+ st = os.stat(fn)
+ inode = (st.st_ino, st.st_size, st.st_mtime, st.st_ctime)
+ msg.content_hash = self.inode_hashes.get(inode)
+ if msg.content_hash is None:
+ msg.content_hash = self._sha1_fn(fn)
+ self.inode_hashes[inode] = msg.content_hash
+
+ if msg.email_id is None:
+ msg.email_id = self.content_msgid.get(msg.content_hash)
+ if msg.email_id is None:
+ with open(fn, "rb") as F:
+ emsg = email.parser.BytesParser().parsebytes(F.read())
+ # Hrm, I wonder if this is the right way to normalize a header?
+ msg.email_id = re.sub(r"\n[ \t]+", " ",
+ emsg["message-id"]).strip()
+ self.alt_file_hashes[msg.content_hash].add(fn)
+ assert self.content_msgid.get(msg.content_hash,
+ msg.email_id) == msg.email_id
+ self.content_msgid[msg.content_hash] = msg.email_id
+ msg.fn = fn
+
+ def write_content(self, content_hash, dest_fn):
+ """Make the filename dest_fn contain content_hash's content"""
+ assert content_hash in self.file_hashes
+ os.link(os.path.join(self.hashes_dir, content_hash), dest_fn)
+
+ def get_temp(self):
+ """Return a file for later use by store_hashed_file"""
+ return tempfile.NamedTemporaryFile(dir=self.hashes_dir)
+
+ def store_hashed_msg(self, msg, tmpf):
+ """Retain the content tmpf in the hashed file database"""
+ tmpf.flush()
+ ch = self._sha1_fn(tmpf.name)
+ if ch not in self.file_hashes:
+ # Adopt the tmpfile into the hashes storage
+ fn = os.path.join(self.hashes_dir, ch)
+ os.link(tmpf.name, fn)
+ self.file_hashes.add(ch)
+ st = os.stat(fn)
+ inode = (st.st_ino, st.st_size, st.st_mtime, st.st_ctime)
+ self.inode_hashes[inode] = ch
+
+ msg.content_hash = ch
+ if msg.email_id is not None:
+ assert self.content_msgid.get(ch, msg.email_id) == msg.email_id
+ self.content_msgid[ch] = msg.email_id
+ self.content_hashes[msg.cid()] = ch
+ assert self.have_content(msg)
+ return ch
+
+ def cleanup_msgs(self, msgs_by_local: MBoxDict_Type):
+ """Clean our various caches to only have current messages"""
+ all_chs: Set[ContentHash_Type] = set()
+ for msgdict in msgs_by_local.values():
+ all_chs.update(msgdict.keys())
+ for ch in self.file_hashes - all_chs:
+ try:
+ os.unlink(os.path.join(self.hashes_dir, ch))
+ except FileNotFoundError:
+ pass
+ self.file_hashes.remove(ch)
+
+ # Remove obsolete items in the inode cache
+ to_del = []
+ for ino, ch in self.inode_hashes.items():
+ if ch not in all_chs:
+ to_del.append(ino)
+ for ino in to_del:
+ del self.inode_hashes[ino]
+
+ def _encrypt_authenticators(self):
+ crypto = Fernet(self.cfg.storage_key)
+ return crypto.encrypt(
+ pickle.dumps({
+ k: v
+ for k, v in self.authenticators.items()
+ if k in self.authenticators_to_save
+ }))
+
+ def _load_authenticators(self, data):
+ if data is None:
+ return
+ crypto = Fernet(self.cfg.storage_key)
+ try:
+ plain_data = crypto.decrypt(data)
+ except (cryptography.exceptions.InvalidSignature,
+ cryptography.fernet.InvalidToken):
+ return
+ for k, v in pickle.loads(plain_data).items():
+ if v[0] > self.authenticators.get(k, (0, ))[0]:
+ self.authenticators[k] = v
+
+ def get_authenticator(self, domain_id):
+ """Return the stored authenticator data for the domain_id"""
+ auth = self.authenticators.get(domain_id)
+ if auth is None:
+ return None
+ return auth[1]
+
+ def set_authenticator(self, domain_id, value):
+ """Store authenticator data for the domain_id. The data will persist
+ across reloads of the message db. Usually this will be the OAUTH
+ refresh token."""
+ self.authenticators_to_save.add(domain_id)
+ serial, cur = self.authenticators.get(domain_id, (0, None))
+ if cur == value:
+ return
+ self.authenticators[domain_id] = (serial + 1, value)
diff --git a/cloud_mdir_sync/oauth.py b/cloud_mdir_sync/oauth.py
new file mode 100644
index 0000000..449d16c
--- /dev/null
+++ b/cloud_mdir_sync/oauth.py
@@ -0,0 +1,60 @@
+# SPDX-License-Identifier: GPL-2.0+
+import asyncio
+
+import aiohttp
+import aiohttp.web
+
+
+class WebServer(object):
+ """A small web server is used to manage oauth requests. The user should point a browser
+ window at localhost. The program will generate redirects for the browser to point at
+ OAUTH servers when interactive authentication is required."""
+ url = "http://localhost:8080/"
+ runner = None
+
+ def __init__(self):
+ self.auth_redirs = {}
+ self.web_app = aiohttp.web.Application()
+ self.web_app.router.add_get("/", self._start)
+ self.web_app.router.add_get("/oauth2/msal", self._oauth2_msal)
+
+ async def go(self):
+ self.runner = aiohttp.web.AppRunner(self.web_app)
+ await self.runner.setup()
+ site = aiohttp.web.TCPSite(self.runner, 'localhost', 8080)
+ await site.start()
+
+ async def close(self):
+ if self.runner:
+ await self.runner.cleanup()
+
+ async def auth_redir(self, url, state):
+ """Call as part of an OAUTH flow to hand the URL off to interactive browser
+ based authentication. The flow will resume when the OAUTH server
+ redirects back to the localhost server. The final query paremeters
+ will be returned by this function"""
+ queue = asyncio.Queue()
+ self.auth_redirs[state] = (url, queue)
+ return await queue.get()
+
+ def _start(self, request):
+ """Feed redirects to the web browser until all authing is done. FIXME: Some
+ fancy java script should be used to fetch new interactive auth
+ requests"""
+ for I in self.auth_redirs.values():
+ raise aiohttp.web.HTTPFound(I[0])
+ return aiohttp.web.Response(text="Authentication done")
+
+ def _oauth2_msal(self, request):
+ """Use for the Azure AD authentication response redirection"""
+ state = request.query["state"]
+ try:
+ queue = self.auth_redirs[state][1]
+ del self.auth_redirs[state]
+ queue.put_nowait(request.query)
+ except KeyError:
+ pass
+
+ for I in self.auth_redirs.values():
+ raise aiohttp.web.HTTPFound(I[0])
+ raise aiohttp.web.HTTPFound(self.url)
diff --git a/cloud_mdir_sync/office365.py b/cloud_mdir_sync/office365.py
new file mode 100644
index 0000000..0afaa3f
--- /dev/null
+++ b/cloud_mdir_sync/office365.py
@@ -0,0 +1,637 @@
+# SPDX-License-Identifier: GPL-2.0+
+import asyncio
+import datetime
+import functools
+import logging
+import os
+import pickle
+import secrets
+import webbrowser
+from typing import Any, Dict, Union
+
+import aiohttp
+import requests
+
+from . import config, mailbox, messages, util
+
+
+def _retry_protect(func):
+ # Graph can return various error codes, see:
+ # https://docs.microsoft.com/en-us/graph/errors
+ @functools.wraps(func)
+ async def async_wrapper(self, *args, **kwargs):
+ while True:
+ while (self.graph_token is None or self.owa_token is None):
+ await self.authenticate()
+
+ try:
+ return await func(self, *args, **kwargs)
+ except aiohttp.ClientResponseError as e:
+ self.cfg.logger.debug(
+ f"Got HTTP Error {e.code} in {func} for {e.request_info.url!r}"
+ )
+ if (e.code == 401 or # Unauthorized
+ e.code == 403): # Forbidden
+ self.graph_token = None
+ self.owa_token = None
+ await self.authenticate()
+ continue
+ if (e.code == 503 or # Service Unavilable
+ e.code == 509 or # Bandwidth Limit Exceeded
+ e.code == 429 or # Too Many Requests
+ e.code == 504 or # Gateway Timeout
+ e.code == 200): # Success, but error JSON
+ self.cfg.logger.error(f"Graph returns {e}, delaying")
+ await asyncio.sleep(10)
+ continue
+ if (e.code == 400 or # Bad Request
+ e.code == 405 or # Method Not Allowed
+ e.code == 406 or # Not Acceptable
+ e.code == 411 or # Length Required
+ e.code == 413 or # Request Entity Too Large
+ e.code == 415 or # Unsupported Media Type
+ e.code == 422 or # Unprocessable Entity
+ e.code == 501): # Not implemented
+ self.cfg.logger.exception(f"Graph call failed {e.body!r}")
+ raise RuntimeError(f"Graph call failed {e!r}")
+
+ # Other errors we retry after resetting the mailbox
+ raise
+ except (asyncio.TimeoutError,
+ aiohttp.client_exceptions.ClientError):
+ self.cfg.logger.debug(f"Got non-HTTP Error in {func}")
+ await asyncio.sleep(10)
+ continue
+
+ return async_wrapper
+
+
+class GraphAPI(object):
+ """An OAUTH2 authenticated session to the Microsoft Graph API"""
+ graph_scopes = [
+ "https://graph.microsoft.com/User.Read",
+ "https://graph.microsoft.com/Mail.ReadWrite"
+ ]
+ graph_token = None
+ owa_scopes = ["https://outlook.office.com/mail.read"]
+ owa_token = None
+ authenticator = None
+
+ def __init__(self, cfg, domain_id, user, tenant):
+ import msal
+ self.msl_cache = msal.SerializableTokenCache()
+ auth = cfg.msgdb.get_authenticator(domain_id)
+ if auth is not None:
+ self.msl_cache.deserialize(auth)
+
+ self.domain_id = domain_id
+ self.cfg = cfg
+ self.user = user
+ self.web_app = cfg.web_app
+
+ if self.user is not None:
+ self.name = f"{self.user}//{tenant}"
+ else:
+ self.name = f"//{tenant}"
+
+ connector = aiohttp.connector.TCPConnector(limit=20, limit_per_host=5)
+ self.session = aiohttp.ClientSession(connector=connector,
+ raise_for_status=False)
+ self.headers = {}
+ self.owa_headers = {}
+
+ # Use the new format much more immutable ids, this will work better
+ # with our caching scheme. See
+ # https://docs.microsoft.com/en-us/graph/outlook-immutable-id
+ self.headers["Prefer"] = 'IdType="ImmutableId"'
+
+ # FIXME: tennant/authority
+ self.msal = msal.PublicClientApplication(
+ client_id="122f4826-adf9-465d-8e84-e9d00bc9f234",
+ authority=f"https://login.microsoftonline.com/{tenant}",
+ token_cache=self.msl_cache)
+
+ def _cached_authenticate(self):
+ accounts = self.msal.get_accounts(self.user)
+ if len(accounts) != 1:
+ return False
+
+ try:
+ if self.graph_token is None:
+ self.graph_token = self.msal.acquire_token_silent(
+ scopes=self.graph_scopes, account=accounts[0])
+ if self.graph_token is None or "access_token" not in self.graph_token:
+ self.graph_token = None
+ return False
+
+ if self.owa_token is None:
+ self.owa_token = self.msal.acquire_token_silent(
+ scopes=self.owa_scopes, account=accounts[0])
+ if self.owa_token is None or "access_token" not in self.owa_token:
+ self.owa_token = None
+ return False
+ except requests.RequestException as e:
+ self.cfg.logger.error(f"msal failed on request {e}")
+ self.graph_token = None
+ self.owa_token = None
+ return False
+
+ self.headers["Authorization"] = self.graph_token[
+ "token_type"] + " " + self.graph_token["access_token"]
+ self.owa_headers["Authorization"] = self.owa_token[
+ "token_type"] + " " + self.owa_token["access_token"]
+ self.cfg.msgdb.set_authenticator(self.domain_id,
+ self.msl_cache.serialize())
+ return True
+
+ @util.log_progress(lambda self: f"Azure AD Authentication for {self.name}")
+ async def _do_authenticate(self):
+ while not self._cached_authenticate():
+ self.graph_token = None
+ self.owa_token = None
+
+ redirect_url = self.web_app.url + "oauth2/msal"
+ state = hex(id(self)) + secrets.token_urlsafe(8)
+ url = self.msal.get_authorization_request_url(
+ scopes=self.graph_scopes + self.owa_scopes,
+ state=state,
+ login_hint=self.user,
+ redirect_uri=redirect_url)
+
+ print(
+ f"Goto {self.cfg.web_app.url} in a web browser to authenticate"
+ )
+ webbrowser.open(url)
+ q = await self.cfg.web_app.auth_redir(url, state)
+ code = q["code"]
+
+ try:
+ self.graph_token = self.msal.acquire_token_by_authorization_code(
+ code=code,
+ scopes=self.graph_scopes,
+ redirect_uri=redirect_url)
+ except requests.RequestException as e:
+ self.cfg.logger.error(f"msal failed on request {e}")
+ await asyncio.sleep(10)
+
+ async def authenticate(self):
+ """Obtain OAUTH bearer tokens for MS services. For users this has to be done
+ interactively via the browser. A cache is used for tokens that have
+ not expired and they can be refreshed non-interactively into active
+ tokens within some limited time period."""
+ # Ensure we only ever have one authentication open at once. Other
+ # threads will all block here on the single authenticator.
+ if self.authenticator is None:
+ self.authenticator = asyncio.create_task(self._do_authenticate())
+ auth = self.authenticator
+ await auth
+ if self.authenticator is auth:
+ self.authenticator = None
+
+ async def _check_op(self, op):
+ if op.status >= 200 and op.status <= 299:
+ return
+ e = aiohttp.ClientResponseError(op.request_info,
+ op.history,
+ code=op.status,
+ message=op.reason,
+ headers=op.headers)
+ try:
+ e.body = await op.json()
+ except:
+ pass
+ raise e
+
+ async def _check_json(self, op):
+ """Check an operation for errors and convert errors to exceptions. Graph can
+ return an HTTP failure code, or (rarely) a JSON error message and a 200 success."""
+ await self._check_op(op)
+
+ res = await op.json()
+ if "error" in res:
+ e = aiohttp.ClientResponseError(op.request_info,
+ op.history,
+ code=op.status,
+ message=op.reason,
+ headers=op.headers)
+ e.body = res
+ raise e
+ return res
+
+ @_retry_protect
+ async def get_to_file(self, outf, ver, path, params=None, dos2unix=False):
+ """Copy the response of a GET operation into outf"""
+ async with self.session.get(f"https://graph.microsoft.com/{ver}{path}",
+ headers=self.headers,
+ params=params) as op:
+ await self._check_op(op)
+ carry = b""
+ async for data in op.content.iter_any():
+ if dos2unix:
+ if carry:
+ data = carry + data
+ data = data.replace(b"\r\n", b"\n")
+ if data[-1] == b'\r':
+ carry = data[-1:len(data)]
+ data = data[:-1]
+ else:
+ carry = b""
+ outf.write(data)
+ if dos2unix and carry:
+ outf.write(carry)
+
+ @_retry_protect
+ async def get_json(self, ver, path, params=None):
+ """Return the JSON dictionary from the GET operation"""
+ async with self.session.get(f"https://graph.microsoft.com/{ver}{path}",
+ headers=self.headers,
+ params=params) as op:
+ return await self._check_json(op)
+
+ @_retry_protect
+ async def post_json(self, ver, path, body, params=None):
+ """Return the JSON dictionary from the POST operation"""
+ async with self.session.post(
+ f"https://graph.microsoft.com/{ver}{path}",
+ headers=self.headers,
+ json=body,
+ params=params) as op:
+ return await self._check_json(op)
+
+ @_retry_protect
+ async def patch_json(self, ver, path, body, params=None):
+ """Return the JSON dictionary from the PATCH operation"""
+ async with self.session.patch(
+ f"https://graph.microsoft.com/{ver}{path}",
+ headers=self.headers,
+ json=body,
+ params=params) as op:
+ return await self._check_json(op)
+
+ @_retry_protect
+ async def delete(self, ver, path):
+ """Issue a delete. For Messages delete doesn't put it in the Deleted Items
+ folder, it is just deleted."""
+ async with self.session.delete(
+ f"https://graph.microsoft.com/{ver}{path}",
+ headers=self.headers) as op:
+ await self._check_op(op)
+ async for _ in op.content.iter_any():
+ pass
+
+ async def get_json_paged(self, ver, path, params=None):
+ """Return an iterator that iterates over every JSON element in a paged
+ result"""
+ # See https://docs.microsoft.com/en-us/graph/paging
+ resp = await self.get_json(ver, path, params)
+ while True:
+ for I in resp["value"]:
+ yield I
+ uri = resp.get("@odata.nextLink")
+ if uri is None:
+ break
+ async with self.session.get(uri, headers=self.headers) as op:
+ resp = await self._check_json(op)
+
+ @_retry_protect
+ async def owa_subscribe(self, resource, changetype):
+ """Graph does not support streaming subscriptions, so we use the OWA interface
+ instead. See
+
+ https://docs.microsoft.com/en-us/previous-versions/office/office-365-api/api/beta/notify-streaming-rest-operations"""
+ body = {
+ "@odata.type": "#Microsoft.OutlookServices.StreamingSubscription",
+ "Resource": resource,
+ "ChangeType": changetype
+ }
+
+ async with self.session.post(
+ f"https://outlook.office.com/api/beta/me/subscriptions",
+ headers=self.owa_headers,
+ json=body) as op:
+ return await self._check_json(op)
+
+ async def owa_get_notifications(self, subscription_id):
+ """Return the notifications as an async iterator"""
+ body = {
+ "ConnectionTimeoutInMinutes": 2,
+ "KeepAliveNotificationIntervalInSeconds": 10,
+ "SubscriptionIds": [subscription_id]
+ }
+ timeout = aiohttp.ClientTimeout(sock_read=20)
+ # FIXME: fine tune timeouts https://docs.aiohttp.org/en/stable/client_quickstart.html#timeouts
+ # FIXME: retry protect for this
+ async with self.session.post(
+ f"https://outlook.office.com/api/beta/Me/GetNotifications",
+ headers=self.owa_headers,
+ json=body,
+ timeout=timeout) as op:
+ await self._check_op(op)
+
+ # There seems to be no relation to http chunks and json fragments,
+ # other than the last chunk before sleeping terminates all the
+ # jsons. I guess this is supposed to be parsed using a fancy
+ # parser. FIXME: We do need to parse this to exclude the keep alives
+ first = True
+ buf = b""
+ async for data, chunk_end in op.content.iter_chunks():
+ buf += data
+ if not chunk_end:
+ continue
+
+ # Last, but probably not reliably so
+ if buf == b']}':
+ return
+
+ if not first:
+ yield buf
+ else:
+ first = False
+ buf = b""
+
+ async def close(self):
+ await self.session.close()
+
+
+class O365Mailbox(mailbox.Mailbox):
+ """Cloud Office365 mailbox using the Microsoft Graph RESET API for data access"""
+ storage_kind = "o365_v0"
+ loop: asyncio.AbstractEventLoop
+ timer = None
+ use_owa_subscribe = True
+ cfg: config.Config
+ graph: GraphAPI
+
+ def __init__(self, mailbox, user=None, tenant="common"):
+ super().__init__()
+ self.mailbox = mailbox
+ self.tenant = tenant
+ self.user = user
+
+ async def setup_mbox(self, cfg):
+ """Setup access to the authenticated API domain for this endpoint"""
+ self.cfg = cfg
+ self.loop = cfg.loop
+ did = f"o365-{self.user}-{self.tenant}"
+ self.graph = cfg.domains.get(did)
+ if self.graph is None:
+ self.graph = GraphAPI(cfg, did, self.user, self.tenant)
+ cfg.domains[did] = self.graph
+
+ self.name = f"{self.graph.name}:{self.mailbox}"
+
+ json = await self.graph.get_json(
+ "v1.0",
+ f"/me/mailFolders",
+ params={"$filter": f"displayName eq '{self.mailbox}'"})
+ if len(json["value"]) != 1:
+ raise ValueError(f"Invalid mailbox name {self.mailbox!r}")
+ self.json = json["value"][0]
+
+ self.mailbox_id = self.json["id"]
+ if self.use_owa_subscribe:
+ asyncio.create_task(self._monitor_changes())
+
+ @mailbox.update_on_failure
+ async def _fetch_message(self, msg, msgdb):
+ with util.log_progress_ctx(logging.DEBUG,
+ f"Downloading {msg.email_id}",
+ lambda msg: f" {util.sizeof_fmt(msg.size)}",
+ msg), msgdb.get_temp() as F:
+ # For some reason this returns a message with dos line
+ # endings. Really weird.
+ await self.graph.get_to_file(
+ F,
+ "v1.0",
+ f"/me/messages/{msg.storage_id}/$value",
+ dos2unix=True)
+ msg.size = F.tell()
+ msg.content_hash = msgdb.store_hashed_msg(msg, F)
+
+ def _json_to_flags(self, jmsg):
+ """This is was remarkably difficult to find out, and seems completely
+ undocumented."""
+ flags = 0
+ # First class properties are easy
+ if bool(jmsg["isRead"]):
+ flags |= messages.Message.FLAG_READ
+ if jmsg["flag"]["flagStatus"] == "flagged":
+ flags |= messages.Message.FLAG_FLAGGED
+
+ # 'Replied' is not a concept in MAPI, at least not a consistent concept.
+ for prop in jmsg.get("singleValueExtendedProperties", []):
+ if prop["id"] == "Integer 0x1080":
+ # Closely matches OWA and the Outlook App
+ # PidTagIconIndex
+ # https://docs.microsoft.com/en-us/openspecs/exchange_server_protocols/ms-oxprops/eeca3a02-14e7-419b-8918-986275a2fac0
+ val = int(prop["value"])
+ if (val == 0x105 or # Replied mail
+ val == 0x106): # Forwarded mail
+ flags |= messages.Message.FLAG_REPLIED
+ elif prop["id"] == "Integer 0x1081":
+ # Sort of matches OWA and the Outlook App
+ # PidTagLastVerbExecuted
+ # https://docs.microsoft.com/en-us/openspecs/exchange_server_protocols/ms-oxprops/4ec55eac-14b3-4dfa-adf3-340c0dcccd44
+ val = int(prop["value"])
+ if (val == 102 or # NOTEIVERB_REPLYTOSENDER
+ val == 103 or # NOTEIVERB_REPLYTOALL
+ val == 104): # NOTEIVERB_FORWARD
+ flags |= messages.Message.FLAG_REPLIED
+ elif prop["id"] == "Integer 0xe17":
+ # This is what IMAP uses but we can't set it
+ # PidTagMessageStatus
+ # https://docs.microsoft.com/en-us/openspecs/exchange_server_protocols/ms-oxprops/5d00fe2b-9548-4953-97ba-89b1aa0ba5ac
+ if int(prop["value"]) & 0x200: # MSGSTATUS_ANSWERED
+ flags |= messages.Message.FLAG_REPLIED
+ else:
+ util.pj(prop)
+ return flags
+
+ @util.log_progress(lambda self: f"Updating Message List for {self.name}",
+ lambda self: f", {len(self.messages)} msgs")
+ @mailbox.update_on_failure
+ async def update_message_list(self, msgdb):
+ """Retrieve the list of all messages and store all the message content in the
+ content_hash message database"""
+ todo = []
+ msgs = []
+
+ async for jmsg in self.graph.get_json_paged(
+ "v1.0",
+ f"/me/mailFolders/{self.mailbox_id}/messages",
+ params=
+ {
+ "$select":
+ "internetMessageId,isRead,Flag,receivedDateTime,singleValueExtendedProperties",
+ "$expand":
+ "SingleValueExtendedProperties($filter=(id eq 'Integer 0xe17') or"
+ " (id eq 'Integer 0x1080'))",
+ }):
+ msg = messages.Message(mailbox=self,
+ storage_id=jmsg["id"],
+ email_id=jmsg["internetMessageId"])
+ msg.received_time = datetime.datetime.strptime(
+ jmsg["receivedDateTime"], '%Y-%m-%dT%H:%M:%SZ')
+ msg.flags = self._json_to_flags(jmsg)
+
+ if not msgdb.have_content(msg):
+ todo.append(
+ asyncio.create_task(self._fetch_message(msg, msgdb)))
+
+ msgs.append(msg)
+ await asyncio.gather(*todo)
+
+ res = {}
+ for msg in msgs:
+ # Something went wrong?
+ if msg.content_hash is not None:
+ res[msg.content_hash] = msg
+ self.messages = res
+ self.need_update = False
+ if not self.use_owa_subscribe:
+ if self.timer:
+ self.timer.cancel()
+ self.timer = None
+ self.timer = self.loop.call_later(60, self._timer)
+ if self.cfg.trace_file is not None:
+ pickle.dump(["0365_update_message_list", self.name, self.messages],
+ self.cfg.trace_file)
+
+ async def _monitor_changes(self):
+ """Keep a persistent PUT that returns data when there are changes."""
+ r = None
+ while True:
+ if r is None:
+ r = await self.graph.owa_subscribe(
+ f"https://outlook.office.com/api/beta/me/mailfolders('{self.mailbox_id}')/Messages",
+ "Created,Updated,Deleted")
+ try:
+ # This should use a single notification channel per graph,
+ # however until we can parse the incremental json it can't be
+ # done.
+ async for data in self.graph.owa_get_notifications(r["Id"]):
+ # hacky hacky
+ if (data ==
+ b'{"@odata.type":"#Microsoft.OutlookServices.KeepAliveNotification","Status":"Ok"}'
+ or data ==
+ b',{"@odata.type":"#Microsoft.OutlookServices.KeepAliveNotification","Status":"Ok"}'
+ ):
+ continue
+ self.need_update = True
+ self.changed_event.set()
+ except (asyncio.TimeoutError,
+ aiohttp.client_exceptions.ClientError):
+ r = None
+ continue
+
+ def _timer(self):
+ self.need_update = True
+ self.changed_event.set()
+
+ def force_content(self, msgdb, msgs):
+ raise RuntimeError("Cannot move messages into the Cloud")
+
+ @util.log_progress(lambda self: f"Uploading local changes for {self.name}",
+ lambda self: f", {self.last_merge_len} changes ")
+ @mailbox.update_on_failure
+ async def merge_content(self, msgs):
+ # There is a batching API for this kind of stuff as well:
+ # https://docs.microsoft.com/en-us/graph/json-batching
+ self.last_merge_len = 0
+ todo = []
+ if self.cfg.trace_file is not None:
+ pickle.dump(["merge_content", self.name, self.messages, msgs],
+ self.cfg.trace_file)
+ for ch, mpair in msgs.items():
+ # lmsg is the message in the local mailbox
+ # cmsg is the current cloud message in this class
+ # old_cmsg is the original cloud message from the last sync
+ lmsg, old_cmsg = mpair
+ cmsg = self.messages.get(ch)
+
+ # Cloud message was deleted, cloud takes priority
+ if cmsg is None:
+ continue
+ if lmsg is None:
+ # Debugging that the message really is to be deleted
+ assert os.stat(os.path.join(self.cfg.msgdb.hashes_dir,
+ ch)).st_nlink == 1
+ # Delete cloud message
+ todo.append(
+ self.graph.post_json(
+ "v1.0",
+ f"/me/mailFolders/{self.mailbox}/messages/{cmsg.storage_id}/move",
+ body={"destinationId": "deleteditems"}))
+ # FIXME: This should be after the operation completes?
+ del self.messages[ch]
+ continue
+
+ if (lmsg.flags == old_cmsg.flags or lmsg.flags == cmsg.flags):
+ continue
+
+ cloud_flags = cmsg.flags ^ old_cmsg.flags
+ flag_mask = messages.Message.ALL_FLAGS ^ cloud_flags
+ nflags = (lmsg.flags & flag_mask) | (cmsg.flags & cloud_flags)
+ modified_flags = nflags ^ cmsg.flags
+
+ # FIXME: https://docs.microsoft.com/en-us/graph/best-practices-concept#getting-minimal-responses
+ # FIXME: Does the ID change?
+ patch: Dict[str, Any] = {}
+ if modified_flags & messages.Message.FLAG_READ:
+ patch["isRead"] = bool(nflags & messages.Message.FLAG_READ)
+ if modified_flags & messages.Message.FLAG_FLAGGED:
+ patch["flag"] = {
+ "flagStatus":
+ "flagged" if nflags
+ & messages.Message.FLAG_FLAGGED else "notFlagged"
+ }
+ if modified_flags & messages.Message.FLAG_REPLIED:
+ # This can only be described as an undocumented disaster.
+ # Different clients set different things. The Icon shows up in
+ # OWS and the Mobile app. The MessageStatus shows up in
+ # IMAP. IMAP sets the MessageStatus but otherwise does not
+ # interact with the other two. We can't seem to set
+ # MessageStatus over REST because it needs RopSetMessageStatus.
+ if nflags & messages.Message.FLAG_REPLIED:
+ now = datetime.datetime.utcnow().strftime(
+ "%Y-%m-%dT%H:%M:%SZ")
+ patch["singleValueExtendedProperties"] = [
+ # PidTagLastVerbExecuted
+ {
+ "id": "Integer 0x1081",
+ "value": "103"
+ },
+ # PidTagLastVerbExecutionTime
+ {
+ "id": "SystemTime 0x1082",
+ "value": now
+ },
+ # PidTagIconIndex
+ {
+ "id": "Integer 0x1080",
+ "value": "261"
+ },
+ ]
+ else:
+ # Rarely does anything undo a replied flag, but it is
+ # useful for testing.
+ patch["singleValueExtendedProperties"] = [
+ {
+ "id":
+ "Integer 0x1080", # PidTagIconIndex
+ "value":
+ "256" if nflags
+ & messages.Message.FLAG_READ else "-1"
+ },
+ ]
+
+ if patch:
+ todo.append(
+ self.graph.patch_json(
+ "v1.0",
+ f"/me/mailFolders/{self.mailbox}/messages/{cmsg.storage_id}",
+ body=patch))
+ cmsg.flags = nflags
+
+ await asyncio.gather(*todo)
+ self.last_merge_len = len(todo)
diff --git a/cloud_mdir_sync/util.py b/cloud_mdir_sync/util.py
new file mode 100644
index 0000000..7799356
--- /dev/null
+++ b/cloud_mdir_sync/util.py
@@ -0,0 +1,70 @@
+# SPDX-License-Identifier: GPL-2.0+
+import contextlib
+import functools
+import inspect
+import json
+import logging
+import time
+
+from . import config
+
+
+@contextlib.contextmanager
+def log_progress_ctx(level, start_msg, end_msg, *args):
+ if inspect.isfunction(start_msg):
+ start_msg = start_msg(*args)
+ if end_msg is None:
+ end_msg = " "
+
+ config.logger.log(level, f"Starting {start_msg}")
+ st = time.perf_counter()
+ try:
+ yield
+ et = time.perf_counter()
+ except Exception as e:
+ if inspect.isfunction(end_msg):
+ end_msg = end_msg(*args)
+ config.logger.warning(f"FAILED({e!r}): {start_msg}")
+ raise
+
+ if inspect.isfunction(end_msg):
+ end_msg = end_msg(*args)
+ if end_msg.startswith("-"):
+ start_msg = ""
+ config.logger.info(
+ f"Completed {start_msg}{end_msg} (took {et-st:.4f} secs)")
+
+
+def log_progress(start_msg, end_msg=None, level=logging.INFO):
+ """Decorator to log the start/end and duration of a method"""
+ def inner(func):
+ @functools.wraps(func)
+ def wrapper(self, *args, **kwargs):
+ with log_progress_ctx(level, start_msg, end_msg, self):
+ res = func(self, *args, **kwargs)
+ return res
+
+ @functools.wraps(func)
+ async def async_wrapper(self, *args, **kwargs):
+ with log_progress_ctx(level, start_msg, end_msg, self):
+ res = await func(self, *args, **kwargs)
+ return res
+
+ if inspect.iscoroutinefunction(func):
+ return async_wrapper
+ return wrapper
+
+ return inner
+
+
+# https://stackoverflow.com/questions/1094841/reusable-library-to-get-human-readable-version-of-file-size
+def sizeof_fmt(num, suffix='B'):
+ for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
+ if abs(num) < 1024.0:
+ return "%3.1f%s%s" % (num, unit, suffix)
+ num /= 1024.0
+ return "%.1f%s%s" % (num, 'Yi', suffix)
+
+
+def pj(json_dict):
+ print(json.dumps(json_dict, indent=4, sort_keys=True))