Get roles from access token, remove user info inspection, refreactorings
This commit is contained in:
parent
5c9ed9724e
commit
ee8ba3d2df
6 changed files with 126 additions and 97 deletions
|
@ -15,7 +15,7 @@ from fastapi.staticfiles import StaticFiles
|
|||
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from jwt import InvalidKeyError, InvalidTokenError
|
||||
from jwt import InvalidTokenError, PyJWTError
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
from authlib.integrations.base_client import OAuthError
|
||||
|
@ -26,7 +26,7 @@ from authlib.oauth2.rfc6749 import OAuth2Token
|
|||
# from fastapi.security import OpenIdConnect
|
||||
# from pkce import generate_code_verifier, generate_pkce_pair
|
||||
|
||||
from .settings import settings
|
||||
from .settings import settings, oidc_providers_settings
|
||||
from .models import User
|
||||
from .auth_utils import (
|
||||
get_oidc_provider,
|
||||
|
@ -37,14 +37,13 @@ from .auth_utils import (
|
|||
get_user_from_token,
|
||||
authlib_oauth,
|
||||
get_token,
|
||||
oidc_providers_settings,
|
||||
get_providers_info,
|
||||
)
|
||||
from .auth_misc import pretty_details
|
||||
from .database import TokenNotInDb, db
|
||||
from .resource_server import get_resource
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
logger = logging.getLogger("oidc-test")
|
||||
|
||||
templates = Jinja2Templates(Path(__file__).parent / "templates")
|
||||
|
||||
|
@ -189,43 +188,28 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
|||
request.session["oidc_provider_id"] = oidc_provider_id
|
||||
# User id (sub) given by oidc provider
|
||||
sub = userinfo["sub"]
|
||||
# Get additional data from userinfo endpoint
|
||||
try:
|
||||
user_info_from_endpoint = await oidc_provider.userinfo(
|
||||
token=token, follow_redirects=True
|
||||
)
|
||||
except Exception as err:
|
||||
logger.warn(f"Cannot get userinfo from endpoint: {err}")
|
||||
user_info_from_endpoint = {}
|
||||
# Build and remember the user in the session
|
||||
request.session["user_sub"] = sub
|
||||
# Verify the token's signature and validity
|
||||
# Store the user in the database, which also verifies the token validity and signature
|
||||
try:
|
||||
oidc_provider_settings = oidc_providers_settings[oidc_provider_id]
|
||||
oidc_provider_settings.decode(token["access_token"])
|
||||
except InvalidKeyError:
|
||||
user = await db.add_user(
|
||||
sub,
|
||||
user_info=userinfo,
|
||||
oidc_provider=oidc_provider,
|
||||
access_token=token["access_token"],
|
||||
)
|
||||
except PyJWTError as err:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token invalid key / signature",
|
||||
detail=f"Token invalid: {err.__class__.__name__}",
|
||||
)
|
||||
except Exception as err:
|
||||
logger.exception(err)
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Cannot decode token or verify its signature",
|
||||
)
|
||||
# Store the user in the database
|
||||
user = await db.add_user(
|
||||
sub,
|
||||
user_info=userinfo,
|
||||
oidc_provider=oidc_provider,
|
||||
user_info_from_endpoint=user_info_from_endpoint,
|
||||
access_token=token["access_token"],
|
||||
)
|
||||
# Add the id_token to the session
|
||||
request.session["token"] = token["id_token"]
|
||||
assert isinstance(user, User)
|
||||
# Add the provider session id to the session
|
||||
request.session["sid"] = userinfo["sid"]
|
||||
# Add the token to the db because it is used for logout
|
||||
await db.add_token(token, user)
|
||||
assert oidc_provider.name is not None
|
||||
oidc_provider_settings = oidc_providers_settings[oidc_provider.name]
|
||||
await db.add_token(oidc_provider_settings, token)
|
||||
# Send the user to the home: (s)he is authenticated
|
||||
return RedirectResponse(url=request.url_for("home"))
|
||||
else:
|
||||
|
@ -268,8 +252,14 @@ async def logout(
|
|||
)
|
||||
return RedirectResponse(request.url_for("non_compliant_logout"))
|
||||
post_logout_uri = request.url_for("home")
|
||||
oidc_provider_settings = oidc_providers_settings.get(
|
||||
request.session.get("oidc_provider_id", "")
|
||||
)
|
||||
assert oidc_provider_settings is not None
|
||||
try:
|
||||
token = await db.get_token(request.session.pop("token", None))
|
||||
token = await db.get_token(
|
||||
oidc_provider_settings, request.session.pop("sid", None)
|
||||
)
|
||||
except TokenNotInDb:
|
||||
logger.warn("No session in db for the token or no token")
|
||||
return RedirectResponse(request.url_for("home"))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue