Fix public resource access; free resource response validation; formatting
This commit is contained in:
parent
40ddb61636
commit
c89ca4098b
4 changed files with 84 additions and 77 deletions
|
@ -87,6 +87,7 @@ def init_providers():
|
||||||
|
|
||||||
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
|
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
|
oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
|
||||||
|
|
||||||
|
|
||||||
def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None:
|
def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None:
|
||||||
|
@ -125,7 +126,7 @@ async def get_current_user(request: Request) -> User:
|
||||||
"""
|
"""
|
||||||
if (user_sub := request.session.get("user_sub")) is None:
|
if (user_sub := request.session.get("user_sub")) is None:
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||||
token = await get_token(request)
|
token = await get_token_from_session(request)
|
||||||
user = await db.get_user(user_sub)
|
user = await db.get_user(user_sub)
|
||||||
## Check if the token is expired
|
## Check if the token is expired
|
||||||
if token.is_expired():
|
if token.is_expired():
|
||||||
|
@ -146,16 +147,16 @@ async def get_current_user(request: Request) -> User:
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
async def get_token_or_none(request: Request) -> OAuth2Token | None:
|
async def get_token_from_session_or_none(request: Request) -> OAuth2Token | None:
|
||||||
"""Return the auth token from the session or None.
|
"""Return the auth token from the session or None.
|
||||||
Can be used in Depends()"""
|
Can be used in Depends()"""
|
||||||
try:
|
try:
|
||||||
return await get_token(request)
|
return await get_token_from_session(request)
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def get_token(request: Request) -> OAuth2Token:
|
async def get_token_from_session(request: Request) -> OAuth2Token:
|
||||||
"""Return the token from the session.
|
"""Return the token from the session.
|
||||||
Can be used in Depends()"""
|
Can be used in Depends()"""
|
||||||
try:
|
try:
|
||||||
|
@ -273,15 +274,19 @@ async def get_user_from_token(
|
||||||
)
|
)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
async def get_user_from_token_or_none(
|
async def get_user_from_token_or_none(
|
||||||
token: Annotated[str, Depends(oauth2_scheme)],
|
token: Annotated[str | None, Depends(oauth2_scheme_optional)],
|
||||||
request: Request,
|
request: Request,
|
||||||
) -> User | None:
|
) -> User | None:
|
||||||
|
if token is None:
|
||||||
|
return None
|
||||||
try:
|
try:
|
||||||
return await get_user_from_token(token, request)
|
return await get_user_from_token(token, request)
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class UserWithRole:
|
class UserWithRole:
|
||||||
roles: set[str]
|
roles: set[str]
|
||||||
|
|
||||||
|
|
|
@ -33,8 +33,8 @@ from oidc_test.auth.utils import (
|
||||||
get_auth_provider_or_none,
|
get_auth_provider_or_none,
|
||||||
get_current_user_or_none,
|
get_current_user_or_none,
|
||||||
authlib_oauth,
|
authlib_oauth,
|
||||||
get_token_or_none,
|
get_token_from_session_or_none,
|
||||||
get_token,
|
get_token_from_session,
|
||||||
update_token,
|
update_token,
|
||||||
)
|
)
|
||||||
from oidc_test.auth.utils import init_providers
|
from oidc_test.auth.utils import init_providers
|
||||||
|
@ -88,9 +88,9 @@ app.mount("/resource", resource_server, name="resource_server")
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
async def home(
|
async def home(
|
||||||
request: Request,
|
request: Request,
|
||||||
user: Annotated[User, Depends(get_current_user_or_none)],
|
user: Annotated[User | None, Depends(get_current_user_or_none)],
|
||||||
provider: Annotated[Provider | None, Depends(get_auth_provider_or_none)],
|
provider: Annotated[Provider | None, Depends(get_auth_provider_or_none)],
|
||||||
token: Annotated[OAuth2Token | None, Depends(get_token_or_none)],
|
token: Annotated[OAuth2Token | None, Depends(get_token_from_session_or_none)],
|
||||||
) -> HTMLResponse:
|
) -> HTMLResponse:
|
||||||
context = {
|
context = {
|
||||||
"show_token": settings.show_token,
|
"show_token": settings.show_token,
|
||||||
|
@ -291,7 +291,7 @@ async def non_compliant_logout(
|
||||||
async def refresh(
|
async def refresh(
|
||||||
request: Request,
|
request: Request,
|
||||||
provider: Annotated[Provider, Depends(get_auth_provider)],
|
provider: Annotated[Provider, Depends(get_auth_provider)],
|
||||||
token: Annotated[OAuth2Token, Depends(get_token)],
|
token: Annotated[OAuth2Token, Depends(get_token_from_session)],
|
||||||
) -> RedirectResponse:
|
) -> RedirectResponse:
|
||||||
"""Manually refresh token"""
|
"""Manually refresh token"""
|
||||||
new_token = await provider.authlib_client.fetch_access_token(
|
new_token = await provider.authlib_client.fetch_access_token(
|
||||||
|
|
|
@ -1,8 +1,7 @@
|
||||||
from importlib.metadata import entry_points
|
from importlib.metadata import entry_points
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from oidc_test.models import User
|
from oidc_test.models import User
|
||||||
|
|
||||||
|
@ -10,7 +9,9 @@ logger = logging.getLogger("registry")
|
||||||
|
|
||||||
|
|
||||||
class ProcessResult(BaseModel):
|
class ProcessResult(BaseModel):
|
||||||
result: dict[str, Any] = {}
|
model_config = ConfigDict(
|
||||||
|
extra="allow",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ProcessError(Exception):
|
class ProcessError(Exception):
|
||||||
|
@ -28,7 +29,7 @@ class ResourceProvider(BaseModel):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.__id__ = name
|
self.__id__ = name
|
||||||
|
|
||||||
async def process(self, user: User, resource_id: str | None = None) -> ProcessResult:
|
async def process(self, user: User | None, resource_id: str | None = None) -> ProcessResult:
|
||||||
logger.warning(f"{self.__id__} should define a process method")
|
logger.warning(f"{self.__id__} should define a process method")
|
||||||
return ProcessResult()
|
return ProcessResult()
|
||||||
|
|
||||||
|
|
|
@ -13,15 +13,13 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
from oidc_test.auth.provider import Provider
|
from oidc_test.auth.provider import Provider
|
||||||
from oidc_test.auth.utils import (
|
from oidc_test.auth.utils import (
|
||||||
get_token_or_none,
|
|
||||||
get_user_from_token,
|
|
||||||
UserWithRole,
|
|
||||||
get_user_from_token_or_none,
|
get_user_from_token_or_none,
|
||||||
|
oauth2_scheme_optional,
|
||||||
)
|
)
|
||||||
from oidc_test.auth_providers import providers
|
from oidc_test.auth_providers import providers
|
||||||
from oidc_test.settings import settings
|
from oidc_test.settings import settings
|
||||||
from oidc_test.models import User
|
from oidc_test.models import User
|
||||||
from oidc_test.registry import ProcessError, ProcessResult, ResourceProvider, registry
|
from oidc_test.registry import ProcessError, ProcessResult, registry
|
||||||
|
|
||||||
logger = logging.getLogger("oidc-test")
|
logger = logging.getLogger("oidc-test")
|
||||||
|
|
||||||
|
@ -50,23 +48,22 @@ resource_server.add_middleware(
|
||||||
|
|
||||||
@resource_server.get("/")
|
@resource_server.get("/")
|
||||||
async def resources() -> dict[str, dict[str, Any]]:
|
async def resources() -> dict[str, dict[str, Any]]:
|
||||||
return {
|
return {"internal": {}, "plugins": registry.resource_providers}
|
||||||
"internal": {},
|
|
||||||
"plugins": registry.resource_providers
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@resource_server.get("/{resource_name}")
|
@resource_server.get("/{resource_name}")
|
||||||
@resource_server.get("/{resource_name}/{resource_id}")
|
@resource_server.get("/{resource_name}/{resource_id}")
|
||||||
async def get_resource(
|
async def get_resource(
|
||||||
resource_name: str,
|
resource_name: str,
|
||||||
user: Annotated[User, Depends(get_user_from_token_or_none)],
|
user: Annotated[User | None, Depends(get_user_from_token_or_none)],
|
||||||
token: Annotated[OAuth2Token | None, Depends(get_token_or_none)],
|
token: Annotated[OAuth2Token | None, Depends(oauth2_scheme_optional)],
|
||||||
request: Request,
|
|
||||||
resource_id: str | None = None,
|
resource_id: str | None = None,
|
||||||
) -> ProcessResult:
|
):
|
||||||
"""Generic path for testing a resource provided by a provider"""
|
"""Generic path for testing a resource provided by a provider.
|
||||||
|
There's no field validation (response type of ProcessResult) on purpose,
|
||||||
|
leaving the responsibility of the response validation to resource providers"""
|
||||||
|
# Get the resource if it's defined in user auth provider's resources (external)
|
||||||
|
if user is not None:
|
||||||
provider = providers[user.auth_provider_id]
|
provider = providers[user.auth_provider_id]
|
||||||
# Third party resource (provided through the auth provider)
|
# Third party resource (provided through the auth provider)
|
||||||
# The token is just passed on
|
# The token is just passed on
|
||||||
|
@ -80,30 +77,38 @@ async def get_resource(
|
||||||
# Internal resource (provided here)
|
# Internal resource (provided here)
|
||||||
if resource_name in registry.resource_providers:
|
if resource_name in registry.resource_providers:
|
||||||
resource_provider = registry.resource_providers[resource_name]
|
resource_provider = registry.resource_providers[resource_name]
|
||||||
reasons: dict[str, str] = {}
|
reason: dict[str, str] = {}
|
||||||
if not resource_provider.is_public:
|
if not resource_provider.is_public:
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Resource is not public")
|
||||||
|
else:
|
||||||
if resource_provider.scope_required is not None and not user.has_scope(
|
if resource_provider.scope_required is not None and not user.has_scope(
|
||||||
resource_provider.scope_required
|
resource_provider.scope_required
|
||||||
):
|
):
|
||||||
reasons["scope"] = f"No scope {resource_provider.scope_required} in the access token " \
|
reason["scope"] = (
|
||||||
|
f"No scope {resource_provider.scope_required} in the access token "
|
||||||
"but it is required for accessing this resource"
|
"but it is required for accessing this resource"
|
||||||
if resource_provider.role_required is not None \
|
)
|
||||||
and resource_provider.role_required not in user.roles_as_set:
|
if (
|
||||||
reasons["role"] = f"You don't have the role {resource_provider.role_required} " \
|
resource_provider.role_required is not None
|
||||||
|
and resource_provider.role_required not in user.roles_as_set
|
||||||
|
):
|
||||||
|
reason["role"] = (
|
||||||
|
f"You don't have the role {resource_provider.role_required} "
|
||||||
"but it is required for accessing this resource"
|
"but it is required for accessing this resource"
|
||||||
if len(reasons) == 0:
|
)
|
||||||
|
if len(reason) == 0:
|
||||||
try:
|
try:
|
||||||
return await resource_provider.process(user=user, resource_id=resource_id)
|
resp = await resource_provider.process(user=user, resource_id=resource_id)
|
||||||
|
return resp
|
||||||
except ProcessError as err:
|
except ProcessError as err:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_401_UNAUTHORIZED, f"Cannot process resource: {err}"
|
status.HTTP_401_UNAUTHORIZED, f"Cannot process resource: {err}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, ", ".join(reason.values()))
|
||||||
status.HTTP_401_UNAUTHORIZED, ", ".join(reasons.values())
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown resource {resource_name}")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Unknown resource")
|
||||||
# return await get_resource_(resource_name, user, **request.query_params)
|
# return await get_resource_(resource_name, user, **request.query_params)
|
||||||
|
|
||||||
|
|
||||||
|
@ -111,9 +116,7 @@ async def get_auth_provider_resource(
|
||||||
provider: Provider, resource_name: str, token: OAuth2Token | None, user: User
|
provider: Provider, resource_name: str, token: OAuth2Token | None, user: User
|
||||||
) -> ProcessResult:
|
) -> ProcessResult:
|
||||||
if token is None:
|
if token is None:
|
||||||
raise HTTPException(
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"No auth token")
|
||||||
status.HTTP_401_UNAUTHORIZED, f"No auth token"
|
|
||||||
)
|
|
||||||
access_token = token["access_token"]
|
access_token = token["access_token"]
|
||||||
resource = [r for r in provider.resources if r.resource_name == resource_name][0]
|
resource = [r for r in provider.resources if r.resource_name == resource_name][0]
|
||||||
async with AsyncClient() as client:
|
async with AsyncClient() as client:
|
||||||
|
@ -129,11 +132,9 @@ async def get_auth_provider_resource(
|
||||||
# Only a demo, real application would really process the response
|
# Only a demo, real application would really process the response
|
||||||
resp_length = len(resp.text)
|
resp_length = len(resp.text)
|
||||||
if resp_length > 1024:
|
if resp_length > 1024:
|
||||||
return ProcessResult(
|
return ProcessResult(msg=f"The resource is too long ({resp_length} bytes) to show here")
|
||||||
result={"msg": f"The resource is too long ({resp_length} bytes) to show here"}
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return ProcessResult(result=resp.json())
|
return ProcessResult(**resp.json())
|
||||||
|
|
||||||
|
|
||||||
# @resource_server.get("/public")
|
# @resource_server.get("/public")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue