From b5f2e5b57bc1d62c7f5c08ba0c24aac1e4e209ba Mon Sep 17 00:00:00 2001 From: phil Date: Mon, 13 Jan 2025 05:37:55 +0100 Subject: [PATCH] Get more user info from provider's userinfo endpoint --- src/oidc_test/database.py | 20 ++++++++++++++++++-- src/oidc_test/main.py | 25 +++++++++++++++++++++---- src/oidc_test/settings.py | 5 +++-- 3 files changed, 42 insertions(+), 8 deletions(-) diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index 1aae7cc..7c7bc8a 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -1,8 +1,11 @@ # Implement a fake in-memory database interface for demo purpose +import logging from authlib.integrations.starlette_client.apps import StarletteOAuth2App -from .models import User, OAuth2Token +from .models import User, OAuth2Token, Role + +logger = logging.getLogger(__name__) class Database: @@ -12,9 +15,22 @@ class Database: # Last sessions for the user (key: users's subject id (sub)) async def add_user( - self, sub: str, user_info: dict, oidc_provider: StarletteOAuth2App + self, + sub: str, + user_info: dict, + oidc_provider: StarletteOAuth2App, + user_info_from_endpoint: dict, ) -> User: user = User.from_auth(userinfo=user_info, oidc_provider=oidc_provider) + try: + raw_roles = user_info_from_endpoint["resource_access"][ + oidc_provider.client_id + ]["roles"] + except Exception as err: + logger.debug(f"Cannot read additional roles: {err}") + raw_roles = [] + for raw_role in raw_roles: + user.roles.append(Role(name=raw_role)) self.users[sub] = user return user diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 9b75985..70ff16f 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -32,8 +32,7 @@ from .auth_utils import ( from .auth_misc import pretty_details from .database import db -# logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) +logger = logging.getLogger("uvicorn.error") templates = Jinja2Templates(Path(__file__).parent / "templates") @@ -59,10 +58,11 @@ for provider in settings.oidc.providers: name=provider.id, server_metadata_url=provider.openid_configuration, client_kwargs={ - "scope": "openid email offline_access profile roles", + "scope": "openid email offline_access profile", }, client_id=provider.client_id, client_secret=provider.client_secret, + api_base_url=provider.url, # fetch_token=fetch_token, # update_token=update_token, # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) @@ -111,6 +111,7 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: except OAuthError as error: raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error) # Remember the oidc_provider in the session + # logger.debug(f"Scope: {token['scope']}") request.session["oidc_provider_id"] = oidc_provider_id # # One could process the full decoded token which contains extra information @@ -119,10 +120,26 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: if userinfo := token.get("userinfo"): # sub given by oidc provider sub = userinfo["sub"] + # Get additional data from userinfo endpoint + try: + user_info_url = oidc_provider.server_metadata["userinfo_endpoint"] + user_info_from_endpoint = ( + await oidc_provider.get( + user_info_url, token=token, follow_redirects=True + ) + ).json() + except Exception as err: + logger.info(f"Cannot get userinfo from endpoint: {err}") + user_info_from_endpoint = {} # Build and remember the user in the session request.session["user_sub"] = sub # Store the user in the database - user = await db.add_user(sub, user_info=userinfo, oidc_provider=oidc_provider) + user = await db.add_user( + sub, + user_info=userinfo, + oidc_provider=oidc_provider, + user_info_from_endpoint=user_info_from_endpoint, + ) request.session["token"] = userinfo["sub"] await db.add_token(token, user) return RedirectResponse(url=request.url_for("home")) diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index ba2d6a0..d01e19f 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -19,7 +19,7 @@ class OIDCProvider(BaseModel): url: str client_id: str client_secret: str = "" - hint: str = "Use your own credentials" + hint: str = "No hint" @computed_field @property @@ -29,7 +29,7 @@ class OIDCProvider(BaseModel): @computed_field @property def token_url(self) -> str: - return "auth/" + self.name + return "auth/" + self.id class OIDCSettings(BaseModel): @@ -43,6 +43,7 @@ class Settings(BaseSettings): oidc: OIDCSettings = OIDCSettings() secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16)) + log: bool = False model_config = SettingsConfigDict(env_nested_delimiter="__")