From 923a63f5d527e5a4128d193bb9ac7adef956cf84 Mon Sep 17 00:00:00 2001 From: phil Date: Sat, 8 Feb 2025 18:32:02 +0100 Subject: [PATCH] Add refresh token button --- src/oidc_test/auth_utils.py | 19 +++++++++++-------- src/oidc_test/main.py | 30 +++++++++++++++++++++++++++--- src/oidc_test/models.py | 9 +++++++-- src/oidc_test/static/utils.js | 1 + src/oidc_test/templates/base.html | 2 +- src/oidc_test/templates/home.html | 21 +++++++++++---------- 6 files changed, 58 insertions(+), 24 deletions(-) diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index 26f3779..cab14b2 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -13,7 +13,7 @@ from httpx import AsyncClient from authlib.oauth2.auth import OAuth2Token from .models import User -from .database import TokenNotInDb, db, UserNotInDB +from .database import db, TokenNotInDb, UserNotInDB from .settings import oidc_providers_settings logger = logging.getLogger("oidc-test") @@ -36,14 +36,14 @@ async def fetch_token(name, request): async def update_token(name, token, refresh_token=None, access_token=None): + """Update the token in the database""" oidc_provider_settings = oidc_providers_settings[name] sid: str = oidc_provider_settings.decode(token["id_token"])["sid"] item = await db.get_token(oidc_provider_settings, sid) # update old token - if access_token is not None: - item["access_token"] = token.get("access_token") - if refresh_token is not None: - item["refresh_token"] = refresh_token + item["access_token"] = token["access_token"] + item["refresh_token"] = token["refresh_token"] + item["id_token"] = token["id_token"] item["expires_at"] = token["expires_at"] logger.info(f"Token {sid} refreshed") # It's a fake db and only in memory, so there's nothing to save @@ -70,8 +70,8 @@ def init_providers(): api_base_url=provider.url, # For PKCE (not implemented yet): # code_challenge_method="S256", - # fetch_token=fetch_token, - # update_token=update_token, + fetch_token=fetch_token, + update_token=update_token, # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) ) @@ -101,7 +101,10 @@ def get_oidc_provider_or_none(request: Request) -> StarletteOAuth2App | None: def get_oidc_provider(request: Request) -> StarletteOAuth2App: if (oidc_provider := get_oidc_provider_or_none(request)) is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") + if oidc_provider is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No provider") + else: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") else: return oidc_provider diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 4a037eb..81b354f 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -35,6 +35,8 @@ from .auth_utils import ( authlib_oauth, get_providers_info, get_token_or_none, + get_token, + update_token, ) from .auth_misc import pretty_details from .database import TokenNotInDb, db @@ -97,7 +99,7 @@ async def home( access_token_scope = None else: try: - access_token_scope = user.decode_access_token()["scope"] + access_token_scope = user.get_scope(verify_signature=False) except InvalidTokenError as err: access_token_scope = None logger.info("Invalid token") @@ -113,15 +115,22 @@ async def home( "resources": resources, } if token is None: + context["access_token"] = None context["id_token_parsed"] = None context["access_token_parsed"] = None context["refresh_token_parsed"] = None else: + context["access_token"] = token["access_token"] assert oidc_provider is not None assert oidc_provider.name is not None oidc_provider_settings = oidc_providers_settings[oidc_provider.name] - context["id_token_parsed"] = pretty_details(user, now) - context["access_token_parsed"] = oidc_provider_settings.decode(token["access_token"]) + # context["id_token_parsed"] = pretty_details(user, now) + context["id_token_parsed"] = oidc_provider_settings.decode( + token["id_token"], verify_signature=False + ) + context["access_token_parsed"] = oidc_provider_settings.decode( + token["access_token"], verify_signature=False + ) context["refresh_token_parsed"] = oidc_provider_settings.decode( token["refresh_token"], verify_signature=False ) @@ -282,6 +291,21 @@ async def non_compliant_logout( ) +@app.get("/refresh") +async def refresh( + request: Request, + oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], + token: Annotated[OAuth2Token, Depends(get_token)], +) -> RedirectResponse: + """Manually refresh token""" + new_token = await oidc_provider.fetch_access_token( + refresh_token=token["refresh_token"], + grant_type="refresh_token", + ) + await update_token(oidc_provider.name, new_token) + return RedirectResponse(url=request.url_for("home")) + + # Snippet for running standalone # Mostly useful for the --version option, # as running with uvicorn is easy and provides better flexibility, eg. diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py index 9554bd5..8aee2e6 100644 --- a/src/oidc_test/models.py +++ b/src/oidc_test/models.py @@ -54,10 +54,15 @@ class User(UserBase): access_token_scopes = [] return scope in set(info_scopes + access_token_scopes) - def decode_access_token(self): + def decode_access_token(self, verify_signature: bool = True): assert self.access_token is not None assert self.oidc_provider is not None assert self.oidc_provider.name is not None from .settings import oidc_providers_settings - return oidc_providers_settings[self.oidc_provider.name].decode(self.access_token) + return oidc_providers_settings[self.oidc_provider.name].decode( + self.access_token, verify_signature=verify_signature + ) + + def get_scope(self, verify_signature: bool = True): + return self.decode_access_token(verify_signature=verify_signature)["scope"] diff --git a/src/oidc_test/static/utils.js b/src/oidc_test/static/utils.js index 6ea8da2..8e8ad59 100644 --- a/src/oidc_test/static/utils.js +++ b/src/oidc_test/static/utils.js @@ -2,6 +2,7 @@ async function checkHref(elem, token, authProvider) { const msg = document.getElementById("msg") const url = `resource/${elem.getAttribute("resource-id")}` const resp = await fetch(url, { + method: "GET", headers: new Headers({ "Content-type": "application/json", "Authorization": `Bearer ${token}`, diff --git a/src/oidc_test/templates/base.html b/src/oidc_test/templates/base.html index 0fe1a6b..2ce758c 100644 --- a/src/oidc_test/templates/base.html +++ b/src/oidc_test/templates/base.html @@ -4,7 +4,7 @@ - +

OIDC-test - FastAPI client

{% block content %} {% endblock %} diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 9da5392..08bcf43 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -57,6 +57,7 @@ Account management {% endif %} + {% endif %} @@ -66,21 +67,21 @@ Resources validated by scope:

Resources validated by role: