Fix token error handling
All checks were successful
/ build (push) Successful in 5s
/ test (push) Successful in 5s

This commit is contained in:
phil 2025-02-04 18:03:17 +01:00
parent aa86f81358
commit 31a783cbf1
3 changed files with 11 additions and 8 deletions

View file

@ -113,8 +113,7 @@ async def get_current_user(request: Request) -> User:
""" """
if (user_sub := request.session.get("user_sub")) is None: if (user_sub := request.session.get("user_sub")) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED) raise HTTPException(status.HTTP_401_UNAUTHORIZED)
if (token := await db.get_token(request.session["token"])) is None: token = await get_token(request)
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token unknown")
user = await db.get_user(user_sub) user = await db.get_user(user_sub)
## Check if the token is expired ## Check if the token is expired
if token.is_expired(): if token.is_expired():
@ -138,8 +137,8 @@ async def get_token(request: Request) -> OAuth2Token:
"""Return the token from a request object, from the session. """Return the token from a request object, from the session.
It can be used in Depends()""" It can be used in Depends()"""
try: try:
return await db.get_token(request.session["token"]) return await db.get_token(request.session.get("token"))
except (KeyError, TokenNotInDb): except TokenNotInDb:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token")

View file

@ -54,7 +54,9 @@ class Database:
async def add_token(self, token: OAuth2Token, user: User) -> None: async def add_token(self, token: OAuth2Token, user: User) -> None:
self.tokens[token["id_token"]] = token self.tokens[token["id_token"]] = token
async def get_token(self, id_token: str) -> OAuth2Token: async def get_token(self, id_token: str | None) -> OAuth2Token:
if id_token is None:
raise TokenNotInDb
try: try:
return self.tokens[id_token] return self.tokens[id_token]
except KeyError: except KeyError:

View file

@ -41,7 +41,7 @@ from .auth_utils import (
get_providers_info, get_providers_info,
) )
from .auth_misc import pretty_details from .auth_misc import pretty_details
from .database import db from .database import TokenNotInDb, db
from .resource_server import get_resource from .resource_server import get_resource
logger = logging.getLogger("uvicorn.error") logger = logging.getLogger("uvicorn.error")
@ -268,8 +268,10 @@ async def logout(
) )
return RedirectResponse(request.url_for("non_compliant_logout")) return RedirectResponse(request.url_for("non_compliant_logout"))
post_logout_uri = request.url_for("home") post_logout_uri = request.url_for("home")
if (token := await db.get_token(request.session.pop("token", None))) is None: try:
logger.warn("No session in db for the token") token = await db.get_token(request.session.pop("token", None))
except TokenNotInDb:
logger.warn("No session in db for the token or no token")
return RedirectResponse(request.url_for("home")) return RedirectResponse(request.url_for("home"))
logout_url = ( logout_url = (
provider_logout_uri provider_logout_uri