From d1877aaf5a791204dd7222f50fc33b5f802c7e94 Mon Sep 17 00:00:00 2001 From: Jason Gunthorpe Date: Fri, 7 Feb 2020 11:52:43 -0400 Subject: Don't leave asyncio tasks running unexpectedly All cases where gather is called intend that the tasks will complete successfully or all cancel at the first error. Add a little wrapper to achieve this. Signed-off-by: Jason Gunthorpe --- cloud_mdir_sync/gmail.py | 5 +++-- cloud_mdir_sync/main.py | 17 +++++++++-------- cloud_mdir_sync/office365.py | 7 ++++--- cloud_mdir_sync/util.py | 13 +++++++++++++ 4 files changed, 29 insertions(+), 13 deletions(-) diff --git a/cloud_mdir_sync/gmail.py b/cloud_mdir_sync/gmail.py index 49468d4..949798b 100644 --- a/cloud_mdir_sync/gmail.py +++ b/cloud_mdir_sync/gmail.py @@ -15,6 +15,7 @@ import oauthlib import requests_oauthlib from . import config, mailbox, messages, util +from .util import asyncio_complete class NativePublicApplicationClient(oauthlib.oauth2.WebApplicationClient): @@ -424,7 +425,7 @@ class GMailMailbox(mailbox.Mailbox): msgs.append(msg) if todo: start_history_id = await todo[0] - await asyncio.gather(*todo) + await asyncio_complete(*todo) return (msgs, start_history_id) @@ -505,7 +506,7 @@ class GMailMailbox(mailbox.Mailbox): msg.received_time = omsg.received_time assert self.msgdb.have_content(msg) msgs.append(msg) - await asyncio.gather(*todo) + await asyncio_complete(*todo) return (msgs, next_history_id) @util.log_progress(lambda self: f"Updating Message List for {self.name}", diff --git a/cloud_mdir_sync/main.py b/cloud_mdir_sync/main.py index af179cb..d810956 100644 --- a/cloud_mdir_sync/main.py +++ b/cloud_mdir_sync/main.py @@ -9,6 +9,7 @@ import aiohttp import pyinotify from . import config, mailbox, messages, oauth +from .util import asyncio_complete def route_cloud_messages(cfg: config.Config) -> messages.MBoxDict_Type: @@ -48,7 +49,7 @@ async def update_cloud_from_local(cfg: config.Config, if lmsg is None and offline_mode: continue msgs_by_cloud[cloud_msg.mailbox][ch] = (lmsg, cloud_msg) - await asyncio.gather(*( + await asyncio_complete(*( mbox.merge_content(msgdict) for mbox, msgdict in msgs_by_cloud.items() if not mbox.same_messages(msgdict, tuple_form=True))) @@ -59,15 +60,15 @@ async def synchronize_mail(cfg: config.Config): try: await cfg.web_app.go() - await asyncio.gather(*(mbox.setup_mbox() - for mbox in cfg.all_mboxes())) + await asyncio_complete(*(mbox.setup_mbox() + for mbox in cfg.all_mboxes())) msgs = None while True: try: - await asyncio.gather(*(mbox.update_message_list() - for mbox in cfg.all_mboxes() - if mbox.need_update)) + await asyncio_complete(*(mbox.update_message_list() + for mbox in cfg.all_mboxes() + if mbox.need_update)) nmsgs = route_cloud_messages(cfg) if msgs is not None: @@ -92,8 +93,8 @@ async def synchronize_mail(cfg: config.Config): cfg.msgdb.cleanup_msgs(msgs) cfg.logger.debug("Changed event, looping") finally: - await asyncio.gather(*(domain.close() - for domain in cfg.domains.values())) + await asyncio_complete(*(domain.close() + for domain in cfg.domains.values())) cfg.domains = {} await cfg.web_app.close() diff --git a/cloud_mdir_sync/office365.py b/cloud_mdir_sync/office365.py index f96ef29..9cfaf0d 100644 --- a/cloud_mdir_sync/office365.py +++ b/cloud_mdir_sync/office365.py @@ -13,6 +13,7 @@ import aiohttp import requests from . import config, mailbox, messages, util +from .util import asyncio_complete def _retry_protect(func): @@ -488,7 +489,7 @@ class O365Mailbox(mailbox.Mailbox): asyncio.create_task(self._fetch_message(msg))) msgs.append(msg) - await asyncio.gather(*todo) + await asyncio_complete(*todo) res = {} for msg in msgs: @@ -652,8 +653,8 @@ class O365Mailbox(mailbox.Mailbox): body={"destinationId": "deleteditems"})) del self.messages[ch] - await asyncio.gather(*todo_flags) + await asyncio_complete(*todo_flags) # Delete must be temporally after move as move will change the mailbox # id. - await asyncio.gather(*todo_del) + await asyncio_complete(*todo_del) self.last_merge_len = len(todo_flags) + len(todo_del) diff --git a/cloud_mdir_sync/util.py b/cloud_mdir_sync/util.py index 7799356..df8ce59 100644 --- a/cloud_mdir_sync/util.py +++ b/cloud_mdir_sync/util.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: GPL-2.0+ +import asyncio import contextlib import functools import inspect @@ -68,3 +69,15 @@ def sizeof_fmt(num, suffix='B'): def pj(json_dict): print(json.dumps(json_dict, indent=4, sort_keys=True)) + + +async def asyncio_complete(*awo_list): + """This is like asyncio.gather but it always ensures that the list of + awaitable objects is completed upon return. For instance if an exception + is thrown then all the awaitables are canceled""" + g = asyncio.gather(*awo_list) + try: + await g + finally: + g.cancel() + await asyncio.gather(*awo_list, return_exceptions=True) -- cgit v1.2.3