Fix public resource access; free resource response validation; formatting
All checks were successful
/ build (push) Successful in 5s
/ test (push) Successful in 5s

This commit is contained in:
phil 2025-02-14 13:21:55 +01:00
parent 40ddb61636
commit c89ca4098b
4 changed files with 84 additions and 77 deletions

View file

@ -87,6 +87,7 @@ def init_providers():
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None:
@ -125,7 +126,7 @@ async def get_current_user(request: Request) -> User:
"""
if (user_sub := request.session.get("user_sub")) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
token = await get_token(request)
token = await get_token_from_session(request)
user = await db.get_user(user_sub)
## Check if the token is expired
if token.is_expired():
@ -146,16 +147,16 @@ async def get_current_user(request: Request) -> User:
return user
async def get_token_or_none(request: Request) -> OAuth2Token | None:
async def get_token_from_session_or_none(request: Request) -> OAuth2Token | None:
"""Return the auth token from the session or None.
Can be used in Depends()"""
try:
return await get_token(request)
return await get_token_from_session(request)
except HTTPException:
return None
async def get_token(request: Request) -> OAuth2Token:
async def get_token_from_session(request: Request) -> OAuth2Token:
"""Return the token from the session.
Can be used in Depends()"""
try:
@ -273,15 +274,19 @@ async def get_user_from_token(
)
return user
async def get_user_from_token_or_none(
token: Annotated[str, Depends(oauth2_scheme)],
token: Annotated[str | None, Depends(oauth2_scheme_optional)],
request: Request,
) -> User | None:
if token is None:
return None
try:
return await get_user_from_token(token, request)
except HTTPException:
return None
class UserWithRole:
roles: set[str]

View file

@ -33,8 +33,8 @@ from oidc_test.auth.utils import (
get_auth_provider_or_none,
get_current_user_or_none,
authlib_oauth,
get_token_or_none,
get_token,
get_token_from_session_or_none,
get_token_from_session,
update_token,
)
from oidc_test.auth.utils import init_providers
@ -88,9 +88,9 @@ app.mount("/resource", resource_server, name="resource_server")
@app.get("/")
async def home(
request: Request,
user: Annotated[User, Depends(get_current_user_or_none)],
user: Annotated[User | None, Depends(get_current_user_or_none)],
provider: Annotated[Provider | None, Depends(get_auth_provider_or_none)],
token: Annotated[OAuth2Token | None, Depends(get_token_or_none)],
token: Annotated[OAuth2Token | None, Depends(get_token_from_session_or_none)],
) -> HTMLResponse:
context = {
"show_token": settings.show_token,
@ -291,7 +291,7 @@ async def non_compliant_logout(
async def refresh(
request: Request,
provider: Annotated[Provider, Depends(get_auth_provider)],
token: Annotated[OAuth2Token, Depends(get_token)],
token: Annotated[OAuth2Token, Depends(get_token_from_session)],
) -> RedirectResponse:
"""Manually refresh token"""
new_token = await provider.authlib_client.fetch_access_token(

View file

@ -1,8 +1,7 @@
from importlib.metadata import entry_points
import logging
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from oidc_test.models import User
@ -10,7 +9,9 @@ logger = logging.getLogger("registry")
class ProcessResult(BaseModel):
result: dict[str, Any] = {}
model_config = ConfigDict(
extra="allow",
)
class ProcessError(Exception):
@ -28,7 +29,7 @@ class ResourceProvider(BaseModel):
super().__init__()
self.__id__ = name
async def process(self, user: User, resource_id: str | None = None) -> ProcessResult:
async def process(self, user: User | None, resource_id: str | None = None) -> ProcessResult:
logger.warning(f"{self.__id__} should define a process method")
return ProcessResult()

View file

@ -13,15 +13,13 @@ from fastapi.middleware.cors import CORSMiddleware
from oidc_test.auth.provider import Provider
from oidc_test.auth.utils import (
get_token_or_none,
get_user_from_token,
UserWithRole,
get_user_from_token_or_none,
oauth2_scheme_optional,
)
from oidc_test.auth_providers import providers
from oidc_test.settings import settings
from oidc_test.models import User
from oidc_test.registry import ProcessError, ProcessResult, ResourceProvider, registry
from oidc_test.registry import ProcessError, ProcessResult, registry
logger = logging.getLogger("oidc-test")
@ -50,60 +48,67 @@ resource_server.add_middleware(
@resource_server.get("/")
async def resources() -> dict[str, dict[str, Any]]:
return {
"internal": {},
"plugins": registry.resource_providers
}
return {"internal": {}, "plugins": registry.resource_providers}
@resource_server.get("/{resource_name}")
@resource_server.get("/{resource_name}/{resource_id}")
async def get_resource(
resource_name: str,
user: Annotated[User, Depends(get_user_from_token_or_none)],
token: Annotated[OAuth2Token | None, Depends(get_token_or_none)],
request: Request,
user: Annotated[User | None, Depends(get_user_from_token_or_none)],
token: Annotated[OAuth2Token | None, Depends(oauth2_scheme_optional)],
resource_id: str | None = None,
) -> ProcessResult:
"""Generic path for testing a resource provided by a provider"""
provider = providers[user.auth_provider_id]
# Third party resource (provided through the auth provider)
# The token is just passed on
if resource_name in [r.resource_name for r in provider.resources]:
return await get_auth_provider_resource(
provider=provider,
resource_name=resource_name,
token=token,
user=user,
)
):
"""Generic path for testing a resource provided by a provider.
There's no field validation (response type of ProcessResult) on purpose,
leaving the responsibility of the response validation to resource providers"""
# Get the resource if it's defined in user auth provider's resources (external)
if user is not None:
provider = providers[user.auth_provider_id]
# Third party resource (provided through the auth provider)
# The token is just passed on
if resource_name in [r.resource_name for r in provider.resources]:
return await get_auth_provider_resource(
provider=provider,
resource_name=resource_name,
token=token,
user=user,
)
# Internal resource (provided here)
if resource_name in registry.resource_providers:
resource_provider = registry.resource_providers[resource_name]
reasons: dict[str, str] = {}
reason: dict[str, str] = {}
if not resource_provider.is_public:
if resource_provider.scope_required is not None and not user.has_scope(
resource_provider.scope_required
):
reasons["scope"] = f"No scope {resource_provider.scope_required} in the access token " \
"but it is required for accessing this resource"
if resource_provider.role_required is not None \
and resource_provider.role_required not in user.roles_as_set:
reasons["role"] = f"You don't have the role {resource_provider.role_required} " \
"but it is required for accessing this resource"
if len(reasons) == 0:
if user is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Resource is not public")
else:
if resource_provider.scope_required is not None and not user.has_scope(
resource_provider.scope_required
):
reason["scope"] = (
f"No scope {resource_provider.scope_required} in the access token "
"but it is required for accessing this resource"
)
if (
resource_provider.role_required is not None
and resource_provider.role_required not in user.roles_as_set
):
reason["role"] = (
f"You don't have the role {resource_provider.role_required} "
"but it is required for accessing this resource"
)
if len(reason) == 0:
try:
return await resource_provider.process(user=user, resource_id=resource_id)
resp = await resource_provider.process(user=user, resource_id=resource_id)
return resp
except ProcessError as err:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED, f"Cannot process resource: {err}"
)
else:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED, ", ".join(reasons.values())
)
raise HTTPException(status.HTTP_401_UNAUTHORIZED, ", ".join(reason.values()))
else:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown resource {resource_name}")
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Unknown resource")
# return await get_resource_(resource_name, user, **request.query_params)
@ -111,9 +116,7 @@ async def get_auth_provider_resource(
provider: Provider, resource_name: str, token: OAuth2Token | None, user: User
) -> ProcessResult:
if token is None:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED, f"No auth token"
)
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"No auth token")
access_token = token["access_token"]
resource = [r for r in provider.resources if r.resource_name == resource_name][0]
async with AsyncClient() as client:
@ -129,52 +132,50 @@ async def get_auth_provider_resource(
# Only a demo, real application would really process the response
resp_length = len(resp.text)
if resp_length > 1024:
return ProcessResult(
result={"msg": f"The resource is too long ({resp_length} bytes) to show here"}
)
return ProcessResult(msg=f"The resource is too long ({resp_length} bytes) to show here")
else:
return ProcessResult(result=resp.json())
return ProcessResult(**resp.json())
#@resource_server.get("/public")
#async def public() -> dict:
# @resource_server.get("/public")
# async def public() -> dict:
# return {"msg": "Not protected"}
#
#
#@resource_server.get("/protected")
#async def get_protected(user: Annotated[User, Depends(get_user_from_token)]):
# @resource_server.get("/protected")
# async def get_protected(user: Annotated[User, Depends(get_user_from_token)]):
# assert user is not None # Just to keep QA checks happy
# return {"msg": "Only authenticated users can see this"}
#
#
#@resource_server.get("/protected-by-foorole")
#async def get_protected_by_foorole(
# @resource_server.get("/protected-by-foorole")
# async def get_protected_by_foorole(
# user: Annotated[User, Depends(UserWithRole("foorole"))],
#):
# ):
# assert user is not None
# return {"msg": "Only users with foorole can see this"}
#
#
#@resource_server.get("/protected-by-barrole")
#async def get_protected_by_barrole(
# @resource_server.get("/protected-by-barrole")
# async def get_protected_by_barrole(
# user: Annotated[User, Depends(UserWithRole("barrole"))],
#):
# ):
# assert user is not None
# return {"msg": "Protected by barrole"}
#
#
#@resource_server.get("/protected-by-foorole-and-barrole")
#async def get_protected_by_foorole_and_barrole(
# @resource_server.get("/protected-by-foorole-and-barrole")
# async def get_protected_by_foorole_and_barrole(
# user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))],
#):
# ):
# assert user is not None # Just to keep QA checks happy
# return {"msg": "Only users with foorole and barrole can see this"}
#
#
#@resource_server.get("/protected-by-foorole-or-barrole")
#async def get_protected_by_foorole_or_barrole(
# @resource_server.get("/protected-by-foorole-or-barrole")
# async def get_protected_by_foorole_or_barrole(
# user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))],
#):
# ):
# assert user is not None # Just to keep QA checks happy
# return {"msg": "Only users with foorole or barrole can see this"}