Add role protection to resource servers, remove hardcoded resources

This commit is contained in:
phil 2025-02-13 18:15:26 +01:00
parent 381ce1ebc1
commit 9d3146dc1c
7 changed files with 127 additions and 110 deletions

View file

@ -8,7 +8,7 @@ from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOA
from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError
# from authlib.oauth1.auth import OAuthToken # from authlib.oauth1.auth import OAuthToken
from authlib.oauth2.auth import OAuth2Token from authlib.oauth2.rfc6749 import OAuth2Token
from oidc_test.auth.provider import Provider from oidc_test.auth.provider import Provider
from oidc_test.models import User from oidc_test.models import User
@ -60,25 +60,28 @@ def init_providers():
sub="", auth_provider_id=provider_settings.id sub="", auth_provider_id=provider_settings.id
) )
provider = Provider(**provider_settings_dict) provider = Provider(**provider_settings_dict)
authlib_oauth.register( if provider.disabled:
name=provider.id, logger.info(f"{provider_settings.name} is disabled, skipping")
server_metadata_url=provider.openid_configuration, else:
client_kwargs={ authlib_oauth.register(
"scope": " ".join( name=provider.id,
["openid", "email", "offline_access", "profile"] server_metadata_url=provider.openid_configuration,
+ provider.resource_provider_scopes client_kwargs={
), "scope": " ".join(
}, ["openid", "email", "offline_access", "profile"]
client_id=provider.client_id, + provider.resource_provider_scopes
client_secret=provider.client_secret, ),
api_base_url=provider.url, },
# For PKCE (not implemented yet): client_id=provider.client_id,
# code_challenge_method="S256", client_secret=provider.client_secret,
fetch_token=fetch_token, api_base_url=provider.url,
update_token=update_token, # For PKCE (not implemented yet):
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) # code_challenge_method="S256",
) fetch_token=fetch_token,
provider.authlib_client = getattr(authlib_oauth, provider.id) update_token=update_token,
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
)
provider.authlib_client = getattr(authlib_oauth, provider.id)
providers[provider.id] = provider providers[provider.id] = provider
@ -270,6 +273,14 @@ async def get_user_from_token(
) )
return user return user
async def get_user_from_token_or_none(
token: Annotated[str, Depends(oauth2_scheme)],
request: Request,
) -> User | None:
try:
return await get_user_from_token(token, request)
except HTTPException:
return None
class UserWithRole: class UserWithRole:
roles: set[str] roles: set[str]

View file

