Add fake db, properly deal with roles, improve types, etc

This commit is contained in:
phil 2025-01-05 05:06:58 +01:00
parent 522b3465df
commit 5c3d54c3f2
6 changed files with 126 additions and 52 deletions

View file

@ -1,16 +1,16 @@
from typing import Union from typing import Union
from functools import wraps from functools import wraps
from fastapi import HTTPException, Request from fastapi import HTTPException, Request, status
from .models import User from .models import User
from .database import db
def get_current_user(request: Request) -> User: def get_current_user(request: Request) -> User:
auth_data = request.session.get("user") if (user_sub := request.session.get("user_sub")) is None:
if auth_data is None: raise HTTPException(status.HTTP_401_UNAUTHORIZED)
raise HTTPException(401, "Not authorized") return db.get_user(user_sub)
return User(**auth_data)
def get_current_user_or_none(request: Request) -> User | None: def get_current_user_or_none(request: Request) -> User | None:
@ -20,11 +20,7 @@ def get_current_user_or_none(request: Request) -> User | None:
return None return None
def hasrole( def hasrole(required_roles: Union[str, list[str]] = []):
required_roles: Union[str, list[str]] = [],
roles_key: str = "roles",
realm: str | None = "realm_access", # Keycloak standard for realm defined roles
):
required_roles_set: set[str] required_roles_set: set[str]
if isinstance(required_roles, str): if isinstance(required_roles, str):
required_roles_set = set([required_roles]) required_roles_set = set([required_roles])
@ -39,18 +35,9 @@ def hasrole(
500, 500,
"Functions decorated with hasrole must have a request:Request argument", "Functions decorated with hasrole must have a request:Request argument",
) )
if "user" not in request.session: user: User = get_current_user(request)
raise HTTPException(401, "Not authorized") if not any(required_roles_set.intersection(user.roles_as_set)):
user = request.session["user"] raise HTTPException(status.HTTP_401_UNAUTHORIZED)
try:
if realm in user:
roles = user[realm][roles_key]
else:
roles = user[roles_key]
except KeyError:
raise HTTPException(401, "Not authorized")
if not any(required_roles_set.intersection(roles)):
raise HTTPException(401, "Not authorized")
return await func(request, *args, **kwargs) return await func(request, *args, **kwargs)
return wrapper return wrapper

24
src/oidc-test/database.py Normal file
View file

@ -0,0 +1,24 @@
# Implement a fake in-memory database interface for demo purpose
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from .models import User
class Database:
users: dict[str, User] = {}
# Last sessions for the user (key: users's subject id (sub))
async def add_user(
self, sub: str, user_info: dict, oidc_provider: StarletteOAuth2App
) -> User:
user = User.from_auth(userinfo=user_info, oidc_provider=oidc_provider)
self.users[sub] = user
return user
def get_user(self, sub: str) -> User:
return self.users[sub]
db = Database()

View file

@ -11,6 +11,7 @@ from authlib.integrations.starlette_client import OAuth, OAuthError
from .settings import settings from .settings import settings
from .models import User from .models import User
from .auth_utils import hasrole, get_current_user_or_none, get_current_user from .auth_utils import hasrole, get_current_user_or_none, get_current_user
from .database import db
templates = Jinja2Templates("src/templates") templates = Jinja2Templates("src/templates")
@ -25,13 +26,15 @@ app.add_middleware(
secret_key=settings.secret_key, secret_key=settings.secret_key,
) )
# Add oidc providers from the settings # Add oidc providers to authlib from the settings
authlib_oauth = OAuth() authlib_oauth = OAuth()
for provider in settings.oidc.providers: for provider in settings.oidc.providers:
authlib_oauth.register( authlib_oauth.register(
name=provider.name, name=provider.name,
server_metadata_url=provider.provider_url, server_metadata_url=provider.provider_url,
client_kwargs={"scope": "openid email offline_access profile roles"}, client_kwargs={
"scope": "openid email offline_access profile roles",
},
client_id=provider.client_id, client_id=provider.client_id,
client_secret=provider.client_secret, client_secret=provider.client_secret,
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
@ -40,38 +43,54 @@ for provider in settings.oidc.providers:
@app.get("/login") @app.get("/login")
async def login(request: Request, provider: str) -> RedirectResponse: async def login(request: Request, provider: str) -> RedirectResponse:
redirect_uri = request.url_for("auth", provider=provider) redirect_uri = request.url_for("auth", oidc_provider_id=provider)
try: try:
provider_: StarletteOAuth2App = getattr(authlib_oauth, provider) provider_: StarletteOAuth2App = getattr(authlib_oauth, provider)
except AttributeError: except AttributeError:
raise HTTPException(500, "No such provider") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
try: try:
return await provider_.authorize_redirect(request, redirect_uri) return await provider_.authorize_redirect(request, redirect_uri)
except HTTPError: except HTTPError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
@app.get("/auth/{provider}") @app.get("/auth/{oidc_provider_id}")
async def auth(request: Request, provider: str) -> RedirectResponse: async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
try: try:
provider_: StarletteOAuth2App = getattr(authlib_oauth, provider) oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
except AttributeError: except AttributeError:
raise HTTPException(500, "No such provider") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
try: try:
token = await provider_.authorize_access_token(request) token = await oidc_provider.authorize_access_token(request)
except OAuthError as error: except OAuthError as error:
raise HTTPException(status_code=401, detail=error.error) raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error)
user = token.get("userinfo") # Remember the oidc_provider in the session
if user: request.session["oidc_provider_id"] = oidc_provider_id
request.session["user"] = dict(user) #
# One could process the full decoded token which contains extra information
# eg for updates. Here we are only interested in roles
#
if userinfo := token.get("userinfo"):
# sub given by oidc provider
sub = userinfo["sub"]
# Build and remember the user in the session
request.session["user_sub"] = sub
# Store the user in the database
await db.add_user(sub, user_info=userinfo, oidc_provider=oidc_provider)
return RedirectResponse(url="/") return RedirectResponse(url="/")
else: else:
# Not sure if it's correct to redirect to plain login (which is not implemented anyway)
# if no userinfo is provided
return RedirectResponse(url="/login") return RedirectResponse(url="/login")
@app.get("/logout") @app.get("/logout")
async def logout(request: Request) -> RedirectResponse: async def logout(
request.session.pop("user", None) request: Request, user: Annotated[User, Depends(get_current_user_or_none)]
) -> RedirectResponse:
# TODO: logout from oidc_provider
# await user.oidc_provider.logout_redirect()
request.session.pop("user_sub", None)
return RedirectResponse(url="/") return RedirectResponse(url="/")
@ -84,7 +103,6 @@ async def home(
context={ context={
"settings": settings.model_dump(), "settings": settings.model_dump(),
"user": user, "user": user,
"auth_data": request.session.get("user"),
}, },
name="index.html", name="index.html",
) )

View file

@ -1,11 +1,48 @@
from pydantic import BaseModel, EmailStr, AnyHttpUrl from functools import cached_property
from typing import Self
from pydantic import BaseModel, EmailStr, AnyHttpUrl, Field, computed_field
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
# from app.models import User # from app.models import User
class User(BaseModel, extra="ignore"): class Role(BaseModel, extra="ignore"):
name: str
class UserBase(BaseModel, extra="ignore"):
id: str | None = None id: str | None = None
name: str name: str
email: EmailStr | None = None email: EmailStr | None = None
picture: AnyHttpUrl | None = None picture: AnyHttpUrl | None = None
roles: list[Role] = []
class User(UserBase):
class Config:
arbitrary_types_allowed = True
sub: str = Field(
description="""subject id of the user given by the oidc provider,
also the key for the database 'table'""",
)
userinfo: dict = {}
oidc_provider: StarletteOAuth2App | None = None
@classmethod
def from_auth(cls, userinfo: dict, oidc_provider: StarletteOAuth2App) -> Self:
user = cls(**userinfo)
user.userinfo = userinfo
user.oidc_provider = oidc_provider
# Add roles if they are provided in the token
if raw_ra := userinfo.get("realm_access"):
if raw_roles := raw_ra.get("roles"):
user.roles = [Role(name=raw_role) for raw_role in raw_roles]
return user
@computed_field
@cached_property
def roles_as_set(self) -> set[str]:
return set([role.name for role in self.roles])

View file

@ -15,7 +15,6 @@ class OIDCProvider(BaseModel):
url: str = "" url: str = ""
client_id: str = "" client_id: str = ""
client_secret: str = "" client_secret: str = ""
is_swagger: bool = False
@computed_field @computed_field
@property @property
@ -33,13 +32,6 @@ class OIDCSettings(BaseModel):
providers: list[OIDCProvider] = [] providers: list[OIDCProvider] = []
swagger_provider: str = "" swagger_provider: str = ""
def get_swagger_provider(self) -> OIDCProvider:
for provider in self.providers:
if provider.is_swagger:
return provider
else:
raise UserWarning("Please define a provider for Swagger with id_swagger")
class Settings(BaseSettings): class Settings(BaseSettings):
"""Settings wil be read from an .env file""" """Settings wil be read from an .env file"""

View file

@ -119,6 +119,10 @@
.hasResponseStatus.status-401 { .hasResponseStatus.status-401 {
background-color: #ff000040; background-color: #ff000040;
} }
.role {
padding: 3px 6px;
background-color: #44228840;
}
</style> </style>
<script> <script>
function checkHref(elem) { function checkHref(elem) {
@ -143,10 +147,10 @@
<h1>FastAPI test app for OIDC</h1> <h1>FastAPI test app for OIDC</h1>
{% if not user %} {% if not user %}
<div class="login-box"> <div class="login-box">
<p>Not logged in</p> <p>Log in with one of these authentication providers:</p>
<div class="login-toolbox"> <div class="login-toolbox">
{% for provider in settings.oidc.providers %} {% for provider in settings.oidc.providers %}
<a href="login?provider={{ provider.name }}">Login with: {{ provider.name }}</a> <a href="login?provider={{ provider.name }}">{{ provider.name }}</a>
{% else %} {% else %}
<span class="error">Cannot login: no oidc prodiver in settings.yaml</span> <span class="error">Cannot login: no oidc prodiver in settings.yaml</span>
{% endfor %} {% endfor %}
@ -159,7 +163,19 @@
{% if user.picture %} {% if user.picture %}
<img src="{{ user.picture }}" class="picture"></img> <img src="{{ user.picture }}" class="picture"></img>
{% endif %} {% endif %}
<p>{{ user.email }}</p> <div>{{ user.email }}</div>
{% if user.roles %}
<div>
<span>Roles:</span>
{% for role in user.roles %}
<span class="role">{{ role.name }}</span>
{% endfor %}
</div>
{% endif %}
<div>
<span>Provider:</span>
{{ user.oidc_provider.name }}
</div>
<a href="logout" class="logout">Logout</a> <a href="logout" class="logout">Logout</a>
</div> </div>
{% endif %} {% endif %}
@ -180,7 +196,7 @@
<div class="debug-auth"> <div class="debug-auth">
<p>Session details</p> <p>Session details</p>
<ul> <ul>
{% for key, value in auth_data.items() %} {% for key, value in user.userinfo.items() %}
<li> <li>
<span class="key">{{ key }}</span>: {{ value }} <span class="key">{{ key }}</span>: {{ value }}
</li> </li>