Get roles from access token, remove user info inspection, refreactorings

This commit is contained in:
phil 2025-02-06 13:30:35 +01:00
parent 5c9ed9724e
commit bc4c4128ad
6 changed files with 126 additions and 97 deletions

View file

@ -13,13 +13,10 @@ from authlib.oauth2.auth import OAuth2Token
from .models import User from .models import User
from .database import TokenNotInDb, db, UserNotInDB from .database import TokenNotInDb, db, UserNotInDB
from .settings import settings, OIDCProvider from .settings import settings, OIDCProvider, oidc_providers_settings
logger = logging.getLogger(__name__) logger = logging.getLogger("oidc-test")
oidc_providers_settings: dict[str, OIDCProvider] = dict(
[(provider.id, provider) for provider in settings.oidc.providers]
)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
@ -36,19 +33,16 @@ async def fetch_token(name, request):
async def update_token(name, token, refresh_token=None, access_token=None): async def update_token(name, token, refresh_token=None, access_token=None):
breakpoint() oidc_provider_settings = oidc_providers_settings[name]
item = await db.get_token(token["id_token"]) sid: str = oidc_provider_settings.decode(token["id_token"])["sid"]
if refresh_token: item = await db.get_token(oidc_provider_settings, sid)
item = OAuth2Token.find(name=name, refresh_token=refresh_token)
elif access_token:
item = OAuth2Token.find(name=name, access_token=access_token)
else:
return
# update old token # update old token
item.access_token = token["access_token"] item["access_token"] = token.get("access_token")
item.refresh_token = token.get("refresh_token") item["refresh_token"] = token.get("refresh_token")
item.expires_at = token["expires_at"] item["expires_at"] = token["expires_at"]
item.save() logger.info(f"Token {sid} refreshed")
# It's a fake db and only in memory, so there's nothing to save
# await item.save()
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
@ -138,8 +132,17 @@ async def get_token(request: Request) -> OAuth2Token:
"""Return the token from a request object, from the session. """Return the token from a request object, from the session.
It can be used in Depends()""" It can be used in Depends()"""
try: try:
return await db.get_token(request.session.get("token")) oidc_provider_settings = oidc_providers_settings[
except TokenNotInDb: request.session.get("oidc_provider_id", "")
]
except KeyError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid provider")
try:
return await db.get_token(
oidc_provider_settings,
request.session.get("sid"),
)
except (TokenNotInDb, InvalidKeyError):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token")
@ -190,14 +193,16 @@ async def get_user_from_token(
token: Annotated[str, Depends(oauth2_scheme)], token: Annotated[str, Depends(oauth2_scheme)],
request: Request, request: Request,
) -> User: ) -> User:
if (auth_provider_id := request.headers.get("auth_provider")) is None: try:
auth_provider_id = request.headers["auth_provider"]
except KeyError:
raise HTTPException( raise HTTPException(
status.HTTP_401_UNAUTHORIZED, status.HTTP_401_UNAUTHORIZED,
"Request headers must have a 'auth_provider' field", "Request headers must have a 'auth_provider' field",
) )
if ( try:
auth_provider_settings := oidc_providers_settings.get(auth_provider_id) auth_provider_settings = oidc_providers_settings[auth_provider_id]
) is None: except KeyError:
raise HTTPException( raise HTTPException(
status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'" status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
) )
@ -216,7 +221,9 @@ async def get_user_from_token(
logger.info("Cannot decode token, see below") logger.info("Cannot decode token, see below")
logger.exception(err) logger.exception(err)
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot decode token") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot decode token")
if (user_id := payload.get("sub")) is None: try:
user_id = payload["sub"]
except KeyError:
raise HTTPException( raise HTTPException(
status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found" status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found"
) )
@ -232,7 +239,6 @@ async def get_user_from_token(
sub=payload["sub"], sub=payload["sub"],
user_info=payload, user_info=payload,
oidc_provider=getattr(authlib_oauth, auth_provider_id), oidc_provider=getattr(authlib_oauth, auth_provider_id),
user_info_from_endpoint={},
access_token=token, access_token=token,
) )
return user return user

View file

@ -3,11 +3,12 @@
import logging import logging
from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from .models import User, Role
from authlib.oauth2.rfc6749 import OAuth2Token from authlib.oauth2.rfc6749 import OAuth2Token
logger = logging.getLogger(__name__) from .settings import OIDCProvider, oidc_providers_settings
from .models import User, Role
logger = logging.getLogger("oidc-test")
class UserNotInDB(Exception): class UserNotInDB(Exception):
@ -29,20 +30,34 @@ class Database:
sub: str, sub: str,
user_info: dict, user_info: dict,
oidc_provider: StarletteOAuth2App, oidc_provider: StarletteOAuth2App,
user_info_from_endpoint: dict,
access_token: str, access_token: str,
access_token_decoded: dict | None = None,
) -> User: ) -> User:
user = User.from_auth(userinfo=user_info, oidc_provider=oidc_provider) if access_token_decoded is None:
assert oidc_provider.name is not None
oidc_provider_settings = oidc_providers_settings[oidc_provider.name]
access_token_decoded = oidc_provider_settings.decode(access_token)
user = User(**user_info)
user.userinfo = user_info
user.oidc_provider = oidc_provider
user.access_token = access_token user.access_token = access_token
user.access_token_decoded = access_token_decoded
# Add roles provided in the access token
roles = set()
try: try:
raw_roles = user_info_from_endpoint["resource_access"][ r = access_token_decoded["resource_access"][oidc_provider.client_id]["roles"]
oidc_provider.client_id roles.update(r)
]["roles"] except KeyError:
except Exception as err: pass
logger.debug(f"Cannot read additional roles: {err}") try:
raw_roles = [] r = access_token_decoded["realm_access"]["roles"]
for raw_role in raw_roles: if isinstance(r, str):
user.roles.append(Role(name=raw_role)) roles.add(r)
else:
roles.update(r)
except KeyError:
pass
user.roles = [Role(name=role_name) for role_name in roles]
self.users[sub] = user self.users[sub] = user
return user return user
@ -51,14 +66,21 @@ class Database:
raise UserNotInDB raise UserNotInDB
return self.users[sub] return self.users[sub]
async def add_token(self, token: OAuth2Token, user: User) -> None: async def add_token(self, oidc_provider_settings: OIDCProvider, token: OAuth2Token) -> None:
self.tokens[token["id_token"]] = token """Store a token using as key the sid (auth provider's session id)
in the id_token"""
sid = token["userinfo"]["sid"]
self.tokens[sid] = token
async def get_token(self, id_token: str | None) -> OAuth2Token: async def get_token(
if id_token is None: self,
oidc_provider_settings: OIDCProvider,
sid: str | None,
) -> OAuth2Token:
if sid is None:
raise TokenNotInDb raise TokenNotInDb
try: try:
return self.tokens[id_token] return self.tokens[sid]
except KeyError: except KeyError:
raise TokenNotInDb raise TokenNotInDb

View file

