88 lines
2.5 KiB
Python
88 lines
2.5 KiB
Python
"""Fake in-memory database interface for demo purpose"""
|
|
|
|
import logging
|
|
|
|
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
|
from authlib.oauth2.rfc6749 import OAuth2Token
|
|
|
|
from .settings import OIDCProvider, oidc_providers_settings
|
|
from .models import User, Role
|
|
|
|
logger = logging.getLogger("oidc-test")
|
|
|
|
|
|
class UserNotInDB(Exception):
|
|
pass
|
|
|
|
|
|
class TokenNotInDb(Exception):
|
|
pass
|
|
|
|
|
|
class Database:
|
|
users: dict[str, User] = {}
|
|
tokens: dict[str, OAuth2Token] = {}
|
|
|
|
# Last sessions for the user (key: users's subject id (sub))
|
|
|
|
async def add_user(
|
|
self,
|
|
sub: str,
|
|
user_info: dict,
|
|
oidc_provider: StarletteOAuth2App,
|
|
access_token: str,
|
|
access_token_decoded: dict | None = None,
|
|
) -> User:
|
|
if access_token_decoded is None:
|
|
assert oidc_provider.name is not None
|
|
oidc_provider_settings = oidc_providers_settings[oidc_provider.name]
|
|
access_token_decoded = oidc_provider_settings.decode(access_token)
|
|
user = User(**user_info)
|
|
user.userinfo = user_info
|
|
user.oidc_provider = oidc_provider
|
|
user.access_token = access_token
|
|
user.access_token_decoded = access_token_decoded
|
|
# Add roles provided in the access token
|
|
roles = set()
|
|
try:
|
|
r = access_token_decoded["resource_access"][oidc_provider.client_id]["roles"]
|
|
roles.update(r)
|
|
except KeyError:
|
|
pass
|
|
try:
|
|
r = access_token_decoded["realm_access"]["roles"]
|
|
if isinstance(r, str):
|
|
roles.add(r)
|
|
else:
|
|
roles.update(r)
|
|
except KeyError:
|
|
pass
|
|
user.roles = [Role(name=role_name) for role_name in roles]
|
|
self.users[sub] = user
|
|
return user
|
|
|
|
async def get_user(self, sub: str) -> User:
|
|
if sub not in self.users:
|
|
raise UserNotInDB
|
|
return self.users[sub]
|
|
|
|
async def add_token(self, oidc_provider_settings: OIDCProvider, token: OAuth2Token) -> None:
|
|
"""Store a token using as key the sid (auth provider's session id)
|
|
in the id_token"""
|
|
sid = token["userinfo"]["sid"]
|
|
self.tokens[sid] = token
|
|
|
|
async def get_token(
|
|
self,
|
|
oidc_provider_settings: OIDCProvider,
|
|
sid: str | None,
|
|
) -> OAuth2Token:
|
|
if sid is None:
|
|
raise TokenNotInDb
|
|
try:
|
|
return self.tokens[sid]
|
|
except KeyError:
|
|
raise TokenNotInDb
|
|
|
|
|
|
db = Database()
|