From 31a783cbf19be0fe2e78f1fdb053c485d379d3a1 Mon Sep 17 00:00:00 2001 From: phil Date: Tue, 4 Feb 2025 18:03:17 +0100 Subject: [PATCH] Fix token error handling --- src/oidc_test/auth_utils.py | 7 +++---- src/oidc_test/database.py | 4 +++- src/oidc_test/main.py | 8 +++++--- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index d96aba9..fd82ecd 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -113,8 +113,7 @@ async def get_current_user(request: Request) -> User: """ if (user_sub := request.session.get("user_sub")) is None: raise HTTPException(status.HTTP_401_UNAUTHORIZED) - if (token := await db.get_token(request.session["token"])) is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token unknown") + token = await get_token(request) user = await db.get_user(user_sub) ## Check if the token is expired if token.is_expired(): @@ -138,8 +137,8 @@ async def get_token(request: Request) -> OAuth2Token: """Return the token from a request object, from the session. It can be used in Depends()""" try: - return await db.get_token(request.session["token"]) - except (KeyError, TokenNotInDb): + return await db.get_token(request.session.get("token")) + except TokenNotInDb: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token") diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index 1b682ef..0a30e9c 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -54,7 +54,9 @@ class Database: async def add_token(self, token: OAuth2Token, user: User) -> None: self.tokens[token["id_token"]] = token - async def get_token(self, id_token: str) -> OAuth2Token: + async def get_token(self, id_token: str | None) -> OAuth2Token: + if id_token is None: + raise TokenNotInDb try: return self.tokens[id_token] except KeyError: diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index aac258b..e9ba5b1 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -41,7 +41,7 @@ from .auth_utils import ( get_providers_info, ) from .auth_misc import pretty_details -from .database import db +from .database import TokenNotInDb, db from .resource_server import get_resource logger = logging.getLogger("uvicorn.error") @@ -268,8 +268,10 @@ async def logout( ) return RedirectResponse(request.url_for("non_compliant_logout")) post_logout_uri = request.url_for("home") - if (token := await db.get_token(request.session.pop("token", None))) is None: - logger.warn("No session in db for the token") + try: + token = await db.get_token(request.session.pop("token", None)) + except TokenNotInDb: + logger.warn("No session in db for the token or no token") return RedirectResponse(request.url_for("home")) logout_url = ( provider_logout_uri