Add fake db, properly deal with roles, improve types, etc

This commit is contained in:
phil 2025-01-05 05:06:58 +01:00
parent 522b3465df
commit 5c3d54c3f2
6 changed files with 126 additions and 52 deletions

View file

@ -1,16 +1,16 @@
from typing import Union
from functools import wraps
from fastapi import HTTPException, Request
from fastapi import HTTPException, Request, status
from .models import User
from .database import db
def get_current_user(request: Request) -> User:
auth_data = request.session.get("user")
if auth_data is None:
raise HTTPException(401, "Not authorized")
return User(**auth_data)
if (user_sub := request.session.get("user_sub")) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
return db.get_user(user_sub)
def get_current_user_or_none(request: Request) -> User | None:
@ -20,11 +20,7 @@ def get_current_user_or_none(request: Request) -> User | None:
return None
def hasrole(
required_roles: Union[str, list[str]] = [],
roles_key: str = "roles",
realm: str | None = "realm_access", # Keycloak standard for realm defined roles
):
def hasrole(required_roles: Union[str, list[str]] = []):
required_roles_set: set[str]
if isinstance(required_roles, str):
required_roles_set = set([required_roles])
@ -39,18 +35,9 @@ def hasrole(
500,
"Functions decorated with hasrole must have a request:Request argument",
)
if "user" not in request.session:
raise HTTPException(401, "Not authorized")
user = request.session["user"]
try:
if realm in user:
roles = user[realm][roles_key]
else:
roles = user[roles_key]
except KeyError:
raise HTTPException(401, "Not authorized")
if not any(required_roles_set.intersection(roles)):
raise HTTPException(401, "Not authorized")
user: User = get_current_user(request)
if not any(required_roles_set.intersection(user.roles_as_set)):
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
return await func(request, *args, **kwargs)
return wrapper

24
src/oidc-test/database.py Normal file
View file

@ -0,0 +1,24 @@
# Implement a fake in-memory database interface for demo purpose
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from .models import User
class Database:
users: dict[str, User] = {}
# Last sessions for the user (key: users's subject id (sub))
async def add_user(
self, sub: str, user_info: dict, oidc_provider: StarletteOAuth2App
) -> User:
user = User.from_auth(userinfo=user_info, oidc_provider=oidc_provider)
self.users[sub] = user
return user
def get_user(self, sub: str) -> User:
return self.users[sub]
db = Database()

View file

@ -11,6 +11,7 @@ from authlib.integrations.starlette_client import OAuth, OAuthError
from .settings import settings
from .models import User
from .auth_utils import hasrole, get_current_user_or_none, get_current_user
from .database import db
templates = Jinja2Templates("src/templates")
@ -25,13 +26,15 @@ app.add_middleware(
secret_key=settings.secret_key,
)
# Add oidc providers from the settings
# Add oidc providers to authlib from the settings
authlib_oauth = OAuth()
for provider in settings.oidc.providers:
authlib_oauth.register(
name=provider.name,
server_metadata_url=provider.provider_url,
client_kwargs={"scope": "openid email offline_access profile roles"},
client_kwargs={
"scope": "openid email offline_access profile roles",
},
client_id=provider.client_id,
client_secret=provider.client_secret,
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
@ -40,38 +43,54 @@ for provider in settings.oidc.providers:
@app.get("/login")
async def login(request: Request, provider: str) -> RedirectResponse:
redirect_uri = request.url_for("auth", provider=provider)
redirect_uri = request.url_for("auth", oidc_provider_id=provider)
try:
provider_: StarletteOAuth2App = getattr(authlib_oauth, provider)
except AttributeError:
raise HTTPException(500, "No such provider")
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
try:
return await provider_.authorize_redirect(request, redirect_uri)
except HTTPError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
@app.get("/auth/{provider}")
async def auth(request: Request, provider: str) -> RedirectResponse:
@app.get("/auth/{oidc_provider_id}")
async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
try:
provider_: StarletteOAuth2App = getattr(authlib_oauth, provider)
oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
except AttributeError:
raise HTTPException(500, "No such provider")
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
try:
token = await provider_.authorize_access_token(request)
token = await oidc_provider.authorize_access_token(request)
except OAuthError as error:
raise HTTPException(status_code=401, detail=error.error)
user = token.get("userinfo")
if user:
request.session["user"] = dict(user)
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error)
# Remember the oidc_provider in the session
request.session["oidc_provider_id"] = oidc_provider_id
#
# One could process the full decoded token which contains extra information
# eg for updates. Here we are only interested in roles
#
if userinfo := token.get("userinfo"):
# sub given by oidc provider
sub = userinfo["sub"]
# Build and remember the user in the session
request.session["user_sub"] = sub
# Store the user in the database
await db.add_user(sub, user_info=userinfo, oidc_provider=oidc_provider)
return RedirectResponse(url="/")
else:
# Not sure if it's correct to redirect to plain login (which is not implemented anyway)
# if no userinfo is provided
return RedirectResponse(url="/login")
@app.get("/logout")
async def logout(request: Request) -> RedirectResponse:
request.session.pop("user", None)
async def logout(
request: Request, user: Annotated[User, Depends(get_current_user_or_none)]
) -> RedirectResponse:
# TODO: logout from oidc_provider
# await user.oidc_provider.logout_redirect()
request.session.pop("user_sub", None)
return RedirectResponse(url="/")
@ -84,7 +103,6 @@ async def home(
context={
"settings": settings.model_dump(),
"user": user,
"auth_data": request.session.get("user"),
},
name="index.html",
)

