Refactor most code, isolate authlib somehow
All checks were successful
/ build (push) Successful in 6s
/ test (push) Successful in 5s

This commit is contained in:
phil 2025-02-09 06:20:48 +01:00
parent 38b983c2a5
commit c5bb4f4319
10 changed files with 183 additions and 218 deletions

View file

@ -1,29 +0,0 @@
from datetime import datetime, timedelta
from collections import OrderedDict
from .models import User
time_keys = set(("iat", "exp", "auth_time", "updated_at"))
def pretty_details(user: User, now: datetime) -> OrderedDict:
details = OrderedDict()
# breakpoint()
for key in sorted(time_keys):
try:
dt = datetime.fromtimestamp(user.userinfo[key])
except (KeyError, TypeError):
pass
else:
td = now - dt
td = timedelta(days=td.days, seconds=td.seconds)
if td.days < 0:
ptd = f"in {-td} h:m:s"
else:
ptd = f"{td} h:m:s ago"
details[key] = f"{user.userinfo[key]} - {dt} ({ptd})"
for key in sorted(user.userinfo):
if key in time_keys:
continue
details[key] = user.userinfo[key]
return details

View file

@ -0,0 +1,43 @@
from typing import Any
from jwt import decode
import logging
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from .settings import AuthProviderSettings, settings
logger = logging.getLogger("oidc-test")
class Provider(AuthProviderSettings):
class Config:
arbitrary_types_allowed = True
authlib_client: StarletteOAuth2App = StarletteOAuth2App(None)
def decode(self, token: str, verify_signature: bool = True) -> dict[str, Any]:
"""Decode the token with signature check"""
if settings.debug_token:
decoded = decode(
token,
self.get_public_key(),
algorithms=[self.signature_alg],
audience=["account", "oidc-test", "oidc-test-web"],
options={
"verify_signature": False,
"verify_aud": False,
}, # not settings.insecure.skip_verify_signature},
)
logger.debug(str(decoded))
return decode(
token,
self.get_public_key(),
algorithms=[self.signature_alg],
audience=["account", "oidc-test", "oidc-test-web"],
options={
"verify_signature": verify_signature,
}, # not settings.insecure.skip_verify_signature},
)
providers: dict[str, Provider] = {}

View file

