Refactor; add services in settings

This commit is contained in:
phil 2025-01-19 01:48:00 +01:00
parent 17fabd21c9
commit f14d8d3114
7 changed files with 272 additions and 224 deletions

View file

@ -10,6 +10,7 @@ from urllib.parse import urlencode
from httpx import HTTPError
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
from fastapi.templating import Jinja2Templates
from fastapi.security import OpenIdConnect
@ -23,7 +24,7 @@ from pkce import generate_code_verifier, generate_pkce_pair
from .settings import settings
from .models import User
from .auth_utils import (
get_provider,
get_oidc_provider,
hasrole,
get_current_user_or_none,
get_current_user,
@ -42,6 +43,9 @@ app = FastAPI(
title="OIDC auth test",
)
app.mount(
"/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static"
)
# SessionMiddleware is required by authlib
app.add_middleware(
@ -76,6 +80,27 @@ for provider in settings.oidc.providers:
_providers[provider.id] = provider
@app.get("/")
async def home(
request: Request, user: Annotated[User, Depends(get_current_user_or_none)]
) -> HTMLResponse:
now = datetime.now()
return templates.TemplateResponse(
name="home.html",
request=request,
context={
"settings": settings.model_dump(),
"user": user,
"now": now,
"user_info_details": (
pretty_details(user, now)
if user and settings.oidc.show_session_details
else None
),
},
)
# Endpoints for the login / authorization process
@ -169,13 +194,13 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
@app.get("/logout")
async def logout(
request: Request,
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
) -> RedirectResponse:
# Clear session
request.session.pop("user_sub", None)
# Get provider's endpoint
if (
provider_logout_uri := provider.server_metadata.get("end_session_endpoint")
provider_logout_uri := oidc_provider.server_metadata.get("end_session_endpoint")
) is None:
logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}")
return RedirectResponse(request.url_for("non_compliant_logout"))
@ -200,7 +225,7 @@ async def logout(
@app.get("/non-compliant-logout")
async def non_compliant_logout(
request: Request,
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
):
"""A page for non-compliant OAuth2 servers that we cannot log out."""
return templates.TemplateResponse(
@ -210,28 +235,34 @@ async def non_compliant_logout(
)
# Home URL
# Route for OAuth resource server
@app.get("/")
async def home(
request: Request, user: Annotated[User, Depends(get_current_user_or_none)]
@app.get("/resource/{name}")
async def get_resource(
name: str,
request: Request,
user: Annotated[User, Depends(get_current_user)],
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
token: Annotated[OAuth2Token, Depends(get_token)],
) -> HTMLResponse:
now = datetime.now()
return templates.TemplateResponse(
name="home.html",
request=request,
context={
"settings": settings.model_dump(),
"user": user,
"now": now,
"user_info_details": (
pretty_details(user, now)
if user and settings.oidc.show_session_details
else None
),
},
)
"""Generic path for testing a resource provided by a provider"""
provider = _providers[oidc_provider.name]
if (
response := await oidc_provider.get(
"/api/v1/user/repos",
# headers={"Authorization": f"token {token['access_token']}"},
token=token,
)
).is_success:
repos = response.json()
names = [repo["name"] for repo in repos]
return HTMLResponse(f"{user.name} has {len(repos)} repos: {', '.join(names)}")
else:
raise HTTPException(status_code=response.status_code, detail=response.text)
# Routes for test
@app.get("/public")
@ -239,9 +270,6 @@ async def public() -> HTMLResponse:
return HTMLResponse("<h1>Not protected</h1>")
# Some URIs for testing the permissions
@app.get("/protected")
async def get_protected(
user: Annotated[User, Depends(get_current_user)]
@ -277,12 +305,12 @@ async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse:
@app.get("/introspect")
async def get_introspect(
request: Request,
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
token: Annotated[OAuth2Token, Depends(get_token)],
) -> JSONResponse:
if (
response := await provider.post(
provider.server_metadata["introspection_endpoint"],
response := await oidc_provider.post(
oidc_provider.server_metadata["introspection_endpoint"],
token=token,
data={"token": token["access_token"]},
)
@ -296,11 +324,11 @@ async def get_introspect(
async def get_forgejo_user_info(
request: Request,
user: Annotated[User, Depends(get_current_user)],
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
token: Annotated[OAuth2Token, Depends(get_token)],
) -> HTMLResponse:
if (
response := await provider.get(
response := await oidc_provider.get(
"/api/v1/user/repos",
# headers={"Authorization": f"token {token['access_token']}"},
token=token,
@ -313,11 +341,9 @@ async def get_forgejo_user_info(
raise HTTPException(status_code=response.status_code, detail=response.text)
# @app.get("/fast_api_depends")
# def fast_api_depends(
# token: Annotated[str, Depends(fastapi_providers["Keycloak"])]
# ) -> HTMLResponse:
# return HTMLResponse("You're Authenticated")
# Snippet for running standalone
# Mostly useful for the --version option,
# as running with uvicorn is easy and provides flaxibility
def main():