oidc-fastapi-test/src/oidc_test/main.py
phil ecdd3702f8
All checks were successful
/ build (push) Successful in 5s
/ test (push) Successful in 5s
Hanle token refresh error
2025-02-20 03:13:41 +01:00

350 lines
12 KiB
Python

"""
Test of OpenId Connect & OAuth2 with FastAPI
"""
from typing import Annotated
from pathlib import Path
from datetime import datetime
import logging
import logging.config
import importlib.resources
from yaml import safe_load
from urllib.parse import urlencode
from contextlib import asynccontextmanager
from httpx import HTTPError
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
from jwt import PyJWTError
from starlette.middleware.sessions import SessionMiddleware
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from authlib.integrations.base_client import OAuthError
from authlib.oauth2.rfc6749 import OAuth2Token
# TODO: PKCE
# from authlib.integrations.httpx_client import AsyncOAuth2Client
# from fastapi.security import OpenIdConnect
# from pkce import generate_code_verifier, generate_pkce_pair
from oidc_test.registry import registry
from oidc_test.auth.provider import NoPublicKey, Provider
from oidc_test.auth.utils import (
get_auth_provider,
get_auth_provider_or_none,
get_current_user_or_none,
authlib_oauth,
get_token_from_session_or_none,
get_token_from_session,
update_token,
)
from oidc_test.auth.utils import init_providers
from oidc_test.settings import settings
from oidc_test.auth_providers import providers
from oidc_test.models import User
from oidc_test.database import TokenNotInDb, db
from oidc_test.resource_server import resource_server
logger = logging.getLogger("oidc-test")
if settings.log:
assert __package__ is not None
with (
importlib.resources.path(__package__) as package_path,
open(package_path / settings.log_config_file) as f,
):
logging_config = safe_load(f)
logging.config.dictConfig(logging_config)
templates = Jinja2Templates(Path(__file__).parent / "templates")
@asynccontextmanager
async def lifespan(app: FastAPI):
assert app is not None
init_providers()
registry.make_registry()
for provider in list(providers.values()):
if provider.disabled:
continue
try:
await provider.get_info()
except NoPublicKey:
logger.warning(f"Disable {provider.id}: public key not found")
del providers[provider.id]
yield
app = FastAPI(title="OIDC auth test", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# SessionMiddleware is required by authlib
app.add_middleware(
SessionMiddleware,
secret_key=settings.secret_key,
)
app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static")
app.mount("/resource", resource_server, name="resource_server")
@app.get("/")
async def home(
request: Request,
user: Annotated[User | None, Depends(get_current_user_or_none)],
provider: Annotated[Provider | None, Depends(get_auth_provider_or_none)],
token: Annotated[OAuth2Token | None, Depends(get_token_from_session_or_none)],
) -> HTMLResponse:
context = {
"show_token": settings.show_token,
"user": user,
"now": datetime.now(),
"auth_provider": provider,
}
if provider is None or token is None:
context["providers"] = providers
context["access_token"] = None
context["id_token_parsed"] = None
context["access_token_parsed"] = None
context["refresh_token_parsed"] = None
context["resources"] = None
else:
context["access_token"] = token["access_token"]
try:
access_token_parsed = provider.decode(token["access_token"], verify_signature=False)
except PyJWTError as err:
access_token_parsed = {"Cannot parse": err.__class__.__name__}
try:
context["access_token_scope"] = access_token_parsed["scope"]
except KeyError:
context["access_token_scope"] = None
context["id_token_parsed"] = provider.decode(token["id_token"], verify_signature=False)
context["access_token_parsed"] = access_token_parsed
context["resources"] = registry.resources
context["resource_providers"] = provider.resource_providers
try:
context["refresh_token_parsed"] = provider.decode(
token["refresh_token"], verify_signature=False
)
except PyJWTError as err:
context["refresh_token_parsed"] = {"Cannot parse": err.__class__.__name__}
return templates.TemplateResponse(name="home.html", request=request, context=context)
# Endpoints for the login / authorization process
@app.get("/login/{auth_provider_id}")
async def login(request: Request, auth_provider_id: str) -> RedirectResponse:
"""Login with the provider id, giving the browser a redirect to its authorize page.
The provider is expected to send the browser back to our own /auth/{auth_provider_id} url
with the token.
"""
redirect_uri = request.url_for("auth", auth_provider_id=auth_provider_id)
try:
provider: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id)
except AttributeError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
# if (
# code_challenge_method := providers[
# auth_provider_id
# ].code_challenge_method
# ) is not None:
# #client = AsyncOAuth2Client(..., code_challenge_method=code_challenge_method)
# code_verifier = generate_code_verifier()
# logger.debug("TODO: PKCE")
# else:
# code_verifier = None
try:
response = await provider.authorize_redirect(
request,
redirect_uri,
access_type="offline",
code_verifier=None,
)
return response
except HTTPError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
@app.get("/auth/{auth_provider_id}")
async def auth(
request: Request,
auth_provider_id: str,
) -> RedirectResponse:
"""Decrypt the auth token, store it to the session (cookie based)
and response to the browser with a redirect to a "welcome user" page.
"""
try:
provider = providers[auth_provider_id]
except KeyError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
try:
token: OAuth2Token = await provider.authlib_client.authorize_access_token(request)
except OAuthError as error:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error)
# Remember the authlib_client in the session
# logger.info(f"Scope: {token['scope']}")
request.session["auth_provider_id"] = auth_provider_id
#
# One could process the full decoded token which contains extra information
# eg for updates. Here we are only interested in roles
#
if userinfo := token.get("userinfo"):
# Remember the authlib_client in the session
request.session["auth_provider_id"] = auth_provider_id
# User id (sub) given by auth provider
sub = userinfo["sub"]
## Get additional data from userinfo endpoint
# try:
# user_info_from_endpoint = await authlib_client.userinfo(
# token=token, follow_redirects=True
# )
# except Exception as err:
# logger.warn(f"Cannot get userinfo from endpoint: {err}")
# user_info_from_endpoint = {}
# Build and remember the user in the session
request.session["user_sub"] = sub
# Store the user in the database, which also verifies the token validity and signature
try:
user = await db.add_user(
sub,
user_info=userinfo,
auth_provider=providers[auth_provider_id],
access_token=token["access_token"],
)
except PyJWTError as err:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED,
detail=f"Token invalid: {err.__class__.__name__}",
)
assert isinstance(user, User)
# Add the provider session id to the session
request.session["sid"] = provider.get_session_key(userinfo)
# Add the token to the db because it is used for logout
await db.add_token(provider, token)
# Send the user to the home: (s)he is authenticated
return RedirectResponse(url=request.url_for("home"))
else:
# Not sure if it's correct to redirect to plain login
# if no userinfo is provided
return RedirectResponse(url=request.url_for("login", auth_provider_id=auth_provider_id))
@app.get("/account")
async def account(
provider: Annotated[Provider, Depends(get_auth_provider)],
) -> RedirectResponse:
"""Redirect to the auth provider account management,
if account_url_template is in the provider's settings"""
return RedirectResponse(f"{provider.account_url_template}")
@app.get("/logout")
async def logout(
request: Request,
provider: Annotated[Provider, Depends(get_auth_provider)],
) -> RedirectResponse:
# Get provider's endpoint
if (
provider_logout_uri := provider.authlib_client.server_metadata.get("end_session_endpoint")
) is None:
logger.warning(f"Cannot find end_session_endpoint for provider {provider.id}")
return RedirectResponse(request.url_for("non_compliant_logout"))
post_logout_uri = request.url_for("home")
# Clear session
request.session.pop("user_sub", None)
request.session.pop("auth_provider_id", None)
try:
token = await db.get_token(provider, request.session.pop("sid", None))
except TokenNotInDb:
logger.warning("No session in db for the token or no token")
return RedirectResponse(request.url_for("home"))
url_query = {
"post_logout_redirect_uri": post_logout_uri,
"client_id": provider.client_id,
}
if provider.logout_with_id_token_hint:
url_query["id_token_hint"] = token["id_token"]
logout_url = f"{provider_logout_uri}?{urlencode(url_query)}"
return RedirectResponse(logout_url)
@app.get("/non-compliant-logout")
async def non_compliant_logout(
request: Request,
provider: Annotated[Provider, Depends(get_auth_provider)],
):
"""A page for non-compliant OAuth2 servers that we cannot log out."""
# Clear session
request.session.pop("user_sub", None)
request.session.pop("auth_provider_id", None)
return templates.TemplateResponse(
name="non_compliant_logout.html",
request=request,
context={"auth_provider": provider, "home_url": request.url_for("home")},
)
@app.get("/refresh")
async def refresh(
request: Request,
provider: Annotated[Provider, Depends(get_auth_provider)],
token: Annotated[OAuth2Token, Depends(get_token_from_session)],
) -> RedirectResponse:
"""Manually refresh token"""
new_token = await provider.authlib_client.fetch_access_token(
refresh_token=token["refresh_token"],
grant_type="refresh_token",
)
try:
await update_token(provider.id, new_token)
except PyJWTError as err:
logger.info(f"Cannot refresh token: {err.__class__.__name__}")
raise HTTPException(
status.HTTP_510_NOT_EXTENDED, f"Token refresh error: {err.__class__.__name__}"
)
return RedirectResponse(url=request.url_for("home"))
# Snippet for running standalone
# Mostly useful for the --version option,
# as running with uvicorn is easy and provides better flexibility, eg.
# uvicorn --host foo oidc_test.main:app --reload
def main():
from uvicorn import run
from argparse import ArgumentParser
parser = ArgumentParser(description=__doc__)
parser.add_argument(
"-l",
"--host",
type=str,
default="0.0.0.0",
help="Address to listen to (default: 0.0.0.0)",
)
parser.add_argument(
"-p", "--port", type=int, default=80, help="Port to listen to (default: 80)"
)
parser.add_argument("-v", "--version", action="store_true", help="Print version and exit")
args = parser.parse_args()
if args.version:
import sys
from importlib.metadata import version
print(version("oidc-fastapi-test"))
sys.exit(0)
run(app, host=args.host, port=args.port)