@ -55,6 +55,8 @@ async def lifespan(app: FastAPI):
init_providers() init_providers()
registry.make_registry() registry.make_registry()
for provider in list(providers.values()): for provider in list(providers.values()):
if provider.disabled:
continue
try: try:
await provider.get_info() await provider.get_info()
except NoPublicKey: except NoPublicKey:
@ -106,7 +108,6 @@ async def home(
else: else:
context["access_token"] = token["access_token"] context["access_token"] = token["access_token"]
# XXX: resources defined externally? I am confused... # XXX: resources defined externally? I am confused...
context["resources"] = provider.resources
try: try:
access_token_parsed = provider.decode(token["access_token"], verify_signature=False) access_token_parsed = provider.decode(token["access_token"], verify_signature=False)
except PyJWTError as err: except PyJWTError as err:

View file

@ -18,15 +18,18 @@ class ProcessError(Exception):
class ResourceProvider(BaseModel): class ResourceProvider(BaseModel):
name: str
scope_required: str | None = None scope_required: str | None = None
role_required: str | None = None
is_public: bool = False
default_resource_id: str | None = None default_resource_id: str | None = None
def __init__(self, name: str): def __init__(self, name: str):
super().__init__() super().__init__()
self.__name__ = name self.__id__ = name
async def process(self, user: User, resource_id: str | None = None) -> ProcessResult: async def process(self, user: User, resource_id: str | None = None) -> ProcessResult:
logger.warning(f"{self.__name__} should define a process method") logger.warning(f"{self.__id__} should define a process method")
return ProcessResult() return ProcessResult()

View file

@ -1,7 +1,7 @@
from typing import Annotated, Any from typing import Annotated, Any
import logging import logging
from authlib.oauth2.auth import OAuth2Token from authlib.oauth2.rfc6749 import OAuth2Token
from httpx import AsyncClient from httpx import AsyncClient
from jwt.exceptions import ExpiredSignatureError, InvalidTokenError from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
from fastapi import FastAPI, HTTPException, Depends, Request, status from fastapi import FastAPI, HTTPException, Depends, Request, status
@ -16,6 +16,7 @@ from oidc_test.auth.utils import (
get_token_or_none, get_token_or_none,
get_user_from_token, get_user_from_token,
UserWithRole, UserWithRole,
get_user_from_token_or_none,
) )
from oidc_test.auth_providers import providers from oidc_test.auth_providers import providers
from oidc_test.settings import settings from oidc_test.settings import settings
@ -55,54 +56,12 @@ async def resources() -> dict[str, dict[str, Any]]:
} }
@resource_server.get("/public")
async def public() -> dict:
return {"msg": "Not protected"}
@resource_server.get("/protected")
async def get_protected(user: Annotated[User, Depends(get_user_from_token)]):
assert user is not None # Just to keep QA checks happy
return {"msg": "Only authenticated users can see this"}
@resource_server.get("/protected-by-foorole")
async def get_protected_by_foorole(
user: Annotated[User, Depends(UserWithRole("foorole"))],
):
assert user is not None
return {"msg": "Only users with foorole can see this"}
@resource_server.get("/protected-by-barrole")
async def get_protected_by_barrole(
user: Annotated[User, Depends(UserWithRole("barrole"))],
):
assert user is not None
return {"msg": "Protected by barrole"}
@resource_server.get("/protected-by-foorole-and-barrole")
async def get_protected_by_foorole_and_barrole(
user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))],
):
assert user is not None # Just to keep QA checks happy
return {"msg": "Only users with foorole and barrole can see this"}
@resource_server.get("/protected-by-foorole-or-barrole")
async def get_protected_by_foorole_or_barrole(
user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))],
):
assert user is not None # Just to keep QA checks happy
return {"msg": "Only users with foorole or barrole can see this"}
@resource_server.get("/{resource_name}") @resource_server.get("/{resource_name}")
@resource_server.get("/{resource_name}/{resource_id}") @resource_server.get("/{resource_name}/{resource_id}")
async def get_resource( async def get_resource(
resource_name: str, resource_name: str,
user: Annotated[User, Depends(get_user_from_token)], user: Annotated[User, Depends(get_user_from_token_or_none)],
token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], token: Annotated[OAuth2Token | None, Depends(get_token_or_none)],
request: Request, request: Request,
resource_id: str | None = None, resource_id: str | None = None,
@ -111,19 +70,29 @@ async def get_resource(
provider = providers[user.auth_provider_id] provider = providers[user.auth_provider_id]
# Third party resource (provided through the auth provider) # Third party resource (provided through the auth provider)
# The token is just passed on # The token is just passed on
breakpoint()
if resource_name in [r.resource_name for r in provider.resources]: if resource_name in [r.resource_name for r in provider.resources]:
return await get_auth_provider_resource( return await get_auth_provider_resource(
provider=provider, provider=provider,
resource_name=resource_name, resource_name=resource_name,
access_token=token["access_token"] if token else None, token=token,
user=user, user=user,
) )
# Internal resource (provided here) # Internal resource (provided here)
if resource_name in registry.resource_providers: if resource_name in registry.resource_providers:
resource_provider = registry.resource_providers[resource_name] resource_provider = registry.resource_providers[resource_name]
if resource_provider.scope_required is not None and user.has_scope( reasons: dict[str, str] = {}
resource_provider.scope_required if not resource_provider.is_public:
): if resource_provider.scope_required is not None and not user.has_scope(
resource_provider.scope_required
):
reasons["scope"] = f"No scope {resource_provider.scope_required} in the access token " \
"but it is required for accessing this resource"
if resource_provider.role_required is not None \
and resource_provider.role_required not in user.roles_as_set:
reasons["role"] = f"You don't have the role {resource_provider.role_required} " \
"but it is required for accessing this resource"
if len(reasons) == 0:
try: try:
return await resource_provider.process(user=user, resource_id=resource_id) return await resource_provider.process(user=user, resource_id=resource_id)
except ProcessError as err: except ProcessError as err:
@ -132,9 +101,7 @@ async def get_resource(
) )
else: else:
raise HTTPException( raise HTTPException(
status.HTTP_401_UNAUTHORIZED, status.HTTP_401_UNAUTHORIZED, ", ".join(reasons.values())
f"No scope {resource_provider.scope_required} in the access token "
+ "but it is required for accessing this resource",
) )
else: else:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown resource {resource_name}") raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown resource {resource_name}")
@ -142,8 +109,13 @@ async def get_resource(
async def get_auth_provider_resource( async def get_auth_provider_resource(
provider: Provider, resource_name: str, access_token: str | None, user: User provider: Provider, resource_name: str, token: OAuth2Token | None, user: User
) -> ProcessResult: ) -> ProcessResult:
if token is None:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED, f"No auth token"
)
access_token = token["access_token"]
resource = [r for r in provider.resources if r.resource_name == resource_name][0] resource = [r for r in provider.resources if r.resource_name == resource_name][0]
async with AsyncClient() as client: async with AsyncClient() as client:
resp = await client.get( resp = await client.get(
@ -165,6 +137,48 @@ async def get_auth_provider_resource(
return ProcessResult(result=resp.json()) return ProcessResult(result=resp.json())
#@resource_server.get("/public")
#async def public() -> dict:
# return {"msg": "Not protected"}
#
#
#@resource_server.get("/protected")
#async def get_protected(user: Annotated[User, Depends(get_user_from_token)]):
# assert user is not None # Just to keep QA checks happy
# return {"msg": "Only authenticated users can see this"}
#
#
#@resource_server.get("/protected-by-foorole")
#async def get_protected_by_foorole(
# user: Annotated[User, Depends(UserWithRole("foorole"))],
#):
# assert user is not None
# return {"msg": "Only users with foorole can see this"}
#
#
#@resource_server.get("/protected-by-barrole")
#async def get_protected_by_barrole(
# user: Annotated[User, Depends(UserWithRole("barrole"))],
#):
# assert user is not None
# return {"msg": "Protected by barrole"}
#
#
#@resource_server.get("/protected-by-foorole-and-barrole")
#async def get_protected_by_foorole_and_barrole(
# user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))],
#):
# assert user is not None # Just to keep QA checks happy
# return {"msg": "Only users with foorole and barrole can see this"}
#
#
#@resource_server.get("/protected-by-foorole-or-barrole")
#async def get_protected_by_foorole_or_barrole(
# user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))],
#):
# assert user is not None # Just to keep QA checks happy
# return {"msg": "Only users with foorole or barrole can see this"}
# async def get_resource_(resource_id: str, user: User, **kwargs) -> dict: # async def get_resource_(resource_id: str, user: User, **kwargs) -> dict:
# """ # """
# Resource processing: build an informative rely as a simple showcase # Resource processing: build an informative rely as a simple showcase

View file

@ -44,6 +44,7 @@ class AuthProviderSettings(BaseModel):
resource_provider_scopes: list[str] = [] resource_provider_scopes: list[str] = []
session_key: str = "sid" session_key: str = "sid"
skip_verify_signature: bool = True skip_verify_signature: bool = True
disabled: bool = False
@computed_field @computed_field
@property @property

View file

@ -142,19 +142,27 @@ hr {
.providers .provider { .providers .provider {
min-height: 2em; min-height: 2em;
} }
.providers .provider a.link { .providers .provider .link {
text-decoration: none; text-decoration: none;
max-height: 2em; max-height: 2em;
} }
.providers .provider .link div { .providers .provider .link {
background-color: #f7c7867d; background-color: #f7c7867d;
border-radius: 8px; border-radius: 8px;
padding: 6px; padding: 6px;
text-align: center; text-align: center;
color: black; color: black;
font-weight: bold; font-weight: 400;
cursor: pointer; cursor: pointer;
border: 0;
width: 100%;
} }
.providers .provider .link.disabled {
color: gray;
cursor: not-allowed;
}
.providers .provider .hint { .providers .provider .hint {
font-size: 80%; font-size: 80%;
max-width: 13em; max-width: 13em;

View file

@ -11,7 +11,11 @@
{% for provider in providers.values() %} {% for provider in providers.values() %}
<tr class="provider"> <tr class="provider">
<td> <td>
<a class="link" href="login/{{ provider.id }}"><div>{{ provider.name }}</div></a> <button class="link{% if provider.disabled %} disabled{% endif %}"
{% if provider.disabled %}disabled{% endif %}
onclick="location.href='login/{{ provider.id }}'">
{{ provider.name }}
</button>
</td> </td>
<td class="hint">{{ provider.hint }}</div> <td class="hint">{{ provider.hint }}</div>
</td> </td>
@ -62,42 +66,17 @@
{% endif %} {% endif %}
<hr> <hr>
<div class="content"> <div class="content">
<p>
Resources validated by role:
</p>
<div class="links-to-check">
<button resource-name="public" onclick="get_resource('public', '{{ access_token }}', '{{ auth_provider.id }}')">Public</button>
<button resource-name="protected" onclick="get_resource('protected', '{{ access_token }}', '{{ auth_provider.id }}')">Auth protected content</button>
<button resource-name="protected-by-foorole" onclick="get_resource('protected-by-foorole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole protected content</button>
<button resource-name="protected-by-foorole-or-barrole" onclick="get_resource('protected-by-foorole-or-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole or barrole protected content</button>
<button resource-name="protected-by-barrole" onclick="get_resource('protected-by-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + barrole protected content</button>
<button resource-name="protected-by-foorole-and-barrole" onclick="get_resource('protected-by-foorole-and-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole and barrole protected content</button>
<button resource-name="fast_api_depends" class="hidden" onclick="get_resource('fast_api_depends', '{{ access_token }}', '{{ auth_provider.id }}')">Using FastAPI Depends</button>
<!--<button resource-name="introspect" onclick="get_resource('introspect', '{{ access_token }}', '{{ auth_provider.id }}')">Introspect token (401 expected)</button>-->
</div>
<!-- XXX confused...
{% if resources %}
<p>
Resources for this provider:
</p>
<div class="links-to-check">
{% for resource in resources %}
<button resource-name="{{ resource.id }}" onclick="get_resource('{{ resource.name }}', '{{ access_token }}', '{{ auth_provider.id }}')">{{ resource.name }}</buttona>
{% endfor %}
</div>
{% endif %}
-->
{% if resource_providers %} {% if resource_providers %}
<p> <p>
Resource providers (validated by scope): Resource providers:
</p> </p>
<div class="links-to-check"> <div class="links-to-check">
{% for name, resource_provider in resource_providers.items() %} {% for name, resource_provider in resource_providers.items() %}
{% if resource_provider.default_resource_id %} {% if resource_provider.default_resource_id %}
<button resource-name="{{ name }}" resource-id="{{ resource_provider.default_resource_id }}" onclick="get_resource('{{ name }}', '{{ access_token }}', '{{ auth_provider.id }}', '{{ resource_provider.default_resource_id }}')">{{ name }}</buttona> <button resource-name="{{ name }}" resource-id="{{ resource_provider.default_resource_id }}" onclick="get_resource('{{ name }}', '{{ access_token }}', '{{ auth_provider.id }}', '{{ resource_provider.default_resource_id }}')">{{ name }}</buttona>
{% else %} {% else %}
<button resource-name="{{ name }}" onclick="get_resource('{{ name }}', '{{ access_token }}', '{{ auth_provider.id }}')">{{ name }}</buttona> <button resource-name="{{ name }}" onclick="get_resource('{{ name }}', '{{ access_token }}', '{{ auth_provider.id }}')">{{ name }}</buttona>
{% endif %} {% endif %}
{% endfor %} {% endfor %}
</div> </div>
{% endif %} {% endif %}