From c7e5332e12c38b508e4337b299e163a504dd67af Mon Sep 17 00:00:00 2001 From: phil Date: Sat, 11 Jan 2025 20:41:33 +0100 Subject: [PATCH] Use id_token for sessions --- src/oidc_test/database.py | 5 +---- src/oidc_test/main.py | 41 +++++++++++++++++++++------------------ 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index 1aae7cc..df4966d 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -22,10 +22,7 @@ class Database: return self.users[sub] async def add_token(self, token_dict: dict, user: User) -> None: - # FIXME: The tokens are stored with the user.sub key, meaning that - # sessions logged in with different clients simultanously will - # interfer with ezach others. - self.tokens[user.sub] = OAuth2Token.from_dict(token_dict=token_dict, user=user) + self.tokens[token_dict['id_token']] = OAuth2Token.from_dict(token_dict=token_dict, user=user) async def get_token(self, name) -> OAuth2Token | None: return self.tokens.get(name) diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 9b75985..5854f00 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -90,9 +90,10 @@ async def login(request: Request, oidc_provider_id: str) -> RedirectResponse: except AttributeError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") try: - return await provider_.authorize_redirect( + response = await provider_.authorize_redirect( request, redirect_uri, access_type="offline" ) + return response except HTTPError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider") @@ -110,21 +111,24 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: token = await oidc_provider.authorize_access_token(request) except OAuthError as error: raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error) - # Remember the oidc_provider in the session - request.session["oidc_provider_id"] = oidc_provider_id # # One could process the full decoded token which contains extra information # eg for updates. Here we are only interested in roles # if userinfo := token.get("userinfo"): - # sub given by oidc provider + # Remember the oidc_provider in the session + request.session["oidc_provider_id"] = oidc_provider_id + # User id (sub) given by oidc provider sub = userinfo["sub"] # Build and remember the user in the session request.session["user_sub"] = sub # Store the user in the database user = await db.add_user(sub, user_info=userinfo, oidc_provider=oidc_provider) - request.session["token"] = userinfo["sub"] + # Add the id_token to the session + request.session["token"] = token['id_token'] + # Add the token to the db because it is used for logout await db.add_token(token, user) + # Send the user to the home: (s)he is authenticated return RedirectResponse(url=request.url_for("home")) else: # Not sure if it's correct to redirect to plain login @@ -134,19 +138,6 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: ) -@app.get("/non-compliant-logout") -async def non_compliant_logout( - request: Request, - provider: Annotated[StarletteOAuth2App, Depends(get_provider)], -): - """A page for non-compliant OAuth2 servers that we cannot log out.""" - return templates.TemplateResponse( - name="non_compliant_logout.html", - request=request, - context={"provider": provider, "home_url": request.url_for("home")}, - ) - - @app.get("/logout") async def logout( request: Request, @@ -161,7 +152,7 @@ async def logout( logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}") return RedirectResponse(request.url_for("non_compliant_logout")) post_logout_uri = request.url_for("home") - if (id_token := await db.get_token(request.session["token"])) is None: + if (id_token := await db.get_token(request.session.pop("token", None))) is None: logger.warn("No session in db for the token") return RedirectResponse(request.url_for("home")) logout_url = ( @@ -178,6 +169,18 @@ async def logout( return RedirectResponse(logout_url) +@app.get("/non-compliant-logout") +async def non_compliant_logout( + request: Request, + provider: Annotated[StarletteOAuth2App, Depends(get_provider)], +): + """A page for non-compliant OAuth2 servers that we cannot log out.""" + return templates.TemplateResponse( + name="non_compliant_logout.html", + request=request, + context={"provider": provider, "home_url": request.url_for("home")}, + ) + # Home URL