Add hasrole route decorator, more checks, refactor
This commit is contained in:
parent
e44d89e512
commit
09d9bc0a7d
3 changed files with 102 additions and 72 deletions
58
src/oidc-test/auth_utils.py
Normal file
58
src/oidc-test/auth_utils.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
from typing import Union
|
||||
from functools import wraps
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from .models import User
|
||||
|
||||
|
||||
def get_current_user(request: Request) -> User:
|
||||
auth_data = request.session.get("user")
|
||||
if auth_data is None:
|
||||
raise HTTPException(401, "Not authorized")
|
||||
return User(**auth_data)
|
||||
|
||||
|
||||
def get_current_user_or_none(request: Request) -> User | None:
|
||||
try:
|
||||
return get_current_user(request)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
def hasrole(
|
||||
required_roles: Union[str, list[str]] = [],
|
||||
roles_key: str = "roles",
|
||||
realm: str | None = "realm_access", # Keycloak standard for realm defined roles
|
||||
):
|
||||
required_roles_set: set[str]
|
||||
if isinstance(required_roles, str):
|
||||
required_roles_set = set([required_roles])
|
||||
else:
|
||||
required_roles_set = set(required_roles)
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(request=None, *args, **kwargs):
|
||||
if request is None:
|
||||
raise HTTPException(
|
||||
500,
|
||||
"Functions decorated with hasrole must have a request:Request argument",
|
||||
)
|
||||
if "user" not in request.session:
|
||||
raise HTTPException(401, "Not authorized")
|
||||
user = request.session["user"]
|
||||
try:
|
||||
if realm in user:
|
||||
roles = user[realm][roles_key]
|
||||
else:
|
||||
roles = user[roles_key]
|
||||
except KeyError:
|
||||
raise HTTPException(401, "Not authorized")
|
||||
if not any(required_roles_set.intersection(roles)):
|
||||
raise HTTPException(401, "Not authorized")
|
||||
return await func(request, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
|
@ -2,101 +2,41 @@ from typing import Annotated
|
|||
|
||||
from httpx import HTTPError
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse, PlainTextResponse
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.security import OpenIdConnect
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
from authlib.integrations.starlette_client import OAuth, OAuthError
|
||||
|
||||
from .settings import settings
|
||||
from .models import User
|
||||
from .auth_utils import hasrole, get_current_user_or_none, get_current_user
|
||||
|
||||
templates = Jinja2Templates("src/templates")
|
||||
|
||||
|
||||
# swagger_provider = settings.oidc.get_swagger_provider()
|
||||
# if swagger_provider is not None:
|
||||
# swagger_ui_init_oauth = {
|
||||
# "clientId": settings.oidc.get_swagger_provider().client_id,
|
||||
# "scopes": ["openid"], # fill in additional scopes when necessary
|
||||
# "appName": "Test Application",
|
||||
# # "usePkceWithAuthorizationCodeGrant": True,
|
||||
# }
|
||||
# else:
|
||||
# swagger_ui_init_oauth = None
|
||||
|
||||
app = FastAPI(
|
||||
title="OIDC auth test",
|
||||
# swagger_ui_init_oauth=swagger_ui_init_oauth,
|
||||
)
|
||||
|
||||
# SessionMiddleware is required by authlib
|
||||
app.add_middleware(
|
||||
SessionMiddleware,
|
||||
secret_key=settings.secret_key,
|
||||
)
|
||||
|
||||
# Add oidc providers from the settings
|
||||
authlib_oauth = OAuth()
|
||||
for provider in settings.oidc.providers:
|
||||
authlib_oauth.register(
|
||||
name=provider.name,
|
||||
server_metadata_url=provider.provider_url,
|
||||
client_kwargs={"scope": "openid email offline_access profile"},
|
||||
client_kwargs={"scope": "openid email offline_access profile roles"},
|
||||
client_id=provider.client_id,
|
||||
client_secret=provider.client_secret,
|
||||
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
|
||||
)
|
||||
|
||||
# oidc_providers = dict(
|
||||
# (
|
||||
# provider.name,
|
||||
# OpenIdConnect(
|
||||
# openIdConnectUrl=provider.url,
|
||||
# scheme_name="openid",
|
||||
# auto_error=True,
|
||||
# ),
|
||||
# )
|
||||
# for provider in settings.oidc.providers
|
||||
# )
|
||||
# oidc_scheme = oidc_providers[swagger_provider.name]
|
||||
|
||||
|
||||
def get_current_user(request: Request) -> User:
|
||||
auth_data = request.session.get("user")
|
||||
if auth_data is None:
|
||||
raise HTTPException(401, "Not authorized")
|
||||
return User(**auth_data)
|
||||
|
||||
|
||||
def get_current_user_or_none(request: Request) -> User | None:
|
||||
try:
|
||||
return get_current_user(request)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
# def fastapi_oauth2():
|
||||
# breakpoint()
|
||||
# ...
|
||||
|
||||
|
||||
# async def current_user(request: Request, token: str | None = Depends(fastapi_oauth2)):
|
||||
# # we could query the identity provider to give us some information about the user
|
||||
# # userinfo = await self.authlib_oauth.provider.userinfo(token={"access_token": token})
|
||||
#
|
||||
# # in my case, the JWT already contains all the information so I only need to decode and verify it
|
||||
# try:
|
||||
# # note that this also validates the JWT by validating all the claims
|
||||
# user = await authlib_oauth.provider.parse_id_token(
|
||||
# request, token={"id_token": token}
|
||||
# )
|
||||
# except Exception as exp:
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
# detail=f"Supplied authentication could not be validated ({exp})",
|
||||
# )
|
||||
# return user
|
||||
|
||||
|
||||
@app.get("/login")
|
||||
async def login(request: Request, provider: str) -> RedirectResponse:
|
||||
|
@ -104,7 +44,7 @@ async def login(request: Request, provider: str) -> RedirectResponse:
|
|||
try:
|
||||
provider_: StarletteOAuth2App = getattr(authlib_oauth, provider)
|
||||
except AttributeError:
|
||||
raise HTTPException(500, "")
|
||||
raise HTTPException(500, "No such provider")
|
||||
try:
|
||||
return await provider_.authorize_redirect(request, redirect_uri)
|
||||
except HTTPError:
|
||||
|
@ -116,7 +56,7 @@ async def auth(request: Request, provider: str) -> RedirectResponse:
|
|||
try:
|
||||
provider_: StarletteOAuth2App = getattr(authlib_oauth, provider)
|
||||
except AttributeError:
|
||||
raise HTTPException(500, "")
|
||||
raise HTTPException(500, "No such provider")
|
||||
try:
|
||||
token = await provider_.authorize_access_token(request)
|
||||
except OAuthError as error:
|
||||
|
@ -150,8 +90,38 @@ async def home(
|
|||
)
|
||||
|
||||
|
||||
@app.get("/public")
|
||||
async def public() -> HTMLResponse:
|
||||
return HTMLResponse("<h1>Not protected</h1>")
|
||||
|
||||
|
||||
@app.get("/protected")
|
||||
async def get_protected(
|
||||
user: Annotated[User, Depends(get_current_user)]
|
||||
) -> PlainTextResponse:
|
||||
return PlainTextResponse("Only authenticated users can see this")
|
||||
) -> HTMLResponse:
|
||||
return HTMLResponse("<h1>Only authenticated users can see this</h1>")
|
||||
|
||||
|
||||
@app.get("/protected-by-foorole")
|
||||
@hasrole("foorole")
|
||||
async def get_protected_by_foorole(request: Request) -> HTMLResponse:
|
||||
return HTMLResponse("<h1>Only users with foorole can see this</h1>")
|
||||
|
||||
|
||||
@app.get("/protected-by-barrole")
|
||||
@hasrole("barrole")
|
||||
async def get_protected_by_barrole(request: Request) -> HTMLResponse:
|
||||
return HTMLResponse("<h1>Protected by barrole</h1>")
|
||||
|
||||
|
||||
@app.get("/protected-by-foorole-and-barrole")
|
||||
@hasrole("barrole")
|
||||
@hasrole("foorole")
|
||||
async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse:
|
||||
return HTMLResponse("<h1>Only users with foorole and barrole can see this</h1>")
|
||||
|
||||
|
||||
@app.get("/protected-by-foorole-or-barrole")
|
||||
@hasrole(["foorole", "barrole"])
|
||||
async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse:
|
||||
return HTMLResponse("<h1>Only users with foorole or barrole can see this</h1>")
|
||||
|
|
|
@ -169,9 +169,11 @@
|
|||
</p>
|
||||
<div id="links-to-check">
|
||||
<a href="public">Public</a>
|
||||
<a href="protected">Access protected content</a>
|
||||
<a href="protected-by-foorole">Access + foorole protected content</a>
|
||||
<a href="protected-by-barrole">Access + barrole protected content</a>
|
||||
<a href="protected">Auth protected content</a>
|
||||
<a href="protected-by-foorole">Auth + foorole protected content</a>
|
||||
<a href="protected-by-barrole">Auth + barrole protected content</a>
|
||||
<a href="protected-by-foorole-and-barrole">Auth + foorole and barrole protected content</a>
|
||||
<a href="protected-by-foorole-or-barrole">Auth + foorole or barrole protected content</a>
|
||||
<a href="other">Other</a>
|
||||
</div>
|
||||
{% if user and settings.oidc.show_session_details %}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue