aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJason Gunthorpe <jgg@mellanox.com>2020-02-04 14:45:48 -0400
committerJason Gunthorpe <jgg@mellanox.com>2020-02-06 15:01:40 -0400
commitf54472e2013f271146adf591843fed26ca7c3392 (patch)
tree810f85eefd0e937866ccb3c6bb3fb70c68a697bc
parent5c037ad30837f7e94d8dcfe160ad10e51056512c (diff)
downloadcloud_mdir_sync-f54472e2013f271146adf591843fed26ca7c3392.tar.gz
cloud_mdir_sync-f54472e2013f271146adf591843fed26ca7c3392.tar.bz2
cloud_mdir_sync-f54472e2013f271146adf591843fed26ca7c3392.zip
GMail support
Basic support for GMail using the REST API Signed-off-by: Jason Gunthorpe <jgg@mellanox.com>
-rw-r--r--cloud_mdir_sync/config.py13
-rw-r--r--cloud_mdir_sync/gmail.py639
-rw-r--r--cloud_mdir_sync/main.py2
-rw-r--r--cloud_mdir_sync/oauth.py26
-rw-r--r--cloud_mdir_sync/office365.py2
-rwxr-xr-xsetup.py2
6 files changed, 673 insertions, 11 deletions
diff --git a/cloud_mdir_sync/config.py b/cloud_mdir_sync/config.py
index f0b5796..85351d5 100644
--- a/cloud_mdir_sync/config.py
+++ b/cloud_mdir_sync/config.py
@@ -93,6 +93,19 @@ class Config(object):
tenant=account[1]))
return self.cloud_mboxes[-1]
+ def GMail_Account(self, user):
+ """Define a GMail account credential. The user must be specified as a
+ fully qualified Google Account email address. This supports both
+ consumer GMail accounts, and accounts linked to a G-Suite account."""
+ return (user,)
+
+ def GMail(self, label, account):
+ """Create a cloud mailbox for Office365. Mailbox is the name of O365
+ mailbox to use, account should be the result of Office365_Account"""
+ from .gmail import GMailMailbox
+ self.cloud_mboxes.append(GMailMailbox(label, user=account[0]))
+ return self.cloud_mboxes[-1]
+
def MailDir(self, directory):
"""Create a local maildir to hold messages"""
from .maildir import MailDirMailbox
diff --git a/cloud_mdir_sync/gmail.py b/cloud_mdir_sync/gmail.py
new file mode 100644
index 0000000..1975ab1
--- /dev/null
+++ b/cloud_mdir_sync/gmail.py
@@ -0,0 +1,639 @@
+# SPDX-License-Identifier: GPL-2.0+
+import asyncio
+import base64
+import collections
+import datetime
+import functools
+import hashlib
+import logging
+import secrets
+import webbrowser
+from typing import Dict, List, Optional, Set
+
+import aiohttp
+import oauthlib
+import requests_oauthlib
+
+from . import config, mailbox, messages, util
+
+
+class NativePublicApplicationClient(oauthlib.oauth2.WebApplicationClient):
+ """Amazingly oauthlib doesn't include client size PCKE support
+ Hack it into the WebApplicationClient"""
+ def __init__(self, client_id):
+ super().__init__(client_id)
+
+ def _code_challenge_method_s256(self, verifier):
+ return base64.urlsafe_b64encode(
+ hashlib.sha256(verifier.encode()).digest()).decode().rstrip('=')
+
+ def prepare_request_uri(self,
+ authority_uri,
+ redirect_uri,
+ scope=None,
+ state=None,
+ **kwargs):
+ self.verifier = secrets.token_urlsafe(96)
+ return super().prepare_request_uri(
+ authority_uri,
+ redirect_uri=redirect_uri,
+ scope=scope,
+ state=state,
+ code_challenge=self._code_challenge_method_s256(self.verifier),
+ code_challenge_method="S256",
+ **kwargs)
+
+ def prepare_request_body(self,
+ code=None,
+ redirect_uri=None,
+ body='',
+ include_client_id=True,
+ **kwargs):
+ return super().prepare_request_body(
+ code=code,
+ redirect_uri=redirect_uri,
+ body=body,
+ include_client_id=include_client_id,
+ code_verifier=self.verifier,
+ **kwargs)
+
+
+def _retry_protect(func):
+ @functools.wraps(func)
+ async def async_wrapper(self, *args, **kwargs):
+ while True:
+ while self.headers is None:
+ await self.authenticate()
+
+ try:
+ return await func(self, *args, **kwargs)
+ except aiohttp.ClientResponseError as e:
+ self.cfg.logger.debug(
+ f"Got HTTP Error {e.code} in {func} for {e.request_info.url!r}"
+ )
+ if (e.code == 401 or # Unauthorized
+ e.code == 403): # Forbidden
+ self.headers = None
+ await self.authenticate()
+ continue
+ if (e.code == 503 or # Service Unavilable
+ e.code == 400 or # Bad Request
+ e.code == 509 or # Bandwidth Limit Exceeded
+ e.code == 429 or # Too Many Requests
+ e.code == 504 or # Gateway Timeout
+ e.code == 200): # Success, but error JSON
+ self.cfg.logger.error(f"Gmail returns {e}, delaying")
+ await asyncio.sleep(10)
+ continue
+ if (e.code == 400 or # Bad Request
+ e.code == 405 or # Method Not Allowed
+ e.code == 406 or # Not Acceptable
+ e.code == 411 or # Length Required
+ e.code == 413 or # Request Entity Too Large
+ e.code == 415 or # Unsupported Media Type
+ e.code == 422 or # Unprocessable Entity
+ e.code == 501): # Not implemented
+ self.cfg.logger.exception(f"Gmail call failed {e.body!r}")
+ raise RuntimeError(f"Gmail call failed {e!r}")
+
+ # Other errors we retry after resetting the mailbox
+ raise
+ except (asyncio.TimeoutError,
+ aiohttp.client_exceptions.ClientError):
+ self.cfg.logger.debug(f"Got non-HTTP Error in {func}")
+ await asyncio.sleep(10)
+ continue
+
+ return async_wrapper
+
+
+class GmailAPI(object):
+ """An OAUTH2 authenticated session to the Google gmail API"""
+ # From ziepe.ca
+ client_id = "14979213351-bik90v3b8b9f22160ura3oah71u3l113.apps.googleusercontent.com"
+ # Google doesn't follow RFC8252 8.5 and does require the client_secret,
+ # but it is not secret.
+ client_secret = "cLICGg-LVQuMAPTh3VxTC42p"
+ authenticator = None
+ headers: Optional[Dict[str, str]] = None
+
+ def __init__(self, cfg: config.Config, domain_id: str, user: str):
+ self.domain_id = domain_id
+ self.cfg = cfg
+ self.user = user
+
+ connector = aiohttp.connector.TCPConnector(limit=20, limit_per_host=5)
+ self.session = aiohttp.ClientSession(connector=connector,
+ raise_for_status=False)
+
+ self.redirect_url = cfg.web_app.url + "oauth2/gmail"
+ self.api_token = cfg.msgdb.get_authenticator(domain_id)
+ self.oauth = requests_oauthlib.OAuth2Session(
+ client_id=self.client_id,
+ client=NativePublicApplicationClient(self.client_id),
+ redirect_uri=self.redirect_url,
+ token=self.api_token,
+ scope=[
+ "https://www.googleapis.com/auth/gmail.modify",
+ # This one is needed for SMTP ?
+ #"https://mail.google.com/",
+ ])
+
+ if self.api_token:
+ self._set_token()
+
+ def _set_token(self):
+ self.cfg.msgdb.set_authenticator(self.domain_id, self.api_token)
+ # We expect to only use a Authorization header
+ self.headers = {}
+ try:
+ _, headers, _ = self.oauth._client.add_token(
+ uri="https://foo/",
+ http_method="GET",
+ headers={},
+ token_placement=oauthlib.oauth2.rfc6749.clients.AUTH_HEADER)
+ assert headers
+ except oauthlib.oauth2.TokenExpiredError:
+ return
+ except oauthlib.oauth2.OAuth2Error:
+ self.api_token = None
+ self.headers = headers
+
+ def _refresh_authenticate(self):
+ if not self.api_token:
+ return False
+
+ try:
+ self.api_token = self.oauth.refresh_token(
+ token_url='https://oauth2.googleapis.com/token',
+ client_id=self.oauth.client_id,
+ client_secret=self.client_secret,
+ refresh_token=self.api_token["refresh_token"])
+ except oauthlib.oauth2.OAuth2Error:
+ self.api_token = None
+ return False
+ self._set_token()
+ return bool(self.headers)
+
+ @util.log_progress(lambda self: f"Google Authentication for {self.user}")
+ async def _do_authenticate(self):
+ while not self._refresh_authenticate():
+ self.api_token = None
+
+ # This flow follows the directions of
+ # https://developers.google.com/identity/protocols/OAuth2InstalledApp
+ state = hex(id(self)) + secrets.token_urlsafe(8)
+ url, state = self.oauth.authorization_url(
+ 'https://accounts.google.com/o/oauth2/v2/auth',
+ state=state,
+ access_type="offline",
+ login_hint=self.user)
+
+ print(
+ f"Goto {self.cfg.web_app.url} in a web browser to authenticate"
+ )
+ webbrowser.open(url)
+ q = await self.cfg.web_app.auth_redir(url, state,
+ self.redirect_url)
+
+ self.api_token = self.oauth.fetch_token(
+ 'https://oauth2.googleapis.com/token',
+ include_client_id=True,
+ client_secret=self.client_secret,
+ code=q["code"])
+ self._set_token()
+
+ async def authenticate(self):
+ """Obtain OAUTH bearer tokens for MS services. For users this has to be done
+ interactively via the browser. A cache is used for tokens that have
+ not expired and they can be refreshed non-interactively into active
+ tokens within some limited time period."""
+ # Ensure we only ever have one authentication open at once. Other
+ # threads will all block here on the single authenticator.
+ if self.authenticator is None:
+ self.authenticator = asyncio.create_task(self._do_authenticate())
+ auth = self.authenticator
+ await auth
+ if self.authenticator is auth:
+ self.authenticator = None
+
+ async def _check_op(self, op: aiohttp.ClientResponse):
+ if op.status >= 200 and op.status <= 299:
+ return
+ e = aiohttp.ClientResponseError(op.request_info,
+ op.history,
+ code=op.status,
+ message=op.reason,
+ headers=op.headers)
+ try:
+ e.body = await op.json()
+ except:
+ pass
+ raise e
+
+ async def _check_json(self, op: aiohttp.ClientResponse):
+ await self._check_op(op)
+ return await op.json()
+
+ async def _check_empty(self, op: aiohttp.ClientResponse):
+ await self._check_op(op)
+ d = await op.text()
+ if d:
+ e = aiohttp.ClientResponseError(
+ op.request_info,
+ op.history,
+ code=op.status,
+ message="POST returned data, not empty",
+ headers=op.headers)
+ raise e
+
+ @_retry_protect
+ async def get_json(self, ver, path, params=None):
+ """Return the JSON dictionary from the GET operation"""
+ async with self.session.get(
+ f"https://www.googleapis.com/gmail/{ver}{path}",
+ headers=self.headers,
+ params=params) as op:
+ return await self._check_json(op)
+
+ @_retry_protect
+ async def post_json(self, ver, path, body, params=None):
+ """Return the JSON dictionary from the POST operation"""
+ async with self.session.post(
+ f"https://www.googleapis.com/gmail/{ver}{path}",
+ headers=self.headers,
+ json=body,
+ params=params) as op:
+ return await self._check_empty(op)
+
+ async def get_json_paged(self,
+ ver,
+ path,
+ key,
+ params=None,
+ last_json=None):
+ """Return an iterator that iterates over every JSON element in a paged
+ result. last_json is a list that will contain only the last json dict
+ returned"""
+ params = dict(params)
+ resp = await self.get_json(ver, path, params)
+ while True:
+ for I in resp.get(key, []):
+ yield I
+ token = resp.get("nextPageToken")
+ if token is None:
+ if last_json is not None:
+ last_json[:] = [resp]
+ return
+ # FIXME: Is this right, or should we drop the other params?
+ params["pageToken"] = token
+ resp = await self.get_json(ver, path, params=params)
+
+ async def close(self):
+ await self.session.close()
+
+
+class GMailMessage(messages.Message):
+ gmail_labels: Optional[Set[str]] = None
+
+ def __init__(self, mailbox, gmail_id, gmail_labels=None):
+ super().__init__(mailbox=mailbox, storage_id=gmail_id)
+ # GMail does not return the email_id, but it does have a stable REST
+ # ID, so if we have the REST ID in the database then we can compute
+ # the email_id
+ self.content_hash = mailbox.cfg.msgdb.content_hashes_cloud.get(
+ self.cid())
+ if self.content_hash:
+ self.email_id = mailbox.cfg.msgdb.content_msgid[self.content_hash]
+ self.gmail_labels = gmail_labels
+ if self.gmail_labels:
+ self._labels_to_flags()
+
+ def _labels_to_flags(self):
+ assert self.gmail_labels is not None
+ flags = 0
+ if "UNREAD" not in self.gmail_labels:
+ flags |= messages.Message.FLAG_READ
+ if "STARRED" in self.gmail_labels:
+ flags |= messages.Message.FLAG_FLAGGED
+ # Unfortunately other IMAP flags do not seem to be available through
+ # the REST interface
+ self.flags = flags
+
+ def update_from_json(self, jmsg):
+ self.gmail_labels = set(jmsg["labelIds"])
+ internal_date = int(jmsg["internalDate"])
+ self.received_time = datetime.datetime.fromtimestamp(internal_date /
+ 1000.0)
+
+ self._labels_to_flags()
+ if "payload" in jmsg:
+ for hdr in jmsg["payload"]["headers"]:
+ if hdr["name"].lower() == "message-id":
+ if self.email_id is None:
+ self.email_id = hdr["value"]
+ else:
+ assert self.email_id == hdr["value"]
+ break
+
+
+class GMailMailbox(mailbox.Mailbox):
+ """Cloud GMail mailbox using the GMail RESET API for data access"""
+ storage_kind = "gmail_v1"
+ supported_flags = (messages.Message.FLAG_READ
+ | messages.Message.FLAG_FLAGGED
+ | messages.Message.FLAG_DELETED)
+ timer = None
+ cfg: config.Config
+ gmail: GmailAPI
+ gmail_messages: Dict[str, GMailMessage]
+ history_delta = None
+ delete_action = "archive" # or delete
+
+ def __init__(self, label: str, user: str):
+ super().__init__()
+ self.label_name = label
+ self.user = user
+ self.gmail_messages = {}
+
+ async def setup_mbox(self, cfg: config.Config):
+ """Setup access to the authenticated API domain for this endpoint"""
+ self.cfg = cfg
+ did = f"gmail-{self.user}"
+ self.name = f"{self.user}:{self.label_name}"
+ gmail = cfg.domains.get(did)
+ if gmail is None:
+ self.gmail = GmailAPI(cfg, did, self.user)
+ cfg.domains[did] = self.gmail
+ else:
+ self.gmail = gmail
+
+ # Verify the label exists
+ jmsg = await self.gmail.get_json("v1", f"/users/me/labels")
+ for I in jmsg["labels"]:
+ if I["name"] == self.label_name:
+ self.label = I["id"]
+ break
+ else:
+ raise ValueError(f"GMail label {self.label_name!r} not found")
+
+ async def _fetch_metadata(self, msg: GMailMessage):
+ params = {"format": "metadata"}
+ if msg.email_id is None:
+ params["metadataHeaders"] = "message-id"
+ jmsg = await self.gmail.get_json(
+ "v1", f"/users/me/messages/{msg.storage_id}", params=params)
+ msg.update_from_json(jmsg)
+ return jmsg["historyId"]
+
+ async def _fetch_message(self, msg: GMailMessage,
+ msgdb: messages.MessageDB):
+ with util.log_progress_ctx(logging.DEBUG,
+ f"Downloading {msg.storage_id}",
+ lambda msg: f" {util.sizeof_fmt(msg.size)}",
+ msg), msgdb.get_temp() as F:
+ jmsg = await self.gmail.get_json(
+ "v1",
+ f"/users/me/messages/{msg.storage_id}",
+ params={
+ "format": "raw",
+ })
+ data = base64.urlsafe_b64decode(jmsg["raw"])
+ data = data.replace(b"\r\n", b"\n")
+ F.write(data)
+ msg.size = F.tell()
+ msg.update_from_json(jmsg)
+ msg.content_hash = msgdb.store_hashed_msg(msg, F)
+ return jmsg["historyId"]
+
+ async def _fetch_all_messages(self, msgdb: messages.MessageDB):
+ """Perform a full synchronization of the mailbox"""
+ start_history_id = None
+ todo = []
+ msgs = []
+ async for jmsg in self.gmail.get_json_paged(
+ "v1",
+ "/users/me/messages",
+ key="messages",
+ params={"labelIds": self.label}):
+ msg = GMailMessage(mailbox=self, gmail_id=jmsg["id"])
+ if not msgdb.have_content(msg):
+ todo.append(
+ asyncio.create_task(self._fetch_message(msg, msgdb)))
+ else:
+ todo.append(asyncio.create_task(self._fetch_metadata(msg)))
+ msgs.append(msg)
+ if todo:
+ start_history_id = await todo[0]
+ await asyncio.gather(*todo)
+
+ return (msgs, start_history_id)
+
+ async def _fetch_delta_messages(self, old_msgs: List[GMailMessage],
+ start_history_id: Optional[str],
+ msgdb: messages.MessageDB):
+ start_history_id: Optional[str]):
+ # Mailbox is empty
+ if start_history_id is None:
+ assert not old_msgs
+ return old_msgs, None
+
+ gmsgs = {msg.storage_id: set(msg.gmail_labels) for msg in old_msgs}
+
+ def add_message(jmsg):
+ jmsg = jmsg["message"]
+ gmail_id = jmsg["id"]
+ if "labelIds" in jmsg:
+ gmsgs[gmail_id] = labels = set(jmsg["labelIds"])
+ else:
+ if gmail_id not in msgs:
+ gmsgs[gmail_id] = labels = set()
+ else:
+ labels = gmsgs[gmail_id]
+ return gmail_id, labels
+
+ last_history = []
+ async for jhistory in self.gmail.get_json_paged(
+ "v1",
+ "/users/me/history",
+ key="history",
+ params={
+ "labelId": self.label,
+ "startHistoryId": start_history_id
+ },
+ last_json=last_history):
+ jf = jhistory.get("messagesAdded")
+ if jf:
+ for jmsg in jf:
+ gmail_id, _ = add_message(jmsg)
+ jf = jhistory.get("labelsAdded")
+ if jf:
+ for jmsg in jf:
+ _, labels = add_message(jmsg)
+ labels.update(jmsg["labelIds"])
+ jf = jhistory.get("labelsRemoved")
+ if jf:
+ for jmsg in jf:
+ _, labels = add_message(jmsg)
+ for I in jmsg["labelIds"]:
+ labels.discard(I)
+ # Deleted means permanently deleted
+ jf = jhistory.get("messagesDeleted")
+ if jf:
+ for jmsg in jf:
+ gmail_id, labels = add_message(jmsg)
+ gmsgs.pop(gmail_id, None)
+
+ next_history_id = last_history[0]["historyId"]
+ old_msgs_map = {msg.storage_id: msg for msg in old_msgs}
+ todo = []
+ msgs = []
+ for gmail_id, gmail_labels in gmsgs.items():
+ if self.label not in gmail_labels:
+ continue
+ omsg = old_msgs_map.get(gmail_id)
+ if omsg is None:
+ msg = GMailMessage(mailbox=self,
+ gmail_id=gmail_id,
+ gmail_labels=gmail_labels)
+ if not msgdb.have_content(msg):
+ todo.append(
+ asyncio.create_task(self._fetch_message(msg, msgdb)))
+ else:
+ todo.append(asyncio.create_task(self._fetch_metadata(msg)))
+ else:
+ msg = GMailMessage(mailbox=self,
+ gmail_id=gmail_id,
+ gmail_labels=gmail_labels)
+ msg.received_time = omsg.received_time
+ assert msgdb.have_content(msg)
+ msgs.append(msg)
+ await asyncio.gather(*todo)
+ return (msgs, next_history_id)
+
+ @util.log_progress(lambda self: f"Updating Message List for {self.name}",
+ lambda self: f", {len(self.messages)} msgs")
+ @mailbox.update_on_failure
+ async def update_message_list(self, msgdb: messages.MessageDB):
+ """Retrieve the list of all messages and store all the message content
+ in the content_hash message database"""
+ if self.history_delta is None or self.history_delta[1] is None:
+ # For whatever reason, there is usually more history than is
+ # suggested by the history_id from the messages.list, so always
+ # drain it out.
+ self.history_delta = await self._fetch_all_messages(msgdb)
+
+ self.history_delta = await self._fetch_delta_messages(
+ start_history_id=self.history_delta[1],
+ old_msgs=self.history_delta[0],
+ msgdb=msgdb)
+
+ self.messages = {
+ msg.content_hash: msg
+ for msg in self.history_delta[0] if msg.content_hash is not None
+ }
+ self.need_update = False
+ if self.timer:
+ self.timer.cancel()
+ self.timer = None
+ self.timer = self.cfg.loop.call_later(60, self._timer)
+
+ def _timer(self):
+ self.need_update = True
+ self.changed_event.set()
+
+ def force_content(self, msgdb, msgs):
+ raise RuntimeError("Cannot move messages into the Cloud")
+
+ def _update_msg_flags(self, cmsg: messages.Message, old_cmsg_flags: int,
+ lmsg: messages.Message, label_edits):
+ lflags = lmsg.flags & (messages.Message.ALL_FLAGS
+ ^ messages.Message.FLAG_DELETED)
+ if lflags == old_cmsg_flags or lflags == cmsg.flags:
+ return None
+
+ cloud_flags = cmsg.flags ^ old_cmsg_flags
+ flag_mask = messages.Message.ALL_FLAGS ^ cloud_flags
+ nflags = (lflags & flag_mask) | (cmsg.flags & cloud_flags)
+ modified_flags = nflags ^ cmsg.flags
+ if modified_flags & messages.Message.FLAG_READ:
+ label_edits[("-" if nflags & messages.Message.FLAG_READ else "+") +
+ "UNREAD"].add(cmsg.storage_id)
+ if modified_flags & messages.Message.FLAG_FLAGGED:
+ label_edits[("+" if nflags
+ & messages.Message.FLAG_FLAGGED else "-") +
+ "STARRED"].add(cmsg.storage_id)
+ # FLAG_REPLIED is not supported
+ cmsg.flags = nflags
+
+ @staticmethod
+ def _next_edit(label_edits):
+ """Break up the edit list into groups of IDs. The algorithm picks
+ groupings of IDs that have matching label changes, and returns every
+ ID exactly once."""
+ sets = list(label_edits.values())
+ while True:
+ gmail_ids = functools.reduce(lambda x, y: x & y, sets)
+ if gmail_ids:
+ if len(gmail_ids) > 50:
+ return set(sorted(gmail_ids)[:50])
+ return set(gmail_ids)
+
+ # Pick an arbitary ID and advance its group of labels. The above
+ # reduction must return at least todo_gmail_id.
+ todo_gmail_id = next(iter(sets[0]))
+ sets = [I for I in sets if todo_gmail_id in I]
+
+ @util.log_progress(lambda self: f"Uploading local changes for {self.name}",
+ lambda self: f", {self.last_merge_len} changes ")
+ @mailbox.update_on_failure
+ async def merge_content(self, msgs: messages.CHMsgMappingDict_Type):
+ self.last_merge_len = 0
+ label_edits: Dict[str, Set[str]] = collections.defaultdict(set)
+ for ch, mpair in msgs.items():
+ # lmsg is the message in the local mailbox
+ # cmsg is the current cloud message in this class
+ # old_cmsg is the original cloud message from the last sync
+ lmsg, old_cmsg = mpair
+ cmsg = self.messages.get(ch)
+ assert old_cmsg is not None
+
+ # Update flags
+ if cmsg is not None and old_cmsg is not None and lmsg is not None:
+ self._update_msg_flags(cmsg, old_cmsg.flags, lmsg, label_edits)
+
+ if cmsg is not None and (lmsg is None or lmsg.flags
+ & messages.Message.FLAG_DELETED):
+ # To archive skip the +TRASH
+ if self.delete_action == "delete":
+ label_edits["+TRASH"].add(cmsg.storage_id)
+ label_edits["-" + self.label].add(cmsg.storage_id)
+ del self.messages[ch]
+
+ empty: Set[str] = set()
+ self.last_merge_len = len(
+ functools.reduce(lambda x, y: x | y, label_edits.values(), empty))
+
+ # Group all the label changes for a single ID together and then batch
+ # them
+ while label_edits:
+ gmail_ids = self._next_edit(label_edits)
+ labels = []
+ for k, v in list(label_edits.items()):
+ if gmail_ids.issubset(v):
+ labels.append(k)
+ v.difference_update(gmail_ids)
+ if not v:
+ del label_edits[k]
+
+ labels.sort()
+ body = {"ids": sorted(gmail_ids)}
+ add_labels = [I[1:] for I in labels if I[0] == "+"]
+ if add_labels:
+ body["addLabelIds"] = add_labels
+ remove_labels = [I[1:] for I in labels if I[0] == "-"]
+ if remove_labels:
+ body["removeLabelIds"] = remove_labels
+ await self.gmail.post_json("v1",
+ f"/users/me/messages/batchModify",
+ body=body)
diff --git a/cloud_mdir_sync/main.py b/cloud_mdir_sync/main.py
index fdff700..816cf20 100644
--- a/cloud_mdir_sync/main.py
+++ b/cloud_mdir_sync/main.py
@@ -8,7 +8,7 @@ from typing import Dict, Optional, Tuple
import aiohttp
import pyinotify
-from . import config, mailbox, messages, oauth, office365
+from . import config, mailbox, messages, oauth
def route_cloud_messages(cfg: config.Config) -> messages.MBoxDict_Type:
diff --git a/cloud_mdir_sync/oauth.py b/cloud_mdir_sync/oauth.py
index 449d16c..163dcba 100644
--- a/cloud_mdir_sync/oauth.py
+++ b/cloud_mdir_sync/oauth.py
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: GPL-2.0+
import asyncio
+import os
import aiohttp
import aiohttp.web
@@ -9,35 +10,36 @@ class WebServer(object):
"""A small web server is used to manage oauth requests. The user should point a browser
window at localhost. The program will generate redirects for the browser to point at
OAUTH servers when interactive authentication is required."""
- url = "http://localhost:8080/"
+ url = "http://127.0.0.1:8080/"
runner = None
def __init__(self):
self.auth_redirs = {}
self.web_app = aiohttp.web.Application()
self.web_app.router.add_get("/", self._start)
- self.web_app.router.add_get("/oauth2/msal", self._oauth2_msal)
+ self.web_app.router.add_get("/oauth2/msal", self._oauth2_redirect)
+ self.web_app.router.add_get("/oauth2/gmail", self._oauth2_redirect)
async def go(self):
self.runner = aiohttp.web.AppRunner(self.web_app)
await self.runner.setup()
- site = aiohttp.web.TCPSite(self.runner, 'localhost', 8080)
+ site = aiohttp.web.TCPSite(self.runner, '127.0.0.1', 8080)
await site.start()
async def close(self):
if self.runner:
await self.runner.cleanup()
- async def auth_redir(self, url, state):
+ async def auth_redir(self, url: str, state: str, redir_url: str):
"""Call as part of an OAUTH flow to hand the URL off to interactive browser
based authentication. The flow will resume when the OAUTH server
redirects back to the localhost server. The final query paremeters
will be returned by this function"""
queue = asyncio.Queue()
- self.auth_redirs[state] = (url, queue)
+ self.auth_redirs[state] = (url, queue, redir_url)
return await queue.get()
- def _start(self, request):
+ def _start(self, request: aiohttp.web.Request):
"""Feed redirects to the web browser until all authing is done. FIXME: Some
fancy java script should be used to fetch new interactive auth
requests"""
@@ -45,11 +47,17 @@ class WebServer(object):
raise aiohttp.web.HTTPFound(I[0])
return aiohttp.web.Response(text="Authentication done")
- def _oauth2_msal(self, request):
+ def _oauth2_redirect(self, request: aiohttp.web.Request):
"""Use for the Azure AD authentication response redirection"""
- state = request.query["state"]
+ state = request.query.get("state", None)
+ if state is None:
+ raise aiohttp.web.HTTPBadRequest(text="No state parameter")
try:
- queue = self.auth_redirs[state][1]
+ _, queue, redir_url = self.auth_redirs[state]
+ # RFC8252 8.10
+ if redir_url != self.url[:-1] + request.path:
+ raise aiohttp.web.HTTPBadRequest(
+ text="Invalid redirection path")
del self.auth_redirs[state]
queue.put_nowait(request.query)
except KeyError:
diff --git a/cloud_mdir_sync/office365.py b/cloud_mdir_sync/office365.py
index 0b77de4..a13c7bb 100644
--- a/cloud_mdir_sync/office365.py
+++ b/cloud_mdir_sync/office365.py
@@ -161,7 +161,7 @@ class GraphAPI(object):
f"Goto {self.cfg.web_app.url} in a web browser to authenticate"
)
webbrowser.open(url)
- q = await self.cfg.web_app.auth_redir(url, state)
+ q = await self.cfg.web_app.auth_redir(url, state, redirect_url)
code = q["code"]
try:
diff --git a/setup.py b/setup.py
index 8d77bb2..cfa9f50 100755
--- a/setup.py
+++ b/setup.py
@@ -33,8 +33,10 @@ setup(
'cryptography>=2.8',
'keyring>=21',
'msal>=1.0',
+ 'oauthlib>=3.1',
'pyinotify>=0.9.6',
'requests>=2.18',
+ 'requests_oauthlib>=1.3',
],
include_package_data=True,
zip_safe=False)