@ -1,4 +1,3 @@
import re
from typing import Union, Annotated from typing import Union, Annotated
from functools import wraps from functools import wraps
import logging import logging
@ -14,7 +13,8 @@ from authlib.oauth2.auth import OAuth2Token
from .models import User from .models import User
from .database import db, TokenNotInDb, UserNotInDB from .database import db, TokenNotInDb, UserNotInDB
from .settings import oidc_providers_settings from .settings import settings
from .auth_provider import providers, Provider
logger = logging.getLogger("oidc-test") logger = logging.getLogger("oidc-test")
@ -35,11 +35,13 @@ async def fetch_token(name, request):
# return token.to_token() # return token.to_token()
async def update_token(name, token, refresh_token=None, access_token=None): async def update_token(
provider_id, token, refresh_token: str | None = None, access_token: str | None = None
):
"""Update the token in the database""" """Update the token in the database"""
oidc_provider_settings = oidc_providers_settings[name] provider = providers[provider_id]
sid: str = oidc_provider_settings.decode(token["id_token"])["sid"] sid: str = provider.decode(token["id_token"])["sid"]
item = await db.get_token(oidc_provider_settings, sid) item = await db.get_token(provider, sid)
# update old token # update old token
item["access_token"] = token["access_token"] item["access_token"] = token["access_token"]
item["refresh_token"] = token["refresh_token"] item["refresh_token"] = token["refresh_token"]
@ -54,10 +56,12 @@ authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_t
def init_providers(): def init_providers():
# Add oidc providers to authlib from the settings """Add oidc providers to authlib from the settings
for id, provider in oidc_providers_settings.items(): and build the providers dict"""
for provider_settings in settings.auth.providers:
provider = Provider(**provider_settings.model_dump())
authlib_oauth.register( authlib_oauth.register(
name=id, name=provider.id,
server_metadata_url=provider.openid_configuration, server_metadata_url=provider.openid_configuration,
client_kwargs={ client_kwargs={
"scope": " ".join( "scope": " ".join(
@ -74,6 +78,8 @@ def init_providers():
update_token=update_token, update_token=update_token,
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
) )
provider.authlib_client = getattr(authlib_oauth, provider.id)
providers[provider.id] = provider
init_providers() init_providers()
@ -82,33 +88,41 @@ init_providers()
async def get_providers_info(): async def get_providers_info():
# Get the public key: # Get the public key:
async with AsyncClient() as client: async with AsyncClient() as client:
for provider_settings in oidc_providers_settings.values(): for provider in providers.values():
if provider_settings.info_url: if provider.info_url:
provider_info = await client.get(provider_settings.url) provider_info = await client.get(provider.url)
provider_settings.info = provider_info.json() provider.info = provider_info.json()
def get_oidc_provider_or_none(request: Request) -> StarletteOAuth2App | None: def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None:
"""Return the oidc_provider from a request object, from the session. """Return the oidc_provider from a request object, from the session.
It can be used in Depends()""" It can be used in Depends()"""
if (oidc_provider_id := request.session.get("oidc_provider_id")) is None: if (auth_provider_id := request.session.get("auth_provider_id")) is None:
return
try:
return getattr(authlib_oauth, str(oidc_provider_id))
except AttributeError:
return return
return getattr(authlib_oauth, str(auth_provider_id), None)
def get_oidc_provider(request: Request) -> StarletteOAuth2App: def get_auth_provider_client(request: Request) -> StarletteOAuth2App:
if (oidc_provider := get_oidc_provider_or_none(request)) is None: if (oidc_provider := get_auth_provider_client_or_none(request)) is None:
if oidc_provider is None: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No provider")
else:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
else: else:
return oidc_provider return oidc_provider
def get_auth_provider_or_none(request: Request) -> Provider | None:
"""Return the oidc_provider settings from a request object, from the session.
It can be used in Depends()"""
if (auth_provider_id := request.session.get("auth_provider_id")) is None:
return
return providers.get(auth_provider_id)
def get_auth_provider(request: Request) -> Provider:
if (provider := get_auth_provider_or_none(request)) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
return provider
async def get_current_user(request: Request) -> User: async def get_current_user(request: Request) -> User:
"""Get the current user from a request object. """Get the current user from a request object.
Also validates the token expiration time. Also validates the token expiration time.
@ -120,11 +134,11 @@ async def get_current_user(request: Request) -> User:
user = await db.get_user(user_sub) user = await db.get_user(user_sub)
## Check if the token is expired ## Check if the token is expired
if token.is_expired(): if token.is_expired():
oidc_provider = get_oidc_provider(request=request) provider = get_auth_provider(request=request)
## Ask a new refresh token from the provider ## Ask a new refresh token from the provider
logger.info(f"Token expired for user {user.name}") logger.info(f"Token expired for user {user.name}")
try: try:
userinfo = await oidc_provider.fetch_access_token( userinfo = await provider.authlib_client.fetch_access_token(
refresh_token=token.get("refresh_token") refresh_token=token.get("refresh_token")
) )
assert userinfo is not None assert userinfo is not None
@ -150,14 +164,12 @@ async def get_token(request: Request) -> OAuth2Token:
"""Return the token from the session. """Return the token from the session.
Can be used in Depends()""" Can be used in Depends()"""
try: try:
oidc_provider_settings = oidc_providers_settings[ provider = providers[request.session.get("auth_provider_id", "")]
request.session.get("oidc_provider_id", "")
]
except KeyError: except KeyError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid provider") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid provider")
try: try:
return await db.get_token( return await db.get_token(
oidc_provider_settings, provider,
request.session.get("sid"), request.session.get("sid"),
) )
except (TokenNotInDb, InvalidKeyError, DecodeError) as err: except (TokenNotInDb, InvalidKeyError, DecodeError) as err:
@ -219,7 +231,7 @@ async def get_user_from_token(
"Request headers must have a 'auth_provider' field", "Request headers must have a 'auth_provider' field",
) )
try: try:
auth_provider_settings = oidc_providers_settings[auth_provider_id] provider = providers[auth_provider_id]
except KeyError: except KeyError:
if auth_provider_id == "": if auth_provider_id == "":
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No auth provider") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No auth provider")
@ -230,7 +242,7 @@ async def get_user_from_token(
if token == "": if token == "":
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token")
try: try:
payload = auth_provider_settings.decode(token) payload = provider.decode(token)
except ExpiredSignatureError as err: except ExpiredSignatureError as err:
logger.info(f"Expired signature: {err}") logger.info(f"Expired signature: {err}")
raise HTTPException( raise HTTPException(
@ -261,7 +273,7 @@ async def get_user_from_token(
user = await db.add_user( user = await db.add_user(
sub=payload["sub"], sub=payload["sub"],
user_info=payload, user_info=payload,
oidc_provider=getattr(authlib_oauth, auth_provider_id), auth_provider=getattr(authlib_oauth, auth_provider_id),
access_token=token, access_token=token,
) )
return user return user

View file

@ -2,11 +2,10 @@
import logging import logging
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from authlib.oauth2.rfc6749 import OAuth2Token from authlib.oauth2.rfc6749 import OAuth2Token
from .settings import OIDCProvider, oidc_providers_settings
from .models import User, Role from .models import User, Role
from .auth_provider import Provider, providers
logger = logging.getLogger("oidc-test") logger = logging.getLogger("oidc-test")
@ -29,23 +28,23 @@ class Database:
self, self,
sub: str, sub: str,
user_info: dict, user_info: dict,
oidc_provider: StarletteOAuth2App, auth_provider: Provider,
access_token: str, access_token: str,
access_token_decoded: dict | None = None, access_token_decoded: dict | None = None,
) -> User: ) -> User:
if access_token_decoded is None: if access_token_decoded is None:
assert oidc_provider.name is not None assert auth_provider.name is not None
oidc_provider_settings = oidc_providers_settings[oidc_provider.name] provider = providers[auth_provider.id]
access_token_decoded = oidc_provider_settings.decode(access_token) access_token_decoded = provider.decode(access_token)
user_info["auth_provider_id"] = auth_provider.id
user = User(**user_info) user = User(**user_info)
user.userinfo = user_info user.userinfo = user_info
user.oidc_provider = oidc_provider # user.access_token = access_token
user.access_token = access_token # user.access_token_decoded = access_token_decoded
user.access_token_decoded = access_token_decoded
# Add roles provided in the access token # Add roles provided in the access token
roles = set() roles = set()
try: try:
r = access_token_decoded["resource_access"][oidc_provider.client_id]["roles"] r = access_token_decoded["resource_access"][auth_provider.client_id]["roles"]
roles.update(r) roles.update(r)
except KeyError: except KeyError:
pass pass
@ -66,19 +65,19 @@ class Database:
raise UserNotInDB raise UserNotInDB
return self.users[sub] return self.users[sub]
async def add_token(self, oidc_provider_settings: OIDCProvider, token: OAuth2Token) -> None: async def add_token(self, provider: Provider, token: OAuth2Token) -> None:
"""Store a token using as key the sid (auth provider's session id) """Store a token using as key the sid (auth provider's session id)
in the id_token""" in the id_token"""
assert isinstance(oidc_provider_settings, OIDCProvider) assert isinstance(provider, Provider)
sid = token["userinfo"]["sid"] sid = token["userinfo"]["sid"]
self.tokens[sid] = token self.tokens[sid] = token
async def get_token( async def get_token(
self, self,
oidc_provider_settings: OIDCProvider, provider: Provider,
sid: str | None, sid: str | None,
) -> OAuth2Token: ) -> OAuth2Token:
assert isinstance(oidc_provider_settings, OIDCProvider) assert isinstance(provider, Provider)
if sid is None: if sid is None:
raise TokenNotInDb raise TokenNotInDb
try: try:

