diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index 0c8dcc7..281511d 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -13,13 +13,10 @@ from authlib.oauth2.auth import OAuth2Token from .models import User from .database import TokenNotInDb, db, UserNotInDB -from .settings import settings, OIDCProvider +from .settings import settings, OIDCProvider, oidc_providers_settings -logger = logging.getLogger(__name__) +logger = logging.getLogger("oidc-test") -oidc_providers_settings: dict[str, OIDCProvider] = dict( - [(provider.id, provider) for provider in settings.oidc.providers] -) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") @@ -36,19 +33,16 @@ async def fetch_token(name, request): async def update_token(name, token, refresh_token=None, access_token=None): - breakpoint() - item = await db.get_token(token["id_token"]) - if refresh_token: - item = OAuth2Token.find(name=name, refresh_token=refresh_token) - elif access_token: - item = OAuth2Token.find(name=name, access_token=access_token) - else: - return + oidc_provider_settings = oidc_providers_settings[name] + sid: str = oidc_provider_settings.decode(token["id_token"])["sid"] + item = await db.get_token(oidc_provider_settings, sid) # update old token - item.access_token = token["access_token"] - item.refresh_token = token.get("refresh_token") - item.expires_at = token["expires_at"] - item.save() + item["access_token"] = token.get("access_token") + item["refresh_token"] = token.get("refresh_token") + item["expires_at"] = token["expires_at"] + logger.info(f"Token {sid} refreshed") + # It's a fake db and only in memory, so there's nothing to save + # await item.save() authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) @@ -138,8 +132,17 @@ async def get_token(request: Request) -> OAuth2Token: """Return the token from a request object, from the session. It can be used in Depends()""" try: - return await db.get_token(request.session.get("token")) - except TokenNotInDb: + oidc_provider_settings = oidc_providers_settings[ + request.session.get("oidc_provider_id", "") + ] + except KeyError: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid provider") + try: + return await db.get_token( + oidc_provider_settings, + request.session.get("sid"), + ) + except (TokenNotInDb, InvalidKeyError): raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token") @@ -190,14 +193,16 @@ async def get_user_from_token( token: Annotated[str, Depends(oauth2_scheme)], request: Request, ) -> User: - if (auth_provider_id := request.headers.get("auth_provider")) is None: + try: + auth_provider_id = request.headers["auth_provider"] + except KeyError: raise HTTPException( status.HTTP_401_UNAUTHORIZED, "Request headers must have a 'auth_provider' field", ) - if ( - auth_provider_settings := oidc_providers_settings.get(auth_provider_id) - ) is None: + try: + auth_provider_settings = oidc_providers_settings[auth_provider_id] + except KeyError: raise HTTPException( status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'" ) @@ -216,7 +221,9 @@ async def get_user_from_token( logger.info("Cannot decode token, see below") logger.exception(err) raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot decode token") - if (user_id := payload.get("sub")) is None: + try: + user_id = payload["sub"] + except KeyError: raise HTTPException( status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found" ) @@ -232,7 +239,6 @@ async def get_user_from_token( sub=payload["sub"], user_info=payload, oidc_provider=getattr(authlib_oauth, auth_provider_id), - user_info_from_endpoint={}, access_token=token, ) return user diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index 0a30e9c..360ef11 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -3,11 +3,12 @@ import logging from authlib.integrations.starlette_client.apps import StarletteOAuth2App - -from .models import User, Role from authlib.oauth2.rfc6749 import OAuth2Token -logger = logging.getLogger(__name__) +from .settings import OIDCProvider, oidc_providers_settings +from .models import User, Role + +logger = logging.getLogger("oidc-test") class UserNotInDB(Exception): @@ -29,20 +30,34 @@ class Database: sub: str, user_info: dict, oidc_provider: StarletteOAuth2App, - user_info_from_endpoint: dict, access_token: str, + access_token_decoded: dict | None = None, ) -> User: - user = User.from_auth(userinfo=user_info, oidc_provider=oidc_provider) + if access_token_decoded is None: + assert oidc_provider.name is not None + oidc_provider_settings = oidc_providers_settings[oidc_provider.name] + access_token_decoded = oidc_provider_settings.decode(access_token) + user = User(**user_info) + user.userinfo = user_info + user.oidc_provider = oidc_provider user.access_token = access_token + user.access_token_decoded = access_token_decoded + # Add roles provided in the access token + roles = set() try: - raw_roles = user_info_from_endpoint["resource_access"][ - oidc_provider.client_id - ]["roles"] - except Exception as err: - logger.debug(f"Cannot read additional roles: {err}") - raw_roles = [] - for raw_role in raw_roles: - user.roles.append(Role(name=raw_role)) + r = access_token_decoded["resource_access"][oidc_provider.client_id]["roles"] + roles.update(r) + except KeyError: + pass + try: + r = access_token_decoded["realm_access"]["roles"] + if isinstance(r, str): + roles.add(r) + else: + roles.update(r) + except KeyError: + pass + user.roles = [Role(name=role_name) for role_name in roles] self.users[sub] = user return user @@ -51,14 +66,21 @@ class Database: raise UserNotInDB return self.users[sub] - async def add_token(self, token: OAuth2Token, user: User) -> None: - self.tokens[token["id_token"]] = token + async def add_token(self, oidc_provider_settings: OIDCProvider, token: OAuth2Token) -> None: + """Store a token using as key the sid (auth provider's session id) + in the id_token""" + sid = token["userinfo"]["sid"] + self.tokens[sid] = token - async def get_token(self, id_token: str | None) -> OAuth2Token: - if id_token is None: + async def get_token( + self, + oidc_provider_settings: OIDCProvider, + sid: str | None, + ) -> OAuth2Token: + if sid is None: raise TokenNotInDb try: - return self.tokens[id_token] + return self.tokens[sid] except KeyError: raise TokenNotInDb diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index e9ba5b1..60482bd 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -15,7 +15,7 @@ from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse from fastapi.templating import Jinja2Templates from fastapi.middleware.cors import CORSMiddleware -from jwt import InvalidKeyError, InvalidTokenError +from jwt import InvalidTokenError, PyJWTError from starlette.middleware.sessions import SessionMiddleware from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.base_client import OAuthError @@ -26,7 +26,7 @@ from authlib.oauth2.rfc6749 import OAuth2Token # from fastapi.security import OpenIdConnect # from pkce import generate_code_verifier, generate_pkce_pair -from .settings import settings +from .settings import settings, oidc_providers_settings from .models import User from .auth_utils import ( get_oidc_provider, @@ -37,14 +37,13 @@ from .auth_utils import ( get_user_from_token, authlib_oauth, get_token, - oidc_providers_settings, get_providers_info, ) from .auth_misc import pretty_details from .database import TokenNotInDb, db from .resource_server import get_resource -logger = logging.getLogger("uvicorn.error") +logger = logging.getLogger("oidc-test") templates = Jinja2Templates(Path(__file__).parent / "templates") @@ -189,43 +188,28 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: request.session["oidc_provider_id"] = oidc_provider_id # User id (sub) given by oidc provider sub = userinfo["sub"] - # Get additional data from userinfo endpoint - try: - user_info_from_endpoint = await oidc_provider.userinfo( - token=token, follow_redirects=True - ) - except Exception as err: - logger.warn(f"Cannot get userinfo from endpoint: {err}") - user_info_from_endpoint = {} # Build and remember the user in the session request.session["user_sub"] = sub - # Verify the token's signature and validity + # Store the user in the database, which also verifies the token validity and signature try: - oidc_provider_settings = oidc_providers_settings[oidc_provider_id] - oidc_provider_settings.decode(token["access_token"]) - except InvalidKeyError: + user = await db.add_user( + sub, + user_info=userinfo, + oidc_provider=oidc_provider, + access_token=token["access_token"], + ) + except PyJWTError as err: raise HTTPException( status.HTTP_401_UNAUTHORIZED, - detail="Token invalid key / signature", + detail=f"Token invalid: {err.__class__.__name__}", ) - except Exception as err: - logger.exception(err) - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, - detail="Cannot decode token or verify its signature", - ) - # Store the user in the database - user = await db.add_user( - sub, - user_info=userinfo, - oidc_provider=oidc_provider, - user_info_from_endpoint=user_info_from_endpoint, - access_token=token["access_token"], - ) - # Add the id_token to the session - request.session["token"] = token["id_token"] + assert isinstance(user, User) + # Add the provider session id to the session + request.session["sid"] = userinfo["sid"] # Add the token to the db because it is used for logout - await db.add_token(token, user) + assert oidc_provider.name is not None + oidc_provider_settings = oidc_providers_settings[oidc_provider.name] + await db.add_token(oidc_provider_settings, token) # Send the user to the home: (s)he is authenticated return RedirectResponse(url=request.url_for("home")) else: @@ -268,8 +252,14 @@ async def logout( ) return RedirectResponse(request.url_for("non_compliant_logout")) post_logout_uri = request.url_for("home") + oidc_provider_settings = oidc_providers_settings.get( + request.session.get("oidc_provider_id", "") + ) + assert oidc_provider_settings is not None try: - token = await db.get_token(request.session.pop("token", None)) + token = await db.get_token( + oidc_provider_settings, request.session.pop("sid", None) + ) except TokenNotInDb: logger.warn("No session in db for the token or no token") return RedirectResponse(request.url_for("home")) diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py index 4b1c064..fc0dba7 100644 --- a/src/oidc_test/models.py +++ b/src/oidc_test/models.py @@ -1,6 +1,6 @@ import logging from functools import cached_property -from typing import Self +from typing import Self, Any from pydantic import ( computed_field, @@ -11,7 +11,7 @@ from pydantic import ( from authlib.integrations.starlette_client.apps import StarletteOAuth2App from sqlmodel import SQLModel, Field -logger = logging.getLogger(__name__) +logger = logging.getLogger("oidc-test") class Role(SQLModel, extra="ignore"): @@ -36,19 +36,9 @@ class User(UserBase): ) userinfo: dict = {} access_token: str | None = None + access_token_decoded: dict[str, Any] | None = None oidc_provider: StarletteOAuth2App | None = None - @classmethod - def from_auth(cls, userinfo: dict, oidc_provider: StarletteOAuth2App) -> Self: - user = cls(**userinfo) - user.userinfo = userinfo - user.oidc_provider = oidc_provider - # Add roles if they are provided in the token - if raw_ra := userinfo.get("realm_access"): - if raw_roles := raw_ra.get("roles"): - user.roles = [Role(name=raw_role) for raw_role in raw_roles] - return user - @computed_field @cached_property def roles_as_set(self) -> set[str]: @@ -68,7 +58,7 @@ class User(UserBase): assert self.access_token is not None assert self.oidc_provider is not None assert self.oidc_provider.name is not None - from .auth_utils import oidc_providers_settings + from .settings import oidc_providers_settings return oidc_providers_settings[self.oidc_provider.name].decode( self.access_token diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index 635a91b..0d90533 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -8,7 +8,7 @@ from starlette.status import HTTP_401_UNAUTHORIZED from .models import User -logger = logging.getLogger(__name__) +logger = logging.getLogger("oidc-test") async def get_resource(resource_id: str, user: User) -> dict: diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index 4d08ada..2544bd7 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -1,8 +1,9 @@ from os import environ import string import random -from typing import Type, Tuple +from typing import Type, Tuple, Any from pathlib import Path +import logging from jwt import decode from pydantic import BaseModel, computed_field, AnyUrl @@ -16,6 +17,8 @@ from starlette.requests import Request from .models import User +logger = logging.getLogger("oidc-test") + class Resource(BaseModel): """A resource with an URL that can be accessed with an OAuth2 access token""" @@ -86,14 +89,27 @@ class OIDCProvider(BaseModel): -----END PUBLIC KEY----- """ - def decode(self, token: str) -> dict: + def decode(self, token: str, verify_signature: bool = True) -> dict[str, Any]: """Decode the token with signature check""" + decoded = decode( + token, + self.get_public_key(), + algorithms=[self.signature_alg], + audience=["account", "oidc-test", "oidc-test-web"], + options={ + "verify_signature": False, + "verify_aud": False, + }, # not settings.insecure.skip_verify_signature}, + ) + logger.debug(str(decoded)) return decode( token, self.get_public_key(), algorithms=[self.signature_alg], - audience=["oidc-test", "oidc-test-web"], - options={"verify_signature": not settings.insecure.skip_verify_signature}, + audience=["account", "oidc-test", "oidc-test-web"], + options={ + "verify_signature": verify_signature, + }, # not settings.insecure.skip_verify_signature}, ) @@ -156,3 +172,8 @@ class Settings(BaseSettings): settings = Settings() + + +oidc_providers_settings: dict[str, OIDCProvider] = dict( + [(provider.id, provider) for provider in settings.oidc.providers] +)