{{ auth_provider.name }} provides these resources:
{{ auth_provider.name }} can request resources from third party resource providers:
+ {% for resource_provider in resource_providers %} +diff --git a/src/oidc_test/auth/provider.py b/src/oidc_test/auth/provider.py index 17dcaa0..c614805 100644 --- a/src/oidc_test/auth/provider.py +++ b/src/oidc_test/auth/provider.py @@ -7,7 +7,7 @@ from pydantic import ConfigDict from authlib.integrations.starlette_client.apps import StarletteOAuth2App from httpx import AsyncClient -from oidc_test.settings import AuthProviderSettings, settings +from oidc_test.settings import AuthProviderSettings, ResourceProvider, Resource, settings from oidc_test.models import User logger = logging.getLogger("oidc-test") @@ -24,6 +24,7 @@ class Provider(AuthProviderSettings): authlib_client: StarletteOAuth2App = StarletteOAuth2App(None) info: dict[str, Any] = {} unknown_auth_user: User + logout_with_id_token_hint: bool = True def decode(self, token: str, verify_signature: bool | None = None) -> dict[str, Any]: """Decode the token with signature check""" @@ -89,3 +90,18 @@ class Provider(AuthProviderSettings): def get_session_key(self, userinfo): return userinfo[self.session_key] + + def get_resource(self, resource_name: str) -> Resource: + return [ + resource for resource in self.resources if resource.resource_name == resource_name + ][0] + + def get_resource_url(self, resource_name: str) -> str: + return self.url + self.get_resource(resource_name).url + + def get_resource_provider(self, resource_provider_id: str) -> ResourceProvider: + return [ + provider + for provider in self.resource_providers + if provider.id == resource_provider_id + ][0] diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 79293a3..f930e48 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -107,7 +107,6 @@ async def home( context["resources"] = None else: context["access_token"] = token["access_token"] - # XXX: resources defined externally? I am confused... try: access_token_parsed = provider.decode(token["access_token"], verify_signature=False) except PyJWTError as err: @@ -118,7 +117,8 @@ async def home( context["access_token_scope"] = None context["id_token_parsed"] = provider.decode(token["id_token"], verify_signature=False) context["access_token_parsed"] = access_token_parsed - context["resource_providers"] = registry.resource_providers + context["resources"] = registry.resources + context["resource_providers"] = provider.resource_providers try: context["refresh_token_parsed"] = provider.decode( token["refresh_token"], verify_signature=False @@ -246,7 +246,7 @@ async def logout( if ( provider_logout_uri := provider.authlib_client.server_metadata.get("end_session_endpoint") ) is None: - logger.warn(f"Cannot find end_session_endpoint for provider {provider.id}") + logger.warning(f"Cannot find end_session_endpoint for provider {provider.id}") return RedirectResponse(request.url_for("non_compliant_logout")) post_logout_uri = request.url_for("home") # Clear session @@ -255,19 +255,15 @@ async def logout( try: token = await db.get_token(provider, request.session.pop("sid", None)) except TokenNotInDb: - logger.warn("No session in db for the token or no token") + logger.warning("No session in db for the token or no token") return RedirectResponse(request.url_for("home")) - logout_url = ( - provider_logout_uri - + "?" - + urlencode( - { - "post_logout_redirect_uri": post_logout_uri, - "id_token_hint": token["id_token"], - "client_id": "oidc_local_test", - } - ) - ) + url_query = { + "post_logout_redirect_uri": post_logout_uri, + "client_id": provider.client_id, + } + if provider.logout_with_id_token_hint: + url_query["id_token_hint"] = token["id_token"] + logout_url = f"{provider_logout_uri}?{urlencode(url_query)}" return RedirectResponse(logout_url) diff --git a/src/oidc_test/registry.py b/src/oidc_test/registry.py index 794a843..3b91ad4 100644 --- a/src/oidc_test/registry.py +++ b/src/oidc_test/registry.py @@ -18,7 +18,7 @@ class ProcessError(Exception): pass -class ResourceProvider(BaseModel): +class Resource(BaseModel): name: str scope_required: str | None = None role_required: str | None = None @@ -35,13 +35,13 @@ class ResourceProvider(BaseModel): class ResourceRegistry(BaseModel): - resource_providers: dict[str, ResourceProvider] = {} + resources: dict[str, Resource] = {} def make_registry(self): for ep in entry_points().select(group="oidc_test.resource_provider"): - ResourceProviderClass = ep.load() - if issubclass(ResourceProviderClass, ResourceProvider): - self.resource_providers[ep.name] = ResourceProviderClass(ep.name) + ResourceClass = ep.load() + if issubclass(ResourceClass, Resource): + self.resources[ep.name] = ResourceClass(ep.name) registry = ResourceRegistry() diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index ee4ff10..a4d5368 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -17,7 +17,7 @@ from oidc_test.auth.utils import ( oauth2_scheme_optional, ) from oidc_test.auth_providers import providers -from oidc_test.settings import settings +from oidc_test.settings import ResourceProvider, settings from oidc_test.models import User from oidc_test.registry import ProcessError, ProcessResult, registry @@ -48,7 +48,7 @@ resource_server.add_middleware( @resource_server.get("/") async def resources() -> dict[str, dict[str, Any]]: - return {"internal": {}, "plugins": registry.resource_providers} + return {"internal": {}, "plugins": registry.resources} @resource_server.get("/{resource_name}") @@ -65,8 +65,41 @@ async def get_resource( # Get the resource if it's defined in user auth provider's resources (external) if user is not None: provider = providers[user.auth_provider_id] + if ":" in resource_name: + # Third-party resource provider: send the request with the request token + resource_provider_id, resource_name = resource_name.split(":", 1) + provider = providers[user.auth_provider_id] + resource_provider: ResourceProvider = provider.get_resource_provider( + resource_provider_id + ) + resource_url = resource_provider.get_resource_url(resource_name) + async with AsyncClient(verify=resource_provider.verify_ssl) as client: + try: + resp = await client.get( + resource_url, + headers={ + "Content-type": "application/json", + "Authorization": f"Bearer {token}", + "auth_provider": user.auth_provider_id, + }, + ) + except Exception as err: + raise HTTPException( + status.HTTP_500_INTERNAL_SERVER_ERROR, err.__class__.__name__ + ) + else: + if resp.is_success: + return resp.json() + else: + reason_str: str + try: + reason_str = resp.json().get("detail", str(resp)) + except Exception: + reason_str = str(resp.text) + raise HTTPException(resp.status_code, reason_str) # Third party resource (provided through the auth provider) # The token is just passed on + # XXX: is this branch valid anymore? if resource_name in [r.resource_name for r in provider.resources]: return await get_auth_provider_resource( provider=provider, @@ -75,31 +108,31 @@ async def get_resource( user=user, ) # Internal resource (provided here) - if resource_name in registry.resource_providers: - resource_provider = registry.resource_providers[resource_name] + if resource_name in registry.resources: + resource = registry.resources[resource_name] reason: dict[str, str] = {} - if not resource_provider.is_public: + if not resource.is_public: if user is None: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Resource is not public") else: - if resource_provider.scope_required is not None and not user.has_scope( - resource_provider.scope_required + if resource.scope_required is not None and not user.has_scope( + resource.scope_required ): reason["scope"] = ( - f"No scope {resource_provider.scope_required} in the access token " + f"No scope {resource.scope_required} in the access token " "but it is required for accessing this resource" ) if ( - resource_provider.role_required is not None - and resource_provider.role_required not in user.roles_as_set + resource.role_required is not None + and resource.role_required not in user.roles_as_set ): reason["role"] = ( - f"You don't have the role {resource_provider.role_required} " + f"You don't have the role {resource.role_required} " "but it is required for accessing this resource" ) if len(reason) == 0: try: - resp = await resource_provider.process(user=user, resource_id=resource_id) + resp = await resource.process(user=user, resource_id=resource_id) return resp except ProcessError as err: raise HTTPException( @@ -116,12 +149,11 @@ async def get_auth_provider_resource( provider: Provider, resource_name: str, token: OAuth2Token | None, user: User ) -> ProcessResult: if token is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"No auth token") + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No auth token") access_token = token["access_token"] - resource = [r for r in provider.resources if r.resource_name == resource_name][0] async with AsyncClient() as client: resp = await client.get( - url=provider.url + resource.url, + url=provider.get_resource_url(resource_name), headers={ "Content-type": "application/json", "Authorization": f"Bearer {access_token}", diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index 3e7001c..e549fd4 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -4,7 +4,7 @@ import random from typing import Type, Tuple from pathlib import Path -from pydantic import BaseModel, computed_field, AnyUrl +from pydantic import AnyHttpUrl, BaseModel, computed_field, AnyUrl from pydantic_settings import ( BaseSettings, SettingsConfigDict, @@ -22,6 +22,22 @@ class Resource(BaseModel): url: str +class ResourceProvider(BaseModel): + id: str + name: str + base_url: AnyUrl + resources: list[Resource] = [] + verify_ssl: bool = True + + def get_resource(self, resource_name: str) -> Resource: + return [ + resource for resource in self.resources if resource.resource_name == resource_name + ][0] + + def get_resource_url(self, resource_name: str) -> str: + return f"{self.base_url}{self.get_resource(resource_name).url}" + + class AuthProviderSettings(BaseModel): """Auth provider, can also be a resource server""" @@ -45,6 +61,7 @@ class AuthProviderSettings(BaseModel): session_key: str = "sid" skip_verify_signature: bool = True disabled: bool = False + resource_providers: list[ResourceProvider] = [] @computed_field @property @@ -67,13 +84,6 @@ class AuthProviderSettings(BaseModel): return None -class ResourceProvider(BaseModel): - id: str - name: str - base_url: AnyUrl - resources: list[Resource] = [] - - class AuthSettings(BaseModel): show_session_details: bool = False providers: list[AuthProviderSettings] = [] @@ -92,13 +102,13 @@ class Settings(BaseSettings): model_config = SettingsConfigDict(env_nested_delimiter="__") auth: AuthSettings = AuthSettings() - resource_providers: list[ResourceProvider] = [] secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16)) log: bool = False insecure: Insecure = Insecure() cors_origins: list[str] = [] debug_token: bool = False show_token: bool = False + show_external_resource_providers_links: bool = False @classmethod def settings_customise_sources( diff --git a/src/oidc_test/static/utils.js b/src/oidc_test/static/utils.js index 978b61c..e988dfe 100644 --- a/src/oidc_test/static/utils.js +++ b/src/oidc_test/static/utils.js @@ -2,7 +2,9 @@ async function checkHref(elem, token, authProvider) { const msg = document.getElementById("msg") const resourceName = elem.getAttribute("resource-name") const resourceId = elem.getAttribute("resource-id") - const url = resourceId ? `resource/${resourceName}/${resourceId}` : `resource/${resourceName}` + const resourceProviderId = elem.getAttribute("resource-provider-id") ? elem.getAttribute("resource-provider-id") : "" + const fqResourceName = resourceProviderId ? `${resourceProviderId}:${resourceName}` : resourceName + const url = resourceId ? `resource/${fqResourceName}/${resourceId}` : `resource/${fqResourceName}` const resp = await fetch(url, { method: "GET", headers: new Headers({ @@ -30,11 +32,13 @@ function checkPerms(className, token, authProvider) { ) } -async function get_resource(resource_name, token, authProvider, resource_id) { +async function get_resource(resourceName, token, authProvider, resourceId, resourceProviderId) { + // BaseUrl for an external resource provider //if (!keycloak.keycloak) { return } const msg = document.getElementById("msg") const resourceElem = document.getElementById('resource') - const url = resource_id ? `resource/${resource_name}/${resource_id}` : `resource/${resource_name}` + const fqResourceName = resourceProviderId ? `${resourceProviderId}:${resourceName}` : resourceName + const url = resourceId ? `resource/${fqResourceName}/${resourceId}` : `resource/${fqResourceName}` const resp = await fetch(url, { method: "GET", headers: new Headers({ diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 6c4e6a6..5bccaee 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -66,29 +66,50 @@ {% endif %}
{{ auth_provider.name }} provides these resources:
{{ auth_provider.name }} can request resources from third party resource providers:
+ {% for resource_provider in resource_providers %} +