2025-01-02 11:23:53 +01:00
|
|
|
from typing import Annotated
|
2025-01-02 02:14:30 +01:00
|
|
|
|
|
|
|
from httpx import HTTPError
|
|
|
|
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
2025-01-03 17:00:38 +01:00
|
|
|
from fastapi.responses import HTMLResponse, RedirectResponse
|
2025-01-02 02:14:30 +01:00
|
|
|
from fastapi.templating import Jinja2Templates
|
|
|
|
from starlette.middleware.sessions import SessionMiddleware
|
2025-01-02 04:04:45 +01:00
|
|
|
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
2025-01-02 02:14:30 +01:00
|
|
|
from authlib.integrations.starlette_client import OAuth, OAuthError
|
|
|
|
|
2025-01-02 11:23:53 +01:00
|
|
|
from .settings import settings
|
2025-01-02 02:14:30 +01:00
|
|
|
from .models import User
|
2025-01-03 17:00:38 +01:00
|
|
|
from .auth_utils import hasrole, get_current_user_or_none, get_current_user
|
2025-01-02 02:14:30 +01:00
|
|
|
|
|
|
|
templates = Jinja2Templates("src/templates")
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI(
|
|
|
|
title="OIDC auth test",
|
2025-01-02 11:23:53 +01:00
|
|
|
)
|
|
|
|
|
2025-01-03 17:00:38 +01:00
|
|
|
# SessionMiddleware is required by authlib
|
2025-01-02 11:23:53 +01:00
|
|
|
app.add_middleware(
|
|
|
|
SessionMiddleware,
|
|
|
|
secret_key=settings.secret_key,
|
2025-01-02 02:14:30 +01:00
|
|
|
)
|
|
|
|
|
2025-01-03 17:00:38 +01:00
|
|
|
# Add oidc providers from the settings
|
2025-01-02 02:14:30 +01:00
|
|
|
authlib_oauth = OAuth()
|
|
|
|
for provider in settings.oidc.providers:
|
|
|
|
authlib_oauth.register(
|
|
|
|
name=provider.name,
|
|
|
|
server_metadata_url=provider.provider_url,
|
2025-01-03 17:00:38 +01:00
|
|
|
client_kwargs={"scope": "openid email offline_access profile roles"},
|
2025-01-02 02:14:30 +01:00
|
|
|
client_id=provider.client_id,
|
|
|
|
client_secret=provider.client_secret,
|
|
|
|
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/login")
|
|
|
|
async def login(request: Request, provider: str) -> RedirectResponse:
|
|
|
|
redirect_uri = request.url_for("auth", provider=provider)
|
|
|
|
try:
|
2025-01-02 04:04:45 +01:00
|
|
|
provider_: StarletteOAuth2App = getattr(authlib_oauth, provider)
|
2025-01-02 03:30:18 +01:00
|
|
|
except AttributeError:
|
2025-01-03 17:00:38 +01:00
|
|
|
raise HTTPException(500, "No such provider")
|
2025-01-02 03:30:18 +01:00
|
|
|
try:
|
|
|
|
return await provider_.authorize_redirect(request, redirect_uri)
|
2025-01-02 02:14:30 +01:00
|
|
|
except HTTPError:
|
|
|
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/auth/{provider}")
|
|
|
|
async def auth(request: Request, provider: str) -> RedirectResponse:
|
|
|
|
try:
|
2025-01-02 04:04:45 +01:00
|
|
|
provider_: StarletteOAuth2App = getattr(authlib_oauth, provider)
|
2025-01-02 03:30:18 +01:00
|
|
|
except AttributeError:
|
2025-01-03 17:00:38 +01:00
|
|
|
raise HTTPException(500, "No such provider")
|
2025-01-02 03:30:18 +01:00
|
|
|
try:
|
|
|
|
token = await provider_.authorize_access_token(request)
|
2025-01-02 02:14:30 +01:00
|
|
|
except OAuthError as error:
|
2025-01-02 03:30:18 +01:00
|
|
|
raise HTTPException(status_code=401, detail=error.error)
|
2025-01-02 02:14:30 +01:00
|
|
|
user = token.get("userinfo")
|
|
|
|
if user:
|
|
|
|
request.session["user"] = dict(user)
|
2025-01-02 03:30:18 +01:00
|
|
|
return RedirectResponse(url="/")
|
|
|
|
else:
|
|
|
|
return RedirectResponse(url="/login")
|
2025-01-02 02:14:30 +01:00
|
|
|
|
|
|
|
|
|
|
|
@app.get("/logout")
|
|
|
|
async def logout(request: Request) -> RedirectResponse:
|
|
|
|
request.session.pop("user", None)
|
|
|
|
return RedirectResponse(url="/")
|
|
|
|
|
|
|
|
|
2025-01-02 03:16:03 +01:00
|
|
|
@app.get("/")
|
2025-01-02 02:14:30 +01:00
|
|
|
async def home(
|
|
|
|
request: Request, user: Annotated[User, Depends(get_current_user_or_none)]
|
|
|
|
) -> HTMLResponse:
|
|
|
|
return templates.TemplateResponse(
|
|
|
|
request=request,
|
|
|
|
context={
|
2025-01-02 10:46:02 +01:00
|
|
|
"settings": settings.model_dump(),
|
2025-01-02 02:14:30 +01:00
|
|
|
"user": user,
|
|
|
|
"auth_data": request.session.get("user"),
|
|
|
|
},
|
|
|
|
name="index.html",
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2025-01-03 17:00:38 +01:00
|
|
|
@app.get("/public")
|
|
|
|
async def public() -> HTMLResponse:
|
|
|
|
return HTMLResponse("<h1>Not protected</h1>")
|
|
|
|
|
|
|
|
|
2025-01-02 03:09:16 +01:00
|
|
|
@app.get("/protected")
|
2025-01-02 03:16:03 +01:00
|
|
|
async def get_protected(
|
|
|
|
user: Annotated[User, Depends(get_current_user)]
|
2025-01-03 17:00:38 +01:00
|
|
|
) -> HTMLResponse:
|
|
|
|
return HTMLResponse("<h1>Only authenticated users can see this</h1>")
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/protected-by-foorole")
|
|
|
|
@hasrole("foorole")
|
|
|
|
async def get_protected_by_foorole(request: Request) -> HTMLResponse:
|
|
|
|
return HTMLResponse("<h1>Only users with foorole can see this</h1>")
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/protected-by-barrole")
|
|
|
|
@hasrole("barrole")
|
|
|
|
async def get_protected_by_barrole(request: Request) -> HTMLResponse:
|
|
|
|
return HTMLResponse("<h1>Protected by barrole</h1>")
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/protected-by-foorole-and-barrole")
|
|
|
|
@hasrole("barrole")
|
|
|
|
@hasrole("foorole")
|
|
|
|
async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse:
|
|
|
|
return HTMLResponse("<h1>Only users with foorole and barrole can see this</h1>")
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/protected-by-foorole-or-barrole")
|
|
|
|
@hasrole(["foorole", "barrole"])
|
|
|
|
async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse:
|
|
|
|
return HTMLResponse("<h1>Only users with foorole or barrole can see this</h1>")
|