diff --git a/.forgejo/workflows/build.yaml b/.forgejo/workflows/build.yaml index 379aaa8..e02bf47 100644 --- a/.forgejo/workflows/build.yaml +++ b/.forgejo/workflows/build.yaml @@ -19,7 +19,7 @@ jobs: - name: Install the latest version of uv uses: astral-sh/setup-uv@v4 with: - version: "0.6.9" + version: "0.5.16" - name: Install run: uv sync @@ -27,26 +27,34 @@ jobs: - name: Run tests (API call) run: .venv/bin/pytest -s tests/basic.py - - name: Get version - run: echo "VERSION=$(.venv/bin/dunamai from any --style semver)" >> $GITHUB_ENV + - name: Get version with git describe + id: version + run: | + echo "version=$(git describe)" >> $GITHUB_OUTPUT + echo "$VERSION" - - name: Version - run: echo $VERSION + - name: Check if the container should be built + id: builder + env: + RUN: ${{ toJSON(inputs.build || !contains(steps.version.outputs.version, '-')) }} + run: | + echo "run=$RUN" >> $GITHUB_OUTPUT + echo "Run build: $RUN" - - name: Get distance from tag - run: echo "DISTANCE=$(.venv/bin/dunamai from any --format '{distance}')" >> $GITHUB_ENV - - - name: Distance - run: echo $DISTANCE + - name: Set the version in pyproject.toml (workaround for uv not supporting dynamic version) + if: fromJSON(steps.builder.outputs.run) + env: + VERSION: ${{ steps.version.outputs.version }} + run: sed "s/0.0.0/$VERSION/" -i pyproject.toml - name: Workaround for bug of podman-login - if: env.DISTANCE == '0' + if: fromJSON(steps.builder.outputs.run) run: | mkdir -p $HOME/.docker echo "{ \"auths\": {} }" > $HOME/.docker/config.json - name: Log in to the container registry (with another workaround) - if: env.DISTANCE == '0' + if: fromJSON(steps.builder.outputs.run) uses: actions/podman-login@v1 with: registry: ${{ vars.REGISTRY }} @@ -55,31 +63,30 @@ jobs: auth_file_path: /tmp/auth.json - name: Build the container image - if: env.DISTANCE == '0' + if: fromJSON(steps.builder.outputs.run) uses: actions/buildah-build@v1 with: image: oidc-fastapi-test oci: true labels: oidc-fastapi-test - tags: "latest ${{ env.VERSION }}" + tags: latest ${{ steps.version.outputs.version }} containerfiles: | ./Containerfile - name: Push the image to the registry - if: env.DISTANCE == '0' + if: fromJSON(steps.builder.outputs.run) uses: actions/push-to-registry@v2 with: registry: "docker://${{ vars.REGISTRY }}/${{ vars.ORGANISATION }}" image: oidc-fastapi-test - tags: "latest ${{ env.VERSION }}" + tags: latest ${{ steps.version.outputs.version }} - name: Build wheel - if: env.DISTANCE == '0' + if: fromJSON(steps.builder.outputs.run) run: uv build --wheel - name: Publish Python package (home) - if: env.DISTANCE == '0' + if: fromJSON(steps.builder.outputs.run) env: LOCAL_PYPI_TOKEN: ${{ secrets.LOCAL_PYPI_TOKEN }} run: uv publish --publish-url https://code.philo.ydns.eu/api/packages/philorg/pypi --token $LOCAL_PYPI_TOKEN - continue-on-error: true diff --git a/.forgejo/workflows/test.yaml b/.forgejo/workflows/test.yaml index f4d994e..a56a9ce 100644 --- a/.forgejo/workflows/test.yaml +++ b/.forgejo/workflows/test.yaml @@ -19,7 +19,7 @@ jobs: - name: Install the latest version of uv uses: astral-sh/setup-uv@v4 with: - version: "0.6.3" + version: "0.5.16" - name: Install run: uv sync diff --git a/Containerfile b/Containerfile index 0ec45d1..2e3fd28 100644 --- a/Containerfile +++ b/Containerfile @@ -1,4 +1,4 @@ -FROM docker.io/library/python:latest +FROM docker.io/library/python:alpine COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/ diff --git a/README.md b/README.md index 68f335d..9e00474 100644 --- a/README.md +++ b/README.md @@ -52,59 +52,31 @@ given by the OIDC providers. For example: ```yaml -secret_key: AVeryWellKeptSecret -debug_token: no -show_token: yes -log: yes - -auth: +oidc: + secret_key: "ASecretNoOneKnows" + show_session_details: yes providers: - id: auth0 name: Okta / Auth0 - url: https:// - public_key_url: https:///pem - client_id: - client_secret: client_secret_generated_by_auth0 - hint: A hint for test credentials + url: "https://" + client_id: "" + client_secret: "client_secret_generated_by_auth0" + hint: "A hint for test credentials" - id: keycloak name: Keycloak at somewhere - url: https:// - info_url: https://philo.ydns.eu/auth/realms/test - account_url_template: /account - client_id: - client_secret: - hint: A hint for test credentials - code_challenge_method: S256 - resource_provider_scopes: - - get:time - - get:bs - resource_providers: - - id: - name: A third party resource provider - base_url: https://some.example.com/ - verify_ssl: yes - resources: - - name: Public RS2 - resource_name: public - url: resource/public - - name: BS RS2 - resource_name: bs - url: resource/bs - - name: Time RS2 - resource_name: time - url: resource/time + url: "https://" + account_url_template: "/account" + client_id: "" + client_secret: "client_secret_generated_by_keycloak" + hint: "User: foo, password: foofoo" - id: codeberg - disabled: no name: Codeberg - url: https://codeberg.org - account_url_template: /user/settings - client_id: - client_secret: client_secret_generated_by_codeberg - info_url: https://codeberg.org/login/oauth/keys - session_key: sub - skip_verify_signature: no + url: "https://codeberg.org" + account_url_template: "/user/settings" + client_id: "" + client_secret: "client_secret_generated_by_codeberg" resources: - name: List of repos id: repos diff --git a/src/oidc_test/log_conf.yaml b/log_conf.yaml similarity index 100% rename from src/oidc_test/log_conf.yaml rename to log_conf.yaml diff --git a/pyproject.toml b/pyproject.toml index c44e9f3..b1e6504 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "oidc-fastapi-test" -#version = "0.0.0" -dynamic = ["version"] +version = "0.0.0" +# dynamic = ["version"] description = "Add your description here" readme = "README.md" requires-python = ">=3.13" @@ -24,21 +24,14 @@ dependencies = [ oidc-test = "oidc_test.main:main" [dependency-groups] -dev = ["dunamai>=1.23.0", "ipdb>=0.13.13", "pytest>=8.3.4"] +dev = ["ipdb>=0.13.13", "pytest>=8.3.4"] [build-system] -requires = ["hatchling", "uv-dynamic-versioning"] +requires = ["hatchling"] build-backend = "hatchling.build" -[tool.hatch.version] -source = "uv-dynamic-versioning" - [tool.hatch.build.targets.wheel] packages = ["src/oidc_test"] -package = true - -[tool.uv-dynamic-versioning] -style = "semver" [tool.uv] package = true diff --git a/src/oidc_test/__init__.py b/src/oidc_test/__init__.py index f449e2b..e69de29 100644 --- a/src/oidc_test/__init__.py +++ b/src/oidc_test/__init__.py @@ -1,11 +0,0 @@ -import importlib.metadata - -try: - from dunamai import Version, Style - - __version__ = Version.from_git().serialize(style=Style.SemVer, dirty=True) -except ImportError: - # __name__ could be used if the package name is the same - # as the directory - not the case here - # __version__ = importlib.metadata.version(__name__) - __version__ = importlib.metadata.version("oidc-fastapi-test") diff --git a/src/oidc_test/auth/provider.py b/src/oidc_test/auth/provider.py index ce288a6..17dcaa0 100644 --- a/src/oidc_test/auth/provider.py +++ b/src/oidc_test/auth/provider.py @@ -7,7 +7,7 @@ from pydantic import ConfigDict from authlib.integrations.starlette_client.apps import StarletteOAuth2App from httpx import AsyncClient -from oidc_test.settings import AuthProviderSettings, ResourceProvider, Resource, settings +from oidc_test.settings import AuthProviderSettings, settings from oidc_test.models import User logger = logging.getLogger("oidc-test") @@ -24,7 +24,6 @@ class Provider(AuthProviderSettings): authlib_client: StarletteOAuth2App = StarletteOAuth2App(None) info: dict[str, Any] = {} unknown_auth_user: User - logout_with_id_token_hint: bool = True def decode(self, token: str, verify_signature: bool | None = None) -> dict[str, Any]: """Decode the token with signature check""" @@ -61,34 +60,28 @@ class Provider(AuthProviderSettings): if self.info_url is not None: try: provider_info = await client.get(self.info_url) - except Exception as err: - logger.debug("Provider_info: cannot connect") - logger.exception(err) + except Exception: raise NoPublicKey try: self.info = provider_info.json() except JSONDecodeError: - logger.debug("Provider_info: cannot decode json response") raise NoPublicKey if "public_key" in self.info: # For Keycloak try: public_key = str(self.info["public_key"]) except KeyError: - logger.debug("Provider_info: cannot get public_key") raise NoPublicKey elif "keys" in self.info: # For Forgejo/Gitea try: public_key = str(self.info["keys"][0]["n"]) except KeyError: - logger.debug("Provider_info: cannot get key 0.n") raise NoPublicKey if self.public_key_url is not None: resp = await client.get(self.public_key_url) public_key = resp.text if public_key is None: - logger.debug("Provider_info: cannot determine public key") raise NoPublicKey self.public_key = "\n".join( ["-----BEGIN PUBLIC KEY-----", public_key, "-----END PUBLIC KEY-----"] @@ -96,18 +89,3 @@ class Provider(AuthProviderSettings): def get_session_key(self, userinfo): return userinfo[self.session_key] - - def get_resource(self, resource_name: str) -> Resource: - return [ - resource for resource in self.resources if resource.resource_name == resource_name - ][0] - - def get_resource_url(self, resource_name: str) -> str: - return self.url + self.get_resource(resource_name).url - - def get_resource_provider(self, resource_provider_id: str) -> ResourceProvider: - return [ - provider - for provider in self.resource_providers - if provider.id == resource_provider_id - ][0] diff --git a/src/oidc_test/auth/utils.py b/src/oidc_test/auth/utils.py index c51b039..acd68b5 100644 --- a/src/oidc_test/auth/utils.py +++ b/src/oidc_test/auth/utils.py @@ -5,9 +5,11 @@ import logging from fastapi import HTTPException, Request, Depends, status from fastapi.security import OAuth2PasswordBearer from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App -from authlib.oauth2.rfc6749 import OAuth2Token from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError +# from authlib.oauth1.auth import OAuthToken +from authlib.oauth2.rfc6749 import OAuth2Token + from oidc_test.auth.provider import Provider from oidc_test.models import User from oidc_test.database import db, TokenNotInDb, UserNotInDB @@ -20,7 +22,7 @@ logger = logging.getLogger("oidc-test") async def fetch_token(name, request): assert name is not None assert request is not None - logger.warning("TODO: fetch_token") + logger.warn("TODO: fetch_token") ... # if name in oidc_providers: # model = OAuth2Token @@ -32,10 +34,7 @@ async def fetch_token(name, request): async def update_token( - provider_id, - token, - refresh_token: str | None = None, - access_token: str | None = None, + provider_id, token, refresh_token: str | None = None, access_token: str | None = None ): """Update the token in the database""" provider = providers[provider_id] @@ -88,7 +87,6 @@ def init_providers(): authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") -oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None: @@ -127,7 +125,7 @@ async def get_current_user(request: Request) -> User: """ if (user_sub := request.session.get("user_sub")) is None: raise HTTPException(status.HTTP_401_UNAUTHORIZED) - token = await get_token_from_session(request) + token = await get_token(request) user = await db.get_user(user_sub) ## Check if the token is expired if token.is_expired(): @@ -148,16 +146,16 @@ async def get_current_user(request: Request) -> User: return user -async def get_token_from_session_or_none(request: Request) -> OAuth2Token | None: +async def get_token_or_none(request: Request) -> OAuth2Token | None: """Return the auth token from the session or None. Can be used in Depends()""" try: - return await get_token_from_session(request) + return await get_token(request) except HTTPException: return None -async def get_token_from_session(request: Request) -> OAuth2Token: +async def get_token(request: Request) -> OAuth2Token: """Return the token from the session. Can be used in Depends()""" try: @@ -275,19 +273,15 @@ async def get_user_from_token( ) return user - async def get_user_from_token_or_none( - token: Annotated[str | None, Depends(oauth2_scheme_optional)], + token: Annotated[str, Depends(oauth2_scheme)], request: Request, ) -> User | None: - if token is None: - return None try: return await get_user_from_token(token, request) except HTTPException: return None - class UserWithRole: roles: set[str] diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index e882cda..9e8b135 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -6,9 +6,6 @@ from typing import Annotated from pathlib import Path from datetime import datetime import logging -import logging.config -import importlib.resources -from yaml import safe_load from urllib.parse import urlencode from contextlib import asynccontextmanager @@ -29,7 +26,6 @@ from authlib.oauth2.rfc6749 import OAuth2Token # from fastapi.security import OpenIdConnect # from pkce import generate_code_verifier, generate_pkce_pair -from oidc_test import __version__ from oidc_test.registry import registry from oidc_test.auth.provider import NoPublicKey, Provider from oidc_test.auth.utils import ( @@ -37,8 +33,8 @@ from oidc_test.auth.utils import ( get_auth_provider_or_none, get_current_user_or_none, authlib_oauth, - get_token_from_session_or_none, - get_token_from_session, + get_token_or_none, + get_token, update_token, ) from oidc_test.auth.utils import init_providers @@ -50,15 +46,6 @@ from oidc_test.resource_server import resource_server logger = logging.getLogger("oidc-test") -if settings.log: - assert __package__ is not None - with ( - importlib.resources.path(__package__) as package_path, - open(package_path / settings.log_config_file) as f, - ): - logging_config = safe_load(f) - logging.config.dictConfig(logging_config) - templates = Jinja2Templates(Path(__file__).parent / "templates") @@ -101,15 +88,15 @@ app.mount("/resource", resource_server, name="resource_server") @app.get("/") async def home( request: Request, - user: Annotated[User | None, Depends(get_current_user_or_none)], + user: Annotated[User, Depends(get_current_user_or_none)], provider: Annotated[Provider | None, Depends(get_auth_provider_or_none)], - token: Annotated[OAuth2Token | None, Depends(get_token_from_session_or_none)], + token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], ) -> HTMLResponse: context = { "show_token": settings.show_token, "user": user, "now": datetime.now(), - "__version__": __version__, + "auth_provider": provider, } if provider is None or token is None: context["providers"] = providers @@ -118,29 +105,26 @@ async def home( context["access_token_parsed"] = None context["refresh_token_parsed"] = None context["resources"] = None - context["auth_provider"] = None else: - context["auth_provider"] = provider context["access_token"] = token["access_token"] + # XXX: resources defined externally? I am confused... try: access_token_parsed = provider.decode(token["access_token"], verify_signature=False) - context["access_token_parsed"] = access_token_parsed - context["access_token_scope"] = access_token_parsed.get("scope") except PyJWTError as err: - context["access_token_parsed"] = {"Cannot parse": err.__class__.__name__} + access_token_parsed = {"Cannot parse": err.__class__.__name__} + try: + context["access_token_scope"] = access_token_parsed["scope"] + except KeyError: context["access_token_scope"] = None + context["id_token_parsed"] = provider.decode(token["id_token"], verify_signature=False) + context["access_token_parsed"] = access_token_parsed + context["resource_providers"] = registry.resource_providers try: - id_token_parsed = provider.decode(token["id_token"], verify_signature=False) - context["id_token_parsed"] = id_token_parsed - except PyJWTError as err: - context["id_token_parsed"] = {"Cannot parse": err.__class__.__name__} - try: - refresh_token_parsed = provider.decode(token["refresh_token"], verify_signature=False) - context["refresh_token_parsed"] = refresh_token_parsed + context["refresh_token_parsed"] = provider.decode( + token["refresh_token"], verify_signature=False + ) except PyJWTError as err: context["refresh_token_parsed"] = {"Cannot parse": err.__class__.__name__} - context["resources"] = registry.resources - context["resource_providers"] = provider.resource_providers return templates.TemplateResponse(name="home.html", request=request, context=context) @@ -262,7 +246,7 @@ async def logout( if ( provider_logout_uri := provider.authlib_client.server_metadata.get("end_session_endpoint") ) is None: - logger.warning(f"Cannot find end_session_endpoint for provider {provider.id}") + logger.warn(f"Cannot find end_session_endpoint for provider {provider.id}") return RedirectResponse(request.url_for("non_compliant_logout")) post_logout_uri = request.url_for("home") # Clear session @@ -271,15 +255,19 @@ async def logout( try: token = await db.get_token(provider, request.session.pop("sid", None)) except TokenNotInDb: - logger.warning("No session in db for the token or no token") + logger.warn("No session in db for the token or no token") return RedirectResponse(request.url_for("home")) - url_query = { - "post_logout_redirect_uri": post_logout_uri, - "client_id": provider.client_id, - } - if provider.logout_with_id_token_hint: - url_query["id_token_hint"] = token["id_token"] - logout_url = f"{provider_logout_uri}?{urlencode(url_query)}" + logout_url = ( + provider_logout_uri + + "?" + + urlencode( + { + "post_logout_redirect_uri": post_logout_uri, + "id_token_hint": token["id_token"], + "cliend_id": "oidc_local_test", + } + ) + ) return RedirectResponse(logout_url) @@ -303,23 +291,16 @@ async def non_compliant_logout( async def refresh( request: Request, provider: Annotated[Provider, Depends(get_auth_provider)], - token: Annotated[OAuth2Token, Depends(get_token_from_session)], + token: Annotated[OAuth2Token, Depends(get_token)], ) -> RedirectResponse: """Manually refresh token""" new_token = await provider.authlib_client.fetch_access_token( refresh_token=token["refresh_token"], grant_type="refresh_token", ) - try: - await update_token(provider.id, new_token) - except PyJWTError as err: - logger.info(f"Cannot refresh token: {err.__class__.__name__}") - raise HTTPException( - status.HTTP_510_NOT_EXTENDED, f"Token refresh error: {err.__class__.__name__}" - ) + await update_token(provider.id, new_token) return RedirectResponse(url=request.url_for("home")) - # Snippet for running standalone # Mostly useful for the --version option, # as running with uvicorn is easy and provides better flexibility, eg. diff --git a/src/oidc_test/registry.py b/src/oidc_test/registry.py index 3b91ad4..e9c9809 100644 --- a/src/oidc_test/registry.py +++ b/src/oidc_test/registry.py @@ -1,7 +1,8 @@ from importlib.metadata import entry_points import logging +from typing import Any -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel from oidc_test.models import User @@ -9,16 +10,14 @@ logger = logging.getLogger("registry") class ProcessResult(BaseModel): - model_config = ConfigDict( - extra="allow", - ) + result: dict[str, Any] = {} class ProcessError(Exception): pass -class Resource(BaseModel): +class ResourceProvider(BaseModel): name: str scope_required: str | None = None role_required: str | None = None @@ -29,19 +28,19 @@ class Resource(BaseModel): super().__init__() self.__id__ = name - async def process(self, user: User | None, resource_id: str | None = None) -> ProcessResult: + async def process(self, user: User, resource_id: str | None = None) -> ProcessResult: logger.warning(f"{self.__id__} should define a process method") return ProcessResult() class ResourceRegistry(BaseModel): - resources: dict[str, Resource] = {} + resource_providers: dict[str, ResourceProvider] = {} def make_registry(self): for ep in entry_points().select(group="oidc_test.resource_provider"): - ResourceClass = ep.load() - if issubclass(ResourceClass, Resource): - self.resources[ep.name] = ResourceClass(ep.name) + ResourceProviderClass = ep.load() + if issubclass(ResourceProviderClass, ResourceProvider): + self.resource_providers[ep.name] = ResourceProviderClass(ep.name) registry = ResourceRegistry() diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index ddc5762..1877875 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -1,10 +1,9 @@ from typing import Annotated, Any import logging -from json import JSONDecodeError from authlib.oauth2.rfc6749 import OAuth2Token -from httpx import AsyncClient, HTTPError -from jwt.exceptions import DecodeError, ExpiredSignatureError, InvalidTokenError +from httpx import AsyncClient +from jwt.exceptions import ExpiredSignatureError, InvalidTokenError from fastapi import FastAPI, HTTPException, Depends, Request, status from fastapi.middleware.cors import CORSMiddleware @@ -14,13 +13,15 @@ from fastapi.middleware.cors import CORSMiddleware from oidc_test.auth.provider import Provider from oidc_test.auth.utils import ( + get_token_or_none, + get_user_from_token, + UserWithRole, get_user_from_token_or_none, - oauth2_scheme_optional, ) from oidc_test.auth_providers import providers -from oidc_test.settings import ResourceProvider, settings +from oidc_test.settings import settings from oidc_test.models import User -from oidc_test.registry import ProcessError, ProcessResult, registry +from oidc_test.registry import ProcessError, ProcessResult, ResourceProvider, registry logger = logging.getLogger("oidc-test") @@ -49,105 +50,60 @@ resource_server.add_middleware( @resource_server.get("/") async def resources() -> dict[str, dict[str, Any]]: - return {"internal": {}, "plugins": registry.resources} + return { + "internal": {}, + "plugins": registry.resource_providers + } + @resource_server.get("/{resource_name}") @resource_server.get("/{resource_name}/{resource_id}") async def get_resource( resource_name: str, - user: Annotated[User | None, Depends(get_user_from_token_or_none)], - token: Annotated[OAuth2Token | None, Depends(oauth2_scheme_optional)], + user: Annotated[User, Depends(get_user_from_token_or_none)], + token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], + request: Request, resource_id: str | None = None, -): - """Generic path for testing a resource provided by a provider. - There's no field validation (response type of ProcessResult) on purpose, - leaving the responsibility of the response validation to resource providers""" - # Get the resource if it's defined in user auth provider's resources (external) - if user is not None: - provider = providers[user.auth_provider_id] - if ":" in resource_name: - # Third-party resource provider: send the request with the request token - resource_provider_id, resource_name = resource_name.split(":", 1) - provider = providers[user.auth_provider_id] - resource_provider: ResourceProvider = provider.get_resource_provider( - resource_provider_id - ) - resource_url = resource_provider.get_resource_url(resource_name) - async with AsyncClient(verify=resource_provider.verify_ssl) as client: - try: - logger.debug(f"GET request to {resource_url}") - resp = await client.get( - resource_url, - headers={ - "Content-type": "application/json", - "Authorization": f"Bearer {token}", - "auth_provider": user.auth_provider_id, - }, - ) - except HTTPError as err: - raise HTTPException( - status.HTTP_503_SERVICE_UNAVAILABLE, err.__class__.__name__ - ) - except Exception as err: - raise HTTPException( - status.HTTP_500_INTERNAL_SERVER_ERROR, err.__class__.__name__ - ) - else: - if resp.is_success: - return resp.json() - else: - reason_str: str - try: - reason_str = resp.json().get("detail", str(resp)) - except Exception: - reason_str = str(resp.text) - raise HTTPException(resp.status_code, reason_str) - # Third party resource (provided through the auth provider) - # The token is just passed on - # XXX: is this branch valid anymore? - if resource_name in [r.resource_name for r in provider.resources]: - return await get_auth_provider_resource( - provider=provider, - resource_name=resource_name, - token=token, - user=user, - ) +) -> ProcessResult: + """Generic path for testing a resource provided by a provider""" + provider = providers[user.auth_provider_id] + # Third party resource (provided through the auth provider) + # The token is just passed on + if resource_name in [r.resource_name for r in provider.resources]: + return await get_auth_provider_resource( + provider=provider, + resource_name=resource_name, + token=token, + user=user, + ) # Internal resource (provided here) - if resource_name in registry.resources: - resource = registry.resources[resource_name] - reason: dict[str, str] = {} - if not resource.is_public: - if user is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Resource is not public") - else: - if resource.scope_required is not None and not user.has_scope( - resource.scope_required - ): - reason["scope"] = ( - f"No scope {resource.scope_required} in the access token " - "but it is required for accessing this resource" - ) - if ( - resource.role_required is not None - and resource.role_required not in user.roles_as_set - ): - reason["role"] = ( - f"You don't have the role {resource.role_required} " - "but it is required for accessing this resource" - ) - if len(reason) == 0: + if resource_name in registry.resource_providers: + resource_provider = registry.resource_providers[resource_name] + reasons: dict[str, str] = {} + if not resource_provider.is_public: + if resource_provider.scope_required is not None and not user.has_scope( + resource_provider.scope_required + ): + reasons["scope"] = f"No scope {resource_provider.scope_required} in the access token " \ + "but it is required for accessing this resource" + if resource_provider.role_required is not None \ + and resource_provider.role_required not in user.roles_as_set: + reasons["role"] = f"You don't have the role {resource_provider.role_required} " \ + "but it is required for accessing this resource" + if len(reasons) == 0: try: - resp = await resource.process(user=user, resource_id=resource_id) - return resp + return await resource_provider.process(user=user, resource_id=resource_id) except ProcessError as err: raise HTTPException( status.HTTP_401_UNAUTHORIZED, f"Cannot process resource: {err}" ) else: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, ", ".join(reason.values())) + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, ", ".join(reasons.values()) + ) else: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Unknown resource") + raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown resource {resource_name}") # return await get_resource_(resource_name, user, **request.query_params) @@ -155,11 +111,14 @@ async def get_auth_provider_resource( provider: Provider, resource_name: str, token: OAuth2Token | None, user: User ) -> ProcessResult: if token is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No auth token") - access_token = token + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, f"No auth token" + ) + access_token = token["access_token"] + resource = [r for r in provider.resources if r.resource_name == resource_name][0] async with AsyncClient() as client: resp = await client.get( - url=provider.get_resource_url(resource_name), + url=provider.url + resource.url, headers={ "Content-type": "application/json", "Authorization": f"Bearer {access_token}", @@ -171,59 +130,51 @@ async def get_auth_provider_resource( resp_length = len(resp.text) if resp_length > 1024: return ProcessResult( - msg=f"The resource is too long ({resp_length} bytes) to show in this demo, here is just the begining in raw format", - start=resp.text[:100] + "...", + result={"msg": f"The resource is too long ({resp_length} bytes) to show here"} ) else: - try: - resp_json = resp.json() - except JSONDecodeError: - return ProcessResult(msg="The resource is not formatted in JSON", text=resp.text) - if isinstance(resp_json, dict): - return ProcessResult(**resp.json()) - elif isinstance(resp_json, list): - return ProcessResult(**{str(i): line for i, line in enumerate(resp_json)}) + return ProcessResult(result=resp.json()) -# @resource_server.get("/public") -# async def public() -> dict: +#@resource_server.get("/public") +#async def public() -> dict: # return {"msg": "Not protected"} # # -# @resource_server.get("/protected") -# async def get_protected(user: Annotated[User, Depends(get_user_from_token)]): +#@resource_server.get("/protected") +#async def get_protected(user: Annotated[User, Depends(get_user_from_token)]): # assert user is not None # Just to keep QA checks happy # return {"msg": "Only authenticated users can see this"} # # -# @resource_server.get("/protected-by-foorole") -# async def get_protected_by_foorole( +#@resource_server.get("/protected-by-foorole") +#async def get_protected_by_foorole( # user: Annotated[User, Depends(UserWithRole("foorole"))], -# ): +#): # assert user is not None # return {"msg": "Only users with foorole can see this"} # # -# @resource_server.get("/protected-by-barrole") -# async def get_protected_by_barrole( +#@resource_server.get("/protected-by-barrole") +#async def get_protected_by_barrole( # user: Annotated[User, Depends(UserWithRole("barrole"))], -# ): +#): # assert user is not None # return {"msg": "Protected by barrole"} # # -# @resource_server.get("/protected-by-foorole-and-barrole") -# async def get_protected_by_foorole_and_barrole( +#@resource_server.get("/protected-by-foorole-and-barrole") +#async def get_protected_by_foorole_and_barrole( # user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))], -# ): +#): # assert user is not None # Just to keep QA checks happy # return {"msg": "Only users with foorole and barrole can see this"} # # -# @resource_server.get("/protected-by-foorole-or-barrole") -# async def get_protected_by_foorole_or_barrole( +#@resource_server.get("/protected-by-foorole-or-barrole") +#async def get_protected_by_foorole_or_barrole( # user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))], -# ): +#): # assert user is not None # Just to keep QA checks happy # return {"msg": "Only users with foorole or barrole can see this"} diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index ad80c06..3e7001c 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -4,7 +4,7 @@ import random from typing import Type, Tuple from pathlib import Path -from pydantic import AnyHttpUrl, BaseModel, computed_field, AnyUrl +from pydantic import BaseModel, computed_field, AnyUrl from pydantic_settings import ( BaseSettings, SettingsConfigDict, @@ -22,22 +22,6 @@ class Resource(BaseModel): url: str -class ResourceProvider(BaseModel): - id: str - name: str - base_url: AnyUrl - resources: list[Resource] = [] - verify_ssl: bool = True - - def get_resource(self, resource_name: str) -> Resource: - return [ - resource for resource in self.resources if resource.resource_name == resource_name - ][0] - - def get_resource_url(self, resource_name: str) -> str: - return f"{self.base_url}{self.get_resource(resource_name).url}" - - class AuthProviderSettings(BaseModel): """Auth provider, can also be a resource server""" @@ -61,7 +45,6 @@ class AuthProviderSettings(BaseModel): session_key: str = "sid" skip_verify_signature: bool = True disabled: bool = False - resource_providers: list[ResourceProvider] = [] @computed_field @property @@ -84,6 +67,13 @@ class AuthProviderSettings(BaseModel): return None +class ResourceProvider(BaseModel): + id: str + name: str + base_url: AnyUrl + resources: list[Resource] = [] + + class AuthSettings(BaseModel): show_session_details: bool = False providers: list[AuthProviderSettings] = [] @@ -102,14 +92,13 @@ class Settings(BaseSettings): model_config = SettingsConfigDict(env_nested_delimiter="__") auth: AuthSettings = AuthSettings() + resource_providers: list[ResourceProvider] = [] secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16)) log: bool = False - log_config_file: str = "log_conf.yaml" insecure: Insecure = Insecure() cors_origins: list[str] = [] debug_token: bool = False show_token: bool = False - show_external_resource_providers_links: bool = False @classmethod def settings_customise_sources( diff --git a/src/oidc_test/static/styles.css b/src/oidc_test/static/styles.css index 1e8dc03..2baa748 100644 --- a/src/oidc_test/static/styles.css +++ b/src/oidc_test/static/styles.css @@ -21,12 +21,6 @@ hr { .hidden { display: none; } -.version { - position: absolute; - font-size: 75%; - top: 0.3em; - right: 0.3em; -} .center { text-align: center; } diff --git a/src/oidc_test/static/utils.js b/src/oidc_test/static/utils.js index e988dfe..978b61c 100644 --- a/src/oidc_test/static/utils.js +++ b/src/oidc_test/static/utils.js @@ -2,9 +2,7 @@ async function checkHref(elem, token, authProvider) { const msg = document.getElementById("msg") const resourceName = elem.getAttribute("resource-name") const resourceId = elem.getAttribute("resource-id") - const resourceProviderId = elem.getAttribute("resource-provider-id") ? elem.getAttribute("resource-provider-id") : "" - const fqResourceName = resourceProviderId ? `${resourceProviderId}:${resourceName}` : resourceName - const url = resourceId ? `resource/${fqResourceName}/${resourceId}` : `resource/${fqResourceName}` + const url = resourceId ? `resource/${resourceName}/${resourceId}` : `resource/${resourceName}` const resp = await fetch(url, { method: "GET", headers: new Headers({ @@ -32,13 +30,11 @@ function checkPerms(className, token, authProvider) { ) } -async function get_resource(resourceName, token, authProvider, resourceId, resourceProviderId) { - // BaseUrl for an external resource provider +async function get_resource(resource_name, token, authProvider, resource_id) { //if (!keycloak.keycloak) { return } const msg = document.getElementById("msg") const resourceElem = document.getElementById('resource') - const fqResourceName = resourceProviderId ? `${resourceProviderId}:${resourceName}` : resourceName - const url = resourceId ? `resource/${fqResourceName}/${resourceId}` : `resource/${fqResourceName}` + const url = resource_id ? `resource/${resource_name}/${resource_id}` : `resource/${resource_name}` const resp = await fetch(url, { method: "GET", headers: new Headers({ diff --git a/src/oidc_test/templates/base.html b/src/oidc_test/templates/base.html index 157e26f..4cb56f5 100644 --- a/src/oidc_test/templates/base.html +++ b/src/oidc_test/templates/base.html @@ -5,7 +5,6 @@ -
v. {{ __version__}}

OIDC-test - FastAPI client

{% block content %} {% endblock %} diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 167616f..6c4e6a6 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -66,67 +66,29 @@ {% endif %}
- {% if resources %} -

This application provides all these resources, eventually protected with scope or roles:

+ {% if resource_providers %} +

+ {{ auth_provider.name }} provides these resources: +

{% endif %} - {% if auth_provider.resources %} -

{{ auth_provider.name }} is also defined as a provider for these resources:

- - {% endif %} - {% if resource_providers %} -

{{ auth_provider.name }} allows this application to request resources from third party resource providers:

- {% for resource_provider in resource_providers %} - - {% endfor %} - {% endif %}
diff --git a/uv.lock b/uv.lock index 0566bb5..01b64de 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,4 @@ version = 1 -revision = 1 requires-python = ">=3.13" [[package]] @@ -207,18 +206,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/68/1b/e0a87d256e40e8c888847551b20a017a6b98139178505dc7ffb96f04e954/dnspython-2.7.0-py3-none-any.whl", hash = "sha256:b4c34b7d10b51bcc3a5071e7b8dee77939f1e878477eeecc965e9835f63c6c86", size = 313632 }, ] -[[package]] -name = "dunamai" -version = "1.23.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "packaging" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/06/4e/a5c8c337a1d9ac0384298ade02d322741fb5998041a5ea74d1cd2a4a1d47/dunamai-1.23.0.tar.gz", hash = "sha256:a163746de7ea5acb6dacdab3a6ad621ebc612ed1e528aaa8beedb8887fccd2c4", size = 44681 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/21/4c/963169386309fec4f96fd61210ac0a0666887d0fb0a50205395674d20b71/dunamai-1.23.0-py3-none-any.whl", hash = "sha256:a0906d876e92441793c6a423e16a4802752e723e9c9a5aabdc5535df02dbe041", size = 26342 }, -] - [[package]] name = "ecdsa" version = "0.19.0" @@ -495,6 +482,7 @@ wheels = [ [[package]] name = "oidc-fastapi-test" +version = "0.0.0" source = { editable = "." } dependencies = [ { name = "authlib" }, @@ -513,7 +501,6 @@ dependencies = [ [package.dev-dependencies] dev = [ - { name = "dunamai" }, { name = "ipdb" }, { name = "pytest" }, ] @@ -536,7 +523,6 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ - { name = "dunamai", specifier = ">=1.23.0" }, { name = "ipdb", specifier = ">=0.13.13" }, { name = "pytest", specifier = ">=8.3.4" }, ]