View file

@ -15,7 +15,7 @@ from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from jwt import InvalidTokenError, PyJWTError from jwt import PyJWTError
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from authlib.integrations.base_client import OAuthError from authlib.integrations.base_client import OAuthError
@ -26,11 +26,12 @@ from authlib.oauth2.rfc6749 import OAuth2Token
# from fastapi.security import OpenIdConnect # from fastapi.security import OpenIdConnect
# from pkce import generate_code_verifier, generate_pkce_pair # from pkce import generate_code_verifier, generate_pkce_pair
from .settings import settings, oidc_providers_settings from .settings import settings
from .auth_provider import Provider, providers
from .models import User from .models import User
from .auth_utils import ( from .auth_utils import (
get_oidc_provider, get_auth_provider,
get_oidc_provider_or_none, get_auth_provider_or_none,
get_current_user_or_none, get_current_user_or_none,
authlib_oauth, authlib_oauth,
get_providers_info, get_providers_info,
@ -38,7 +39,6 @@ from .auth_utils import (
get_token, get_token,
update_token, update_token,
) )
from .auth_misc import pretty_details
from .database import TokenNotInDb, db from .database import TokenNotInDb, db
from .resource_server import resource_server from .resource_server import resource_server
@ -78,51 +78,30 @@ app.mount("/resource", resource_server, name="resource_server")
async def home( async def home(
request: Request, request: Request,
user: Annotated[User, Depends(get_current_user_or_none)], user: Annotated[User, Depends(get_current_user_or_none)],
oidc_provider: Annotated[StarletteOAuth2App | None, Depends(get_oidc_provider_or_none)], provider: Annotated[Provider | None, Depends(get_auth_provider_or_none)],
token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], token: Annotated[OAuth2Token | None, Depends(get_token_or_none)],
) -> HTMLResponse: ) -> HTMLResponse:
now = datetime.now()
if oidc_provider and (
(
oidc_provider_settings := oidc_providers_settings.get(
request.session.get("oidc_provider_id", "")
)
)
is not None
):
resources = oidc_provider_settings.resources
else:
resources = []
oidc_provider_settings = None
context = { context = {
"settings": settings.model_dump(), "settings": settings.model_dump(),
"user": user, "user": user,
"now": now, "now": datetime.now(),
"oidc_provider": oidc_provider, "auth_provider": provider,
"oidc_provider_settings": oidc_provider_settings,
"resources": resources,
} }
if token is None: if provider is None or token is None:
context["access_token"] = None context["access_token"] = None
context["id_token_parsed"] = None context["id_token_parsed"] = None
context["access_token_parsed"] = None context["access_token_parsed"] = None
context["refresh_token_parsed"] = None context["refresh_token_parsed"] = None
context["resources"] = None
else: else:
context["access_token"] = token["access_token"] context["access_token"] = token["access_token"]
assert oidc_provider is not None context["resources"] = provider.resources
assert oidc_provider.name is not None access_token_parsed = provider.decode(token["access_token"], verify_signature=False)
oidc_provider_settings = oidc_providers_settings[oidc_provider.name]
access_token_parsed = oidc_provider_settings.decode(
token["access_token"], verify_signature=False
)
context["access_token_scope"] = access_token_parsed["scope"] context["access_token_scope"] = access_token_parsed["scope"]
# context["id_token_parsed"] = pretty_details(user, now) # context["id_token_parsed"] = pretty_details(user, now)
context["id_token_parsed"] = oidc_provider_settings.decode( context["id_token_parsed"] = provider.decode(token["id_token"], verify_signature=False)
token["id_token"], verify_signature=False
)
context["access_token_parsed"] = access_token_parsed context["access_token_parsed"] = access_token_parsed
context["refresh_token_parsed"] = oidc_provider_settings.decode( context["refresh_token_parsed"] = provider.decode(
token["refresh_token"], verify_signature=False token["refresh_token"], verify_signature=False
) )
return templates.TemplateResponse(name="home.html", request=request, context=context) return templates.TemplateResponse(name="home.html", request=request, context=context)
@ -131,20 +110,20 @@ async def home(
# Endpoints for the login / authorization process # Endpoints for the login / authorization process
@app.get("/login/{oidc_provider_id}") @app.get("/login/{auth_provider_id}")
async def login(request: Request, oidc_provider_id: str) -> RedirectResponse: async def login(request: Request, auth_provider_id: str) -> RedirectResponse:
"""Login with the provider id, giving the browser a redirect to its authorize page. """Login with the provider id, giving the browser a redirect to its authorize page.
The provider is expected to send the browser back to our own /auth/{oidc_provider_id} url The provider is expected to send the browser back to our own /auth/{auth_provider_id} url
with the token. with the token.
""" """
redirect_uri = request.url_for("auth", oidc_provider_id=oidc_provider_id) redirect_uri = request.url_for("auth", auth_provider_id=auth_provider_id)
try: try:
provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id) provider: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id)
except AttributeError: except AttributeError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
# if ( # if (
# code_challenge_method := oidc_providers_settings[ # code_challenge_method := providers[
# oidc_provider_id # auth_provider_id
# ].code_challenge_method # ].code_challenge_method
# ) is not None: # ) is not None:
# #client = AsyncOAuth2Client(..., code_challenge_method=code_challenge_method) # #client = AsyncOAuth2Client(..., code_challenge_method=code_challenge_method)
@ -164,30 +143,30 @@ async def login(request: Request, oidc_provider_id: str) -> RedirectResponse:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
@app.get("/auth/{oidc_provider_id}") @app.get("/auth/{auth_provider_id}")
async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: async def auth(request: Request, auth_provider_id: str) -> RedirectResponse:
"""Decrypt the auth token, store it to the session (cookie based) """Decrypt the auth token, store it to the session (cookie based)
and response to the browser with a redirect to a "welcome user" page. and response to the browser with a redirect to a "welcome user" page.
""" """
try: try:
oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id) authlib_client: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id)
except AttributeError: except AttributeError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
try: try:
token: OAuth2Token = await oidc_provider.authorize_access_token(request) token: OAuth2Token = await authlib_client.authorize_access_token(request)
except OAuthError as error: except OAuthError as error:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error) raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error)
# Remember the oidc_provider in the session # Remember the authlib_client in the session
# logger.info(f"Scope: {token['scope']}") # logger.info(f"Scope: {token['scope']}")
request.session["oidc_provider_id"] = oidc_provider_id request.session["auth_provider_id"] = auth_provider_id
# #
# One could process the full decoded token which contains extra information # One could process the full decoded token which contains extra information
# eg for updates. Here we are only interested in roles # eg for updates. Here we are only interested in roles
# #
if userinfo := token.get("userinfo"): if userinfo := token.get("userinfo"):
# Remember the oidc_provider in the session # Remember the authlib_client in the session
request.session["oidc_provider_id"] = oidc_provider_id request.session["auth_provider_id"] = auth_provider_id
# User id (sub) given by oidc provider # User id (sub) given by auth provider
sub = userinfo["sub"] sub = userinfo["sub"]
# Build and remember the user in the session # Build and remember the user in the session
request.session["user_sub"] = sub request.session["user_sub"] = sub
@ -196,7 +175,7 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
user = await db.add_user( user = await db.add_user(
sub, sub,
user_info=userinfo, user_info=userinfo,
oidc_provider=oidc_provider, auth_provider=providers[auth_provider_id],
access_token=token["access_token"], access_token=token["access_token"],
) )
except PyJWTError as err: except PyJWTError as err:
@ -208,48 +187,41 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
# Add the provider session id to the session # Add the provider session id to the session
request.session["sid"] = userinfo["sid"] request.session["sid"] = userinfo["sid"]
# Add the token to the db because it is used for logout # Add the token to the db because it is used for logout
assert oidc_provider.name is not None provider = providers[auth_provider_id]
oidc_provider_settings = oidc_providers_settings[oidc_provider.name] await db.add_token(provider, token)
await db.add_token(oidc_provider_settings, token)
# Send the user to the home: (s)he is authenticated # Send the user to the home: (s)he is authenticated
return RedirectResponse(url=request.url_for("home")) return RedirectResponse(url=request.url_for("home"))
else: else:
# Not sure if it's correct to redirect to plain login # Not sure if it's correct to redirect to plain login
# if no userinfo is provided # if no userinfo is provided
return RedirectResponse(url=request.url_for("login", oidc_provider_id=oidc_provider_id)) return RedirectResponse(url=request.url_for("login", auth_provider_id=auth_provider_id))
@app.get("/account") @app.get("/account")
async def account( async def account(
request: Request, provider: Annotated[Provider, Depends(get_auth_provider)],
) -> RedirectResponse: ) -> RedirectResponse:
if ( """Redirect to the auth provider account management,
oidc_provider_settings := oidc_providers_settings.get( if account_url_template is in the provider's settings"""
request.session.get("oidc_provider_id", "") return RedirectResponse(f"{provider.account_url_template}")
)
) is None:
raise HTTPException(status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider settings")
return RedirectResponse(f"{oidc_provider_settings.account_url_template}")
@app.get("/logout") @app.get("/logout")
async def logout( async def logout(
request: Request, request: Request,
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], provider: Annotated[Provider, Depends(get_auth_provider)],
) -> RedirectResponse: ) -> RedirectResponse:
# Clear session # Clear session
request.session.pop("user_sub", None) request.session.pop("user_sub", None)
# Get provider's endpoint # Get provider's endpoint
if (provider_logout_uri := oidc_provider.server_metadata.get("end_session_endpoint")) is None: if (
logger.warn(f"Cannot find end_session_endpoint for provider {oidc_provider.name}") provider_logout_uri := provider.authlib_client.server_metadata.get("end_session_endpoint")
) is None:
logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}")
return RedirectResponse(request.url_for("non_compliant_logout")) return RedirectResponse(request.url_for("non_compliant_logout"))
post_logout_uri = request.url_for("home") post_logout_uri = request.url_for("home")
oidc_provider_settings = oidc_providers_settings.get(
request.session.get("oidc_provider_id", "")
)
assert oidc_provider_settings is not None
try: try:
token = await db.get_token(oidc_provider_settings, request.session.pop("sid", None)) token = await db.get_token(provider, request.session.pop("sid", None))
except TokenNotInDb: except TokenNotInDb:
logger.warn("No session in db for the token or no token") logger.warn("No session in db for the token or no token")
return RedirectResponse(request.url_for("home")) return RedirectResponse(request.url_for("home"))
@ -270,30 +242,30 @@ async def logout(
@app.get("/non-compliant-logout") @app.get("/non-compliant-logout")
async def non_compliant_logout( async def non_compliant_logout(
request: Request, request: Request,
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], provider: Annotated[StarletteOAuth2App, Depends(get_auth_provider)],
): ):
"""A page for non-compliant OAuth2 servers that we cannot log out.""" """A page for non-compliant OAuth2 servers that we cannot log out."""
# Clear the remain of the session # Clear the remain of the session
request.session.pop("oidc_provider_id", None) request.session.pop("auth_provider_id", None)
return templates.TemplateResponse( return templates.TemplateResponse(
name="non_compliant_logout.html", name="non_compliant_logout.html",
request=request, request=request,
context={"oidc_provider": oidc_provider, "home_url": request.url_for("home")}, context={"oidc_provider": provider, "home_url": request.url_for("home")},
) )
@app.get("/refresh") @app.get("/refresh")
async def refresh( async def refresh(
request: Request, request: Request,
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], provider: Annotated[Provider, Depends(get_auth_provider)],
token: Annotated[OAuth2Token, Depends(get_token)], token: Annotated[OAuth2Token, Depends(get_token)],
) -> RedirectResponse: ) -> RedirectResponse:
"""Manually refresh token""" """Manually refresh token"""
new_token = await oidc_provider.fetch_access_token( new_token = await provider.authlib_client.fetch_access_token(
refresh_token=token["refresh_token"], refresh_token=token["refresh_token"],
grant_type="refresh_token", grant_type="refresh_token",
) )
await update_token(oidc_provider.name, new_token) await update_token(provider.id, new_token)
return RedirectResponse(url=request.url_for("home")) return RedirectResponse(url=request.url_for("home"))

View file

@ -8,7 +8,6 @@ from pydantic import (
EmailStr, EmailStr,
ConfigDict, ConfigDict,
) )
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from sqlmodel import SQLModel, Field from sqlmodel import SQLModel, Field
logger = logging.getLogger("oidc-test") logger = logging.getLogger("oidc-test")
@ -37,7 +36,7 @@ class User(UserBase):
userinfo: dict = {} userinfo: dict = {}
access_token: str | None = None access_token: str | None = None
access_token_decoded: dict[str, Any] | None = None access_token_decoded: dict[str, Any] | None = None
oidc_provider: StarletteOAuth2App | None = None auth_provider_id: str
@computed_field @computed_field
@cached_property @cached_property
@ -56,11 +55,10 @@ class User(UserBase):
def decode_access_token(self, verify_signature: bool = True): def decode_access_token(self, verify_signature: bool = True):
assert self.access_token is not None assert self.access_token is not None
assert self.oidc_provider is not None assert self.auth_provider_id is not None
assert self.oidc_provider.name is not None from .auth_provider import providers
from .settings import oidc_providers_settings
return oidc_providers_settings[self.oidc_provider.name].decode( return providers[self.auth_provider_id].decode(
self.access_token, verify_signature=verify_signature self.access_token, verify_signature=verify_signature
) )

View file

@ -2,6 +2,7 @@ from datetime import datetime
from typing import Annotated from typing import Annotated
import logging import logging
from authlib.jose import Key
from httpx import AsyncClient from httpx import AsyncClient
from jwt.exceptions import ExpiredSignatureError, InvalidTokenError from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
from fastapi import FastAPI, HTTPException, Depends, status from fastapi import FastAPI, HTTPException, Depends, status
@ -15,10 +16,9 @@ from .models import User
from .auth_utils import ( from .auth_utils import (
get_user_from_token, get_user_from_token,
UserWithRole, UserWithRole,
# get_oidc_provider,
# get_token,
) )
from .settings import settings from .settings import settings
from .auth_provider import providers
logger = logging.getLogger("oidc-test") logger = logging.getLogger("oidc-test")
@ -128,7 +128,10 @@ async def get_resource(resource_id: str, user: User) -> dict:
""" """
Resource processing: build an informative rely as a simple showcase Resource processing: build an informative rely as a simple showcase
""" """
pname = getattr(user.oidc_provider, "name", "?") try:
pname = providers[user.auth_provider_id].name
except KeyError:
pname = "?"
resp = { resp = {
"hello": f"Hi {user.name} from an OAuth resource provider", "hello": f"Hi {user.name} from an OAuth resource provider",
"comment": f"I received a request for '{resource_id}' " "comment": f"I received a request for '{resource_id}' "

View file

@ -1,11 +1,9 @@
from os import environ from os import environ
import string import string
import random import random
from typing import Type, Tuple, Any from typing import Type, Tuple
from pathlib import Path from pathlib import Path
import logging
from jwt import decode
from pydantic import BaseModel, computed_field, AnyUrl from pydantic import BaseModel, computed_field, AnyUrl
from pydantic_settings import ( from pydantic_settings import (
BaseSettings, BaseSettings,
@ -17,8 +15,6 @@ from starlette.requests import Request
from .models import User from .models import User
logger = logging.getLogger("oidc-test")
class Resource(BaseModel): class Resource(BaseModel):
"""A resource with an URL that can be accessed with an OAuth2 access token""" """A resource with an URL that can be accessed with an OAuth2 access token"""
@ -27,8 +23,8 @@ class Resource(BaseModel):
name: str name: str
class OIDCProvider(BaseModel): class AuthProviderSettings(BaseModel):
"""OIDC provider, can also be a resource server""" """Auth provider, can also be a resource server"""
id: str id: str
name: str name: str
@ -79,30 +75,6 @@ class OIDCProvider(BaseModel):
-----END PUBLIC KEY----- -----END PUBLIC KEY-----
""" """
def decode(self, token: str, verify_signature: bool = True) -> dict[str, Any]:
"""Decode the token with signature check"""
if settings.debug_token:
decoded = decode(
token,
self.get_public_key(),
algorithms=[self.signature_alg],
audience=["account", "oidc-test", "oidc-test-web"],
options={
"verify_signature": False,
"verify_aud": False,
}, # not settings.insecure.skip_verify_signature},
)
logger.debug(str(decoded))
return decode(
token,
self.get_public_key(),
algorithms=[self.signature_alg],
audience=["account", "oidc-test", "oidc-test-web"],
options={
"verify_signature": verify_signature,
}, # not settings.insecure.skip_verify_signature},
)
class ResourceProvider(BaseModel): class ResourceProvider(BaseModel):
id: str id: str
@ -111,9 +83,9 @@ class ResourceProvider(BaseModel):
resources: list[Resource] = [] resources: list[Resource] = []
class OIDCSettings(BaseModel): class AuthSettings(BaseModel):
show_session_details: bool = False show_session_details: bool = False
providers: list[OIDCProvider] = [] providers: list[AuthProviderSettings] = []
swagger_provider: str = "" swagger_provider: str = ""
@ -128,7 +100,7 @@ class Settings(BaseSettings):
model_config = SettingsConfigDict(env_nested_delimiter="__") model_config = SettingsConfigDict(env_nested_delimiter="__")
oidc: OIDCSettings = OIDCSettings() auth: AuthSettings = AuthSettings()
resource_providers: list[ResourceProvider] = [] resource_providers: list[ResourceProvider] = []
secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16)) secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16))
log: bool = False log: bool = False
@ -163,8 +135,3 @@ class Settings(BaseSettings):
settings = Settings() settings = Settings()
oidc_providers_settings: dict[str, OIDCProvider] = dict(
[(provider.id, provider) for provider in settings.oidc.providers]
)

View file

@ -4,7 +4,7 @@
<link href="{{ url_for('static', path='/styles.css') }}" rel="stylesheet"> <link href="{{ url_for('static', path='/styles.css') }}" rel="stylesheet">
<script src="{{ url_for('static', path='/utils.js') }}"></script> <script src="{{ url_for('static', path='/utils.js') }}"></script>
</head> </head>
<body onload="checkPerms('links-to-check', '{{ access_token }}', '{{ oidc_provider_settings.id }}')"> <body onload="checkPerms('links-to-check', '{{ access_token }}', '{{ auth_provider.id }}')">
<h1>OIDC-test - FastAPI client</h1> <h1>OIDC-test - FastAPI client</h1>
{% block content %} {% block content %}
{% endblock %} {% endblock %}

View file

@ -8,7 +8,7 @@
<div class="login-box"> <div class="login-box">
<p class="description">Log in with:</p> <p class="description">Log in with:</p>
<table class="providers"> <table class="providers">
{% for provider in settings.oidc.providers %} {% for provider in settings.auth.providers %}
<tr class="provider"> <tr class="provider">
<td> <td>
<a class="link" href="login/{{ provider.id }}"><div>{{ provider.name }}</div></a> <a class="link" href="login/{{ provider.id }}"><div>{{ provider.name }}</div></a>
@ -32,7 +32,7 @@
<div>{{ user.email }}</div> <div>{{ user.email }}</div>
<div> <div>
<span>Provider:</span> <span>Provider:</span>
{{ oidc_provider_settings.name }} {{ auth_provider.name }}
</div> </div>
{% if user.roles %} {% if user.roles %}
<div> <div>
@ -50,9 +50,9 @@
{% endfor %} {% endfor %}
</div> </div>
{% endif %} {% endif %}
{% if oidc_provider_settings.account_url_template %} {% if auth_provider.account_url_template %}
<button <button
onclick="location.href='{{ oidc_provider_settings.get_account_url(request, user) }}'" onclick="location.href='{{ auth_provider.get_account_url(request, user) }}'"
class="account"> class="account">
Account management Account management
</button> </button>
@ -67,21 +67,21 @@
Resources validated by scope: Resources validated by scope:
</p> </p>
<div class="links-to-check"> <div class="links-to-check">
<button resource-id="time" onclick="get_resource('time', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Time</button> <button resource-id="time" onclick="get_resource('time', '{{ access_token }}', '{{ auth_provider.id }}')">Time</button>
<button resource-id="bs" onclick="get_resource('bs', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">BS</button> <button resource-id="bs" onclick="get_resource('bs', '{{ access_token }}', '{{ auth_provider.id }}')">BS</button>
</div> </div>
<p> <p>
Resources validated by role: Resources validated by role:
</p> </p>
<div class="links-to-check"> <div class="links-to-check">
<button resource-id="public" onclick="get_resource('public', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Public</button> <button resource-id="public" onclick="get_resource('public', '{{ access_token }}', '{{ auth_provider.id }}')">Public</button>
<button resource-id="protected" onclick="get_resource('protected', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Auth protected content</button> <button resource-id="protected" onclick="get_resource('protected', '{{ access_token }}', '{{ auth_provider.id }}')">Auth protected content</button>
<button resource-id="protected-by-foorole" onclick="get_resource('protected-by-foorole', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Auth + foorole protected content</button> <button resource-id="protected-by-foorole" onclick="get_resource('protected-by-foorole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole protected content</button>
<button resource-id="protected-by-foorole-or-barrole" onclick="get_resource('protected-by-foorole-or-barrole', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Auth + foorole or barrole protected content</button> <button resource-id="protected-by-foorole-or-barrole" onclick="get_resource('protected-by-foorole-or-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole or barrole protected content</button>
<button resource-id="protected-by-barrole" onclick="get_resource('protected-by-barrole', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Auth + barrole protected content</button> <button resource-id="protected-by-barrole" onclick="get_resource('protected-by-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + barrole protected content</button>
<button resource-id="protected-by-foorole-and-barrole" onclick="get_resource('protected-by-foorole-and-barrole', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Auth + foorole and barrole protected content</button> <button resource-id="protected-by-foorole-and-barrole" onclick="get_resource('protected-by-foorole-and-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole and barrole protected content</button>
<button resource-id="fast_api_depends" class="hidden" onclick="get_resource('fast_api_depends', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Using FastAPI Depends</button> <button resource-id="fast_api_depends" class="hidden" onclick="get_resource('fast_api_depends', '{{ access_token }}', '{{ auth_provider.id }}')">Using FastAPI Depends</button>
<!--<button resource-id="introspect" onclick="get_resource('introspect', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Introspect token (401 expected)</button>--> <!--<button resource-id="introspect" onclick="get_resource('introspect', '{{ access_token }}', '{{ auth_provider.id }}')">Introspect token (401 expected)</button>-->
</div> </div>
<div class="resourceResult"> <div class="resourceResult">
<div id="resource" class="resource"></div> <div id="resource" class="resource"></div>