aboutsummaryrefslogtreecommitdiffstats
path: root/cloud_mdir_sync/oauth.py
diff options
context:
space:
mode:
Diffstat (limited to 'cloud_mdir_sync/oauth.py')
-rw-r--r--cloud_mdir_sync/oauth.py132
1 files changed, 124 insertions, 8 deletions
diff --git a/cloud_mdir_sync/oauth.py b/cloud_mdir_sync/oauth.py
index 55f3f31..0ab30ae 100644
--- a/cloud_mdir_sync/oauth.py
+++ b/cloud_mdir_sync/oauth.py
@@ -1,23 +1,21 @@
# SPDX-License-Identifier: GPL-2.0+
import asyncio
+import base64
+import hashlib
import os
+import secrets
from abc import abstractmethod
-from typing import TYPE_CHECKING, List
+from typing import TYPE_CHECKING, Dict, List, Optional
import aiohttp
import aiohttp.web
+import oauthlib
+import oauthlib.oauth2
if TYPE_CHECKING:
from . import config
-def check_scopes(token, required_scopes: List[str]) -> bool:
- if token is None:
- return False
- tscopes = set(token.get("scope", []))
- return set(required_scopes).issubset(tscopes)
-
-
class Account(object):
"""An OAUTH2 account"""
oauth_smtp = False
@@ -91,3 +89,121 @@ class WebServer(object):
for I in self.auth_redirs.values():
raise aiohttp.web.HTTPFound(I[0])
raise aiohttp.web.HTTPFound(self.url)
+
+
+class NativePublicApplicationClient(oauthlib.oauth2.WebApplicationClient):
+ """Amazingly oauthlib doesn't include client side PCKE support
+ Hack it into the WebApplicationClient"""
+ def __init__(self, client_id):
+ super().__init__(client_id)
+
+ def _code_challenge_method_s256(self, verifier):
+ return base64.urlsafe_b64encode(
+ hashlib.sha256(verifier.encode()).digest()).decode().rstrip('=')
+
+ def prepare_request_uri(self,
+ authority_uri,
+ redirect_uri,
+ scope=None,
+ state=None,
+ **kwargs):
+ self.verifier = secrets.token_urlsafe(96)
+ return super().prepare_request_uri(
+ authority_uri,
+ redirect_uri=redirect_uri,
+ scope=scope,
+ state=state,
+ code_challenge=self._code_challenge_method_s256(self.verifier),
+ code_challenge_method="S256",
+ **kwargs)
+
+ def prepare_request_body(self,
+ code=None,
+ redirect_uri=None,
+ body='',
+ include_client_id=True,
+ **kwargs):
+ return super().prepare_request_body(
+ code=code,
+ redirect_uri=redirect_uri,
+ body=body,
+ include_client_id=include_client_id,
+ code_verifier=self.verifier,
+ **kwargs)
+
+
+class OAuth2Session(object):
+ """Helper to execute OAUTH JSON queries using asyncio http"""
+ def __init__(self,
+ client_id: str,
+ client: oauthlib.oauth2.rfc6749.clients.base.Client,
+ redirect_uri: str,
+ token: Optional[Dict],
+ strict_scopes=True):
+ """strict_scopes can be True if the server always returns only the
+ scopes that were requested"""
+ self._client = client
+ self.redirect_uri = redirect_uri
+ self.strict_scopes = strict_scopes
+
+ if token is not None:
+ self._client.token = token
+ self._client.populate_token_attributes(token)
+
+ def authorization_url(self, url: str, state: str, scopes: List[str], **kwargs) -> str:
+ return self._client.prepare_request_uri(url,
+ redirect_uri=self.redirect_uri,
+ scope=scopes,
+ state=state,
+ **kwargs)
+
+ async def fetch_token(self,
+ session: aiohttp.ClientSession,
+ token_url: str,
+ include_client_id: bool,
+ scopes: List[str],
+ code: str,
+ client_secret: Optional[str] = None) -> Dict:
+ """Complete the exchange started with authorization_url"""
+ body = self._client.prepare_request_body(
+ code=code,
+ redirect_uri=self.redirect_uri,
+ include_client_id=include_client_id,
+ scope=scopes,
+ client_secret=client_secret)
+ async with session.post(
+ token_url,
+ data=dict(oauthlib.common.urldecode(body)),
+ headers={
+ "Accept": "application/json",
+ #"Content-Type":
+ #"application/x-www-form-urlencoded;charset=UTF-8",
+ }) as op:
+ self.token = self._client.parse_request_body_response(
+ await op.text(), scope=scopes if self.strict_scopes else None)
+ return self.token
+
+ async def refresh_token(self,
+ session: aiohttp.ClientSession,
+ token_url: str,
+ client_id: str,
+ scopes: List[str],
+ refresh_token: str,
+ client_secret: Optional[str] = None) -> Dict:
+ body = self._client.prepare_refresh_body(refresh_token=refresh_token,
+ scope=scopes,
+ client_id=client_id,
+ client_secret=client_secret)
+ async with session.post(
+ token_url,
+ data=dict(oauthlib.common.urldecode(body)),
+ headers={
+ "Accept": "application/json",
+ #"Content-Type":
+ #"application/x-www-form-urlencoded;charset=UTF-8",
+ }) as op:
+ self.token = self._client.parse_request_body_response(
+ await op.text(), scope=scopes if self.strict_scopes else None)
+ if not "refresh_token" in self.token:
+ self.token["refresh_token"] = refresh_token
+ return self.token