From c89ca4098b2165014890af89caebde7310b88db0 Mon Sep 17 00:00:00 2001 From: phil Date: Fri, 14 Feb 2025 13:21:55 +0100 Subject: [PATCH] Fix public resource access; free resource response validation; formatting --- src/oidc_test/auth/utils.py | 15 ++-- src/oidc_test/main.py | 10 +-- src/oidc_test/registry.py | 9 ++- src/oidc_test/resource_server.py | 127 ++++++++++++++++--------------- 4 files changed, 84 insertions(+), 77 deletions(-) diff --git a/src/oidc_test/auth/utils.py b/src/oidc_test/auth/utils.py index acd68b5..7dd0e3d 100644 --- a/src/oidc_test/auth/utils.py +++ b/src/oidc_test/auth/utils.py @@ -87,6 +87,7 @@ def init_providers(): authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") +oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None: @@ -125,7 +126,7 @@ async def get_current_user(request: Request) -> User: """ if (user_sub := request.session.get("user_sub")) is None: raise HTTPException(status.HTTP_401_UNAUTHORIZED) - token = await get_token(request) + token = await get_token_from_session(request) user = await db.get_user(user_sub) ## Check if the token is expired if token.is_expired(): @@ -146,16 +147,16 @@ async def get_current_user(request: Request) -> User: return user -async def get_token_or_none(request: Request) -> OAuth2Token | None: +async def get_token_from_session_or_none(request: Request) -> OAuth2Token | None: """Return the auth token from the session or None. Can be used in Depends()""" try: - return await get_token(request) + return await get_token_from_session(request) except HTTPException: return None -async def get_token(request: Request) -> OAuth2Token: +async def get_token_from_session(request: Request) -> OAuth2Token: """Return the token from the session. Can be used in Depends()""" try: @@ -273,15 +274,19 @@ async def get_user_from_token( ) return user + async def get_user_from_token_or_none( - token: Annotated[str, Depends(oauth2_scheme)], + token: Annotated[str | None, Depends(oauth2_scheme_optional)], request: Request, ) -> User | None: + if token is None: + return None try: return await get_user_from_token(token, request) except HTTPException: return None + class UserWithRole: roles: set[str] diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 9e8b135..9f5e746 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -33,8 +33,8 @@ from oidc_test.auth.utils import ( get_auth_provider_or_none, get_current_user_or_none, authlib_oauth, - get_token_or_none, - get_token, + get_token_from_session_or_none, + get_token_from_session, update_token, ) from oidc_test.auth.utils import init_providers @@ -88,9 +88,9 @@ app.mount("/resource", resource_server, name="resource_server") @app.get("/") async def home( request: Request, - user: Annotated[User, Depends(get_current_user_or_none)], + user: Annotated[User | None, Depends(get_current_user_or_none)], provider: Annotated[Provider | None, Depends(get_auth_provider_or_none)], - token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], + token: Annotated[OAuth2Token | None, Depends(get_token_from_session_or_none)], ) -> HTMLResponse: context = { "show_token": settings.show_token, @@ -291,7 +291,7 @@ async def non_compliant_logout( async def refresh( request: Request, provider: Annotated[Provider, Depends(get_auth_provider)], - token: Annotated[OAuth2Token, Depends(get_token)], + token: Annotated[OAuth2Token, Depends(get_token_from_session)], ) -> RedirectResponse: """Manually refresh token""" new_token = await provider.authlib_client.fetch_access_token( diff --git a/src/oidc_test/registry.py b/src/oidc_test/registry.py index e9c9809..794a843 100644 --- a/src/oidc_test/registry.py +++ b/src/oidc_test/registry.py @@ -1,8 +1,7 @@ from importlib.metadata import entry_points import logging -from typing import Any -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from oidc_test.models import User @@ -10,7 +9,9 @@ logger = logging.getLogger("registry") class ProcessResult(BaseModel): - result: dict[str, Any] = {} + model_config = ConfigDict( + extra="allow", + ) class ProcessError(Exception): @@ -28,7 +29,7 @@ class ResourceProvider(BaseModel): super().__init__() self.__id__ = name - async def process(self, user: User, resource_id: str | None = None) -> ProcessResult: + async def process(self, user: User | None, resource_id: str | None = None) -> ProcessResult: logger.warning(f"{self.__id__} should define a process method") return ProcessResult() diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index 1877875..ee4ff10 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -13,15 +13,13 @@ from fastapi.middleware.cors import CORSMiddleware from oidc_test.auth.provider import Provider from oidc_test.auth.utils import ( - get_token_or_none, - get_user_from_token, - UserWithRole, get_user_from_token_or_none, + oauth2_scheme_optional, ) from oidc_test.auth_providers import providers from oidc_test.settings import settings from oidc_test.models import User -from oidc_test.registry import ProcessError, ProcessResult, ResourceProvider, registry +from oidc_test.registry import ProcessError, ProcessResult, registry logger = logging.getLogger("oidc-test") @@ -50,60 +48,67 @@ resource_server.add_middleware( @resource_server.get("/") async def resources() -> dict[str, dict[str, Any]]: - return { - "internal": {}, - "plugins": registry.resource_providers - } - + return {"internal": {}, "plugins": registry.resource_providers} @resource_server.get("/{resource_name}") @resource_server.get("/{resource_name}/{resource_id}") async def get_resource( resource_name: str, - user: Annotated[User, Depends(get_user_from_token_or_none)], - token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], - request: Request, + user: Annotated[User | None, Depends(get_user_from_token_or_none)], + token: Annotated[OAuth2Token | None, Depends(oauth2_scheme_optional)], resource_id: str | None = None, -) -> ProcessResult: - """Generic path for testing a resource provided by a provider""" - provider = providers[user.auth_provider_id] - # Third party resource (provided through the auth provider) - # The token is just passed on - if resource_name in [r.resource_name for r in provider.resources]: - return await get_auth_provider_resource( - provider=provider, - resource_name=resource_name, - token=token, - user=user, - ) +): + """Generic path for testing a resource provided by a provider. + There's no field validation (response type of ProcessResult) on purpose, + leaving the responsibility of the response validation to resource providers""" + # Get the resource if it's defined in user auth provider's resources (external) + if user is not None: + provider = providers[user.auth_provider_id] + # Third party resource (provided through the auth provider) + # The token is just passed on + if resource_name in [r.resource_name for r in provider.resources]: + return await get_auth_provider_resource( + provider=provider, + resource_name=resource_name, + token=token, + user=user, + ) # Internal resource (provided here) if resource_name in registry.resource_providers: resource_provider = registry.resource_providers[resource_name] - reasons: dict[str, str] = {} + reason: dict[str, str] = {} if not resource_provider.is_public: - if resource_provider.scope_required is not None and not user.has_scope( - resource_provider.scope_required - ): - reasons["scope"] = f"No scope {resource_provider.scope_required} in the access token " \ - "but it is required for accessing this resource" - if resource_provider.role_required is not None \ - and resource_provider.role_required not in user.roles_as_set: - reasons["role"] = f"You don't have the role {resource_provider.role_required} " \ - "but it is required for accessing this resource" - if len(reasons) == 0: + if user is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Resource is not public") + else: + if resource_provider.scope_required is not None and not user.has_scope( + resource_provider.scope_required + ): + reason["scope"] = ( + f"No scope {resource_provider.scope_required} in the access token " + "but it is required for accessing this resource" + ) + if ( + resource_provider.role_required is not None + and resource_provider.role_required not in user.roles_as_set + ): + reason["role"] = ( + f"You don't have the role {resource_provider.role_required} " + "but it is required for accessing this resource" + ) + if len(reason) == 0: try: - return await resource_provider.process(user=user, resource_id=resource_id) + resp = await resource_provider.process(user=user, resource_id=resource_id) + return resp except ProcessError as err: raise HTTPException( status.HTTP_401_UNAUTHORIZED, f"Cannot process resource: {err}" ) else: - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, ", ".join(reasons.values()) - ) + raise HTTPException(status.HTTP_401_UNAUTHORIZED, ", ".join(reason.values())) else: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown resource {resource_name}") + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Unknown resource") # return await get_resource_(resource_name, user, **request.query_params) @@ -111,9 +116,7 @@ async def get_auth_provider_resource( provider: Provider, resource_name: str, token: OAuth2Token | None, user: User ) -> ProcessResult: if token is None: - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, f"No auth token" - ) + raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"No auth token") access_token = token["access_token"] resource = [r for r in provider.resources if r.resource_name == resource_name][0] async with AsyncClient() as client: @@ -129,52 +132,50 @@ async def get_auth_provider_resource( # Only a demo, real application would really process the response resp_length = len(resp.text) if resp_length > 1024: - return ProcessResult( - result={"msg": f"The resource is too long ({resp_length} bytes) to show here"} - ) + return ProcessResult(msg=f"The resource is too long ({resp_length} bytes) to show here") else: - return ProcessResult(result=resp.json()) + return ProcessResult(**resp.json()) -#@resource_server.get("/public") -#async def public() -> dict: +# @resource_server.get("/public") +# async def public() -> dict: # return {"msg": "Not protected"} # # -#@resource_server.get("/protected") -#async def get_protected(user: Annotated[User, Depends(get_user_from_token)]): +# @resource_server.get("/protected") +# async def get_protected(user: Annotated[User, Depends(get_user_from_token)]): # assert user is not None # Just to keep QA checks happy # return {"msg": "Only authenticated users can see this"} # # -#@resource_server.get("/protected-by-foorole") -#async def get_protected_by_foorole( +# @resource_server.get("/protected-by-foorole") +# async def get_protected_by_foorole( # user: Annotated[User, Depends(UserWithRole("foorole"))], -#): +# ): # assert user is not None # return {"msg": "Only users with foorole can see this"} # # -#@resource_server.get("/protected-by-barrole") -#async def get_protected_by_barrole( +# @resource_server.get("/protected-by-barrole") +# async def get_protected_by_barrole( # user: Annotated[User, Depends(UserWithRole("barrole"))], -#): +# ): # assert user is not None # return {"msg": "Protected by barrole"} # # -#@resource_server.get("/protected-by-foorole-and-barrole") -#async def get_protected_by_foorole_and_barrole( +# @resource_server.get("/protected-by-foorole-and-barrole") +# async def get_protected_by_foorole_and_barrole( # user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))], -#): +# ): # assert user is not None # Just to keep QA checks happy # return {"msg": "Only users with foorole and barrole can see this"} # # -#@resource_server.get("/protected-by-foorole-or-barrole") -#async def get_protected_by_foorole_or_barrole( +# @resource_server.get("/protected-by-foorole-or-barrole") +# async def get_protected_by_foorole_or_barrole( # user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))], -#): +# ): # assert user is not None # Just to keep QA checks happy # return {"msg": "Only users with foorole or barrole can see this"}