279 lines
8.8 KiB
Python
279 lines
8.8 KiB
Python
"""
|
|
Test of OpenId Connect & OAuth2 with FastAPI
|
|
"""
|
|
|
|
from typing import Annotated
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
import logging
|
|
from urllib.parse import urlencode
|
|
|
|
from httpx import HTTPError
|
|
from fastapi import Depends, FastAPI, HTTPException, Request, Response, status
|
|
from fastapi.responses import HTMLResponse, RedirectResponse
|
|
from fastapi.templating import Jinja2Templates
|
|
from fastapi.security import OpenIdConnect
|
|
from starlette.middleware.sessions import SessionMiddleware
|
|
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
|
from authlib.integrations.starlette_client import OAuth, OAuthError
|
|
|
|
# authlib startlette integration does not support revocation: using requests
|
|
# from authlib.integrations.requests_client import OAuth2Session
|
|
|
|
from .settings import settings
|
|
from .models import User
|
|
from .auth_utils import (
|
|
get_provider,
|
|
hasrole,
|
|
get_current_user_or_none,
|
|
get_current_user,
|
|
authlib_oauth,
|
|
)
|
|
from .auth_misc import pretty_details
|
|
from .database import db
|
|
|
|
# logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
templates = Jinja2Templates(Path(__file__).parent / "templates")
|
|
|
|
|
|
app = FastAPI(
|
|
title="OIDC auth test",
|
|
)
|
|
|
|
|
|
# SessionMiddleware is required by authlib
|
|
app.add_middleware(
|
|
SessionMiddleware,
|
|
secret_key=settings.secret_key,
|
|
)
|
|
|
|
# Add oidc providers to authlib from the settings
|
|
|
|
fastapi_providers = {}
|
|
_providers = {}
|
|
|
|
for provider in settings.oidc.providers:
|
|
authlib_oauth.register(
|
|
name=provider.id,
|
|
server_metadata_url=provider.openid_configuration,
|
|
client_kwargs={
|
|
"scope": "openid email offline_access profile roles",
|
|
},
|
|
client_id=provider.client_id,
|
|
client_secret=provider.client_secret,
|
|
# fetch_token=fetch_token,
|
|
# update_token=update_token,
|
|
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
|
|
)
|
|
fastapi_providers[provider.id] = OpenIdConnect(
|
|
openIdConnectUrl=provider.openid_configuration
|
|
)
|
|
_providers[provider.id] = provider
|
|
|
|
|
|
# Endpoints for the login / authorization process
|
|
|
|
|
|
@app.get("/login/{oidc_provider_id}")
|
|
async def login(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
|
"""Login with the provider id,
|
|
by giving the browser a redirect to its authorize page.
|
|
After successful authentification, the provider replies with an encrypted
|
|
auth token that only we can decode and contains userinfo,
|
|
and a redirect to our own /auth/{oidc_provider_id} url
|
|
"""
|
|
redirect_uri = request.url_for("auth", oidc_provider_id=oidc_provider_id)
|
|
try:
|
|
provider_: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
|
|
except AttributeError:
|
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
|
try:
|
|
return await provider_.authorize_redirect(
|
|
request, redirect_uri, access_type="offline"
|
|
)
|
|
except HTTPError:
|
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
|
|
|
|
|
|
@app.get("/auth/{oidc_provider_id}")
|
|
async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
|
"""Decrypt the auth token, store it to the session (cookie based)
|
|
and response to the browser with a redirect to a "welcome user" page.
|
|
"""
|
|
try:
|
|
oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
|
|
except AttributeError:
|
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
|
try:
|
|
token = await oidc_provider.authorize_access_token(request)
|
|
except OAuthError as error:
|
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error)
|
|
# Remember the oidc_provider in the session
|
|
request.session["oidc_provider_id"] = oidc_provider_id
|
|
#
|
|
# One could process the full decoded token which contains extra information
|
|
# eg for updates. Here we are only interested in roles
|
|
#
|
|
if userinfo := token.get("userinfo"):
|
|
# sub given by oidc provider
|
|
sub = userinfo["sub"]
|
|
# Build and remember the user in the session
|
|
request.session["user_sub"] = sub
|
|
# Store the user in the database
|
|
user = await db.add_user(sub, user_info=userinfo, oidc_provider=oidc_provider)
|
|
request.session["token"] = userinfo["sub"]
|
|
await db.add_token(token, user)
|
|
return RedirectResponse(url=request.url_for("home"))
|
|
else:
|
|
# Not sure if it's correct to redirect to plain login
|
|
# if no userinfo is provided
|
|
return RedirectResponse(
|
|
url=request.url_for("login", oidc_provider_id=oidc_provider_id)
|
|
)
|
|
|
|
|
|
@app.get("/non-compliant-logout")
|
|
async def non_compliant_logout(
|
|
request: Request,
|
|
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
|
|
):
|
|
"""A page for non-compliant OAuth2 servers that we cannot log out."""
|
|
return templates.TemplateResponse(
|
|
name="non_compliant_logout.html",
|
|
request=request,
|
|
context={"provider": provider, "home_url": request.url_for("home")},
|
|
)
|
|
|
|
|
|
@app.get("/logout")
|
|
async def logout(
|
|
request: Request,
|
|
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
|
|
) -> RedirectResponse:
|
|
# Clear session
|
|
request.session.pop("user_sub", None)
|
|
# Get provider's endpoint
|
|
if (
|
|
provider_logout_uri := provider.server_metadata.get("end_session_endpoint")
|
|
) is None:
|
|
logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}")
|
|
return RedirectResponse(request.url_for("non_compliant_logout"))
|
|
post_logout_uri = request.url_for("home")
|
|
if (id_token := await db.get_token(request.session["token"])) is None:
|
|
logger.warn("No session in db for the token")
|
|
return RedirectResponse(request.url_for("home"))
|
|
logout_url = (
|
|
provider_logout_uri
|
|
+ "?"
|
|
+ urlencode(
|
|
{
|
|
"post_logout_redirect_uri": post_logout_uri,
|
|
"id_token_hint": id_token.raw_id_token,
|
|
"cliend_id": "oidc_local_test",
|
|
}
|
|
)
|
|
)
|
|
return RedirectResponse(logout_url)
|
|
|
|
|
|
# Home URL
|
|
|
|
|
|
@app.get("/")
|
|
async def home(
|
|
request: Request, user: Annotated[User, Depends(get_current_user_or_none)]
|
|
) -> HTMLResponse:
|
|
now = datetime.now()
|
|
return templates.TemplateResponse(
|
|
name="home.html",
|
|
request=request,
|
|
context={
|
|
"settings": settings.model_dump(),
|
|
"user": user,
|
|
"now": now,
|
|
"user_info_details": (
|
|
pretty_details(user, now)
|
|
if user and settings.oidc.show_session_details
|
|
else None
|
|
),
|
|
},
|
|
)
|
|
|
|
|
|
@app.get("/public")
|
|
async def public() -> HTMLResponse:
|
|
return HTMLResponse("<h1>Not protected</h1>")
|
|
|
|
|
|
# Some URIs for testing the permissions
|
|
|
|
|
|
@app.get("/protected")
|
|
async def get_protected(
|
|
user: Annotated[User, Depends(get_current_user)]
|
|
) -> HTMLResponse:
|
|
return HTMLResponse("<h1>Only authenticated users can see this</h1>")
|
|
|
|
|
|
@app.get("/protected-by-foorole")
|
|
@hasrole("foorole")
|
|
async def get_protected_by_foorole(request: Request) -> HTMLResponse:
|
|
return HTMLResponse("<h1>Only users with foorole can see this</h1>")
|
|
|
|
|
|
@app.get("/protected-by-barrole")
|
|
@hasrole("barrole")
|
|
async def get_protected_by_barrole(request: Request) -> HTMLResponse:
|
|
return HTMLResponse("<h1>Protected by barrole</h1>")
|
|
|
|
|
|
@app.get("/protected-by-foorole-and-barrole")
|
|
@hasrole("barrole")
|
|
@hasrole("foorole")
|
|
async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse:
|
|
return HTMLResponse("<h1>Only users with foorole and barrole can see this</h1>")
|
|
|
|
|
|
@app.get("/protected-by-foorole-or-barrole")
|
|
@hasrole(["foorole", "barrole"])
|
|
async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse:
|
|
return HTMLResponse("<h1>Only users with foorole or barrole can see this</h1>")
|
|
|
|
|
|
# @app.get("/fast_api_depends")
|
|
# def fast_api_depends(
|
|
# token: Annotated[str, Depends(fastapi_providers["Keycloak"])]
|
|
# ) -> HTMLResponse:
|
|
# return HTMLResponse("You're Authenticated")
|
|
|
|
|
|
def main():
|
|
from uvicorn import run
|
|
from argparse import ArgumentParser
|
|
|
|
parser = ArgumentParser(description=__doc__)
|
|
parser.add_argument(
|
|
"-l",
|
|
"--host",
|
|
type=str,
|
|
default="0.0.0.0",
|
|
help="Address to listen to (default: 0.0.0.0)",
|
|
)
|
|
parser.add_argument(
|
|
"-p", "--port", type=int, default=80, help="Port to listen to (default: 80)"
|
|
)
|
|
parser.add_argument(
|
|
"-v", "--version", action="store_true", help="Print version and exit"
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
if args.version:
|
|
import sys
|
|
from importlib.metadata import version
|
|
|
|
print(version("oidc-fastapi-test"))
|
|
sys.exit(0)
|
|
|
|
run(app, host=args.host, port=args.port)
|