Fetch provider info at boot time: get public key from there instead of in settings
Some checks failed
/ build (push) Failing after 15s
/ test (push) Successful in 5s

This commit is contained in:
phil 2025-01-29 14:03:33 +01:00
parent 5b31ef888c
commit f910834736
3 changed files with 61 additions and 29 deletions

View file

@ -6,7 +6,8 @@ from fastapi import HTTPException, Request, Depends, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from authlib.oauth2.rfc6749 import OAuth2Token from authlib.oauth2.rfc6749 import OAuth2Token
from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App
from jwt import decode from jwt import ExpiredSignatureError, InvalidKeyError, decode
from httpx import AsyncClient
# from authlib.oauth1.auth import OAuthToken # from authlib.oauth1.auth import OAuthToken
# from authlib.oauth2.auth import OAuth2Token # from authlib.oauth2.auth import OAuth2Token
@ -41,23 +42,35 @@ def update_token(*args, **kwargs):
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
def init_providers():
# Add oidc providers to authlib from the settings # Add oidc providers to authlib from the settings
for id, provider in oidc_providers_settings.items(): for id, provider in oidc_providers_settings.items():
authlib_oauth.register( authlib_oauth.register(
name=id, name=id,
server_metadata_url=provider.openid_configuration, server_metadata_url=provider.openid_configuration,
client_kwargs={ client_kwargs={
"scope": "openid email offline_access profile", "scope": "openid email offline_access profile",
}, },
client_id=provider.client_id, client_id=provider.client_id,
client_secret=provider.client_secret, client_secret=provider.client_secret,
api_base_url=provider.url, api_base_url=provider.url,
# For PKCE (not implemented yet): # For PKCE (not implemented yet):
# code_challenge_method="S256", # code_challenge_method="S256",
# fetch_token=fetch_token, # fetch_token=fetch_token,
# update_token=update_token, # update_token=update_token,
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
) )
init_providers()
async def get_providers_info():
# Get the public key:
async with AsyncClient() as client:
for provider_settings in oidc_providers_settings.values():
if provider_settings.info_url:
provider_info = await client.get(provider_settings.url)
provider_settings.info = provider_info.json()
def get_oidc_provider_or_none(request: Request) -> StarletteOAuth2App | None: def get_oidc_provider_or_none(request: Request) -> StarletteOAuth2App | None:
"""Return the oidc_provider from a request object, from the session. """Return the oidc_provider from a request object, from the session.
@ -156,20 +169,26 @@ def get_token_info(token: dict) -> dict:
return token_info return token_info
async def get_resource_user( async def get_user_from_token(
token: Annotated[str, Depends(oauth2_scheme)], token: Annotated[str, Depends(oauth2_scheme)],
request: Request, request: Request,
) -> User: ) -> User:
# TODO: decode token (ah!)
# See https://fastapi.tiangolo.com/tutorial/security/oauth2-jwt/#hash-and-verify-the-passwords
if (auth_provider_id := request.headers.get("auth_provider")) is None: if (auth_provider_id := request.headers.get("auth_provider")) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Request headers must have a 'auth_provider' field") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Request headers must have a 'auth_provider' field")
if (auth_provider := oidc_providers_settings.get(auth_provider_id)) is None: if (auth_provider_settings := oidc_providers_settings.get(auth_provider_id)) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'") raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'")
if (key := auth_provider.get_key()) is None: oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id)
await oidc_provider.load_server_metadata()
if (key := auth_provider_settings.get_public_key()) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Key for provider '{auth_provider_id}' unknown") raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Key for provider '{auth_provider_id}' unknown")
try: try:
payload = decode(token, key=key, algorithms=["RS256"], audience="oidc-test") payload = decode(token, key=key, algorithms=["RS256"], audience="oidc-test")
except ExpiredSignatureError as err:
logger.info(f"Expired signature: {err}")
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Expired signature (refresh not implemented yet)")
except InvalidKeyError as err:
logger.info(f"Invalid key: {err}")
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key")
except Exception as err: except Exception as err:
logger.info("Cannot decode token, see below") logger.info("Cannot decode token, see below")
logger.exception(err) logger.exception(err)

View file

@ -7,6 +7,7 @@ from pathlib import Path
from datetime import datetime from datetime import datetime
import logging import logging
from urllib.parse import urlencode from urllib.parse import urlencode
from contextlib import asynccontextmanager
from httpx import HTTPError from httpx import HTTPError
from fastapi import Depends, FastAPI, HTTPException, Request, status from fastapi import Depends, FastAPI, HTTPException, Request, status
@ -32,10 +33,11 @@ from .auth_utils import (
hasrole, hasrole,
get_current_user_or_none, get_current_user_or_none,
get_current_user, get_current_user,
get_resource_user, get_user_from_token,
authlib_oauth, authlib_oauth,
get_token, get_token,
oidc_providers_settings, oidc_providers_settings,
get_providers_info,
) )
from .auth_misc import pretty_details from .auth_misc import pretty_details
from .database import db from .database import db
@ -51,10 +53,18 @@ origins = [
"https://philo.ydns.eu/", "https://philo.ydns.eu/",
] ]
@asynccontextmanager
async def lifespan(app: FastAPI):
await get_providers_info()
yield
app = FastAPI( app = FastAPI(
title="OIDC auth test", title="OIDC auth test",
lifespan=lifespan
) )
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=origins, allow_origins=origins,
@ -278,7 +288,7 @@ async def get_resource_(
# user: Annotated[User, Depends(get_current_user)], # user: Annotated[User, Depends(get_current_user)],
# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], # oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
# token: Annotated[OAuth2Token, Depends(get_token)], # token: Annotated[OAuth2Token, Depends(get_token)],
user: Annotated[User, Depends(get_resource_user)], user: Annotated[User, Depends(get_user_from_token)],
) -> JSONResponse: ) -> JSONResponse:
"""Generic path for testing a resource provided by a provider""" """Generic path for testing a resource provided by a provider"""
return JSONResponse(await get_resource(id, user)) return JSONResponse(await get_resource(id, user))

View file

@ -36,7 +36,9 @@ class OIDCProvider(BaseModel):
hint: str = "No hint" hint: str = "No hint"
resources: list[Resource] = [] resources: list[Resource] = []
account_url_template: str | None = None account_url_template: str | None = None
key: str | None = None info_url: str | None = None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key)
info: dict[str, str | int ] | None = None # Info fetched from info_url, eg. public key
public_key: str | None = None
@computed_field @computed_field
@property @property
@ -64,13 +66,14 @@ class OIDCProvider(BaseModel):
else: else:
return None return None
def get_key(self) -> str | None: def get_public_key(self) -> str | None:
"""Return the public key formatted for """ """Return the public key formatted for decoding token"""
if self.key is None: public_key = self.public_key or (self.info is not None and self.info["public_key"])
if public_key is None:
return None return None
return f""" return f"""
-----BEGIN PUBLIC KEY----- -----BEGIN PUBLIC KEY-----
{self.key} {public_key}
-----END PUBLIC KEY----- -----END PUBLIC KEY-----
""" """