View file

@ -1,11 +1,48 @@
from pydantic import BaseModel, EmailStr, AnyHttpUrl
from functools import cached_property
from typing import Self
from pydantic import BaseModel, EmailStr, AnyHttpUrl, Field, computed_field
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
# from app.models import User
class User(BaseModel, extra="ignore"):
class Role(BaseModel, extra="ignore"):
name: str
class UserBase(BaseModel, extra="ignore"):
id: str | None = None
name: str
email: EmailStr | None = None
picture: AnyHttpUrl | None = None
roles: list[Role] = []
class User(UserBase):
class Config:
arbitrary_types_allowed = True
sub: str = Field(
description="""subject id of the user given by the oidc provider,
also the key for the database 'table'""",
)
userinfo: dict = {}
oidc_provider: StarletteOAuth2App | None = None
@classmethod
def from_auth(cls, userinfo: dict, oidc_provider: StarletteOAuth2App) -> Self:
user = cls(**userinfo)
user.userinfo = userinfo
user.oidc_provider = oidc_provider
# Add roles if they are provided in the token
if raw_ra := userinfo.get("realm_access"):
if raw_roles := raw_ra.get("roles"):
user.roles = [Role(name=raw_role) for raw_role in raw_roles]
return user
@computed_field
@cached_property
def roles_as_set(self) -> set[str]:
return set([role.name for role in self.roles])

View file

@ -15,7 +15,6 @@ class OIDCProvider(BaseModel):
url: str = ""
client_id: str = ""
client_secret: str = ""
is_swagger: bool = False
@computed_field
@property
@ -33,13 +32,6 @@ class OIDCSettings(BaseModel):
providers: list[OIDCProvider] = []
swagger_provider: str = ""
def get_swagger_provider(self) -> OIDCProvider:
for provider in self.providers:
if provider.is_swagger:
return provider
else:
raise UserWarning("Please define a provider for Swagger with id_swagger")
class Settings(BaseSettings):
"""Settings wil be read from an .env file"""

View file

@ -119,6 +119,10 @@
.hasResponseStatus.status-401 {
background-color: #ff000040;
}
.role {
padding: 3px 6px;
background-color: #44228840;
}
</style>
<script>
function checkHref(elem) {
@ -143,10 +147,10 @@
<h1>FastAPI test app for OIDC</h1>
{% if not user %}
<div class="login-box">
<p>Not logged in</p>
<p>Log in with one of these authentication providers:</p>
<div class="login-toolbox">
{% for provider in settings.oidc.providers %}
<a href="login?provider={{ provider.name }}">Login with: {{ provider.name }}</a>
<a href="login?provider={{ provider.name }}">{{ provider.name }}</a>
{% else %}
<span class="error">Cannot login: no oidc prodiver in settings.yaml</span>
{% endfor %}
@ -159,7 +163,19 @@
{% if user.picture %}
<img src="{{ user.picture }}" class="picture"></img>
{% endif %}
<p>{{ user.email }}</p>
<div>{{ user.email }}</div>
{% if user.roles %}
<div>
<span>Roles:</span>
{% for role in user.roles %}
<span class="role">{{ role.name }}</span>
{% endfor %}
</div>
{% endif %}
<div>
<span>Provider:</span>
{{ user.oidc_provider.name }}
</div>
<a href="logout" class="logout">Logout</a>
</div>
{% endif %}
@ -180,7 +196,7 @@
<div class="debug-auth">
<p>Session details</p>
<ul>
{% for key, value in auth_data.items() %}
{% for key, value in user.userinfo.items() %}
<li>
<span class="key">{{ key }}</span>: {{ value }}
</li>