diff options
Diffstat (limited to 'cloud_mdir_sync/oauth.py')
-rw-r--r-- | cloud_mdir_sync/oauth.py | 132 |
1 files changed, 124 insertions, 8 deletions
diff --git a/cloud_mdir_sync/oauth.py b/cloud_mdir_sync/oauth.py index 55f3f31..0ab30ae 100644 --- a/cloud_mdir_sync/oauth.py +++ b/cloud_mdir_sync/oauth.py @@ -1,23 +1,21 @@ # SPDX-License-Identifier: GPL-2.0+ import asyncio +import base64 +import hashlib import os +import secrets from abc import abstractmethod -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, Dict, List, Optional import aiohttp import aiohttp.web +import oauthlib +import oauthlib.oauth2 if TYPE_CHECKING: from . import config -def check_scopes(token, required_scopes: List[str]) -> bool: - if token is None: - return False - tscopes = set(token.get("scope", [])) - return set(required_scopes).issubset(tscopes) - - class Account(object): """An OAUTH2 account""" oauth_smtp = False @@ -91,3 +89,121 @@ class WebServer(object): for I in self.auth_redirs.values(): raise aiohttp.web.HTTPFound(I[0]) raise aiohttp.web.HTTPFound(self.url) + + +class NativePublicApplicationClient(oauthlib.oauth2.WebApplicationClient): + """Amazingly oauthlib doesn't include client side PCKE support + Hack it into the WebApplicationClient""" + def __init__(self, client_id): + super().__init__(client_id) + + def _code_challenge_method_s256(self, verifier): + return base64.urlsafe_b64encode( + hashlib.sha256(verifier.encode()).digest()).decode().rstrip('=') + + def prepare_request_uri(self, + authority_uri, + redirect_uri, + scope=None, + state=None, + **kwargs): + self.verifier = secrets.token_urlsafe(96) + return super().prepare_request_uri( + authority_uri, + redirect_uri=redirect_uri, + scope=scope, + state=state, + code_challenge=self._code_challenge_method_s256(self.verifier), + code_challenge_method="S256", + **kwargs) + + def prepare_request_body(self, + code=None, + redirect_uri=None, + body='', + include_client_id=True, + **kwargs): + return super().prepare_request_body( + code=code, + redirect_uri=redirect_uri, + body=body, + include_client_id=include_client_id, + code_verifier=self.verifier, + **kwargs) + + +class OAuth2Session(object): + """Helper to execute OAUTH JSON queries using asyncio http""" + def __init__(self, + client_id: str, + client: oauthlib.oauth2.rfc6749.clients.base.Client, + redirect_uri: str, + token: Optional[Dict], + strict_scopes=True): + """strict_scopes can be True if the server always returns only the + scopes that were requested""" + self._client = client + self.redirect_uri = redirect_uri + self.strict_scopes = strict_scopes + + if token is not None: + self._client.token = token + self._client.populate_token_attributes(token) + + def authorization_url(self, url: str, state: str, scopes: List[str], **kwargs) -> str: + return self._client.prepare_request_uri(url, + redirect_uri=self.redirect_uri, + scope=scopes, + state=state, + **kwargs) + + async def fetch_token(self, + session: aiohttp.ClientSession, + token_url: str, + include_client_id: bool, + scopes: List[str], + code: str, + client_secret: Optional[str] = None) -> Dict: + """Complete the exchange started with authorization_url""" + body = self._client.prepare_request_body( + code=code, + redirect_uri=self.redirect_uri, + include_client_id=include_client_id, + scope=scopes, + client_secret=client_secret) + async with session.post( + token_url, + data=dict(oauthlib.common.urldecode(body)), + headers={ + "Accept": "application/json", + #"Content-Type": + #"application/x-www-form-urlencoded;charset=UTF-8", + }) as op: + self.token = self._client.parse_request_body_response( + await op.text(), scope=scopes if self.strict_scopes else None) + return self.token + + async def refresh_token(self, + session: aiohttp.ClientSession, + token_url: str, + client_id: str, + scopes: List[str], + refresh_token: str, + client_secret: Optional[str] = None) -> Dict: + body = self._client.prepare_refresh_body(refresh_token=refresh_token, + scope=scopes, + client_id=client_id, + client_secret=client_secret) + async with session.post( + token_url, + data=dict(oauthlib.common.urldecode(body)), + headers={ + "Accept": "application/json", + #"Content-Type": + #"application/x-www-form-urlencoded;charset=UTF-8", + }) as op: + self.token = self._client.parse_request_body_response( + await op.text(), scope=scopes if self.strict_scopes else None) + if not "refresh_token" in self.token: + self.token["refresh_token"] = refresh_token + return self.token |