Add hasrole route decorator, more checks, refactor

This commit is contained in:
phil 2025-01-03 17:00:38 +01:00
parent e44d89e512
commit 09d9bc0a7d
3 changed files with 102 additions and 72 deletions

View file

@ -0,0 +1,58 @@
from typing import Union
from functools import wraps
from fastapi import HTTPException, Request
from .models import User
def get_current_user(request: Request) -> User:
auth_data = request.session.get("user")
if auth_data is None:
raise HTTPException(401, "Not authorized")
return User(**auth_data)
def get_current_user_or_none(request: Request) -> User | None:
try:
return get_current_user(request)
except HTTPException:
return None
def hasrole(
required_roles: Union[str, list[str]] = [],
roles_key: str = "roles",
realm: str | None = "realm_access", # Keycloak standard for realm defined roles
):
required_roles_set: set[str]
if isinstance(required_roles, str):
required_roles_set = set([required_roles])
else:
required_roles_set = set(required_roles)
def decorator(func):
@wraps(func)
async def wrapper(request=None, *args, **kwargs):
if request is None:
raise HTTPException(
500,
"Functions decorated with hasrole must have a request:Request argument",
)
if "user" not in request.session:
raise HTTPException(401, "Not authorized")
user = request.session["user"]
try:
if realm in user:
roles = user[realm][roles_key]
else:
roles = user[roles_key]
except KeyError:
raise HTTPException(401, "Not authorized")
if not any(required_roles_set.intersection(roles)):
raise HTTPException(401, "Not authorized")
return await func(request, *args, **kwargs)
return wrapper
return decorator

View file

@ -2,101 +2,41 @@ from typing import Annotated
from httpx import HTTPError
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.responses import HTMLResponse, RedirectResponse, PlainTextResponse
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from fastapi.security import OpenIdConnect
from starlette.middleware.sessions import SessionMiddleware
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from authlib.integrations.starlette_client import OAuth, OAuthError
from .settings import settings
from .models import User
from .auth_utils import hasrole, get_current_user_or_none, get_current_user
templates = Jinja2Templates("src/templates")
# swagger_provider = settings.oidc.get_swagger_provider()
# if swagger_provider is not None:
# swagger_ui_init_oauth = {
# "clientId": settings.oidc.get_swagger_provider().client_id,
# "scopes": ["openid"], # fill in additional scopes when necessary
# "appName": "Test Application",
# # "usePkceWithAuthorizationCodeGrant": True,
# }
# else:
# swagger_ui_init_oauth = None
app = FastAPI(
title="OIDC auth test",
# swagger_ui_init_oauth=swagger_ui_init_oauth,
)
# SessionMiddleware is required by authlib
app.add_middleware(
SessionMiddleware,
secret_key=settings.secret_key,
)
# Add oidc providers from the settings
authlib_oauth = OAuth()
for provider in settings.oidc.providers:
authlib_oauth.register(
name=provider.name,
server_metadata_url=provider.provider_url,
client_kwargs={"scope": "openid email offline_access profile"},
client_kwargs={"scope": "openid email offline_access profile roles"},
client_id=provider.client_id,
client_secret=provider.client_secret,
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
)
# oidc_providers = dict(
# (
# provider.name,
# OpenIdConnect(
# openIdConnectUrl=provider.url,
# scheme_name="openid",
# auto_error=True,
# ),
# )
# for provider in settings.oidc.providers
# )
# oidc_scheme = oidc_providers[swagger_provider.name]
def get_current_user(request: Request) -> User:
auth_data = request.session.get("user")
if auth_data is None:
raise HTTPException(401, "Not authorized")
return User(**auth_data)
def get_current_user_or_none(request: Request) -> User | None:
try:
return get_current_user(request)
except HTTPException:
return None
# def fastapi_oauth2():
# breakpoint()
# ...
# async def current_user(request: Request, token: str | None = Depends(fastapi_oauth2)):
# # we could query the identity provider to give us some information about the user
# # userinfo = await self.authlib_oauth.provider.userinfo(token={"access_token": token})
#
# # in my case, the JWT already contains all the information so I only need to decode and verify it
# try:
# # note that this also validates the JWT by validating all the claims
# user = await authlib_oauth.provider.parse_id_token(
# request, token={"id_token": token}
# )
# except Exception as exp:
# raise HTTPException(
# status_code=status.HTTP_401_UNAUTHORIZED,
# detail=f"Supplied authentication could not be validated ({exp})",
# )
# return user
@app.get("/login")
async def login(request: Request, provider: str) -> RedirectResponse:
@ -104,7 +44,7 @@ async def login(request: Request, provider: str) -> RedirectResponse:
try:
provider_: StarletteOAuth2App = getattr(authlib_oauth, provider)
except AttributeError:
raise HTTPException(500, "")
raise HTTPException(500, "No such provider")
try:
return await provider_.authorize_redirect(request, redirect_uri)
except HTTPError:
@ -116,7 +56,7 @@ async def auth(request: Request, provider: str) -> RedirectResponse:
try:
provider_: StarletteOAuth2App = getattr(authlib_oauth, provider)
except AttributeError:
raise HTTPException(500, "")
raise HTTPException(500, "No such provider")
try:
token = await provider_.authorize_access_token(request)
except OAuthError as error:
@ -150,8 +90,38 @@ async def home(
)
@app.get("/public")
async def public() -> HTMLResponse:
return HTMLResponse("<h1>Not protected</h1>")
@app.get("/protected")
async def get_protected(
user: Annotated[User, Depends(get_current_user)]
) -> PlainTextResponse:
return PlainTextResponse("Only authenticated users can see this")
) -> HTMLResponse:
return HTMLResponse("<h1>Only authenticated users can see this</h1>")
@app.get("/protected-by-foorole")
@hasrole("foorole")
async def get_protected_by_foorole(request: Request) -> HTMLResponse:
return HTMLResponse("<h1>Only users with foorole can see this</h1>")
@app.get("/protected-by-barrole")
@hasrole("barrole")
async def get_protected_by_barrole(request: Request) -> HTMLResponse:
return HTMLResponse("<h1>Protected by barrole</h1>")
@app.get("/protected-by-foorole-and-barrole")
@hasrole("barrole")
@hasrole("foorole")
async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse:
return HTMLResponse("<h1>Only users with foorole and barrole can see this</h1>")
@app.get("/protected-by-foorole-or-barrole")
@hasrole(["foorole", "barrole"])
async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse:
return HTMLResponse("<h1>Only users with foorole or barrole can see this</h1>")

View file

@ -169,9 +169,11 @@
</p>
<div id="links-to-check">
<a href="public">Public</a>
<a href="protected">Access protected content</a>
<a href="protected-by-foorole">Access + foorole protected content</a>
<a href="protected-by-barrole">Access + barrole protected content</a>
<a href="protected">Auth protected content</a>
<a href="protected-by-foorole">Auth + foorole protected content</a>
<a href="protected-by-barrole">Auth + barrole protected content</a>
<a href="protected-by-foorole-and-barrole">Auth + foorole and barrole protected content</a>
<a href="protected-by-foorole-or-barrole">Auth + foorole or barrole protected content</a>
<a href="other">Other</a>
</div>
{% if user and settings.oidc.show_session_details %}