oidc-fastapi-test/src/oidc_test/main.py
phil b96bfa870a
Some checks failed
/ build (push) Failing after 14s
/ test (push) Successful in 4s
Fix token introspection link (should be 401)
2025-01-18 14:23:01 +01:00

350 lines
11 KiB
Python

"""
Test of OpenId Connect & OAuth2 with FastAPI
"""
from typing import Annotated
from pathlib import Path
from datetime import datetime
import logging
from urllib.parse import urlencode
from httpx import HTTPError
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
from fastapi.templating import Jinja2Templates
from fastapi.security import OpenIdConnect
from starlette.middleware.sessions import SessionMiddleware
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from authlib.integrations.base_client import OAuthError
from authlib.integrations.httpx_client import AsyncOAuth2Client
from authlib.oauth2.rfc6749 import OAuth2Token
from pkce import generate_code_verifier, generate_pkce_pair
from .settings import settings
from .models import User
from .auth_utils import (
get_provider,
hasrole,
get_current_user_or_none,
get_current_user,
authlib_oauth,
get_token,
)
from .auth_misc import pretty_details
from .database import db
logger = logging.getLogger("uvicorn.error")
templates = Jinja2Templates(Path(__file__).parent / "templates")
app = FastAPI(
title="OIDC auth test",
)
# SessionMiddleware is required by authlib
app.add_middleware(
SessionMiddleware,
secret_key=settings.secret_key,
)
# Add oidc providers to authlib from the settings
fastapi_providers = {}
_providers = {}
for provider in settings.oidc.providers:
authlib_oauth.register(
name=provider.id,
server_metadata_url=provider.openid_configuration,
client_kwargs={
"scope": "openid email", # offline_access profile",
},
client_id=provider.client_id,
client_secret=provider.client_secret,
api_base_url=provider.url,
# For PKCE (not implemented yet):
# code_challenge_method="S256",
# fetch_token=fetch_token,
# update_token=update_token,
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
)
fastapi_providers[provider.id] = OpenIdConnect(
openIdConnectUrl=provider.openid_configuration
)
_providers[provider.id] = provider
# Endpoints for the login / authorization process
@app.get("/login/{oidc_provider_id}")
async def login(request: Request, oidc_provider_id: str) -> RedirectResponse:
"""Login with the provider id, giving the browser a redirect to its authorize page.
The provider is expected to send the browser back to our own /auth/{oidc_provider_id} url
with the token.
"""
redirect_uri = request.url_for("auth", oidc_provider_id=oidc_provider_id)
try:
provider_: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
except AttributeError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
if (
code_challenge_method := _providers[oidc_provider_id].code_challenge_method
) is not None:
client = AsyncOAuth2Client(..., code_challenge_method=code_challenge_method)
code_verifier = generate_code_verifier()
logger.debug("TODO: PKCE")
else:
code_verifier = None
try:
response = await provider_.authorize_redirect(
request,
redirect_uri,
access_type="offline",
code_verifier=code_verifier,
)
return response
except HTTPError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
@app.get("/auth/{oidc_provider_id}")
async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
"""Decrypt the auth token, store it to the session (cookie based)
and response to the browser with a redirect to a "welcome user" page.
"""
try:
oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
except AttributeError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
try:
token: OAuth2Token = await oidc_provider.authorize_access_token(request)
except OAuthError as error:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error)
# Remember the oidc_provider in the session
# logger.debug(f"Scope: {token['scope']}")
request.session["oidc_provider_id"] = oidc_provider_id
#
# One could process the full decoded token which contains extra information
# eg for updates. Here we are only interested in roles
#
if userinfo := token.get("userinfo"):
# Remember the oidc_provider in the session
request.session["oidc_provider_id"] = oidc_provider_id
# User id (sub) given by oidc provider
sub = userinfo["sub"]
# Get additional data from userinfo endpoint
try:
user_info_from_endpoint = await oidc_provider.userinfo(
token=token, follow_redirects=True
)
except Exception as err:
logger.warn(f"Cannot get userinfo from endpoint: {err}")
user_info_from_endpoint = {}
# Build and remember the user in the session
request.session["user_sub"] = sub
# Store the user in the database
user = await db.add_user(
sub,
user_info=userinfo,
oidc_provider=oidc_provider,
user_info_from_endpoint=user_info_from_endpoint,
)
# Add the id_token to the session
request.session["token"] = token["id_token"]
# Add the token to the db because it is used for logout
await db.add_token(token, user)
# Send the user to the home: (s)he is authenticated
return RedirectResponse(url=request.url_for("home"))
else:
# Not sure if it's correct to redirect to plain login
# if no userinfo is provided
return RedirectResponse(
url=request.url_for("login", oidc_provider_id=oidc_provider_id)
)
@app.get("/logout")
async def logout(
request: Request,
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
) -> RedirectResponse:
# Clear session
request.session.pop("user_sub", None)
# Get provider's endpoint
if (
provider_logout_uri := provider.server_metadata.get("end_session_endpoint")
) is None:
logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}")
return RedirectResponse(request.url_for("non_compliant_logout"))
post_logout_uri = request.url_for("home")
if (token := await db.get_token(request.session.pop("token", None))) is None:
logger.warn("No session in db for the token")
return RedirectResponse(request.url_for("home"))
logout_url = (
provider_logout_uri
+ "?"
+ urlencode(
{
"post_logout_redirect_uri": post_logout_uri,
"id_token_hint": token["id_token"],
"cliend_id": "oidc_local_test",
}
)
)
return RedirectResponse(logout_url)
@app.get("/non-compliant-logout")
async def non_compliant_logout(
request: Request,
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
):
"""A page for non-compliant OAuth2 servers that we cannot log out."""
return templates.TemplateResponse(
name="non_compliant_logout.html",
request=request,
context={"provider": provider, "home_url": request.url_for("home")},
)
# Home URL
@app.get("/")
async def home(
request: Request, user: Annotated[User, Depends(get_current_user_or_none)]
) -> HTMLResponse:
now = datetime.now()
return templates.TemplateResponse(
name="home.html",
request=request,
context={
"settings": settings.model_dump(),
"user": user,
"now": now,
"user_info_details": (
pretty_details(user, now)
if user and settings.oidc.show_session_details
else None
),
},
)
@app.get("/public")
async def public() -> HTMLResponse:
return HTMLResponse("<h1>Not protected</h1>")
# Some URIs for testing the permissions
@app.get("/protected")
async def get_protected(
user: Annotated[User, Depends(get_current_user)]
) -> HTMLResponse:
return HTMLResponse("<h1>Only authenticated users can see this</h1>")
@app.get("/protected-by-foorole")
@hasrole("foorole")
async def get_protected_by_foorole(request: Request) -> HTMLResponse:
return HTMLResponse("<h1>Only users with foorole can see this</h1>")
@app.get("/protected-by-barrole")
@hasrole("barrole")
async def get_protected_by_barrole(request: Request) -> HTMLResponse:
return HTMLResponse("<h1>Protected by barrole</h1>")
@app.get("/protected-by-foorole-and-barrole")
@hasrole("barrole")
@hasrole("foorole")
async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse:
return HTMLResponse("<h1>Only users with foorole and barrole can see this</h1>")
@app.get("/protected-by-foorole-or-barrole")
@hasrole(["foorole", "barrole"])
async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse:
return HTMLResponse("<h1>Only users with foorole or barrole can see this</h1>")
@app.get("/introspect")
async def get_introspect(
request: Request,
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
token: Annotated[OAuth2Token, Depends(get_token)],
) -> JSONResponse:
if (
response := await provider.post(
provider.server_metadata["introspection_endpoint"],
token=token,
data={"token": token["access_token"]},
)
).is_success:
return response.json()
else:
raise HTTPException(status_code=response.status_code, detail=response.text)
@app.get("/oauth2-forgejo-test")
async def get_forgejo_user_info(
request: Request,
user: Annotated[User, Depends(get_current_user)],
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
token: Annotated[OAuth2Token, Depends(get_token)],
) -> HTMLResponse:
if (
response := await provider.get(
"/api/v1/user/repos",
# headers={"Authorization": f"token {token['access_token']}"},
token=token,
)
).is_success:
repos = response.json()
names = [repo["name"] for repo in repos]
return HTMLResponse(f"{user.name} has {len(repos)} repos: {', '.join(names)}")
else:
raise HTTPException(status_code=response.status_code, detail=response.text)
# @app.get("/fast_api_depends")
# def fast_api_depends(
# token: Annotated[str, Depends(fastapi_providers["Keycloak"])]
# ) -> HTMLResponse:
# return HTMLResponse("You're Authenticated")
def main():
from uvicorn import run
from argparse import ArgumentParser
parser = ArgumentParser(description=__doc__)
parser.add_argument(
"-l",
"--host",
type=str,
default="0.0.0.0",
help="Address to listen to (default: 0.0.0.0)",
)
parser.add_argument(
"-p", "--port", type=int, default=80, help="Port to listen to (default: 80)"
)
parser.add_argument(
"-v", "--version", action="store_true", help="Print version and exit"
)
args = parser.parse_args()
if args.version:
import sys
from importlib.metadata import version
print(version("oidc-fastapi-test"))
sys.exit(0)
run(app, host=args.host, port=args.port)