oidc-fastapi-test/src/oidc_test/main.py

351 lines
11 KiB
Python
Raw Normal View History

"""
Test of OpenId Connect & OAuth2 with FastAPI
"""
2025-01-02 11:23:53 +01:00
from typing import Annotated
from pathlib import Path
2025-01-09 23:41:32 +01:00
from datetime import datetime
import logging
from urllib.parse import urlencode
2025-01-02 02:14:30 +01:00
from httpx import HTTPError
2025-01-16 05:43:26 +01:00
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
2025-01-02 02:14:30 +01:00
from fastapi.templating import Jinja2Templates
2025-01-09 23:41:32 +01:00
from fastapi.security import OpenIdConnect
2025-01-02 02:14:30 +01:00
from starlette.middleware.sessions import SessionMiddleware
2025-01-02 04:04:45 +01:00
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from authlib.integrations.base_client import OAuthError
from authlib.integrations.httpx_client import AsyncOAuth2Client
from authlib.oauth2.rfc6749 import OAuth2Token
from pkce import generate_code_verifier, generate_pkce_pair
2025-01-09 23:41:32 +01:00
2025-01-02 11:23:53 +01:00
from .settings import settings
2025-01-02 02:14:30 +01:00
from .models import User
2025-01-09 23:41:32 +01:00
from .auth_utils import (
get_provider,
hasrole,
get_current_user_or_none,
get_current_user,
authlib_oauth,
get_token,
2025-01-09 23:41:32 +01:00
)
from .auth_misc import pretty_details
from .database import db
2025-01-02 02:14:30 +01:00
logger = logging.getLogger("uvicorn.error")
2025-01-09 23:41:32 +01:00
templates = Jinja2Templates(Path(__file__).parent / "templates")
2025-01-02 02:14:30 +01:00
app = FastAPI(
title="OIDC auth test",
2025-01-02 11:23:53 +01:00
)
2025-01-09 23:41:32 +01:00
# SessionMiddleware is required by authlib
2025-01-02 11:23:53 +01:00
app.add_middleware(
SessionMiddleware,
secret_key=settings.secret_key,
2025-01-02 02:14:30 +01:00
)
# Add oidc providers to authlib from the settings
2025-01-09 23:41:32 +01:00
fastapi_providers = {}
_providers = {}
2025-01-02 02:14:30 +01:00
for provider in settings.oidc.providers:
authlib_oauth.register(
2025-01-10 00:09:12 +01:00
name=provider.id,
2025-01-09 23:41:32 +01:00
server_metadata_url=provider.openid_configuration,
client_kwargs={
2025-01-16 05:43:26 +01:00
"scope": "openid email", # offline_access profile",
},
2025-01-02 02:14:30 +01:00
client_id=provider.client_id,
client_secret=provider.client_secret,
api_base_url=provider.url,
2025-01-16 05:43:26 +01:00
# For PKCE (not implemented yet):
# code_challenge_method="S256",
2025-01-09 23:41:32 +01:00
# fetch_token=fetch_token,
# update_token=update_token,
2025-01-02 02:14:30 +01:00
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
)
2025-01-10 00:09:12 +01:00
fastapi_providers[provider.id] = OpenIdConnect(
2025-01-09 23:41:32 +01:00
openIdConnectUrl=provider.openid_configuration
)
2025-01-10 00:09:12 +01:00
_providers[provider.id] = provider
2025-01-09 23:41:32 +01:00
2025-01-02 02:14:30 +01:00
2025-01-09 23:41:32 +01:00
# Endpoints for the login / authorization process
2025-01-02 02:14:30 +01:00
2025-01-09 23:41:32 +01:00
@app.get("/login/{oidc_provider_id}")
async def login(request: Request, oidc_provider_id: str) -> RedirectResponse:
"""Login with the provider id, giving the browser a redirect to its authorize page.
2025-01-16 05:43:26 +01:00
The provider is expected to send the browser back to our own /auth/{oidc_provider_id} url
with the token.
2025-01-09 23:41:32 +01:00
"""
redirect_uri = request.url_for("auth", oidc_provider_id=oidc_provider_id)
2025-01-02 02:14:30 +01:00
try:
2025-01-09 23:41:32 +01:00
provider_: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
2025-01-02 03:30:18 +01:00
except AttributeError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
if (
code_challenge_method := _providers[oidc_provider_id].code_challenge_method
) is not None:
client = AsyncOAuth2Client(..., code_challenge_method=code_challenge_method)
code_verifier = generate_code_verifier()
logger.debug("TODO: PKCE")
else:
code_verifier = None
2025-01-02 03:30:18 +01:00
try:
2025-01-11 20:41:33 +01:00
response = await provider_.authorize_redirect(
request,
redirect_uri,
access_type="offline",
code_verifier=code_verifier,
2025-01-09 23:41:32 +01:00
)
2025-01-11 20:41:33 +01:00
return response
2025-01-02 02:14:30 +01:00
except HTTPError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
@app.get("/auth/{oidc_provider_id}")
async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
2025-01-09 23:41:32 +01:00
"""Decrypt the auth token, store it to the session (cookie based)
and response to the browser with a redirect to a "welcome user" page.
"""
2025-01-02 02:14:30 +01:00
try:
oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
2025-01-02 03:30:18 +01:00
except AttributeError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
2025-01-02 03:30:18 +01:00
try:
token: OAuth2Token = await oidc_provider.authorize_access_token(request)
2025-01-02 02:14:30 +01:00
except OAuthError as error:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error)
# Remember the oidc_provider in the session
# logger.debug(f"Scope: {token['scope']}")
request.session["oidc_provider_id"] = oidc_provider_id
#
# One could process the full decoded token which contains extra information
# eg for updates. Here we are only interested in roles
#
if userinfo := token.get("userinfo"):
2025-01-11 20:41:33 +01:00
# Remember the oidc_provider in the session
request.session["oidc_provider_id"] = oidc_provider_id
# User id (sub) given by oidc provider
sub = userinfo["sub"]
# Get additional data from userinfo endpoint
try:
2025-01-16 05:43:26 +01:00
user_info_from_endpoint = await oidc_provider.userinfo(
token=token, follow_redirects=True
)
except Exception as err:
2025-01-16 05:43:26 +01:00
logger.warn(f"Cannot get userinfo from endpoint: {err}")
user_info_from_endpoint = {}
# Build and remember the user in the session
request.session["user_sub"] = sub
# Store the user in the database
user = await db.add_user(
sub,
user_info=userinfo,
oidc_provider=oidc_provider,
user_info_from_endpoint=user_info_from_endpoint,
)
2025-01-11 20:41:33 +01:00
# Add the id_token to the session
2025-01-13 05:45:31 +01:00
request.session["token"] = token["id_token"]
2025-01-11 20:41:33 +01:00
# Add the token to the db because it is used for logout
2025-01-09 23:41:32 +01:00
await db.add_token(token, user)
2025-01-11 20:41:33 +01:00
# Send the user to the home: (s)he is authenticated
return RedirectResponse(url=request.url_for("home"))
2025-01-02 03:30:18 +01:00
else:
2025-01-09 23:41:32 +01:00
# Not sure if it's correct to redirect to plain login
# if no userinfo is provided
return RedirectResponse(
url=request.url_for("login", oidc_provider_id=oidc_provider_id)
)
2025-01-09 23:41:32 +01:00
2025-01-02 02:14:30 +01:00
@app.get("/logout")
async def logout(
2025-01-09 23:41:32 +01:00
request: Request,
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
) -> RedirectResponse:
2025-01-09 23:41:32 +01:00
# Clear session
request.session.pop("user_sub", None)
2025-01-09 23:41:32 +01:00
# Get provider's endpoint
if (
provider_logout_uri := provider.server_metadata.get("end_session_endpoint")
) is None:
logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}")
return RedirectResponse(request.url_for("non_compliant_logout"))
post_logout_uri = request.url_for("home")
if (token := await db.get_token(request.session.pop("token", None))) is None:
2025-01-09 23:41:32 +01:00
logger.warn("No session in db for the token")
return RedirectResponse(request.url_for("home"))
logout_url = (
provider_logout_uri
+ "?"
+ urlencode(
{
"post_logout_redirect_uri": post_logout_uri,
"id_token_hint": token["id_token"],
2025-01-09 23:41:32 +01:00
"cliend_id": "oidc_local_test",
}
)
)
return RedirectResponse(logout_url)
2025-01-11 20:41:33 +01:00
@app.get("/non-compliant-logout")
async def non_compliant_logout(
request: Request,
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
):
"""A page for non-compliant OAuth2 servers that we cannot log out."""
return templates.TemplateResponse(
name="non_compliant_logout.html",
request=request,
context={"provider": provider, "home_url": request.url_for("home")},
)
2025-01-13 05:45:31 +01:00
2025-01-09 23:41:32 +01:00
# Home URL
2025-01-02 02:14:30 +01:00
2025-01-02 03:16:03 +01:00
@app.get("/")
2025-01-02 02:14:30 +01:00
async def home(
request: Request, user: Annotated[User, Depends(get_current_user_or_none)]
) -> HTMLResponse:
2025-01-09 23:41:32 +01:00
now = datetime.now()
2025-01-02 02:14:30 +01:00
return templates.TemplateResponse(
2025-01-09 23:41:32 +01:00
name="home.html",
2025-01-02 02:14:30 +01:00
request=request,
context={
2025-01-02 10:46:02 +01:00
"settings": settings.model_dump(),
2025-01-02 02:14:30 +01:00
"user": user,
2025-01-09 23:41:32 +01:00
"now": now,
"user_info_details": (
pretty_details(user, now)
if user and settings.oidc.show_session_details
else None
),
2025-01-02 02:14:30 +01:00
},
)
@app.get("/public")
async def public() -> HTMLResponse:
return HTMLResponse("<h1>Not protected</h1>")
2025-01-09 23:41:32 +01:00
# Some URIs for testing the permissions
2025-01-02 03:09:16 +01:00
@app.get("/protected")
2025-01-02 03:16:03 +01:00
async def get_protected(
user: Annotated[User, Depends(get_current_user)]
) -> HTMLResponse:
return HTMLResponse("<h1>Only authenticated users can see this</h1>")
@app.get("/protected-by-foorole")
@hasrole("foorole")
async def get_protected_by_foorole(request: Request) -> HTMLResponse:
return HTMLResponse("<h1>Only users with foorole can see this</h1>")
@app.get("/protected-by-barrole")
@hasrole("barrole")
async def get_protected_by_barrole(request: Request) -> HTMLResponse:
return HTMLResponse("<h1>Protected by barrole</h1>")
@app.get("/protected-by-foorole-and-barrole")
@hasrole("barrole")
@hasrole("foorole")
async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse:
return HTMLResponse("<h1>Only users with foorole and barrole can see this</h1>")
@app.get("/protected-by-foorole-or-barrole")
@hasrole(["foorole", "barrole"])
async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse:
return HTMLResponse("<h1>Only users with foorole or barrole can see this</h1>")
2025-01-09 23:41:32 +01:00
@app.get("/introspect")
async def get_introspect(
request: Request,
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
token: Annotated[OAuth2Token, Depends(get_token)],
) -> JSONResponse:
if (
response := await provider.post(
provider.server_metadata["introspection_endpoint"],
token=token,
data={"token": token["access_token"]},
)
).is_success:
return response.json()
else:
raise HTTPException(status_code=response.status_code, detail=response.text)
@app.get("/oauth2-forgejo-test")
async def get_forgejo_user_info(
request: Request,
user: Annotated[User, Depends(get_current_user)],
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
token: Annotated[OAuth2Token, Depends(get_token)],
) -> HTMLResponse:
if (
response := await provider.get(
"/api/v1/user/repos",
# headers={"Authorization": f"token {token['access_token']}"},
token=token,
)
).is_success:
repos = response.json()
names = [repo["name"] for repo in repos]
return HTMLResponse(f"{user.name} has {len(repos)} repos: {', '.join(names)}")
else:
raise HTTPException(status_code=response.status_code, detail=response.text)
2025-01-09 23:41:32 +01:00
# @app.get("/fast_api_depends")
# def fast_api_depends(
# token: Annotated[str, Depends(fastapi_providers["Keycloak"])]
# ) -> HTMLResponse:
# return HTMLResponse("You're Authenticated")
def main():
from uvicorn import run
from argparse import ArgumentParser
parser = ArgumentParser(description=__doc__)
parser.add_argument(
"-l",
"--host",
type=str,
default="0.0.0.0",
help="Address to listen to (default: 0.0.0.0)",
)
parser.add_argument(
"-p", "--port", type=int, default=80, help="Port to listen to (default: 80)"
)
parser.add_argument(
"-v", "--version", action="store_true", help="Print version and exit"
)
args = parser.parse_args()
if args.version:
import sys
from importlib.metadata import version
print(version("oidc-fastapi-test"))
2025-01-09 23:41:32 +01:00
sys.exit(0)
run(app, host=args.host, port=args.port)