Git merge
Some checks failed
/ build (push) Failing after 14s
/ test (push) Successful in 5s

This commit is contained in:
phil 2025-01-13 05:45:31 +01:00
commit 831ea063c1
2 changed files with 24 additions and 21 deletions

View file

@ -38,10 +38,7 @@ class Database:
return self.users[sub]
async def add_token(self, token_dict: dict, user: User) -> None:
# FIXME: The tokens are stored with the user.sub key, meaning that
# sessions logged in with different clients simultanously will
# interfer with ezach others.
self.tokens[user.sub] = OAuth2Token.from_dict(token_dict=token_dict, user=user)
self.tokens[token_dict['id_token']] = OAuth2Token.from_dict(token_dict=token_dict, user=user)
async def get_token(self, name) -> OAuth2Token | None:
return self.tokens.get(name)

View file

@ -90,9 +90,10 @@ async def login(request: Request, oidc_provider_id: str) -> RedirectResponse:
except AttributeError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
try:
return await provider_.authorize_redirect(
response = await provider_.authorize_redirect(
request, redirect_uri, access_type="offline"
)
return response
except HTTPError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
@ -118,7 +119,9 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
# eg for updates. Here we are only interested in roles
#
if userinfo := token.get("userinfo"):
# sub given by oidc provider
# Remember the oidc_provider in the session
request.session["oidc_provider_id"] = oidc_provider_id
# User id (sub) given by oidc provider
sub = userinfo["sub"]
# Get additional data from userinfo endpoint
try:
@ -140,8 +143,11 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
oidc_provider=oidc_provider,
user_info_from_endpoint=user_info_from_endpoint,
)
request.session["token"] = userinfo["sub"]
# Add the id_token to the session
request.session["token"] = token["id_token"]
# Add the token to the db because it is used for logout
await db.add_token(token, user)
# Send the user to the home: (s)he is authenticated
return RedirectResponse(url=request.url_for("home"))
else:
# Not sure if it's correct to redirect to plain login
@ -151,19 +157,6 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
)
@app.get("/non-compliant-logout")
async def non_compliant_logout(
request: Request,
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
):
"""A page for non-compliant OAuth2 servers that we cannot log out."""
return templates.TemplateResponse(
name="non_compliant_logout.html",
request=request,
context={"provider": provider, "home_url": request.url_for("home")},
)
@app.get("/logout")
async def logout(
request: Request,
@ -178,7 +171,7 @@ async def logout(
logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}")
return RedirectResponse(request.url_for("non_compliant_logout"))
post_logout_uri = request.url_for("home")
if (id_token := await db.get_token(request.session["token"])) is None:
if (id_token := await db.get_token(request.session.pop("token", None))) is None:
logger.warn("No session in db for the token")
return RedirectResponse(request.url_for("home"))
logout_url = (
@ -195,6 +188,19 @@ async def logout(
return RedirectResponse(logout_url)
@app.get("/non-compliant-logout")
async def non_compliant_logout(
request: Request,
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
):
"""A page for non-compliant OAuth2 servers that we cannot log out."""
return templates.TemplateResponse(
name="non_compliant_logout.html",
request=request,
context={"provider": provider, "home_url": request.url_for("home")},
)
# Home URL