Use id_token for sessions
This commit is contained in:
parent
724887e133
commit
c7e5332e12
2 changed files with 23 additions and 23 deletions
|
@ -22,10 +22,7 @@ class Database:
|
||||||
return self.users[sub]
|
return self.users[sub]
|
||||||
|
|
||||||
async def add_token(self, token_dict: dict, user: User) -> None:
|
async def add_token(self, token_dict: dict, user: User) -> None:
|
||||||
# FIXME: The tokens are stored with the user.sub key, meaning that
|
self.tokens[token_dict['id_token']] = OAuth2Token.from_dict(token_dict=token_dict, user=user)
|
||||||
# sessions logged in with different clients simultanously will
|
|
||||||
# interfer with ezach others.
|
|
||||||
self.tokens[user.sub] = OAuth2Token.from_dict(token_dict=token_dict, user=user)
|
|
||||||
|
|
||||||
async def get_token(self, name) -> OAuth2Token | None:
|
async def get_token(self, name) -> OAuth2Token | None:
|
||||||
return self.tokens.get(name)
|
return self.tokens.get(name)
|
||||||
|
|
|
@ -90,9 +90,10 @@ async def login(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
||||||
try:
|
try:
|
||||||
return await provider_.authorize_redirect(
|
response = await provider_.authorize_redirect(
|
||||||
request, redirect_uri, access_type="offline"
|
request, redirect_uri, access_type="offline"
|
||||||
)
|
)
|
||||||
|
return response
|
||||||
except HTTPError:
|
except HTTPError:
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
|
||||||
|
|
||||||
|
@ -110,21 +111,24 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
||||||
token = await oidc_provider.authorize_access_token(request)
|
token = await oidc_provider.authorize_access_token(request)
|
||||||
except OAuthError as error:
|
except OAuthError as error:
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error)
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error)
|
||||||
# Remember the oidc_provider in the session
|
|
||||||
request.session["oidc_provider_id"] = oidc_provider_id
|
|
||||||
#
|
#
|
||||||
# One could process the full decoded token which contains extra information
|
# One could process the full decoded token which contains extra information
|
||||||
# eg for updates. Here we are only interested in roles
|
# eg for updates. Here we are only interested in roles
|
||||||
#
|
#
|
||||||
if userinfo := token.get("userinfo"):
|
if userinfo := token.get("userinfo"):
|
||||||
# sub given by oidc provider
|
# Remember the oidc_provider in the session
|
||||||
|
request.session["oidc_provider_id"] = oidc_provider_id
|
||||||
|
# User id (sub) given by oidc provider
|
||||||
sub = userinfo["sub"]
|
sub = userinfo["sub"]
|
||||||
# Build and remember the user in the session
|
# Build and remember the user in the session
|
||||||
request.session["user_sub"] = sub
|
request.session["user_sub"] = sub
|
||||||
# Store the user in the database
|
# Store the user in the database
|
||||||
user = await db.add_user(sub, user_info=userinfo, oidc_provider=oidc_provider)
|
user = await db.add_user(sub, user_info=userinfo, oidc_provider=oidc_provider)
|
||||||
request.session["token"] = userinfo["sub"]
|
# Add the id_token to the session
|
||||||
|
request.session["token"] = token['id_token']
|
||||||
|
# Add the token to the db because it is used for logout
|
||||||
await db.add_token(token, user)
|
await db.add_token(token, user)
|
||||||
|
# Send the user to the home: (s)he is authenticated
|
||||||
return RedirectResponse(url=request.url_for("home"))
|
return RedirectResponse(url=request.url_for("home"))
|
||||||
else:
|
else:
|
||||||
# Not sure if it's correct to redirect to plain login
|
# Not sure if it's correct to redirect to plain login
|
||||||
|
@ -134,19 +138,6 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/non-compliant-logout")
|
|
||||||
async def non_compliant_logout(
|
|
||||||
request: Request,
|
|
||||||
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
|
|
||||||
):
|
|
||||||
"""A page for non-compliant OAuth2 servers that we cannot log out."""
|
|
||||||
return templates.TemplateResponse(
|
|
||||||
name="non_compliant_logout.html",
|
|
||||||
request=request,
|
|
||||||
context={"provider": provider, "home_url": request.url_for("home")},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/logout")
|
@app.get("/logout")
|
||||||
async def logout(
|
async def logout(
|
||||||
request: Request,
|
request: Request,
|
||||||
|
@ -161,7 +152,7 @@ async def logout(
|
||||||
logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}")
|
logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}")
|
||||||
return RedirectResponse(request.url_for("non_compliant_logout"))
|
return RedirectResponse(request.url_for("non_compliant_logout"))
|
||||||
post_logout_uri = request.url_for("home")
|
post_logout_uri = request.url_for("home")
|
||||||
if (id_token := await db.get_token(request.session["token"])) is None:
|
if (id_token := await db.get_token(request.session.pop("token", None))) is None:
|
||||||
logger.warn("No session in db for the token")
|
logger.warn("No session in db for the token")
|
||||||
return RedirectResponse(request.url_for("home"))
|
return RedirectResponse(request.url_for("home"))
|
||||||
logout_url = (
|
logout_url = (
|
||||||
|
@ -178,6 +169,18 @@ async def logout(
|
||||||
return RedirectResponse(logout_url)
|
return RedirectResponse(logout_url)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/non-compliant-logout")
|
||||||
|
async def non_compliant_logout(
|
||||||
|
request: Request,
|
||||||
|
provider: Annotated[StarletteOAuth2App, Depends(get_provider)],
|
||||||
|
):
|
||||||
|
"""A page for non-compliant OAuth2 servers that we cannot log out."""
|
||||||
|
return templates.TemplateResponse(
|
||||||
|
name="non_compliant_logout.html",
|
||||||
|
request=request,
|
||||||
|
context={"provider": provider, "home_url": request.url_for("home")},
|
||||||
|
)
|
||||||
|
|
||||||
# Home URL
|
# Home URL
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue