From d39adf41eff5dac79bde268c01602f5f6072385d Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Fri, 7 Feb 2025 13:57:17 +0100 Subject: [PATCH 01/46] Create a sub-app for resource server move all resources to resource server; use token bearer instead of session cookie for resources and use fetch instead of XMLHttpRequest for checking resource status; add UserWithRole class for fastapi depends (instead of has_role decorator); add asserts for typing QC; code formatting; comment out introspect endpoint processing --- src/oidc_test/auth_utils.py | 28 ++++++- src/oidc_test/database.py | 2 + src/oidc_test/main.py | 133 ++++-------------------------- src/oidc_test/models.py | 6 +- src/oidc_test/resource_server.py | 122 +++++++++++++++++++++++++-- src/oidc_test/static/utils.js | 32 ++++--- src/oidc_test/templates/base.html | 2 +- src/oidc_test/templates/home.html | 16 ++-- 8 files changed, 188 insertions(+), 153 deletions(-) diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index 281511d..1004527 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -13,7 +13,7 @@ from authlib.oauth2.auth import OAuth2Token from .models import User from .database import TokenNotInDb, db, UserNotInDB -from .settings import settings, OIDCProvider, oidc_providers_settings +from .settings import oidc_providers_settings logger = logging.getLogger("oidc-test") @@ -21,6 +21,8 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") async def fetch_token(name, request): + assert name is not None + assert request is not None logger.warn("TODO: fetch_token") ... # if name in oidc_providers: @@ -37,8 +39,10 @@ async def update_token(name, token, refresh_token=None, access_token=None): sid: str = oidc_provider_settings.decode(token["id_token"])["sid"] item = await db.get_token(oidc_provider_settings, sid) # update old token - item["access_token"] = token.get("access_token") - item["refresh_token"] = token.get("refresh_token") + if access_token is not None: + item["access_token"] = token.get("access_token") + if refresh_token is not None: + item["refresh_token"] = refresh_token item["expires_at"] = token["expires_at"] logger.info(f"Token {sid} refreshed") # It's a fake db and only in memory, so there's nothing to save @@ -119,6 +123,7 @@ async def get_current_user(request: Request) -> User: userinfo = await oidc_provider.fetch_access_token( refresh_token=token.get("refresh_token") ) + assert userinfo is not None except OAuthError as err: logger.exception(err) # raise HTTPException( @@ -242,3 +247,20 @@ async def get_user_from_token( access_token=token, ) return user + + +class UserWithRole: + roles: set[str] + + def __init__(self, roles: str | list[str] | tuple[str] | set[str]): + if isinstance(roles, str): + self.roles = set([roles]) + elif isinstance(roles, (list, tuple, set)): + self.roles = set(roles) + + def __call__(self, user: User = Depends(get_user_from_token)) -> User: + if not any(self.roles.intersection(user.roles_as_set)): + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, f"Not of any required role {', '.join(self.roles)}" + ) + return user diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index 360ef11..d3bdd4e 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -69,6 +69,7 @@ class Database: async def add_token(self, oidc_provider_settings: OIDCProvider, token: OAuth2Token) -> None: """Store a token using as key the sid (auth provider's session id) in the id_token""" + assert isinstance(oidc_provider_settings, OIDCProvider) sid = token["userinfo"]["sid"] self.tokens[sid] = token @@ -77,6 +78,7 @@ class Database: oidc_provider_settings: OIDCProvider, sid: str | None, ) -> OAuth2Token: + assert isinstance(oidc_provider_settings, OIDCProvider) if sid is None: raise TokenNotInDb try: diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 60482bd..47d0c39 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -12,7 +12,7 @@ from contextlib import asynccontextmanager from httpx import HTTPError from fastapi import Depends, FastAPI, HTTPException, Request, status from fastapi.staticfiles import StaticFiles -from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse +from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.templating import Jinja2Templates from fastapi.middleware.cors import CORSMiddleware from jwt import InvalidTokenError, PyJWTError @@ -31,17 +31,13 @@ from .models import User from .auth_utils import ( get_oidc_provider, get_oidc_provider_or_none, - hasrole, get_current_user_or_none, - get_current_user, - get_user_from_token, authlib_oauth, - get_token, get_providers_info, ) from .auth_misc import pretty_details from .database import TokenNotInDb, db -from .resource_server import get_resource +from .resource_server import resource_server logger = logging.getLogger("oidc-test") @@ -50,6 +46,7 @@ templates = Jinja2Templates(Path(__file__).parent / "templates") @asynccontextmanager async def lifespan(app: FastAPI): + assert app is not None await get_providers_info() yield @@ -64,24 +61,21 @@ app.add_middleware( allow_headers=["*"], ) -app.mount( - "/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static" -) - # SessionMiddleware is required by authlib app.add_middleware( SessionMiddleware, secret_key=settings.secret_key, ) +app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static") +app.mount("/resource", resource_server, name="resource_server") + @app.get("/") async def home( request: Request, user: Annotated[User, Depends(get_current_user_or_none)], - oidc_provider: Annotated[ - StarletteOAuth2App | None, Depends(get_oidc_provider_or_none) - ], + oidc_provider: Annotated[StarletteOAuth2App | None, Depends(get_oidc_provider_or_none)], ) -> HTMLResponse: now = datetime.now() if oidc_provider and ( @@ -119,9 +113,7 @@ async def home( "oidc_provider_settings": oidc_provider_settings, "resources": resources, "user_info_details": ( - pretty_details(user, now) - if user and settings.oidc.show_session_details - else None + pretty_details(user, now) if user and settings.oidc.show_session_details else None ), }, ) @@ -215,24 +207,19 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: else: # Not sure if it's correct to redirect to plain login # if no userinfo is provided - return RedirectResponse( - url=request.url_for("login", oidc_provider_id=oidc_provider_id) - ) + return RedirectResponse(url=request.url_for("login", oidc_provider_id=oidc_provider_id)) @app.get("/account") async def account( request: Request, - oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], ) -> RedirectResponse: if ( oidc_provider_settings := oidc_providers_settings.get( request.session.get("oidc_provider_id", "") ) ) is None: - raise HTTPException( - status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider settings" - ) + raise HTTPException(status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider settings") return RedirectResponse(f"{oidc_provider_settings.account_url_template}") @@ -244,12 +231,8 @@ async def logout( # Clear session request.session.pop("user_sub", None) # Get provider's endpoint - if ( - provider_logout_uri := oidc_provider.server_metadata.get("end_session_endpoint") - ) is None: - logger.warn( - f"Cannot find end_session_endpoint for provider {oidc_provider.name}" - ) + if (provider_logout_uri := oidc_provider.server_metadata.get("end_session_endpoint")) is None: + logger.warn(f"Cannot find end_session_endpoint for provider {oidc_provider.name}") return RedirectResponse(request.url_for("non_compliant_logout")) post_logout_uri = request.url_for("home") oidc_provider_settings = oidc_providers_settings.get( @@ -257,9 +240,7 @@ async def logout( ) assert oidc_provider_settings is not None try: - token = await db.get_token( - oidc_provider_settings, request.session.pop("sid", None) - ) + token = await db.get_token(oidc_provider_settings, request.session.pop("sid", None)) except TokenNotInDb: logger.warn("No session in db for the token or no token") return RedirectResponse(request.url_for("home")) @@ -292,90 +273,6 @@ async def non_compliant_logout( ) -# Route for OAuth resource server - - -@app.get("/resource/{id}") -async def get_resource_( - id: str, - # user: Annotated[User, Depends(get_current_user)], - # oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], - # token: Annotated[OAuth2Token, Depends(get_token)], - user: Annotated[User, Depends(get_user_from_token)], -) -> JSONResponse: - """Generic path for testing a resource provided by a provider""" - return JSONResponse(await get_resource(id, user)) - - -# Routes for RBAC based tests - - -@app.get("/public") -async def public() -> HTMLResponse: - return HTMLResponse("<h1>Not protected</h1>") - - -@app.get("/protected") -async def get_protected( - user: Annotated[User, Depends(get_current_user)] -) -> HTMLResponse: - assert user is not None # Just to keep QA checks happy - return HTMLResponse("<h1>Only authenticated users can see this</h1>") - - -@app.get("/protected-by-foorole") -@hasrole("foorole") -async def get_protected_by_foorole(request: Request) -> HTMLResponse: - assert request is not None # Just to keep QA checks happy - return HTMLResponse("<h1>Only users with foorole can see this</h1>") - - -@app.get("/protected-by-barrole") -@hasrole("barrole") -async def get_protected_by_barrole(request: Request) -> HTMLResponse: - assert request is not None # Just to keep QA checks happy - return HTMLResponse("<h1>Protected by barrole</h1>") - - -@app.get("/protected-by-foorole-and-barrole") -@hasrole("barrole") -@hasrole("foorole") -async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse: - assert request is not None # Just to keep QA checks happy - return HTMLResponse("<h1>Only users with foorole and barrole can see this</h1>") - - -@app.get("/protected-by-foorole-or-barrole") -@hasrole(["foorole", "barrole"]) -async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse: - assert request is not None # Just to keep QA checks happy - return HTMLResponse("<h1>Only users with foorole or barrole can see this</h1>") - - -@app.get("/introspect") -async def get_introspect( - request: Request, - oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], - token: Annotated[OAuth2Token, Depends(get_token)], -) -> JSONResponse: - assert request is not None # Just to keep QA checks happy - if (url := oidc_provider.server_metadata.get("introspection_endpoint")) is None: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="No intrispection endpoint found for the OIDC provider", - ) - if ( - response := await oidc_provider.post( - url, - token=token, - data={"token": token["access_token"]}, - ) - ).is_success: - return response.json() - else: - raise HTTPException(status_code=response.status_code, detail=response.text) - - # Snippet for running standalone # Mostly useful for the --version option, # as running with uvicorn is easy and provides better flexibility, eg. @@ -397,9 +294,7 @@ def main(): parser.add_argument( "-p", "--port", type=int, default=80, help="Port to listen to (default: 80)" ) - parser.add_argument( - "-v", "--version", action="store_true", help="Print version and exit" - ) + parser.add_argument("-v", "--version", action="store_true", help="Print version and exit") args = parser.parse_args() if args.version: diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py index fc0dba7..9554bd5 100644 --- a/src/oidc_test/models.py +++ b/src/oidc_test/models.py @@ -1,6 +1,6 @@ import logging from functools import cached_property -from typing import Self, Any +from typing import Any from pydantic import ( computed_field, @@ -60,6 +60,4 @@ class User(UserBase): assert self.oidc_provider.name is not None from .settings import oidc_providers_settings - return oidc_providers_settings[self.oidc_provider.name].decode( - self.access_token - ) + return oidc_providers_settings[self.oidc_provider.name].decode(self.access_token) diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index 0d90533..d5e2aaa 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -1,15 +1,127 @@ from datetime import datetime +from typing import Annotated import logging from httpx import AsyncClient from jwt.exceptions import ExpiredSignatureError, InvalidTokenError -from fastapi import HTTPException, status -from starlette.status import HTTP_401_UNAUTHORIZED +from fastapi import FastAPI, HTTPException, Depends, Request, status +from fastapi.responses import HTMLResponse, JSONResponse +from fastapi.middleware.cors import CORSMiddleware + +# from starlette.middleware.sessions import SessionMiddleware +# from authlib.integrations.starlette_client.apps import StarletteOAuth2App +# from authlib.oauth2.rfc6749 import OAuth2Token from .models import User +from .auth_utils import ( + get_user_from_token, + UserWithRole, + get_oidc_provider, + get_token, +) +from .settings import settings logger = logging.getLogger("oidc-test") +resource_server = FastAPI() + + +resource_server.add_middleware( + CORSMiddleware, + allow_origins=settings.cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# SessionMiddleware is required by authlib +# resource_server.add_middleware( +# SessionMiddleware, +# secret_key=settings.secret_key, +# ) + +# Route for OAuth resource server + + +# Routes for RBAC based tests + + +@resource_server.get("/public") +async def public() -> HTMLResponse: + return HTMLResponse("<h1>Not protected</h1>") + + +@resource_server.get("/protected") +async def get_protected(user: Annotated[User, Depends(get_user_from_token)]) -> HTMLResponse: + assert user is not None # Just to keep QA checks happy + return HTMLResponse("<h1>Only authenticated users can see this</h1>") + + +@resource_server.get("/protected-by-foorole") +async def get_protected_by_foorole( + user: Annotated[User, Depends(UserWithRole("foorole"))] +) -> HTMLResponse: + return HTMLResponse("<h1>Only users with foorole can see this</h1>") + + +@resource_server.get("/protected-by-barrole") +async def get_protected_by_barrole( + user: Annotated[User, Depends(UserWithRole("barrole"))] +) -> HTMLResponse: + return HTMLResponse("<h1>Protected by barrole</h1>") + + +@resource_server.get("/protected-by-foorole-and-barrole") +async def get_protected_by_foorole_and_barrole( + user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))], +) -> HTMLResponse: + assert user is not None # Just to keep QA checks happy + return HTMLResponse("<h1>Only users with foorole and barrole can see this</h1>") + + +@resource_server.get("/protected-by-foorole-or-barrole") +async def get_protected_by_foorole_or_barrole( + user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))] +) -> HTMLResponse: + assert user is not None # Just to keep QA checks happy + return HTMLResponse("<h1>Only users with foorole or barrole can see this</h1>") + + +# @resource_server.get("/introspect") +# async def get_introspect( +# request: Request, +# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], +# token: Annotated[OAuth2Token, Depends(get_token)], +# ) -> JSONResponse: +# assert request is not None # Just to keep QA checks happy +# if (url := oidc_provider.server_metadata.get("introspection_endpoint")) is None: +# raise HTTPException( +# status_code=status.HTTP_401_UNAUTHORIZED, +# detail="No introspection endpoint found for the OIDC provider", +# ) +# if ( +# response := await oidc_provider.post( +# url, +# token=token, +# data={"token": token["access_token"]}, +# ) +# ).is_success: +# return response.json() +# else: +# raise HTTPException(status_code=response.status_code, detail=response.text) + + +@resource_server.get("/{id}") +async def get_resource_( + id: str, + # user: Annotated[User, Depends(get_current_user)], + # oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], + # token: Annotated[OAuth2Token, Depends(get_token)], + user: Annotated[User, Depends(get_user_from_token)], +) -> JSONResponse: + """Generic path for testing a resource provided by a provider""" + return JSONResponse(await get_resource(id, user)) + async def get_resource(resource_id: str, user: User) -> dict: """ @@ -34,12 +146,10 @@ async def get_resource(resource_id: str, user: User) -> dict: raise HTTPException( status.HTTP_401_UNAUTHORIZED, f"No scope {required_scope} in the access token " - + "but it is required for accessing this resource.", + + "but it is required for accessing this resource", ) except ExpiredSignatureError: - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, "The token's signature has expired" - ) + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token's signature has expired") except InvalidTokenError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token is invalid") return resp diff --git a/src/oidc_test/static/utils.js b/src/oidc_test/static/utils.js index a982267..e6c4bfc 100644 --- a/src/oidc_test/static/utils.js +++ b/src/oidc_test/static/utils.js @@ -1,20 +1,28 @@ -function checkHref(elem) { - var xmlHttp = new XMLHttpRequest() - xmlHttp.onreadystatechange = function () { - if (xmlHttp.readyState == 4) { - elem.classList.add("hasResponseStatus") - elem.classList.add("status-" + xmlHttp.status) - elem.title = "Response code: " + xmlHttp.status + " - " + xmlHttp.statusText - } +async function checkHref(elem, token, authProvider) { + const msg = document.getElementById("msg") + const resp = await fetch(elem.href, { + headers: new Headers({ + "Content-type": "application/json", + "Authorization": `Bearer ${token}`, + "auth_provider": authProvider, + }), + }).catch(err => { + msg.innerHTML = "Cannot fetch resource: " + err.message + resourceElem.innerHTML = "" + }) + if (resp === undefined) { + return + } else { + elem.classList.add("hasResponseStatus") + elem.classList.add("status-" + resp.status) + elem.title = "Response code: " + resp.status + " - " + resp.statusText } - xmlHttp.open("GET", elem.href, true) // true for asynchronous - xmlHttp.send(null) } -function checkPerms(className) { +function checkPerms(className, token, authProvider) { var rootElems = document.getElementsByClassName(className) Array.from(rootElems).forEach(elem => - Array.from(elem.children).forEach(elem => checkHref(elem)) + Array.from(elem.children).forEach(elem => checkHref(elem, token, authProvider)) ) } diff --git a/src/oidc_test/templates/base.html b/src/oidc_test/templates/base.html index 3bdb3f3..0fe1a6b 100644 --- a/src/oidc_test/templates/base.html +++ b/src/oidc_test/templates/base.html @@ -4,7 +4,7 @@ <link href="{{ url_for('static', path='/styles.css') }}" rel="stylesheet"> <script src="{{ url_for('static', path='/utils.js') }}"></script> </head> - <body onload="checkPerms('links-to-check')"> + <body onload="checkPerms('links-to-check', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')"> <h1>OIDC-test - FastAPI client</h1> {% block content %} {% endblock %} diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index ce344cc..09c313f 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -80,14 +80,14 @@ These links should get different response codes depending on the authorization: </p> <div class="links-to-check"> - <a href="public">Public</a> - <a href="protected">Auth protected content</a> - <a href="protected-by-foorole">Auth + foorole protected content</a> - <a href="protected-by-foorole-or-barrole">Auth + foorole or barrole protected content</a> - <a href="protected-by-barrole">Auth + barrole protected content</a> - <a href="protected-by-foorole-and-barrole">Auth + foorole and barrole protected content</a> - <a href="fast_api_depends" class="hidden">Using FastAPI Depends</a> - <a href="introspect">Introspect token (401 expected)</a> + <a href="resource/public">Public</a> + <a href="resource/protected">Auth protected content</a> + <a href="resource/protected-by-foorole">Auth + foorole protected content</a> + <a href="resource/protected-by-foorole-or-barrole">Auth + foorole or barrole protected content</a> + <a href="resource/protected-by-barrole">Auth + barrole protected content</a> + <a href="resource/protected-by-foorole-and-barrole">Auth + foorole and barrole protected content</a> + <a href="resource/fast_api_depends" class="hidden">Using FastAPI Depends</a> + <!--<a href="resource/introspect">Introspect token (401 expected)</a>--> </div> {% if resources %} <p> From 3eb6dc3dcf4be7350aa4cda6b3157183fe77d5d8 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Fri, 7 Feb 2025 16:09:49 +0100 Subject: [PATCH 02/46] Migrate all resources to json contents; improve token decoding & logging error messages --- src/oidc_test/auth_utils.py | 17 ++++++++---- src/oidc_test/resource_server.py | 43 +++++++++++++++-------------- src/oidc_test/settings.py | 46 ++++++++++++------------------- src/oidc_test/static/styles.css | 10 ++----- src/oidc_test/static/utils.js | 3 +- src/oidc_test/templates/home.html | 45 ++++++++++++++---------------- 6 files changed, 77 insertions(+), 87 deletions(-) diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index 1004527..3303e58 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -5,7 +5,7 @@ import logging from fastapi import HTTPException, Request, Depends, status from fastapi.security import OAuth2PasswordBearer from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App -from jwt import ExpiredSignatureError, InvalidKeyError +from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError from httpx import AsyncClient # from authlib.oauth1.auth import OAuthToken @@ -147,8 +147,8 @@ async def get_token(request: Request) -> OAuth2Token: oidc_provider_settings, request.session.get("sid"), ) - except (TokenNotInDb, InvalidKeyError): - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token") + except (TokenNotInDb, InvalidKeyError, DecodeError) as err: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, err.__class__.__name__) async def get_current_user_or_none(request: Request) -> User | None: @@ -208,9 +208,14 @@ async def get_user_from_token( try: auth_provider_settings = oidc_providers_settings[auth_provider_id] except KeyError: - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'" - ) + if auth_provider_id == "": + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No auth provider") + else: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'" + ) + if token == "": + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token") try: payload = auth_provider_settings.decode(token) except ExpiredSignatureError as err: diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index d5e2aaa..cb944ed 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -4,8 +4,7 @@ import logging from httpx import AsyncClient from jwt.exceptions import ExpiredSignatureError, InvalidTokenError -from fastapi import FastAPI, HTTPException, Depends, Request, status -from fastapi.responses import HTMLResponse, JSONResponse +from fastapi import FastAPI, HTTPException, Depends, status from fastapi.middleware.cors import CORSMiddleware # from starlette.middleware.sessions import SessionMiddleware @@ -16,8 +15,8 @@ from .models import User from .auth_utils import ( get_user_from_token, UserWithRole, - get_oidc_provider, - get_token, + # get_oidc_provider, + # get_token, ) from .settings import settings @@ -47,44 +46,46 @@ resource_server.add_middleware( @resource_server.get("/public") -async def public() -> HTMLResponse: - return HTMLResponse("<h1>Not protected</h1>") +async def public() -> dict: + return {"msg": "Not protected"} @resource_server.get("/protected") -async def get_protected(user: Annotated[User, Depends(get_user_from_token)]) -> HTMLResponse: +async def get_protected(user: Annotated[User, Depends(get_user_from_token)]): assert user is not None # Just to keep QA checks happy - return HTMLResponse("<h1>Only authenticated users can see this</h1>") + return {"msg": "Only authenticated users can see this"} @resource_server.get("/protected-by-foorole") async def get_protected_by_foorole( - user: Annotated[User, Depends(UserWithRole("foorole"))] -) -> HTMLResponse: - return HTMLResponse("<h1>Only users with foorole can see this</h1>") + user: Annotated[User, Depends(UserWithRole("foorole"))], +): + assert user is not None + return {"msg": "Only users with foorole can see this"} @resource_server.get("/protected-by-barrole") async def get_protected_by_barrole( - user: Annotated[User, Depends(UserWithRole("barrole"))] -) -> HTMLResponse: - return HTMLResponse("<h1>Protected by barrole</h1>") + user: Annotated[User, Depends(UserWithRole("barrole"))], +): + assert user is not None + return {"msg": "Protected by barrole"} @resource_server.get("/protected-by-foorole-and-barrole") async def get_protected_by_foorole_and_barrole( user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))], -) -> HTMLResponse: +): assert user is not None # Just to keep QA checks happy - return HTMLResponse("<h1>Only users with foorole and barrole can see this</h1>") + return {"msg": "Only users with foorole and barrole can see this"} @resource_server.get("/protected-by-foorole-or-barrole") async def get_protected_by_foorole_or_barrole( - user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))] -) -> HTMLResponse: + user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))], +): assert user is not None # Just to keep QA checks happy - return HTMLResponse("<h1>Only users with foorole or barrole can see this</h1>") + return {"msg": "Only users with foorole or barrole can see this"} # @resource_server.get("/introspect") @@ -118,9 +119,9 @@ async def get_resource_( # oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], # token: Annotated[OAuth2Token, Depends(get_token)], user: Annotated[User, Depends(get_user_from_token)], -) -> JSONResponse: +): """Generic path for testing a resource provided by a provider""" - return JSONResponse(await get_resource(id, user)) + return await get_resource(id, user) async def get_resource(resource_id: str, user: User) -> dict: diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index 2544bd7..b601739 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -43,9 +43,7 @@ class OIDCProvider(BaseModel): info_url: str | None = ( None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key) ) - info: dict[str, str | int] | None = ( - None # Info fetched from info_url, eg. public key - ) + info: dict[str, str | int] | None = None # Info fetched from info_url, eg. public key public_key: str | None = None signature_alg: str = "RS256" resource_provider_scopes: list[str] = [] @@ -62,25 +60,17 @@ class OIDCProvider(BaseModel): def get_account_url(self, request: Request, user: User) -> str | None: if self.account_url_template: - if not ( - self.url.endswith("/") or self.account_url_template.startswith("/") - ): + if not (self.url.endswith("/") or self.account_url_template.startswith("/")): sep = "/" else: sep = "" - return ( - self.url - + sep - + self.account_url_template.format(request=request, user=user) - ) + return self.url + sep + self.account_url_template.format(request=request, user=user) else: return None def get_public_key(self) -> str: """Return the public key formatted for decoding token""" - public_key = self.public_key or ( - self.info is not None and self.info["public_key"] - ) + public_key = self.public_key or (self.info is not None and self.info["public_key"]) if public_key is None: raise AttributeError(f"Cannot get public key for {self.name}") return f""" @@ -91,17 +81,18 @@ class OIDCProvider(BaseModel): def decode(self, token: str, verify_signature: bool = True) -> dict[str, Any]: """Decode the token with signature check""" - decoded = decode( - token, - self.get_public_key(), - algorithms=[self.signature_alg], - audience=["account", "oidc-test", "oidc-test-web"], - options={ - "verify_signature": False, - "verify_aud": False, - }, # not settings.insecure.skip_verify_signature}, - ) - logger.debug(str(decoded)) + if settings.debug_token: + decoded = decode( + token, + self.get_public_key(), + algorithms=[self.signature_alg], + audience=["account", "oidc-test", "oidc-test-web"], + options={ + "verify_signature": False, + "verify_aud": False, + }, # not settings.insecure.skip_verify_signature}, + ) + logger.debug(str(decoded)) return decode( token, self.get_public_key(), @@ -143,6 +134,7 @@ class Settings(BaseSettings): log: bool = False insecure: Insecure = Insecure() cors_origins: list[str] = [] + debug_token: bool = False @classmethod def settings_customise_sources( @@ -161,9 +153,7 @@ class Settings(BaseSettings): settings_cls, Path( Path( - environ.get( - "OIDC_TEST_SETTINGS_FILE", Path.cwd() / "settings.yaml" - ), + environ.get("OIDC_TEST_SETTINGS_FILE", Path.cwd() / "settings.yaml"), ) ), ), diff --git a/src/oidc_test/static/styles.css b/src/oidc_test/static/styles.css index 7e1260b..6262d79 100644 --- a/src/oidc_test/static/styles.css +++ b/src/oidc_test/static/styles.css @@ -171,11 +171,13 @@ hr { gap: 0.5em; flex-flow: wrap; } -.content .links-to-check a { +.content .links-to-check button { color: black; padding: 5px 10px; text-decoration: none; border-radius: 8px; + border: none; + cursor: pointer; } .token { @@ -183,12 +185,6 @@ hr { font-family: monospace; } -.actions { - display: flex; - justify-content: center; - gap: 0.5em; -} - .resourceResult { padding: 0.5em; gap: 0.5em; diff --git a/src/oidc_test/static/utils.js b/src/oidc_test/static/utils.js index e6c4bfc..6ea8da2 100644 --- a/src/oidc_test/static/utils.js +++ b/src/oidc_test/static/utils.js @@ -1,6 +1,7 @@ async function checkHref(elem, token, authProvider) { const msg = document.getElementById("msg") - const resp = await fetch(elem.href, { + const url = `resource/${elem.getAttribute("resource-id")}` + const resp = await fetch(url, { headers: new Headers({ "Content-type": "application/json", "Authorization": `Bearer ${token}`, diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 09c313f..92b7068 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -61,33 +61,30 @@ </div> {% endif %} <hr> - {% if user %} - <p class="center"> - Fetch resources from the resource server with your authentication token: - </p> - <div class="actions"> - <button onclick="get_resource('time', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">Time</button> - <button onclick="get_resource('bs', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">BS</button> - </div> - <div class="resourceResult"> - <div id="resource" class="resource"></div> - <div id="msg" class="msg error"></div> - </div> - <hr> - {% endif %} <div class="content"> - <p> - These links should get different response codes depending on the authorization: + <p class="center"> + Resources validated by scope: </p> <div class="links-to-check"> - <a href="resource/public">Public</a> - <a href="resource/protected">Auth protected content</a> - <a href="resource/protected-by-foorole">Auth + foorole protected content</a> - <a href="resource/protected-by-foorole-or-barrole">Auth + foorole or barrole protected content</a> - <a href="resource/protected-by-barrole">Auth + barrole protected content</a> - <a href="resource/protected-by-foorole-and-barrole">Auth + foorole and barrole protected content</a> - <a href="resource/fast_api_depends" class="hidden">Using FastAPI Depends</a> - <!--<a href="resource/introspect">Introspect token (401 expected)</a>--> + <button resource-id="time" onclick="get_resource('time', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">Time</button> + <button resource-id="bs" onclick="get_resource('bs', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">BS</button> + </div> + <p> + Resources validated by role: + </p> + <div class="links-to-check"> + <button resource-id="public" onclick="get_resource('public', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">Public</button> + <button resource-id="protected" onclick="get_resource('protected', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">Auth protected content</button> + <button resource-id="protected-by-foorole" onclick="get_resource('protected-by-foorole', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">Auth + foorole protected content</button> + <button resource-id="protected-by-foorole-or-barrole" onclick="get_resource('protected-by-foorole-or-barrole', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">Auth + foorole or barrole protected content</button> + <button resource-id="protected-by-barrole" onclick="get_resource('protected-by-barrole', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">Auth + barrole protected content</button> + <button resource-id="protected-by-foorole-and-barrole" onclick="get_resource('protected-by-foorole-and-barrole', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">Auth + foorole and barrole protected content</button> + <button resource-id="fast_api_depends" class="hidden" onclick="get_resource('fast_api_depends', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">Using FastAPI Depends</button> + <!--<button resource-id="introspect" onclick="get_resource('introspect', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">Introspect token (401 expected)</button>--> + </div> + <div class="resourceResult"> + <div id="resource" class="resource"></div> + <div id="msg" class="msg error"></div> </div> {% if resources %} <p> From ff72f0cae585858e400a9bc8f7d3fe1727035c44 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Sat, 8 Feb 2025 01:55:36 +0100 Subject: [PATCH 03/46] Display full token info --- src/oidc_test/auth_utils.py | 14 +++++++++-- src/oidc_test/main.py | 41 +++++++++++++++++++------------ src/oidc_test/settings.py | 1 + src/oidc_test/static/styles.css | 4 +-- src/oidc_test/templates/home.html | 41 ++++++++++++++++++++++--------- 5 files changed, 70 insertions(+), 31 deletions(-) diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index 3303e58..26f3779 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -1,3 +1,4 @@ +import re from typing import Union, Annotated from functools import wraps import logging @@ -133,9 +134,18 @@ async def get_current_user(request: Request) -> User: return user +async def get_token_or_none(request: Request) -> OAuth2Token | None: + """Return the auth token from the session or None. + Can be used in Depends()""" + try: + return await get_token(request) + except HTTPException: + return None + + async def get_token(request: Request) -> OAuth2Token: - """Return the token from a request object, from the session. - It can be used in Depends()""" + """Return the token from the session. + Can be used in Depends()""" try: oidc_provider_settings = oidc_providers_settings[ request.session.get("oidc_provider_id", "") diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 47d0c39..4a037eb 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -34,6 +34,7 @@ from .auth_utils import ( get_current_user_or_none, authlib_oauth, get_providers_info, + get_token_or_none, ) from .auth_misc import pretty_details from .database import TokenNotInDb, db @@ -76,6 +77,7 @@ async def home( request: Request, user: Annotated[User, Depends(get_current_user_or_none)], oidc_provider: Annotated[StarletteOAuth2App | None, Depends(get_oidc_provider_or_none)], + token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], ) -> HTMLResponse: now = datetime.now() if oidc_provider and ( @@ -101,22 +103,29 @@ async def home( logger.info("Invalid token") logger.exception(err) - return templates.TemplateResponse( - name="home.html", - request=request, - context={ - "settings": settings.model_dump(), - "user": user, - "access_token_scope": access_token_scope, - "now": now, - "oidc_provider": oidc_provider, - "oidc_provider_settings": oidc_provider_settings, - "resources": resources, - "user_info_details": ( - pretty_details(user, now) if user and settings.oidc.show_session_details else None - ), - }, - ) + context = { + "settings": settings.model_dump(), + "user": user, + "access_token_scope": access_token_scope, + "now": now, + "oidc_provider": oidc_provider, + "oidc_provider_settings": oidc_provider_settings, + "resources": resources, + } + if token is None: + context["id_token_parsed"] = None + context["access_token_parsed"] = None + context["refresh_token_parsed"] = None + else: + assert oidc_provider is not None + assert oidc_provider.name is not None + oidc_provider_settings = oidc_providers_settings[oidc_provider.name] + context["id_token_parsed"] = pretty_details(user, now) + context["access_token_parsed"] = oidc_provider_settings.decode(token["access_token"]) + context["refresh_token_parsed"] = oidc_provider_settings.decode( + token["refresh_token"], verify_signature=False + ) + return templates.TemplateResponse(name="home.html", request=request, context=context) # Endpoints for the login / authorization process diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index b601739..e448c1e 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -135,6 +135,7 @@ class Settings(BaseSettings): insecure: Insecure = Insecure() cors_origins: list[str] = [] debug_token: bool = False + show_token: bool = False @classmethod def settings_customise_sources( diff --git a/src/oidc_test/static/styles.css b/src/oidc_test/static/styles.css index 6262d79..367ea99 100644 --- a/src/oidc_test/static/styles.css +++ b/src/oidc_test/static/styles.css @@ -73,7 +73,6 @@ hr { } .debug-auth p { border-bottom: 1px solid black; - text-align: left; } .debug-auth ul { padding: 0; @@ -185,8 +184,9 @@ hr { font-family: monospace; } -.resourceResult { +.resource { padding: 0.5em; + display: flex; gap: 0.5em; flex-direction: column; width: fit-content; diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 92b7068..9da5392 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -97,19 +97,38 @@ </div> {% endif %} </div> - {% if user_info_details %} - <hr> - <div class="debug-auth"> - <p>User info</p> - <ul> - {% for key, value in user_info_details.items() %} - <li> - <span class="key">{{ key }}</span>: <span class="value">{{ value }}</span> - </li> + {% if settings.show_token and id_token_parsed %} + <div class="token-info"> + <hr> + <div> + <h2>id token</h2> + <div class="token"> + {% for key, value in id_token_parsed.items() %} + <div> + <div class="key">{{ key }}</div> + <div class="value">{{ value }}</div> + </div> {% endfor %} - </ul> + </div> + <h2>access token</h2> + <div class="token"> + {% for key, value in access_token_parsed.items() %} + <div> + <div class="key">{{ key }}</div> + <div class="value">{{ value }}</div> + </div> + {% endfor %} + </div> + <h2>refresh token</h2> + <div class="token"> + {% for key, value in refresh_token_parsed.items() %} + <div> + <div class="key">{{ key }}</div> + <div class="value">{{ value }}</div> + </div> + {% endfor %} + </div> </div> - <div>Now is: {{ now.strftime("%T, %D") }} </div> </div> {% endif %} {% endblock %} From 923a63f5d527e5a4128d193bb9ac7adef956cf84 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Sat, 8 Feb 2025 18:32:02 +0100 Subject: [PATCH 04/46] Add refresh token button --- src/oidc_test/auth_utils.py | 19 +++++++++++-------- src/oidc_test/main.py | 30 +++++++++++++++++++++++++++--- src/oidc_test/models.py | 9 +++++++-- src/oidc_test/static/utils.js | 1 + src/oidc_test/templates/base.html | 2 +- src/oidc_test/templates/home.html | 21 +++++++++++---------- 6 files changed, 58 insertions(+), 24 deletions(-) diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index 26f3779..cab14b2 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -13,7 +13,7 @@ from httpx import AsyncClient from authlib.oauth2.auth import OAuth2Token from .models import User -from .database import TokenNotInDb, db, UserNotInDB +from .database import db, TokenNotInDb, UserNotInDB from .settings import oidc_providers_settings logger = logging.getLogger("oidc-test") @@ -36,14 +36,14 @@ async def fetch_token(name, request): async def update_token(name, token, refresh_token=None, access_token=None): + """Update the token in the database""" oidc_provider_settings = oidc_providers_settings[name] sid: str = oidc_provider_settings.decode(token["id_token"])["sid"] item = await db.get_token(oidc_provider_settings, sid) # update old token - if access_token is not None: - item["access_token"] = token.get("access_token") - if refresh_token is not None: - item["refresh_token"] = refresh_token + item["access_token"] = token["access_token"] + item["refresh_token"] = token["refresh_token"] + item["id_token"] = token["id_token"] item["expires_at"] = token["expires_at"] logger.info(f"Token {sid} refreshed") # It's a fake db and only in memory, so there's nothing to save @@ -70,8 +70,8 @@ def init_providers(): api_base_url=provider.url, # For PKCE (not implemented yet): # code_challenge_method="S256", - # fetch_token=fetch_token, - # update_token=update_token, + fetch_token=fetch_token, + update_token=update_token, # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) ) @@ -101,7 +101,10 @@ def get_oidc_provider_or_none(request: Request) -> StarletteOAuth2App | None: def get_oidc_provider(request: Request) -> StarletteOAuth2App: if (oidc_provider := get_oidc_provider_or_none(request)) is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") + if oidc_provider is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No provider") + else: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") else: return oidc_provider diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 4a037eb..81b354f 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -35,6 +35,8 @@ from .auth_utils import ( authlib_oauth, get_providers_info, get_token_or_none, + get_token, + update_token, ) from .auth_misc import pretty_details from .database import TokenNotInDb, db @@ -97,7 +99,7 @@ async def home( access_token_scope = None else: try: - access_token_scope = user.decode_access_token()["scope"] + access_token_scope = user.get_scope(verify_signature=False) except InvalidTokenError as err: access_token_scope = None logger.info("Invalid token") @@ -113,15 +115,22 @@ async def home( "resources": resources, } if token is None: + context["access_token"] = None context["id_token_parsed"] = None context["access_token_parsed"] = None context["refresh_token_parsed"] = None else: + context["access_token"] = token["access_token"] assert oidc_provider is not None assert oidc_provider.name is not None oidc_provider_settings = oidc_providers_settings[oidc_provider.name] - context["id_token_parsed"] = pretty_details(user, now) - context["access_token_parsed"] = oidc_provider_settings.decode(token["access_token"]) + # context["id_token_parsed"] = pretty_details(user, now) + context["id_token_parsed"] = oidc_provider_settings.decode( + token["id_token"], verify_signature=False + ) + context["access_token_parsed"] = oidc_provider_settings.decode( + token["access_token"], verify_signature=False + ) context["refresh_token_parsed"] = oidc_provider_settings.decode( token["refresh_token"], verify_signature=False ) @@ -282,6 +291,21 @@ async def non_compliant_logout( ) +@app.get("/refresh") +async def refresh( + request: Request, + oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], + token: Annotated[OAuth2Token, Depends(get_token)], +) -> RedirectResponse: + """Manually refresh token""" + new_token = await oidc_provider.fetch_access_token( + refresh_token=token["refresh_token"], + grant_type="refresh_token", + ) + await update_token(oidc_provider.name, new_token) + return RedirectResponse(url=request.url_for("home")) + + # Snippet for running standalone # Mostly useful for the --version option, # as running with uvicorn is easy and provides better flexibility, eg. diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py index 9554bd5..8aee2e6 100644 --- a/src/oidc_test/models.py +++ b/src/oidc_test/models.py @@ -54,10 +54,15 @@ class User(UserBase): access_token_scopes = [] return scope in set(info_scopes + access_token_scopes) - def decode_access_token(self): + def decode_access_token(self, verify_signature: bool = True): assert self.access_token is not None assert self.oidc_provider is not None assert self.oidc_provider.name is not None from .settings import oidc_providers_settings - return oidc_providers_settings[self.oidc_provider.name].decode(self.access_token) + return oidc_providers_settings[self.oidc_provider.name].decode( + self.access_token, verify_signature=verify_signature + ) + + def get_scope(self, verify_signature: bool = True): + return self.decode_access_token(verify_signature=verify_signature)["scope"] diff --git a/src/oidc_test/static/utils.js b/src/oidc_test/static/utils.js index 6ea8da2..8e8ad59 100644 --- a/src/oidc_test/static/utils.js +++ b/src/oidc_test/static/utils.js @@ -2,6 +2,7 @@ async function checkHref(elem, token, authProvider) { const msg = document.getElementById("msg") const url = `resource/${elem.getAttribute("resource-id")}` const resp = await fetch(url, { + method: "GET", headers: new Headers({ "Content-type": "application/json", "Authorization": `Bearer ${token}`, diff --git a/src/oidc_test/templates/base.html b/src/oidc_test/templates/base.html index 0fe1a6b..2ce758c 100644 --- a/src/oidc_test/templates/base.html +++ b/src/oidc_test/templates/base.html @@ -4,7 +4,7 @@ <link href="{{ url_for('static', path='/styles.css') }}" rel="stylesheet"> <script src="{{ url_for('static', path='/utils.js') }}"></script> </head> - <body onload="checkPerms('links-to-check', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')"> + <body onload="checkPerms('links-to-check', '{{ access_token }}', '{{ oidc_provider_settings.id }}')"> <h1>OIDC-test - FastAPI client</h1> {% block content %} {% endblock %} diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 9da5392..08bcf43 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -57,6 +57,7 @@ Account management </button> {% endif %} + <button onclick="location.href='{{ request.url_for("refresh") }}'" class="refresh">Refresh</button> <button onclick="location.href='{{ request.url_for("logout") }}'" class="logout">Logout</button> </div> {% endif %} @@ -66,21 +67,21 @@ Resources validated by scope: </p> <div class="links-to-check"> - <button resource-id="time" onclick="get_resource('time', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">Time</button> - <button resource-id="bs" onclick="get_resource('bs', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">BS</button> + <button resource-id="time" onclick="get_resource('time', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Time</button> + <button resource-id="bs" onclick="get_resource('bs', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">BS</button> </div> <p> Resources validated by role: </p> <div class="links-to-check"> - <button resource-id="public" onclick="get_resource('public', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">Public</button> - <button resource-id="protected" onclick="get_resource('protected', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">Auth protected content</button> - <button resource-id="protected-by-foorole" onclick="get_resource('protected-by-foorole', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">Auth + foorole protected content</button> - <button resource-id="protected-by-foorole-or-barrole" onclick="get_resource('protected-by-foorole-or-barrole', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">Auth + foorole or barrole protected content</button> - <button resource-id="protected-by-barrole" onclick="get_resource('protected-by-barrole', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">Auth + barrole protected content</button> - <button resource-id="protected-by-foorole-and-barrole" onclick="get_resource('protected-by-foorole-and-barrole', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">Auth + foorole and barrole protected content</button> - <button resource-id="fast_api_depends" class="hidden" onclick="get_resource('fast_api_depends', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">Using FastAPI Depends</button> - <!--<button resource-id="introspect" onclick="get_resource('introspect', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">Introspect token (401 expected)</button>--> + <button resource-id="public" onclick="get_resource('public', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Public</button> + <button resource-id="protected" onclick="get_resource('protected', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Auth protected content</button> + <button resource-id="protected-by-foorole" onclick="get_resource('protected-by-foorole', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Auth + foorole protected content</button> + <button resource-id="protected-by-foorole-or-barrole" onclick="get_resource('protected-by-foorole-or-barrole', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Auth + foorole or barrole protected content</button> + <button resource-id="protected-by-barrole" onclick="get_resource('protected-by-barrole', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Auth + barrole protected content</button> + <button resource-id="protected-by-foorole-and-barrole" onclick="get_resource('protected-by-foorole-and-barrole', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Auth + foorole and barrole protected content</button> + <button resource-id="fast_api_depends" class="hidden" onclick="get_resource('fast_api_depends', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Using FastAPI Depends</button> + <!--<button resource-id="introspect" onclick="get_resource('introspect', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Introspect token (401 expected)</button>--> </div> <div class="resourceResult"> <div id="resource" class="resource"></div> From 38b983c2a51ff1866e3306896a9e5e960bbe984b Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Sat, 8 Feb 2025 19:05:13 +0100 Subject: [PATCH 05/46] Fix scope --- src/oidc_test/main.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 81b354f..03d13d7 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -95,20 +95,9 @@ async def home( resources = [] oidc_provider_settings = None - if user is None: - access_token_scope = None - else: - try: - access_token_scope = user.get_scope(verify_signature=False) - except InvalidTokenError as err: - access_token_scope = None - logger.info("Invalid token") - logger.exception(err) - context = { "settings": settings.model_dump(), "user": user, - "access_token_scope": access_token_scope, "now": now, "oidc_provider": oidc_provider, "oidc_provider_settings": oidc_provider_settings, @@ -124,13 +113,15 @@ async def home( assert oidc_provider is not None assert oidc_provider.name is not None oidc_provider_settings = oidc_providers_settings[oidc_provider.name] + access_token_parsed = oidc_provider_settings.decode( + token["access_token"], verify_signature=False + ) + context["access_token_scope"] = access_token_parsed["scope"] # context["id_token_parsed"] = pretty_details(user, now) context["id_token_parsed"] = oidc_provider_settings.decode( token["id_token"], verify_signature=False ) - context["access_token_parsed"] = oidc_provider_settings.decode( - token["access_token"], verify_signature=False - ) + context["access_token_parsed"] = access_token_parsed context["refresh_token_parsed"] = oidc_provider_settings.decode( token["refresh_token"], verify_signature=False ) From c5bb4f4319445ba145ab08067593a6957ce37f42 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Sun, 9 Feb 2025 06:20:48 +0100 Subject: [PATCH 06/46] Refactor most code, isolate authlib somehow --- src/oidc_test/auth_misc.py | 29 ------- src/oidc_test/auth_provider.py | 43 ++++++++++ src/oidc_test/auth_utils.py | 80 +++++++++++-------- src/oidc_test/database.py | 27 +++---- src/oidc_test/main.py | 128 ++++++++++++------------------ src/oidc_test/models.py | 10 +-- src/oidc_test/resource_server.py | 9 ++- src/oidc_test/settings.py | 45 ++--------- src/oidc_test/templates/base.html | 2 +- src/oidc_test/templates/home.html | 28 +++---- 10 files changed, 183 insertions(+), 218 deletions(-) delete mode 100644 src/oidc_test/auth_misc.py create mode 100644 src/oidc_test/auth_provider.py diff --git a/src/oidc_test/auth_misc.py b/src/oidc_test/auth_misc.py deleted file mode 100644 index a4e9ea3..0000000 --- a/src/oidc_test/auth_misc.py +++ /dev/null @@ -1,29 +0,0 @@ -from datetime import datetime, timedelta -from collections import OrderedDict - -from .models import User - -time_keys = set(("iat", "exp", "auth_time", "updated_at")) - - -def pretty_details(user: User, now: datetime) -> OrderedDict: - details = OrderedDict() - # breakpoint() - for key in sorted(time_keys): - try: - dt = datetime.fromtimestamp(user.userinfo[key]) - except (KeyError, TypeError): - pass - else: - td = now - dt - td = timedelta(days=td.days, seconds=td.seconds) - if td.days < 0: - ptd = f"in {-td} h:m:s" - else: - ptd = f"{td} h:m:s ago" - details[key] = f"{user.userinfo[key]} - {dt} ({ptd})" - for key in sorted(user.userinfo): - if key in time_keys: - continue - details[key] = user.userinfo[key] - return details diff --git a/src/oidc_test/auth_provider.py b/src/oidc_test/auth_provider.py new file mode 100644 index 0000000..bed4596 --- /dev/null +++ b/src/oidc_test/auth_provider.py @@ -0,0 +1,43 @@ +from typing import Any +from jwt import decode +import logging + +from authlib.integrations.starlette_client.apps import StarletteOAuth2App + +from .settings import AuthProviderSettings, settings + +logger = logging.getLogger("oidc-test") + + +class Provider(AuthProviderSettings): + class Config: + arbitrary_types_allowed = True + + authlib_client: StarletteOAuth2App = StarletteOAuth2App(None) + + def decode(self, token: str, verify_signature: bool = True) -> dict[str, Any]: + """Decode the token with signature check""" + if settings.debug_token: + decoded = decode( + token, + self.get_public_key(), + algorithms=[self.signature_alg], + audience=["account", "oidc-test", "oidc-test-web"], + options={ + "verify_signature": False, + "verify_aud": False, + }, # not settings.insecure.skip_verify_signature}, + ) + logger.debug(str(decoded)) + return decode( + token, + self.get_public_key(), + algorithms=[self.signature_alg], + audience=["account", "oidc-test", "oidc-test-web"], + options={ + "verify_signature": verify_signature, + }, # not settings.insecure.skip_verify_signature}, + ) + + +providers: dict[str, Provider] = {} diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index cab14b2..e62fe39 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -1,4 +1,3 @@ -import re from typing import Union, Annotated from functools import wraps import logging @@ -14,7 +13,8 @@ from authlib.oauth2.auth import OAuth2Token from .models import User from .database import db, TokenNotInDb, UserNotInDB -from .settings import oidc_providers_settings +from .settings import settings +from .auth_provider import providers, Provider logger = logging.getLogger("oidc-test") @@ -35,11 +35,13 @@ async def fetch_token(name, request): # return token.to_token() -async def update_token(name, token, refresh_token=None, access_token=None): +async def update_token( + provider_id, token, refresh_token: str | None = None, access_token: str | None = None +): """Update the token in the database""" - oidc_provider_settings = oidc_providers_settings[name] - sid: str = oidc_provider_settings.decode(token["id_token"])["sid"] - item = await db.get_token(oidc_provider_settings, sid) + provider = providers[provider_id] + sid: str = provider.decode(token["id_token"])["sid"] + item = await db.get_token(provider, sid) # update old token item["access_token"] = token["access_token"] item["refresh_token"] = token["refresh_token"] @@ -54,10 +56,12 @@ authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_t def init_providers(): - # Add oidc providers to authlib from the settings - for id, provider in oidc_providers_settings.items(): + """Add oidc providers to authlib from the settings + and build the providers dict""" + for provider_settings in settings.auth.providers: + provider = Provider(**provider_settings.model_dump()) authlib_oauth.register( - name=id, + name=provider.id, server_metadata_url=provider.openid_configuration, client_kwargs={ "scope": " ".join( @@ -74,6 +78,8 @@ def init_providers(): update_token=update_token, # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) ) + provider.authlib_client = getattr(authlib_oauth, provider.id) + providers[provider.id] = provider init_providers() @@ -82,33 +88,41 @@ init_providers() async def get_providers_info(): # Get the public key: async with AsyncClient() as client: - for provider_settings in oidc_providers_settings.values(): - if provider_settings.info_url: - provider_info = await client.get(provider_settings.url) - provider_settings.info = provider_info.json() + for provider in providers.values(): + if provider.info_url: + provider_info = await client.get(provider.url) + provider.info = provider_info.json() -def get_oidc_provider_or_none(request: Request) -> StarletteOAuth2App | None: +def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None: """Return the oidc_provider from a request object, from the session. It can be used in Depends()""" - if (oidc_provider_id := request.session.get("oidc_provider_id")) is None: - return - try: - return getattr(authlib_oauth, str(oidc_provider_id)) - except AttributeError: + if (auth_provider_id := request.session.get("auth_provider_id")) is None: return + return getattr(authlib_oauth, str(auth_provider_id), None) -def get_oidc_provider(request: Request) -> StarletteOAuth2App: - if (oidc_provider := get_oidc_provider_or_none(request)) is None: - if oidc_provider is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No provider") - else: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") +def get_auth_provider_client(request: Request) -> StarletteOAuth2App: + if (oidc_provider := get_auth_provider_client_or_none(request)) is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") else: return oidc_provider +def get_auth_provider_or_none(request: Request) -> Provider | None: + """Return the oidc_provider settings from a request object, from the session. + It can be used in Depends()""" + if (auth_provider_id := request.session.get("auth_provider_id")) is None: + return + return providers.get(auth_provider_id) + + +def get_auth_provider(request: Request) -> Provider: + if (provider := get_auth_provider_or_none(request)) is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") + return provider + + async def get_current_user(request: Request) -> User: """Get the current user from a request object. Also validates the token expiration time. @@ -120,11 +134,11 @@ async def get_current_user(request: Request) -> User: user = await db.get_user(user_sub) ## Check if the token is expired if token.is_expired(): - oidc_provider = get_oidc_provider(request=request) + provider = get_auth_provider(request=request) ## Ask a new refresh token from the provider logger.info(f"Token expired for user {user.name}") try: - userinfo = await oidc_provider.fetch_access_token( + userinfo = await provider.authlib_client.fetch_access_token( refresh_token=token.get("refresh_token") ) assert userinfo is not None @@ -150,14 +164,12 @@ async def get_token(request: Request) -> OAuth2Token: """Return the token from the session. Can be used in Depends()""" try: - oidc_provider_settings = oidc_providers_settings[ - request.session.get("oidc_provider_id", "") - ] + provider = providers[request.session.get("auth_provider_id", "")] except KeyError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid provider") try: return await db.get_token( - oidc_provider_settings, + provider, request.session.get("sid"), ) except (TokenNotInDb, InvalidKeyError, DecodeError) as err: @@ -219,7 +231,7 @@ async def get_user_from_token( "Request headers must have a 'auth_provider' field", ) try: - auth_provider_settings = oidc_providers_settings[auth_provider_id] + provider = providers[auth_provider_id] except KeyError: if auth_provider_id == "": raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No auth provider") @@ -230,7 +242,7 @@ async def get_user_from_token( if token == "": raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token") try: - payload = auth_provider_settings.decode(token) + payload = provider.decode(token) except ExpiredSignatureError as err: logger.info(f"Expired signature: {err}") raise HTTPException( @@ -261,7 +273,7 @@ async def get_user_from_token( user = await db.add_user( sub=payload["sub"], user_info=payload, - oidc_provider=getattr(authlib_oauth, auth_provider_id), + auth_provider=getattr(authlib_oauth, auth_provider_id), access_token=token, ) return user diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index d3bdd4e..3493429 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -2,11 +2,10 @@ import logging -from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.oauth2.rfc6749 import OAuth2Token -from .settings import OIDCProvider, oidc_providers_settings from .models import User, Role +from .auth_provider import Provider, providers logger = logging.getLogger("oidc-test") @@ -29,23 +28,23 @@ class Database: self, sub: str, user_info: dict, - oidc_provider: StarletteOAuth2App, + auth_provider: Provider, access_token: str, access_token_decoded: dict | None = None, ) -> User: if access_token_decoded is None: - assert oidc_provider.name is not None - oidc_provider_settings = oidc_providers_settings[oidc_provider.name] - access_token_decoded = oidc_provider_settings.decode(access_token) + assert auth_provider.name is not None + provider = providers[auth_provider.id] + access_token_decoded = provider.decode(access_token) + user_info["auth_provider_id"] = auth_provider.id user = User(**user_info) user.userinfo = user_info - user.oidc_provider = oidc_provider - user.access_token = access_token - user.access_token_decoded = access_token_decoded + # user.access_token = access_token + # user.access_token_decoded = access_token_decoded # Add roles provided in the access token roles = set() try: - r = access_token_decoded["resource_access"][oidc_provider.client_id]["roles"] + r = access_token_decoded["resource_access"][auth_provider.client_id]["roles"] roles.update(r) except KeyError: pass @@ -66,19 +65,19 @@ class Database: raise UserNotInDB return self.users[sub] - async def add_token(self, oidc_provider_settings: OIDCProvider, token: OAuth2Token) -> None: + async def add_token(self, provider: Provider, token: OAuth2Token) -> None: """Store a token using as key the sid (auth provider's session id) in the id_token""" - assert isinstance(oidc_provider_settings, OIDCProvider) + assert isinstance(provider, Provider) sid = token["userinfo"]["sid"] self.tokens[sid] = token async def get_token( self, - oidc_provider_settings: OIDCProvider, + provider: Provider, sid: str | None, ) -> OAuth2Token: - assert isinstance(oidc_provider_settings, OIDCProvider) + assert isinstance(provider, Provider) if sid is None: raise TokenNotInDb try: diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 03d13d7..304df92 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -15,7 +15,7 @@ from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.templating import Jinja2Templates from fastapi.middleware.cors import CORSMiddleware -from jwt import InvalidTokenError, PyJWTError +from jwt import PyJWTError from starlette.middleware.sessions import SessionMiddleware from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.base_client import OAuthError @@ -26,11 +26,12 @@ from authlib.oauth2.rfc6749 import OAuth2Token # from fastapi.security import OpenIdConnect # from pkce import generate_code_verifier, generate_pkce_pair -from .settings import settings, oidc_providers_settings +from .settings import settings +from .auth_provider import Provider, providers from .models import User from .auth_utils import ( - get_oidc_provider, - get_oidc_provider_or_none, + get_auth_provider, + get_auth_provider_or_none, get_current_user_or_none, authlib_oauth, get_providers_info, @@ -38,7 +39,6 @@ from .auth_utils import ( get_token, update_token, ) -from .auth_misc import pretty_details from .database import TokenNotInDb, db from .resource_server import resource_server @@ -78,51 +78,30 @@ app.mount("/resource", resource_server, name="resource_server") async def home( request: Request, user: Annotated[User, Depends(get_current_user_or_none)], - oidc_provider: Annotated[StarletteOAuth2App | None, Depends(get_oidc_provider_or_none)], + provider: Annotated[Provider | None, Depends(get_auth_provider_or_none)], token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], ) -> HTMLResponse: - now = datetime.now() - if oidc_provider and ( - ( - oidc_provider_settings := oidc_providers_settings.get( - request.session.get("oidc_provider_id", "") - ) - ) - is not None - ): - resources = oidc_provider_settings.resources - else: - resources = [] - oidc_provider_settings = None - context = { "settings": settings.model_dump(), "user": user, - "now": now, - "oidc_provider": oidc_provider, - "oidc_provider_settings": oidc_provider_settings, - "resources": resources, + "now": datetime.now(), + "auth_provider": provider, } - if token is None: + if provider is None or token is None: context["access_token"] = None context["id_token_parsed"] = None context["access_token_parsed"] = None context["refresh_token_parsed"] = None + context["resources"] = None else: context["access_token"] = token["access_token"] - assert oidc_provider is not None - assert oidc_provider.name is not None - oidc_provider_settings = oidc_providers_settings[oidc_provider.name] - access_token_parsed = oidc_provider_settings.decode( - token["access_token"], verify_signature=False - ) + context["resources"] = provider.resources + access_token_parsed = provider.decode(token["access_token"], verify_signature=False) context["access_token_scope"] = access_token_parsed["scope"] # context["id_token_parsed"] = pretty_details(user, now) - context["id_token_parsed"] = oidc_provider_settings.decode( - token["id_token"], verify_signature=False - ) + context["id_token_parsed"] = provider.decode(token["id_token"], verify_signature=False) context["access_token_parsed"] = access_token_parsed - context["refresh_token_parsed"] = oidc_provider_settings.decode( + context["refresh_token_parsed"] = provider.decode( token["refresh_token"], verify_signature=False ) return templates.TemplateResponse(name="home.html", request=request, context=context) @@ -131,20 +110,20 @@ async def home( # Endpoints for the login / authorization process -@app.get("/login/{oidc_provider_id}") -async def login(request: Request, oidc_provider_id: str) -> RedirectResponse: +@app.get("/login/{auth_provider_id}") +async def login(request: Request, auth_provider_id: str) -> RedirectResponse: """Login with the provider id, giving the browser a redirect to its authorize page. - The provider is expected to send the browser back to our own /auth/{oidc_provider_id} url + The provider is expected to send the browser back to our own /auth/{auth_provider_id} url with the token. """ - redirect_uri = request.url_for("auth", oidc_provider_id=oidc_provider_id) + redirect_uri = request.url_for("auth", auth_provider_id=auth_provider_id) try: - provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id) + provider: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id) except AttributeError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") # if ( - # code_challenge_method := oidc_providers_settings[ - # oidc_provider_id + # code_challenge_method := providers[ + # auth_provider_id # ].code_challenge_method # ) is not None: # #client = AsyncOAuth2Client(..., code_challenge_method=code_challenge_method) @@ -164,30 +143,30 @@ async def login(request: Request, oidc_provider_id: str) -> RedirectResponse: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider") -@app.get("/auth/{oidc_provider_id}") -async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: +@app.get("/auth/{auth_provider_id}") +async def auth(request: Request, auth_provider_id: str) -> RedirectResponse: """Decrypt the auth token, store it to the session (cookie based) and response to the browser with a redirect to a "welcome user" page. """ try: - oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id) + authlib_client: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id) except AttributeError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") try: - token: OAuth2Token = await oidc_provider.authorize_access_token(request) + token: OAuth2Token = await authlib_client.authorize_access_token(request) except OAuthError as error: raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error) - # Remember the oidc_provider in the session + # Remember the authlib_client in the session # logger.info(f"Scope: {token['scope']}") - request.session["oidc_provider_id"] = oidc_provider_id + request.session["auth_provider_id"] = auth_provider_id # # One could process the full decoded token which contains extra information # eg for updates. Here we are only interested in roles # if userinfo := token.get("userinfo"): - # Remember the oidc_provider in the session - request.session["oidc_provider_id"] = oidc_provider_id - # User id (sub) given by oidc provider + # Remember the authlib_client in the session + request.session["auth_provider_id"] = auth_provider_id + # User id (sub) given by auth provider sub = userinfo["sub"] # Build and remember the user in the session request.session["user_sub"] = sub @@ -196,7 +175,7 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: user = await db.add_user( sub, user_info=userinfo, - oidc_provider=oidc_provider, + auth_provider=providers[auth_provider_id], access_token=token["access_token"], ) except PyJWTError as err: @@ -208,48 +187,41 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: # Add the provider session id to the session request.session["sid"] = userinfo["sid"] # Add the token to the db because it is used for logout - assert oidc_provider.name is not None - oidc_provider_settings = oidc_providers_settings[oidc_provider.name] - await db.add_token(oidc_provider_settings, token) + provider = providers[auth_provider_id] + await db.add_token(provider, token) # Send the user to the home: (s)he is authenticated return RedirectResponse(url=request.url_for("home")) else: # Not sure if it's correct to redirect to plain login # if no userinfo is provided - return RedirectResponse(url=request.url_for("login", oidc_provider_id=oidc_provider_id)) + return RedirectResponse(url=request.url_for("login", auth_provider_id=auth_provider_id)) @app.get("/account") async def account( - request: Request, + provider: Annotated[Provider, Depends(get_auth_provider)], ) -> RedirectResponse: - if ( - oidc_provider_settings := oidc_providers_settings.get( - request.session.get("oidc_provider_id", "") - ) - ) is None: - raise HTTPException(status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider settings") - return RedirectResponse(f"{oidc_provider_settings.account_url_template}") + """Redirect to the auth provider account management, + if account_url_template is in the provider's settings""" + return RedirectResponse(f"{provider.account_url_template}") @app.get("/logout") async def logout( request: Request, - oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], + provider: Annotated[Provider, Depends(get_auth_provider)], ) -> RedirectResponse: # Clear session request.session.pop("user_sub", None) # Get provider's endpoint - if (provider_logout_uri := oidc_provider.server_metadata.get("end_session_endpoint")) is None: - logger.warn(f"Cannot find end_session_endpoint for provider {oidc_provider.name}") + if ( + provider_logout_uri := provider.authlib_client.server_metadata.get("end_session_endpoint") + ) is None: + logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}") return RedirectResponse(request.url_for("non_compliant_logout")) post_logout_uri = request.url_for("home") - oidc_provider_settings = oidc_providers_settings.get( - request.session.get("oidc_provider_id", "") - ) - assert oidc_provider_settings is not None try: - token = await db.get_token(oidc_provider_settings, request.session.pop("sid", None)) + token = await db.get_token(provider, request.session.pop("sid", None)) except TokenNotInDb: logger.warn("No session in db for the token or no token") return RedirectResponse(request.url_for("home")) @@ -270,30 +242,30 @@ async def logout( @app.get("/non-compliant-logout") async def non_compliant_logout( request: Request, - oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], + provider: Annotated[StarletteOAuth2App, Depends(get_auth_provider)], ): """A page for non-compliant OAuth2 servers that we cannot log out.""" # Clear the remain of the session - request.session.pop("oidc_provider_id", None) + request.session.pop("auth_provider_id", None) return templates.TemplateResponse( name="non_compliant_logout.html", request=request, - context={"oidc_provider": oidc_provider, "home_url": request.url_for("home")}, + context={"oidc_provider": provider, "home_url": request.url_for("home")}, ) @app.get("/refresh") async def refresh( request: Request, - oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], + provider: Annotated[Provider, Depends(get_auth_provider)], token: Annotated[OAuth2Token, Depends(get_token)], ) -> RedirectResponse: """Manually refresh token""" - new_token = await oidc_provider.fetch_access_token( + new_token = await provider.authlib_client.fetch_access_token( refresh_token=token["refresh_token"], grant_type="refresh_token", ) - await update_token(oidc_provider.name, new_token) + await update_token(provider.id, new_token) return RedirectResponse(url=request.url_for("home")) diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py index 8aee2e6..eda63a6 100644 --- a/src/oidc_test/models.py +++ b/src/oidc_test/models.py @@ -8,7 +8,6 @@ from pydantic import ( EmailStr, ConfigDict, ) -from authlib.integrations.starlette_client.apps import StarletteOAuth2App from sqlmodel import SQLModel, Field logger = logging.getLogger("oidc-test") @@ -37,7 +36,7 @@ class User(UserBase): userinfo: dict = {} access_token: str | None = None access_token_decoded: dict[str, Any] | None = None - oidc_provider: StarletteOAuth2App | None = None + auth_provider_id: str @computed_field @cached_property @@ -56,11 +55,10 @@ class User(UserBase): def decode_access_token(self, verify_signature: bool = True): assert self.access_token is not None - assert self.oidc_provider is not None - assert self.oidc_provider.name is not None - from .settings import oidc_providers_settings + assert self.auth_provider_id is not None + from .auth_provider import providers - return oidc_providers_settings[self.oidc_provider.name].decode( + return providers[self.auth_provider_id].decode( self.access_token, verify_signature=verify_signature ) diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index cb944ed..e5670ed 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -2,6 +2,7 @@ from datetime import datetime from typing import Annotated import logging +from authlib.jose import Key from httpx import AsyncClient from jwt.exceptions import ExpiredSignatureError, InvalidTokenError from fastapi import FastAPI, HTTPException, Depends, status @@ -15,10 +16,9 @@ from .models import User from .auth_utils import ( get_user_from_token, UserWithRole, - # get_oidc_provider, - # get_token, ) from .settings import settings +from .auth_provider import providers logger = logging.getLogger("oidc-test") @@ -128,7 +128,10 @@ async def get_resource(resource_id: str, user: User) -> dict: """ Resource processing: build an informative rely as a simple showcase """ - pname = getattr(user.oidc_provider, "name", "?") + try: + pname = providers[user.auth_provider_id].name + except KeyError: + pname = "?" resp = { "hello": f"Hi {user.name} from an OAuth resource provider", "comment": f"I received a request for '{resource_id}' " diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index e448c1e..9a789a0 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -1,11 +1,9 @@ from os import environ import string import random -from typing import Type, Tuple, Any +from typing import Type, Tuple from pathlib import Path -import logging -from jwt import decode from pydantic import BaseModel, computed_field, AnyUrl from pydantic_settings import ( BaseSettings, @@ -17,8 +15,6 @@ from starlette.requests import Request from .models import User -logger = logging.getLogger("oidc-test") - class Resource(BaseModel): """A resource with an URL that can be accessed with an OAuth2 access token""" @@ -27,8 +23,8 @@ class Resource(BaseModel): name: str -class OIDCProvider(BaseModel): - """OIDC provider, can also be a resource server""" +class AuthProviderSettings(BaseModel): + """Auth provider, can also be a resource server""" id: str name: str @@ -79,30 +75,6 @@ class OIDCProvider(BaseModel): -----END PUBLIC KEY----- """ - def decode(self, token: str, verify_signature: bool = True) -> dict[str, Any]: - """Decode the token with signature check""" - if settings.debug_token: - decoded = decode( - token, - self.get_public_key(), - algorithms=[self.signature_alg], - audience=["account", "oidc-test", "oidc-test-web"], - options={ - "verify_signature": False, - "verify_aud": False, - }, # not settings.insecure.skip_verify_signature}, - ) - logger.debug(str(decoded)) - return decode( - token, - self.get_public_key(), - algorithms=[self.signature_alg], - audience=["account", "oidc-test", "oidc-test-web"], - options={ - "verify_signature": verify_signature, - }, # not settings.insecure.skip_verify_signature}, - ) - class ResourceProvider(BaseModel): id: str @@ -111,9 +83,9 @@ class ResourceProvider(BaseModel): resources: list[Resource] = [] -class OIDCSettings(BaseModel): +class AuthSettings(BaseModel): show_session_details: bool = False - providers: list[OIDCProvider] = [] + providers: list[AuthProviderSettings] = [] swagger_provider: str = "" @@ -128,7 +100,7 @@ class Settings(BaseSettings): model_config = SettingsConfigDict(env_nested_delimiter="__") - oidc: OIDCSettings = OIDCSettings() + auth: AuthSettings = AuthSettings() resource_providers: list[ResourceProvider] = [] secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16)) log: bool = False @@ -163,8 +135,3 @@ class Settings(BaseSettings): settings = Settings() - - -oidc_providers_settings: dict[str, OIDCProvider] = dict( - [(provider.id, provider) for provider in settings.oidc.providers] -) diff --git a/src/oidc_test/templates/base.html b/src/oidc_test/templates/base.html index 2ce758c..4cb56f5 100644 --- a/src/oidc_test/templates/base.html +++ b/src/oidc_test/templates/base.html @@ -4,7 +4,7 @@ <link href="{{ url_for('static', path='/styles.css') }}" rel="stylesheet"> <script src="{{ url_for('static', path='/utils.js') }}"></script> </head> - <body onload="checkPerms('links-to-check', '{{ access_token }}', '{{ oidc_provider_settings.id }}')"> + <body onload="checkPerms('links-to-check', '{{ access_token }}', '{{ auth_provider.id }}')"> <h1>OIDC-test - FastAPI client</h1> {% block content %} {% endblock %} diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 08bcf43..7275f2d 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -8,7 +8,7 @@ <div class="login-box"> <p class="description">Log in with:</p> <table class="providers"> - {% for provider in settings.oidc.providers %} + {% for provider in settings.auth.providers %} <tr class="provider"> <td> <a class="link" href="login/{{ provider.id }}"><div>{{ provider.name }}</div></a> @@ -32,7 +32,7 @@ <div>{{ user.email }}</div> <div> <span>Provider:</span> - {{ oidc_provider_settings.name }} + {{ auth_provider.name }} </div> {% if user.roles %} <div> @@ -50,9 +50,9 @@ {% endfor %} </div> {% endif %} - {% if oidc_provider_settings.account_url_template %} + {% if auth_provider.account_url_template %} <button - onclick="location.href='{{ oidc_provider_settings.get_account_url(request, user) }}'" + onclick="location.href='{{ auth_provider.get_account_url(request, user) }}'" class="account"> Account management </button> @@ -67,21 +67,21 @@ Resources validated by scope: </p> <div class="links-to-check"> - <button resource-id="time" onclick="get_resource('time', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Time</button> - <button resource-id="bs" onclick="get_resource('bs', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">BS</button> + <button resource-id="time" onclick="get_resource('time', '{{ access_token }}', '{{ auth_provider.id }}')">Time</button> + <button resource-id="bs" onclick="get_resource('bs', '{{ access_token }}', '{{ auth_provider.id }}')">BS</button> </div> <p> Resources validated by role: </p> <div class="links-to-check"> - <button resource-id="public" onclick="get_resource('public', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Public</button> - <button resource-id="protected" onclick="get_resource('protected', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Auth protected content</button> - <button resource-id="protected-by-foorole" onclick="get_resource('protected-by-foorole', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Auth + foorole protected content</button> - <button resource-id="protected-by-foorole-or-barrole" onclick="get_resource('protected-by-foorole-or-barrole', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Auth + foorole or barrole protected content</button> - <button resource-id="protected-by-barrole" onclick="get_resource('protected-by-barrole', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Auth + barrole protected content</button> - <button resource-id="protected-by-foorole-and-barrole" onclick="get_resource('protected-by-foorole-and-barrole', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Auth + foorole and barrole protected content</button> - <button resource-id="fast_api_depends" class="hidden" onclick="get_resource('fast_api_depends', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Using FastAPI Depends</button> - <!--<button resource-id="introspect" onclick="get_resource('introspect', '{{ access_token }}', '{{ oidc_provider_settings.id }}')">Introspect token (401 expected)</button>--> + <button resource-id="public" onclick="get_resource('public', '{{ access_token }}', '{{ auth_provider.id }}')">Public</button> + <button resource-id="protected" onclick="get_resource('protected', '{{ access_token }}', '{{ auth_provider.id }}')">Auth protected content</button> + <button resource-id="protected-by-foorole" onclick="get_resource('protected-by-foorole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole protected content</button> + <button resource-id="protected-by-foorole-or-barrole" onclick="get_resource('protected-by-foorole-or-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole or barrole protected content</button> + <button resource-id="protected-by-barrole" onclick="get_resource('protected-by-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + barrole protected content</button> + <button resource-id="protected-by-foorole-and-barrole" onclick="get_resource('protected-by-foorole-and-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole and barrole protected content</button> + <button resource-id="fast_api_depends" class="hidden" onclick="get_resource('fast_api_depends', '{{ access_token }}', '{{ auth_provider.id }}')">Using FastAPI Depends</button> + <!--<button resource-id="introspect" onclick="get_resource('introspect', '{{ access_token }}', '{{ auth_provider.id }}')">Introspect token (401 expected)</button>--> </div> <div class="resourceResult"> <div id="resource" class="resource"></div> From 496ce016e3d31fdb003146a3f84bc86726275e6d Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Mon, 10 Feb 2025 02:05:34 +0100 Subject: [PATCH 07/46] Continue refactor; fetch resources from the providers' settings --- src/oidc_test/auth_provider.py | 64 +++++++++++++++-- src/oidc_test/auth_utils.py | 45 ++++++------ src/oidc_test/database.py | 9 ++- src/oidc_test/main.py | 68 +++++++++++++------ src/oidc_test/models.py | 6 +- src/oidc_test/resource_server.py | 50 +++++++++++--- src/oidc_test/settings.py | 16 ++--- src/oidc_test/static/utils.js | 17 +++-- src/oidc_test/templates/home.html | 49 +++++++------ .../templates/non_compliant_logout.html | 6 +- 10 files changed, 217 insertions(+), 113 deletions(-) diff --git a/src/oidc_test/auth_provider.py b/src/oidc_test/auth_provider.py index bed4596..e50241c 100644 --- a/src/oidc_test/auth_provider.py +++ b/src/oidc_test/auth_provider.py @@ -1,26 +1,41 @@ +from json import JSONDecodeError from typing import Any from jwt import decode import logging +from collections import OrderedDict +from pydantic import ConfigDict from authlib.integrations.starlette_client.apps import StarletteOAuth2App +from httpx import AsyncClient from .settings import AuthProviderSettings, settings +from .models import User logger = logging.getLogger("oidc-test") +class NoPublicKey(Exception): + pass + + class Provider(AuthProviderSettings): - class Config: - arbitrary_types_allowed = True + # To allow authlib_client as StarletteOAuth2App + model_config = ConfigDict(arbitrary_types_allowed=True) # type:ignore authlib_client: StarletteOAuth2App = StarletteOAuth2App(None) + info: dict[str, Any] = {} + unknown_auth_user: User - def decode(self, token: str, verify_signature: bool = True) -> dict[str, Any]: + def decode(self, token: str, verify_signature: bool | None = None) -> dict[str, Any]: """Decode the token with signature check""" + if self.public_key is None: + raise NoPublicKey + if verify_signature is None: + verify_signature = self.skip_verify_signature if settings.debug_token: decoded = decode( token, - self.get_public_key(), + self.public_key, algorithms=[self.signature_alg], audience=["account", "oidc-test", "oidc-test-web"], options={ @@ -31,7 +46,7 @@ class Provider(AuthProviderSettings): logger.debug(str(decoded)) return decode( token, - self.get_public_key(), + self.public_key, algorithms=[self.signature_alg], audience=["account", "oidc-test", "oidc-test-web"], options={ @@ -39,5 +54,42 @@ class Provider(AuthProviderSettings): }, # not settings.insecure.skip_verify_signature}, ) + async def get_info(self): + # Get the public key: + async with AsyncClient() as client: + public_key: str | None = None + if self.info_url is not None: + try: + provider_info = await client.get(self.info_url) + except Exception: + raise NoPublicKey + try: + self.info = provider_info.json() + except JSONDecodeError: + raise NoPublicKey + if "public_key" in self.info: + # For Keycloak + try: + public_key = str(self.info["public_key"]) + except KeyError: + raise NoPublicKey + elif "keys" in self.info: + # For Forgejo/Gitea + try: + public_key = str(self.info["keys"][0]["n"]) + except KeyError: + raise NoPublicKey + if self.public_key_url is not None: + resp = await client.get(self.public_key_url) + public_key = resp.text + if public_key is None: + raise NoPublicKey + self.public_key = "\n".join( + ["-----BEGIN PUBLIC KEY-----", public_key, "-----END PUBLIC KEY-----"] + ) -providers: dict[str, Provider] = {} + def get_session_key(self, userinfo): + return userinfo[self.session_key] + + +providers: OrderedDict[str, Provider] = OrderedDict() diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index e62fe39..8cd5028 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -5,8 +5,7 @@ import logging from fastapi import HTTPException, Request, Depends, status from fastapi.security import OAuth2PasswordBearer from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App -from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError -from httpx import AsyncClient +from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError # from authlib.oauth1.auth import OAuthToken from authlib.oauth2.auth import OAuth2Token @@ -40,7 +39,7 @@ async def update_token( ): """Update the token in the database""" provider = providers[provider_id] - sid: str = provider.decode(token["id_token"])["sid"] + sid: str = provider.get_session_key(provider.decode(token["id_token"])) item = await db.get_token(provider, sid) # update old token item["access_token"] = token["access_token"] @@ -59,7 +58,12 @@ def init_providers(): """Add oidc providers to authlib from the settings and build the providers dict""" for provider_settings in settings.auth.providers: - provider = Provider(**provider_settings.model_dump()) + provider_settings_dict = provider_settings.model_dump() + # Add an anonymous user, that cannot be identified but has provided a valid access token + provider_settings_dict["unknown_auth_user"] = User( + sub="", auth_provider_id=provider_settings.id + ) + provider = Provider(**provider_settings_dict) authlib_oauth.register( name=provider.id, server_metadata_url=provider.openid_configuration, @@ -85,15 +89,6 @@ def init_providers(): init_providers() -async def get_providers_info(): - # Get the public key: - async with AsyncClient() as client: - for provider in providers.values(): - if provider.info_url: - provider_info = await client.get(provider.url) - provider.info = provider_info.json() - - def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None: """Return the oidc_provider from a request object, from the session. It can be used in Depends()""" @@ -166,6 +161,8 @@ async def get_token(request: Request) -> OAuth2Token: try: provider = providers[request.session.get("auth_provider_id", "")] except KeyError: + request.session.pop("auth_provider_id", None) + request.session.pop("user_sub", None) raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid provider") try: return await db.get_token( @@ -239,29 +236,27 @@ async def get_user_from_token( raise HTTPException( status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'" ) - if token == "": + if token == "None": + request.session.pop("auth_provider_id", None) + request.session.pop("user_sub", None) raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token") try: payload = provider.decode(token) - except ExpiredSignatureError as err: - logger.info(f"Expired signature: {err}") + except ExpiredSignatureError: raise HTTPException( status.HTTP_401_UNAUTHORIZED, "Expired signature (refresh not implemented yet)", ) - except InvalidKeyError as err: - logger.info(f"Invalid key: {err}") + except InvalidKeyError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key") - except Exception as err: - logger.info("Cannot decode token, see below") - logger.exception(err) - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot decode token") + except PyJWTError as err: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, f"Cannot decode token: {err.__class__.__name__}" + ) try: user_id = payload["sub"] except KeyError: - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found" - ) + return provider.unknown_auth_user try: user = await db.get_user(user_id) if user.access_token != token: diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index 3493429..659fd13 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -3,6 +3,7 @@ import logging from authlib.oauth2.rfc6749 import OAuth2Token +from jwt import PyJWTError from .models import User, Role from .auth_provider import Provider, providers @@ -35,7 +36,10 @@ class Database: if access_token_decoded is None: assert auth_provider.name is not None provider = providers[auth_provider.id] - access_token_decoded = provider.decode(access_token) + try: + access_token_decoded = provider.decode(access_token) + except PyJWTError: + access_token_decoded = {} user_info["auth_provider_id"] = auth_provider.id user = User(**user_info) user.userinfo = user_info @@ -68,8 +72,7 @@ class Database: async def add_token(self, provider: Provider, token: OAuth2Token) -> None: """Store a token using as key the sid (auth provider's session id) in the id_token""" - assert isinstance(provider, Provider) - sid = token["userinfo"]["sid"] + sid = provider.get_session_key(token["userinfo"]) self.tokens[sid] = token async def get_token( diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 304df92..4018997 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -27,14 +27,13 @@ from authlib.oauth2.rfc6749 import OAuth2Token # from pkce import generate_code_verifier, generate_pkce_pair from .settings import settings -from .auth_provider import Provider, providers +from .auth_provider import NoPublicKey, Provider, providers from .models import User from .auth_utils import ( get_auth_provider, get_auth_provider_or_none, get_current_user_or_none, authlib_oauth, - get_providers_info, get_token_or_none, get_token, update_token, @@ -50,7 +49,12 @@ templates = Jinja2Templates(Path(__file__).parent / "templates") @asynccontextmanager async def lifespan(app: FastAPI): assert app is not None - await get_providers_info() + for provider in list(providers.values()): + try: + await provider.get_info() + except NoPublicKey: + logger.warn(f"Disable {provider.id}: public key not found") + del providers[provider.id] yield @@ -82,12 +86,13 @@ async def home( token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], ) -> HTMLResponse: context = { - "settings": settings.model_dump(), + "show_token": settings.show_token, "user": user, "now": datetime.now(), "auth_provider": provider, } if provider is None or token is None: + context["providers"] = providers context["access_token"] = None context["id_token_parsed"] = None context["access_token_parsed"] = None @@ -96,14 +101,23 @@ async def home( else: context["access_token"] = token["access_token"] context["resources"] = provider.resources - access_token_parsed = provider.decode(token["access_token"], verify_signature=False) - context["access_token_scope"] = access_token_parsed["scope"] + try: + access_token_parsed = provider.decode(token["access_token"], verify_signature=False) + except PyJWTError as err: + access_token_parsed = {"Cannot parse": err.__class__.__name__} + try: + context["access_token_scope"] = access_token_parsed["scope"] + except KeyError: + context["access_token_scope"] = None # context["id_token_parsed"] = pretty_details(user, now) context["id_token_parsed"] = provider.decode(token["id_token"], verify_signature=False) context["access_token_parsed"] = access_token_parsed - context["refresh_token_parsed"] = provider.decode( - token["refresh_token"], verify_signature=False - ) + try: + context["refresh_token_parsed"] = provider.decode( + token["refresh_token"], verify_signature=False + ) + except PyJWTError as err: + context["refresh_token_parsed"] = {"Cannot parse": err.__class__.__name__} return templates.TemplateResponse(name="home.html", request=request, context=context) @@ -144,16 +158,19 @@ async def login(request: Request, auth_provider_id: str) -> RedirectResponse: @app.get("/auth/{auth_provider_id}") -async def auth(request: Request, auth_provider_id: str) -> RedirectResponse: +async def auth( + request: Request, + auth_provider_id: str, +) -> RedirectResponse: """Decrypt the auth token, store it to the session (cookie based) and response to the browser with a redirect to a "welcome user" page. """ try: - authlib_client: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id) - except AttributeError: + provider = providers[auth_provider_id] + except KeyError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") try: - token: OAuth2Token = await authlib_client.authorize_access_token(request) + token: OAuth2Token = await provider.authlib_client.authorize_access_token(request) except OAuthError as error: raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error) # Remember the authlib_client in the session @@ -168,6 +185,14 @@ async def auth(request: Request, auth_provider_id: str) -> RedirectResponse: request.session["auth_provider_id"] = auth_provider_id # User id (sub) given by auth provider sub = userinfo["sub"] + ## Get additional data from userinfo endpoint + # try: + # user_info_from_endpoint = await authlib_client.userinfo( + # token=token, follow_redirects=True + # ) + # except Exception as err: + # logger.warn(f"Cannot get userinfo from endpoint: {err}") + # user_info_from_endpoint = {} # Build and remember the user in the session request.session["user_sub"] = sub # Store the user in the database, which also verifies the token validity and signature @@ -185,9 +210,8 @@ async def auth(request: Request, auth_provider_id: str) -> RedirectResponse: ) assert isinstance(user, User) # Add the provider session id to the session - request.session["sid"] = userinfo["sid"] + request.session["sid"] = provider.get_session_key(userinfo) # Add the token to the db because it is used for logout - provider = providers[auth_provider_id] await db.add_token(provider, token) # Send the user to the home: (s)he is authenticated return RedirectResponse(url=request.url_for("home")) @@ -211,15 +235,16 @@ async def logout( request: Request, provider: Annotated[Provider, Depends(get_auth_provider)], ) -> RedirectResponse: - # Clear session - request.session.pop("user_sub", None) # Get provider's endpoint if ( provider_logout_uri := provider.authlib_client.server_metadata.get("end_session_endpoint") ) is None: - logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}") + logger.warn(f"Cannot find end_session_endpoint for provider {provider.id}") return RedirectResponse(request.url_for("non_compliant_logout")) post_logout_uri = request.url_for("home") + # Clear session + request.session.pop("user_sub", None) + request.session.pop("auth_provider_id", None) try: token = await db.get_token(provider, request.session.pop("sid", None)) except TokenNotInDb: @@ -242,15 +267,16 @@ async def logout( @app.get("/non-compliant-logout") async def non_compliant_logout( request: Request, - provider: Annotated[StarletteOAuth2App, Depends(get_auth_provider)], + provider: Annotated[Provider, Depends(get_auth_provider)], ): """A page for non-compliant OAuth2 servers that we cannot log out.""" - # Clear the remain of the session + # Clear session + request.session.pop("user_sub", None) request.session.pop("auth_provider_id", None) return templates.TemplateResponse( name="non_compliant_logout.html", request=request, - context={"oidc_provider": provider, "home_url": request.url_for("home")}, + context={"auth_provider": provider, "home_url": request.url_for("home")}, ) diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py index eda63a6..7c5250b 100644 --- a/src/oidc_test/models.py +++ b/src/oidc_test/models.py @@ -49,13 +49,13 @@ class User(UserBase): try: access_token_scopes = self.decode_access_token().get("scope", "").split(" ") except Exception as err: - logger.info(f"Access token cannot be decoded: {err}") + logger.debug(f"Cannot find scope because the access token cannot be decoded: {err}") access_token_scopes = [] return scope in set(info_scopes + access_token_scopes) def decode_access_token(self, verify_signature: bool = True): - assert self.access_token is not None - assert self.auth_provider_id is not None + assert self.access_token is not None, "no access_token" + assert self.auth_provider_id is not None, "no auth_provider_id" from .auth_provider import providers return providers[self.auth_provider_id].decode( diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index e5670ed..f7f0433 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import Annotated import logging -from authlib.jose import Key +from authlib.oauth2.auth import OAuth2Token from httpx import AsyncClient from jwt.exceptions import ExpiredSignatureError, InvalidTokenError from fastapi import FastAPI, HTTPException, Depends, status @@ -14,11 +14,12 @@ from fastapi.middleware.cors import CORSMiddleware from .models import User from .auth_utils import ( + get_token_or_none, get_user_from_token, UserWithRole, ) from .settings import settings -from .auth_provider import providers +from .auth_provider import providers, Provider logger = logging.getLogger("oidc-test") @@ -113,23 +114,51 @@ async def get_protected_by_foorole_or_barrole( @resource_server.get("/{id}") -async def get_resource_( +async def get_resource( id: str, - # user: Annotated[User, Depends(get_current_user)], - # oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], - # token: Annotated[OAuth2Token, Depends(get_token)], user: Annotated[User, Depends(get_user_from_token)], -): + token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], +) -> dict | list: """Generic path for testing a resource provided by a provider""" - return await get_resource(id, user) + provider = providers[user.auth_provider_id] + if id in [r.id for r in provider.resources]: + return await get_external_resource( + provider=provider, + id=id, + access_token=token["access_token"] if token else None, + user=user, + ) + return await get_resource_(id, user) -async def get_resource(resource_id: str, user: User) -> dict: +async def get_external_resource( + provider: Provider, id: str, access_token: str | None, user: User +) -> dict | list: + resource = [r for r in provider.resources if r.id == id][0] + async with AsyncClient() as client: + resp = await client.get( + url=provider.url + resource.url, + headers={ + "Content-type": "application/json", + "Authorization": f"Bearer {access_token}", + }, + ) + if resp.is_error: + raise HTTPException(resp.status_code, f"Cannot fetch resource: {resp.reason_phrase}") + resp_length = len(resp.text) + if resp_length > 1024: + return {"msg": f"The resource is too long ({resp_length} bytes) to show here"} + else: + return resp.json() + + +async def get_resource_(resource_id: str, user: User) -> dict: """ Resource processing: build an informative rely as a simple showcase """ + provider = providers[user.auth_provider_id] try: - pname = providers[user.auth_provider_id].name + pname = provider.name except KeyError: pname = "?" resp = { @@ -164,7 +193,6 @@ async def process(user, resource_id, resp): Too simple to be serious. It's a good fit for a plugin architecture for production """ - assert user is not None if resource_id == "time": resp["time"] = datetime.now().strftime("%c") elif resource_id == "bs": diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index 9a789a0..86a2b6b 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -21,6 +21,7 @@ class Resource(BaseModel): id: str name: str + url: str class AuthProviderSettings(BaseModel): @@ -39,10 +40,12 @@ class AuthProviderSettings(BaseModel): info_url: str | None = ( None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key) ) - info: dict[str, str | int] | None = None # Info fetched from info_url, eg. public key public_key: str | None = None + public_key_url: str | None = None signature_alg: str = "RS256" resource_provider_scopes: list[str] = [] + session_key: str = "sid" + skip_verify_signature: bool = True @computed_field @property @@ -64,17 +67,6 @@ class AuthProviderSettings(BaseModel): else: return None - def get_public_key(self) -> str: - """Return the public key formatted for decoding token""" - public_key = self.public_key or (self.info is not None and self.info["public_key"]) - if public_key is None: - raise AttributeError(f"Cannot get public key for {self.name}") - return f""" - -----BEGIN PUBLIC KEY----- - {public_key} - -----END PUBLIC KEY----- - """ - class ResourceProvider(BaseModel): id: str diff --git a/src/oidc_test/static/utils.js b/src/oidc_test/static/utils.js index 8e8ad59..6c9fae4 100644 --- a/src/oidc_test/static/utils.js +++ b/src/oidc_test/static/utils.js @@ -55,15 +55,24 @@ async function get_resource(id, token, authProvider) { msg.innerHTML = "" resourceElem.innerHTML = "" Object.entries(resource).forEach( - ([k, v]) => { + ([key, value]) => { let r = document.createElement('div') let kElem = document.createElement('div') - kElem.innerText = k + kElem.innerText = key kElem.className = "key" let vElem = document.createElement('div') - vElem.innerText = v + if (typeof value == "object") { + Object.entries(value).forEach(v => { + const ne = document.createElement('div') + ne.innerHTML = `<span class="key">${v[0]}</span>: <span class="value">${v[1]}</span>` + vElem.appendChild(ne) + }) + } + else { + vElem.innerText = value + } vElem.className = "value" - if (k == "sorry") { + if (key == "sorry") { vElem.classList.add("error") } r.appendChild(kElem) diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 7275f2d..da513c9 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -5,25 +5,24 @@ with OpenID Connect and OAuth2 with different providers. </p> {% if not user %} - <div class="login-box"> - <p class="description">Log in with:</p> - <table class="providers"> - {% for provider in settings.auth.providers %} - <tr class="provider"> - <td> - <a class="link" href="login/{{ provider.id }}"><div>{{ provider.name }}</div></a> - </td> - <td class="hint">{{ provider.hint }}</div> - </td> - </tr> - {% else %} - <div class="error">There is no authentication provider defined. - Hint: check the settings.yaml file.</div> - {% endfor %} - </table> - </div> - {% endif %} - {% if user %} + <div class="login-box"> + <p class="description">Log in with:</p> + <table class="providers"> + {% for provider in providers.values() %} + <tr class="provider"> + <td> + <a class="link" href="login/{{ provider.id }}"><div>{{ provider.name }}</div></a> + </td> + <td class="hint">{{ provider.hint }}</div> + </td> + </tr> + {% else %} + <div class="error">There is no authentication provider defined. + Hint: check the settings.yaml file.</div> + {% endfor %} + </table> + </div> + {% else %} <div class="user-info"> <p>Hey, {{ user.name }}</p> {% if user.picture %} @@ -83,22 +82,22 @@ <button resource-id="fast_api_depends" class="hidden" onclick="get_resource('fast_api_depends', '{{ access_token }}', '{{ auth_provider.id }}')">Using FastAPI Depends</button> <!--<button resource-id="introspect" onclick="get_resource('introspect', '{{ access_token }}', '{{ auth_provider.id }}')">Introspect token (401 expected)</button>--> </div> - <div class="resourceResult"> - <div id="resource" class="resource"></div> - <div id="msg" class="msg error"></div> - </div> {% if resources %} <p> Resources for this provider: </p> <div class="links-to-check"> {% for resource in resources %} - <a href="{{ request.url_for('get_resource', id=resource.id) }}">{{ resource.name }}</a> + <button resource-id="{{ resource.id }}" onclick="get_resource('{{ resource.id }}', '{{ access_token }}', '{{ auth_provider.id }}')">{{ resource.name }}</buttona> {% endfor %} </div> {% endif %} + <div class="resourceResult"> + <div id="resource" class="resource"></div> + <div id="msg" class="msg error"></div> + </div> </div> - {% if settings.show_token and id_token_parsed %} + {% if show_token and id_token_parsed %} <div class="token-info"> <hr> <div> diff --git a/src/oidc_test/templates/non_compliant_logout.html b/src/oidc_test/templates/non_compliant_logout.html index 24a96ae..56758de 100644 --- a/src/oidc_test/templates/non_compliant_logout.html +++ b/src/oidc_test/templates/non_compliant_logout.html @@ -6,12 +6,12 @@ authorisation to log in again without asking for credentials. </p> <p> - This is because {{ oidc_provider.name }} does not provide "end_session_endpoint" in its metadata - (see: <a href="{{ oidc_provider._server_metadata_url }}">{{ oidc_provider._server_metadata_url }}</a>). + This is because {{ auth_provider.name }} does not provide "end_session_endpoint" in its metadata + (see: <a href="{{ auth_provider.authlib_client._server_metadata_url }}">{{ auth_provider.authlib_client._server_metadata_url }}</a>). </p> <p> You can just also go back to the <a href="{{ home_url }}">application home page</a>, but - it recommended to go to the <a href="{{ oidc_provider.server_metadata['issuer'] }}">OIDC provider's site</a> + it recommended to go to the <a href="{{ auth_provider.authlib_client.server_metadata['issuer'] }}">OIDC provider's site</a> and log out explicitely from there. </p> {% endblock %} From e56be3c378e10e4ac972a3901021226ed26c7c1f Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Mon, 10 Feb 2025 14:14:32 +0100 Subject: [PATCH 08/46] Continue refactor --- .../{auth_provider.py => auth/provider.py} | 8 ++----- .../{auth_utils.py => auth/utils.py} | 24 +++++++++---------- src/oidc_test/auth_providers.py | 5 ++++ src/oidc_test/database.py | 4 +++- src/oidc_test/main.py | 12 ++++++---- src/oidc_test/models.py | 2 +- src/oidc_test/resource_server.py | 8 ++++--- src/oidc_test/settings.py | 4 +--- src/oidc_test/static/styles.css | 3 +-- src/oidc_test/templates/home.html | 2 +- 10 files changed, 38 insertions(+), 34 deletions(-) rename src/oidc_test/{auth_provider.py => auth/provider.py} (94%) rename src/oidc_test/{auth_utils.py => auth/utils.py} (96%) create mode 100644 src/oidc_test/auth_providers.py diff --git a/src/oidc_test/auth_provider.py b/src/oidc_test/auth/provider.py similarity index 94% rename from src/oidc_test/auth_provider.py rename to src/oidc_test/auth/provider.py index e50241c..dab4764 100644 --- a/src/oidc_test/auth_provider.py +++ b/src/oidc_test/auth/provider.py @@ -2,14 +2,13 @@ from json import JSONDecodeError from typing import Any from jwt import decode import logging -from collections import OrderedDict from pydantic import ConfigDict from authlib.integrations.starlette_client.apps import StarletteOAuth2App from httpx import AsyncClient -from .settings import AuthProviderSettings, settings -from .models import User +from ..settings import AuthProviderSettings, settings +from ..models import User logger = logging.getLogger("oidc-test") @@ -90,6 +89,3 @@ class Provider(AuthProviderSettings): def get_session_key(self, userinfo): return userinfo[self.session_key] - - -providers: OrderedDict[str, Provider] = OrderedDict() diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth/utils.py similarity index 96% rename from src/oidc_test/auth_utils.py rename to src/oidc_test/auth/utils.py index 8cd5028..0623186 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth/utils.py @@ -10,15 +10,15 @@ from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError # from authlib.oauth1.auth import OAuthToken from authlib.oauth2.auth import OAuth2Token -from .models import User -from .database import db, TokenNotInDb, UserNotInDB -from .settings import settings -from .auth_provider import providers, Provider +from .provider import Provider + +from ..models import User +from ..database import db, TokenNotInDb, UserNotInDB +from ..settings import settings +from ..auth_providers import providers logger = logging.getLogger("oidc-test") -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") - async def fetch_token(name, request): assert name is not None @@ -51,9 +51,6 @@ async def update_token( # await item.save() -authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) - - def init_providers(): """Add oidc providers to authlib from the settings and build the providers dict""" @@ -86,7 +83,8 @@ def init_providers(): providers[provider.id] = provider -init_providers() +authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None: @@ -245,7 +243,7 @@ async def get_user_from_token( except ExpiredSignatureError: raise HTTPException( status.HTTP_401_UNAUTHORIZED, - "Expired signature (refresh not implemented yet)", + "Expired signature (token refresh not implemented yet)", ) except InvalidKeyError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key") @@ -263,12 +261,12 @@ async def get_user_from_token( user.access_token = token except UserNotInDB: logger.info( - f"User {user_id} not found in DB, creating it (real apps can behave differently" + f"User {user_id} not found in DB, creating it (real apps can behave differently)" ) user = await db.add_user( sub=payload["sub"], user_info=payload, - auth_provider=getattr(authlib_oauth, auth_provider_id), + auth_provider=providers[auth_provider_id], access_token=token, ) return user diff --git a/src/oidc_test/auth_providers.py b/src/oidc_test/auth_providers.py new file mode 100644 index 0000000..45f4de6 --- /dev/null +++ b/src/oidc_test/auth_providers.py @@ -0,0 +1,5 @@ +from collections import OrderedDict + +from .auth.provider import Provider + +providers: OrderedDict[str, Provider] = OrderedDict() diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index 659fd13..4704f9b 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -5,8 +5,10 @@ import logging from authlib.oauth2.rfc6749 import OAuth2Token from jwt import PyJWTError +from .auth.provider import Provider + from .models import User, Role -from .auth_provider import Provider, providers +from .auth_providers import providers logger = logging.getLogger("oidc-test") diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 4018997..f37339d 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -26,10 +26,8 @@ from authlib.oauth2.rfc6749 import OAuth2Token # from fastapi.security import OpenIdConnect # from pkce import generate_code_verifier, generate_pkce_pair -from .settings import settings -from .auth_provider import NoPublicKey, Provider, providers -from .models import User -from .auth_utils import ( +from .auth.provider import NoPublicKey, Provider +from .auth.utils import ( get_auth_provider, get_auth_provider_or_none, get_current_user_or_none, @@ -38,6 +36,11 @@ from .auth_utils import ( get_token, update_token, ) + +from .auth.utils import init_providers +from .settings import settings +from .auth_providers import providers +from .models import User from .database import TokenNotInDb, db from .resource_server import resource_server @@ -49,6 +52,7 @@ templates = Jinja2Templates(Path(__file__).parent / "templates") @asynccontextmanager async def lifespan(app: FastAPI): assert app is not None + init_providers() for provider in list(providers.values()): try: await provider.get_info() diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py index 7c5250b..7b6fd0e 100644 --- a/src/oidc_test/models.py +++ b/src/oidc_test/models.py @@ -56,7 +56,7 @@ class User(UserBase): def decode_access_token(self, verify_signature: bool = True): assert self.access_token is not None, "no access_token" assert self.auth_provider_id is not None, "no auth_provider_id" - from .auth_provider import providers + from .auth_providers import providers return providers[self.auth_provider_id].decode( self.access_token, verify_signature=verify_signature diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index f7f0433..15084bc 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -12,14 +12,16 @@ from fastapi.middleware.cors import CORSMiddleware # from authlib.integrations.starlette_client.apps import StarletteOAuth2App # from authlib.oauth2.rfc6749 import OAuth2Token -from .models import User -from .auth_utils import ( +from .auth.provider import Provider +from .auth.utils import ( get_token_or_none, get_user_from_token, UserWithRole, ) + +from .auth_providers import providers from .settings import settings -from .auth_provider import providers, Provider +from .models import User logger = logging.getLogger("oidc-test") diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index 86a2b6b..f3ac8f3 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -13,8 +13,6 @@ from pydantic_settings import ( ) from starlette.requests import Request -from .models import User - class Resource(BaseModel): """A resource with an URL that can be accessed with an OAuth2 access token""" @@ -57,7 +55,7 @@ class AuthProviderSettings(BaseModel): def token_url(self) -> str: return "auth/" + self.id - def get_account_url(self, request: Request, user: User) -> str | None: + def get_account_url(self, request: Request, user: dict) -> str | None: if self.account_url_template: if not (self.url.endswith("/") or self.account_url_template.startswith("/")): sep = "/" diff --git a/src/oidc_test/static/styles.css b/src/oidc_test/static/styles.css index 367ea99..e163a68 100644 --- a/src/oidc_test/static/styles.css +++ b/src/oidc_test/static/styles.css @@ -184,11 +184,10 @@ hr { font-family: monospace; } -.resource { +.resourceResult { padding: 0.5em; display: flex; gap: 0.5em; - flex-direction: column; width: fit-content; align-items: center; margin: 5px auto; diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index da513c9..93d0bc6 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -51,7 +51,7 @@ {% endif %} {% if auth_provider.account_url_template %} <button - onclick="location.href='{{ auth_provider.get_account_url(request, user) }}'" + onclick="location.href='{{ auth_provider.get_account_url(request, user.model_dump()) }}'" class="account"> Account management </button> From 64f6a90f22c82a6813cd7e68f2ab193fdbd1980b Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Tue, 11 Feb 2025 17:27:49 +0100 Subject: [PATCH 09/46] Add resource provided registry and plugin system --- src/oidc_test/auth/provider.py | 4 +- src/oidc_test/auth/utils.py | 11 +- src/oidc_test/auth_providers.py | 2 +- src/oidc_test/database.py | 8 +- src/oidc_test/main.py | 22 +-- src/oidc_test/registry.py | 43 ++++++ src/oidc_test/resource_server.py | 229 +++++++++++++++++------------- src/oidc_test/settings.py | 2 +- src/oidc_test/static/utils.js | 9 +- src/oidc_test/templates/home.html | 41 +++--- 10 files changed, 229 insertions(+), 142 deletions(-) create mode 100644 src/oidc_test/registry.py diff --git a/src/oidc_test/auth/provider.py b/src/oidc_test/auth/provider.py index dab4764..17dcaa0 100644 --- a/src/oidc_test/auth/provider.py +++ b/src/oidc_test/auth/provider.py @@ -7,8 +7,8 @@ from pydantic import ConfigDict from authlib.integrations.starlette_client.apps import StarletteOAuth2App from httpx import AsyncClient -from ..settings import AuthProviderSettings, settings -from ..models import User +from oidc_test.settings import AuthProviderSettings, settings +from oidc_test.models import User logger = logging.getLogger("oidc-test") diff --git a/src/oidc_test/auth/utils.py b/src/oidc_test/auth/utils.py index 0623186..9479c48 100644 --- a/src/oidc_test/auth/utils.py +++ b/src/oidc_test/auth/utils.py @@ -10,12 +10,11 @@ from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError # from authlib.oauth1.auth import OAuthToken from authlib.oauth2.auth import OAuth2Token -from .provider import Provider - -from ..models import User -from ..database import db, TokenNotInDb, UserNotInDB -from ..settings import settings -from ..auth_providers import providers +from oidc_test.auth.provider import Provider +from oidc_test.models import User +from oidc_test.database import db, TokenNotInDb, UserNotInDB +from oidc_test.settings import settings +from oidc_test.auth_providers import providers logger = logging.getLogger("oidc-test") diff --git a/src/oidc_test/auth_providers.py b/src/oidc_test/auth_providers.py index 45f4de6..1c33ae8 100644 --- a/src/oidc_test/auth_providers.py +++ b/src/oidc_test/auth_providers.py @@ -1,5 +1,5 @@ from collections import OrderedDict -from .auth.provider import Provider +from oidc_test.auth.provider import Provider providers: OrderedDict[str, Provider] = OrderedDict() diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index 4704f9b..8d87a48 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -5,10 +5,10 @@ import logging from authlib.oauth2.rfc6749 import OAuth2Token from jwt import PyJWTError -from .auth.provider import Provider +from oidc_test.auth.provider import Provider -from .models import User, Role -from .auth_providers import providers +from oidc_test.models import User, Role +from oidc_test.auth_providers import providers logger = logging.getLogger("oidc-test") @@ -23,6 +23,7 @@ class TokenNotInDb(Exception): class Database: users: dict[str, User] = {} + # TODO: key of the token table should be provider: sid tokens: dict[str, OAuth2Token] = {} # Last sessions for the user (key: users's subject id (sub)) @@ -82,6 +83,7 @@ class Database: provider: Provider, sid: str | None, ) -> OAuth2Token: + # TODO: key of the token table should be provider: sid assert isinstance(provider, Provider) if sid is None: raise TokenNotInDb diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index f37339d..28eab8a 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -26,8 +26,9 @@ from authlib.oauth2.rfc6749 import OAuth2Token # from fastapi.security import OpenIdConnect # from pkce import generate_code_verifier, generate_pkce_pair -from .auth.provider import NoPublicKey, Provider -from .auth.utils import ( +from oidc_test.registry import registry +from oidc_test.auth.provider import NoPublicKey, Provider +from oidc_test.auth.utils import ( get_auth_provider, get_auth_provider_or_none, get_current_user_or_none, @@ -36,13 +37,12 @@ from .auth.utils import ( get_token, update_token, ) - -from .auth.utils import init_providers -from .settings import settings -from .auth_providers import providers -from .models import User -from .database import TokenNotInDb, db -from .resource_server import resource_server +from oidc_test.auth.utils import init_providers +from oidc_test.settings import settings +from oidc_test.auth_providers import providers +from oidc_test.models import User +from oidc_test.database import TokenNotInDb, db +from oidc_test.resource_server import resource_server logger = logging.getLogger("oidc-test") @@ -53,6 +53,7 @@ templates = Jinja2Templates(Path(__file__).parent / "templates") async def lifespan(app: FastAPI): assert app is not None init_providers() + registry.make_registry() for provider in list(providers.values()): try: await provider.get_info() @@ -104,6 +105,7 @@ async def home( context["resources"] = None else: context["access_token"] = token["access_token"] + # XXX: resources defined externally? I am confused... context["resources"] = provider.resources try: access_token_parsed = provider.decode(token["access_token"], verify_signature=False) @@ -113,9 +115,9 @@ async def home( context["access_token_scope"] = access_token_parsed["scope"] except KeyError: context["access_token_scope"] = None - # context["id_token_parsed"] = pretty_details(user, now) context["id_token_parsed"] = provider.decode(token["id_token"], verify_signature=False) context["access_token_parsed"] = access_token_parsed + context["resource_providers"] = registry.resource_providers try: context["refresh_token_parsed"] = provider.decode( token["refresh_token"], verify_signature=False diff --git a/src/oidc_test/registry.py b/src/oidc_test/registry.py new file mode 100644 index 0000000..6db0a47 --- /dev/null +++ b/src/oidc_test/registry.py @@ -0,0 +1,43 @@ +from importlib.metadata import entry_points +import logging +from typing import Any + +from pydantic import BaseModel + +from oidc_test.models import User + +logger = logging.getLogger("registry") + + +class ProcessResult(BaseModel): + result: dict[str, Any] = {} + + +class ProcessError(Exception): + pass + + +class ResourceProvider: + name: str + scope_required: str | None = None + default_resource_id: str | None = None + + def __init__(self, name: str): + self.name = name + + async def process(self, user: User, resource_id: str | None = None) -> ProcessResult: + logger.warning(f"{self.name} should define a process method") + return ProcessResult() + + +class ResourceRegistry: + resource_providers: dict[str, ResourceProvider] = {} + + def make_registry(self): + for ep in entry_points().select(group="oidc_test.resource_provider"): + ResourceProviderClass = ep.load() + if issubclass(ResourceProviderClass, ResourceProvider): + self.resource_providers[ep.name] = ResourceProviderClass(ep.name) + + +registry = ResourceRegistry() diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index 15084bc..3b89240 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -5,23 +5,23 @@ import logging from authlib.oauth2.auth import OAuth2Token from httpx import AsyncClient from jwt.exceptions import ExpiredSignatureError, InvalidTokenError -from fastapi import FastAPI, HTTPException, Depends, status +from fastapi import FastAPI, HTTPException, Depends, Request, status from fastapi.middleware.cors import CORSMiddleware # from starlette.middleware.sessions import SessionMiddleware # from authlib.integrations.starlette_client.apps import StarletteOAuth2App # from authlib.oauth2.rfc6749 import OAuth2Token -from .auth.provider import Provider -from .auth.utils import ( +from oidc_test.auth.provider import Provider +from oidc_test.auth.utils import ( get_token_or_none, get_user_from_token, UserWithRole, ) - -from .auth_providers import providers -from .settings import settings -from .models import User +from oidc_test.auth_providers import providers +from oidc_test.settings import settings +from oidc_test.models import User +from oidc_test.registry import ProcessError, ProcessResult, registry logger = logging.getLogger("oidc-test") @@ -91,6 +91,128 @@ async def get_protected_by_foorole_or_barrole( return {"msg": "Only users with foorole or barrole can see this"} +@resource_server.get("/{resource_name}") +@resource_server.get("/{resource_name}/{resource_id}") +async def get_resource( + resource_name: str, + user: Annotated[User, Depends(get_user_from_token)], + token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], + request: Request, + resource_id: str | None = None, +) -> ProcessResult: + """Generic path for testing a resource provided by a provider""" + provider = providers[user.auth_provider_id] + # Third party resource (provided through the auth provider) + # The token is just passed on + if resource_name in [r.resource_name for r in provider.resources]: + return await get_auth_provider_resource( + provider=provider, + resource_name=resource_name, + access_token=token["access_token"] if token else None, + user=user, + ) + # Internal resource (provided here) + if resource_name in registry.resource_providers: + resource_provider = registry.resource_providers[resource_name] + if resource_provider.scope_required is not None and user.has_scope( + resource_provider.scope_required + ): + try: + return await resource_provider.process(user=user, resource_id=resource_id) + except ProcessError as err: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, f"Cannot process resource: {err}" + ) + else: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, + f"No scope {resource_provider.scope_required} in the access token " + + "but it is required for accessing this resource", + ) + else: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown resource {resource_name}") + # return await get_resource_(resource_name, user, **request.query_params) + + +async def get_auth_provider_resource( + provider: Provider, resource_name: str, access_token: str | None, user: User +) -> ProcessResult: + resource = [r for r in provider.resources if r.resource_name == resource_name][0] + async with AsyncClient() as client: + resp = await client.get( + url=provider.url + resource.url, + headers={ + "Content-type": "application/json", + "Authorization": f"Bearer {access_token}", + }, + ) + if resp.is_error: + raise HTTPException(resp.status_code, f"Cannot fetch resource: {resp.reason_phrase}") + # Only a demo, real application would really process the response + resp_length = len(resp.text) + if resp_length > 1024: + return ProcessResult( + result={"msg": f"The resource is too long ({resp_length} bytes) to show here"} + ) + else: + return ProcessResult(result=resp.json()) + + +# async def get_resource_(resource_id: str, user: User, **kwargs) -> dict: +# """ +# Resource processing: build an informative rely as a simple showcase +# """ +# if resource_id == "petition": +# return await sign(user, kwargs["petition_id"]) +# provider = providers[user.auth_provider_id] +# try: +# pname = provider.name +# except KeyError: +# pname = "?" +# resp = { +# "hello": f"Hi {user.name} from an OAuth resource provider", +# "comment": f"I received a request for '{resource_id}' " +# + f"with an access token signed by {pname}", +# } +# # For the demo, resource resource_id matches a scope get:resource_id, +# # but this has to be refined for production +# required_scope = f"get:{resource_id}" +# # Check if the required scope is in the scopes allowed in userinfo +# try: +# if user.has_scope(required_scope): +# await process(user, resource_id, resp) +# else: +# ## For the showcase, giving a explanation. +# ## Alternatively, raise HTTP_401_UNAUTHORIZED +# raise HTTPException( +# status.HTTP_401_UNAUTHORIZED, +# f"No scope {required_scope} in the access token " +# + "but it is required for accessing this resource", +# ) +# except ExpiredSignatureError: +# raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token's signature has expired") +# except InvalidTokenError: +# raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token is invalid") +# return resp + + +# async def process(user, resource_id, resp): +# """ +# Too simple to be serious. +# It's a good fit for a plugin architecture for production +# """ +# if resource_id == "time": +# resp["time"] = datetime.now().strftime("%c") +# elif resource_id == "bs": +# async with AsyncClient() as client: +# bs = await client.get("https://corporatebs-generator.sameerkumar.website/") +# resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.") +# else: +# raise HTTPException( +# status.HTTP_401_UNAUTHORIZED, f"I don't known how to give '{resource_id}'." +# ) + + # @resource_server.get("/introspect") # async def get_introspect( # request: Request, @@ -114,99 +236,6 @@ async def get_protected_by_foorole_or_barrole( # else: # raise HTTPException(status_code=response.status_code, detail=response.text) - -@resource_server.get("/{id}") -async def get_resource( - id: str, - user: Annotated[User, Depends(get_user_from_token)], - token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], -) -> dict | list: - """Generic path for testing a resource provided by a provider""" - provider = providers[user.auth_provider_id] - if id in [r.id for r in provider.resources]: - return await get_external_resource( - provider=provider, - id=id, - access_token=token["access_token"] if token else None, - user=user, - ) - return await get_resource_(id, user) - - -async def get_external_resource( - provider: Provider, id: str, access_token: str | None, user: User -) -> dict | list: - resource = [r for r in provider.resources if r.id == id][0] - async with AsyncClient() as client: - resp = await client.get( - url=provider.url + resource.url, - headers={ - "Content-type": "application/json", - "Authorization": f"Bearer {access_token}", - }, - ) - if resp.is_error: - raise HTTPException(resp.status_code, f"Cannot fetch resource: {resp.reason_phrase}") - resp_length = len(resp.text) - if resp_length > 1024: - return {"msg": f"The resource is too long ({resp_length} bytes) to show here"} - else: - return resp.json() - - -async def get_resource_(resource_id: str, user: User) -> dict: - """ - Resource processing: build an informative rely as a simple showcase - """ - provider = providers[user.auth_provider_id] - try: - pname = provider.name - except KeyError: - pname = "?" - resp = { - "hello": f"Hi {user.name} from an OAuth resource provider", - "comment": f"I received a request for '{resource_id}' " - + f"with an access token signed by {pname}", - } - # For the demo, resource resource_id matches a scope get:resource_id, - # but this has to be refined for production - required_scope = f"get:{resource_id}" - # Check if the required scope is in the scopes allowed in userinfo - try: - if user.has_scope(required_scope): - await process(user, resource_id, resp) - else: - ## For the showcase, giving a explanation. - ## Alternatively, raise HTTP_401_UNAUTHORIZED - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, - f"No scope {required_scope} in the access token " - + "but it is required for accessing this resource", - ) - except ExpiredSignatureError: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token's signature has expired") - except InvalidTokenError: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token is invalid") - return resp - - -async def process(user, resource_id, resp): - """ - Too simple to be serious. - It's a good fit for a plugin architecture for production - """ - if resource_id == "time": - resp["time"] = datetime.now().strftime("%c") - elif resource_id == "bs": - async with AsyncClient() as client: - bs = await client.get("https://corporatebs-generator.sameerkumar.website/") - resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.") - else: - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, f"I don't known how to give '{resource_id}'." - ) - - # assert user.oidc_provider is not None ### Get some info (TODO: refactor) # if (auth_provider_id := user.oidc_provider.name) is None: diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index f3ac8f3..2acbc3f 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -17,7 +17,7 @@ from starlette.requests import Request class Resource(BaseModel): """A resource with an URL that can be accessed with an OAuth2 access token""" - id: str + resource_name: str name: str url: str diff --git a/src/oidc_test/static/utils.js b/src/oidc_test/static/utils.js index 6c9fae4..978b61c 100644 --- a/src/oidc_test/static/utils.js +++ b/src/oidc_test/static/utils.js @@ -1,6 +1,8 @@ async function checkHref(elem, token, authProvider) { const msg = document.getElementById("msg") - const url = `resource/${elem.getAttribute("resource-id")}` + const resourceName = elem.getAttribute("resource-name") + const resourceId = elem.getAttribute("resource-id") + const url = resourceId ? `resource/${resourceName}/${resourceId}` : `resource/${resourceName}` const resp = await fetch(url, { method: "GET", headers: new Headers({ @@ -28,11 +30,12 @@ function checkPerms(className, token, authProvider) { ) } -async function get_resource(id, token, authProvider) { +async function get_resource(resource_name, token, authProvider, resource_id) { //if (!keycloak.keycloak) { return } const msg = document.getElementById("msg") const resourceElem = document.getElementById('resource') - const resp = await fetch("resource/" + id, { + const url = resource_id ? `resource/${resource_name}/${resource_id}` : `resource/${resource_name}` + const resp = await fetch(url, { method: "GET", headers: new Headers({ "Content-type": "application/json", diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 93d0bc6..790da81 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -62,33 +62,42 @@ {% endif %} <hr> <div class="content"> - <p class="center"> - Resources validated by scope: - </p> - <div class="links-to-check"> - <button resource-id="time" onclick="get_resource('time', '{{ access_token }}', '{{ auth_provider.id }}')">Time</button> - <button resource-id="bs" onclick="get_resource('bs', '{{ access_token }}', '{{ auth_provider.id }}')">BS</button> - </div> <p> Resources validated by role: </p> <div class="links-to-check"> - <button resource-id="public" onclick="get_resource('public', '{{ access_token }}', '{{ auth_provider.id }}')">Public</button> - <button resource-id="protected" onclick="get_resource('protected', '{{ access_token }}', '{{ auth_provider.id }}')">Auth protected content</button> - <button resource-id="protected-by-foorole" onclick="get_resource('protected-by-foorole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole protected content</button> - <button resource-id="protected-by-foorole-or-barrole" onclick="get_resource('protected-by-foorole-or-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole or barrole protected content</button> - <button resource-id="protected-by-barrole" onclick="get_resource('protected-by-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + barrole protected content</button> - <button resource-id="protected-by-foorole-and-barrole" onclick="get_resource('protected-by-foorole-and-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole and barrole protected content</button> - <button resource-id="fast_api_depends" class="hidden" onclick="get_resource('fast_api_depends', '{{ access_token }}', '{{ auth_provider.id }}')">Using FastAPI Depends</button> - <!--<button resource-id="introspect" onclick="get_resource('introspect', '{{ access_token }}', '{{ auth_provider.id }}')">Introspect token (401 expected)</button>--> + <button resource-name="public" onclick="get_resource('public', '{{ access_token }}', '{{ auth_provider.id }}')">Public</button> + <button resource-name="protected" onclick="get_resource('protected', '{{ access_token }}', '{{ auth_provider.id }}')">Auth protected content</button> + <button resource-name="protected-by-foorole" onclick="get_resource('protected-by-foorole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole protected content</button> + <button resource-name="protected-by-foorole-or-barrole" onclick="get_resource('protected-by-foorole-or-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole or barrole protected content</button> + <button resource-name="protected-by-barrole" onclick="get_resource('protected-by-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + barrole protected content</button> + <button resource-name="protected-by-foorole-and-barrole" onclick="get_resource('protected-by-foorole-and-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole and barrole protected content</button> + <button resource-name="fast_api_depends" class="hidden" onclick="get_resource('fast_api_depends', '{{ access_token }}', '{{ auth_provider.id }}')">Using FastAPI Depends</button> + <!--<button resource-name="introspect" onclick="get_resource('introspect', '{{ access_token }}', '{{ auth_provider.id }}')">Introspect token (401 expected)</button>--> </div> + <!-- XXX confused... {% if resources %} <p> Resources for this provider: </p> <div class="links-to-check"> {% for resource in resources %} - <button resource-id="{{ resource.id }}" onclick="get_resource('{{ resource.id }}', '{{ access_token }}', '{{ auth_provider.id }}')">{{ resource.name }}</buttona> + <button resource-name="{{ resource.id }}" onclick="get_resource('{{ resource.name }}', '{{ access_token }}', '{{ auth_provider.id }}')">{{ resource.name }}</buttona> + {% endfor %} + </div> + {% endif %} + --> + {% if resource_providers %} + <p> + Resource providers (validated by scope): + </p> + <div class="links-to-check"> + {% for name, resource_provider in resource_providers.items() %} + {% if resource_provider.default_resource_id %} + <button resource-name="{{ name }}" resource-id="{{ resource_provider.default_resource_id }}" onclick="get_resource('{{ name }}', '{{ access_token }}', '{{ auth_provider.id }}', '{{ resource_provider.default_resource_id }}')">{{ name }}</buttona> + {% else %} + <button resource-name="{{ name }}" onclick="get_resource('{{ name }}', '{{ access_token }}', '{{ auth_provider.id }}')">{{ name }}</buttona> + {% endif %} {% endfor %} </div> {% endif %} From 0464047f8a6d17590f06fc7c52c4b1d4e007b2ce Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Wed, 12 Feb 2025 03:21:06 +0100 Subject: [PATCH 10/46] Container: add demo plugin --- Containerfile | 3 +++ src/oidc_test/templates/home.html | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/Containerfile b/Containerfile index aef57f8..75fef2b 100644 --- a/Containerfile +++ b/Containerfile @@ -9,6 +9,9 @@ WORKDIR /app RUN uv pip install --system . +# Add demo plugin +RUN PIP_EXTRA_INDEX_URL=https://pypi.org/simple/ uv pip install --system --index-url https://code.philo.ydns.eu/api/packages/philorg/pypi/simple/ oidc-fastapi-test-petition + # Possible to run with: #CMD ["oidc-test", "--port", "80"] #CMD ["fastapi", "run", "src/oidc_test/main.py", "--port", "8873", "--root-path", "/oidc-test"] diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 790da81..23ba7ff 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -56,7 +56,7 @@ Account management </button> {% endif %} - <button onclick="location.href='{{ request.url_for("refresh") }}'" class="refresh">Refresh</button> + <button onclick="location.href='{{ request.url_for("refresh") }}'" class="refresh">Refresh access token</button> <button onclick="location.href='{{ request.url_for("logout") }}'" class="logout">Logout</button> </div> {% endif %} From 381ce1ebc175d899cca49de14b8b7b6a6b263866 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Thu, 13 Feb 2025 12:23:18 +0100 Subject: [PATCH 11/46] Use pydantic on ResourceServer --- src/oidc_test/main.py | 3 +-- src/oidc_test/registry.py | 10 +++++----- src/oidc_test/resource_server.py | 13 ++++++++++--- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 28eab8a..3858a08 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -58,7 +58,7 @@ async def lifespan(app: FastAPI): try: await provider.get_info() except NoPublicKey: - logger.warn(f"Disable {provider.id}: public key not found") + logger.warning(f"Disable {provider.id}: public key not found") del providers[provider.id] yield @@ -300,7 +300,6 @@ async def refresh( await update_token(provider.id, new_token) return RedirectResponse(url=request.url_for("home")) - # Snippet for running standalone # Mostly useful for the --version option, # as running with uvicorn is easy and provides better flexibility, eg. diff --git a/src/oidc_test/registry.py b/src/oidc_test/registry.py index 6db0a47..a184ec0 100644 --- a/src/oidc_test/registry.py +++ b/src/oidc_test/registry.py @@ -17,20 +17,20 @@ class ProcessError(Exception): pass -class ResourceProvider: - name: str +class ResourceProvider(BaseModel): scope_required: str | None = None default_resource_id: str | None = None def __init__(self, name: str): - self.name = name + super().__init__() + self.__name__ = name async def process(self, user: User, resource_id: str | None = None) -> ProcessResult: - logger.warning(f"{self.name} should define a process method") + logger.warning(f"{self.__name__} should define a process method") return ProcessResult() -class ResourceRegistry: +class ResourceRegistry(BaseModel): resource_providers: dict[str, ResourceProvider] = {} def make_registry(self): diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index 3b89240..1af0f6b 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -1,5 +1,4 @@ -from datetime import datetime -from typing import Annotated +from typing import Annotated, Any import logging from authlib.oauth2.auth import OAuth2Token @@ -21,7 +20,7 @@ from oidc_test.auth.utils import ( from oidc_test.auth_providers import providers from oidc_test.settings import settings from oidc_test.models import User -from oidc_test.registry import ProcessError, ProcessResult, registry +from oidc_test.registry import ProcessError, ProcessResult, ResourceProvider, registry logger = logging.getLogger("oidc-test") @@ -48,6 +47,14 @@ resource_server.add_middleware( # Routes for RBAC based tests +@resource_server.get("/") +async def resources() -> dict[str, dict[str, Any]]: + return { + "internal": {}, + "plugins": registry.resource_providers + } + + @resource_server.get("/public") async def public() -> dict: return {"msg": "Not protected"} From 9d3146dc1c895d7a4eebace39af6222ebbbea091 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Thu, 13 Feb 2025 18:15:26 +0100 Subject: [PATCH 12/46] Add role protection to resource servers, remove hardcoded resources --- src/oidc_test/auth/utils.py | 51 ++++++++----- src/oidc_test/main.py | 3 +- src/oidc_test/registry.py | 7 +- src/oidc_test/resource_server.py | 118 +++++++++++++++++------------- src/oidc_test/settings.py | 1 + src/oidc_test/static/styles.css | 14 +++- src/oidc_test/templates/home.html | 43 +++-------- 7 files changed, 127 insertions(+), 110 deletions(-) diff --git a/src/oidc_test/auth/utils.py b/src/oidc_test/auth/utils.py index 9479c48..acd68b5 100644 --- a/src/oidc_test/auth/utils.py +++ b/src/oidc_test/auth/utils.py @@ -8,7 +8,7 @@ from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOA from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError # from authlib.oauth1.auth import OAuthToken -from authlib.oauth2.auth import OAuth2Token +from authlib.oauth2.rfc6749 import OAuth2Token from oidc_test.auth.provider import Provider from oidc_test.models import User @@ -60,25 +60,28 @@ def init_providers(): sub="", auth_provider_id=provider_settings.id ) provider = Provider(**provider_settings_dict) - authlib_oauth.register( - name=provider.id, - server_metadata_url=provider.openid_configuration, - client_kwargs={ - "scope": " ".join( - ["openid", "email", "offline_access", "profile"] - + provider.resource_provider_scopes - ), - }, - client_id=provider.client_id, - client_secret=provider.client_secret, - api_base_url=provider.url, - # For PKCE (not implemented yet): - # code_challenge_method="S256", - fetch_token=fetch_token, - update_token=update_token, - # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) - ) - provider.authlib_client = getattr(authlib_oauth, provider.id) + if provider.disabled: + logger.info(f"{provider_settings.name} is disabled, skipping") + else: + authlib_oauth.register( + name=provider.id, + server_metadata_url=provider.openid_configuration, + client_kwargs={ + "scope": " ".join( + ["openid", "email", "offline_access", "profile"] + + provider.resource_provider_scopes + ), + }, + client_id=provider.client_id, + client_secret=provider.client_secret, + api_base_url=provider.url, + # For PKCE (not implemented yet): + # code_challenge_method="S256", + fetch_token=fetch_token, + update_token=update_token, + # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) + ) + provider.authlib_client = getattr(authlib_oauth, provider.id) providers[provider.id] = provider @@ -270,6 +273,14 @@ async def get_user_from_token( ) return user +async def get_user_from_token_or_none( + token: Annotated[str, Depends(oauth2_scheme)], + request: Request, +) -> User | None: + try: + return await get_user_from_token(token, request) + except HTTPException: + return None class UserWithRole: roles: set[str] diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 3858a08..9e8b135 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -55,6 +55,8 @@ async def lifespan(app: FastAPI): init_providers() registry.make_registry() for provider in list(providers.values()): + if provider.disabled: + continue try: await provider.get_info() except NoPublicKey: @@ -106,7 +108,6 @@ async def home( else: context["access_token"] = token["access_token"] # XXX: resources defined externally? I am confused... - context["resources"] = provider.resources try: access_token_parsed = provider.decode(token["access_token"], verify_signature=False) except PyJWTError as err: diff --git a/src/oidc_test/registry.py b/src/oidc_test/registry.py index a184ec0..e9c9809 100644 --- a/src/oidc_test/registry.py +++ b/src/oidc_test/registry.py @@ -18,15 +18,18 @@ class ProcessError(Exception): class ResourceProvider(BaseModel): + name: str scope_required: str | None = None + role_required: str | None = None + is_public: bool = False default_resource_id: str | None = None def __init__(self, name: str): super().__init__() - self.__name__ = name + self.__id__ = name async def process(self, user: User, resource_id: str | None = None) -> ProcessResult: - logger.warning(f"{self.__name__} should define a process method") + logger.warning(f"{self.__id__} should define a process method") return ProcessResult() diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index 1af0f6b..03d109e 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -1,7 +1,7 @@ from typing import Annotated, Any import logging -from authlib.oauth2.auth import OAuth2Token +from authlib.oauth2.rfc6749 import OAuth2Token from httpx import AsyncClient from jwt.exceptions import ExpiredSignatureError, InvalidTokenError from fastapi import FastAPI, HTTPException, Depends, Request, status @@ -16,6 +16,7 @@ from oidc_test.auth.utils import ( get_token_or_none, get_user_from_token, UserWithRole, + get_user_from_token_or_none, ) from oidc_test.auth_providers import providers from oidc_test.settings import settings @@ -55,54 +56,12 @@ async def resources() -> dict[str, dict[str, Any]]: } -@resource_server.get("/public") -async def public() -> dict: - return {"msg": "Not protected"} - - -@resource_server.get("/protected") -async def get_protected(user: Annotated[User, Depends(get_user_from_token)]): - assert user is not None # Just to keep QA checks happy - return {"msg": "Only authenticated users can see this"} - - -@resource_server.get("/protected-by-foorole") -async def get_protected_by_foorole( - user: Annotated[User, Depends(UserWithRole("foorole"))], -): - assert user is not None - return {"msg": "Only users with foorole can see this"} - - -@resource_server.get("/protected-by-barrole") -async def get_protected_by_barrole( - user: Annotated[User, Depends(UserWithRole("barrole"))], -): - assert user is not None - return {"msg": "Protected by barrole"} - - -@resource_server.get("/protected-by-foorole-and-barrole") -async def get_protected_by_foorole_and_barrole( - user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))], -): - assert user is not None # Just to keep QA checks happy - return {"msg": "Only users with foorole and barrole can see this"} - - -@resource_server.get("/protected-by-foorole-or-barrole") -async def get_protected_by_foorole_or_barrole( - user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))], -): - assert user is not None # Just to keep QA checks happy - return {"msg": "Only users with foorole or barrole can see this"} - @resource_server.get("/{resource_name}") @resource_server.get("/{resource_name}/{resource_id}") async def get_resource( resource_name: str, - user: Annotated[User, Depends(get_user_from_token)], + user: Annotated[User, Depends(get_user_from_token_or_none)], token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], request: Request, resource_id: str | None = None, @@ -111,19 +70,29 @@ async def get_resource( provider = providers[user.auth_provider_id] # Third party resource (provided through the auth provider) # The token is just passed on + breakpoint() if resource_name in [r.resource_name for r in provider.resources]: return await get_auth_provider_resource( provider=provider, resource_name=resource_name, - access_token=token["access_token"] if token else None, + token=token, user=user, ) # Internal resource (provided here) if resource_name in registry.resource_providers: resource_provider = registry.resource_providers[resource_name] - if resource_provider.scope_required is not None and user.has_scope( - resource_provider.scope_required - ): + reasons: dict[str, str] = {} + if not resource_provider.is_public: + if resource_provider.scope_required is not None and not user.has_scope( + resource_provider.scope_required + ): + reasons["scope"] = f"No scope {resource_provider.scope_required} in the access token " \ + "but it is required for accessing this resource" + if resource_provider.role_required is not None \ + and resource_provider.role_required not in user.roles_as_set: + reasons["role"] = f"You don't have the role {resource_provider.role_required} " \ + "but it is required for accessing this resource" + if len(reasons) == 0: try: return await resource_provider.process(user=user, resource_id=resource_id) except ProcessError as err: @@ -132,9 +101,7 @@ async def get_resource( ) else: raise HTTPException( - status.HTTP_401_UNAUTHORIZED, - f"No scope {resource_provider.scope_required} in the access token " - + "but it is required for accessing this resource", + status.HTTP_401_UNAUTHORIZED, ", ".join(reasons.values()) ) else: raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown resource {resource_name}") @@ -142,8 +109,13 @@ async def get_resource( async def get_auth_provider_resource( - provider: Provider, resource_name: str, access_token: str | None, user: User + provider: Provider, resource_name: str, token: OAuth2Token | None, user: User ) -> ProcessResult: + if token is None: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, f"No auth token" + ) + access_token = token["access_token"] resource = [r for r in provider.resources if r.resource_name == resource_name][0] async with AsyncClient() as client: resp = await client.get( @@ -165,6 +137,48 @@ async def get_auth_provider_resource( return ProcessResult(result=resp.json()) +#@resource_server.get("/public") +#async def public() -> dict: +# return {"msg": "Not protected"} +# +# +#@resource_server.get("/protected") +#async def get_protected(user: Annotated[User, Depends(get_user_from_token)]): +# assert user is not None # Just to keep QA checks happy +# return {"msg": "Only authenticated users can see this"} +# +# +#@resource_server.get("/protected-by-foorole") +#async def get_protected_by_foorole( +# user: Annotated[User, Depends(UserWithRole("foorole"))], +#): +# assert user is not None +# return {"msg": "Only users with foorole can see this"} +# +# +#@resource_server.get("/protected-by-barrole") +#async def get_protected_by_barrole( +# user: Annotated[User, Depends(UserWithRole("barrole"))], +#): +# assert user is not None +# return {"msg": "Protected by barrole"} +# +# +#@resource_server.get("/protected-by-foorole-and-barrole") +#async def get_protected_by_foorole_and_barrole( +# user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))], +#): +# assert user is not None # Just to keep QA checks happy +# return {"msg": "Only users with foorole and barrole can see this"} +# +# +#@resource_server.get("/protected-by-foorole-or-barrole") +#async def get_protected_by_foorole_or_barrole( +# user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))], +#): +# assert user is not None # Just to keep QA checks happy +# return {"msg": "Only users with foorole or barrole can see this"} + # async def get_resource_(resource_id: str, user: User, **kwargs) -> dict: # """ # Resource processing: build an informative rely as a simple showcase diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index 2acbc3f..3e7001c 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -44,6 +44,7 @@ class AuthProviderSettings(BaseModel): resource_provider_scopes: list[str] = [] session_key: str = "sid" skip_verify_signature: bool = True + disabled: bool = False @computed_field @property diff --git a/src/oidc_test/static/styles.css b/src/oidc_test/static/styles.css index e163a68..2baa748 100644 --- a/src/oidc_test/static/styles.css +++ b/src/oidc_test/static/styles.css @@ -142,19 +142,27 @@ hr { .providers .provider { min-height: 2em; } -.providers .provider a.link { +.providers .provider .link { text-decoration: none; max-height: 2em; } -.providers .provider .link div { +.providers .provider .link { background-color: #f7c7867d; border-radius: 8px; padding: 6px; text-align: center; color: black; - font-weight: bold; + font-weight: 400; cursor: pointer; + border: 0; + width: 100%; } + +.providers .provider .link.disabled { + color: gray; + cursor: not-allowed; +} + .providers .provider .hint { font-size: 80%; max-width: 13em; diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 23ba7ff..ecefb0f 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -11,7 +11,11 @@ {% for provider in providers.values() %} <tr class="provider"> <td> - <a class="link" href="login/{{ provider.id }}"><div>{{ provider.name }}</div></a> + <button class="link{% if provider.disabled %} disabled{% endif %}" + {% if provider.disabled %}disabled{% endif %} + onclick="location.href='login/{{ provider.id }}'"> + {{ provider.name }} + </button> </td> <td class="hint">{{ provider.hint }}</div> </td> @@ -62,42 +66,17 @@ {% endif %} <hr> <div class="content"> - <p> - Resources validated by role: - </p> - <div class="links-to-check"> - <button resource-name="public" onclick="get_resource('public', '{{ access_token }}', '{{ auth_provider.id }}')">Public</button> - <button resource-name="protected" onclick="get_resource('protected', '{{ access_token }}', '{{ auth_provider.id }}')">Auth protected content</button> - <button resource-name="protected-by-foorole" onclick="get_resource('protected-by-foorole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole protected content</button> - <button resource-name="protected-by-foorole-or-barrole" onclick="get_resource('protected-by-foorole-or-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole or barrole protected content</button> - <button resource-name="protected-by-barrole" onclick="get_resource('protected-by-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + barrole protected content</button> - <button resource-name="protected-by-foorole-and-barrole" onclick="get_resource('protected-by-foorole-and-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole and barrole protected content</button> - <button resource-name="fast_api_depends" class="hidden" onclick="get_resource('fast_api_depends', '{{ access_token }}', '{{ auth_provider.id }}')">Using FastAPI Depends</button> - <!--<button resource-name="introspect" onclick="get_resource('introspect', '{{ access_token }}', '{{ auth_provider.id }}')">Introspect token (401 expected)</button>--> - </div> - <!-- XXX confused... - {% if resources %} - <p> - Resources for this provider: - </p> - <div class="links-to-check"> - {% for resource in resources %} - <button resource-name="{{ resource.id }}" onclick="get_resource('{{ resource.name }}', '{{ access_token }}', '{{ auth_provider.id }}')">{{ resource.name }}</buttona> - {% endfor %} - </div> - {% endif %} - --> {% if resource_providers %} <p> - Resource providers (validated by scope): + Resource providers: </p> <div class="links-to-check"> {% for name, resource_provider in resource_providers.items() %} - {% if resource_provider.default_resource_id %} - <button resource-name="{{ name }}" resource-id="{{ resource_provider.default_resource_id }}" onclick="get_resource('{{ name }}', '{{ access_token }}', '{{ auth_provider.id }}', '{{ resource_provider.default_resource_id }}')">{{ name }}</buttona> - {% else %} - <button resource-name="{{ name }}" onclick="get_resource('{{ name }}', '{{ access_token }}', '{{ auth_provider.id }}')">{{ name }}</buttona> - {% endif %} + {% if resource_provider.default_resource_id %} + <button resource-name="{{ name }}" resource-id="{{ resource_provider.default_resource_id }}" onclick="get_resource('{{ name }}', '{{ access_token }}', '{{ auth_provider.id }}', '{{ resource_provider.default_resource_id }}')">{{ name }}</buttona> + {% else %} + <button resource-name="{{ name }}" onclick="get_resource('{{ name }}', '{{ access_token }}', '{{ auth_provider.id }}')">{{ name }}</buttona> + {% endif %} {% endfor %} </div> {% endif %} From 5bd4b8280427d4a070a724cee141ee52ed3664a4 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Thu, 13 Feb 2025 18:26:23 +0100 Subject: [PATCH 13/46] Update demo resource provider package name --- Containerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Containerfile b/Containerfile index 75fef2b..2e3fd28 100644 --- a/Containerfile +++ b/Containerfile @@ -10,7 +10,7 @@ WORKDIR /app RUN uv pip install --system . # Add demo plugin -RUN PIP_EXTRA_INDEX_URL=https://pypi.org/simple/ uv pip install --system --index-url https://code.philo.ydns.eu/api/packages/philorg/pypi/simple/ oidc-fastapi-test-petition +RUN PIP_EXTRA_INDEX_URL=https://pypi.org/simple/ uv pip install --system --index-url https://code.philo.ydns.eu/api/packages/philorg/pypi/simple/ oidc-fastapi-test-resource-provider-demo # Possible to run with: #CMD ["oidc-test", "--port", "80"] From 40ddb616363e303968f1b9d5e7291f0868ac7656 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Thu, 13 Feb 2025 18:26:48 +0100 Subject: [PATCH 14/46] Cleanup --- src/oidc_test/resource_server.py | 1 - src/oidc_test/templates/home.html | 15 ++++++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index 03d109e..1877875 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -70,7 +70,6 @@ async def get_resource( provider = providers[user.auth_provider_id] # Third party resource (provided through the auth provider) # The token is just passed on - breakpoint() if resource_name in [r.resource_name for r in provider.resources]: return await get_auth_provider_resource( provider=provider, diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index ecefb0f..6c4e6a6 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -68,14 +68,23 @@ <div class="content"> {% if resource_providers %} <p> - Resource providers: + {{ auth_provider.name }} provides these resources: </p> <div class="links-to-check"> {% for name, resource_provider in resource_providers.items() %} {% if resource_provider.default_resource_id %} - <button resource-name="{{ name }}" resource-id="{{ resource_provider.default_resource_id }}" onclick="get_resource('{{ name }}', '{{ access_token }}', '{{ auth_provider.id }}', '{{ resource_provider.default_resource_id }}')">{{ name }}</buttona> + <button resource-name="{{ name }}" + resource-id="{{ resource_provider.default_resource_id }}" + onclick="get_resource('{{ name }}', '{{ access_token }}', '{{ auth_provider.id }}', '{{ resource_provider.default_resource_id }}')" + > + {{ resource_provider.name }} + </button> {% else %} - <button resource-name="{{ name }}" onclick="get_resource('{{ name }}', '{{ access_token }}', '{{ auth_provider.id }}')">{{ name }}</buttona> + <button resource-name="{{ name }}" + onclick="get_resource('{{ name }}', '{{ access_token }}', '{{ auth_provider.id }}')" + > + {{ resource_provider.name }} + </buttona> {% endif %} {% endfor %} </div> From c89ca4098b2165014890af89caebde7310b88db0 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Fri, 14 Feb 2025 13:21:55 +0100 Subject: [PATCH 15/46] Fix public resource access; free resource response validation; formatting --- src/oidc_test/auth/utils.py | 15 ++-- src/oidc_test/main.py | 10 +-- src/oidc_test/registry.py | 9 ++- src/oidc_test/resource_server.py | 127 ++++++++++++++++--------------- 4 files changed, 84 insertions(+), 77 deletions(-) diff --git a/src/oidc_test/auth/utils.py b/src/oidc_test/auth/utils.py index acd68b5..7dd0e3d 100644 --- a/src/oidc_test/auth/utils.py +++ b/src/oidc_test/auth/utils.py @@ -87,6 +87,7 @@ def init_providers(): authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") +oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None: @@ -125,7 +126,7 @@ async def get_current_user(request: Request) -> User: """ if (user_sub := request.session.get("user_sub")) is None: raise HTTPException(status.HTTP_401_UNAUTHORIZED) - token = await get_token(request) + token = await get_token_from_session(request) user = await db.get_user(user_sub) ## Check if the token is expired if token.is_expired(): @@ -146,16 +147,16 @@ async def get_current_user(request: Request) -> User: return user -async def get_token_or_none(request: Request) -> OAuth2Token | None: +async def get_token_from_session_or_none(request: Request) -> OAuth2Token | None: """Return the auth token from the session or None. Can be used in Depends()""" try: - return await get_token(request) + return await get_token_from_session(request) except HTTPException: return None -async def get_token(request: Request) -> OAuth2Token: +async def get_token_from_session(request: Request) -> OAuth2Token: """Return the token from the session. Can be used in Depends()""" try: @@ -273,15 +274,19 @@ async def get_user_from_token( ) return user + async def get_user_from_token_or_none( - token: Annotated[str, Depends(oauth2_scheme)], + token: Annotated[str | None, Depends(oauth2_scheme_optional)], request: Request, ) -> User | None: + if token is None: + return None try: return await get_user_from_token(token, request) except HTTPException: return None + class UserWithRole: roles: set[str] diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 9e8b135..9f5e746 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -33,8 +33,8 @@ from oidc_test.auth.utils import ( get_auth_provider_or_none, get_current_user_or_none, authlib_oauth, - get_token_or_none, - get_token, + get_token_from_session_or_none, + get_token_from_session, update_token, ) from oidc_test.auth.utils import init_providers @@ -88,9 +88,9 @@ app.mount("/resource", resource_server, name="resource_server") @app.get("/") async def home( request: Request, - user: Annotated[User, Depends(get_current_user_or_none)], + user: Annotated[User | None, Depends(get_current_user_or_none)], provider: Annotated[Provider | None, Depends(get_auth_provider_or_none)], - token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], + token: Annotated[OAuth2Token | None, Depends(get_token_from_session_or_none)], ) -> HTMLResponse: context = { "show_token": settings.show_token, @@ -291,7 +291,7 @@ async def non_compliant_logout( async def refresh( request: Request, provider: Annotated[Provider, Depends(get_auth_provider)], - token: Annotated[OAuth2Token, Depends(get_token)], + token: Annotated[OAuth2Token, Depends(get_token_from_session)], ) -> RedirectResponse: """Manually refresh token""" new_token = await provider.authlib_client.fetch_access_token( diff --git a/src/oidc_test/registry.py b/src/oidc_test/registry.py index e9c9809..794a843 100644 --- a/src/oidc_test/registry.py +++ b/src/oidc_test/registry.py @@ -1,8 +1,7 @@ from importlib.metadata import entry_points import logging -from typing import Any -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from oidc_test.models import User @@ -10,7 +9,9 @@ logger = logging.getLogger("registry") class ProcessResult(BaseModel): - result: dict[str, Any] = {} + model_config = ConfigDict( + extra="allow", + ) class ProcessError(Exception): @@ -28,7 +29,7 @@ class ResourceProvider(BaseModel): super().__init__() self.__id__ = name - async def process(self, user: User, resource_id: str | None = None) -> ProcessResult: + async def process(self, user: User | None, resource_id: str | None = None) -> ProcessResult: logger.warning(f"{self.__id__} should define a process method") return ProcessResult() diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index 1877875..ee4ff10 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -13,15 +13,13 @@ from fastapi.middleware.cors import CORSMiddleware from oidc_test.auth.provider import Provider from oidc_test.auth.utils import ( - get_token_or_none, - get_user_from_token, - UserWithRole, get_user_from_token_or_none, + oauth2_scheme_optional, ) from oidc_test.auth_providers import providers from oidc_test.settings import settings from oidc_test.models import User -from oidc_test.registry import ProcessError, ProcessResult, ResourceProvider, registry +from oidc_test.registry import ProcessError, ProcessResult, registry logger = logging.getLogger("oidc-test") @@ -50,60 +48,67 @@ resource_server.add_middleware( @resource_server.get("/") async def resources() -> dict[str, dict[str, Any]]: - return { - "internal": {}, - "plugins": registry.resource_providers - } - + return {"internal": {}, "plugins": registry.resource_providers} @resource_server.get("/{resource_name}") @resource_server.get("/{resource_name}/{resource_id}") async def get_resource( resource_name: str, - user: Annotated[User, Depends(get_user_from_token_or_none)], - token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], - request: Request, + user: Annotated[User | None, Depends(get_user_from_token_or_none)], + token: Annotated[OAuth2Token | None, Depends(oauth2_scheme_optional)], resource_id: str | None = None, -) -> ProcessResult: - """Generic path for testing a resource provided by a provider""" - provider = providers[user.auth_provider_id] - # Third party resource (provided through the auth provider) - # The token is just passed on - if resource_name in [r.resource_name for r in provider.resources]: - return await get_auth_provider_resource( - provider=provider, - resource_name=resource_name, - token=token, - user=user, - ) +): + """Generic path for testing a resource provided by a provider. + There's no field validation (response type of ProcessResult) on purpose, + leaving the responsibility of the response validation to resource providers""" + # Get the resource if it's defined in user auth provider's resources (external) + if user is not None: + provider = providers[user.auth_provider_id] + # Third party resource (provided through the auth provider) + # The token is just passed on + if resource_name in [r.resource_name for r in provider.resources]: + return await get_auth_provider_resource( + provider=provider, + resource_name=resource_name, + token=token, + user=user, + ) # Internal resource (provided here) if resource_name in registry.resource_providers: resource_provider = registry.resource_providers[resource_name] - reasons: dict[str, str] = {} + reason: dict[str, str] = {} if not resource_provider.is_public: - if resource_provider.scope_required is not None and not user.has_scope( - resource_provider.scope_required - ): - reasons["scope"] = f"No scope {resource_provider.scope_required} in the access token " \ - "but it is required for accessing this resource" - if resource_provider.role_required is not None \ - and resource_provider.role_required not in user.roles_as_set: - reasons["role"] = f"You don't have the role {resource_provider.role_required} " \ - "but it is required for accessing this resource" - if len(reasons) == 0: + if user is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Resource is not public") + else: + if resource_provider.scope_required is not None and not user.has_scope( + resource_provider.scope_required + ): + reason["scope"] = ( + f"No scope {resource_provider.scope_required} in the access token " + "but it is required for accessing this resource" + ) + if ( + resource_provider.role_required is not None + and resource_provider.role_required not in user.roles_as_set + ): + reason["role"] = ( + f"You don't have the role {resource_provider.role_required} " + "but it is required for accessing this resource" + ) + if len(reason) == 0: try: - return await resource_provider.process(user=user, resource_id=resource_id) + resp = await resource_provider.process(user=user, resource_id=resource_id) + return resp except ProcessError as err: raise HTTPException( status.HTTP_401_UNAUTHORIZED, f"Cannot process resource: {err}" ) else: - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, ", ".join(reasons.values()) - ) + raise HTTPException(status.HTTP_401_UNAUTHORIZED, ", ".join(reason.values())) else: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown resource {resource_name}") + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Unknown resource") # return await get_resource_(resource_name, user, **request.query_params) @@ -111,9 +116,7 @@ async def get_auth_provider_resource( provider: Provider, resource_name: str, token: OAuth2Token | None, user: User ) -> ProcessResult: if token is None: - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, f"No auth token" - ) + raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"No auth token") access_token = token["access_token"] resource = [r for r in provider.resources if r.resource_name == resource_name][0] async with AsyncClient() as client: @@ -129,52 +132,50 @@ async def get_auth_provider_resource( # Only a demo, real application would really process the response resp_length = len(resp.text) if resp_length > 1024: - return ProcessResult( - result={"msg": f"The resource is too long ({resp_length} bytes) to show here"} - ) + return ProcessResult(msg=f"The resource is too long ({resp_length} bytes) to show here") else: - return ProcessResult(result=resp.json()) + return ProcessResult(**resp.json()) -#@resource_server.get("/public") -#async def public() -> dict: +# @resource_server.get("/public") +# async def public() -> dict: # return {"msg": "Not protected"} # # -#@resource_server.get("/protected") -#async def get_protected(user: Annotated[User, Depends(get_user_from_token)]): +# @resource_server.get("/protected") +# async def get_protected(user: Annotated[User, Depends(get_user_from_token)]): # assert user is not None # Just to keep QA checks happy # return {"msg": "Only authenticated users can see this"} # # -#@resource_server.get("/protected-by-foorole") -#async def get_protected_by_foorole( +# @resource_server.get("/protected-by-foorole") +# async def get_protected_by_foorole( # user: Annotated[User, Depends(UserWithRole("foorole"))], -#): +# ): # assert user is not None # return {"msg": "Only users with foorole can see this"} # # -#@resource_server.get("/protected-by-barrole") -#async def get_protected_by_barrole( +# @resource_server.get("/protected-by-barrole") +# async def get_protected_by_barrole( # user: Annotated[User, Depends(UserWithRole("barrole"))], -#): +# ): # assert user is not None # return {"msg": "Protected by barrole"} # # -#@resource_server.get("/protected-by-foorole-and-barrole") -#async def get_protected_by_foorole_and_barrole( +# @resource_server.get("/protected-by-foorole-and-barrole") +# async def get_protected_by_foorole_and_barrole( # user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))], -#): +# ): # assert user is not None # Just to keep QA checks happy # return {"msg": "Only users with foorole and barrole can see this"} # # -#@resource_server.get("/protected-by-foorole-or-barrole") -#async def get_protected_by_foorole_or_barrole( +# @resource_server.get("/protected-by-foorole-or-barrole") +# async def get_protected_by_foorole_or_barrole( # user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))], -#): +# ): # assert user is not None # Just to keep QA checks happy # return {"msg": "Only users with foorole or barrole can see this"} From 4008036bca36ae933b14fba5ed239ddd0e727d57 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Fri, 14 Feb 2025 13:36:22 +0100 Subject: [PATCH 16/46] CI: don't fail because of publish step (already exists) --- .forgejo/workflows/build.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.forgejo/workflows/build.yaml b/.forgejo/workflows/build.yaml index e02bf47..352a0a9 100644 --- a/.forgejo/workflows/build.yaml +++ b/.forgejo/workflows/build.yaml @@ -90,3 +90,4 @@ jobs: env: LOCAL_PYPI_TOKEN: ${{ secrets.LOCAL_PYPI_TOKEN }} run: uv publish --publish-url https://code.philo.ydns.eu/api/packages/philorg/pypi --token $LOCAL_PYPI_TOKEN + continue-on-error: true From 1c57944a902aa25bf7bd4527b4d83bfb6d1b81e2 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Mon, 17 Feb 2025 17:26:30 +0100 Subject: [PATCH 17/46] Fix typo --- src/oidc_test/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 9f5e746..79293a3 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -264,7 +264,7 @@ async def logout( { "post_logout_redirect_uri": post_logout_uri, "id_token_hint": token["id_token"], - "cliend_id": "oidc_local_test", + "client_id": "oidc_local_test", } ) ) @@ -301,6 +301,7 @@ async def refresh( await update_token(provider.id, new_token) return RedirectResponse(url=request.url_for("home")) + # Snippet for running standalone # Mostly useful for the --version option, # as running with uvicorn is easy and provides better flexibility, eg. From 435c11b6caf7ca59529a62eec664a2bbcf2d554a Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Wed, 19 Feb 2025 04:07:57 +0100 Subject: [PATCH 18/46] Working use as third party resource provider --- src/oidc_test/auth/provider.py | 18 ++++++++- src/oidc_test/main.py | 26 ++++++------- src/oidc_test/registry.py | 10 ++--- src/oidc_test/resource_server.py | 62 +++++++++++++++++++++++-------- src/oidc_test/settings.py | 28 +++++++++----- src/oidc_test/static/utils.js | 10 +++-- src/oidc_test/templates/home.html | 43 +++++++++++++++------ 7 files changed, 138 insertions(+), 59 deletions(-) diff --git a/src/oidc_test/auth/provider.py b/src/oidc_test/auth/provider.py index 17dcaa0..c614805 100644 --- a/src/oidc_test/auth/provider.py +++ b/src/oidc_test/auth/provider.py @@ -7,7 +7,7 @@ from pydantic import ConfigDict from authlib.integrations.starlette_client.apps import StarletteOAuth2App from httpx import AsyncClient -from oidc_test.settings import AuthProviderSettings, settings +from oidc_test.settings import AuthProviderSettings, ResourceProvider, Resource, settings from oidc_test.models import User logger = logging.getLogger("oidc-test") @@ -24,6 +24,7 @@ class Provider(AuthProviderSettings): authlib_client: StarletteOAuth2App = StarletteOAuth2App(None) info: dict[str, Any] = {} unknown_auth_user: User + logout_with_id_token_hint: bool = True def decode(self, token: str, verify_signature: bool | None = None) -> dict[str, Any]: """Decode the token with signature check""" @@ -89,3 +90,18 @@ class Provider(AuthProviderSettings): def get_session_key(self, userinfo): return userinfo[self.session_key] + + def get_resource(self, resource_name: str) -> Resource: + return [ + resource for resource in self.resources if resource.resource_name == resource_name + ][0] + + def get_resource_url(self, resource_name: str) -> str: + return self.url + self.get_resource(resource_name).url + + def get_resource_provider(self, resource_provider_id: str) -> ResourceProvider: + return [ + provider + for provider in self.resource_providers + if provider.id == resource_provider_id + ][0] diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 79293a3..f930e48 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -107,7 +107,6 @@ async def home( context["resources"] = None else: context["access_token"] = token["access_token"] - # XXX: resources defined externally? I am confused... try: access_token_parsed = provider.decode(token["access_token"], verify_signature=False) except PyJWTError as err: @@ -118,7 +117,8 @@ async def home( context["access_token_scope"] = None context["id_token_parsed"] = provider.decode(token["id_token"], verify_signature=False) context["access_token_parsed"] = access_token_parsed - context["resource_providers"] = registry.resource_providers + context["resources"] = registry.resources + context["resource_providers"] = provider.resource_providers try: context["refresh_token_parsed"] = provider.decode( token["refresh_token"], verify_signature=False @@ -246,7 +246,7 @@ async def logout( if ( provider_logout_uri := provider.authlib_client.server_metadata.get("end_session_endpoint") ) is None: - logger.warn(f"Cannot find end_session_endpoint for provider {provider.id}") + logger.warning(f"Cannot find end_session_endpoint for provider {provider.id}") return RedirectResponse(request.url_for("non_compliant_logout")) post_logout_uri = request.url_for("home") # Clear session @@ -255,19 +255,15 @@ async def logout( try: token = await db.get_token(provider, request.session.pop("sid", None)) except TokenNotInDb: - logger.warn("No session in db for the token or no token") + logger.warning("No session in db for the token or no token") return RedirectResponse(request.url_for("home")) - logout_url = ( - provider_logout_uri - + "?" - + urlencode( - { - "post_logout_redirect_uri": post_logout_uri, - "id_token_hint": token["id_token"], - "client_id": "oidc_local_test", - } - ) - ) + url_query = { + "post_logout_redirect_uri": post_logout_uri, + "client_id": provider.client_id, + } + if provider.logout_with_id_token_hint: + url_query["id_token_hint"] = token["id_token"] + logout_url = f"{provider_logout_uri}?{urlencode(url_query)}" return RedirectResponse(logout_url) diff --git a/src/oidc_test/registry.py b/src/oidc_test/registry.py index 794a843..3b91ad4 100644 --- a/src/oidc_test/registry.py +++ b/src/oidc_test/registry.py @@ -18,7 +18,7 @@ class ProcessError(Exception): pass -class ResourceProvider(BaseModel): +class Resource(BaseModel): name: str scope_required: str | None = None role_required: str | None = None @@ -35,13 +35,13 @@ class ResourceProvider(BaseModel): class ResourceRegistry(BaseModel): - resource_providers: dict[str, ResourceProvider] = {} + resources: dict[str, Resource] = {} def make_registry(self): for ep in entry_points().select(group="oidc_test.resource_provider"): - ResourceProviderClass = ep.load() - if issubclass(ResourceProviderClass, ResourceProvider): - self.resource_providers[ep.name] = ResourceProviderClass(ep.name) + ResourceClass = ep.load() + if issubclass(ResourceClass, Resource): + self.resources[ep.name] = ResourceClass(ep.name) registry = ResourceRegistry() diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index ee4ff10..a4d5368 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -17,7 +17,7 @@ from oidc_test.auth.utils import ( oauth2_scheme_optional, ) from oidc_test.auth_providers import providers -from oidc_test.settings import settings +from oidc_test.settings import ResourceProvider, settings from oidc_test.models import User from oidc_test.registry import ProcessError, ProcessResult, registry @@ -48,7 +48,7 @@ resource_server.add_middleware( @resource_server.get("/") async def resources() -> dict[str, dict[str, Any]]: - return {"internal": {}, "plugins": registry.resource_providers} + return {"internal": {}, "plugins": registry.resources} @resource_server.get("/{resource_name}") @@ -65,8 +65,41 @@ async def get_resource( # Get the resource if it's defined in user auth provider's resources (external) if user is not None: provider = providers[user.auth_provider_id] + if ":" in resource_name: + # Third-party resource provider: send the request with the request token + resource_provider_id, resource_name = resource_name.split(":", 1) + provider = providers[user.auth_provider_id] + resource_provider: ResourceProvider = provider.get_resource_provider( + resource_provider_id + ) + resource_url = resource_provider.get_resource_url(resource_name) + async with AsyncClient(verify=resource_provider.verify_ssl) as client: + try: + resp = await client.get( + resource_url, + headers={ + "Content-type": "application/json", + "Authorization": f"Bearer {token}", + "auth_provider": user.auth_provider_id, + }, + ) + except Exception as err: + raise HTTPException( + status.HTTP_500_INTERNAL_SERVER_ERROR, err.__class__.__name__ + ) + else: + if resp.is_success: + return resp.json() + else: + reason_str: str + try: + reason_str = resp.json().get("detail", str(resp)) + except Exception: + reason_str = str(resp.text) + raise HTTPException(resp.status_code, reason_str) # Third party resource (provided through the auth provider) # The token is just passed on + # XXX: is this branch valid anymore? if resource_name in [r.resource_name for r in provider.resources]: return await get_auth_provider_resource( provider=provider, @@ -75,31 +108,31 @@ async def get_resource( user=user, ) # Internal resource (provided here) - if resource_name in registry.resource_providers: - resource_provider = registry.resource_providers[resource_name] + if resource_name in registry.resources: + resource = registry.resources[resource_name] reason: dict[str, str] = {} - if not resource_provider.is_public: + if not resource.is_public: if user is None: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Resource is not public") else: - if resource_provider.scope_required is not None and not user.has_scope( - resource_provider.scope_required + if resource.scope_required is not None and not user.has_scope( + resource.scope_required ): reason["scope"] = ( - f"No scope {resource_provider.scope_required} in the access token " + f"No scope {resource.scope_required} in the access token " "but it is required for accessing this resource" ) if ( - resource_provider.role_required is not None - and resource_provider.role_required not in user.roles_as_set + resource.role_required is not None + and resource.role_required not in user.roles_as_set ): reason["role"] = ( - f"You don't have the role {resource_provider.role_required} " + f"You don't have the role {resource.role_required} " "but it is required for accessing this resource" ) if len(reason) == 0: try: - resp = await resource_provider.process(user=user, resource_id=resource_id) + resp = await resource.process(user=user, resource_id=resource_id) return resp except ProcessError as err: raise HTTPException( @@ -116,12 +149,11 @@ async def get_auth_provider_resource( provider: Provider, resource_name: str, token: OAuth2Token | None, user: User ) -> ProcessResult: if token is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"No auth token") + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No auth token") access_token = token["access_token"] - resource = [r for r in provider.resources if r.resource_name == resource_name][0] async with AsyncClient() as client: resp = await client.get( - url=provider.url + resource.url, + url=provider.get_resource_url(resource_name), headers={ "Content-type": "application/json", "Authorization": f"Bearer {access_token}", diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index 3e7001c..e549fd4 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -4,7 +4,7 @@ import random from typing import Type, Tuple from pathlib import Path -from pydantic import BaseModel, computed_field, AnyUrl +from pydantic import AnyHttpUrl, BaseModel, computed_field, AnyUrl from pydantic_settings import ( BaseSettings, SettingsConfigDict, @@ -22,6 +22,22 @@ class Resource(BaseModel): url: str +class ResourceProvider(BaseModel): + id: str + name: str + base_url: AnyUrl + resources: list[Resource] = [] + verify_ssl: bool = True + + def get_resource(self, resource_name: str) -> Resource: + return [ + resource for resource in self.resources if resource.resource_name == resource_name + ][0] + + def get_resource_url(self, resource_name: str) -> str: + return f"{self.base_url}{self.get_resource(resource_name).url}" + + class AuthProviderSettings(BaseModel): """Auth provider, can also be a resource server""" @@ -45,6 +61,7 @@ class AuthProviderSettings(BaseModel): session_key: str = "sid" skip_verify_signature: bool = True disabled: bool = False + resource_providers: list[ResourceProvider] = [] @computed_field @property @@ -67,13 +84,6 @@ class AuthProviderSettings(BaseModel): return None -class ResourceProvider(BaseModel): - id: str - name: str - base_url: AnyUrl - resources: list[Resource] = [] - - class AuthSettings(BaseModel): show_session_details: bool = False providers: list[AuthProviderSettings] = [] @@ -92,13 +102,13 @@ class Settings(BaseSettings): model_config = SettingsConfigDict(env_nested_delimiter="__") auth: AuthSettings = AuthSettings() - resource_providers: list[ResourceProvider] = [] secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16)) log: bool = False insecure: Insecure = Insecure() cors_origins: list[str] = [] debug_token: bool = False show_token: bool = False + show_external_resource_providers_links: bool = False @classmethod def settings_customise_sources( diff --git a/src/oidc_test/static/utils.js b/src/oidc_test/static/utils.js index 978b61c..e988dfe 100644 --- a/src/oidc_test/static/utils.js +++ b/src/oidc_test/static/utils.js @@ -2,7 +2,9 @@ async function checkHref(elem, token, authProvider) { const msg = document.getElementById("msg") const resourceName = elem.getAttribute("resource-name") const resourceId = elem.getAttribute("resource-id") - const url = resourceId ? `resource/${resourceName}/${resourceId}` : `resource/${resourceName}` + const resourceProviderId = elem.getAttribute("resource-provider-id") ? elem.getAttribute("resource-provider-id") : "" + const fqResourceName = resourceProviderId ? `${resourceProviderId}:${resourceName}` : resourceName + const url = resourceId ? `resource/${fqResourceName}/${resourceId}` : `resource/${fqResourceName}` const resp = await fetch(url, { method: "GET", headers: new Headers({ @@ -30,11 +32,13 @@ function checkPerms(className, token, authProvider) { ) } -async function get_resource(resource_name, token, authProvider, resource_id) { +async function get_resource(resourceName, token, authProvider, resourceId, resourceProviderId) { + // BaseUrl for an external resource provider //if (!keycloak.keycloak) { return } const msg = document.getElementById("msg") const resourceElem = document.getElementById('resource') - const url = resource_id ? `resource/${resource_name}/${resource_id}` : `resource/${resource_name}` + const fqResourceName = resourceProviderId ? `${resourceProviderId}:${resourceName}` : resourceName + const url = resourceId ? `resource/${fqResourceName}/${resourceId}` : `resource/${fqResourceName}` const resp = await fetch(url, { method: "GET", headers: new Headers({ diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 6c4e6a6..5bccaee 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -66,29 +66,50 @@ {% endif %} <hr> <div class="content"> - {% if resource_providers %} + <!-- + --> + {% if resources %} <p> {{ auth_provider.name }} provides these resources: </p> <div class="links-to-check"> - {% for name, resource_provider in resource_providers.items() %} - {% if resource_provider.default_resource_id %} - <button resource-name="{{ name }}" - resource-id="{{ resource_provider.default_resource_id }}" - onclick="get_resource('{{ name }}', '{{ access_token }}', '{{ auth_provider.id }}', '{{ resource_provider.default_resource_id }}')" - > - {{ resource_provider.name }} - </button> + {% for name, resource in resources.items() %} + {% if resource.default_resource_id %} + <button resource-name="{{ name }}" + resource-id="{{ resource.default_resource_id }}" + onclick="get_resource('{{ name }}', '{{ access_token }}', '{{ auth_provider.id }}', '{{ resource.default_resource_id }}')" + > + {{ resource.name }} + </button> {% else %} <button resource-name="{{ name }}" onclick="get_resource('{{ name }}', '{{ access_token }}', '{{ auth_provider.id }}')" > - {{ resource_provider.name }} - </buttona> + {{ resource.name }} + </button> {% endif %} {% endfor %} </div> {% endif %} + {% if resource_providers %} + <p>{{ auth_provider.name }} can request resources from third party resource providers:</p> + {% for resource_provider in resource_providers %} + <div class="links-to-check"> + {{ resource_provider.name }} + {% for resource in resource_provider.resources %} + <button resource-name="{{ resource.resource_name }}" + resource-id="{{ resource.default_resource_id }}" + resource-provider-id="{{ resource_provider.id }}" + onclick="get_resource('{{ resource.resource_name }}', '{{ access_token }}', + '{{ auth_provider.id }}', '{{ resource.default_resource_id }}', + '{{ resource_provider.id }}')" + > + {{ resource.name }} + </button> + {% endfor %} + </div> + {% endfor %} + {% endif %} <div class="resourceResult"> <div id="resource" class="resource"></div> <div id="msg" class="msg error"></div> From e925f2176258d4f73b5b7e565123652da02d6b12 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Thu, 20 Feb 2025 02:01:18 +0100 Subject: [PATCH 19/46] Add configurable logging from settings --- log_conf.yaml => src/oidc_test/log_conf.yaml | 0 src/oidc_test/main.py | 13 +++++++++++++ src/oidc_test/settings.py | 1 + 3 files changed, 14 insertions(+) rename log_conf.yaml => src/oidc_test/log_conf.yaml (100%) diff --git a/log_conf.yaml b/src/oidc_test/log_conf.yaml similarity index 100% rename from log_conf.yaml rename to src/oidc_test/log_conf.yaml diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index f930e48..cf28f27 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -6,6 +6,9 @@ from typing import Annotated from pathlib import Path from datetime import datetime import logging +import logging.config +import importlib.resources +from yaml import safe_load from urllib.parse import urlencode from contextlib import asynccontextmanager @@ -46,6 +49,16 @@ from oidc_test.resource_server import resource_server logger = logging.getLogger("oidc-test") +if settings.log: + assert __package__ is not None + with ( + importlib.resources.path(__package__) as package_path, + open(package_path / settings.log_config_file) as f, + ): + logging_config = safe_load(f) + logging.config.dictConfig(logging_config) + +breakpoint() templates = Jinja2Templates(Path(__file__).parent / "templates") diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index e549fd4..ad80c06 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -104,6 +104,7 @@ class Settings(BaseSettings): auth: AuthSettings = AuthSettings() secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16)) log: bool = False + log_config_file: str = "log_conf.yaml" insecure: Insecure = Insecure() cors_origins: list[str] = [] debug_token: bool = False From 703985f31102705769e725b25d54dc8414bb0610 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Thu, 20 Feb 2025 02:01:33 +0100 Subject: [PATCH 20/46] Add configurable logging from settings --- src/oidc_test/main.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index cf28f27..8fe32d8 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -58,7 +58,6 @@ if settings.log: logging_config = safe_load(f) logging.config.dictConfig(logging_config) -breakpoint() templates = Jinja2Templates(Path(__file__).parent / "templates") From 0764b1c003151ddbf7afe6d6d17590aca0913de2 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Thu, 20 Feb 2025 02:05:15 +0100 Subject: [PATCH 21/46] Log request to resource server --- src/oidc_test/resource_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index a4d5368..604052c 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -75,6 +75,7 @@ async def get_resource( resource_url = resource_provider.get_resource_url(resource_name) async with AsyncClient(verify=resource_provider.verify_ssl) as client: try: + logger.debug(f"GET request to {resource_url}") resp = await client.get( resource_url, headers={ From d924c56b1710082690df1e6fd8a146e8f672cf6d Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Thu, 20 Feb 2025 02:56:28 +0100 Subject: [PATCH 22/46] Cosmetic --- src/oidc_test/templates/home.html | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 5bccaee..7d2b1db 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -70,7 +70,7 @@ --> {% if resources %} <p> - {{ auth_provider.name }} provides these resources: + This application provides all these resources, eventually protected with roles: </p> <div class="links-to-check"> {% for name, resource in resources.items() %} @@ -92,7 +92,7 @@ </div> {% endif %} {% if resource_providers %} - <p>{{ auth_provider.name }} can request resources from third party resource providers:</p> + <p>{{ auth_provider.name }} allows this applicaiton to request resources from third party resource providers:</p> {% for resource_provider in resource_providers %} <div class="links-to-check"> {{ resource_provider.name }} From ecdd3702f85c04e62741fa15910e25f425d5c20d Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Thu, 20 Feb 2025 03:13:41 +0100 Subject: [PATCH 23/46] Hanle token refresh error --- src/oidc_test/main.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 8fe32d8..8808562 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -306,7 +306,13 @@ async def refresh( refresh_token=token["refresh_token"], grant_type="refresh_token", ) - await update_token(provider.id, new_token) + try: + await update_token(provider.id, new_token) + except PyJWTError as err: + logger.info(f"Cannot refresh token: {err.__class__.__name__}") + raise HTTPException( + status.HTTP_510_NOT_EXTENDED, f"Token refresh error: {err.__class__.__name__}" + ) return RedirectResponse(url=request.url_for("home")) From 3f945310a4aba2b898c7647728b9840dda9796c1 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Thu, 20 Feb 2025 03:20:09 +0100 Subject: [PATCH 24/46] Cosmetic --- src/oidc_test/templates/home.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 7d2b1db..3c1ff3c 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -70,7 +70,7 @@ --> {% if resources %} <p> - This application provides all these resources, eventually protected with roles: + This application provides all these resources, eventually protected with scope or roles: </p> <div class="links-to-check"> {% for name, resource in resources.items() %} From 347c3953943a1e0b7d84d5555727fae2f1d104be Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Thu, 20 Feb 2025 21:16:43 +0100 Subject: [PATCH 25/46] Fix auth provider resources --- src/oidc_test/main.py | 25 +++++++++++++------------ src/oidc_test/resource_server.py | 25 ++++++++++++++++++++----- src/oidc_test/templates/home.html | 29 +++++++++++++++++++++++------ 3 files changed, 56 insertions(+), 23 deletions(-) diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 8808562..54d69c5 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -108,7 +108,6 @@ async def home( "show_token": settings.show_token, "user": user, "now": datetime.now(), - "auth_provider": provider, } if provider is None or token is None: context["providers"] = providers @@ -117,26 +116,28 @@ async def home( context["access_token_parsed"] = None context["refresh_token_parsed"] = None context["resources"] = None + context["auth_provider"] = None else: + context["auth_provider"] = provider context["access_token"] = token["access_token"] try: access_token_parsed = provider.decode(token["access_token"], verify_signature=False) + context["access_token_parsed"] = access_token_parsed except PyJWTError as err: access_token_parsed = {"Cannot parse": err.__class__.__name__} try: - context["access_token_scope"] = access_token_parsed["scope"] - except KeyError: - context["access_token_scope"] = None - context["id_token_parsed"] = provider.decode(token["id_token"], verify_signature=False) - context["access_token_parsed"] = access_token_parsed + id_token_parsed = provider.decode(token["id_token"], verify_signature=False) + context["id_token_parsed"] = id_token_parsed + except PyJWTError as err: + id_token_parsed = {"Cannot parse": err.__class__.__name__} + try: + refresh_token_parsed = provider.decode(token["refresh_token"], verify_signature=False) + context["refresh_token_parsed"] = refresh_token_parsed + except PyJWTError as err: + refresh_token_parsed = {"Cannot parse": err.__class__.__name__} + context["access_token_scope"] = access_token_parsed.get("scope") context["resources"] = registry.resources context["resource_providers"] = provider.resource_providers - try: - context["refresh_token_parsed"] = provider.decode( - token["refresh_token"], verify_signature=False - ) - except PyJWTError as err: - context["refresh_token_parsed"] = {"Cannot parse": err.__class__.__name__} return templates.TemplateResponse(name="home.html", request=request, context=context) diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index 604052c..ddc5762 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -1,9 +1,10 @@ from typing import Annotated, Any import logging +from json import JSONDecodeError from authlib.oauth2.rfc6749 import OAuth2Token -from httpx import AsyncClient -from jwt.exceptions import ExpiredSignatureError, InvalidTokenError +from httpx import AsyncClient, HTTPError +from jwt.exceptions import DecodeError, ExpiredSignatureError, InvalidTokenError from fastapi import FastAPI, HTTPException, Depends, Request, status from fastapi.middleware.cors import CORSMiddleware @@ -84,6 +85,10 @@ async def get_resource( "auth_provider": user.auth_provider_id, }, ) + except HTTPError as err: + raise HTTPException( + status.HTTP_503_SERVICE_UNAVAILABLE, err.__class__.__name__ + ) except Exception as err: raise HTTPException( status.HTTP_500_INTERNAL_SERVER_ERROR, err.__class__.__name__ @@ -151,7 +156,7 @@ async def get_auth_provider_resource( ) -> ProcessResult: if token is None: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No auth token") - access_token = token["access_token"] + access_token = token async with AsyncClient() as client: resp = await client.get( url=provider.get_resource_url(resource_name), @@ -165,9 +170,19 @@ async def get_auth_provider_resource( # Only a demo, real application would really process the response resp_length = len(resp.text) if resp_length > 1024: - return ProcessResult(msg=f"The resource is too long ({resp_length} bytes) to show here") + return ProcessResult( + msg=f"The resource is too long ({resp_length} bytes) to show in this demo, here is just the begining in raw format", + start=resp.text[:100] + "...", + ) else: - return ProcessResult(**resp.json()) + try: + resp_json = resp.json() + except JSONDecodeError: + return ProcessResult(msg="The resource is not formatted in JSON", text=resp.text) + if isinstance(resp_json, dict): + return ProcessResult(**resp.json()) + elif isinstance(resp_json, list): + return ProcessResult(**{str(i): line for i, line in enumerate(resp_json)}) # @resource_server.get("/public") diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 3c1ff3c..b4460ee 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -66,12 +66,8 @@ {% endif %} <hr> <div class="content"> - <!-- - --> {% if resources %} - <p> - This application provides all these resources, eventually protected with scope or roles: - </p> + <p>This application provides all these resources, eventually protected with scope or roles:</p> <div class="links-to-check"> {% for name, resource in resources.items() %} {% if resource.default_resource_id %} @@ -91,8 +87,29 @@ {% endfor %} </div> {% endif %} + {% if auth_provider.resources %} + <p>{{ auth_provider.name }} is also defined as a provider for these resources:</p> + <div class="links-to-check"> + {% for resource in auth_provider.resources %} + {% if resource.default_resource_id %} + <button resource-name="{{ resource.resource_name }}" + resource-id="{{ resource.default_resource_id }}" + onclick="get_resource('{{ resource.resource_name }}', '{{ access_token }}', '{{ auth_provider.id }}', '{{ resource.default_resource_id }}')" + > + {{ resource.name }} + </button> + {% else %} + <button resource-name="{{ resource.name }}" + onclick="get_resource('{{ resource.resource_name }}', '{{ access_token }}', '{{ auth_provider.id }}')" + > + {{ resource.name }} + </button> + {% endif %} + {% endfor %} + </div> + {% endif %} {% if resource_providers %} - <p>{{ auth_provider.name }} allows this applicaiton to request resources from third party resource providers:</p> + <p>{{ auth_provider.name }} allows this application to request resources from third party resource providers:</p> {% for resource_provider in resource_providers %} <div class="links-to-check"> {{ resource_provider.name }} From 4c2b197850a1e4e16f7b4d690eeb0ecd123422c5 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Sat, 22 Feb 2025 14:02:05 +0100 Subject: [PATCH 26/46] Cosmetic --- src/oidc_test/auth/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/oidc_test/auth/utils.py b/src/oidc_test/auth/utils.py index 7dd0e3d..134131e 100644 --- a/src/oidc_test/auth/utils.py +++ b/src/oidc_test/auth/utils.py @@ -5,10 +5,8 @@ import logging from fastapi import HTTPException, Request, Depends, status from fastapi.security import OAuth2PasswordBearer from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App -from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError - -# from authlib.oauth1.auth import OAuthToken from authlib.oauth2.rfc6749 import OAuth2Token +from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError from oidc_test.auth.provider import Provider from oidc_test.models import User From f6a84fd3aaef82460c3a701ad4dc451aeb7ac73e Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Sat, 22 Feb 2025 18:57:25 +0100 Subject: [PATCH 27/46] Cosmetic --- src/oidc_test/auth/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/oidc_test/auth/utils.py b/src/oidc_test/auth/utils.py index 134131e..c51b039 100644 --- a/src/oidc_test/auth/utils.py +++ b/src/oidc_test/auth/utils.py @@ -20,7 +20,7 @@ logger = logging.getLogger("oidc-test") async def fetch_token(name, request): assert name is not None assert request is not None - logger.warn("TODO: fetch_token") + logger.warning("TODO: fetch_token") ... # if name in oidc_providers: # model = OAuth2Token @@ -32,7 +32,10 @@ async def fetch_token(name, request): async def update_token( - provider_id, token, refresh_token: str | None = None, access_token: str | None = None + provider_id, + token, + refresh_token: str | None = None, + access_token: str | None = None, ): """Update the token in the database""" provider = providers[provider_id] From 850db9f59035645cb8530625388ae42ec69103c9 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Sun, 23 Feb 2025 16:37:47 +0100 Subject: [PATCH 28/46] Fix scope cannot be determined when the access token cannot be decoded --- src/oidc_test/main.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 54d69c5..e5238c8 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -123,19 +123,20 @@ async def home( try: access_token_parsed = provider.decode(token["access_token"], verify_signature=False) context["access_token_parsed"] = access_token_parsed + context["access_token_scope"] = access_token_parsed.get("scope") except PyJWTError as err: - access_token_parsed = {"Cannot parse": err.__class__.__name__} + context["access_token_parsed"] = {"Cannot parse": err.__class__.__name__} + context["access_token_scope"] = None try: id_token_parsed = provider.decode(token["id_token"], verify_signature=False) context["id_token_parsed"] = id_token_parsed except PyJWTError as err: - id_token_parsed = {"Cannot parse": err.__class__.__name__} + context["id_token_parsed"] = {"Cannot parse": err.__class__.__name__} try: refresh_token_parsed = provider.decode(token["refresh_token"], verify_signature=False) context["refresh_token_parsed"] = refresh_token_parsed except PyJWTError as err: - refresh_token_parsed = {"Cannot parse": err.__class__.__name__} - context["access_token_scope"] = access_token_parsed.get("scope") + context["refresh_token_parsed"] = {"Cannot parse": err.__class__.__name__} context["resources"] = registry.resources context["resource_providers"] = provider.resource_providers return templates.TemplateResponse(name="home.html", request=request, context=context) From 5f429797ff7a8656dee9dfc9cbc156dfe27f9c8f Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Sun, 23 Feb 2025 17:14:04 +0100 Subject: [PATCH 29/46] Fix auto check of auth provider resource (resource_name in template) --- src/oidc_test/templates/home.html | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index b4460ee..167616f 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -93,13 +93,13 @@ {% for resource in auth_provider.resources %} {% if resource.default_resource_id %} <button resource-name="{{ resource.resource_name }}" - resource-id="{{ resource.default_resource_id }}" + resource-id="{{ resource.default_resource_id }}" onclick="get_resource('{{ resource.resource_name }}', '{{ access_token }}', '{{ auth_provider.id }}', '{{ resource.default_resource_id }}')" > {{ resource.name }} </button> {% else %} - <button resource-name="{{ resource.name }}" + <button resource-name="{{ resource.resource_name }}" onclick="get_resource('{{ resource.resource_name }}', '{{ access_token }}', '{{ auth_provider.id }}')" > {{ resource.name }} From 9249885c8080a9afaca61d22022cd9be2e6bd8bd Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Mon, 24 Feb 2025 03:29:23 +0100 Subject: [PATCH 30/46] Update README (config example) --- README.md | 60 ++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 44 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 9e00474..68f335d 100644 --- a/README.md +++ b/README.md @@ -52,31 +52,59 @@ given by the OIDC providers. For example: ```yaml -oidc: - secret_key: "ASecretNoOneKnows" - show_session_details: yes +secret_key: AVeryWellKeptSecret +debug_token: no +show_token: yes +log: yes + +auth: providers: - id: auth0 name: Okta / Auth0 - url: "https://<your_auth0_app_URL>" - client_id: "<your_auth0_client_id>" - client_secret: "client_secret_generated_by_auth0" - hint: "A hint for test credentials" + url: https://<your_auth0_app_URL> + public_key_url: https://<your_auth0_app_URL>/pem + client_id: <your_auth0_client_id> + client_secret: client_secret_generated_by_auth0 + hint: A hint for test credentials - id: keycloak name: Keycloak at somewhere - url: "https://<the_keycloak_realm_url>" - account_url_template: "/account" - client_id: "<your_keycloak_client_id>" - client_secret: "client_secret_generated_by_keycloak" - hint: "User: foo, password: foofoo" + url: https://<the_keycloak_realm_url> + info_url: https://philo.ydns.eu/auth/realms/test + account_url_template: /account + client_id: <your_keycloak_client_id> + client_secret: <client_secret_generated_by_keycloak> + hint: A hint for test credentials + code_challenge_method: S256 + resource_provider_scopes: + - get:time + - get:bs + resource_providers: + - id: <third_party_resource_provider_id> + name: A third party resource provider + base_url: https://some.example.com/ + verify_ssl: yes + resources: + - name: Public RS2 + resource_name: public + url: resource/public + - name: BS RS2 + resource_name: bs + url: resource/bs + - name: Time RS2 + resource_name: time + url: resource/time - id: codeberg + disabled: no name: Codeberg - url: "https://codeberg.org" - account_url_template: "/user/settings" - client_id: "<your_codeberg_client_id>" - client_secret: "client_secret_generated_by_codeberg" + url: https://codeberg.org + account_url_template: /user/settings + client_id: <your_codeberg_client_id> + client_secret: client_secret_generated_by_codeberg + info_url: https://codeberg.org/login/oauth/keys + session_key: sub + skip_verify_signature: no resources: - name: List of repos id: repos From 395ec1c7f757ba919224c781fb76f53a0204ea34 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Mon, 24 Feb 2025 19:56:00 +0100 Subject: [PATCH 31/46] Dynamic versioning --- pyproject.toml | 15 +++++++++++---- uv.lock | 16 +++++++++++++++- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b1e6504..9c205e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "oidc-fastapi-test" -version = "0.0.0" -# dynamic = ["version"] +#version = "0.0.0" +dynamic = ["version"] description = "Add your description here" readme = "README.md" requires-python = ">=3.13" @@ -24,12 +24,19 @@ dependencies = [ oidc-test = "oidc_test.main:main" [dependency-groups] -dev = ["ipdb>=0.13.13", "pytest>=8.3.4"] +dev = [ + "dunamai>=1.23.0", + "ipdb>=0.13.13", + "pytest>=8.3.4", +] [build-system] -requires = ["hatchling"] +requires = ["hatchling", "uv-dynamic-versioning"] build-backend = "hatchling.build" +[tool.hatch.version] +source = "uv-dynamic-versioning" + [tool.hatch.build.targets.wheel] packages = ["src/oidc_test"] diff --git a/uv.lock b/uv.lock index 01b64de..0566bb5 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.13" [[package]] @@ -206,6 +207,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/68/1b/e0a87d256e40e8c888847551b20a017a6b98139178505dc7ffb96f04e954/dnspython-2.7.0-py3-none-any.whl", hash = "sha256:b4c34b7d10b51bcc3a5071e7b8dee77939f1e878477eeecc965e9835f63c6c86", size = 313632 }, ] +[[package]] +name = "dunamai" +version = "1.23.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/06/4e/a5c8c337a1d9ac0384298ade02d322741fb5998041a5ea74d1cd2a4a1d47/dunamai-1.23.0.tar.gz", hash = "sha256:a163746de7ea5acb6dacdab3a6ad621ebc612ed1e528aaa8beedb8887fccd2c4", size = 44681 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/21/4c/963169386309fec4f96fd61210ac0a0666887d0fb0a50205395674d20b71/dunamai-1.23.0-py3-none-any.whl", hash = "sha256:a0906d876e92441793c6a423e16a4802752e723e9c9a5aabdc5535df02dbe041", size = 26342 }, +] + [[package]] name = "ecdsa" version = "0.19.0" @@ -482,7 +495,6 @@ wheels = [ [[package]] name = "oidc-fastapi-test" -version = "0.0.0" source = { editable = "." } dependencies = [ { name = "authlib" }, @@ -501,6 +513,7 @@ dependencies = [ [package.dev-dependencies] dev = [ + { name = "dunamai" }, { name = "ipdb" }, { name = "pytest" }, ] @@ -523,6 +536,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ + { name = "dunamai", specifier = ">=1.23.0" }, { name = "ipdb", specifier = ">=0.13.13" }, { name = "pytest", specifier = ">=8.3.4" }, ] From ef7c265d8e4b70d283d9d7d161165003b62ae7fa Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Tue, 25 Feb 2025 00:38:43 +0100 Subject: [PATCH 32/46] Cleanup pyproject --- pyproject.toml | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9c205e7..f634fbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,11 +24,7 @@ dependencies = [ oidc-test = "oidc_test.main:main" [dependency-groups] -dev = [ - "dunamai>=1.23.0", - "ipdb>=0.13.13", - "pytest>=8.3.4", -] +dev = ["dunamai>=1.23.0", "ipdb>=0.13.13", "pytest>=8.3.4"] [build-system] requires = ["hatchling", "uv-dynamic-versioning"] @@ -39,6 +35,12 @@ source = "uv-dynamic-versioning" [tool.hatch.build.targets.wheel] packages = ["src/oidc_test"] +package = true + +[tool.uv-dynamic-versioning] +metadata = true +dirty = true +format = semver [tool.uv] package = true From 9c1f84328399104e8759a5d8c2f12d4a7f4113b7 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Tue, 25 Feb 2025 00:40:33 +0100 Subject: [PATCH 33/46] Cleanup pyproject --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f634fbe..1e23a8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ package = true [tool.uv-dynamic-versioning] metadata = true dirty = true -format = semver +style = "semver" [tool.uv] package = true From 3da485c945e221c73721f71540e0c57c12f7c22b Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Tue, 25 Feb 2025 00:41:36 +0100 Subject: [PATCH 34/46] Cleanup pyproject --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1e23a8d..8770c8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,6 @@ package = true [tool.uv-dynamic-versioning] metadata = true -dirty = true style = "semver" [tool.uv] From 9c462379051155cdf834dfd64309cc97da348f4c Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Tue, 25 Feb 2025 01:37:17 +0100 Subject: [PATCH 35/46] Semver versioning, show version on web page --- pyproject.toml | 1 - src/oidc_test/__init__.py | 11 +++++++++++ src/oidc_test/main.py | 2 ++ src/oidc_test/static/styles.css | 6 ++++++ src/oidc_test/templates/base.html | 1 + 5 files changed, 20 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8770c8d..c44e9f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,6 @@ packages = ["src/oidc_test"] package = true [tool.uv-dynamic-versioning] -metadata = true style = "semver" [tool.uv] diff --git a/src/oidc_test/__init__.py b/src/oidc_test/__init__.py index e69de29..f449e2b 100644 --- a/src/oidc_test/__init__.py +++ b/src/oidc_test/__init__.py @@ -0,0 +1,11 @@ +import importlib.metadata + +try: + from dunamai import Version, Style + + __version__ = Version.from_git().serialize(style=Style.SemVer, dirty=True) +except ImportError: + # __name__ could be used if the package name is the same + # as the directory - not the case here + # __version__ = importlib.metadata.version(__name__) + __version__ = importlib.metadata.version("oidc-fastapi-test") diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index e5238c8..e882cda 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -29,6 +29,7 @@ from authlib.oauth2.rfc6749 import OAuth2Token # from fastapi.security import OpenIdConnect # from pkce import generate_code_verifier, generate_pkce_pair +from oidc_test import __version__ from oidc_test.registry import registry from oidc_test.auth.provider import NoPublicKey, Provider from oidc_test.auth.utils import ( @@ -108,6 +109,7 @@ async def home( "show_token": settings.show_token, "user": user, "now": datetime.now(), + "__version__": __version__, } if provider is None or token is None: context["providers"] = providers diff --git a/src/oidc_test/static/styles.css b/src/oidc_test/static/styles.css index 2baa748..1e8dc03 100644 --- a/src/oidc_test/static/styles.css +++ b/src/oidc_test/static/styles.css @@ -21,6 +21,12 @@ hr { .hidden { display: none; } +.version { + position: absolute; + font-size: 75%; + top: 0.3em; + right: 0.3em; +} .center { text-align: center; } diff --git a/src/oidc_test/templates/base.html b/src/oidc_test/templates/base.html index 4cb56f5..157e26f 100644 --- a/src/oidc_test/templates/base.html +++ b/src/oidc_test/templates/base.html @@ -5,6 +5,7 @@ <script src="{{ url_for('static', path='/utils.js') }}"></script> </head> <body onload="checkPerms('links-to-check', '{{ access_token }}', '{{ auth_provider.id }}')"> + <div class="version">v. {{ __version__}}</div> <h1>OIDC-test - FastAPI client</h1> {% block content %} {% endblock %} From b4653947660be16fbfee161fc4c52e2f4747f0f9 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Tue, 25 Feb 2025 01:42:49 +0100 Subject: [PATCH 36/46] CI: WIP --- .forgejo/workflows/build.yaml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.forgejo/workflows/build.yaml b/.forgejo/workflows/build.yaml index 352a0a9..c58230e 100644 --- a/.forgejo/workflows/build.yaml +++ b/.forgejo/workflows/build.yaml @@ -27,6 +27,15 @@ jobs: - name: Run tests (API call) run: .venv/bin/pytest -s tests/basic.py + - name: Get version + uses: mtkennerly/dunamai-action@v1 + with: + env-var: VERSION + args: --style semver + + - name: Version + run: echo $VERSION + - name: Get version with git describe id: version run: | From f4b38e1c69d3a1f288ddba1b62ae84ae4184d362 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Tue, 25 Feb 2025 02:20:35 +0100 Subject: [PATCH 37/46] CI: use dunamai for version --- .forgejo/workflows/build.yaml | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/.forgejo/workflows/build.yaml b/.forgejo/workflows/build.yaml index c58230e..9f21750 100644 --- a/.forgejo/workflows/build.yaml +++ b/.forgejo/workflows/build.yaml @@ -30,32 +30,29 @@ jobs: - name: Get version uses: mtkennerly/dunamai-action@v1 with: - env-var: VERSION args: --style semver + env-var: VERSION - name: Version run: echo $VERSION - - name: Get version with git describe - id: version - run: | - echo "version=$(git describe)" >> $GITHUB_OUTPUT - echo "$VERSION" + - name: Get distance from tag + uses: mtkennerly/dunamai-action@v1 + with: + args: --format "{distance}" + env-var: DISTANCE - - name: Check if the container should be built + - name: Distance + run: echo $DISTANCE + + - name: Check if the container should be built (distance from git tag is 0, or force build) id: builder env: - RUN: ${{ toJSON(inputs.build || !contains(steps.version.outputs.version, '-')) }} + RUN: ${{ toJSON(inputs.build || env.DISTANCE == "0" }} run: | echo "run=$RUN" >> $GITHUB_OUTPUT echo "Run build: $RUN" - - name: Set the version in pyproject.toml (workaround for uv not supporting dynamic version) - if: fromJSON(steps.builder.outputs.run) - env: - VERSION: ${{ steps.version.outputs.version }} - run: sed "s/0.0.0/$VERSION/" -i pyproject.toml - - name: Workaround for bug of podman-login if: fromJSON(steps.builder.outputs.run) run: | From 6f060dc2bfd726eec57607ef7ca911de22aa4177 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Tue, 25 Feb 2025 02:26:37 +0100 Subject: [PATCH 38/46] CI: bump uv --- .forgejo/workflows/build.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.forgejo/workflows/build.yaml b/.forgejo/workflows/build.yaml index 9f21750..18a89ee 100644 --- a/.forgejo/workflows/build.yaml +++ b/.forgejo/workflows/build.yaml @@ -19,7 +19,7 @@ jobs: - name: Install the latest version of uv uses: astral-sh/setup-uv@v4 with: - version: "0.5.16" + version: "0.6.3" - name: Install run: uv sync From 22d0a9852c3c97ffeac9a6cdbff5e4f31ee6e6a0 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Tue, 25 Feb 2025 03:04:14 +0100 Subject: [PATCH 39/46] CI: not use dunamai github action as it uses plain pip, not uv pip --- .forgejo/workflows/build.yaml | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/.forgejo/workflows/build.yaml b/.forgejo/workflows/build.yaml index 18a89ee..8284c15 100644 --- a/.forgejo/workflows/build.yaml +++ b/.forgejo/workflows/build.yaml @@ -28,19 +28,13 @@ jobs: run: .venv/bin/pytest -s tests/basic.py - name: Get version - uses: mtkennerly/dunamai-action@v1 - with: - args: --style semver - env-var: VERSION + run: echo "VERSION=$(.venv/bin/dunamai --style semver)" >> $GITHUB_ENV - name: Version run: echo $VERSION - name: Get distance from tag - uses: mtkennerly/dunamai-action@v1 - with: - args: --format "{distance}" - env-var: DISTANCE + run: echo "DISTANCE=$(.venv/bin/dunamai --format '{distance}')" >> $GITHUB_ENV - name: Distance run: echo $DISTANCE From 9f7b0902739534127c9e5890b40c9be7e42dfac7 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Tue, 25 Feb 2025 03:12:46 +0100 Subject: [PATCH 40/46] CI: WIP --- .forgejo/workflows/build.yaml | 6 +++--- .forgejo/workflows/test.yaml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.forgejo/workflows/build.yaml b/.forgejo/workflows/build.yaml index 8284c15..645d5d2 100644 --- a/.forgejo/workflows/build.yaml +++ b/.forgejo/workflows/build.yaml @@ -28,13 +28,13 @@ jobs: run: .venv/bin/pytest -s tests/basic.py - name: Get version - run: echo "VERSION=$(.venv/bin/dunamai --style semver)" >> $GITHUB_ENV + run: echo "VERSION=$(.venv/bin/dunamai from any --style semver)" >> $GITHUB_ENV - name: Version run: echo $VERSION - name: Get distance from tag - run: echo "DISTANCE=$(.venv/bin/dunamai --format '{distance}')" >> $GITHUB_ENV + run: echo "DISTANCE=$(.venv/bin/dunamai from any --format '{distance}')" >> $GITHUB_ENV - name: Distance run: echo $DISTANCE @@ -42,7 +42,7 @@ jobs: - name: Check if the container should be built (distance from git tag is 0, or force build) id: builder env: - RUN: ${{ toJSON(inputs.build || env.DISTANCE == "0" }} + RUN: ${{ toJSON(inputs.build || env.DISTANCE == "0") }} run: | echo "run=$RUN" >> $GITHUB_OUTPUT echo "Run build: $RUN" diff --git a/.forgejo/workflows/test.yaml b/.forgejo/workflows/test.yaml index a56a9ce..f4d994e 100644 --- a/.forgejo/workflows/test.yaml +++ b/.forgejo/workflows/test.yaml @@ -19,7 +19,7 @@ jobs: - name: Install the latest version of uv uses: astral-sh/setup-uv@v4 with: - version: "0.5.16" + version: "0.6.3" - name: Install run: uv sync From 821df027587eaee24b3db792033266564f66c62c Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Tue, 25 Feb 2025 04:28:04 +0100 Subject: [PATCH 41/46] CI: WIP --- .forgejo/workflows/build.yaml | 32 +++++++++++--------------------- 1 file changed, 11 insertions(+), 21 deletions(-) diff --git a/.forgejo/workflows/build.yaml b/.forgejo/workflows/build.yaml index 645d5d2..cf079d9 100644 --- a/.forgejo/workflows/build.yaml +++ b/.forgejo/workflows/build.yaml @@ -28,33 +28,23 @@ jobs: run: .venv/bin/pytest -s tests/basic.py - name: Get version - run: echo "VERSION=$(.venv/bin/dunamai from any --style semver)" >> $GITHUB_ENV - - - name: Version - run: echo $VERSION + run: | + echo "VERSION=$(.venv/bin/dunamai from any --style semver)" >> $GITHUB_ENV + echo $VERSION - name: Get distance from tag - run: echo "DISTANCE=$(.venv/bin/dunamai from any --format '{distance}')" >> $GITHUB_ENV - - - name: Distance - run: echo $DISTANCE - - - name: Check if the container should be built (distance from git tag is 0, or force build) - id: builder - env: - RUN: ${{ toJSON(inputs.build || env.DISTANCE == "0") }} run: | - echo "run=$RUN" >> $GITHUB_OUTPUT - echo "Run build: $RUN" + echo "DISTANCE=$(.venv/bin/dunamai from any --format '{distance}')" >> $GITHUB_ENV + echo $DISTANCE - name: Workaround for bug of podman-login - if: fromJSON(steps.builder.outputs.run) + if: env.DISTANCE == '0' run: | mkdir -p $HOME/.docker echo "{ \"auths\": {} }" > $HOME/.docker/config.json - name: Log in to the container registry (with another workaround) - if: fromJSON(steps.builder.outputs.run) + if: env.DISTANCE == '0' uses: actions/podman-login@v1 with: registry: ${{ vars.REGISTRY }} @@ -63,7 +53,7 @@ jobs: auth_file_path: /tmp/auth.json - name: Build the container image - if: fromJSON(steps.builder.outputs.run) + if: env.DISTANCE == '0' uses: actions/buildah-build@v1 with: image: oidc-fastapi-test @@ -74,7 +64,7 @@ jobs: ./Containerfile - name: Push the image to the registry - if: fromJSON(steps.builder.outputs.run) + if: env.DISTANCE == '0' uses: actions/push-to-registry@v2 with: registry: "docker://${{ vars.REGISTRY }}/${{ vars.ORGANISATION }}" @@ -82,11 +72,11 @@ jobs: tags: latest ${{ steps.version.outputs.version }} - name: Build wheel - if: fromJSON(steps.builder.outputs.run) + if: env.DISTANCE == '0' run: uv build --wheel - name: Publish Python package (home) - if: fromJSON(steps.builder.outputs.run) + if: env.DISTANCE == '0' env: LOCAL_PYPI_TOKEN: ${{ secrets.LOCAL_PYPI_TOKEN }} run: uv publish --publish-url https://code.philo.ydns.eu/api/packages/philorg/pypi --token $LOCAL_PYPI_TOKEN From c5b1bdeda92d885749c38360a81572f2ef147087 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Tue, 25 Feb 2025 04:31:31 +0100 Subject: [PATCH 42/46] CI: WIP --- .forgejo/workflows/build.yaml | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/.forgejo/workflows/build.yaml b/.forgejo/workflows/build.yaml index cf079d9..06c8be1 100644 --- a/.forgejo/workflows/build.yaml +++ b/.forgejo/workflows/build.yaml @@ -28,14 +28,16 @@ jobs: run: .venv/bin/pytest -s tests/basic.py - name: Get version - run: | - echo "VERSION=$(.venv/bin/dunamai from any --style semver)" >> $GITHUB_ENV - echo $VERSION + run: echo "VERSION=$(.venv/bin/dunamai from any --style semver)" >> $GITHUB_ENV + + - name: Version + run: echo $VERSION - name: Get distance from tag - run: | - echo "DISTANCE=$(.venv/bin/dunamai from any --format '{distance}')" >> $GITHUB_ENV - echo $DISTANCE + run: echo "DISTANCE=$(.venv/bin/dunamai from any --format '{distance}')" >> $GITHUB_ENV + + - name: Distance + run: echo $DISTANCE - name: Workaround for bug of podman-login if: env.DISTANCE == '0' From c3ebad42d575b91d221f13a3b93d9cbbbcd80af8 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Tue, 25 Feb 2025 04:34:19 +0100 Subject: [PATCH 43/46] CI: WIP --- .forgejo/workflows/build.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.forgejo/workflows/build.yaml b/.forgejo/workflows/build.yaml index 06c8be1..e617190 100644 --- a/.forgejo/workflows/build.yaml +++ b/.forgejo/workflows/build.yaml @@ -61,7 +61,7 @@ jobs: image: oidc-fastapi-test oci: true labels: oidc-fastapi-test - tags: latest ${{ steps.version.outputs.version }} + tags: latest env.VERSION containerfiles: | ./Containerfile @@ -71,7 +71,7 @@ jobs: with: registry: "docker://${{ vars.REGISTRY }}/${{ vars.ORGANISATION }}" image: oidc-fastapi-test - tags: latest ${{ steps.version.outputs.version }} + tags: latest env.VERSION - name: Build wheel if: env.DISTANCE == '0' From 4355e6dc423b0db8ac5ca4746287dc20fd8c8d39 Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Tue, 25 Feb 2025 12:30:23 +0100 Subject: [PATCH 44/46] CI: WIP --- Containerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Containerfile b/Containerfile index 2e3fd28..0ec45d1 100644 --- a/Containerfile +++ b/Containerfile @@ -1,4 +1,4 @@ -FROM docker.io/library/python:alpine +FROM docker.io/library/python:latest COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/ From b01f2332086a08b4ad7c619d8fcb1a67e9a8e99c Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Tue, 25 Feb 2025 18:34:52 +0100 Subject: [PATCH 45/46] Add log messages for debugging connection to auth server --- src/oidc_test/auth/provider.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/oidc_test/auth/provider.py b/src/oidc_test/auth/provider.py index c614805..ce288a6 100644 --- a/src/oidc_test/auth/provider.py +++ b/src/oidc_test/auth/provider.py @@ -61,28 +61,34 @@ class Provider(AuthProviderSettings): if self.info_url is not None: try: provider_info = await client.get(self.info_url) - except Exception: + except Exception as err: + logger.debug("Provider_info: cannot connect") + logger.exception(err) raise NoPublicKey try: self.info = provider_info.json() except JSONDecodeError: + logger.debug("Provider_info: cannot decode json response") raise NoPublicKey if "public_key" in self.info: # For Keycloak try: public_key = str(self.info["public_key"]) except KeyError: + logger.debug("Provider_info: cannot get public_key") raise NoPublicKey elif "keys" in self.info: # For Forgejo/Gitea try: public_key = str(self.info["keys"][0]["n"]) except KeyError: + logger.debug("Provider_info: cannot get key 0.n") raise NoPublicKey if self.public_key_url is not None: resp = await client.get(self.public_key_url) public_key = resp.text if public_key is None: + logger.debug("Provider_info: cannot determine public key") raise NoPublicKey self.public_key = "\n".join( ["-----BEGIN PUBLIC KEY-----", public_key, "-----END PUBLIC KEY-----"] From 8b3a339196d92aa36dd17480af0586783bf1121e Mon Sep 17 00:00:00 2001 From: phil <phil.dev@philome.mooo.com> Date: Sat, 22 Mar 2025 01:01:32 +0100 Subject: [PATCH 46/46] CI: fix container tag --- .forgejo/workflows/build.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.forgejo/workflows/build.yaml b/.forgejo/workflows/build.yaml index e617190..379aaa8 100644 --- a/.forgejo/workflows/build.yaml +++ b/.forgejo/workflows/build.yaml @@ -19,7 +19,7 @@ jobs: - name: Install the latest version of uv uses: astral-sh/setup-uv@v4 with: - version: "0.6.3" + version: "0.6.9" - name: Install run: uv sync @@ -61,7 +61,7 @@ jobs: image: oidc-fastapi-test oci: true labels: oidc-fastapi-test - tags: latest env.VERSION + tags: "latest ${{ env.VERSION }}" containerfiles: | ./Containerfile @@ -71,7 +71,7 @@ jobs: with: registry: "docker://${{ vars.REGISTRY }}/${{ vars.ORGANISATION }}" image: oidc-fastapi-test - tags: latest env.VERSION + tags: "latest ${{ env.VERSION }}" - name: Build wheel if: env.DISTANCE == '0'