@ -15,7 +15,7 @@ from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from jwt import InvalidKeyError, InvalidTokenError from jwt import InvalidTokenError, PyJWTError
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from authlib.integrations.base_client import OAuthError from authlib.integrations.base_client import OAuthError
@ -26,7 +26,7 @@ from authlib.oauth2.rfc6749 import OAuth2Token
# from fastapi.security import OpenIdConnect # from fastapi.security import OpenIdConnect
# from pkce import generate_code_verifier, generate_pkce_pair # from pkce import generate_code_verifier, generate_pkce_pair
from .settings import settings from .settings import settings, oidc_providers_settings
from .models import User from .models import User
from .auth_utils import ( from .auth_utils import (
get_oidc_provider, get_oidc_provider,
@ -37,14 +37,13 @@ from .auth_utils import (
get_user_from_token, get_user_from_token,
authlib_oauth, authlib_oauth,
get_token, get_token,
oidc_providers_settings,
get_providers_info, get_providers_info,
) )
from .auth_misc import pretty_details from .auth_misc import pretty_details
from .database import TokenNotInDb, db from .database import TokenNotInDb, db
from .resource_server import get_resource from .resource_server import get_resource
logger = logging.getLogger("uvicorn.error") logger = logging.getLogger("oidc-test")
templates = Jinja2Templates(Path(__file__).parent / "templates") templates = Jinja2Templates(Path(__file__).parent / "templates")
@ -189,43 +188,28 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
request.session["oidc_provider_id"] = oidc_provider_id request.session["oidc_provider_id"] = oidc_provider_id
# User id (sub) given by oidc provider # User id (sub) given by oidc provider
sub = userinfo["sub"] sub = userinfo["sub"]
# Get additional data from userinfo endpoint
try:
user_info_from_endpoint = await oidc_provider.userinfo(
token=token, follow_redirects=True
)
except Exception as err:
logger.warn(f"Cannot get userinfo from endpoint: {err}")
user_info_from_endpoint = {}
# Build and remember the user in the session # Build and remember the user in the session
request.session["user_sub"] = sub request.session["user_sub"] = sub
# Verify the token's signature and validity # Store the user in the database, which also verifies the token validity and signature
try: try:
oidc_provider_settings = oidc_providers_settings[oidc_provider_id] user = await db.add_user(
oidc_provider_settings.decode(token["access_token"]) sub,
except InvalidKeyError: user_info=userinfo,
oidc_provider=oidc_provider,
access_token=token["access_token"],
)
except PyJWTError as err:
raise HTTPException( raise HTTPException(
status.HTTP_401_UNAUTHORIZED, status.HTTP_401_UNAUTHORIZED,
detail="Token invalid key / signature", detail=f"Token invalid: {err.__class__.__name__}",
) )
except Exception as err: assert isinstance(user, User)
logger.exception(err) # Add the provider session id to the session
raise HTTPException( request.session["sid"] = userinfo["sid"]
status.HTTP_401_UNAUTHORIZED,
detail="Cannot decode token or verify its signature",
)
# Store the user in the database
user = await db.add_user(
sub,
user_info=userinfo,
oidc_provider=oidc_provider,
user_info_from_endpoint=user_info_from_endpoint,
access_token=token["access_token"],
)
# Add the id_token to the session
request.session["token"] = token["id_token"]
# Add the token to the db because it is used for logout # Add the token to the db because it is used for logout
await db.add_token(token, user) assert oidc_provider.name is not None
oidc_provider_settings = oidc_providers_settings[oidc_provider.name]
await db.add_token(oidc_provider_settings, token)
# Send the user to the home: (s)he is authenticated # Send the user to the home: (s)he is authenticated
return RedirectResponse(url=request.url_for("home")) return RedirectResponse(url=request.url_for("home"))
else: else:
@ -268,8 +252,14 @@ async def logout(
) )
return RedirectResponse(request.url_for("non_compliant_logout")) return RedirectResponse(request.url_for("non_compliant_logout"))
post_logout_uri = request.url_for("home") post_logout_uri = request.url_for("home")
oidc_provider_settings = oidc_providers_settings.get(
request.session.get("oidc_provider_id", "")
)
assert oidc_provider_settings is not None
try: try:
token = await db.get_token(request.session.pop("token", None)) token = await db.get_token(
oidc_provider_settings, request.session.pop("sid", None)
)
except TokenNotInDb: except TokenNotInDb:
logger.warn("No session in db for the token or no token") logger.warn("No session in db for the token or no token")
return RedirectResponse(request.url_for("home")) return RedirectResponse(request.url_for("home"))

View file

@ -1,6 +1,6 @@
import logging import logging
from functools import cached_property from functools import cached_property
from typing import Self from typing import Self, Any
from pydantic import ( from pydantic import (
computed_field, computed_field,
@ -11,7 +11,7 @@ from pydantic import (
from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from sqlmodel import SQLModel, Field from sqlmodel import SQLModel, Field
logger = logging.getLogger(__name__) logger = logging.getLogger("oidc-test")
class Role(SQLModel, extra="ignore"): class Role(SQLModel, extra="ignore"):
@ -36,19 +36,9 @@ class User(UserBase):
) )
userinfo: dict = {} userinfo: dict = {}
access_token: str | None = None access_token: str | None = None
access_token_decoded: dict[str, Any] | None = None
oidc_provider: StarletteOAuth2App | None = None oidc_provider: StarletteOAuth2App | None = None
@classmethod
def from_auth(cls, userinfo: dict, oidc_provider: StarletteOAuth2App) -> Self:
user = cls(**userinfo)
user.userinfo = userinfo
user.oidc_provider = oidc_provider
# Add roles if they are provided in the token
if raw_ra := userinfo.get("realm_access"):
if raw_roles := raw_ra.get("roles"):
user.roles = [Role(name=raw_role) for raw_role in raw_roles]
return user
@computed_field @computed_field
@cached_property @cached_property
def roles_as_set(self) -> set[str]: def roles_as_set(self) -> set[str]:
@ -68,7 +58,7 @@ class User(UserBase):
assert self.access_token is not None assert self.access_token is not None
assert self.oidc_provider is not None assert self.oidc_provider is not None
assert self.oidc_provider.name is not None assert self.oidc_provider.name is not None
from .auth_utils import oidc_providers_settings from .settings import oidc_providers_settings
return oidc_providers_settings[self.oidc_provider.name].decode( return oidc_providers_settings[self.oidc_provider.name].decode(
self.access_token self.access_token

View file

@ -8,7 +8,7 @@ from starlette.status import HTTP_401_UNAUTHORIZED
from .models import User from .models import User
logger = logging.getLogger(__name__) logger = logging.getLogger("oidc-test")
async def get_resource(resource_id: str, user: User) -> dict: async def get_resource(resource_id: str, user: User) -> dict:

View file

@ -1,8 +1,9 @@
from os import environ from os import environ
import string import string
import random import random
from typing import Type, Tuple from typing import Type, Tuple, Any
from pathlib import Path from pathlib import Path
import logging
from jwt import decode from jwt import decode
from pydantic import BaseModel, computed_field, AnyUrl from pydantic import BaseModel, computed_field, AnyUrl
@ -16,6 +17,8 @@ from starlette.requests import Request
from .models import User from .models import User
logger = logging.getLogger("oidc-test")
class Resource(BaseModel): class Resource(BaseModel):
"""A resource with an URL that can be accessed with an OAuth2 access token""" """A resource with an URL that can be accessed with an OAuth2 access token"""
@ -86,14 +89,27 @@ class OIDCProvider(BaseModel):
-----END PUBLIC KEY----- -----END PUBLIC KEY-----
""" """
def decode(self, token: str) -> dict: def decode(self, token: str, verify_signature: bool = True) -> dict[str, Any]:
"""Decode the token with signature check""" """Decode the token with signature check"""
decoded = decode(
token,
self.get_public_key(),
algorithms=[self.signature_alg],
audience=["account", "oidc-test", "oidc-test-web"],
options={
"verify_signature": False,
"verify_aud": False,
}, # not settings.insecure.skip_verify_signature},
)
logger.debug(str(decoded))
return decode( return decode(
token, token,
self.get_public_key(), self.get_public_key(),
algorithms=[self.signature_alg], algorithms=[self.signature_alg],
audience=["oidc-test", "oidc-test-web"], audience=["account", "oidc-test", "oidc-test-web"],
options={"verify_signature": not settings.insecure.skip_verify_signature}, options={
"verify_signature": verify_signature,
}, # not settings.insecure.skip_verify_signature},
) )
@ -156,3 +172,8 @@ class Settings(BaseSettings):
settings = Settings() settings = Settings()
oidc_providers_settings: dict[str, OIDCProvider] = dict(
[(provider.id, provider) for provider in settings.oidc.providers]
)