Fix public resource access; free resource response validation; formatting
This commit is contained in:
parent
40ddb61636
commit
c89ca4098b
4 changed files with 84 additions and 77 deletions
|
@ -87,6 +87,7 @@ def init_providers():
|
|||
|
||||
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
|
||||
|
||||
|
||||
def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None:
|
||||
|
@ -125,7 +126,7 @@ async def get_current_user(request: Request) -> User:
|
|||
"""
|
||||
if (user_sub := request.session.get("user_sub")) is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
token = await get_token(request)
|
||||
token = await get_token_from_session(request)
|
||||
user = await db.get_user(user_sub)
|
||||
## Check if the token is expired
|
||||
if token.is_expired():
|
||||
|
@ -146,16 +147,16 @@ async def get_current_user(request: Request) -> User:
|
|||
return user
|
||||
|
||||
|
||||
async def get_token_or_none(request: Request) -> OAuth2Token | None:
|
||||
async def get_token_from_session_or_none(request: Request) -> OAuth2Token | None:
|
||||
"""Return the auth token from the session or None.
|
||||
Can be used in Depends()"""
|
||||
try:
|
||||
return await get_token(request)
|
||||
return await get_token_from_session(request)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
async def get_token(request: Request) -> OAuth2Token:
|
||||
async def get_token_from_session(request: Request) -> OAuth2Token:
|
||||
"""Return the token from the session.
|
||||
Can be used in Depends()"""
|
||||
try:
|
||||
|
@ -273,15 +274,19 @@ async def get_user_from_token(
|
|||
)
|
||||
return user
|
||||
|
||||
|
||||
async def get_user_from_token_or_none(
|
||||
token: Annotated[str, Depends(oauth2_scheme)],
|
||||
token: Annotated[str | None, Depends(oauth2_scheme_optional)],
|
||||
request: Request,
|
||||
) -> User | None:
|
||||
if token is None:
|
||||
return None
|
||||
try:
|
||||
return await get_user_from_token(token, request)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
class UserWithRole:
|
||||
roles: set[str]
|
||||
|
||||
|
|
|
@ -33,8 +33,8 @@ from oidc_test.auth.utils import (
|
|||
get_auth_provider_or_none,
|
||||
get_current_user_or_none,
|
||||
authlib_oauth,
|
||||
get_token_or_none,
|
||||
get_token,
|
||||
get_token_from_session_or_none,
|
||||
get_token_from_session,
|
||||
update_token,
|
||||
)
|
||||
from oidc_test.auth.utils import init_providers
|
||||
|
@ -88,9 +88,9 @@ app.mount("/resource", resource_server, name="resource_server")
|
|||
@app.get("/")
|
||||
async def home(
|
||||
request: Request,
|
||||
user: Annotated[User, Depends(get_current_user_or_none)],
|
||||
user: Annotated[User | None, Depends(get_current_user_or_none)],
|
||||
provider: Annotated[Provider | None, Depends(get_auth_provider_or_none)],
|
||||
token: Annotated[OAuth2Token | None, Depends(get_token_or_none)],
|
||||
token: Annotated[OAuth2Token | None, Depends(get_token_from_session_or_none)],
|
||||
) -> HTMLResponse:
|
||||
context = {
|
||||
"show_token": settings.show_token,
|
||||
|
@ -291,7 +291,7 @@ async def non_compliant_logout(
|
|||
async def refresh(
|
||||
request: Request,
|
||||
provider: Annotated[Provider, Depends(get_auth_provider)],
|
||||
token: Annotated[OAuth2Token, Depends(get_token)],
|
||||
token: Annotated[OAuth2Token, Depends(get_token_from_session)],
|
||||
) -> RedirectResponse:
|
||||
"""Manually refresh token"""
|
||||
new_token = await provider.authlib_client.fetch_access_token(
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
from importlib.metadata import entry_points
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from oidc_test.models import User
|
||||
|
||||
|
@ -10,7 +9,9 @@ logger = logging.getLogger("registry")
|
|||
|
||||
|
||||
class ProcessResult(BaseModel):
|
||||
result: dict[str, Any] = {}
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
)
|
||||
|
||||
|
||||
class ProcessError(Exception):
|
||||
|
@ -28,7 +29,7 @@ class ResourceProvider(BaseModel):
|
|||
super().__init__()
|
||||
self.__id__ = name
|
||||
|
||||
async def process(self, user: User, resource_id: str | None = None) -> ProcessResult:
|
||||
async def process(self, user: User | None, resource_id: str | None = None) -> ProcessResult:
|
||||
logger.warning(f"{self.__id__} should define a process method")
|
||||
return ProcessResult()
|
||||
|
||||
|
|
|
@ -13,15 +13,13 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
|
||||
from oidc_test.auth.provider import Provider
|
||||
from oidc_test.auth.utils import (
|
||||
get_token_or_none,
|
||||
get_user_from_token,
|
||||
UserWithRole,
|
||||
get_user_from_token_or_none,
|
||||
oauth2_scheme_optional,
|
||||
)
|
||||
from oidc_test.auth_providers import providers
|
||||
from oidc_test.settings import settings
|
||||
from oidc_test.models import User
|
||||
from oidc_test.registry import ProcessError, ProcessResult, ResourceProvider, registry
|
||||
from oidc_test.registry import ProcessError, ProcessResult, registry
|
||||
|
||||
logger = logging.getLogger("oidc-test")
|
||||
|
||||
|
@ -50,60 +48,67 @@ resource_server.add_middleware(
|
|||
|
||||
@resource_server.get("/")
|
||||
async def resources() -> dict[str, dict[str, Any]]:
|
||||
return {
|
||||
"internal": {},
|
||||
"plugins": registry.resource_providers
|
||||
}
|
||||
|
||||
return {"internal": {}, "plugins": registry.resource_providers}
|
||||
|
||||
|
||||
@resource_server.get("/{resource_name}")
|
||||
@resource_server.get("/{resource_name}/{resource_id}")
|
||||
async def get_resource(
|
||||
resource_name: str,
|
||||
user: Annotated[User, Depends(get_user_from_token_or_none)],
|
||||
token: Annotated[OAuth2Token | None, Depends(get_token_or_none)],
|
||||
request: Request,
|
||||
user: Annotated[User | None, Depends(get_user_from_token_or_none)],
|
||||
token: Annotated[OAuth2Token | None, Depends(oauth2_scheme_optional)],
|
||||
resource_id: str | None = None,
|
||||
) -> ProcessResult:
|
||||
"""Generic path for testing a resource provided by a provider"""
|
||||
provider = providers[user.auth_provider_id]
|
||||
# Third party resource (provided through the auth provider)
|
||||
# The token is just passed on
|
||||
if resource_name in [r.resource_name for r in provider.resources]:
|
||||
return await get_auth_provider_resource(
|
||||
provider=provider,
|
||||
resource_name=resource_name,
|
||||
token=token,
|
||||
user=user,
|
||||
)
|
||||
):
|
||||
"""Generic path for testing a resource provided by a provider.
|
||||
There's no field validation (response type of ProcessResult) on purpose,
|
||||
leaving the responsibility of the response validation to resource providers"""
|
||||
# Get the resource if it's defined in user auth provider's resources (external)
|
||||
if user is not None:
|
||||
provider = providers[user.auth_provider_id]
|
||||
# Third party resource (provided through the auth provider)
|
||||
# The token is just passed on
|
||||
if resource_name in [r.resource_name for r in provider.resources]:
|
||||
return await get_auth_provider_resource(
|
||||
provider=provider,
|
||||
resource_name=resource_name,
|
||||
token=token,
|
||||
user=user,
|
||||
)
|
||||
# Internal resource (provided here)
|
||||
if resource_name in registry.resource_providers:
|
||||
resource_provider = registry.resource_providers[resource_name]
|
||||
reasons: dict[str, str] = {}
|
||||
reason: dict[str, str] = {}
|
||||
if not resource_provider.is_public:
|
||||
if resource_provider.scope_required is not None and not user.has_scope(
|
||||
resource_provider.scope_required
|
||||
):
|
||||
reasons["scope"] = f"No scope {resource_provider.scope_required} in the access token " \
|
||||
"but it is required for accessing this resource"
|
||||
if resource_provider.role_required is not None \
|
||||
and resource_provider.role_required not in user.roles_as_set:
|
||||
reasons["role"] = f"You don't have the role {resource_provider.role_required} " \
|
||||
"but it is required for accessing this resource"
|
||||
if len(reasons) == 0:
|
||||
if user is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Resource is not public")
|
||||
else:
|
||||
if resource_provider.scope_required is not None and not user.has_scope(
|
||||
resource_provider.scope_required
|
||||
):
|
||||
reason["scope"] = (
|
||||
f"No scope {resource_provider.scope_required} in the access token "
|
||||
"but it is required for accessing this resource"
|
||||
)
|
||||
if (
|
||||
resource_provider.role_required is not None
|
||||
and resource_provider.role_required not in user.roles_as_set
|
||||
):
|
||||
reason["role"] = (
|
||||
f"You don't have the role {resource_provider.role_required} "
|
||||
"but it is required for accessing this resource"
|
||||
)
|
||||
if len(reason) == 0:
|
||||
try:
|
||||
return await resource_provider.process(user=user, resource_id=resource_id)
|
||||
resp = await resource_provider.process(user=user, resource_id=resource_id)
|
||||
return resp
|
||||
except ProcessError as err:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED, f"Cannot process resource: {err}"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED, ", ".join(reasons.values())
|
||||
)
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, ", ".join(reason.values()))
|
||||
else:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown resource {resource_name}")
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Unknown resource")
|
||||
# return await get_resource_(resource_name, user, **request.query_params)
|
||||
|
||||
|
||||
|
@ -111,9 +116,7 @@ async def get_auth_provider_resource(
|
|||
provider: Provider, resource_name: str, token: OAuth2Token | None, user: User
|
||||
) -> ProcessResult:
|
||||
if token is None:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED, f"No auth token"
|
||||
)
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"No auth token")
|
||||
access_token = token["access_token"]
|
||||
resource = [r for r in provider.resources if r.resource_name == resource_name][0]
|
||||
async with AsyncClient() as client:
|
||||
|
@ -129,52 +132,50 @@ async def get_auth_provider_resource(
|
|||
# Only a demo, real application would really process the response
|
||||
resp_length = len(resp.text)
|
||||
if resp_length > 1024:
|
||||
return ProcessResult(
|
||||
result={"msg": f"The resource is too long ({resp_length} bytes) to show here"}
|
||||
)
|
||||
return ProcessResult(msg=f"The resource is too long ({resp_length} bytes) to show here")
|
||||
else:
|
||||
return ProcessResult(result=resp.json())
|
||||
return ProcessResult(**resp.json())
|
||||
|
||||
|
||||
#@resource_server.get("/public")
|
||||
#async def public() -> dict:
|
||||
# @resource_server.get("/public")
|
||||
# async def public() -> dict:
|
||||
# return {"msg": "Not protected"}
|
||||
#
|
||||
#
|
||||
#@resource_server.get("/protected")
|
||||
#async def get_protected(user: Annotated[User, Depends(get_user_from_token)]):
|
||||
# @resource_server.get("/protected")
|
||||
# async def get_protected(user: Annotated[User, Depends(get_user_from_token)]):
|
||||
# assert user is not None # Just to keep QA checks happy
|
||||
# return {"msg": "Only authenticated users can see this"}
|
||||
#
|
||||
#
|
||||
#@resource_server.get("/protected-by-foorole")
|
||||
#async def get_protected_by_foorole(
|
||||
# @resource_server.get("/protected-by-foorole")
|
||||
# async def get_protected_by_foorole(
|
||||
# user: Annotated[User, Depends(UserWithRole("foorole"))],
|
||||
#):
|
||||
# ):
|
||||
# assert user is not None
|
||||
# return {"msg": "Only users with foorole can see this"}
|
||||
#
|
||||
#
|
||||
#@resource_server.get("/protected-by-barrole")
|
||||
#async def get_protected_by_barrole(
|
||||
# @resource_server.get("/protected-by-barrole")
|
||||
# async def get_protected_by_barrole(
|
||||
# user: Annotated[User, Depends(UserWithRole("barrole"))],
|
||||
#):
|
||||
# ):
|
||||
# assert user is not None
|
||||
# return {"msg": "Protected by barrole"}
|
||||
#
|
||||
#
|
||||
#@resource_server.get("/protected-by-foorole-and-barrole")
|
||||
#async def get_protected_by_foorole_and_barrole(
|
||||
# @resource_server.get("/protected-by-foorole-and-barrole")
|
||||
# async def get_protected_by_foorole_and_barrole(
|
||||
# user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))],
|
||||
#):
|
||||
# ):
|
||||
# assert user is not None # Just to keep QA checks happy
|
||||
# return {"msg": "Only users with foorole and barrole can see this"}
|
||||
#
|
||||
#
|
||||
#@resource_server.get("/protected-by-foorole-or-barrole")
|
||||
#async def get_protected_by_foorole_or_barrole(
|
||||
# @resource_server.get("/protected-by-foorole-or-barrole")
|
||||
# async def get_protected_by_foorole_or_barrole(
|
||||
# user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))],
|
||||
#):
|
||||
# ):
|
||||
# assert user is not None # Just to keep QA checks happy
|
||||
# return {"msg": "Only users with foorole or barrole can see this"}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue