From 64f6a90f22c82a6813cd7e68f2ab193fdbd1980b Mon Sep 17 00:00:00 2001 From: phil Date: Tue, 11 Feb 2025 17:27:49 +0100 Subject: [PATCH] Add resource provided registry and plugin system --- src/oidc_test/auth/provider.py | 4 +- src/oidc_test/auth/utils.py | 11 +- src/oidc_test/auth_providers.py | 2 +- src/oidc_test/database.py | 8 +- src/oidc_test/main.py | 22 +-- src/oidc_test/registry.py | 43 ++++++ src/oidc_test/resource_server.py | 229 +++++++++++++++++------------- src/oidc_test/settings.py | 2 +- src/oidc_test/static/utils.js | 9 +- src/oidc_test/templates/home.html | 41 +++--- 10 files changed, 229 insertions(+), 142 deletions(-) create mode 100644 src/oidc_test/registry.py diff --git a/src/oidc_test/auth/provider.py b/src/oidc_test/auth/provider.py index dab4764..17dcaa0 100644 --- a/src/oidc_test/auth/provider.py +++ b/src/oidc_test/auth/provider.py @@ -7,8 +7,8 @@ from pydantic import ConfigDict from authlib.integrations.starlette_client.apps import StarletteOAuth2App from httpx import AsyncClient -from ..settings import AuthProviderSettings, settings -from ..models import User +from oidc_test.settings import AuthProviderSettings, settings +from oidc_test.models import User logger = logging.getLogger("oidc-test") diff --git a/src/oidc_test/auth/utils.py b/src/oidc_test/auth/utils.py index 0623186..9479c48 100644 --- a/src/oidc_test/auth/utils.py +++ b/src/oidc_test/auth/utils.py @@ -10,12 +10,11 @@ from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError # from authlib.oauth1.auth import OAuthToken from authlib.oauth2.auth import OAuth2Token -from .provider import Provider - -from ..models import User -from ..database import db, TokenNotInDb, UserNotInDB -from ..settings import settings -from ..auth_providers import providers +from oidc_test.auth.provider import Provider +from oidc_test.models import User +from oidc_test.database import db, TokenNotInDb, UserNotInDB +from oidc_test.settings import settings +from oidc_test.auth_providers import providers logger = logging.getLogger("oidc-test") diff --git a/src/oidc_test/auth_providers.py b/src/oidc_test/auth_providers.py index 45f4de6..1c33ae8 100644 --- a/src/oidc_test/auth_providers.py +++ b/src/oidc_test/auth_providers.py @@ -1,5 +1,5 @@ from collections import OrderedDict -from .auth.provider import Provider +from oidc_test.auth.provider import Provider providers: OrderedDict[str, Provider] = OrderedDict() diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index 4704f9b..8d87a48 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -5,10 +5,10 @@ import logging from authlib.oauth2.rfc6749 import OAuth2Token from jwt import PyJWTError -from .auth.provider import Provider +from oidc_test.auth.provider import Provider -from .models import User, Role -from .auth_providers import providers +from oidc_test.models import User, Role +from oidc_test.auth_providers import providers logger = logging.getLogger("oidc-test") @@ -23,6 +23,7 @@ class TokenNotInDb(Exception): class Database: users: dict[str, User] = {} + # TODO: key of the token table should be provider: sid tokens: dict[str, OAuth2Token] = {} # Last sessions for the user (key: users's subject id (sub)) @@ -82,6 +83,7 @@ class Database: provider: Provider, sid: str | None, ) -> OAuth2Token: + # TODO: key of the token table should be provider: sid assert isinstance(provider, Provider) if sid is None: raise TokenNotInDb diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index f37339d..28eab8a 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -26,8 +26,9 @@ from authlib.oauth2.rfc6749 import OAuth2Token # from fastapi.security import OpenIdConnect # from pkce import generate_code_verifier, generate_pkce_pair -from .auth.provider import NoPublicKey, Provider -from .auth.utils import ( +from oidc_test.registry import registry +from oidc_test.auth.provider import NoPublicKey, Provider +from oidc_test.auth.utils import ( get_auth_provider, get_auth_provider_or_none, get_current_user_or_none, @@ -36,13 +37,12 @@ from .auth.utils import ( get_token, update_token, ) - -from .auth.utils import init_providers -from .settings import settings -from .auth_providers import providers -from .models import User -from .database import TokenNotInDb, db -from .resource_server import resource_server +from oidc_test.auth.utils import init_providers +from oidc_test.settings import settings +from oidc_test.auth_providers import providers +from oidc_test.models import User +from oidc_test.database import TokenNotInDb, db +from oidc_test.resource_server import resource_server logger = logging.getLogger("oidc-test") @@ -53,6 +53,7 @@ templates = Jinja2Templates(Path(__file__).parent / "templates") async def lifespan(app: FastAPI): assert app is not None init_providers() + registry.make_registry() for provider in list(providers.values()): try: await provider.get_info() @@ -104,6 +105,7 @@ async def home( context["resources"] = None else: context["access_token"] = token["access_token"] + # XXX: resources defined externally? I am confused... context["resources"] = provider.resources try: access_token_parsed = provider.decode(token["access_token"], verify_signature=False) @@ -113,9 +115,9 @@ async def home( context["access_token_scope"] = access_token_parsed["scope"] except KeyError: context["access_token_scope"] = None - # context["id_token_parsed"] = pretty_details(user, now) context["id_token_parsed"] = provider.decode(token["id_token"], verify_signature=False) context["access_token_parsed"] = access_token_parsed + context["resource_providers"] = registry.resource_providers try: context["refresh_token_parsed"] = provider.decode( token["refresh_token"], verify_signature=False diff --git a/src/oidc_test/registry.py b/src/oidc_test/registry.py new file mode 100644 index 0000000..6db0a47 --- /dev/null +++ b/src/oidc_test/registry.py @@ -0,0 +1,43 @@ +from importlib.metadata import entry_points +import logging +from typing import Any + +from pydantic import BaseModel + +from oidc_test.models import User + +logger = logging.getLogger("registry") + + +class ProcessResult(BaseModel): + result: dict[str, Any] = {} + + +class ProcessError(Exception): + pass + + +class ResourceProvider: + name: str + scope_required: str | None = None + default_resource_id: str | None = None + + def __init__(self, name: str): + self.name = name + + async def process(self, user: User, resource_id: str | None = None) -> ProcessResult: + logger.warning(f"{self.name} should define a process method") + return ProcessResult() + + +class ResourceRegistry: + resource_providers: dict[str, ResourceProvider] = {} + + def make_registry(self): + for ep in entry_points().select(group="oidc_test.resource_provider"): + ResourceProviderClass = ep.load() + if issubclass(ResourceProviderClass, ResourceProvider): + self.resource_providers[ep.name] = ResourceProviderClass(ep.name) + + +registry = ResourceRegistry() diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index 15084bc..3b89240 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -5,23 +5,23 @@ import logging from authlib.oauth2.auth import OAuth2Token from httpx import AsyncClient from jwt.exceptions import ExpiredSignatureError, InvalidTokenError -from fastapi import FastAPI, HTTPException, Depends, status +from fastapi import FastAPI, HTTPException, Depends, Request, status from fastapi.middleware.cors import CORSMiddleware # from starlette.middleware.sessions import SessionMiddleware # from authlib.integrations.starlette_client.apps import StarletteOAuth2App # from authlib.oauth2.rfc6749 import OAuth2Token -from .auth.provider import Provider -from .auth.utils import ( +from oidc_test.auth.provider import Provider +from oidc_test.auth.utils import ( get_token_or_none, get_user_from_token, UserWithRole, ) - -from .auth_providers import providers -from .settings import settings -from .models import User +from oidc_test.auth_providers import providers +from oidc_test.settings import settings +from oidc_test.models import User +from oidc_test.registry import ProcessError, ProcessResult, registry logger = logging.getLogger("oidc-test") @@ -91,6 +91,128 @@ async def get_protected_by_foorole_or_barrole( return {"msg": "Only users with foorole or barrole can see this"} +@resource_server.get("/{resource_name}") +@resource_server.get("/{resource_name}/{resource_id}") +async def get_resource( + resource_name: str, + user: Annotated[User, Depends(get_user_from_token)], + token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], + request: Request, + resource_id: str | None = None, +) -> ProcessResult: + """Generic path for testing a resource provided by a provider""" + provider = providers[user.auth_provider_id] + # Third party resource (provided through the auth provider) + # The token is just passed on + if resource_name in [r.resource_name for r in provider.resources]: + return await get_auth_provider_resource( + provider=provider, + resource_name=resource_name, + access_token=token["access_token"] if token else None, + user=user, + ) + # Internal resource (provided here) + if resource_name in registry.resource_providers: + resource_provider = registry.resource_providers[resource_name] + if resource_provider.scope_required is not None and user.has_scope( + resource_provider.scope_required + ): + try: + return await resource_provider.process(user=user, resource_id=resource_id) + except ProcessError as err: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, f"Cannot process resource: {err}" + ) + else: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, + f"No scope {resource_provider.scope_required} in the access token " + + "but it is required for accessing this resource", + ) + else: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown resource {resource_name}") + # return await get_resource_(resource_name, user, **request.query_params) + + +async def get_auth_provider_resource( + provider: Provider, resource_name: str, access_token: str | None, user: User +) -> ProcessResult: + resource = [r for r in provider.resources if r.resource_name == resource_name][0] + async with AsyncClient() as client: + resp = await client.get( + url=provider.url + resource.url, + headers={ + "Content-type": "application/json", + "Authorization": f"Bearer {access_token}", + }, + ) + if resp.is_error: + raise HTTPException(resp.status_code, f"Cannot fetch resource: {resp.reason_phrase}") + # Only a demo, real application would really process the response + resp_length = len(resp.text) + if resp_length > 1024: + return ProcessResult( + result={"msg": f"The resource is too long ({resp_length} bytes) to show here"} + ) + else: + return ProcessResult(result=resp.json()) + + +# async def get_resource_(resource_id: str, user: User, **kwargs) -> dict: +# """ +# Resource processing: build an informative rely as a simple showcase +# """ +# if resource_id == "petition": +# return await sign(user, kwargs["petition_id"]) +# provider = providers[user.auth_provider_id] +# try: +# pname = provider.name +# except KeyError: +# pname = "?" +# resp = { +# "hello": f"Hi {user.name} from an OAuth resource provider", +# "comment": f"I received a request for '{resource_id}' " +# + f"with an access token signed by {pname}", +# } +# # For the demo, resource resource_id matches a scope get:resource_id, +# # but this has to be refined for production +# required_scope = f"get:{resource_id}" +# # Check if the required scope is in the scopes allowed in userinfo +# try: +# if user.has_scope(required_scope): +# await process(user, resource_id, resp) +# else: +# ## For the showcase, giving a explanation. +# ## Alternatively, raise HTTP_401_UNAUTHORIZED +# raise HTTPException( +# status.HTTP_401_UNAUTHORIZED, +# f"No scope {required_scope} in the access token " +# + "but it is required for accessing this resource", +# ) +# except ExpiredSignatureError: +# raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token's signature has expired") +# except InvalidTokenError: +# raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token is invalid") +# return resp + + +# async def process(user, resource_id, resp): +# """ +# Too simple to be serious. +# It's a good fit for a plugin architecture for production +# """ +# if resource_id == "time": +# resp["time"] = datetime.now().strftime("%c") +# elif resource_id == "bs": +# async with AsyncClient() as client: +# bs = await client.get("https://corporatebs-generator.sameerkumar.website/") +# resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.") +# else: +# raise HTTPException( +# status.HTTP_401_UNAUTHORIZED, f"I don't known how to give '{resource_id}'." +# ) + + # @resource_server.get("/introspect") # async def get_introspect( # request: Request, @@ -114,99 +236,6 @@ async def get_protected_by_foorole_or_barrole( # else: # raise HTTPException(status_code=response.status_code, detail=response.text) - -@resource_server.get("/{id}") -async def get_resource( - id: str, - user: Annotated[User, Depends(get_user_from_token)], - token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], -) -> dict | list: - """Generic path for testing a resource provided by a provider""" - provider = providers[user.auth_provider_id] - if id in [r.id for r in provider.resources]: - return await get_external_resource( - provider=provider, - id=id, - access_token=token["access_token"] if token else None, - user=user, - ) - return await get_resource_(id, user) - - -async def get_external_resource( - provider: Provider, id: str, access_token: str | None, user: User -) -> dict | list: - resource = [r for r in provider.resources if r.id == id][0] - async with AsyncClient() as client: - resp = await client.get( - url=provider.url + resource.url, - headers={ - "Content-type": "application/json", - "Authorization": f"Bearer {access_token}", - }, - ) - if resp.is_error: - raise HTTPException(resp.status_code, f"Cannot fetch resource: {resp.reason_phrase}") - resp_length = len(resp.text) - if resp_length > 1024: - return {"msg": f"The resource is too long ({resp_length} bytes) to show here"} - else: - return resp.json() - - -async def get_resource_(resource_id: str, user: User) -> dict: - """ - Resource processing: build an informative rely as a simple showcase - """ - provider = providers[user.auth_provider_id] - try: - pname = provider.name - except KeyError: - pname = "?" - resp = { - "hello": f"Hi {user.name} from an OAuth resource provider", - "comment": f"I received a request for '{resource_id}' " - + f"with an access token signed by {pname}", - } - # For the demo, resource resource_id matches a scope get:resource_id, - # but this has to be refined for production - required_scope = f"get:{resource_id}" - # Check if the required scope is in the scopes allowed in userinfo - try: - if user.has_scope(required_scope): - await process(user, resource_id, resp) - else: - ## For the showcase, giving a explanation. - ## Alternatively, raise HTTP_401_UNAUTHORIZED - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, - f"No scope {required_scope} in the access token " - + "but it is required for accessing this resource", - ) - except ExpiredSignatureError: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token's signature has expired") - except InvalidTokenError: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token is invalid") - return resp - - -async def process(user, resource_id, resp): - """ - Too simple to be serious. - It's a good fit for a plugin architecture for production - """ - if resource_id == "time": - resp["time"] = datetime.now().strftime("%c") - elif resource_id == "bs": - async with AsyncClient() as client: - bs = await client.get("https://corporatebs-generator.sameerkumar.website/") - resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.") - else: - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, f"I don't known how to give '{resource_id}'." - ) - - # assert user.oidc_provider is not None ### Get some info (TODO: refactor) # if (auth_provider_id := user.oidc_provider.name) is None: diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index f3ac8f3..2acbc3f 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -17,7 +17,7 @@ from starlette.requests import Request class Resource(BaseModel): """A resource with an URL that can be accessed with an OAuth2 access token""" - id: str + resource_name: str name: str url: str diff --git a/src/oidc_test/static/utils.js b/src/oidc_test/static/utils.js index 6c9fae4..978b61c 100644 --- a/src/oidc_test/static/utils.js +++ b/src/oidc_test/static/utils.js @@ -1,6 +1,8 @@ async function checkHref(elem, token, authProvider) { const msg = document.getElementById("msg") - const url = `resource/${elem.getAttribute("resource-id")}` + const resourceName = elem.getAttribute("resource-name") + const resourceId = elem.getAttribute("resource-id") + const url = resourceId ? `resource/${resourceName}/${resourceId}` : `resource/${resourceName}` const resp = await fetch(url, { method: "GET", headers: new Headers({ @@ -28,11 +30,12 @@ function checkPerms(className, token, authProvider) { ) } -async function get_resource(id, token, authProvider) { +async function get_resource(resource_name, token, authProvider, resource_id) { //if (!keycloak.keycloak) { return } const msg = document.getElementById("msg") const resourceElem = document.getElementById('resource') - const resp = await fetch("resource/" + id, { + const url = resource_id ? `resource/${resource_name}/${resource_id}` : `resource/${resource_name}` + const resp = await fetch(url, { method: "GET", headers: new Headers({ "Content-type": "application/json", diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 93d0bc6..790da81 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -62,33 +62,42 @@ {% endif %}
-

- Resources validated by scope: -

-

Resources validated by role:

+ + {% if resource_providers %} +

+ Resource providers (validated by scope): +

+ {% endif %}