Fix home when token cannot be decoded
All checks were successful
/ build (push) Successful in 5s
/ test (push) Successful in 5s

This commit is contained in:
phil 2025-02-04 03:38:33 +01:00
parent fefe44acfe
commit aa86f81358
2 changed files with 32 additions and 6 deletions

View file

@ -15,6 +15,7 @@ from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from jwt import InvalidKeyError, InvalidTokenError
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from authlib.integrations.base_client import OAuthError from authlib.integrations.base_client import OAuthError
@ -96,13 +97,24 @@ async def home(
else: else:
resources = [] resources = []
oidc_provider_settings = None oidc_provider_settings = None
if user is None:
access_token_scope = None
else:
try:
access_token_scope = user.decode_access_token()["scope"]
except InvalidTokenError as err:
access_token_scope = None
logger.info("Invalid token")
logger.exception(err)
return templates.TemplateResponse( return templates.TemplateResponse(
name="home.html", name="home.html",
request=request, request=request,
context={ context={
"settings": settings.model_dump(), "settings": settings.model_dump(),
"user": user, "user": user,
"access_token_scope": user.access_token_parsed()["scope"] if user else None, "access_token_scope": access_token_scope,
"now": now, "now": now,
"oidc_provider": oidc_provider, "oidc_provider": oidc_provider,
"oidc_provider_settings": oidc_provider_settings, "oidc_provider_settings": oidc_provider_settings,
@ -187,15 +199,22 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
user_info_from_endpoint = {} user_info_from_endpoint = {}
# Build and remember the user in the session # Build and remember the user in the session
request.session["user_sub"] = sub request.session["user_sub"] = sub
# Store the user in the database # Verify the token's signature and validity
try: try:
oidc_provider_settings = oidc_providers_settings[oidc_provider_id] oidc_provider_settings = oidc_providers_settings[oidc_provider_id]
access_token = oidc_provider_settings.decode(token["access_token"]) oidc_provider_settings.decode(token["access_token"])
except Exception: except InvalidKeyError:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED,
detail="Token invalid key / signature",
)
except Exception as err:
logger.exception(err)
raise HTTPException( raise HTTPException(
status.HTTP_401_UNAUTHORIZED, status.HTTP_401_UNAUTHORIZED,
detail="Cannot decode token or verify its signature", detail="Cannot decode token or verify its signature",
) )
# Store the user in the database
user = await db.add_user( user = await db.add_user(
sub, sub,
user_info=userinfo, user_info=userinfo,

View file

@ -1,3 +1,4 @@
import logging
from functools import cached_property from functools import cached_property
from typing import Self from typing import Self
@ -10,6 +11,8 @@ from pydantic import (
from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from sqlmodel import SQLModel, Field from sqlmodel import SQLModel, Field
logger = logging.getLogger(__name__)
class Role(SQLModel, extra="ignore"): class Role(SQLModel, extra="ignore"):
name: str name: str
@ -54,10 +57,14 @@ class User(UserBase):
def has_scope(self, scope: str) -> bool: def has_scope(self, scope: str) -> bool:
"""Check if the scope is present in user info or access token""" """Check if the scope is present in user info or access token"""
info_scopes = self.userinfo.get("scope", "").split(" ") info_scopes = self.userinfo.get("scope", "").split(" ")
access_token_scopes = self.access_token_parsed().get("scope", "").split(" ") try:
access_token_scopes = self.decode_access_token().get("scope", "").split(" ")
except Exception as err:
logger.info(f"Access token cannot be decoded: {err}")
access_token_scopes = []
return scope in set(info_scopes + access_token_scopes) return scope in set(info_scopes + access_token_scopes)
def access_token_parsed(self): def decode_access_token(self):
assert self.access_token is not None assert self.access_token is not None
assert self.oidc_provider is not None assert self.oidc_provider is not None
assert self.oidc_provider.name is not None assert self.oidc_provider.name is not None