Fetch provider info at boot time: get public key from there instead of in settings
Some checks failed
/ build (push) Failing after 15s
/ test (push) Successful in 5s

This commit is contained in:
phil 2025-01-29 14:03:33 +01:00
parent 5b31ef888c
commit f910834736
3 changed files with 61 additions and 29 deletions

View file

@ -6,7 +6,8 @@ from fastapi import HTTPException, Request, Depends, status
from fastapi.security import OAuth2PasswordBearer
from authlib.oauth2.rfc6749 import OAuth2Token
from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App
from jwt import decode
from jwt import ExpiredSignatureError, InvalidKeyError, decode
from httpx import AsyncClient
# from authlib.oauth1.auth import OAuthToken
# from authlib.oauth2.auth import OAuth2Token
@ -41,23 +42,35 @@ def update_token(*args, **kwargs):
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
def init_providers():
# Add oidc providers to authlib from the settings
for id, provider in oidc_providers_settings.items():
authlib_oauth.register(
name=id,
server_metadata_url=provider.openid_configuration,
client_kwargs={
"scope": "openid email offline_access profile",
},
client_id=provider.client_id,
client_secret=provider.client_secret,
api_base_url=provider.url,
# For PKCE (not implemented yet):
# code_challenge_method="S256",
# fetch_token=fetch_token,
# update_token=update_token,
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
)
for id, provider in oidc_providers_settings.items():
authlib_oauth.register(
name=id,
server_metadata_url=provider.openid_configuration,
client_kwargs={
"scope": "openid email offline_access profile",
},
client_id=provider.client_id,
client_secret=provider.client_secret,
api_base_url=provider.url,
# For PKCE (not implemented yet):
# code_challenge_method="S256",
# fetch_token=fetch_token,
# update_token=update_token,
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
)
init_providers()
async def get_providers_info():
# Get the public key:
async with AsyncClient() as client:
for provider_settings in oidc_providers_settings.values():
if provider_settings.info_url:
provider_info = await client.get(provider_settings.url)
provider_settings.info = provider_info.json()
def get_oidc_provider_or_none(request: Request) -> StarletteOAuth2App | None:
"""Return the oidc_provider from a request object, from the session.
@ -156,20 +169,26 @@ def get_token_info(token: dict) -> dict:
return token_info
async def get_resource_user(
async def get_user_from_token(
token: Annotated[str, Depends(oauth2_scheme)],
request: Request,
) -> User:
# TODO: decode token (ah!)
# See https://fastapi.tiangolo.com/tutorial/security/oauth2-jwt/#hash-and-verify-the-passwords
if (auth_provider_id := request.headers.get("auth_provider")) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Request headers must have a 'auth_provider' field")
if (auth_provider := oidc_providers_settings.get(auth_provider_id)) is None:
if (auth_provider_settings := oidc_providers_settings.get(auth_provider_id)) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'")
if (key := auth_provider.get_key()) is None:
oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id)
await oidc_provider.load_server_metadata()
if (key := auth_provider_settings.get_public_key()) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Key for provider '{auth_provider_id}' unknown")
try:
payload = decode(token, key=key, algorithms=["RS256"], audience="oidc-test")
except ExpiredSignatureError as err:
logger.info(f"Expired signature: {err}")
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Expired signature (refresh not implemented yet)")
except InvalidKeyError as err:
logger.info(f"Invalid key: {err}")
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key")
except Exception as err:
logger.info("Cannot decode token, see below")
logger.exception(err)

View file

@ -7,6 +7,7 @@ from pathlib import Path
from datetime import datetime
import logging
from urllib.parse import urlencode
from contextlib import asynccontextmanager
from httpx import HTTPError
from fastapi import Depends, FastAPI, HTTPException, Request, status
@ -32,10 +33,11 @@ from .auth_utils import (
hasrole,
get_current_user_or_none,
get_current_user,
get_resource_user,
get_user_from_token,
authlib_oauth,
get_token,
oidc_providers_settings,
get_providers_info,
)
from .auth_misc import pretty_details
from .database import db
@ -51,10 +53,18 @@ origins = [
"https://philo.ydns.eu/",
]
@asynccontextmanager
async def lifespan(app: FastAPI):
await get_providers_info()
yield
app = FastAPI(
title="OIDC auth test",
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
@ -278,7 +288,7 @@ async def get_resource_(
# user: Annotated[User, Depends(get_current_user)],
# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
# token: Annotated[OAuth2Token, Depends(get_token)],
user: Annotated[User, Depends(get_resource_user)],
user: Annotated[User, Depends(get_user_from_token)],
) -> JSONResponse:
"""Generic path for testing a resource provided by a provider"""
return JSONResponse(await get_resource(id, user))

View file

@ -36,7 +36,9 @@ class OIDCProvider(BaseModel):
hint: str = "No hint"
resources: list[Resource] = []
account_url_template: str | None = None
key: str | None = None
info_url: str | None = None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key)
info: dict[str, str | int ] | None = None # Info fetched from info_url, eg. public key
public_key: str | None = None
@computed_field
@property
@ -64,13 +66,14 @@ class OIDCProvider(BaseModel):
else:
return None
def get_key(self) -> str | None:
"""Return the public key formatted for """
if self.key is None:
def get_public_key(self) -> str | None:
"""Return the public key formatted for decoding token"""
public_key = self.public_key or (self.info is not None and self.info["public_key"])
if public_key is None:
return None
return f"""
-----BEGIN PUBLIC KEY-----
{self.key}
{public_key}
-----END PUBLIC KEY-----
"""