2025-01-10 17:33:10 +01:00
|
|
|
"""
|
|
|
|
Test of OpenId Connect & OAuth2 with FastAPI
|
|
|
|
"""
|
|
|
|
|
2025-01-02 11:23:53 +01:00
|
|
|
from typing import Annotated
|
2025-01-10 17:33:10 +01:00
|
|
|
from pathlib import Path
|
2025-01-09 23:41:32 +01:00
|
|
|
from datetime import datetime
|
|
|
|
import logging
|
|
|
|
from urllib.parse import urlencode
|
2025-01-29 14:03:33 +01:00
|
|
|
from contextlib import asynccontextmanager
|
2025-01-02 02:14:30 +01:00
|
|
|
|
|
|
|
from httpx import HTTPError
|
2025-01-16 05:43:26 +01:00
|
|
|
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
2025-01-19 01:48:00 +01:00
|
|
|
from fastapi.staticfiles import StaticFiles
|
2025-02-07 13:57:17 +01:00
|
|
|
from fastapi.responses import HTMLResponse, RedirectResponse
|
2025-01-02 02:14:30 +01:00
|
|
|
from fastapi.templating import Jinja2Templates
|
2025-01-28 19:48:35 +01:00
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
2025-02-06 13:30:35 +01:00
|
|
|
from jwt import InvalidTokenError, PyJWTError
|
2025-01-02 02:14:30 +01:00
|
|
|
from starlette.middleware.sessions import SessionMiddleware
|
2025-01-02 04:04:45 +01:00
|
|
|
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
2025-01-18 06:20:44 +01:00
|
|
|
from authlib.integrations.base_client import OAuthError
|
|
|
|
from authlib.oauth2.rfc6749 import OAuth2Token
|
2025-01-19 16:27:12 +01:00
|
|
|
|
|
|
|
# TODO: PKCE
|
|
|
|
# from authlib.integrations.httpx_client import AsyncOAuth2Client
|
|
|
|
# from fastapi.security import OpenIdConnect
|
|
|
|
# from pkce import generate_code_verifier, generate_pkce_pair
|
2025-01-09 23:41:32 +01:00
|
|
|
|
2025-02-06 13:30:35 +01:00
|
|
|
from .settings import settings, oidc_providers_settings
|
2025-01-02 02:14:30 +01:00
|
|
|
from .models import User
|
2025-01-09 23:41:32 +01:00
|
|
|
from .auth_utils import (
|
2025-01-19 01:48:00 +01:00
|
|
|
get_oidc_provider,
|
2025-01-19 14:26:54 +01:00
|
|
|
get_oidc_provider_or_none,
|
2025-01-09 23:41:32 +01:00
|
|
|
get_current_user_or_none,
|
|
|
|
authlib_oauth,
|
2025-01-29 14:03:33 +01:00
|
|
|
get_providers_info,
|
2025-02-08 01:55:36 +01:00
|
|
|
get_token_or_none,
|
2025-01-09 23:41:32 +01:00
|
|
|
)
|
|
|
|
from .auth_misc import pretty_details
|
2025-02-04 18:03:17 +01:00
|
|
|
from .database import TokenNotInDb, db
|
2025-02-07 13:57:17 +01:00
|
|
|
from .resource_server import resource_server
|
2025-01-02 02:14:30 +01:00
|
|
|
|
2025-02-06 13:30:35 +01:00
|
|
|
logger = logging.getLogger("oidc-test")
|
2025-01-09 23:41:32 +01:00
|
|
|
|
2025-01-10 17:33:10 +01:00
|
|
|
templates = Jinja2Templates(Path(__file__).parent / "templates")
|
2025-01-02 02:14:30 +01:00
|
|
|
|
|
|
|
|
2025-01-29 14:03:33 +01:00
|
|
|
@asynccontextmanager
|
|
|
|
async def lifespan(app: FastAPI):
|
2025-02-07 13:57:17 +01:00
|
|
|
assert app is not None
|
2025-01-29 14:03:33 +01:00
|
|
|
await get_providers_info()
|
|
|
|
yield
|
|
|
|
|
|
|
|
|
2025-01-30 20:40:04 +01:00
|
|
|
app = FastAPI(title="OIDC auth test", lifespan=lifespan)
|
2025-01-02 11:23:53 +01:00
|
|
|
|
2025-01-28 19:48:35 +01:00
|
|
|
app.add_middleware(
|
|
|
|
CORSMiddleware,
|
2025-01-31 00:12:50 +01:00
|
|
|
allow_origins=settings.cors_origins,
|
2025-01-28 19:48:35 +01:00
|
|
|
allow_credentials=True,
|
|
|
|
allow_methods=["*"],
|
|
|
|
allow_headers=["*"],
|
|
|
|
)
|
|
|
|
|
2025-01-03 17:00:38 +01:00
|
|
|
# SessionMiddleware is required by authlib
|
2025-01-02 11:23:53 +01:00
|
|
|
app.add_middleware(
|
|
|
|
SessionMiddleware,
|
|
|
|
secret_key=settings.secret_key,
|
2025-01-02 02:14:30 +01:00
|
|
|
)
|
|
|
|
|
2025-02-07 13:57:17 +01:00
|
|
|
app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static")
|
|
|
|
app.mount("/resource", resource_server, name="resource_server")
|
|
|
|
|
2025-01-02 02:14:30 +01:00
|
|
|
|
2025-01-19 01:48:00 +01:00
|
|
|
@app.get("/")
|
|
|
|
async def home(
|
2025-01-19 14:26:54 +01:00
|
|
|
request: Request,
|
|
|
|
user: Annotated[User, Depends(get_current_user_or_none)],
|
2025-02-07 13:57:17 +01:00
|
|
|
oidc_provider: Annotated[StarletteOAuth2App | None, Depends(get_oidc_provider_or_none)],
|
2025-02-08 01:55:36 +01:00
|
|
|
token: Annotated[OAuth2Token | None, Depends(get_token_or_none)],
|
2025-01-19 01:48:00 +01:00
|
|
|
) -> HTMLResponse:
|
|
|
|
now = datetime.now()
|
2025-01-19 14:26:54 +01:00
|
|
|
if oidc_provider and (
|
2025-01-19 16:27:12 +01:00
|
|
|
(
|
2025-01-26 23:37:56 +01:00
|
|
|
oidc_provider_settings := oidc_providers_settings.get(
|
2025-01-19 16:27:12 +01:00
|
|
|
request.session.get("oidc_provider_id", "")
|
|
|
|
)
|
|
|
|
)
|
|
|
|
is not None
|
2025-01-19 14:26:54 +01:00
|
|
|
):
|
2025-01-26 19:08:49 +01:00
|
|
|
resources = oidc_provider_settings.resources
|
2025-01-19 14:26:54 +01:00
|
|
|
else:
|
|
|
|
resources = []
|
2025-01-26 19:08:49 +01:00
|
|
|
oidc_provider_settings = None
|
2025-02-04 03:38:33 +01:00
|
|
|
|
|
|
|
if user is None:
|
|
|
|
access_token_scope = None
|
|
|
|
else:
|
|
|
|
try:
|
|
|
|
access_token_scope = user.decode_access_token()["scope"]
|
|
|
|
except InvalidTokenError as err:
|
|
|
|
access_token_scope = None
|
|
|
|
logger.info("Invalid token")
|
|
|
|
logger.exception(err)
|
|
|
|
|
2025-02-08 01:55:36 +01:00
|
|
|
context = {
|
|
|
|
"settings": settings.model_dump(),
|
|
|
|
"user": user,
|
|
|
|
"access_token_scope": access_token_scope,
|
|
|
|
"now": now,
|
|
|
|
"oidc_provider": oidc_provider,
|
|
|
|
"oidc_provider_settings": oidc_provider_settings,
|
|
|
|
"resources": resources,
|
|
|
|
}
|
|
|
|
if token is None:
|
|
|
|
context["id_token_parsed"] = None
|
|
|
|
context["access_token_parsed"] = None
|
|
|
|
context["refresh_token_parsed"] = None
|
|
|
|
else:
|
|
|
|
assert oidc_provider is not None
|
|
|
|
assert oidc_provider.name is not None
|
|
|
|
oidc_provider_settings = oidc_providers_settings[oidc_provider.name]
|
|
|
|
context["id_token_parsed"] = pretty_details(user, now)
|
|
|
|
context["access_token_parsed"] = oidc_provider_settings.decode(token["access_token"])
|
|
|
|
context["refresh_token_parsed"] = oidc_provider_settings.decode(
|
|
|
|
token["refresh_token"], verify_signature=False
|
|
|
|
)
|
|
|
|
return templates.TemplateResponse(name="home.html", request=request, context=context)
|
2025-01-19 01:48:00 +01:00
|
|
|
|
|
|
|
|
2025-01-09 23:41:32 +01:00
|
|
|
# Endpoints for the login / authorization process
|
2025-01-02 02:14:30 +01:00
|
|
|
|
2025-01-09 23:41:32 +01:00
|
|
|
|
|
|
|
@app.get("/login/{oidc_provider_id}")
|
|
|
|
async def login(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
2025-01-18 06:20:44 +01:00
|
|
|
"""Login with the provider id, giving the browser a redirect to its authorize page.
|
2025-01-16 05:43:26 +01:00
|
|
|
The provider is expected to send the browser back to our own /auth/{oidc_provider_id} url
|
|
|
|
with the token.
|
2025-01-09 23:41:32 +01:00
|
|
|
"""
|
|
|
|
redirect_uri = request.url_for("auth", oidc_provider_id=oidc_provider_id)
|
2025-01-02 02:14:30 +01:00
|
|
|
try:
|
2025-01-19 14:26:54 +01:00
|
|
|
provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
|
2025-01-02 03:30:18 +01:00
|
|
|
except AttributeError:
|
2025-01-05 05:06:58 +01:00
|
|
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
2025-01-19 16:27:12 +01:00
|
|
|
# if (
|
2025-01-26 23:37:56 +01:00
|
|
|
# code_challenge_method := oidc_providers_settings[
|
2025-01-19 16:27:12 +01:00
|
|
|
# oidc_provider_id
|
|
|
|
# ].code_challenge_method
|
|
|
|
# ) is not None:
|
|
|
|
# #client = AsyncOAuth2Client(..., code_challenge_method=code_challenge_method)
|
|
|
|
# code_verifier = generate_code_verifier()
|
|
|
|
# logger.debug("TODO: PKCE")
|
|
|
|
# else:
|
|
|
|
# code_verifier = None
|
2025-01-02 03:30:18 +01:00
|
|
|
try:
|
2025-01-19 14:26:54 +01:00
|
|
|
response = await provider.authorize_redirect(
|
2025-01-18 06:20:44 +01:00
|
|
|
request,
|
|
|
|
redirect_uri,
|
|
|
|
access_type="offline",
|
2025-01-19 16:27:12 +01:00
|
|
|
code_verifier=None,
|
2025-01-09 23:41:32 +01:00
|
|
|
)
|
2025-01-11 20:41:33 +01:00
|
|
|
return response
|
2025-01-02 02:14:30 +01:00
|
|
|
except HTTPError:
|
|
|
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
|
|
|
|
|
|
|
|
|
2025-01-05 05:06:58 +01:00
|
|
|
@app.get("/auth/{oidc_provider_id}")
|
|
|
|
async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
2025-01-09 23:41:32 +01:00
|
|
|
"""Decrypt the auth token, store it to the session (cookie based)
|
|
|
|
and response to the browser with a redirect to a "welcome user" page.
|
|
|
|
"""
|
2025-01-02 02:14:30 +01:00
|
|
|
try:
|
2025-01-05 05:06:58 +01:00
|
|
|
oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
|
2025-01-02 03:30:18 +01:00
|
|
|
except AttributeError:
|
2025-01-05 05:06:58 +01:00
|
|
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
2025-01-02 03:30:18 +01:00
|
|
|
try:
|
2025-01-18 06:20:44 +01:00
|
|
|
token: OAuth2Token = await oidc_provider.authorize_access_token(request)
|
2025-01-02 02:14:30 +01:00
|
|
|
except OAuthError as error:
|
2025-01-05 05:06:58 +01:00
|
|
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error)
|
|
|
|
# Remember the oidc_provider in the session
|
2025-01-20 01:16:17 +01:00
|
|
|
# logger.info(f"Scope: {token['scope']}")
|
2025-01-05 05:06:58 +01:00
|
|
|
request.session["oidc_provider_id"] = oidc_provider_id
|
|
|
|
#
|
|
|
|
# One could process the full decoded token which contains extra information
|
|
|
|
# eg for updates. Here we are only interested in roles
|
|
|
|
#
|
|
|
|
if userinfo := token.get("userinfo"):
|
2025-01-11 20:41:33 +01:00
|
|
|
# Remember the oidc_provider in the session
|
|
|
|
request.session["oidc_provider_id"] = oidc_provider_id
|
|
|
|
# User id (sub) given by oidc provider
|
2025-01-05 05:06:58 +01:00
|
|
|
sub = userinfo["sub"]
|
|
|
|
# Build and remember the user in the session
|
|
|
|
request.session["user_sub"] = sub
|
2025-02-06 13:30:35 +01:00
|
|
|
# Store the user in the database, which also verifies the token validity and signature
|
2025-02-02 15:54:44 +01:00
|
|
|
try:
|
2025-02-06 13:30:35 +01:00
|
|
|
user = await db.add_user(
|
|
|
|
sub,
|
|
|
|
user_info=userinfo,
|
|
|
|
oidc_provider=oidc_provider,
|
|
|
|
access_token=token["access_token"],
|
2025-02-04 03:38:33 +01:00
|
|
|
)
|
2025-02-06 13:30:35 +01:00
|
|
|
except PyJWTError as err:
|
2025-02-02 15:54:44 +01:00
|
|
|
raise HTTPException(
|
|
|
|
status.HTTP_401_UNAUTHORIZED,
|
2025-02-06 13:30:35 +01:00
|
|
|
detail=f"Token invalid: {err.__class__.__name__}",
|
2025-02-02 15:54:44 +01:00
|
|
|
)
|
2025-02-06 13:30:35 +01:00
|
|
|
assert isinstance(user, User)
|
|
|
|
# Add the provider session id to the session
|
|
|
|
request.session["sid"] = userinfo["sid"]
|
2025-01-11 20:41:33 +01:00
|
|
|
# Add the token to the db because it is used for logout
|
2025-02-06 13:30:35 +01:00
|
|
|
assert oidc_provider.name is not None
|
|
|
|
oidc_provider_settings = oidc_providers_settings[oidc_provider.name]
|
|
|
|
await db.add_token(oidc_provider_settings, token)
|
2025-01-11 20:41:33 +01:00
|
|
|
# Send the user to the home: (s)he is authenticated
|
2025-01-10 17:33:10 +01:00
|
|
|
return RedirectResponse(url=request.url_for("home"))
|
2025-01-02 03:30:18 +01:00
|
|
|
else:
|
2025-01-09 23:41:32 +01:00
|
|
|
# Not sure if it's correct to redirect to plain login
|
2025-01-05 05:06:58 +01:00
|
|
|
# if no userinfo is provided
|
2025-02-07 13:57:17 +01:00
|
|
|
return RedirectResponse(url=request.url_for("login", oidc_provider_id=oidc_provider_id))
|
2025-01-09 23:41:32 +01:00
|
|
|
|
|
|
|
|
2025-01-26 19:08:49 +01:00
|
|
|
@app.get("/account")
|
|
|
|
async def account(
|
|
|
|
request: Request,
|
|
|
|
) -> RedirectResponse:
|
|
|
|
if (
|
2025-02-02 15:54:44 +01:00
|
|
|
oidc_provider_settings := oidc_providers_settings.get(
|
2025-01-26 23:37:56 +01:00
|
|
|
request.session.get("oidc_provider_id", "")
|
|
|
|
)
|
2025-01-26 19:08:49 +01:00
|
|
|
) is None:
|
2025-02-07 13:57:17 +01:00
|
|
|
raise HTTPException(status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider settings")
|
2025-02-03 13:20:33 +01:00
|
|
|
return RedirectResponse(f"{oidc_provider_settings.account_url_template}")
|
2025-01-26 19:08:49 +01:00
|
|
|
|
|
|
|
|
2025-01-02 02:14:30 +01:00
|
|
|
@app.get("/logout")
|
2025-01-05 05:06:58 +01:00
|
|
|
async def logout(
|
2025-01-09 23:41:32 +01:00
|
|
|
request: Request,
|
2025-01-19 01:48:00 +01:00
|
|
|
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
2025-01-05 05:06:58 +01:00
|
|
|
) -> RedirectResponse:
|
2025-01-09 23:41:32 +01:00
|
|
|
# Clear session
|
2025-01-05 05:06:58 +01:00
|
|
|
request.session.pop("user_sub", None)
|
2025-01-09 23:41:32 +01:00
|
|
|
# Get provider's endpoint
|
2025-02-07 13:57:17 +01:00
|
|
|
if (provider_logout_uri := oidc_provider.server_metadata.get("end_session_endpoint")) is None:
|
|
|
|
logger.warn(f"Cannot find end_session_endpoint for provider {oidc_provider.name}")
|
2025-01-09 23:41:32 +01:00
|
|
|
return RedirectResponse(request.url_for("non_compliant_logout"))
|
|
|
|
post_logout_uri = request.url_for("home")
|
2025-02-06 13:30:35 +01:00
|
|
|
oidc_provider_settings = oidc_providers_settings.get(
|
|
|
|
request.session.get("oidc_provider_id", "")
|
|
|
|
)
|
|
|
|
assert oidc_provider_settings is not None
|
2025-02-04 18:03:17 +01:00
|
|
|
try:
|
2025-02-07 13:57:17 +01:00
|
|
|
token = await db.get_token(oidc_provider_settings, request.session.pop("sid", None))
|
2025-02-04 18:03:17 +01:00
|
|
|
except TokenNotInDb:
|
|
|
|
logger.warn("No session in db for the token or no token")
|
2025-01-09 23:41:32 +01:00
|
|
|
return RedirectResponse(request.url_for("home"))
|
|
|
|
logout_url = (
|
|
|
|
provider_logout_uri
|
|
|
|
+ "?"
|
|
|
|
+ urlencode(
|
|
|
|
{
|
|
|
|
"post_logout_redirect_uri": post_logout_uri,
|
2025-01-18 06:20:44 +01:00
|
|
|
"id_token_hint": token["id_token"],
|
2025-01-09 23:41:32 +01:00
|
|
|
"cliend_id": "oidc_local_test",
|
|
|
|
}
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return RedirectResponse(logout_url)
|
|
|
|
|
|
|
|
|
2025-01-11 20:41:33 +01:00
|
|
|
@app.get("/non-compliant-logout")
|
|
|
|
async def non_compliant_logout(
|
|
|
|
request: Request,
|
2025-01-19 01:48:00 +01:00
|
|
|
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
2025-01-11 20:41:33 +01:00
|
|
|
):
|
|
|
|
"""A page for non-compliant OAuth2 servers that we cannot log out."""
|
2025-01-19 16:45:21 +01:00
|
|
|
# Clear the remain of the session
|
|
|
|
request.session.pop("oidc_provider_id", None)
|
2025-01-11 20:41:33 +01:00
|
|
|
return templates.TemplateResponse(
|
|
|
|
name="non_compliant_logout.html",
|
|
|
|
request=request,
|
2025-01-19 14:54:08 +01:00
|
|
|
context={"oidc_provider": oidc_provider, "home_url": request.url_for("home")},
|
2025-01-11 20:41:33 +01:00
|
|
|
)
|
|
|
|
|
2025-01-13 05:45:31 +01:00
|
|
|
|
2025-01-19 01:48:00 +01:00
|
|
|
# Snippet for running standalone
|
|
|
|
# Mostly useful for the --version option,
|
2025-01-19 16:45:21 +01:00
|
|
|
# as running with uvicorn is easy and provides better flexibility, eg.
|
|
|
|
# uvicorn --host foo oidc_test.main:app --reload
|
2025-01-09 23:41:32 +01:00
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
from uvicorn import run
|
|
|
|
from argparse import ArgumentParser
|
|
|
|
|
|
|
|
parser = ArgumentParser(description=__doc__)
|
|
|
|
parser.add_argument(
|
|
|
|
"-l",
|
|
|
|
"--host",
|
|
|
|
type=str,
|
|
|
|
default="0.0.0.0",
|
|
|
|
help="Address to listen to (default: 0.0.0.0)",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"-p", "--port", type=int, default=80, help="Port to listen to (default: 80)"
|
|
|
|
)
|
2025-02-07 13:57:17 +01:00
|
|
|
parser.add_argument("-v", "--version", action="store_true", help="Print version and exit")
|
2025-01-09 23:41:32 +01:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
if args.version:
|
|
|
|
import sys
|
|
|
|
from importlib.metadata import version
|
|
|
|
|
2025-01-10 17:33:10 +01:00
|
|
|
print(version("oidc-fastapi-test"))
|
2025-01-09 23:41:32 +01:00
|
|
|
sys.exit(0)
|
|
|
|
|
|
|
|
run(app, host=args.host, port=args.port)
|