oidc-fastapi-test/src/oidc_test/main.py

351 lines
12 KiB
Python
Raw Normal View History

"""
Test of OpenId Connect & OAuth2 with FastAPI
"""
2025-01-02 11:23:53 +01:00
from typing import Annotated
from pathlib import Path
2025-01-09 23:41:32 +01:00
from datetime import datetime
import logging
2025-02-20 02:01:18 +01:00
import logging.config
import importlib.resources
from yaml import safe_load
2025-01-09 23:41:32 +01:00
from urllib.parse import urlencode
from contextlib import asynccontextmanager
2025-01-02 02:14:30 +01:00
from httpx import HTTPError
2025-01-16 05:43:26 +01:00
from fastapi import Depends, FastAPI, HTTPException, Request, status
2025-01-19 01:48:00 +01:00
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, RedirectResponse
2025-01-02 02:14:30 +01:00
from fastapi.templating import Jinja2Templates
2025-01-28 19:48:35 +01:00
from fastapi.middleware.cors import CORSMiddleware
from jwt import PyJWTError
2025-01-02 02:14:30 +01:00
from starlette.middleware.sessions import SessionMiddleware
2025-01-02 04:04:45 +01:00
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from authlib.integrations.base_client import OAuthError
from authlib.oauth2.rfc6749 import OAuth2Token
2025-01-19 16:27:12 +01:00
# TODO: PKCE
# from authlib.integrations.httpx_client import AsyncOAuth2Client
# from fastapi.security import OpenIdConnect
# from pkce import generate_code_verifier, generate_pkce_pair
2025-01-09 23:41:32 +01:00
from oidc_test.registry import registry
from oidc_test.auth.provider import NoPublicKey, Provider
from oidc_test.auth.utils import (
get_auth_provider,
get_auth_provider_or_none,
2025-01-09 23:41:32 +01:00
get_current_user_or_none,
authlib_oauth,
get_token_from_session_or_none,
get_token_from_session,
2025-02-08 18:32:02 +01:00
update_token,
2025-01-09 23:41:32 +01:00
)
from oidc_test.auth.utils import init_providers
from oidc_test.settings import settings
from oidc_test.auth_providers import providers
from oidc_test.models import User
from oidc_test.database import TokenNotInDb, db
from oidc_test.resource_server import resource_server
2025-01-02 02:14:30 +01:00
logger = logging.getLogger("oidc-test")
2025-01-09 23:41:32 +01:00
2025-02-20 02:01:18 +01:00
if settings.log:
assert __package__ is not None
with (
importlib.resources.path(__package__) as package_path,
open(package_path / settings.log_config_file) as f,
):
logging_config = safe_load(f)
logging.config.dictConfig(logging_config)
templates = Jinja2Templates(Path(__file__).parent / "templates")
2025-01-02 02:14:30 +01:00
@asynccontextmanager
async def lifespan(app: FastAPI):
assert app is not None
2025-02-10 14:14:32 +01:00
init_providers()
registry.make_registry()
for provider in list(providers.values()):
if provider.disabled:
continue
try:
await provider.get_info()
except NoPublicKey:
2025-02-13 12:23:18 +01:00
logger.warning(f"Disable {provider.id}: public key not found")
del providers[provider.id]
yield
app = FastAPI(title="OIDC auth test", lifespan=lifespan)
2025-01-02 11:23:53 +01:00
2025-01-28 19:48:35 +01:00
app.add_middleware(
CORSMiddleware,
2025-01-31 00:12:50 +01:00
allow_origins=settings.cors_origins,
2025-01-28 19:48:35 +01:00
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# SessionMiddleware is required by authlib
2025-01-02 11:23:53 +01:00
app.add_middleware(
SessionMiddleware,
secret_key=settings.secret_key,
2025-01-02 02:14:30 +01:00
)
app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static")
app.mount("/resource", resource_server, name="resource_server")
2025-01-02 02:14:30 +01:00
2025-01-19 01:48:00 +01:00
@app.get("/")
async def home(
2025-01-19 14:26:54 +01:00
request: Request,
user: Annotated[User | None, Depends(get_current_user_or_none)],
provider: Annotated[Provider | None, Depends(get_auth_provider_or_none)],
token: Annotated[OAuth2Token | None, Depends(get_token_from_session_or_none)],
2025-01-19 01:48:00 +01:00
) -> HTMLResponse:
2025-02-08 01:55:36 +01:00
context = {
"show_token": settings.show_token,
2025-02-08 01:55:36 +01:00
"user": user,
"now": datetime.now(),
"auth_provider": provider,
2025-02-08 01:55:36 +01:00
}
if provider is None or token is None:
context["providers"] = providers
2025-02-08 18:32:02 +01:00
context["access_token"] = None
2025-02-08 01:55:36 +01:00
context["id_token_parsed"] = None
context["access_token_parsed"] = None
context["refresh_token_parsed"] = None
context["resources"] = None
2025-02-08 01:55:36 +01:00
else:
2025-02-08 18:32:02 +01:00
context["access_token"] = token["access_token"]
try:
access_token_parsed = provider.decode(token["access_token"], verify_signature=False)
except PyJWTError as err:
access_token_parsed = {"Cannot parse": err.__class__.__name__}
try:
context["access_token_scope"] = access_token_parsed["scope"]
except KeyError:
context["access_token_scope"] = None
context["id_token_parsed"] = provider.decode(token["id_token"], verify_signature=False)
2025-02-08 19:05:13 +01:00
context["access_token_parsed"] = access_token_parsed
context["resources"] = registry.resources
context["resource_providers"] = provider.resource_providers
try:
context["refresh_token_parsed"] = provider.decode(
token["refresh_token"], verify_signature=False
)
except PyJWTError as err:
context["refresh_token_parsed"] = {"Cannot parse": err.__class__.__name__}
2025-02-08 01:55:36 +01:00
return templates.TemplateResponse(name="home.html", request=request, context=context)
2025-01-19 01:48:00 +01:00
2025-01-09 23:41:32 +01:00
# Endpoints for the login / authorization process
2025-01-02 02:14:30 +01:00
2025-01-09 23:41:32 +01:00
@app.get("/login/{auth_provider_id}")
async def login(request: Request, auth_provider_id: str) -> RedirectResponse:
"""Login with the provider id, giving the browser a redirect to its authorize page.
The provider is expected to send the browser back to our own /auth/{auth_provider_id} url
2025-01-16 05:43:26 +01:00
with the token.
2025-01-09 23:41:32 +01:00
"""
redirect_uri = request.url_for("auth", auth_provider_id=auth_provider_id)
2025-01-02 02:14:30 +01:00
try:
provider: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id)
2025-01-02 03:30:18 +01:00
except AttributeError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
2025-01-19 16:27:12 +01:00
# if (
# code_challenge_method := providers[
# auth_provider_id
2025-01-19 16:27:12 +01:00
# ].code_challenge_method
# ) is not None:
# #client = AsyncOAuth2Client(..., code_challenge_method=code_challenge_method)
# code_verifier = generate_code_verifier()
# logger.debug("TODO: PKCE")
# else:
# code_verifier = None
2025-01-02 03:30:18 +01:00
try:
2025-01-19 14:26:54 +01:00
response = await provider.authorize_redirect(
request,
redirect_uri,
access_type="offline",
2025-01-19 16:27:12 +01:00
code_verifier=None,
2025-01-09 23:41:32 +01:00
)
2025-01-11 20:41:33 +01:00
return response
2025-01-02 02:14:30 +01:00
except HTTPError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
@app.get("/auth/{auth_provider_id}")
async def auth(
request: Request,
auth_provider_id: str,
) -> RedirectResponse:
2025-01-09 23:41:32 +01:00
"""Decrypt the auth token, store it to the session (cookie based)
and response to the browser with a redirect to a "welcome user" page.
"""
2025-01-02 02:14:30 +01:00
try:
provider = providers[auth_provider_id]
except KeyError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
2025-01-02 03:30:18 +01:00
try:
token: OAuth2Token = await provider.authlib_client.authorize_access_token(request)
2025-01-02 02:14:30 +01:00
except OAuthError as error:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error)
# Remember the authlib_client in the session
2025-01-20 01:16:17 +01:00
# logger.info(f"Scope: {token['scope']}")
request.session["auth_provider_id"] = auth_provider_id
#
# One could process the full decoded token which contains extra information
# eg for updates. Here we are only interested in roles
#
if userinfo := token.get("userinfo"):
# Remember the authlib_client in the session
request.session["auth_provider_id"] = auth_provider_id
# User id (sub) given by auth provider
sub = userinfo["sub"]
## Get additional data from userinfo endpoint
# try:
# user_info_from_endpoint = await authlib_client.userinfo(
# token=token, follow_redirects=True
# )
# except Exception as err:
# logger.warn(f"Cannot get userinfo from endpoint: {err}")
# user_info_from_endpoint = {}
# Build and remember the user in the session
request.session["user_sub"] = sub
# Store the user in the database, which also verifies the token validity and signature
2025-02-02 15:54:44 +01:00
try:
user = await db.add_user(
sub,
user_info=userinfo,
auth_provider=providers[auth_provider_id],
access_token=token["access_token"],
2025-02-04 03:38:33 +01:00
)
except PyJWTError as err:
2025-02-02 15:54:44 +01:00
raise HTTPException(
status.HTTP_401_UNAUTHORIZED,
detail=f"Token invalid: {err.__class__.__name__}",
2025-02-02 15:54:44 +01:00
)
assert isinstance(user, User)
# Add the provider session id to the session
request.session["sid"] = provider.get_session_key(userinfo)
2025-01-11 20:41:33 +01:00
# Add the token to the db because it is used for logout
await db.add_token(provider, token)
2025-01-11 20:41:33 +01:00
# Send the user to the home: (s)he is authenticated
return RedirectResponse(url=request.url_for("home"))
2025-01-02 03:30:18 +01:00
else:
2025-01-09 23:41:32 +01:00
# Not sure if it's correct to redirect to plain login
# if no userinfo is provided
return RedirectResponse(url=request.url_for("login", auth_provider_id=auth_provider_id))
2025-01-09 23:41:32 +01:00
@app.get("/account")
async def account(
provider: Annotated[Provider, Depends(get_auth_provider)],
) -> RedirectResponse:
"""Redirect to the auth provider account management,
if account_url_template is in the provider's settings"""
return RedirectResponse(f"{provider.account_url_template}")
2025-01-02 02:14:30 +01:00
@app.get("/logout")
async def logout(
2025-01-09 23:41:32 +01:00
request: Request,
provider: Annotated[Provider, Depends(get_auth_provider)],
) -> RedirectResponse:
2025-01-09 23:41:32 +01:00
# Get provider's endpoint
if (
provider_logout_uri := provider.authlib_client.server_metadata.get("end_session_endpoint")
) is None:
logger.warning(f"Cannot find end_session_endpoint for provider {provider.id}")
2025-01-09 23:41:32 +01:00
return RedirectResponse(request.url_for("non_compliant_logout"))
post_logout_uri = request.url_for("home")
# Clear session
request.session.pop("user_sub", None)
request.session.pop("auth_provider_id", None)
2025-02-04 18:03:17 +01:00
try:
token = await db.get_token(provider, request.session.pop("sid", None))
2025-02-04 18:03:17 +01:00
except TokenNotInDb:
logger.warning("No session in db for the token or no token")
2025-01-09 23:41:32 +01:00
return RedirectResponse(request.url_for("home"))
url_query = {
"post_logout_redirect_uri": post_logout_uri,
"client_id": provider.client_id,
}
if provider.logout_with_id_token_hint:
url_query["id_token_hint"] = token["id_token"]
logout_url = f"{provider_logout_uri}?{urlencode(url_query)}"
2025-01-09 23:41:32 +01:00
return RedirectResponse(logout_url)
2025-01-11 20:41:33 +01:00
@app.get("/non-compliant-logout")
async def non_compliant_logout(
request: Request,
provider: Annotated[Provider, Depends(get_auth_provider)],
2025-01-11 20:41:33 +01:00
):
"""A page for non-compliant OAuth2 servers that we cannot log out."""
# Clear session
request.session.pop("user_sub", None)
request.session.pop("auth_provider_id", None)
2025-01-11 20:41:33 +01:00
return templates.TemplateResponse(
name="non_compliant_logout.html",
request=request,
context={"auth_provider": provider, "home_url": request.url_for("home")},
2025-01-11 20:41:33 +01:00
)
2025-01-13 05:45:31 +01:00
2025-02-08 18:32:02 +01:00
@app.get("/refresh")
async def refresh(
request: Request,
provider: Annotated[Provider, Depends(get_auth_provider)],
token: Annotated[OAuth2Token, Depends(get_token_from_session)],
2025-02-08 18:32:02 +01:00
) -> RedirectResponse:
"""Manually refresh token"""
new_token = await provider.authlib_client.fetch_access_token(
2025-02-08 18:32:02 +01:00
refresh_token=token["refresh_token"],
grant_type="refresh_token",
)
2025-02-20 03:13:41 +01:00
try:
await update_token(provider.id, new_token)
except PyJWTError as err:
logger.info(f"Cannot refresh token: {err.__class__.__name__}")
raise HTTPException(
status.HTTP_510_NOT_EXTENDED, f"Token refresh error: {err.__class__.__name__}"
)
2025-02-08 18:32:02 +01:00
return RedirectResponse(url=request.url_for("home"))
2025-02-17 17:26:30 +01:00
2025-01-19 01:48:00 +01:00
# Snippet for running standalone
# Mostly useful for the --version option,
2025-01-19 16:45:21 +01:00
# as running with uvicorn is easy and provides better flexibility, eg.
# uvicorn --host foo oidc_test.main:app --reload
2025-01-09 23:41:32 +01:00
def main():
from uvicorn import run
from argparse import ArgumentParser
parser = ArgumentParser(description=__doc__)
parser.add_argument(
"-l",
"--host",
type=str,
default="0.0.0.0",
help="Address to listen to (default: 0.0.0.0)",
)
parser.add_argument(
"-p", "--port", type=int, default=80, help="Port to listen to (default: 80)"
)
parser.add_argument("-v", "--version", action="store_true", help="Print version and exit")
2025-01-09 23:41:32 +01:00
args = parser.parse_args()
if args.version:
import sys
from importlib.metadata import version
print(version("oidc-fastapi-test"))
2025-01-09 23:41:32 +01:00
sys.exit(0)
run(app, host=args.host, port=args.port)