From aa86f8135843232739df3790aa137b3da4d78f85 Mon Sep 17 00:00:00 2001 From: phil Date: Tue, 4 Feb 2025 03:38:33 +0100 Subject: [PATCH] Fix home when token cannot be decoded --- src/oidc_test/main.py | 27 +++++++++++++++++++++++---- src/oidc_test/models.py | 11 +++++++++-- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index e14b4a8..aac258b 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -15,6 +15,7 @@ from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse from fastapi.templating import Jinja2Templates from fastapi.middleware.cors import CORSMiddleware +from jwt import InvalidKeyError, InvalidTokenError from starlette.middleware.sessions import SessionMiddleware from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.base_client import OAuthError @@ -96,13 +97,24 @@ async def home( else: resources = [] oidc_provider_settings = None + + if user is None: + access_token_scope = None + else: + try: + access_token_scope = user.decode_access_token()["scope"] + except InvalidTokenError as err: + access_token_scope = None + logger.info("Invalid token") + logger.exception(err) + return templates.TemplateResponse( name="home.html", request=request, context={ "settings": settings.model_dump(), "user": user, - "access_token_scope": user.access_token_parsed()["scope"] if user else None, + "access_token_scope": access_token_scope, "now": now, "oidc_provider": oidc_provider, "oidc_provider_settings": oidc_provider_settings, @@ -187,15 +199,22 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: user_info_from_endpoint = {} # Build and remember the user in the session request.session["user_sub"] = sub - # Store the user in the database + # Verify the token's signature and validity try: oidc_provider_settings = oidc_providers_settings[oidc_provider_id] - access_token = oidc_provider_settings.decode(token["access_token"]) - except Exception: + oidc_provider_settings.decode(token["access_token"]) + except InvalidKeyError: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, + detail="Token invalid key / signature", + ) + except Exception as err: + logger.exception(err) raise HTTPException( status.HTTP_401_UNAUTHORIZED, detail="Cannot decode token or verify its signature", ) + # Store the user in the database user = await db.add_user( sub, user_info=userinfo, diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py index db5d6ad..4b1c064 100644 --- a/src/oidc_test/models.py +++ b/src/oidc_test/models.py @@ -1,3 +1,4 @@ +import logging from functools import cached_property from typing import Self @@ -10,6 +11,8 @@ from pydantic import ( from authlib.integrations.starlette_client.apps import StarletteOAuth2App from sqlmodel import SQLModel, Field +logger = logging.getLogger(__name__) + class Role(SQLModel, extra="ignore"): name: str @@ -54,10 +57,14 @@ class User(UserBase): def has_scope(self, scope: str) -> bool: """Check if the scope is present in user info or access token""" info_scopes = self.userinfo.get("scope", "").split(" ") - access_token_scopes = self.access_token_parsed().get("scope", "").split(" ") + try: + access_token_scopes = self.decode_access_token().get("scope", "").split(" ") + except Exception as err: + logger.info(f"Access token cannot be decoded: {err}") + access_token_scopes = [] return scope in set(info_scopes + access_token_scopes) - def access_token_parsed(self): + def decode_access_token(self): assert self.access_token is not None assert self.oidc_provider is not None assert self.oidc_provider.name is not None