Refactor; add services in settings
This commit is contained in:
parent
17fabd21c9
commit
f14d8d3114
7 changed files with 272 additions and 224 deletions
|
@ -10,6 +10,7 @@ from urllib.parse import urlencode
|
|||
|
||||
from httpx import HTTPError
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.security import OpenIdConnect
|
||||
|
@ -23,7 +24,7 @@ from pkce import generate_code_verifier, generate_pkce_pair
|
|||
from .settings import settings
|
||||
from .models import User
|
||||
from .auth_utils import (
|
||||
get_provider,
|
||||
get_oidc_provider,
|
||||
hasrole,
|
||||
get_current_user_or_none,
|
||||
get_current_user,
|
||||
|
@ -42,6 +43,9 @@ app = FastAPI(
|
|||
title="OIDC auth test",
|
||||
)
|
||||
|
||||
app.mount(
|
||||
"/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static"
|
||||
)
|
||||
|
||||
# SessionMiddleware is required by authlib
|
||||
app.add_middleware(
|
||||
|
@ -76,6 +80,27 @@ for provider in settings.oidc.providers:
|
|||
_providers[provider.id] = provider
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def home(
|
||||
request: Request, user: Annotated[User, Depends(get_current_user_or_none)]
|
||||
) -> HTMLResponse:
|
||||
now = datetime.now()
|
||||
return templates.TemplateResponse(
|
||||
name="home.html",
|
||||
request=request,
|
||||
context={
|
||||
"settings": settings.model_dump(),
|
||||
"user": user,
|
||||
"now": now,
|
||||
"user_info_details": (
|
||||
pretty_details(user, now)
|
||||
if user and settings.oidc.show_session_details
|
||||
else None
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Endpoints for the login / authorization process
|
||||
|
||||
|
||||
|
@ -169,13 +194,13 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
|||
@app.get("/logout")
|
||||
async def logout(
|
||||
request: Request,
|
||||
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
|
||||
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
||||
) -> RedirectResponse:
|
||||
# Clear session
|
||||
request.session.pop("user_sub", None)
|
||||
# Get provider's endpoint
|
||||
if (
|
||||
provider_logout_uri := provider.server_metadata.get("end_session_endpoint")
|
||||
provider_logout_uri := oidc_provider.server_metadata.get("end_session_endpoint")
|
||||
) is None:
|
||||
logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}")
|
||||
return RedirectResponse(request.url_for("non_compliant_logout"))
|
||||
|
@ -200,7 +225,7 @@ async def logout(
|
|||
@app.get("/non-compliant-logout")
|
||||
async def non_compliant_logout(
|
||||
request: Request,
|
||||
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
|
||||
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
||||
):
|
||||
"""A page for non-compliant OAuth2 servers that we cannot log out."""
|
||||
return templates.TemplateResponse(
|
||||
|
@ -210,28 +235,34 @@ async def non_compliant_logout(
|
|||
)
|
||||
|
||||
|
||||
# Home URL
|
||||
# Route for OAuth resource server
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def home(
|
||||
request: Request, user: Annotated[User, Depends(get_current_user_or_none)]
|
||||
@app.get("/resource/{name}")
|
||||
async def get_resource(
|
||||
name: str,
|
||||
request: Request,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
||||
token: Annotated[OAuth2Token, Depends(get_token)],
|
||||
) -> HTMLResponse:
|
||||
now = datetime.now()
|
||||
return templates.TemplateResponse(
|
||||
name="home.html",
|
||||
request=request,
|
||||
context={
|
||||
"settings": settings.model_dump(),
|
||||
"user": user,
|
||||
"now": now,
|
||||
"user_info_details": (
|
||||
pretty_details(user, now)
|
||||
if user and settings.oidc.show_session_details
|
||||
else None
|
||||
),
|
||||
},
|
||||
)
|
||||
"""Generic path for testing a resource provided by a provider"""
|
||||
provider = _providers[oidc_provider.name]
|
||||
if (
|
||||
response := await oidc_provider.get(
|
||||
"/api/v1/user/repos",
|
||||
# headers={"Authorization": f"token {token['access_token']}"},
|
||||
token=token,
|
||||
)
|
||||
).is_success:
|
||||
repos = response.json()
|
||||
names = [repo["name"] for repo in repos]
|
||||
return HTMLResponse(f"{user.name} has {len(repos)} repos: {', '.join(names)}")
|
||||
else:
|
||||
raise HTTPException(status_code=response.status_code, detail=response.text)
|
||||
|
||||
|
||||
# Routes for test
|
||||
|
||||
|
||||
@app.get("/public")
|
||||
|
@ -239,9 +270,6 @@ async def public() -> HTMLResponse:
|
|||
return HTMLResponse("<h1>Not protected</h1>")
|
||||
|
||||
|
||||
# Some URIs for testing the permissions
|
||||
|
||||
|
||||
@app.get("/protected")
|
||||
async def get_protected(
|
||||
user: Annotated[User, Depends(get_current_user)]
|
||||
|
@ -277,12 +305,12 @@ async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse:
|
|||
@app.get("/introspect")
|
||||
async def get_introspect(
|
||||
request: Request,
|
||||
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
|
||||
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
||||
token: Annotated[OAuth2Token, Depends(get_token)],
|
||||
) -> JSONResponse:
|
||||
if (
|
||||
response := await provider.post(
|
||||
provider.server_metadata["introspection_endpoint"],
|
||||
response := await oidc_provider.post(
|
||||
oidc_provider.server_metadata["introspection_endpoint"],
|
||||
token=token,
|
||||
data={"token": token["access_token"]},
|
||||
)
|
||||
|
@ -296,11 +324,11 @@ async def get_introspect(
|
|||
async def get_forgejo_user_info(
|
||||
request: Request,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
|
||||
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
||||
token: Annotated[OAuth2Token, Depends(get_token)],
|
||||
) -> HTMLResponse:
|
||||
if (
|
||||
response := await provider.get(
|
||||
response := await oidc_provider.get(
|
||||
"/api/v1/user/repos",
|
||||
# headers={"Authorization": f"token {token['access_token']}"},
|
||||
token=token,
|
||||
|
@ -313,11 +341,9 @@ async def get_forgejo_user_info(
|
|||
raise HTTPException(status_code=response.status_code, detail=response.text)
|
||||
|
||||
|
||||
# @app.get("/fast_api_depends")
|
||||
# def fast_api_depends(
|
||||
# token: Annotated[str, Depends(fastapi_providers["Keycloak"])]
|
||||
# ) -> HTMLResponse:
|
||||
# return HTMLResponse("You're Authenticated")
|
||||
# Snippet for running standalone
|
||||
# Mostly useful for the --version option,
|
||||
# as running with uvicorn is easy and provides flaxibility
|
||||
|
||||
|
||||
def main():
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue