Fix public resource access; free resource response validation; formatting
All checks were successful
/ build (push) Successful in 5s
/ test (push) Successful in 5s

This commit is contained in:
phil 2025-02-14 13:21:55 +01:00
parent 40ddb61636
commit c89ca4098b
4 changed files with 84 additions and 77 deletions

View file

@ -87,6 +87,7 @@ def init_providers():
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None: def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None:
@ -125,7 +126,7 @@ async def get_current_user(request: Request) -> User:
""" """
if (user_sub := request.session.get("user_sub")) is None: if (user_sub := request.session.get("user_sub")) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED) raise HTTPException(status.HTTP_401_UNAUTHORIZED)
token = await get_token(request) token = await get_token_from_session(request)
user = await db.get_user(user_sub) user = await db.get_user(user_sub)
## Check if the token is expired ## Check if the token is expired
if token.is_expired(): if token.is_expired():
@ -146,16 +147,16 @@ async def get_current_user(request: Request) -> User:
return user return user
async def get_token_or_none(request: Request) -> OAuth2Token | None: async def get_token_from_session_or_none(request: Request) -> OAuth2Token | None:
"""Return the auth token from the session or None. """Return the auth token from the session or None.
Can be used in Depends()""" Can be used in Depends()"""
try: try:
return await get_token(request) return await get_token_from_session(request)
except HTTPException: except HTTPException:
return None return None
async def get_token(request: Request) -> OAuth2Token: async def get_token_from_session(request: Request) -> OAuth2Token:
"""Return the token from the session. """Return the token from the session.
Can be used in Depends()""" Can be used in Depends()"""
try: try:
@ -273,15 +274,19 @@ async def get_user_from_token(
) )
return user return user
async def get_user_from_token_or_none( async def get_user_from_token_or_none(
token: Annotated[str, Depends(oauth2_scheme)], token: Annotated[str | None, Depends(oauth2_scheme_optional)],
request: Request, request: Request,
) -> User | None: ) -> User | None:
if token is None:
return None
try: try:
return await get_user_from_token(token, request) return await get_user_from_token(token, request)
except HTTPException: except HTTPException:
return None return None
class UserWithRole: class UserWithRole:
roles: set[str] roles: set[str]

View file

@ -33,8 +33,8 @@ from oidc_test.auth.utils import (
get_auth_provider_or_none, get_auth_provider_or_none,
get_current_user_or_none, get_current_user_or_none,
authlib_oauth, authlib_oauth,
get_token_or_none, get_token_from_session_or_none,
get_token, get_token_from_session,
update_token, update_token,
) )
from oidc_test.auth.utils import init_providers from oidc_test.auth.utils import init_providers
@ -88,9 +88,9 @@ app.mount("/resource", resource_server, name="resource_server")
@app.get("/") @app.get("/")
async def home( async def home(
request: Request, request: Request,
user: Annotated[User, Depends(get_current_user_or_none)], user: Annotated[User | None, Depends(get_current_user_or_none)],
provider: Annotated[Provider | None, Depends(get_auth_provider_or_none)], provider: Annotated[Provider | None, Depends(get_auth_provider_or_none)],
token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], token: Annotated[OAuth2Token | None, Depends(get_token_from_session_or_none)],
) -> HTMLResponse: ) -> HTMLResponse:
context = { context = {
"show_token": settings.show_token, "show_token": settings.show_token,
@ -291,7 +291,7 @@ async def non_compliant_logout(
async def refresh( async def refresh(
request: Request, request: Request,
provider: Annotated[Provider, Depends(get_auth_provider)], provider: Annotated[Provider, Depends(get_auth_provider)],
token: Annotated[OAuth2Token, Depends(get_token)], token: Annotated[OAuth2Token, Depends(get_token_from_session)],
) -> RedirectResponse: ) -> RedirectResponse:
"""Manually refresh token""" """Manually refresh token"""
new_token = await provider.authlib_client.fetch_access_token( new_token = await provider.authlib_client.fetch_access_token(

View file

@ -1,8 +1,7 @@
from importlib.metadata import entry_points from importlib.metadata import entry_points
import logging import logging
from typing import Any
from pydantic import BaseModel from pydantic import BaseModel, ConfigDict
from oidc_test.models import User from oidc_test.models import User
@ -10,7 +9,9 @@ logger = logging.getLogger("registry")
class ProcessResult(BaseModel): class ProcessResult(BaseModel):
result: dict[str, Any] = {} model_config = ConfigDict(
extra="allow",
)
class ProcessError(Exception): class ProcessError(Exception):
@ -28,7 +29,7 @@ class ResourceProvider(BaseModel):
super().__init__() super().__init__()
self.__id__ = name self.__id__ = name
async def process(self, user: User, resource_id: str | None = None) -> ProcessResult: async def process(self, user: User | None, resource_id: str | None = None) -> ProcessResult:
logger.warning(f"{self.__id__} should define a process method") logger.warning(f"{self.__id__} should define a process method")
return ProcessResult() return ProcessResult()

View file

@ -13,15 +13,13 @@ from fastapi.middleware.cors import CORSMiddleware
from oidc_test.auth.provider import Provider from oidc_test.auth.provider import Provider
from oidc_test.auth.utils import ( from oidc_test.auth.utils import (
get_token_or_none,
get_user_from_token,
UserWithRole,
get_user_from_token_or_none, get_user_from_token_or_none,
oauth2_scheme_optional,
) )
from oidc_test.auth_providers import providers from oidc_test.auth_providers import providers
from oidc_test.settings import settings from oidc_test.settings import settings
from oidc_test.models import User from oidc_test.models import User
from oidc_test.registry import ProcessError, ProcessResult, ResourceProvider, registry from oidc_test.registry import ProcessError, ProcessResult, registry
logger = logging.getLogger("oidc-test") logger = logging.getLogger("oidc-test")
@ -50,23 +48,22 @@ resource_server.add_middleware(
@resource_server.get("/") @resource_server.get("/")
async def resources() -> dict[str, dict[str, Any]]: async def resources() -> dict[str, dict[str, Any]]:
return { return {"internal": {}, "plugins": registry.resource_providers}
"internal": {},
"plugins": registry.resource_providers
}
@resource_server.get("/{resource_name}") @resource_server.get("/{resource_name}")
@resource_server.get("/{resource_name}/{resource_id}") @resource_server.get("/{resource_name}/{resource_id}")
async def get_resource( async def get_resource(
resource_name: str, resource_name: str,
user: Annotated[User, Depends(get_user_from_token_or_none)], user: Annotated[User | None, Depends(get_user_from_token_or_none)],
token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], token: Annotated[OAuth2Token | None, Depends(oauth2_scheme_optional)],
request: Request,
resource_id: str | None = None, resource_id: str | None = None,
) -> ProcessResult: ):
"""Generic path for testing a resource provided by a provider""" """Generic path for testing a resource provided by a provider.
There's no field validation (response type of ProcessResult) on purpose,
leaving the responsibility of the response validation to resource providers"""
# Get the resource if it's defined in user auth provider's resources (external)
if user is not None:
provider = providers[user.auth_provider_id] provider = providers[user.auth_provider_id]
# Third party resource (provided through the auth provider) # Third party resource (provided through the auth provider)
# The token is just passed on # The token is just passed on
@ -80,30 +77,38 @@ async def get_resource(
# Internal resource (provided here) # Internal resource (provided here)
if resource_name in registry.resource_providers: if resource_name in registry.resource_providers:
resource_provider = registry.resource_providers[resource_name] resource_provider = registry.resource_providers[resource_name]
reasons: dict[str, str] = {} reason: dict[str, str] = {}
if not resource_provider.is_public: if not resource_provider.is_public:
if user is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Resource is not public")
else:
if resource_provider.scope_required is not None and not user.has_scope( if resource_provider.scope_required is not None and not user.has_scope(
resource_provider.scope_required resource_provider.scope_required
): ):
reasons["scope"] = f"No scope {resource_provider.scope_required} in the access token " \ reason["scope"] = (
f"No scope {resource_provider.scope_required} in the access token "
"but it is required for accessing this resource" "but it is required for accessing this resource"
if resource_provider.role_required is not None \ )
and resource_provider.role_required not in user.roles_as_set: if (
reasons["role"] = f"You don't have the role {resource_provider.role_required} " \ resource_provider.role_required is not None
and resource_provider.role_required not in user.roles_as_set
):
reason["role"] = (
f"You don't have the role {resource_provider.role_required} "
"but it is required for accessing this resource" "but it is required for accessing this resource"
if len(reasons) == 0: )
if len(reason) == 0:
try: try:
return await resource_provider.process(user=user, resource_id=resource_id) resp = await resource_provider.process(user=user, resource_id=resource_id)
return resp
except ProcessError as err: except ProcessError as err:
raise HTTPException( raise HTTPException(
status.HTTP_401_UNAUTHORIZED, f"Cannot process resource: {err}" status.HTTP_401_UNAUTHORIZED, f"Cannot process resource: {err}"
) )
else: else:
raise HTTPException( raise HTTPException(status.HTTP_401_UNAUTHORIZED, ", ".join(reason.values()))
status.HTTP_401_UNAUTHORIZED, ", ".join(reasons.values())
)
else: else:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown resource {resource_name}") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Unknown resource")
# return await get_resource_(resource_name, user, **request.query_params) # return await get_resource_(resource_name, user, **request.query_params)
@ -111,9 +116,7 @@ async def get_auth_provider_resource(
provider: Provider, resource_name: str, token: OAuth2Token | None, user: User provider: Provider, resource_name: str, token: OAuth2Token | None, user: User
) -> ProcessResult: ) -> ProcessResult:
if token is None: if token is None:
raise HTTPException( raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"No auth token")
status.HTTP_401_UNAUTHORIZED, f"No auth token"
)
access_token = token["access_token"] access_token = token["access_token"]
resource = [r for r in provider.resources if r.resource_name == resource_name][0] resource = [r for r in provider.resources if r.resource_name == resource_name][0]
async with AsyncClient() as client: async with AsyncClient() as client:
@ -129,11 +132,9 @@ async def get_auth_provider_resource(
# Only a demo, real application would really process the response # Only a demo, real application would really process the response
resp_length = len(resp.text) resp_length = len(resp.text)
if resp_length > 1024: if resp_length > 1024:
return ProcessResult( return ProcessResult(msg=f"The resource is too long ({resp_length} bytes) to show here")
result={"msg": f"The resource is too long ({resp_length} bytes) to show here"}
)
else: else:
return ProcessResult(result=resp.json()) return ProcessResult(**resp.json())
# @resource_server.get("/public") # @resource_server.get("/public")