Many updates, too log to list

This commit is contained in:
phil 2025-01-09 19:39:20 +01:00
parent 5c3d54c3f2
commit 23f180e521
6 changed files with 0 additions and 0 deletions

View file

View file

@ -0,0 +1,45 @@
from typing import Union
from functools import wraps
from fastapi import HTTPException, Request, status
from .models import User
from .database import db
def get_current_user(request: Request) -> User:
if (user_sub := request.session.get("user_sub")) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
return db.get_user(user_sub)
def get_current_user_or_none(request: Request) -> User | None:
try:
return get_current_user(request)
except HTTPException:
return None
def hasrole(required_roles: Union[str, list[str]] = []):
required_roles_set: set[str]
if isinstance(required_roles, str):
required_roles_set = set([required_roles])
else:
required_roles_set = set(required_roles)
def decorator(func):
@wraps(func)
async def wrapper(request=None, *args, **kwargs):
if request is None:
raise HTTPException(
500,
"Functions decorated with hasrole must have a request:Request argument",
)
user: User = get_current_user(request)
if not any(required_roles_set.intersection(user.roles_as_set)):
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
return await func(request, *args, **kwargs)
return wrapper
return decorator

24
src/oidc_test/database.py Normal file
View file

@ -0,0 +1,24 @@
# Implement a fake in-memory database interface for demo purpose
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from .models import User
class Database:
users: dict[str, User] = {}
# Last sessions for the user (key: users's subject id (sub))
async def add_user(
self, sub: str, user_info: dict, oidc_provider: StarletteOAuth2App
) -> User:
user = User.from_auth(userinfo=user_info, oidc_provider=oidc_provider)
self.users[sub] = user
return user
def get_user(self, sub: str) -> User:
return self.users[sub]
db = Database()

145
src/oidc_test/main.py Normal file
View file

@ -0,0 +1,145 @@
from typing import Annotated
from httpx import HTTPError
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from starlette.middleware.sessions import SessionMiddleware
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from authlib.integrations.starlette_client import OAuth, OAuthError
from .settings import settings
from .models import User
from .auth_utils import hasrole, get_current_user_or_none, get_current_user
from .database import db
templates = Jinja2Templates("src/templates")
app = FastAPI(
title="OIDC auth test",
)
# SessionMiddleware is required by authlib
app.add_middleware(
SessionMiddleware,
secret_key=settings.secret_key,
)
# Add oidc providers to authlib from the settings
authlib_oauth = OAuth()
for provider in settings.oidc.providers:
authlib_oauth.register(
name=provider.name,
server_metadata_url=provider.provider_url,
client_kwargs={
"scope": "openid email offline_access profile roles",
},
client_id=provider.client_id,
client_secret=provider.client_secret,
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
)
@app.get("/login")
async def login(request: Request, provider: str) -> RedirectResponse:
redirect_uri = request.url_for("auth", oidc_provider_id=provider)
try:
provider_: StarletteOAuth2App = getattr(authlib_oauth, provider)
except AttributeError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
try:
return await provider_.authorize_redirect(request, redirect_uri)
except HTTPError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
@app.get("/auth/{oidc_provider_id}")
async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
try:
oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
except AttributeError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
try:
token = await oidc_provider.authorize_access_token(request)
except OAuthError as error:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error)
# Remember the oidc_provider in the session
request.session["oidc_provider_id"] = oidc_provider_id
#
# One could process the full decoded token which contains extra information
# eg for updates. Here we are only interested in roles
#
if userinfo := token.get("userinfo"):
# sub given by oidc provider
sub = userinfo["sub"]
# Build and remember the user in the session
request.session["user_sub"] = sub
# Store the user in the database
await db.add_user(sub, user_info=userinfo, oidc_provider=oidc_provider)
return RedirectResponse(url="/")
else:
# Not sure if it's correct to redirect to plain login (which is not implemented anyway)
# if no userinfo is provided
return RedirectResponse(url="/login")
@app.get("/logout")
async def logout(
request: Request, user: Annotated[User, Depends(get_current_user_or_none)]
) -> RedirectResponse:
# TODO: logout from oidc_provider
# await user.oidc_provider.logout_redirect()
request.session.pop("user_sub", None)
return RedirectResponse(url="/")
@app.get("/")
async def home(
request: Request, user: Annotated[User, Depends(get_current_user_or_none)]
) -> HTMLResponse:
return templates.TemplateResponse(
request=request,
context={
"settings": settings.model_dump(),
"user": user,
},
name="index.html",
)
@app.get("/public")
async def public() -> HTMLResponse:
return HTMLResponse("<h1>Not protected</h1>")
@app.get("/protected")
async def get_protected(
user: Annotated[User, Depends(get_current_user)]
) -> HTMLResponse:
return HTMLResponse("<h1>Only authenticated users can see this</h1>")
@app.get("/protected-by-foorole")
@hasrole("foorole")
async def get_protected_by_foorole(request: Request) -> HTMLResponse:
return HTMLResponse("<h1>Only users with foorole can see this</h1>")
@app.get("/protected-by-barrole")
@hasrole("barrole")
async def get_protected_by_barrole(request: Request) -> HTMLResponse:
return HTMLResponse("<h1>Protected by barrole</h1>")
@app.get("/protected-by-foorole-and-barrole")
@hasrole("barrole")
@hasrole("foorole")
async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse:
return HTMLResponse("<h1>Only users with foorole and barrole can see this</h1>")
@app.get("/protected-by-foorole-or-barrole")
@hasrole(["foorole", "barrole"])
async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse:
return HTMLResponse("<h1>Only users with foorole or barrole can see this</h1>")

