""" Test of OpenId Connect & OAuth2 with FastAPI """ from typing import Annotated from pathlib import Path from datetime import datetime import logging from urllib.parse import urlencode from httpx import HTTPError from fastapi import Depends, FastAPI, HTTPException, Request, status from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse from fastapi.templating import Jinja2Templates from starlette.middleware.sessions import SessionMiddleware from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.base_client import OAuthError from authlib.oauth2.rfc6749 import OAuth2Token # TODO: PKCE # from authlib.integrations.httpx_client import AsyncOAuth2Client # from fastapi.security import OpenIdConnect # from pkce import generate_code_verifier, generate_pkce_pair from .settings import settings, OIDCProvider from .models import User from .auth_utils import ( get_oidc_provider, get_oidc_provider_or_none, hasrole, get_current_user_or_none, get_current_user, authlib_oauth, get_token, ) from .auth_misc import pretty_details from .database import db logger = logging.getLogger("uvicorn.error") templates = Jinja2Templates(Path(__file__).parent / "templates") app = FastAPI( title="OIDC auth test", ) app.mount( "/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static" ) # SessionMiddleware is required by authlib app.add_middleware( SessionMiddleware, secret_key=settings.secret_key, ) # Add oidc providers to authlib from the settings # fastapi_providers: dict[str, OpenIdConnect] = {} providers_settings: dict[str, OIDCProvider] = {} for provider in settings.oidc.providers: authlib_oauth.register( name=provider.id, server_metadata_url=provider.openid_configuration, client_kwargs={ "scope": "openid email", # offline_access profile", }, client_id=provider.client_id, client_secret=provider.client_secret, api_base_url=provider.url, # For PKCE (not implemented yet): # code_challenge_method="S256", # fetch_token=fetch_token, # update_token=update_token, # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) ) # fastapi_providers[provider.id] = OpenIdConnect( # openIdConnectUrl=provider.openid_configuration # ) providers_settings[provider.id] = provider @app.get("/") async def home( request: Request, user: Annotated[User, Depends(get_current_user_or_none)], oidc_provider: Annotated[ StarletteOAuth2App | None, Depends(get_oidc_provider_or_none) ], ) -> HTMLResponse: now = datetime.now() if oidc_provider and ( ( provider := providers_settings.get( request.session.get("oidc_provider_id", "") ) ) is not None ): resources = provider.resources else: resources = [] return templates.TemplateResponse( name="home.html", request=request, context={ "settings": settings.model_dump(), "user": user, "now": now, "resources": resources, "user_info_details": ( pretty_details(user, now) if user and settings.oidc.show_session_details else None ), }, ) # Endpoints for the login / authorization process @app.get("/login/{oidc_provider_id}") async def login(request: Request, oidc_provider_id: str) -> RedirectResponse: """Login with the provider id, giving the browser a redirect to its authorize page. The provider is expected to send the browser back to our own /auth/{oidc_provider_id} url with the token. """ redirect_uri = request.url_for("auth", oidc_provider_id=oidc_provider_id) try: provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id) except AttributeError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") # if ( # code_challenge_method := providers_settings[ # oidc_provider_id # ].code_challenge_method # ) is not None: # #client = AsyncOAuth2Client(..., code_challenge_method=code_challenge_method) # code_verifier = generate_code_verifier() # logger.debug("TODO: PKCE") # else: # code_verifier = None try: response = await provider.authorize_redirect( request, redirect_uri, access_type="offline", code_verifier=None, ) return response except HTTPError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider") @app.get("/auth/{oidc_provider_id}") async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: """Decrypt the auth token, store it to the session (cookie based) and response to the browser with a redirect to a "welcome user" page. """ try: oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id) except AttributeError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") try: token: OAuth2Token = await oidc_provider.authorize_access_token(request) except OAuthError as error: raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error) # Remember the oidc_provider in the session # logger.debug(f"Scope: {token['scope']}") request.session["oidc_provider_id"] = oidc_provider_id # # One could process the full decoded token which contains extra information # eg for updates. Here we are only interested in roles # if userinfo := token.get("userinfo"): # Remember the oidc_provider in the session request.session["oidc_provider_id"] = oidc_provider_id # User id (sub) given by oidc provider sub = userinfo["sub"] # Get additional data from userinfo endpoint try: user_info_from_endpoint = await oidc_provider.userinfo( token=token, follow_redirects=True ) except Exception as err: logger.warn(f"Cannot get userinfo from endpoint: {err}") user_info_from_endpoint = {} # Build and remember the user in the session request.session["user_sub"] = sub # Store the user in the database user = await db.add_user( sub, user_info=userinfo, oidc_provider=oidc_provider, user_info_from_endpoint=user_info_from_endpoint, ) # Add the id_token to the session request.session["token"] = token["id_token"] # Add the token to the db because it is used for logout await db.add_token(token, user) # Send the user to the home: (s)he is authenticated return RedirectResponse(url=request.url_for("home")) else: # Not sure if it's correct to redirect to plain login # if no userinfo is provided return RedirectResponse( url=request.url_for("login", oidc_provider_id=oidc_provider_id) ) @app.get("/logout") async def logout( request: Request, oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], ) -> RedirectResponse: # Clear session request.session.pop("user_sub", None) # Get provider's endpoint if ( provider_logout_uri := oidc_provider.server_metadata.get("end_session_endpoint") ) is None: logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}") return RedirectResponse(request.url_for("non_compliant_logout")) post_logout_uri = request.url_for("home") if (token := await db.get_token(request.session.pop("token", None))) is None: logger.warn("No session in db for the token") return RedirectResponse(request.url_for("home")) logout_url = ( provider_logout_uri + "?" + urlencode( { "post_logout_redirect_uri": post_logout_uri, "id_token_hint": token["id_token"], "cliend_id": "oidc_local_test", } ) ) return RedirectResponse(logout_url) @app.get("/non-compliant-logout") async def non_compliant_logout( request: Request, oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], ): """A page for non-compliant OAuth2 servers that we cannot log out.""" return templates.TemplateResponse( name="non_compliant_logout.html", request=request, context={"oidc_provider": oidc_provider, "home_url": request.url_for("home")}, ) # Route for OAuth resource server @app.get("/resource/{id}") async def get_resource( id: str, request: Request, user: Annotated[User, Depends(get_current_user)], oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], token: Annotated[OAuth2Token, Depends(get_token)], ) -> JSONResponse: """Generic path for testing a resource provided by a provider""" assert user is not None if oidc_provider is None: raise HTTPException( status.HTTP_406_NOT_ACCEPTABLE, detail="No such oidc provider" ) if ( provider := providers_settings.get(request.session.get("oidc_provider_id", "")) ) is None: raise HTTPException( status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider setting" ) try: resource = next(x for x in provider.resources if x.id == id) except StopIteration: raise HTTPException( status.HTTP_406_NOT_ACCEPTABLE, detail="No such resource for this provider" ) if ( response := await oidc_provider.get( resource.url, # headers={"Authorization": f"token {token['access_token']}"}, token=token, ) ).is_success: return JSONResponse(response.json()) else: raise HTTPException(status_code=response.status_code, detail=response.text) # Routes for test @app.get("/public") async def public() -> HTMLResponse: return HTMLResponse("

Not protected

") @app.get("/protected") async def get_protected( user: Annotated[User, Depends(get_current_user)] ) -> HTMLResponse: assert user is not None return HTMLResponse("

Only authenticated users can see this

") @app.get("/protected-by-foorole") @hasrole("foorole") async def get_protected_by_foorole(request: Request) -> HTMLResponse: assert request is not None return HTMLResponse("

Only users with foorole can see this

") @app.get("/protected-by-barrole") @hasrole("barrole") async def get_protected_by_barrole(request: Request) -> HTMLResponse: assert request is not None return HTMLResponse("

Protected by barrole

") @app.get("/protected-by-foorole-and-barrole") @hasrole("barrole") @hasrole("foorole") async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse: assert request is not None return HTMLResponse("

Only users with foorole and barrole can see this

") @app.get("/protected-by-foorole-or-barrole") @hasrole(["foorole", "barrole"]) async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse: assert request is not None return HTMLResponse("

Only users with foorole or barrole can see this

") @app.get("/introspect") async def get_introspect( request: Request, oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], token: Annotated[OAuth2Token, Depends(get_token)], ) -> JSONResponse: assert request is not None if ( response := await oidc_provider.post( oidc_provider.server_metadata["introspection_endpoint"], token=token, data={"token": token["access_token"]}, ) ).is_success: return response.json() else: raise HTTPException(status_code=response.status_code, detail=response.text) @app.get("/oauth2-forgejo-test") async def get_forgejo_user_info( request: Request, user: Annotated[User, Depends(get_current_user)], oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], token: Annotated[OAuth2Token, Depends(get_token)], ) -> HTMLResponse: assert request is not None if ( response := await oidc_provider.get( "/api/v1/user/repos", # headers={"Authorization": f"token {token['access_token']}"}, token=token, ) ).is_success: repos = response.json() names = [repo["name"] for repo in repos] return HTMLResponse(f"{user.name} has {len(repos)} repos: {', '.join(names)}") else: raise HTTPException(status_code=response.status_code, detail=response.text) # Snippet for running standalone # Mostly useful for the --version option, # as running with uvicorn is easy and provides flaxibility def main(): from uvicorn import run from argparse import ArgumentParser parser = ArgumentParser(description=__doc__) parser.add_argument( "-l", "--host", type=str, default="0.0.0.0", help="Address to listen to (default: 0.0.0.0)", ) parser.add_argument( "-p", "--port", type=int, default=80, help="Port to listen to (default: 80)" ) parser.add_argument( "-v", "--version", action="store_true", help="Print version and exit" ) args = parser.parse_args() if args.version: import sys from importlib.metadata import version print(version("oidc-fastapi-test")) sys.exit(0) run(app, host=args.host, port=args.port)