From 09d9bc0a7d74bdd37c12af48fd1d7d401e78e49b Mon Sep 17 00:00:00 2001
From: phil
Date: Fri, 3 Jan 2025 17:00:38 +0100
Subject: [PATCH] Add hasrole route decorator, more checks, refactor
---
src/oidc-test/auth_utils.py | 58 +++++++++++++++++++
src/oidc-test/main.py | 108 +++++++++++++-----------------------
src/templates/index.html | 8 ++-
3 files changed, 102 insertions(+), 72 deletions(-)
create mode 100644 src/oidc-test/auth_utils.py
diff --git a/src/oidc-test/auth_utils.py b/src/oidc-test/auth_utils.py
new file mode 100644
index 0000000..a322979
--- /dev/null
+++ b/src/oidc-test/auth_utils.py
@@ -0,0 +1,58 @@
+from typing import Union
+from functools import wraps
+
+from fastapi import HTTPException, Request
+
+from .models import User
+
+
+def get_current_user(request: Request) -> User:
+ auth_data = request.session.get("user")
+ if auth_data is None:
+ raise HTTPException(401, "Not authorized")
+ return User(**auth_data)
+
+
+def get_current_user_or_none(request: Request) -> User | None:
+ try:
+ return get_current_user(request)
+ except HTTPException:
+ return None
+
+
+def hasrole(
+ required_roles: Union[str, list[str]] = [],
+ roles_key: str = "roles",
+ realm: str | None = "realm_access", # Keycloak standard for realm defined roles
+):
+ required_roles_set: set[str]
+ if isinstance(required_roles, str):
+ required_roles_set = set([required_roles])
+ else:
+ required_roles_set = set(required_roles)
+
+ def decorator(func):
+ @wraps(func)
+ async def wrapper(request=None, *args, **kwargs):
+ if request is None:
+ raise HTTPException(
+ 500,
+ "Functions decorated with hasrole must have a request:Request argument",
+ )
+ if "user" not in request.session:
+ raise HTTPException(401, "Not authorized")
+ user = request.session["user"]
+ try:
+ if realm in user:
+ roles = user[realm][roles_key]
+ else:
+ roles = user[roles_key]
+ except KeyError:
+ raise HTTPException(401, "Not authorized")
+ if not any(required_roles_set.intersection(roles)):
+ raise HTTPException(401, "Not authorized")
+ return await func(request, *args, **kwargs)
+
+ return wrapper
+
+ return decorator
diff --git a/src/oidc-test/main.py b/src/oidc-test/main.py
index fce3e75..4254277 100644
--- a/src/oidc-test/main.py
+++ b/src/oidc-test/main.py
@@ -2,101 +2,41 @@ from typing import Annotated
from httpx import HTTPError
from fastapi import Depends, FastAPI, HTTPException, Request, status
-from fastapi.responses import HTMLResponse, RedirectResponse, PlainTextResponse
+from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
-from fastapi.security import OpenIdConnect
from starlette.middleware.sessions import SessionMiddleware
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from authlib.integrations.starlette_client import OAuth, OAuthError
from .settings import settings
from .models import User
+from .auth_utils import hasrole, get_current_user_or_none, get_current_user
templates = Jinja2Templates("src/templates")
-# swagger_provider = settings.oidc.get_swagger_provider()
-# if swagger_provider is not None:
-# swagger_ui_init_oauth = {
-# "clientId": settings.oidc.get_swagger_provider().client_id,
-# "scopes": ["openid"], # fill in additional scopes when necessary
-# "appName": "Test Application",
-# # "usePkceWithAuthorizationCodeGrant": True,
-# }
-# else:
-# swagger_ui_init_oauth = None
-
app = FastAPI(
title="OIDC auth test",
- # swagger_ui_init_oauth=swagger_ui_init_oauth,
)
+# SessionMiddleware is required by authlib
app.add_middleware(
SessionMiddleware,
secret_key=settings.secret_key,
)
+# Add oidc providers from the settings
authlib_oauth = OAuth()
for provider in settings.oidc.providers:
authlib_oauth.register(
name=provider.name,
server_metadata_url=provider.provider_url,
- client_kwargs={"scope": "openid email offline_access profile"},
+ client_kwargs={"scope": "openid email offline_access profile roles"},
client_id=provider.client_id,
client_secret=provider.client_secret,
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
)
-# oidc_providers = dict(
-# (
-# provider.name,
-# OpenIdConnect(
-# openIdConnectUrl=provider.url,
-# scheme_name="openid",
-# auto_error=True,
-# ),
-# )
-# for provider in settings.oidc.providers
-# )
-# oidc_scheme = oidc_providers[swagger_provider.name]
-
-
-def get_current_user(request: Request) -> User:
- auth_data = request.session.get("user")
- if auth_data is None:
- raise HTTPException(401, "Not authorized")
- return User(**auth_data)
-
-
-def get_current_user_or_none(request: Request) -> User | None:
- try:
- return get_current_user(request)
- except HTTPException:
- return None
-
-
-# def fastapi_oauth2():
-# breakpoint()
-# ...
-
-
-# async def current_user(request: Request, token: str | None = Depends(fastapi_oauth2)):
-# # we could query the identity provider to give us some information about the user
-# # userinfo = await self.authlib_oauth.provider.userinfo(token={"access_token": token})
-#
-# # in my case, the JWT already contains all the information so I only need to decode and verify it
-# try:
-# # note that this also validates the JWT by validating all the claims
-# user = await authlib_oauth.provider.parse_id_token(
-# request, token={"id_token": token}
-# )
-# except Exception as exp:
-# raise HTTPException(
-# status_code=status.HTTP_401_UNAUTHORIZED,
-# detail=f"Supplied authentication could not be validated ({exp})",
-# )
-# return user
-
@app.get("/login")
async def login(request: Request, provider: str) -> RedirectResponse:
@@ -104,7 +44,7 @@ async def login(request: Request, provider: str) -> RedirectResponse:
try:
provider_: StarletteOAuth2App = getattr(authlib_oauth, provider)
except AttributeError:
- raise HTTPException(500, "")
+ raise HTTPException(500, "No such provider")
try:
return await provider_.authorize_redirect(request, redirect_uri)
except HTTPError:
@@ -116,7 +56,7 @@ async def auth(request: Request, provider: str) -> RedirectResponse:
try:
provider_: StarletteOAuth2App = getattr(authlib_oauth, provider)
except AttributeError:
- raise HTTPException(500, "")
+ raise HTTPException(500, "No such provider")
try:
token = await provider_.authorize_access_token(request)
except OAuthError as error:
@@ -150,8 +90,38 @@ async def home(
)
+@app.get("/public")
+async def public() -> HTMLResponse:
+ return HTMLResponse("Not protected
")
+
+
@app.get("/protected")
async def get_protected(
user: Annotated[User, Depends(get_current_user)]
-) -> PlainTextResponse:
- return PlainTextResponse("Only authenticated users can see this")
+) -> HTMLResponse:
+ return HTMLResponse("Only authenticated users can see this
")
+
+
+@app.get("/protected-by-foorole")
+@hasrole("foorole")
+async def get_protected_by_foorole(request: Request) -> HTMLResponse:
+ return HTMLResponse("Only users with foorole can see this
")
+
+
+@app.get("/protected-by-barrole")
+@hasrole("barrole")
+async def get_protected_by_barrole(request: Request) -> HTMLResponse:
+ return HTMLResponse("Protected by barrole
")
+
+
+@app.get("/protected-by-foorole-and-barrole")
+@hasrole("barrole")
+@hasrole("foorole")
+async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse:
+ return HTMLResponse("Only users with foorole and barrole can see this
")
+
+
+@app.get("/protected-by-foorole-or-barrole")
+@hasrole(["foorole", "barrole"])
+async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse:
+ return HTMLResponse("Only users with foorole or barrole can see this
")
diff --git a/src/templates/index.html b/src/templates/index.html
index 9b21772..2030651 100644
--- a/src/templates/index.html
+++ b/src/templates/index.html
@@ -169,9 +169,11 @@
{% if user and settings.oidc.show_session_details %}