oidc-fastapi-test/src/oidc-test/main.py

128 lines
4 KiB
Python
Raw Normal View History

2025-01-02 11:23:53 +01:00
from typing import Annotated
2025-01-02 02:14:30 +01:00
from httpx import HTTPError
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.responses import HTMLResponse, RedirectResponse
2025-01-02 02:14:30 +01:00
from fastapi.templating import Jinja2Templates
from starlette.middleware.sessions import SessionMiddleware
2025-01-02 04:04:45 +01:00
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
2025-01-02 02:14:30 +01:00
from authlib.integrations.starlette_client import OAuth, OAuthError
2025-01-02 11:23:53 +01:00
from .settings import settings
2025-01-02 02:14:30 +01:00
from .models import User
from .auth_utils import hasrole, get_current_user_or_none, get_current_user
2025-01-02 02:14:30 +01:00
templates = Jinja2Templates("src/templates")
app = FastAPI(
title="OIDC auth test",
2025-01-02 11:23:53 +01:00
)
# SessionMiddleware is required by authlib
2025-01-02 11:23:53 +01:00
app.add_middleware(
SessionMiddleware,
secret_key=settings.secret_key,
2025-01-02 02:14:30 +01:00
)
# Add oidc providers from the settings
2025-01-02 02:14:30 +01:00
authlib_oauth = OAuth()
for provider in settings.oidc.providers:
authlib_oauth.register(
name=provider.name,
server_metadata_url=provider.provider_url,
client_kwargs={"scope": "openid email offline_access profile roles"},
2025-01-02 02:14:30 +01:00
client_id=provider.client_id,
client_secret=provider.client_secret,
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
)
@app.get("/login")
async def login(request: Request, provider: str) -> RedirectResponse:
redirect_uri = request.url_for("auth", provider=provider)
try:
2025-01-02 04:04:45 +01:00
provider_: StarletteOAuth2App = getattr(authlib_oauth, provider)
2025-01-02 03:30:18 +01:00
except AttributeError:
raise HTTPException(500, "No such provider")
2025-01-02 03:30:18 +01:00
try:
return await provider_.authorize_redirect(request, redirect_uri)
2025-01-02 02:14:30 +01:00
except HTTPError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
@app.get("/auth/{provider}")
async def auth(request: Request, provider: str) -> RedirectResponse:
try:
2025-01-02 04:04:45 +01:00
provider_: StarletteOAuth2App = getattr(authlib_oauth, provider)
2025-01-02 03:30:18 +01:00
except AttributeError:
raise HTTPException(500, "No such provider")
2025-01-02 03:30:18 +01:00
try:
token = await provider_.authorize_access_token(request)
2025-01-02 02:14:30 +01:00
except OAuthError as error:
2025-01-02 03:30:18 +01:00
raise HTTPException(status_code=401, detail=error.error)
2025-01-02 02:14:30 +01:00
user = token.get("userinfo")
if user:
request.session["user"] = dict(user)
2025-01-02 03:30:18 +01:00
return RedirectResponse(url="/")
else:
return RedirectResponse(url="/login")
2025-01-02 02:14:30 +01:00
@app.get("/logout")
async def logout(request: Request) -> RedirectResponse:
request.session.pop("user", None)
return RedirectResponse(url="/")
2025-01-02 03:16:03 +01:00
@app.get("/")
2025-01-02 02:14:30 +01:00
async def home(
request: Request, user: Annotated[User, Depends(get_current_user_or_none)]
) -> HTMLResponse:
return templates.TemplateResponse(
request=request,
context={
2025-01-02 10:46:02 +01:00
"settings": settings.model_dump(),
2025-01-02 02:14:30 +01:00
"user": user,
"auth_data": request.session.get("user"),
},
name="index.html",
)
@app.get("/public")
async def public() -> HTMLResponse:
return HTMLResponse("<h1>Not protected</h1>")
2025-01-02 03:09:16 +01:00
@app.get("/protected")
2025-01-02 03:16:03 +01:00
async def get_protected(
user: Annotated[User, Depends(get_current_user)]
) -> HTMLResponse:
return HTMLResponse("<h1>Only authenticated users can see this</h1>")
@app.get("/protected-by-foorole")
@hasrole("foorole")
async def get_protected_by_foorole(request: Request) -> HTMLResponse:
return HTMLResponse("<h1>Only users with foorole can see this</h1>")
@app.get("/protected-by-barrole")
@hasrole("barrole")
async def get_protected_by_barrole(request: Request) -> HTMLResponse:
return HTMLResponse("<h1>Protected by barrole</h1>")
@app.get("/protected-by-foorole-and-barrole")
@hasrole("barrole")
@hasrole("foorole")
async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse:
return HTMLResponse("<h1>Only users with foorole and barrole can see this</h1>")
@app.get("/protected-by-foorole-or-barrole")
@hasrole(["foorole", "barrole"])
async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse:
return HTMLResponse("<h1>Only users with foorole or barrole can see this</h1>")