48
src/oidc_test/models.py Normal file
View file

@ -0,0 +1,48 @@
from functools import cached_property
from typing import Self
from pydantic import BaseModel, EmailStr, AnyHttpUrl, Field, computed_field
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
# from app.models import User
class Role(BaseModel, extra="ignore"):
name: str
class UserBase(BaseModel, extra="ignore"):
id: str | None = None
name: str
email: EmailStr | None = None
picture: AnyHttpUrl | None = None
roles: list[Role] = []
class User(UserBase):
class Config:
arbitrary_types_allowed = True
sub: str = Field(
description="""subject id of the user given by the oidc provider,
also the key for the database 'table'""",
)
userinfo: dict = {}
oidc_provider: StarletteOAuth2App | None = None
@classmethod
def from_auth(cls, userinfo: dict, oidc_provider: StarletteOAuth2App) -> Self:
user = cls(**userinfo)
user.userinfo = userinfo
user.oidc_provider = oidc_provider
# Add roles if they are provided in the token
if raw_ra := userinfo.get("realm_access"):
if raw_roles := raw_ra.get("roles"):
user.roles = [Role(name=raw_role) for raw_role in raw_roles]
return user
@computed_field
@cached_property
def roles_as_set(self) -> set[str]:
return set([role.name for role in self.roles])

63
src/oidc_test/settings.py Normal file
View file

@ -0,0 +1,63 @@
import string
import random
from typing import Type, Tuple
from pydantic import BaseModel, computed_field
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
YamlConfigSettingsSource,
)
class OIDCProvider(BaseModel):
name: str = ""
url: str = ""
client_id: str = ""
client_secret: str = ""
@computed_field
@property
def provider_url(self) -> str:
return self.url + "/.well-known/openid-configuration"
@computed_field
@property
def token_url(self) -> str:
return "auth/" + self.name
class OIDCSettings(BaseModel):
show_session_details: bool = False
providers: list[OIDCProvider] = []
swagger_provider: str = ""
class Settings(BaseSettings):
"""Settings wil be read from an .env file"""
oidc: OIDCSettings = OIDCSettings()
secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16))
class Config:
env_nested_delimiter = "__"
@classmethod
def settings_customise_sources(
cls,
settings_cls: Type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> Tuple[PydanticBaseSettingsSource, ...]:
return (
init_settings,
env_settings,
file_secret_settings,
YamlConfigSettingsSource(settings_cls, "settings.yaml"),
dotenv_settings,
)
settings = Settings()