From 09d9bc0a7d74bdd37c12af48fd1d7d401e78e49b Mon Sep 17 00:00:00 2001 From: phil Date: Fri, 3 Jan 2025 17:00:38 +0100 Subject: [PATCH] Add hasrole route decorator, more checks, refactor --- src/oidc-test/auth_utils.py | 58 +++++++++++++++++++ src/oidc-test/main.py | 108 +++++++++++++----------------------- src/templates/index.html | 8 ++- 3 files changed, 102 insertions(+), 72 deletions(-) create mode 100644 src/oidc-test/auth_utils.py diff --git a/src/oidc-test/auth_utils.py b/src/oidc-test/auth_utils.py new file mode 100644 index 0000000..a322979 --- /dev/null +++ b/src/oidc-test/auth_utils.py @@ -0,0 +1,58 @@ +from typing import Union +from functools import wraps + +from fastapi import HTTPException, Request + +from .models import User + + +def get_current_user(request: Request) -> User: + auth_data = request.session.get("user") + if auth_data is None: + raise HTTPException(401, "Not authorized") + return User(**auth_data) + + +def get_current_user_or_none(request: Request) -> User | None: + try: + return get_current_user(request) + except HTTPException: + return None + + +def hasrole( + required_roles: Union[str, list[str]] = [], + roles_key: str = "roles", + realm: str | None = "realm_access", # Keycloak standard for realm defined roles +): + required_roles_set: set[str] + if isinstance(required_roles, str): + required_roles_set = set([required_roles]) + else: + required_roles_set = set(required_roles) + + def decorator(func): + @wraps(func) + async def wrapper(request=None, *args, **kwargs): + if request is None: + raise HTTPException( + 500, + "Functions decorated with hasrole must have a request:Request argument", + ) + if "user" not in request.session: + raise HTTPException(401, "Not authorized") + user = request.session["user"] + try: + if realm in user: + roles = user[realm][roles_key] + else: + roles = user[roles_key] + except KeyError: + raise HTTPException(401, "Not authorized") + if not any(required_roles_set.intersection(roles)): + raise HTTPException(401, "Not authorized") + return await func(request, *args, **kwargs) + + return wrapper + + return decorator diff --git a/src/oidc-test/main.py b/src/oidc-test/main.py index fce3e75..4254277 100644 --- a/src/oidc-test/main.py +++ b/src/oidc-test/main.py @@ -2,101 +2,41 @@ from typing import Annotated from httpx import HTTPError from fastapi import Depends, FastAPI, HTTPException, Request, status -from fastapi.responses import HTMLResponse, RedirectResponse, PlainTextResponse +from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.templating import Jinja2Templates -from fastapi.security import OpenIdConnect from starlette.middleware.sessions import SessionMiddleware from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.starlette_client import OAuth, OAuthError from .settings import settings from .models import User +from .auth_utils import hasrole, get_current_user_or_none, get_current_user templates = Jinja2Templates("src/templates") -# swagger_provider = settings.oidc.get_swagger_provider() -# if swagger_provider is not None: -# swagger_ui_init_oauth = { -# "clientId": settings.oidc.get_swagger_provider().client_id, -# "scopes": ["openid"], # fill in additional scopes when necessary -# "appName": "Test Application", -# # "usePkceWithAuthorizationCodeGrant": True, -# } -# else: -# swagger_ui_init_oauth = None - app = FastAPI( title="OIDC auth test", - # swagger_ui_init_oauth=swagger_ui_init_oauth, ) +# SessionMiddleware is required by authlib app.add_middleware( SessionMiddleware, secret_key=settings.secret_key, ) +# Add oidc providers from the settings authlib_oauth = OAuth() for provider in settings.oidc.providers: authlib_oauth.register( name=provider.name, server_metadata_url=provider.provider_url, - client_kwargs={"scope": "openid email offline_access profile"}, + client_kwargs={"scope": "openid email offline_access profile roles"}, client_id=provider.client_id, client_secret=provider.client_secret, # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) ) -# oidc_providers = dict( -# ( -# provider.name, -# OpenIdConnect( -# openIdConnectUrl=provider.url, -# scheme_name="openid", -# auto_error=True, -# ), -# ) -# for provider in settings.oidc.providers -# ) -# oidc_scheme = oidc_providers[swagger_provider.name] - - -def get_current_user(request: Request) -> User: - auth_data = request.session.get("user") - if auth_data is None: - raise HTTPException(401, "Not authorized") - return User(**auth_data) - - -def get_current_user_or_none(request: Request) -> User | None: - try: - return get_current_user(request) - except HTTPException: - return None - - -# def fastapi_oauth2(): -# breakpoint() -# ... - - -# async def current_user(request: Request, token: str | None = Depends(fastapi_oauth2)): -# # we could query the identity provider to give us some information about the user -# # userinfo = await self.authlib_oauth.provider.userinfo(token={"access_token": token}) -# -# # in my case, the JWT already contains all the information so I only need to decode and verify it -# try: -# # note that this also validates the JWT by validating all the claims -# user = await authlib_oauth.provider.parse_id_token( -# request, token={"id_token": token} -# ) -# except Exception as exp: -# raise HTTPException( -# status_code=status.HTTP_401_UNAUTHORIZED, -# detail=f"Supplied authentication could not be validated ({exp})", -# ) -# return user - @app.get("/login") async def login(request: Request, provider: str) -> RedirectResponse: @@ -104,7 +44,7 @@ async def login(request: Request, provider: str) -> RedirectResponse: try: provider_: StarletteOAuth2App = getattr(authlib_oauth, provider) except AttributeError: - raise HTTPException(500, "") + raise HTTPException(500, "No such provider") try: return await provider_.authorize_redirect(request, redirect_uri) except HTTPError: @@ -116,7 +56,7 @@ async def auth(request: Request, provider: str) -> RedirectResponse: try: provider_: StarletteOAuth2App = getattr(authlib_oauth, provider) except AttributeError: - raise HTTPException(500, "") + raise HTTPException(500, "No such provider") try: token = await provider_.authorize_access_token(request) except OAuthError as error: @@ -150,8 +90,38 @@ async def home( ) +@app.get("/public") +async def public() -> HTMLResponse: + return HTMLResponse("

Not protected

") + + @app.get("/protected") async def get_protected( user: Annotated[User, Depends(get_current_user)] -) -> PlainTextResponse: - return PlainTextResponse("Only authenticated users can see this") +) -> HTMLResponse: + return HTMLResponse("

Only authenticated users can see this

") + + +@app.get("/protected-by-foorole") +@hasrole("foorole") +async def get_protected_by_foorole(request: Request) -> HTMLResponse: + return HTMLResponse("

Only users with foorole can see this

") + + +@app.get("/protected-by-barrole") +@hasrole("barrole") +async def get_protected_by_barrole(request: Request) -> HTMLResponse: + return HTMLResponse("

Protected by barrole

") + + +@app.get("/protected-by-foorole-and-barrole") +@hasrole("barrole") +@hasrole("foorole") +async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse: + return HTMLResponse("

Only users with foorole and barrole can see this

") + + +@app.get("/protected-by-foorole-or-barrole") +@hasrole(["foorole", "barrole"]) +async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse: + return HTMLResponse("

Only users with foorole or barrole can see this

") diff --git a/src/templates/index.html b/src/templates/index.html index 9b21772..2030651 100644 --- a/src/templates/index.html +++ b/src/templates/index.html @@ -169,9 +169,11 @@

{% if user and settings.oidc.show_session_details %}