Get roles from access token, remove user info inspection, refreactorings
All checks were successful
/ build (push) Successful in 5s
/ test (push) Successful in 5s

This commit is contained in:
phil 2025-02-06 13:30:35 +01:00
parent 5c9ed9724e
commit ee8ba3d2df
6 changed files with 126 additions and 97 deletions

View file

@ -13,13 +13,10 @@ from authlib.oauth2.auth import OAuth2Token
from .models import User
from .database import TokenNotInDb, db, UserNotInDB
from .settings import settings, OIDCProvider
from .settings import settings, OIDCProvider, oidc_providers_settings
logger = logging.getLogger(__name__)
logger = logging.getLogger("oidc-test")
oidc_providers_settings: dict[str, OIDCProvider] = dict(
[(provider.id, provider) for provider in settings.oidc.providers]
)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
@ -36,19 +33,16 @@ async def fetch_token(name, request):
async def update_token(name, token, refresh_token=None, access_token=None):
breakpoint()
item = await db.get_token(token["id_token"])
if refresh_token:
item = OAuth2Token.find(name=name, refresh_token=refresh_token)
elif access_token:
item = OAuth2Token.find(name=name, access_token=access_token)
else:
return
oidc_provider_settings = oidc_providers_settings[name]
sid: str = oidc_provider_settings.decode(token["id_token"])["sid"]
item = await db.get_token(oidc_provider_settings, sid)
# update old token
item.access_token = token["access_token"]
item.refresh_token = token.get("refresh_token")
item.expires_at = token["expires_at"]
item.save()
item["access_token"] = token.get("access_token")
item["refresh_token"] = token.get("refresh_token")
item["expires_at"] = token["expires_at"]
logger.info(f"Token {sid} refreshed")
# It's a fake db and only in memory, so there's nothing to save
# await item.save()
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
@ -138,8 +132,17 @@ async def get_token(request: Request) -> OAuth2Token:
"""Return the token from a request object, from the session.
It can be used in Depends()"""
try:
return await db.get_token(request.session.get("token"))
except TokenNotInDb:
oidc_provider_settings = oidc_providers_settings[
request.session.get("oidc_provider_id", "")
]
except KeyError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid provider")
try:
return await db.get_token(
oidc_provider_settings,
request.session.get("sid"),
)
except (TokenNotInDb, InvalidKeyError):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token")
@ -190,14 +193,16 @@ async def get_user_from_token(
token: Annotated[str, Depends(oauth2_scheme)],
request: Request,
) -> User:
if (auth_provider_id := request.headers.get("auth_provider")) is None:
try:
auth_provider_id = request.headers["auth_provider"]
except KeyError:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED,
"Request headers must have a 'auth_provider' field",
)
if (
auth_provider_settings := oidc_providers_settings.get(auth_provider_id)
) is None:
try:
auth_provider_settings = oidc_providers_settings[auth_provider_id]
except KeyError:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
)
@ -216,7 +221,9 @@ async def get_user_from_token(
logger.info("Cannot decode token, see below")
logger.exception(err)
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot decode token")
if (user_id := payload.get("sub")) is None:
try:
user_id = payload["sub"]
except KeyError:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found"
)
@ -232,7 +239,6 @@ async def get_user_from_token(
sub=payload["sub"],
user_info=payload,
oidc_provider=getattr(authlib_oauth, auth_provider_id),
user_info_from_endpoint={},
access_token=token,
)
return user

View file

@ -3,11 +3,12 @@
import logging
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from .models import User, Role
from authlib.oauth2.rfc6749 import OAuth2Token
logger = logging.getLogger(__name__)
from .settings import OIDCProvider, oidc_providers_settings
from .models import User, Role
logger = logging.getLogger("oidc-test")
class UserNotInDB(Exception):
@ -29,20 +30,34 @@ class Database:
sub: str,
user_info: dict,
oidc_provider: StarletteOAuth2App,
user_info_from_endpoint: dict,
access_token: str,
access_token_decoded: dict | None = None,
) -> User:
user = User.from_auth(userinfo=user_info, oidc_provider=oidc_provider)
if access_token_decoded is None:
assert oidc_provider.name is not None
oidc_provider_settings = oidc_providers_settings[oidc_provider.name]
access_token_decoded = oidc_provider_settings.decode(access_token)
user = User(**user_info)
user.userinfo = user_info
user.oidc_provider = oidc_provider
user.access_token = access_token
user.access_token_decoded = access_token_decoded
# Add roles provided in the access token
roles = set()
try:
raw_roles = user_info_from_endpoint["resource_access"][
oidc_provider.client_id
]["roles"]
except Exception as err:
logger.debug(f"Cannot read additional roles: {err}")
raw_roles = []
for raw_role in raw_roles:
user.roles.append(Role(name=raw_role))
r = access_token_decoded["resource_access"][oidc_provider.client_id]["roles"]
roles.update(r)
except KeyError:
pass
try:
r = access_token_decoded["realm_access"]["roles"]
if isinstance(r, str):
roles.add(r)
else:
roles.update(r)
except KeyError:
pass
user.roles = [Role(name=role_name) for role_name in roles]
self.users[sub] = user
return user
@ -51,14 +66,21 @@ class Database:
raise UserNotInDB
return self.users[sub]
async def add_token(self, token: OAuth2Token, user: User) -> None:
self.tokens[token["id_token"]] = token
async def add_token(self, oidc_provider_settings: OIDCProvider, token: OAuth2Token) -> None:
"""Store a token using as key the sid (auth provider's session id)
in the id_token"""
sid = token["userinfo"]["sid"]
self.tokens[sid] = token
async def get_token(self, id_token: str | None) -> OAuth2Token:
if id_token is None:
async def get_token(
self,
oidc_provider_settings: OIDCProvider,
sid: str | None,
) -> OAuth2Token:
if sid is None:
raise TokenNotInDb
try:
return self.tokens[id_token]
return self.tokens[sid]
except KeyError:
raise TokenNotInDb

View file

@ -15,7 +15,7 @@ from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
from jwt import InvalidKeyError, InvalidTokenError
from jwt import InvalidTokenError, PyJWTError
from starlette.middleware.sessions import SessionMiddleware
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from authlib.integrations.base_client import OAuthError
@ -26,7 +26,7 @@ from authlib.oauth2.rfc6749 import OAuth2Token
# from fastapi.security import OpenIdConnect
# from pkce import generate_code_verifier, generate_pkce_pair
from .settings import settings
from .settings import settings, oidc_providers_settings
from .models import User
from .auth_utils import (
get_oidc_provider,
@ -37,14 +37,13 @@ from .auth_utils import (
get_user_from_token,
authlib_oauth,
get_token,
oidc_providers_settings,
get_providers_info,
)
from .auth_misc import pretty_details
from .database import TokenNotInDb, db
from .resource_server import get_resource
logger = logging.getLogger("uvicorn.error")
logger = logging.getLogger("oidc-test")
templates = Jinja2Templates(Path(__file__).parent / "templates")
@ -189,43 +188,28 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
request.session["oidc_provider_id"] = oidc_provider_id
# User id (sub) given by oidc provider
sub = userinfo["sub"]
# Get additional data from userinfo endpoint
try:
user_info_from_endpoint = await oidc_provider.userinfo(
token=token, follow_redirects=True
)
except Exception as err:
logger.warn(f"Cannot get userinfo from endpoint: {err}")
user_info_from_endpoint = {}
# Build and remember the user in the session
request.session["user_sub"] = sub
# Verify the token's signature and validity
# Store the user in the database, which also verifies the token validity and signature
try:
oidc_provider_settings = oidc_providers_settings[oidc_provider_id]
oidc_provider_settings.decode(token["access_token"])
except InvalidKeyError:
user = await db.add_user(
sub,
user_info=userinfo,
oidc_provider=oidc_provider,
access_token=token["access_token"],
)
except PyJWTError as err:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED,
detail="Token invalid key / signature",
detail=f"Token invalid: {err.__class__.__name__}",
)
except Exception as err:
logger.exception(err)
raise HTTPException(
status.HTTP_401_UNAUTHORIZED,
detail="Cannot decode token or verify its signature",
)
# Store the user in the database
user = await db.add_user(
sub,
user_info=userinfo,
oidc_provider=oidc_provider,
user_info_from_endpoint=user_info_from_endpoint,
access_token=token["access_token"],
)
# Add the id_token to the session
request.session["token"] = token["id_token"]
assert isinstance(user, User)
# Add the provider session id to the session
request.session["sid"] = userinfo["sid"]
# Add the token to the db because it is used for logout
await db.add_token(token, user)
assert oidc_provider.name is not None
oidc_provider_settings = oidc_providers_settings[oidc_provider.name]
await db.add_token(oidc_provider_settings, token)
# Send the user to the home: (s)he is authenticated
return RedirectResponse(url=request.url_for("home"))
else:
@ -268,8 +252,14 @@ async def logout(
)
return RedirectResponse(request.url_for("non_compliant_logout"))
post_logout_uri = request.url_for("home")
oidc_provider_settings = oidc_providers_settings.get(
request.session.get("oidc_provider_id", "")
)
assert oidc_provider_settings is not None
try:
token = await db.get_token(request.session.pop("token", None))
token = await db.get_token(
oidc_provider_settings, request.session.pop("sid", None)
)
except TokenNotInDb:
logger.warn("No session in db for the token or no token")
return RedirectResponse(request.url_for("home"))

View file

@ -1,6 +1,6 @@
import logging
from functools import cached_property
from typing import Self
from typing import Self, Any
from pydantic import (
computed_field,
@ -11,7 +11,7 @@ from pydantic import (
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from sqlmodel import SQLModel, Field
logger = logging.getLogger(__name__)
logger = logging.getLogger("oidc-test")
class Role(SQLModel, extra="ignore"):
@ -36,19 +36,9 @@ class User(UserBase):
)
userinfo: dict = {}
access_token: str | None = None
access_token_decoded: dict[str, Any] | None = None
oidc_provider: StarletteOAuth2App | None = None
@classmethod
def from_auth(cls, userinfo: dict, oidc_provider: StarletteOAuth2App) -> Self:
user = cls(**userinfo)
user.userinfo = userinfo
user.oidc_provider = oidc_provider
# Add roles if they are provided in the token
if raw_ra := userinfo.get("realm_access"):
if raw_roles := raw_ra.get("roles"):
user.roles = [Role(name=raw_role) for raw_role in raw_roles]
return user
@computed_field
@cached_property
def roles_as_set(self) -> set[str]:
@ -68,7 +58,7 @@ class User(UserBase):
assert self.access_token is not None
assert self.oidc_provider is not None
assert self.oidc_provider.name is not None
from .auth_utils import oidc_providers_settings
from .settings import oidc_providers_settings
return oidc_providers_settings[self.oidc_provider.name].decode(
self.access_token

View file

@ -8,7 +8,7 @@ from starlette.status import HTTP_401_UNAUTHORIZED
from .models import User
logger = logging.getLogger(__name__)
logger = logging.getLogger("oidc-test")
async def get_resource(resource_id: str, user: User) -> dict:

View file

@ -1,8 +1,9 @@
from os import environ
import string
import random
from typing import Type, Tuple
from typing import Type, Tuple, Any
from pathlib import Path
import logging
from jwt import decode
from pydantic import BaseModel, computed_field, AnyUrl
@ -16,6 +17,8 @@ from starlette.requests import Request
from .models import User
logger = logging.getLogger("oidc-test")
class Resource(BaseModel):
"""A resource with an URL that can be accessed with an OAuth2 access token"""
@ -86,14 +89,27 @@ class OIDCProvider(BaseModel):
-----END PUBLIC KEY-----
"""
def decode(self, token: str) -> dict:
def decode(self, token: str, verify_signature: bool = True) -> dict[str, Any]:
"""Decode the token with signature check"""
decoded = decode(
token,
self.get_public_key(),
algorithms=[self.signature_alg],
audience=["account", "oidc-test", "oidc-test-web"],
options={
"verify_signature": False,
"verify_aud": False,
}, # not settings.insecure.skip_verify_signature},
)
logger.debug(str(decoded))
return decode(
token,
self.get_public_key(),
algorithms=[self.signature_alg],
audience=["oidc-test", "oidc-test-web"],
options={"verify_signature": not settings.insecure.skip_verify_signature},
audience=["account", "oidc-test", "oidc-test-web"],
options={
"verify_signature": verify_signature,
}, # not settings.insecure.skip_verify_signature},
)
@ -156,3 +172,8 @@ class Settings(BaseSettings):
settings = Settings()
oidc_providers_settings: dict[str, OIDCProvider] = dict(
[(provider.id, provider) for provider in settings.oidc.providers]
)