oidc-fastapi-test/src/oidc_test/database.py
phil ee8ba3d2df
All checks were successful
/ build (push) Successful in 5s
/ test (push) Successful in 5s
Get roles from access token, remove user info inspection, refreactorings
2025-02-06 13:30:35 +01:00

88 lines
2.5 KiB
Python

"""Fake in-memory database interface for demo purpose"""
import logging
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from authlib.oauth2.rfc6749 import OAuth2Token
from .settings import OIDCProvider, oidc_providers_settings
from .models import User, Role
logger = logging.getLogger("oidc-test")
class UserNotInDB(Exception):
pass
class TokenNotInDb(Exception):
pass
class Database:
users: dict[str, User] = {}
tokens: dict[str, OAuth2Token] = {}
# Last sessions for the user (key: users's subject id (sub))
async def add_user(
self,
sub: str,
user_info: dict,
oidc_provider: StarletteOAuth2App,
access_token: str,
access_token_decoded: dict | None = None,
) -> User:
if access_token_decoded is None:
assert oidc_provider.name is not None
oidc_provider_settings = oidc_providers_settings[oidc_provider.name]
access_token_decoded = oidc_provider_settings.decode(access_token)
user = User(**user_info)
user.userinfo = user_info
user.oidc_provider = oidc_provider
user.access_token = access_token
user.access_token_decoded = access_token_decoded
# Add roles provided in the access token
roles = set()
try:
r = access_token_decoded["resource_access"][oidc_provider.client_id]["roles"]
roles.update(r)
except KeyError:
pass
try:
r = access_token_decoded["realm_access"]["roles"]
if isinstance(r, str):
roles.add(r)
else:
roles.update(r)
except KeyError:
pass
user.roles = [Role(name=role_name) for role_name in roles]
self.users[sub] = user
return user
async def get_user(self, sub: str) -> User:
if sub not in self.users:
raise UserNotInDB
return self.users[sub]
async def add_token(self, oidc_provider_settings: OIDCProvider, token: OAuth2Token) -> None:
"""Store a token using as key the sid (auth provider's session id)
in the id_token"""
sid = token["userinfo"]["sid"]
self.tokens[sid] = token
async def get_token(
self,
oidc_provider_settings: OIDCProvider,
sid: str | None,
) -> OAuth2Token:
if sid is None:
raise TokenNotInDb
try:
return self.tokens[sid]
except KeyError:
raise TokenNotInDb
db = Database()