diff --git a/src/oidc_test/auth/utils.py b/src/oidc_test/auth/utils.py index 9479c48..acd68b5 100644 --- a/src/oidc_test/auth/utils.py +++ b/src/oidc_test/auth/utils.py @@ -8,7 +8,7 @@ from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOA from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError # from authlib.oauth1.auth import OAuthToken -from authlib.oauth2.auth import OAuth2Token +from authlib.oauth2.rfc6749 import OAuth2Token from oidc_test.auth.provider import Provider from oidc_test.models import User @@ -60,25 +60,28 @@ def init_providers(): sub="", auth_provider_id=provider_settings.id ) provider = Provider(**provider_settings_dict) - authlib_oauth.register( - name=provider.id, - server_metadata_url=provider.openid_configuration, - client_kwargs={ - "scope": " ".join( - ["openid", "email", "offline_access", "profile"] - + provider.resource_provider_scopes - ), - }, - client_id=provider.client_id, - client_secret=provider.client_secret, - api_base_url=provider.url, - # For PKCE (not implemented yet): - # code_challenge_method="S256", - fetch_token=fetch_token, - update_token=update_token, - # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) - ) - provider.authlib_client = getattr(authlib_oauth, provider.id) + if provider.disabled: + logger.info(f"{provider_settings.name} is disabled, skipping") + else: + authlib_oauth.register( + name=provider.id, + server_metadata_url=provider.openid_configuration, + client_kwargs={ + "scope": " ".join( + ["openid", "email", "offline_access", "profile"] + + provider.resource_provider_scopes + ), + }, + client_id=provider.client_id, + client_secret=provider.client_secret, + api_base_url=provider.url, + # For PKCE (not implemented yet): + # code_challenge_method="S256", + fetch_token=fetch_token, + update_token=update_token, + # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) + ) + provider.authlib_client = getattr(authlib_oauth, provider.id) providers[provider.id] = provider @@ -270,6 +273,14 @@ async def get_user_from_token( ) return user +async def get_user_from_token_or_none( + token: Annotated[str, Depends(oauth2_scheme)], + request: Request, +) -> User | None: + try: + return await get_user_from_token(token, request) + except HTTPException: + return None class UserWithRole: roles: set[str] diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 3858a08..9e8b135 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -55,6 +55,8 @@ async def lifespan(app: FastAPI): init_providers() registry.make_registry() for provider in list(providers.values()): + if provider.disabled: + continue try: await provider.get_info() except NoPublicKey: @@ -106,7 +108,6 @@ async def home( else: context["access_token"] = token["access_token"] # XXX: resources defined externally? I am confused... - context["resources"] = provider.resources try: access_token_parsed = provider.decode(token["access_token"], verify_signature=False) except PyJWTError as err: diff --git a/src/oidc_test/registry.py b/src/oidc_test/registry.py index a184ec0..e9c9809 100644 --- a/src/oidc_test/registry.py +++ b/src/oidc_test/registry.py @@ -18,15 +18,18 @@ class ProcessError(Exception): class ResourceProvider(BaseModel): + name: str scope_required: str | None = None + role_required: str | None = None + is_public: bool = False default_resource_id: str | None = None def __init__(self, name: str): super().__init__() - self.__name__ = name + self.__id__ = name async def process(self, user: User, resource_id: str | None = None) -> ProcessResult: - logger.warning(f"{self.__name__} should define a process method") + logger.warning(f"{self.__id__} should define a process method") return ProcessResult() diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index 1af0f6b..03d109e 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -1,7 +1,7 @@ from typing import Annotated, Any import logging -from authlib.oauth2.auth import OAuth2Token +from authlib.oauth2.rfc6749 import OAuth2Token from httpx import AsyncClient from jwt.exceptions import ExpiredSignatureError, InvalidTokenError from fastapi import FastAPI, HTTPException, Depends, Request, status @@ -16,6 +16,7 @@ from oidc_test.auth.utils import ( get_token_or_none, get_user_from_token, UserWithRole, + get_user_from_token_or_none, ) from oidc_test.auth_providers import providers from oidc_test.settings import settings @@ -55,54 +56,12 @@ async def resources() -> dict[str, dict[str, Any]]: } -@resource_server.get("/public") -async def public() -> dict: - return {"msg": "Not protected"} - - -@resource_server.get("/protected") -async def get_protected(user: Annotated[User, Depends(get_user_from_token)]): - assert user is not None # Just to keep QA checks happy - return {"msg": "Only authenticated users can see this"} - - -@resource_server.get("/protected-by-foorole") -async def get_protected_by_foorole( - user: Annotated[User, Depends(UserWithRole("foorole"))], -): - assert user is not None - return {"msg": "Only users with foorole can see this"} - - -@resource_server.get("/protected-by-barrole") -async def get_protected_by_barrole( - user: Annotated[User, Depends(UserWithRole("barrole"))], -): - assert user is not None - return {"msg": "Protected by barrole"} - - -@resource_server.get("/protected-by-foorole-and-barrole") -async def get_protected_by_foorole_and_barrole( - user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))], -): - assert user is not None # Just to keep QA checks happy - return {"msg": "Only users with foorole and barrole can see this"} - - -@resource_server.get("/protected-by-foorole-or-barrole") -async def get_protected_by_foorole_or_barrole( - user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))], -): - assert user is not None # Just to keep QA checks happy - return {"msg": "Only users with foorole or barrole can see this"} - @resource_server.get("/{resource_name}") @resource_server.get("/{resource_name}/{resource_id}") async def get_resource( resource_name: str, - user: Annotated[User, Depends(get_user_from_token)], + user: Annotated[User, Depends(get_user_from_token_or_none)], token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], request: Request, resource_id: str | None = None, @@ -111,19 +70,29 @@ async def get_resource( provider = providers[user.auth_provider_id] # Third party resource (provided through the auth provider) # The token is just passed on + breakpoint() if resource_name in [r.resource_name for r in provider.resources]: return await get_auth_provider_resource( provider=provider, resource_name=resource_name, - access_token=token["access_token"] if token else None, + token=token, user=user, ) # Internal resource (provided here) if resource_name in registry.resource_providers: resource_provider = registry.resource_providers[resource_name] - if resource_provider.scope_required is not None and user.has_scope( - resource_provider.scope_required - ): + reasons: dict[str, str] = {} + if not resource_provider.is_public: + if resource_provider.scope_required is not None and not user.has_scope( + resource_provider.scope_required + ): + reasons["scope"] = f"No scope {resource_provider.scope_required} in the access token " \ + "but it is required for accessing this resource" + if resource_provider.role_required is not None \ + and resource_provider.role_required not in user.roles_as_set: + reasons["role"] = f"You don't have the role {resource_provider.role_required} " \ + "but it is required for accessing this resource" + if len(reasons) == 0: try: return await resource_provider.process(user=user, resource_id=resource_id) except ProcessError as err: @@ -132,9 +101,7 @@ async def get_resource( ) else: raise HTTPException( - status.HTTP_401_UNAUTHORIZED, - f"No scope {resource_provider.scope_required} in the access token " - + "but it is required for accessing this resource", + status.HTTP_401_UNAUTHORIZED, ", ".join(reasons.values()) ) else: raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown resource {resource_name}") @@ -142,8 +109,13 @@ async def get_resource( async def get_auth_provider_resource( - provider: Provider, resource_name: str, access_token: str | None, user: User + provider: Provider, resource_name: str, token: OAuth2Token | None, user: User ) -> ProcessResult: + if token is None: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, f"No auth token" + ) + access_token = token["access_token"] resource = [r for r in provider.resources if r.resource_name == resource_name][0] async with AsyncClient() as client: resp = await client.get( @@ -165,6 +137,48 @@ async def get_auth_provider_resource( return ProcessResult(result=resp.json()) +#@resource_server.get("/public") +#async def public() -> dict: +# return {"msg": "Not protected"} +# +# +#@resource_server.get("/protected") +#async def get_protected(user: Annotated[User, Depends(get_user_from_token)]): +# assert user is not None # Just to keep QA checks happy +# return {"msg": "Only authenticated users can see this"} +# +# +#@resource_server.get("/protected-by-foorole") +#async def get_protected_by_foorole( +# user: Annotated[User, Depends(UserWithRole("foorole"))], +#): +# assert user is not None +# return {"msg": "Only users with foorole can see this"} +# +# +#@resource_server.get("/protected-by-barrole") +#async def get_protected_by_barrole( +# user: Annotated[User, Depends(UserWithRole("barrole"))], +#): +# assert user is not None +# return {"msg": "Protected by barrole"} +# +# +#@resource_server.get("/protected-by-foorole-and-barrole") +#async def get_protected_by_foorole_and_barrole( +# user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))], +#): +# assert user is not None # Just to keep QA checks happy +# return {"msg": "Only users with foorole and barrole can see this"} +# +# +#@resource_server.get("/protected-by-foorole-or-barrole") +#async def get_protected_by_foorole_or_barrole( +# user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))], +#): +# assert user is not None # Just to keep QA checks happy +# return {"msg": "Only users with foorole or barrole can see this"} + # async def get_resource_(resource_id: str, user: User, **kwargs) -> dict: # """ # Resource processing: build an informative rely as a simple showcase diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index 2acbc3f..3e7001c 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -44,6 +44,7 @@ class AuthProviderSettings(BaseModel): resource_provider_scopes: list[str] = [] session_key: str = "sid" skip_verify_signature: bool = True + disabled: bool = False @computed_field @property diff --git a/src/oidc_test/static/styles.css b/src/oidc_test/static/styles.css index e163a68..2baa748 100644 --- a/src/oidc_test/static/styles.css +++ b/src/oidc_test/static/styles.css @@ -142,19 +142,27 @@ hr { .providers .provider { min-height: 2em; } -.providers .provider a.link { +.providers .provider .link { text-decoration: none; max-height: 2em; } -.providers .provider .link div { +.providers .provider .link { background-color: #f7c7867d; border-radius: 8px; padding: 6px; text-align: center; color: black; - font-weight: bold; + font-weight: 400; cursor: pointer; + border: 0; + width: 100%; } + +.providers .provider .link.disabled { + color: gray; + cursor: not-allowed; +} + .providers .provider .hint { font-size: 80%; max-width: 13em; diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 23ba7ff..ecefb0f 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -11,7 +11,11 @@ {% for provider in providers.values() %} -
{{ provider.name }}
+ {{ provider.hint }} @@ -62,42 +66,17 @@ {% endif %}
-

- Resources validated by role: -

- - {% if resource_providers %}

- Resource providers (validated by scope): + Resource providers:

{% endif %}