From 923a63f5d527e5a4128d193bb9ac7adef956cf84 Mon Sep 17 00:00:00 2001
From: phil
Date: Sat, 8 Feb 2025 18:32:02 +0100
Subject: [PATCH] Add refresh token button
---
src/oidc_test/auth_utils.py | 19 +++++++++++--------
src/oidc_test/main.py | 30 +++++++++++++++++++++++++++---
src/oidc_test/models.py | 9 +++++++--
src/oidc_test/static/utils.js | 1 +
src/oidc_test/templates/base.html | 2 +-
src/oidc_test/templates/home.html | 21 +++++++++++----------
6 files changed, 58 insertions(+), 24 deletions(-)
diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py
index 26f3779..cab14b2 100644
--- a/src/oidc_test/auth_utils.py
+++ b/src/oidc_test/auth_utils.py
@@ -13,7 +13,7 @@ from httpx import AsyncClient
from authlib.oauth2.auth import OAuth2Token
from .models import User
-from .database import TokenNotInDb, db, UserNotInDB
+from .database import db, TokenNotInDb, UserNotInDB
from .settings import oidc_providers_settings
logger = logging.getLogger("oidc-test")
@@ -36,14 +36,14 @@ async def fetch_token(name, request):
async def update_token(name, token, refresh_token=None, access_token=None):
+ """Update the token in the database"""
oidc_provider_settings = oidc_providers_settings[name]
sid: str = oidc_provider_settings.decode(token["id_token"])["sid"]
item = await db.get_token(oidc_provider_settings, sid)
# update old token
- if access_token is not None:
- item["access_token"] = token.get("access_token")
- if refresh_token is not None:
- item["refresh_token"] = refresh_token
+ item["access_token"] = token["access_token"]
+ item["refresh_token"] = token["refresh_token"]
+ item["id_token"] = token["id_token"]
item["expires_at"] = token["expires_at"]
logger.info(f"Token {sid} refreshed")
# It's a fake db and only in memory, so there's nothing to save
@@ -70,8 +70,8 @@ def init_providers():
api_base_url=provider.url,
# For PKCE (not implemented yet):
# code_challenge_method="S256",
- # fetch_token=fetch_token,
- # update_token=update_token,
+ fetch_token=fetch_token,
+ update_token=update_token,
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
)
@@ -101,7 +101,10 @@ def get_oidc_provider_or_none(request: Request) -> StarletteOAuth2App | None:
def get_oidc_provider(request: Request) -> StarletteOAuth2App:
if (oidc_provider := get_oidc_provider_or_none(request)) is None:
- raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
+ if oidc_provider is None:
+ raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No provider")
+ else:
+ raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
else:
return oidc_provider
diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py
index 4a037eb..81b354f 100644
--- a/src/oidc_test/main.py
+++ b/src/oidc_test/main.py
@@ -35,6 +35,8 @@ from .auth_utils import (
authlib_oauth,
get_providers_info,
get_token_or_none,
+ get_token,
+ update_token,
)
from .auth_misc import pretty_details
from .database import TokenNotInDb, db
@@ -97,7 +99,7 @@ async def home(
access_token_scope = None
else:
try:
- access_token_scope = user.decode_access_token()["scope"]
+ access_token_scope = user.get_scope(verify_signature=False)
except InvalidTokenError as err:
access_token_scope = None
logger.info("Invalid token")
@@ -113,15 +115,22 @@ async def home(
"resources": resources,
}
if token is None:
+ context["access_token"] = None
context["id_token_parsed"] = None
context["access_token_parsed"] = None
context["refresh_token_parsed"] = None
else:
+ context["access_token"] = token["access_token"]
assert oidc_provider is not None
assert oidc_provider.name is not None
oidc_provider_settings = oidc_providers_settings[oidc_provider.name]
- context["id_token_parsed"] = pretty_details(user, now)
- context["access_token_parsed"] = oidc_provider_settings.decode(token["access_token"])
+ # context["id_token_parsed"] = pretty_details(user, now)
+ context["id_token_parsed"] = oidc_provider_settings.decode(
+ token["id_token"], verify_signature=False
+ )
+ context["access_token_parsed"] = oidc_provider_settings.decode(
+ token["access_token"], verify_signature=False
+ )
context["refresh_token_parsed"] = oidc_provider_settings.decode(
token["refresh_token"], verify_signature=False
)
@@ -282,6 +291,21 @@ async def non_compliant_logout(
)
+@app.get("/refresh")
+async def refresh(
+ request: Request,
+ oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
+ token: Annotated[OAuth2Token, Depends(get_token)],
+) -> RedirectResponse:
+ """Manually refresh token"""
+ new_token = await oidc_provider.fetch_access_token(
+ refresh_token=token["refresh_token"],
+ grant_type="refresh_token",
+ )
+ await update_token(oidc_provider.name, new_token)
+ return RedirectResponse(url=request.url_for("home"))
+
+
# Snippet for running standalone
# Mostly useful for the --version option,
# as running with uvicorn is easy and provides better flexibility, eg.
diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py
index 9554bd5..8aee2e6 100644
--- a/src/oidc_test/models.py
+++ b/src/oidc_test/models.py
@@ -54,10 +54,15 @@ class User(UserBase):
access_token_scopes = []
return scope in set(info_scopes + access_token_scopes)
- def decode_access_token(self):
+ def decode_access_token(self, verify_signature: bool = True):
assert self.access_token is not None
assert self.oidc_provider is not None
assert self.oidc_provider.name is not None
from .settings import oidc_providers_settings
- return oidc_providers_settings[self.oidc_provider.name].decode(self.access_token)
+ return oidc_providers_settings[self.oidc_provider.name].decode(
+ self.access_token, verify_signature=verify_signature
+ )
+
+ def get_scope(self, verify_signature: bool = True):
+ return self.decode_access_token(verify_signature=verify_signature)["scope"]
diff --git a/src/oidc_test/static/utils.js b/src/oidc_test/static/utils.js
index 6ea8da2..8e8ad59 100644
--- a/src/oidc_test/static/utils.js
+++ b/src/oidc_test/static/utils.js
@@ -2,6 +2,7 @@ async function checkHref(elem, token, authProvider) {
const msg = document.getElementById("msg")
const url = `resource/${elem.getAttribute("resource-id")}`
const resp = await fetch(url, {
+ method: "GET",
headers: new Headers({
"Content-type": "application/json",
"Authorization": `Bearer ${token}`,
diff --git a/src/oidc_test/templates/base.html b/src/oidc_test/templates/base.html
index 0fe1a6b..2ce758c 100644
--- a/src/oidc_test/templates/base.html
+++ b/src/oidc_test/templates/base.html
@@ -4,7 +4,7 @@
-
+
OIDC-test - FastAPI client
{% block content %}
{% endblock %}
diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html
index 9da5392..08bcf43 100644
--- a/src/oidc_test/templates/home.html
+++ b/src/oidc_test/templates/home.html
@@ -57,6 +57,7 @@
Account management
{% endif %}
+
{% endif %}
@@ -66,21 +67,21 @@
Resources validated by scope:
-
-
+
+
Resources validated by role:
-
-
-
-
-
-
-
-
+
+
+
+
+
+
+
+