From f9108347369a24dd2998323d140c5fa4dcab4a88 Mon Sep 17 00:00:00 2001 From: phil Date: Wed, 29 Jan 2025 14:03:33 +0100 Subject: [PATCH] Fetch provider info at boot time: get public key from there instead of in settings --- src/oidc_test/auth_utils.py | 63 ++++++++++++++++++++++++------------- src/oidc_test/main.py | 14 +++++++-- src/oidc_test/settings.py | 13 +++++--- 3 files changed, 61 insertions(+), 29 deletions(-) diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index 2b3d0fd..d64b6cf 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -6,7 +6,8 @@ from fastapi import HTTPException, Request, Depends, status from fastapi.security import OAuth2PasswordBearer from authlib.oauth2.rfc6749 import OAuth2Token from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App -from jwt import decode +from jwt import ExpiredSignatureError, InvalidKeyError, decode +from httpx import AsyncClient # from authlib.oauth1.auth import OAuthToken # from authlib.oauth2.auth import OAuth2Token @@ -41,23 +42,35 @@ def update_token(*args, **kwargs): authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) +def init_providers(): # Add oidc providers to authlib from the settings -for id, provider in oidc_providers_settings.items(): - authlib_oauth.register( - name=id, - server_metadata_url=provider.openid_configuration, - client_kwargs={ - "scope": "openid email offline_access profile", - }, - client_id=provider.client_id, - client_secret=provider.client_secret, - api_base_url=provider.url, - # For PKCE (not implemented yet): - # code_challenge_method="S256", - # fetch_token=fetch_token, - # update_token=update_token, - # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) - ) + for id, provider in oidc_providers_settings.items(): + authlib_oauth.register( + name=id, + server_metadata_url=provider.openid_configuration, + client_kwargs={ + "scope": "openid email offline_access profile", + }, + client_id=provider.client_id, + client_secret=provider.client_secret, + api_base_url=provider.url, + # For PKCE (not implemented yet): + # code_challenge_method="S256", + # fetch_token=fetch_token, + # update_token=update_token, + # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) + ) + +init_providers() + +async def get_providers_info(): + # Get the public key: + async with AsyncClient() as client: + for provider_settings in oidc_providers_settings.values(): + if provider_settings.info_url: + provider_info = await client.get(provider_settings.url) + provider_settings.info = provider_info.json() + def get_oidc_provider_or_none(request: Request) -> StarletteOAuth2App | None: """Return the oidc_provider from a request object, from the session. @@ -156,20 +169,26 @@ def get_token_info(token: dict) -> dict: return token_info -async def get_resource_user( +async def get_user_from_token( token: Annotated[str, Depends(oauth2_scheme)], request: Request, ) -> User: - # TODO: decode token (ah!) - # See https://fastapi.tiangolo.com/tutorial/security/oauth2-jwt/#hash-and-verify-the-passwords if (auth_provider_id := request.headers.get("auth_provider")) is None: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Request headers must have a 'auth_provider' field") - if (auth_provider := oidc_providers_settings.get(auth_provider_id)) is None: + if (auth_provider_settings := oidc_providers_settings.get(auth_provider_id)) is None: raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'") - if (key := auth_provider.get_key()) is None: + oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id) + await oidc_provider.load_server_metadata() + if (key := auth_provider_settings.get_public_key()) is None: raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Key for provider '{auth_provider_id}' unknown") try: payload = decode(token, key=key, algorithms=["RS256"], audience="oidc-test") + except ExpiredSignatureError as err: + logger.info(f"Expired signature: {err}") + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Expired signature (refresh not implemented yet)") + except InvalidKeyError as err: + logger.info(f"Invalid key: {err}") + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key") except Exception as err: logger.info("Cannot decode token, see below") logger.exception(err) diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 351cc2f..6942759 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -7,6 +7,7 @@ from pathlib import Path from datetime import datetime import logging from urllib.parse import urlencode +from contextlib import asynccontextmanager from httpx import HTTPError from fastapi import Depends, FastAPI, HTTPException, Request, status @@ -32,10 +33,11 @@ from .auth_utils import ( hasrole, get_current_user_or_none, get_current_user, - get_resource_user, + get_user_from_token, authlib_oauth, get_token, oidc_providers_settings, + get_providers_info, ) from .auth_misc import pretty_details from .database import db @@ -51,10 +53,18 @@ origins = [ "https://philo.ydns.eu/", ] +@asynccontextmanager +async def lifespan(app: FastAPI): + await get_providers_info() + yield + + app = FastAPI( title="OIDC auth test", + lifespan=lifespan ) + app.add_middleware( CORSMiddleware, allow_origins=origins, @@ -278,7 +288,7 @@ async def get_resource_( # user: Annotated[User, Depends(get_current_user)], # oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], # token: Annotated[OAuth2Token, Depends(get_token)], - user: Annotated[User, Depends(get_resource_user)], + user: Annotated[User, Depends(get_user_from_token)], ) -> JSONResponse: """Generic path for testing a resource provided by a provider""" return JSONResponse(await get_resource(id, user)) diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index c511f86..00c3f23 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -36,7 +36,9 @@ class OIDCProvider(BaseModel): hint: str = "No hint" resources: list[Resource] = [] account_url_template: str | None = None - key: str | None = None + info_url: str | None = None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key) + info: dict[str, str | int ] | None = None # Info fetched from info_url, eg. public key + public_key: str | None = None @computed_field @property @@ -64,13 +66,14 @@ class OIDCProvider(BaseModel): else: return None - def get_key(self) -> str | None: - """Return the public key formatted for """ - if self.key is None: + def get_public_key(self) -> str | None: + """Return the public key formatted for decoding token""" + public_key = self.public_key or (self.info is not None and self.info["public_key"]) + if public_key is None: return None return f""" -----BEGIN PUBLIC KEY----- - {self.key} + {public_key} -----END PUBLIC KEY----- """