From c5bb4f4319445ba145ab08067593a6957ce37f42 Mon Sep 17 00:00:00 2001 From: phil Date: Sun, 9 Feb 2025 06:20:48 +0100 Subject: [PATCH] Refactor most code, isolate authlib somehow --- src/oidc_test/auth_misc.py | 29 ------- src/oidc_test/auth_provider.py | 43 ++++++++++ src/oidc_test/auth_utils.py | 80 +++++++++++-------- src/oidc_test/database.py | 27 +++---- src/oidc_test/main.py | 128 ++++++++++++------------------ src/oidc_test/models.py | 10 +-- src/oidc_test/resource_server.py | 9 ++- src/oidc_test/settings.py | 45 ++--------- src/oidc_test/templates/base.html | 2 +- src/oidc_test/templates/home.html | 28 +++---- 10 files changed, 183 insertions(+), 218 deletions(-) delete mode 100644 src/oidc_test/auth_misc.py create mode 100644 src/oidc_test/auth_provider.py diff --git a/src/oidc_test/auth_misc.py b/src/oidc_test/auth_misc.py deleted file mode 100644 index a4e9ea3..0000000 --- a/src/oidc_test/auth_misc.py +++ /dev/null @@ -1,29 +0,0 @@ -from datetime import datetime, timedelta -from collections import OrderedDict - -from .models import User - -time_keys = set(("iat", "exp", "auth_time", "updated_at")) - - -def pretty_details(user: User, now: datetime) -> OrderedDict: - details = OrderedDict() - # breakpoint() - for key in sorted(time_keys): - try: - dt = datetime.fromtimestamp(user.userinfo[key]) - except (KeyError, TypeError): - pass - else: - td = now - dt - td = timedelta(days=td.days, seconds=td.seconds) - if td.days < 0: - ptd = f"in {-td} h:m:s" - else: - ptd = f"{td} h:m:s ago" - details[key] = f"{user.userinfo[key]} - {dt} ({ptd})" - for key in sorted(user.userinfo): - if key in time_keys: - continue - details[key] = user.userinfo[key] - return details diff --git a/src/oidc_test/auth_provider.py b/src/oidc_test/auth_provider.py new file mode 100644 index 0000000..bed4596 --- /dev/null +++ b/src/oidc_test/auth_provider.py @@ -0,0 +1,43 @@ +from typing import Any +from jwt import decode +import logging + +from authlib.integrations.starlette_client.apps import StarletteOAuth2App + +from .settings import AuthProviderSettings, settings + +logger = logging.getLogger("oidc-test") + + +class Provider(AuthProviderSettings): + class Config: + arbitrary_types_allowed = True + + authlib_client: StarletteOAuth2App = StarletteOAuth2App(None) + + def decode(self, token: str, verify_signature: bool = True) -> dict[str, Any]: + """Decode the token with signature check""" + if settings.debug_token: + decoded = decode( + token, + self.get_public_key(), + algorithms=[self.signature_alg], + audience=["account", "oidc-test", "oidc-test-web"], + options={ + "verify_signature": False, + "verify_aud": False, + }, # not settings.insecure.skip_verify_signature}, + ) + logger.debug(str(decoded)) + return decode( + token, + self.get_public_key(), + algorithms=[self.signature_alg], + audience=["account", "oidc-test", "oidc-test-web"], + options={ + "verify_signature": verify_signature, + }, # not settings.insecure.skip_verify_signature}, + ) + + +providers: dict[str, Provider] = {} diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index cab14b2..e62fe39 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -1,4 +1,3 @@ -import re from typing import Union, Annotated from functools import wraps import logging @@ -14,7 +13,8 @@ from authlib.oauth2.auth import OAuth2Token from .models import User from .database import db, TokenNotInDb, UserNotInDB -from .settings import oidc_providers_settings +from .settings import settings +from .auth_provider import providers, Provider logger = logging.getLogger("oidc-test") @@ -35,11 +35,13 @@ async def fetch_token(name, request): # return token.to_token() -async def update_token(name, token, refresh_token=None, access_token=None): +async def update_token( + provider_id, token, refresh_token: str | None = None, access_token: str | None = None +): """Update the token in the database""" - oidc_provider_settings = oidc_providers_settings[name] - sid: str = oidc_provider_settings.decode(token["id_token"])["sid"] - item = await db.get_token(oidc_provider_settings, sid) + provider = providers[provider_id] + sid: str = provider.decode(token["id_token"])["sid"] + item = await db.get_token(provider, sid) # update old token item["access_token"] = token["access_token"] item["refresh_token"] = token["refresh_token"] @@ -54,10 +56,12 @@ authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_t def init_providers(): - # Add oidc providers to authlib from the settings - for id, provider in oidc_providers_settings.items(): + """Add oidc providers to authlib from the settings + and build the providers dict""" + for provider_settings in settings.auth.providers: + provider = Provider(**provider_settings.model_dump()) authlib_oauth.register( - name=id, + name=provider.id, server_metadata_url=provider.openid_configuration, client_kwargs={ "scope": " ".join( @@ -74,6 +78,8 @@ def init_providers(): update_token=update_token, # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) ) + provider.authlib_client = getattr(authlib_oauth, provider.id) + providers[provider.id] = provider init_providers() @@ -82,33 +88,41 @@ init_providers() async def get_providers_info(): # Get the public key: async with AsyncClient() as client: - for provider_settings in oidc_providers_settings.values(): - if provider_settings.info_url: - provider_info = await client.get(provider_settings.url) - provider_settings.info = provider_info.json() + for provider in providers.values(): + if provider.info_url: + provider_info = await client.get(provider.url) + provider.info = provider_info.json() -def get_oidc_provider_or_none(request: Request) -> StarletteOAuth2App | None: +def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None: """Return the oidc_provider from a request object, from the session. It can be used in Depends()""" - if (oidc_provider_id := request.session.get("oidc_provider_id")) is None: - return - try: - return getattr(authlib_oauth, str(oidc_provider_id)) - except AttributeError: + if (auth_provider_id := request.session.get("auth_provider_id")) is None: return + return getattr(authlib_oauth, str(auth_provider_id), None) -def get_oidc_provider(request: Request) -> StarletteOAuth2App: - if (oidc_provider := get_oidc_provider_or_none(request)) is None: - if oidc_provider is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No provider") - else: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") +def get_auth_provider_client(request: Request) -> StarletteOAuth2App: + if (oidc_provider := get_auth_provider_client_or_none(request)) is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") else: return oidc_provider +def get_auth_provider_or_none(request: Request) -> Provider | None: + """Return the oidc_provider settings from a request object, from the session. + It can be used in Depends()""" + if (auth_provider_id := request.session.get("auth_provider_id")) is None: + return + return providers.get(auth_provider_id) + + +def get_auth_provider(request: Request) -> Provider: + if (provider := get_auth_provider_or_none(request)) is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") + return provider + + async def get_current_user(request: Request) -> User: """Get the current user from a request object. Also validates the token expiration time. @@ -120,11 +134,11 @@ async def get_current_user(request: Request) -> User: user = await db.get_user(user_sub) ## Check if the token is expired if token.is_expired(): - oidc_provider = get_oidc_provider(request=request) + provider = get_auth_provider(request=request) ## Ask a new refresh token from the provider logger.info(f"Token expired for user {user.name}") try: - userinfo = await oidc_provider.fetch_access_token( + userinfo = await provider.authlib_client.fetch_access_token( refresh_token=token.get("refresh_token") ) assert userinfo is not None @@ -150,14 +164,12 @@ async def get_token(request: Request) -> OAuth2Token: """Return the token from the session. Can be used in Depends()""" try: - oidc_provider_settings = oidc_providers_settings[ - request.session.get("oidc_provider_id", "") - ] + provider = providers[request.session.get("auth_provider_id", "")] except KeyError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid provider") try: return await db.get_token( - oidc_provider_settings, + provider, request.session.get("sid"), ) except (TokenNotInDb, InvalidKeyError, DecodeError) as err: @@ -219,7 +231,7 @@ async def get_user_from_token( "Request headers must have a 'auth_provider' field", ) try: - auth_provider_settings = oidc_providers_settings[auth_provider_id] + provider = providers[auth_provider_id] except KeyError: if auth_provider_id == "": raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No auth provider") @@ -230,7 +242,7 @@ async def get_user_from_token( if token == "": raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token") try: - payload = auth_provider_settings.decode(token) + payload = provider.decode(token) except ExpiredSignatureError as err: logger.info(f"Expired signature: {err}") raise HTTPException( @@ -261,7 +273,7 @@ async def get_user_from_token( user = await db.add_user( sub=payload["sub"], user_info=payload, - oidc_provider=getattr(authlib_oauth, auth_provider_id), + auth_provider=getattr(authlib_oauth, auth_provider_id), access_token=token, ) return user diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index d3bdd4e..3493429 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -2,11 +2,10 @@ import logging -from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.oauth2.rfc6749 import OAuth2Token -from .settings import OIDCProvider, oidc_providers_settings from .models import User, Role +from .auth_provider import Provider, providers logger = logging.getLogger("oidc-test") @@ -29,23 +28,23 @@ class Database: self, sub: str, user_info: dict, - oidc_provider: StarletteOAuth2App, + auth_provider: Provider, access_token: str, access_token_decoded: dict | None = None, ) -> User: if access_token_decoded is None: - assert oidc_provider.name is not None - oidc_provider_settings = oidc_providers_settings[oidc_provider.name] - access_token_decoded = oidc_provider_settings.decode(access_token) + assert auth_provider.name is not None + provider = providers[auth_provider.id] + access_token_decoded = provider.decode(access_token) + user_info["auth_provider_id"] = auth_provider.id user = User(**user_info) user.userinfo = user_info - user.oidc_provider = oidc_provider - user.access_token = access_token - user.access_token_decoded = access_token_decoded + # user.access_token = access_token + # user.access_token_decoded = access_token_decoded # Add roles provided in the access token roles = set() try: - r = access_token_decoded["resource_access"][oidc_provider.client_id]["roles"] + r = access_token_decoded["resource_access"][auth_provider.client_id]["roles"] roles.update(r) except KeyError: pass @@ -66,19 +65,19 @@ class Database: raise UserNotInDB return self.users[sub] - async def add_token(self, oidc_provider_settings: OIDCProvider, token: OAuth2Token) -> None: + async def add_token(self, provider: Provider, token: OAuth2Token) -> None: """Store a token using as key the sid (auth provider's session id) in the id_token""" - assert isinstance(oidc_provider_settings, OIDCProvider) + assert isinstance(provider, Provider) sid = token["userinfo"]["sid"] self.tokens[sid] = token async def get_token( self, - oidc_provider_settings: OIDCProvider, + provider: Provider, sid: str | None, ) -> OAuth2Token: - assert isinstance(oidc_provider_settings, OIDCProvider) + assert isinstance(provider, Provider) if sid is None: raise TokenNotInDb try: diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 03d13d7..304df92 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -15,7 +15,7 @@ from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.templating import Jinja2Templates from fastapi.middleware.cors import CORSMiddleware -from jwt import InvalidTokenError, PyJWTError +from jwt import PyJWTError from starlette.middleware.sessions import SessionMiddleware from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.base_client import OAuthError @@ -26,11 +26,12 @@ from authlib.oauth2.rfc6749 import OAuth2Token # from fastapi.security import OpenIdConnect # from pkce import generate_code_verifier, generate_pkce_pair -from .settings import settings, oidc_providers_settings +from .settings import settings +from .auth_provider import Provider, providers from .models import User from .auth_utils import ( - get_oidc_provider, - get_oidc_provider_or_none, + get_auth_provider, + get_auth_provider_or_none, get_current_user_or_none, authlib_oauth, get_providers_info, @@ -38,7 +39,6 @@ from .auth_utils import ( get_token, update_token, ) -from .auth_misc import pretty_details from .database import TokenNotInDb, db from .resource_server import resource_server @@ -78,51 +78,30 @@ app.mount("/resource", resource_server, name="resource_server") async def home( request: Request, user: Annotated[User, Depends(get_current_user_or_none)], - oidc_provider: Annotated[StarletteOAuth2App | None, Depends(get_oidc_provider_or_none)], + provider: Annotated[Provider | None, Depends(get_auth_provider_or_none)], token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], ) -> HTMLResponse: - now = datetime.now() - if oidc_provider and ( - ( - oidc_provider_settings := oidc_providers_settings.get( - request.session.get("oidc_provider_id", "") - ) - ) - is not None - ): - resources = oidc_provider_settings.resources - else: - resources = [] - oidc_provider_settings = None - context = { "settings": settings.model_dump(), "user": user, - "now": now, - "oidc_provider": oidc_provider, - "oidc_provider_settings": oidc_provider_settings, - "resources": resources, + "now": datetime.now(), + "auth_provider": provider, } - if token is None: + if provider is None or token is None: context["access_token"] = None context["id_token_parsed"] = None context["access_token_parsed"] = None context["refresh_token_parsed"] = None + context["resources"] = None else: context["access_token"] = token["access_token"] - assert oidc_provider is not None - assert oidc_provider.name is not None - oidc_provider_settings = oidc_providers_settings[oidc_provider.name] - access_token_parsed = oidc_provider_settings.decode( - token["access_token"], verify_signature=False - ) + context["resources"] = provider.resources + access_token_parsed = provider.decode(token["access_token"], verify_signature=False) context["access_token_scope"] = access_token_parsed["scope"] # context["id_token_parsed"] = pretty_details(user, now) - context["id_token_parsed"] = oidc_provider_settings.decode( - token["id_token"], verify_signature=False - ) + context["id_token_parsed"] = provider.decode(token["id_token"], verify_signature=False) context["access_token_parsed"] = access_token_parsed - context["refresh_token_parsed"] = oidc_provider_settings.decode( + context["refresh_token_parsed"] = provider.decode( token["refresh_token"], verify_signature=False ) return templates.TemplateResponse(name="home.html", request=request, context=context) @@ -131,20 +110,20 @@ async def home( # Endpoints for the login / authorization process -@app.get("/login/{oidc_provider_id}") -async def login(request: Request, oidc_provider_id: str) -> RedirectResponse: +@app.get("/login/{auth_provider_id}") +async def login(request: Request, auth_provider_id: str) -> RedirectResponse: """Login with the provider id, giving the browser a redirect to its authorize page. - The provider is expected to send the browser back to our own /auth/{oidc_provider_id} url + The provider is expected to send the browser back to our own /auth/{auth_provider_id} url with the token. """ - redirect_uri = request.url_for("auth", oidc_provider_id=oidc_provider_id) + redirect_uri = request.url_for("auth", auth_provider_id=auth_provider_id) try: - provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id) + provider: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id) except AttributeError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") # if ( - # code_challenge_method := oidc_providers_settings[ - # oidc_provider_id + # code_challenge_method := providers[ + # auth_provider_id # ].code_challenge_method # ) is not None: # #client = AsyncOAuth2Client(..., code_challenge_method=code_challenge_method) @@ -164,30 +143,30 @@ async def login(request: Request, oidc_provider_id: str) -> RedirectResponse: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider") -@app.get("/auth/{oidc_provider_id}") -async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: +@app.get("/auth/{auth_provider_id}") +async def auth(request: Request, auth_provider_id: str) -> RedirectResponse: """Decrypt the auth token, store it to the session (cookie based) and response to the browser with a redirect to a "welcome user" page. """ try: - oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id) + authlib_client: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id) except AttributeError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") try: - token: OAuth2Token = await oidc_provider.authorize_access_token(request) + token: OAuth2Token = await authlib_client.authorize_access_token(request) except OAuthError as error: raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error) - # Remember the oidc_provider in the session + # Remember the authlib_client in the session # logger.info(f"Scope: {token['scope']}") - request.session["oidc_provider_id"] = oidc_provider_id + request.session["auth_provider_id"] = auth_provider_id # # One could process the full decoded token which contains extra information # eg for updates. Here we are only interested in roles # if userinfo := token.get("userinfo"): - # Remember the oidc_provider in the session - request.session["oidc_provider_id"] = oidc_provider_id - # User id (sub) given by oidc provider + # Remember the authlib_client in the session + request.session["auth_provider_id"] = auth_provider_id + # User id (sub) given by auth provider sub = userinfo["sub"] # Build and remember the user in the session request.session["user_sub"] = sub @@ -196,7 +175,7 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: user = await db.add_user( sub, user_info=userinfo, - oidc_provider=oidc_provider, + auth_provider=providers[auth_provider_id], access_token=token["access_token"], ) except PyJWTError as err: @@ -208,48 +187,41 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: # Add the provider session id to the session request.session["sid"] = userinfo["sid"] # Add the token to the db because it is used for logout - assert oidc_provider.name is not None - oidc_provider_settings = oidc_providers_settings[oidc_provider.name] - await db.add_token(oidc_provider_settings, token) + provider = providers[auth_provider_id] + await db.add_token(provider, token) # Send the user to the home: (s)he is authenticated return RedirectResponse(url=request.url_for("home")) else: # Not sure if it's correct to redirect to plain login # if no userinfo is provided - return RedirectResponse(url=request.url_for("login", oidc_provider_id=oidc_provider_id)) + return RedirectResponse(url=request.url_for("login", auth_provider_id=auth_provider_id)) @app.get("/account") async def account( - request: Request, + provider: Annotated[Provider, Depends(get_auth_provider)], ) -> RedirectResponse: - if ( - oidc_provider_settings := oidc_providers_settings.get( - request.session.get("oidc_provider_id", "") - ) - ) is None: - raise HTTPException(status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider settings") - return RedirectResponse(f"{oidc_provider_settings.account_url_template}") + """Redirect to the auth provider account management, + if account_url_template is in the provider's settings""" + return RedirectResponse(f"{provider.account_url_template}") @app.get("/logout") async def logout( request: Request, - oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], + provider: Annotated[Provider, Depends(get_auth_provider)], ) -> RedirectResponse: # Clear session request.session.pop("user_sub", None) # Get provider's endpoint - if (provider_logout_uri := oidc_provider.server_metadata.get("end_session_endpoint")) is None: - logger.warn(f"Cannot find end_session_endpoint for provider {oidc_provider.name}") + if ( + provider_logout_uri := provider.authlib_client.server_metadata.get("end_session_endpoint") + ) is None: + logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}") return RedirectResponse(request.url_for("non_compliant_logout")) post_logout_uri = request.url_for("home") - oidc_provider_settings = oidc_providers_settings.get( - request.session.get("oidc_provider_id", "") - ) - assert oidc_provider_settings is not None try: - token = await db.get_token(oidc_provider_settings, request.session.pop("sid", None)) + token = await db.get_token(provider, request.session.pop("sid", None)) except TokenNotInDb: logger.warn("No session in db for the token or no token") return RedirectResponse(request.url_for("home")) @@ -270,30 +242,30 @@ async def logout( @app.get("/non-compliant-logout") async def non_compliant_logout( request: Request, - oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], + provider: Annotated[StarletteOAuth2App, Depends(get_auth_provider)], ): """A page for non-compliant OAuth2 servers that we cannot log out.""" # Clear the remain of the session - request.session.pop("oidc_provider_id", None) + request.session.pop("auth_provider_id", None) return templates.TemplateResponse( name="non_compliant_logout.html", request=request, - context={"oidc_provider": oidc_provider, "home_url": request.url_for("home")}, + context={"oidc_provider": provider, "home_url": request.url_for("home")}, ) @app.get("/refresh") async def refresh( request: Request, - oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], + provider: Annotated[Provider, Depends(get_auth_provider)], token: Annotated[OAuth2Token, Depends(get_token)], ) -> RedirectResponse: """Manually refresh token""" - new_token = await oidc_provider.fetch_access_token( + new_token = await provider.authlib_client.fetch_access_token( refresh_token=token["refresh_token"], grant_type="refresh_token", ) - await update_token(oidc_provider.name, new_token) + await update_token(provider.id, new_token) return RedirectResponse(url=request.url_for("home")) diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py index 8aee2e6..eda63a6 100644 --- a/src/oidc_test/models.py +++ b/src/oidc_test/models.py @@ -8,7 +8,6 @@ from pydantic import ( EmailStr, ConfigDict, ) -from authlib.integrations.starlette_client.apps import StarletteOAuth2App from sqlmodel import SQLModel, Field logger = logging.getLogger("oidc-test") @@ -37,7 +36,7 @@ class User(UserBase): userinfo: dict = {} access_token: str | None = None access_token_decoded: dict[str, Any] | None = None - oidc_provider: StarletteOAuth2App | None = None + auth_provider_id: str @computed_field @cached_property @@ -56,11 +55,10 @@ class User(UserBase): def decode_access_token(self, verify_signature: bool = True): assert self.access_token is not None - assert self.oidc_provider is not None - assert self.oidc_provider.name is not None - from .settings import oidc_providers_settings + assert self.auth_provider_id is not None + from .auth_provider import providers - return oidc_providers_settings[self.oidc_provider.name].decode( + return providers[self.auth_provider_id].decode( self.access_token, verify_signature=verify_signature ) diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index cb944ed..e5670ed 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -2,6 +2,7 @@ from datetime import datetime from typing import Annotated import logging +from authlib.jose import Key from httpx import AsyncClient from jwt.exceptions import ExpiredSignatureError, InvalidTokenError from fastapi import FastAPI, HTTPException, Depends, status @@ -15,10 +16,9 @@ from .models import User from .auth_utils import ( get_user_from_token, UserWithRole, - # get_oidc_provider, - # get_token, ) from .settings import settings +from .auth_provider import providers logger = logging.getLogger("oidc-test") @@ -128,7 +128,10 @@ async def get_resource(resource_id: str, user: User) -> dict: """ Resource processing: build an informative rely as a simple showcase """ - pname = getattr(user.oidc_provider, "name", "?") + try: + pname = providers[user.auth_provider_id].name + except KeyError: + pname = "?" resp = { "hello": f"Hi {user.name} from an OAuth resource provider", "comment": f"I received a request for '{resource_id}' " diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index e448c1e..9a789a0 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -1,11 +1,9 @@ from os import environ import string import random -from typing import Type, Tuple, Any +from typing import Type, Tuple from pathlib import Path -import logging -from jwt import decode from pydantic import BaseModel, computed_field, AnyUrl from pydantic_settings import ( BaseSettings, @@ -17,8 +15,6 @@ from starlette.requests import Request from .models import User -logger = logging.getLogger("oidc-test") - class Resource(BaseModel): """A resource with an URL that can be accessed with an OAuth2 access token""" @@ -27,8 +23,8 @@ class Resource(BaseModel): name: str -class OIDCProvider(BaseModel): - """OIDC provider, can also be a resource server""" +class AuthProviderSettings(BaseModel): + """Auth provider, can also be a resource server""" id: str name: str @@ -79,30 +75,6 @@ class OIDCProvider(BaseModel): -----END PUBLIC KEY----- """ - def decode(self, token: str, verify_signature: bool = True) -> dict[str, Any]: - """Decode the token with signature check""" - if settings.debug_token: - decoded = decode( - token, - self.get_public_key(), - algorithms=[self.signature_alg], - audience=["account", "oidc-test", "oidc-test-web"], - options={ - "verify_signature": False, - "verify_aud": False, - }, # not settings.insecure.skip_verify_signature}, - ) - logger.debug(str(decoded)) - return decode( - token, - self.get_public_key(), - algorithms=[self.signature_alg], - audience=["account", "oidc-test", "oidc-test-web"], - options={ - "verify_signature": verify_signature, - }, # not settings.insecure.skip_verify_signature}, - ) - class ResourceProvider(BaseModel): id: str @@ -111,9 +83,9 @@ class ResourceProvider(BaseModel): resources: list[Resource] = [] -class OIDCSettings(BaseModel): +class AuthSettings(BaseModel): show_session_details: bool = False - providers: list[OIDCProvider] = [] + providers: list[AuthProviderSettings] = [] swagger_provider: str = "" @@ -128,7 +100,7 @@ class Settings(BaseSettings): model_config = SettingsConfigDict(env_nested_delimiter="__") - oidc: OIDCSettings = OIDCSettings() + auth: AuthSettings = AuthSettings() resource_providers: list[ResourceProvider] = [] secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16)) log: bool = False @@ -163,8 +135,3 @@ class Settings(BaseSettings): settings = Settings() - - -oidc_providers_settings: dict[str, OIDCProvider] = dict( - [(provider.id, provider) for provider in settings.oidc.providers] -) diff --git a/src/oidc_test/templates/base.html b/src/oidc_test/templates/base.html index 2ce758c..4cb56f5 100644 --- a/src/oidc_test/templates/base.html +++ b/src/oidc_test/templates/base.html @@ -4,7 +4,7 @@ - +

OIDC-test - FastAPI client

{% block content %} {% endblock %} diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 08bcf43..7275f2d 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -8,7 +8,7 @@

Log in with:

- {% for provider in settings.oidc.providers %} + {% for provider in settings.auth.providers %}
{{ provider.name }}
@@ -32,7 +32,7 @@
{{ user.email }}
Provider: - {{ oidc_provider_settings.name }} + {{ auth_provider.name }}
{% if user.roles %}
@@ -50,9 +50,9 @@ {% endfor %}
{% endif %} - {% if oidc_provider_settings.account_url_template %} + {% if auth_provider.account_url_template %} @@ -67,21 +67,21 @@ Resources validated by scope:

Resources validated by role: