Add hasrole route decorator, more checks, refactor
This commit is contained in:
parent
e44d89e512
commit
09d9bc0a7d
3 changed files with 102 additions and 72 deletions
58
src/oidc-test/auth_utils.py
Normal file
58
src/oidc-test/auth_utils.py
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
from typing import Union
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
|
from .models import User
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user(request: Request) -> User:
|
||||||
|
auth_data = request.session.get("user")
|
||||||
|
if auth_data is None:
|
||||||
|
raise HTTPException(401, "Not authorized")
|
||||||
|
return User(**auth_data)
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user_or_none(request: Request) -> User | None:
|
||||||
|
try:
|
||||||
|
return get_current_user(request)
|
||||||
|
except HTTPException:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def hasrole(
|
||||||
|
required_roles: Union[str, list[str]] = [],
|
||||||
|
roles_key: str = "roles",
|
||||||
|
realm: str | None = "realm_access", # Keycloak standard for realm defined roles
|
||||||
|
):
|
||||||
|
required_roles_set: set[str]
|
||||||
|
if isinstance(required_roles, str):
|
||||||
|
required_roles_set = set([required_roles])
|
||||||
|
else:
|
||||||
|
required_roles_set = set(required_roles)
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(request=None, *args, **kwargs):
|
||||||
|
if request is None:
|
||||||
|
raise HTTPException(
|
||||||
|
500,
|
||||||
|
"Functions decorated with hasrole must have a request:Request argument",
|
||||||
|
)
|
||||||
|
if "user" not in request.session:
|
||||||
|
raise HTTPException(401, "Not authorized")
|
||||||
|
user = request.session["user"]
|
||||||
|
try:
|
||||||
|
if realm in user:
|
||||||
|
roles = user[realm][roles_key]
|
||||||
|
else:
|
||||||
|
roles = user[roles_key]
|
||||||
|
except KeyError:
|
||||||
|
raise HTTPException(401, "Not authorized")
|
||||||
|
if not any(required_roles_set.intersection(roles)):
|
||||||
|
raise HTTPException(401, "Not authorized")
|
||||||
|
return await func(request, *args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
|
@ -2,101 +2,41 @@ from typing import Annotated
|
||||||
|
|
||||||
from httpx import HTTPError
|
from httpx import HTTPError
|
||||||
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
||||||
from fastapi.responses import HTMLResponse, RedirectResponse, PlainTextResponse
|
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||||
from fastapi.templating import Jinja2Templates
|
from fastapi.templating import Jinja2Templates
|
||||||
from fastapi.security import OpenIdConnect
|
|
||||||
from starlette.middleware.sessions import SessionMiddleware
|
from starlette.middleware.sessions import SessionMiddleware
|
||||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||||
from authlib.integrations.starlette_client import OAuth, OAuthError
|
from authlib.integrations.starlette_client import OAuth, OAuthError
|
||||||
|
|
||||||
from .settings import settings
|
from .settings import settings
|
||||||
from .models import User
|
from .models import User
|
||||||
|
from .auth_utils import hasrole, get_current_user_or_none, get_current_user
|
||||||
|
|
||||||
templates = Jinja2Templates("src/templates")
|
templates = Jinja2Templates("src/templates")
|
||||||
|
|
||||||
|
|
||||||
# swagger_provider = settings.oidc.get_swagger_provider()
|
|
||||||
# if swagger_provider is not None:
|
|
||||||
# swagger_ui_init_oauth = {
|
|
||||||
# "clientId": settings.oidc.get_swagger_provider().client_id,
|
|
||||||
# "scopes": ["openid"], # fill in additional scopes when necessary
|
|
||||||
# "appName": "Test Application",
|
|
||||||
# # "usePkceWithAuthorizationCodeGrant": True,
|
|
||||||
# }
|
|
||||||
# else:
|
|
||||||
# swagger_ui_init_oauth = None
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="OIDC auth test",
|
title="OIDC auth test",
|
||||||
# swagger_ui_init_oauth=swagger_ui_init_oauth,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# SessionMiddleware is required by authlib
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
SessionMiddleware,
|
SessionMiddleware,
|
||||||
secret_key=settings.secret_key,
|
secret_key=settings.secret_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add oidc providers from the settings
|
||||||
authlib_oauth = OAuth()
|
authlib_oauth = OAuth()
|
||||||
for provider in settings.oidc.providers:
|
for provider in settings.oidc.providers:
|
||||||
authlib_oauth.register(
|
authlib_oauth.register(
|
||||||
name=provider.name,
|
name=provider.name,
|
||||||
server_metadata_url=provider.provider_url,
|
server_metadata_url=provider.provider_url,
|
||||||
client_kwargs={"scope": "openid email offline_access profile"},
|
client_kwargs={"scope": "openid email offline_access profile roles"},
|
||||||
client_id=provider.client_id,
|
client_id=provider.client_id,
|
||||||
client_secret=provider.client_secret,
|
client_secret=provider.client_secret,
|
||||||
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
|
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
|
||||||
)
|
)
|
||||||
|
|
||||||
# oidc_providers = dict(
|
|
||||||
# (
|
|
||||||
# provider.name,
|
|
||||||
# OpenIdConnect(
|
|
||||||
# openIdConnectUrl=provider.url,
|
|
||||||
# scheme_name="openid",
|
|
||||||
# auto_error=True,
|
|
||||||
# ),
|
|
||||||
# )
|
|
||||||
# for provider in settings.oidc.providers
|
|
||||||
# )
|
|
||||||
# oidc_scheme = oidc_providers[swagger_provider.name]
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_user(request: Request) -> User:
|
|
||||||
auth_data = request.session.get("user")
|
|
||||||
if auth_data is None:
|
|
||||||
raise HTTPException(401, "Not authorized")
|
|
||||||
return User(**auth_data)
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_user_or_none(request: Request) -> User | None:
|
|
||||||
try:
|
|
||||||
return get_current_user(request)
|
|
||||||
except HTTPException:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# def fastapi_oauth2():
|
|
||||||
# breakpoint()
|
|
||||||
# ...
|
|
||||||
|
|
||||||
|
|
||||||
# async def current_user(request: Request, token: str | None = Depends(fastapi_oauth2)):
|
|
||||||
# # we could query the identity provider to give us some information about the user
|
|
||||||
# # userinfo = await self.authlib_oauth.provider.userinfo(token={"access_token": token})
|
|
||||||
#
|
|
||||||
# # in my case, the JWT already contains all the information so I only need to decode and verify it
|
|
||||||
# try:
|
|
||||||
# # note that this also validates the JWT by validating all the claims
|
|
||||||
# user = await authlib_oauth.provider.parse_id_token(
|
|
||||||
# request, token={"id_token": token}
|
|
||||||
# )
|
|
||||||
# except Exception as exp:
|
|
||||||
# raise HTTPException(
|
|
||||||
# status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
# detail=f"Supplied authentication could not be validated ({exp})",
|
|
||||||
# )
|
|
||||||
# return user
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/login")
|
@app.get("/login")
|
||||||
async def login(request: Request, provider: str) -> RedirectResponse:
|
async def login(request: Request, provider: str) -> RedirectResponse:
|
||||||
|
@ -104,7 +44,7 @@ async def login(request: Request, provider: str) -> RedirectResponse:
|
||||||
try:
|
try:
|
||||||
provider_: StarletteOAuth2App = getattr(authlib_oauth, provider)
|
provider_: StarletteOAuth2App = getattr(authlib_oauth, provider)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise HTTPException(500, "")
|
raise HTTPException(500, "No such provider")
|
||||||
try:
|
try:
|
||||||
return await provider_.authorize_redirect(request, redirect_uri)
|
return await provider_.authorize_redirect(request, redirect_uri)
|
||||||
except HTTPError:
|
except HTTPError:
|
||||||
|
@ -116,7 +56,7 @@ async def auth(request: Request, provider: str) -> RedirectResponse:
|
||||||
try:
|
try:
|
||||||
provider_: StarletteOAuth2App = getattr(authlib_oauth, provider)
|
provider_: StarletteOAuth2App = getattr(authlib_oauth, provider)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise HTTPException(500, "")
|
raise HTTPException(500, "No such provider")
|
||||||
try:
|
try:
|
||||||
token = await provider_.authorize_access_token(request)
|
token = await provider_.authorize_access_token(request)
|
||||||
except OAuthError as error:
|
except OAuthError as error:
|
||||||
|
@ -150,8 +90,38 @@ async def home(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/public")
|
||||||
|
async def public() -> HTMLResponse:
|
||||||
|
return HTMLResponse("<h1>Not protected</h1>")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/protected")
|
@app.get("/protected")
|
||||||
async def get_protected(
|
async def get_protected(
|
||||||
user: Annotated[User, Depends(get_current_user)]
|
user: Annotated[User, Depends(get_current_user)]
|
||||||
) -> PlainTextResponse:
|
) -> HTMLResponse:
|
||||||
return PlainTextResponse("Only authenticated users can see this")
|
return HTMLResponse("<h1>Only authenticated users can see this</h1>")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/protected-by-foorole")
|
||||||
|
@hasrole("foorole")
|
||||||
|
async def get_protected_by_foorole(request: Request) -> HTMLResponse:
|
||||||
|
return HTMLResponse("<h1>Only users with foorole can see this</h1>")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/protected-by-barrole")
|
||||||
|
@hasrole("barrole")
|
||||||
|
async def get_protected_by_barrole(request: Request) -> HTMLResponse:
|
||||||
|
return HTMLResponse("<h1>Protected by barrole</h1>")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/protected-by-foorole-and-barrole")
|
||||||
|
@hasrole("barrole")
|
||||||
|
@hasrole("foorole")
|
||||||
|
async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse:
|
||||||
|
return HTMLResponse("<h1>Only users with foorole and barrole can see this</h1>")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/protected-by-foorole-or-barrole")
|
||||||
|
@hasrole(["foorole", "barrole"])
|
||||||
|
async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse:
|
||||||
|
return HTMLResponse("<h1>Only users with foorole or barrole can see this</h1>")
|
||||||
|
|
|
@ -169,9 +169,11 @@
|
||||||
</p>
|
</p>
|
||||||
<div id="links-to-check">
|
<div id="links-to-check">
|
||||||
<a href="public">Public</a>
|
<a href="public">Public</a>
|
||||||
<a href="protected">Access protected content</a>
|
<a href="protected">Auth protected content</a>
|
||||||
<a href="protected-by-foorole">Access + foorole protected content</a>
|
<a href="protected-by-foorole">Auth + foorole protected content</a>
|
||||||
<a href="protected-by-barrole">Access + barrole protected content</a>
|
<a href="protected-by-barrole">Auth + barrole protected content</a>
|
||||||
|
<a href="protected-by-foorole-and-barrole">Auth + foorole and barrole protected content</a>
|
||||||
|
<a href="protected-by-foorole-or-barrole">Auth + foorole or barrole protected content</a>
|
||||||
<a href="other">Other</a>
|
<a href="other">Other</a>
|
||||||
</div>
|
</div>
|
||||||
{% if user and settings.oidc.show_session_details %}
|
{% if user and settings.oidc.show_session_details %}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue