diff --git a/src/oidc-test/auth_utils.py b/src/oidc-test/auth_utils.py index a322979..668bfd8 100644 --- a/src/oidc-test/auth_utils.py +++ b/src/oidc-test/auth_utils.py @@ -1,16 +1,16 @@ from typing import Union from functools import wraps -from fastapi import HTTPException, Request +from fastapi import HTTPException, Request, status from .models import User +from .database import db def get_current_user(request: Request) -> User: - auth_data = request.session.get("user") - if auth_data is None: - raise HTTPException(401, "Not authorized") - return User(**auth_data) + if (user_sub := request.session.get("user_sub")) is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED) + return db.get_user(user_sub) def get_current_user_or_none(request: Request) -> User | None: @@ -20,11 +20,7 @@ def get_current_user_or_none(request: Request) -> User | None: return None -def hasrole( - required_roles: Union[str, list[str]] = [], - roles_key: str = "roles", - realm: str | None = "realm_access", # Keycloak standard for realm defined roles -): +def hasrole(required_roles: Union[str, list[str]] = []): required_roles_set: set[str] if isinstance(required_roles, str): required_roles_set = set([required_roles]) @@ -39,18 +35,9 @@ def hasrole( 500, "Functions decorated with hasrole must have a request:Request argument", ) - if "user" not in request.session: - raise HTTPException(401, "Not authorized") - user = request.session["user"] - try: - if realm in user: - roles = user[realm][roles_key] - else: - roles = user[roles_key] - except KeyError: - raise HTTPException(401, "Not authorized") - if not any(required_roles_set.intersection(roles)): - raise HTTPException(401, "Not authorized") + user: User = get_current_user(request) + if not any(required_roles_set.intersection(user.roles_as_set)): + raise HTTPException(status.HTTP_401_UNAUTHORIZED) return await func(request, *args, **kwargs) return wrapper diff --git a/src/oidc-test/database.py b/src/oidc-test/database.py new file mode 100644 index 0000000..fb0b167 --- /dev/null +++ b/src/oidc-test/database.py @@ -0,0 +1,24 @@ +# Implement a fake in-memory database interface for demo purpose + +from authlib.integrations.starlette_client.apps import StarletteOAuth2App + +from .models import User + + +class Database: + users: dict[str, User] = {} + + # Last sessions for the user (key: users's subject id (sub)) + + async def add_user( + self, sub: str, user_info: dict, oidc_provider: StarletteOAuth2App + ) -> User: + user = User.from_auth(userinfo=user_info, oidc_provider=oidc_provider) + self.users[sub] = user + return user + + def get_user(self, sub: str) -> User: + return self.users[sub] + + +db = Database() diff --git a/src/oidc-test/main.py b/src/oidc-test/main.py index 4254277..832b101 100644 --- a/src/oidc-test/main.py +++ b/src/oidc-test/main.py @@ -11,6 +11,7 @@ from authlib.integrations.starlette_client import OAuth, OAuthError from .settings import settings from .models import User from .auth_utils import hasrole, get_current_user_or_none, get_current_user +from .database import db templates = Jinja2Templates("src/templates") @@ -25,13 +26,15 @@ app.add_middleware( secret_key=settings.secret_key, ) -# Add oidc providers from the settings +# Add oidc providers to authlib from the settings authlib_oauth = OAuth() for provider in settings.oidc.providers: authlib_oauth.register( name=provider.name, server_metadata_url=provider.provider_url, - client_kwargs={"scope": "openid email offline_access profile roles"}, + client_kwargs={ + "scope": "openid email offline_access profile roles", + }, client_id=provider.client_id, client_secret=provider.client_secret, # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) @@ -40,38 +43,54 @@ for provider in settings.oidc.providers: @app.get("/login") async def login(request: Request, provider: str) -> RedirectResponse: - redirect_uri = request.url_for("auth", provider=provider) + redirect_uri = request.url_for("auth", oidc_provider_id=provider) try: provider_: StarletteOAuth2App = getattr(authlib_oauth, provider) except AttributeError: - raise HTTPException(500, "No such provider") + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") try: return await provider_.authorize_redirect(request, redirect_uri) except HTTPError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider") -@app.get("/auth/{provider}") -async def auth(request: Request, provider: str) -> RedirectResponse: +@app.get("/auth/{oidc_provider_id}") +async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: try: - provider_: StarletteOAuth2App = getattr(authlib_oauth, provider) + oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id) except AttributeError: - raise HTTPException(500, "No such provider") + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") try: - token = await provider_.authorize_access_token(request) + token = await oidc_provider.authorize_access_token(request) except OAuthError as error: - raise HTTPException(status_code=401, detail=error.error) - user = token.get("userinfo") - if user: - request.session["user"] = dict(user) + raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error) + # Remember the oidc_provider in the session + request.session["oidc_provider_id"] = oidc_provider_id + # + # One could process the full decoded token which contains extra information + # eg for updates. Here we are only interested in roles + # + if userinfo := token.get("userinfo"): + # sub given by oidc provider + sub = userinfo["sub"] + # Build and remember the user in the session + request.session["user_sub"] = sub + # Store the user in the database + await db.add_user(sub, user_info=userinfo, oidc_provider=oidc_provider) return RedirectResponse(url="/") else: + # Not sure if it's correct to redirect to plain login (which is not implemented anyway) + # if no userinfo is provided return RedirectResponse(url="/login") @app.get("/logout") -async def logout(request: Request) -> RedirectResponse: - request.session.pop("user", None) +async def logout( + request: Request, user: Annotated[User, Depends(get_current_user_or_none)] +) -> RedirectResponse: + # TODO: logout from oidc_provider + # await user.oidc_provider.logout_redirect() + request.session.pop("user_sub", None) return RedirectResponse(url="/") @@ -84,7 +103,6 @@ async def home( context={ "settings": settings.model_dump(), "user": user, - "auth_data": request.session.get("user"), }, name="index.html", ) diff --git a/src/oidc-test/models.py b/src/oidc-test/models.py index 904fe43..2a2dd0e 100644 --- a/src/oidc-test/models.py +++ b/src/oidc-test/models.py @@ -1,11 +1,48 @@ -from pydantic import BaseModel, EmailStr, AnyHttpUrl +from functools import cached_property +from typing import Self + +from pydantic import BaseModel, EmailStr, AnyHttpUrl, Field, computed_field +from authlib.integrations.starlette_client.apps import StarletteOAuth2App # from app.models import User -class User(BaseModel, extra="ignore"): +class Role(BaseModel, extra="ignore"): + name: str + + +class UserBase(BaseModel, extra="ignore"): id: str | None = None name: str email: EmailStr | None = None picture: AnyHttpUrl | None = None + roles: list[Role] = [] + + +class User(UserBase): + class Config: + arbitrary_types_allowed = True + + sub: str = Field( + description="""subject id of the user given by the oidc provider, + also the key for the database 'table'""", + ) + userinfo: dict = {} + oidc_provider: StarletteOAuth2App | None = None + + @classmethod + def from_auth(cls, userinfo: dict, oidc_provider: StarletteOAuth2App) -> Self: + user = cls(**userinfo) + user.userinfo = userinfo + user.oidc_provider = oidc_provider + # Add roles if they are provided in the token + if raw_ra := userinfo.get("realm_access"): + if raw_roles := raw_ra.get("roles"): + user.roles = [Role(name=raw_role) for raw_role in raw_roles] + return user + + @computed_field + @cached_property + def roles_as_set(self) -> set[str]: + return set([role.name for role in self.roles]) diff --git a/src/oidc-test/settings.py b/src/oidc-test/settings.py index e188b82..6054799 100644 --- a/src/oidc-test/settings.py +++ b/src/oidc-test/settings.py @@ -15,7 +15,6 @@ class OIDCProvider(BaseModel): url: str = "" client_id: str = "" client_secret: str = "" - is_swagger: bool = False @computed_field @property @@ -33,13 +32,6 @@ class OIDCSettings(BaseModel): providers: list[OIDCProvider] = [] swagger_provider: str = "" - def get_swagger_provider(self) -> OIDCProvider: - for provider in self.providers: - if provider.is_swagger: - return provider - else: - raise UserWarning("Please define a provider for Swagger with id_swagger") - class Settings(BaseSettings): """Settings wil be read from an .env file""" diff --git a/src/templates/index.html b/src/templates/index.html index 2030651..5f1ff42 100644 --- a/src/templates/index.html +++ b/src/templates/index.html @@ -119,6 +119,10 @@ .hasResponseStatus.status-401 { background-color: #ff000040; } + .role { + padding: 3px 6px; + background-color: #44228840; + }