diff --git a/.forgejo/workflows/build.yaml b/.forgejo/workflows/build.yaml index e02bf47..379aaa8 100644 --- a/.forgejo/workflows/build.yaml +++ b/.forgejo/workflows/build.yaml @@ -19,7 +19,7 @@ jobs: - name: Install the latest version of uv uses: astral-sh/setup-uv@v4 with: - version: "0.5.16" + version: "0.6.9" - name: Install run: uv sync @@ -27,34 +27,26 @@ jobs: - name: Run tests (API call) run: .venv/bin/pytest -s tests/basic.py - - name: Get version with git describe - id: version - run: | - echo "version=$(git describe)" >> $GITHUB_OUTPUT - echo "$VERSION" + - name: Get version + run: echo "VERSION=$(.venv/bin/dunamai from any --style semver)" >> $GITHUB_ENV - - name: Check if the container should be built - id: builder - env: - RUN: ${{ toJSON(inputs.build || !contains(steps.version.outputs.version, '-')) }} - run: | - echo "run=$RUN" >> $GITHUB_OUTPUT - echo "Run build: $RUN" + - name: Version + run: echo $VERSION - - name: Set the version in pyproject.toml (workaround for uv not supporting dynamic version) - if: fromJSON(steps.builder.outputs.run) - env: - VERSION: ${{ steps.version.outputs.version }} - run: sed "s/0.0.0/$VERSION/" -i pyproject.toml + - name: Get distance from tag + run: echo "DISTANCE=$(.venv/bin/dunamai from any --format '{distance}')" >> $GITHUB_ENV + + - name: Distance + run: echo $DISTANCE - name: Workaround for bug of podman-login - if: fromJSON(steps.builder.outputs.run) + if: env.DISTANCE == '0' run: | mkdir -p $HOME/.docker echo "{ \"auths\": {} }" > $HOME/.docker/config.json - name: Log in to the container registry (with another workaround) - if: fromJSON(steps.builder.outputs.run) + if: env.DISTANCE == '0' uses: actions/podman-login@v1 with: registry: ${{ vars.REGISTRY }} @@ -63,30 +55,31 @@ jobs: auth_file_path: /tmp/auth.json - name: Build the container image - if: fromJSON(steps.builder.outputs.run) + if: env.DISTANCE == '0' uses: actions/buildah-build@v1 with: image: oidc-fastapi-test oci: true labels: oidc-fastapi-test - tags: latest ${{ steps.version.outputs.version }} + tags: "latest ${{ env.VERSION }}" containerfiles: | ./Containerfile - name: Push the image to the registry - if: fromJSON(steps.builder.outputs.run) + if: env.DISTANCE == '0' uses: actions/push-to-registry@v2 with: registry: "docker://${{ vars.REGISTRY }}/${{ vars.ORGANISATION }}" image: oidc-fastapi-test - tags: latest ${{ steps.version.outputs.version }} + tags: "latest ${{ env.VERSION }}" - name: Build wheel - if: fromJSON(steps.builder.outputs.run) + if: env.DISTANCE == '0' run: uv build --wheel - name: Publish Python package (home) - if: fromJSON(steps.builder.outputs.run) + if: env.DISTANCE == '0' env: LOCAL_PYPI_TOKEN: ${{ secrets.LOCAL_PYPI_TOKEN }} run: uv publish --publish-url https://code.philo.ydns.eu/api/packages/philorg/pypi --token $LOCAL_PYPI_TOKEN + continue-on-error: true diff --git a/.forgejo/workflows/test.yaml b/.forgejo/workflows/test.yaml index a56a9ce..f4d994e 100644 --- a/.forgejo/workflows/test.yaml +++ b/.forgejo/workflows/test.yaml @@ -19,7 +19,7 @@ jobs: - name: Install the latest version of uv uses: astral-sh/setup-uv@v4 with: - version: "0.5.16" + version: "0.6.3" - name: Install run: uv sync diff --git a/Containerfile b/Containerfile index aef57f8..0ec45d1 100644 --- a/Containerfile +++ b/Containerfile @@ -1,4 +1,4 @@ -FROM docker.io/library/python:alpine +FROM docker.io/library/python:latest COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/ @@ -9,6 +9,9 @@ WORKDIR /app RUN uv pip install --system . +# Add demo plugin +RUN PIP_EXTRA_INDEX_URL=https://pypi.org/simple/ uv pip install --system --index-url https://code.philo.ydns.eu/api/packages/philorg/pypi/simple/ oidc-fastapi-test-resource-provider-demo + # Possible to run with: #CMD ["oidc-test", "--port", "80"] #CMD ["fastapi", "run", "src/oidc_test/main.py", "--port", "8873", "--root-path", "/oidc-test"] diff --git a/README.md b/README.md index 9e00474..68f335d 100644 --- a/README.md +++ b/README.md @@ -52,31 +52,59 @@ given by the OIDC providers. For example: ```yaml -oidc: - secret_key: "ASecretNoOneKnows" - show_session_details: yes +secret_key: AVeryWellKeptSecret +debug_token: no +show_token: yes +log: yes + +auth: providers: - id: auth0 name: Okta / Auth0 - url: "https://" - client_id: "" - client_secret: "client_secret_generated_by_auth0" - hint: "A hint for test credentials" + url: https:// + public_key_url: https:///pem + client_id: + client_secret: client_secret_generated_by_auth0 + hint: A hint for test credentials - id: keycloak name: Keycloak at somewhere - url: "https://" - account_url_template: "/account" - client_id: "" - client_secret: "client_secret_generated_by_keycloak" - hint: "User: foo, password: foofoo" + url: https:// + info_url: https://philo.ydns.eu/auth/realms/test + account_url_template: /account + client_id: + client_secret: + hint: A hint for test credentials + code_challenge_method: S256 + resource_provider_scopes: + - get:time + - get:bs + resource_providers: + - id: + name: A third party resource provider + base_url: https://some.example.com/ + verify_ssl: yes + resources: + - name: Public RS2 + resource_name: public + url: resource/public + - name: BS RS2 + resource_name: bs + url: resource/bs + - name: Time RS2 + resource_name: time + url: resource/time - id: codeberg + disabled: no name: Codeberg - url: "https://codeberg.org" - account_url_template: "/user/settings" - client_id: "" - client_secret: "client_secret_generated_by_codeberg" + url: https://codeberg.org + account_url_template: /user/settings + client_id: + client_secret: client_secret_generated_by_codeberg + info_url: https://codeberg.org/login/oauth/keys + session_key: sub + skip_verify_signature: no resources: - name: List of repos id: repos diff --git a/pyproject.toml b/pyproject.toml index b1e6504..c44e9f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "oidc-fastapi-test" -version = "0.0.0" -# dynamic = ["version"] +#version = "0.0.0" +dynamic = ["version"] description = "Add your description here" readme = "README.md" requires-python = ">=3.13" @@ -24,14 +24,21 @@ dependencies = [ oidc-test = "oidc_test.main:main" [dependency-groups] -dev = ["ipdb>=0.13.13", "pytest>=8.3.4"] +dev = ["dunamai>=1.23.0", "ipdb>=0.13.13", "pytest>=8.3.4"] [build-system] -requires = ["hatchling"] +requires = ["hatchling", "uv-dynamic-versioning"] build-backend = "hatchling.build" +[tool.hatch.version] +source = "uv-dynamic-versioning" + [tool.hatch.build.targets.wheel] packages = ["src/oidc_test"] +package = true + +[tool.uv-dynamic-versioning] +style = "semver" [tool.uv] package = true diff --git a/src/oidc_test/__init__.py b/src/oidc_test/__init__.py index e69de29..f449e2b 100644 --- a/src/oidc_test/__init__.py +++ b/src/oidc_test/__init__.py @@ -0,0 +1,11 @@ +import importlib.metadata + +try: + from dunamai import Version, Style + + __version__ = Version.from_git().serialize(style=Style.SemVer, dirty=True) +except ImportError: + # __name__ could be used if the package name is the same + # as the directory - not the case here + # __version__ = importlib.metadata.version(__name__) + __version__ = importlib.metadata.version("oidc-fastapi-test") diff --git a/src/oidc_test/auth/provider.py b/src/oidc_test/auth/provider.py index 17dcaa0..ce288a6 100644 --- a/src/oidc_test/auth/provider.py +++ b/src/oidc_test/auth/provider.py @@ -7,7 +7,7 @@ from pydantic import ConfigDict from authlib.integrations.starlette_client.apps import StarletteOAuth2App from httpx import AsyncClient -from oidc_test.settings import AuthProviderSettings, settings +from oidc_test.settings import AuthProviderSettings, ResourceProvider, Resource, settings from oidc_test.models import User logger = logging.getLogger("oidc-test") @@ -24,6 +24,7 @@ class Provider(AuthProviderSettings): authlib_client: StarletteOAuth2App = StarletteOAuth2App(None) info: dict[str, Any] = {} unknown_auth_user: User + logout_with_id_token_hint: bool = True def decode(self, token: str, verify_signature: bool | None = None) -> dict[str, Any]: """Decode the token with signature check""" @@ -60,28 +61,34 @@ class Provider(AuthProviderSettings): if self.info_url is not None: try: provider_info = await client.get(self.info_url) - except Exception: + except Exception as err: + logger.debug("Provider_info: cannot connect") + logger.exception(err) raise NoPublicKey try: self.info = provider_info.json() except JSONDecodeError: + logger.debug("Provider_info: cannot decode json response") raise NoPublicKey if "public_key" in self.info: # For Keycloak try: public_key = str(self.info["public_key"]) except KeyError: + logger.debug("Provider_info: cannot get public_key") raise NoPublicKey elif "keys" in self.info: # For Forgejo/Gitea try: public_key = str(self.info["keys"][0]["n"]) except KeyError: + logger.debug("Provider_info: cannot get key 0.n") raise NoPublicKey if self.public_key_url is not None: resp = await client.get(self.public_key_url) public_key = resp.text if public_key is None: + logger.debug("Provider_info: cannot determine public key") raise NoPublicKey self.public_key = "\n".join( ["-----BEGIN PUBLIC KEY-----", public_key, "-----END PUBLIC KEY-----"] @@ -89,3 +96,18 @@ class Provider(AuthProviderSettings): def get_session_key(self, userinfo): return userinfo[self.session_key] + + def get_resource(self, resource_name: str) -> Resource: + return [ + resource for resource in self.resources if resource.resource_name == resource_name + ][0] + + def get_resource_url(self, resource_name: str) -> str: + return self.url + self.get_resource(resource_name).url + + def get_resource_provider(self, resource_provider_id: str) -> ResourceProvider: + return [ + provider + for provider in self.resource_providers + if provider.id == resource_provider_id + ][0] diff --git a/src/oidc_test/auth/utils.py b/src/oidc_test/auth/utils.py index 9479c48..c51b039 100644 --- a/src/oidc_test/auth/utils.py +++ b/src/oidc_test/auth/utils.py @@ -5,11 +5,9 @@ import logging from fastapi import HTTPException, Request, Depends, status from fastapi.security import OAuth2PasswordBearer from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App +from authlib.oauth2.rfc6749 import OAuth2Token from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError -# from authlib.oauth1.auth import OAuthToken -from authlib.oauth2.auth import OAuth2Token - from oidc_test.auth.provider import Provider from oidc_test.models import User from oidc_test.database import db, TokenNotInDb, UserNotInDB @@ -22,7 +20,7 @@ logger = logging.getLogger("oidc-test") async def fetch_token(name, request): assert name is not None assert request is not None - logger.warn("TODO: fetch_token") + logger.warning("TODO: fetch_token") ... # if name in oidc_providers: # model = OAuth2Token @@ -34,7 +32,10 @@ async def fetch_token(name, request): async def update_token( - provider_id, token, refresh_token: str | None = None, access_token: str | None = None + provider_id, + token, + refresh_token: str | None = None, + access_token: str | None = None, ): """Update the token in the database""" provider = providers[provider_id] @@ -60,30 +61,34 @@ def init_providers(): sub="", auth_provider_id=provider_settings.id ) provider = Provider(**provider_settings_dict) - authlib_oauth.register( - name=provider.id, - server_metadata_url=provider.openid_configuration, - client_kwargs={ - "scope": " ".join( - ["openid", "email", "offline_access", "profile"] - + provider.resource_provider_scopes - ), - }, - client_id=provider.client_id, - client_secret=provider.client_secret, - api_base_url=provider.url, - # For PKCE (not implemented yet): - # code_challenge_method="S256", - fetch_token=fetch_token, - update_token=update_token, - # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) - ) - provider.authlib_client = getattr(authlib_oauth, provider.id) + if provider.disabled: + logger.info(f"{provider_settings.name} is disabled, skipping") + else: + authlib_oauth.register( + name=provider.id, + server_metadata_url=provider.openid_configuration, + client_kwargs={ + "scope": " ".join( + ["openid", "email", "offline_access", "profile"] + + provider.resource_provider_scopes + ), + }, + client_id=provider.client_id, + client_secret=provider.client_secret, + api_base_url=provider.url, + # For PKCE (not implemented yet): + # code_challenge_method="S256", + fetch_token=fetch_token, + update_token=update_token, + # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) + ) + provider.authlib_client = getattr(authlib_oauth, provider.id) providers[provider.id] = provider authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") +oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None: @@ -122,7 +127,7 @@ async def get_current_user(request: Request) -> User: """ if (user_sub := request.session.get("user_sub")) is None: raise HTTPException(status.HTTP_401_UNAUTHORIZED) - token = await get_token(request) + token = await get_token_from_session(request) user = await db.get_user(user_sub) ## Check if the token is expired if token.is_expired(): @@ -143,16 +148,16 @@ async def get_current_user(request: Request) -> User: return user -async def get_token_or_none(request: Request) -> OAuth2Token | None: +async def get_token_from_session_or_none(request: Request) -> OAuth2Token | None: """Return the auth token from the session or None. Can be used in Depends()""" try: - return await get_token(request) + return await get_token_from_session(request) except HTTPException: return None -async def get_token(request: Request) -> OAuth2Token: +async def get_token_from_session(request: Request) -> OAuth2Token: """Return the token from the session. Can be used in Depends()""" try: @@ -271,6 +276,18 @@ async def get_user_from_token( return user +async def get_user_from_token_or_none( + token: Annotated[str | None, Depends(oauth2_scheme_optional)], + request: Request, +) -> User | None: + if token is None: + return None + try: + return await get_user_from_token(token, request) + except HTTPException: + return None + + class UserWithRole: roles: set[str] diff --git a/log_conf.yaml b/src/oidc_test/log_conf.yaml similarity index 100% rename from log_conf.yaml rename to src/oidc_test/log_conf.yaml diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 28eab8a..e882cda 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -6,6 +6,9 @@ from typing import Annotated from pathlib import Path from datetime import datetime import logging +import logging.config +import importlib.resources +from yaml import safe_load from urllib.parse import urlencode from contextlib import asynccontextmanager @@ -26,6 +29,7 @@ from authlib.oauth2.rfc6749 import OAuth2Token # from fastapi.security import OpenIdConnect # from pkce import generate_code_verifier, generate_pkce_pair +from oidc_test import __version__ from oidc_test.registry import registry from oidc_test.auth.provider import NoPublicKey, Provider from oidc_test.auth.utils import ( @@ -33,8 +37,8 @@ from oidc_test.auth.utils import ( get_auth_provider_or_none, get_current_user_or_none, authlib_oauth, - get_token_or_none, - get_token, + get_token_from_session_or_none, + get_token_from_session, update_token, ) from oidc_test.auth.utils import init_providers @@ -46,6 +50,15 @@ from oidc_test.resource_server import resource_server logger = logging.getLogger("oidc-test") +if settings.log: + assert __package__ is not None + with ( + importlib.resources.path(__package__) as package_path, + open(package_path / settings.log_config_file) as f, + ): + logging_config = safe_load(f) + logging.config.dictConfig(logging_config) + templates = Jinja2Templates(Path(__file__).parent / "templates") @@ -55,10 +68,12 @@ async def lifespan(app: FastAPI): init_providers() registry.make_registry() for provider in list(providers.values()): + if provider.disabled: + continue try: await provider.get_info() except NoPublicKey: - logger.warn(f"Disable {provider.id}: public key not found") + logger.warning(f"Disable {provider.id}: public key not found") del providers[provider.id] yield @@ -86,15 +101,15 @@ app.mount("/resource", resource_server, name="resource_server") @app.get("/") async def home( request: Request, - user: Annotated[User, Depends(get_current_user_or_none)], + user: Annotated[User | None, Depends(get_current_user_or_none)], provider: Annotated[Provider | None, Depends(get_auth_provider_or_none)], - token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], + token: Annotated[OAuth2Token | None, Depends(get_token_from_session_or_none)], ) -> HTMLResponse: context = { "show_token": settings.show_token, "user": user, "now": datetime.now(), - "auth_provider": provider, + "__version__": __version__, } if provider is None or token is None: context["providers"] = providers @@ -103,27 +118,29 @@ async def home( context["access_token_parsed"] = None context["refresh_token_parsed"] = None context["resources"] = None + context["auth_provider"] = None else: + context["auth_provider"] = provider context["access_token"] = token["access_token"] - # XXX: resources defined externally? I am confused... - context["resources"] = provider.resources try: access_token_parsed = provider.decode(token["access_token"], verify_signature=False) + context["access_token_parsed"] = access_token_parsed + context["access_token_scope"] = access_token_parsed.get("scope") except PyJWTError as err: - access_token_parsed = {"Cannot parse": err.__class__.__name__} - try: - context["access_token_scope"] = access_token_parsed["scope"] - except KeyError: + context["access_token_parsed"] = {"Cannot parse": err.__class__.__name__} context["access_token_scope"] = None - context["id_token_parsed"] = provider.decode(token["id_token"], verify_signature=False) - context["access_token_parsed"] = access_token_parsed - context["resource_providers"] = registry.resource_providers try: - context["refresh_token_parsed"] = provider.decode( - token["refresh_token"], verify_signature=False - ) + id_token_parsed = provider.decode(token["id_token"], verify_signature=False) + context["id_token_parsed"] = id_token_parsed + except PyJWTError as err: + context["id_token_parsed"] = {"Cannot parse": err.__class__.__name__} + try: + refresh_token_parsed = provider.decode(token["refresh_token"], verify_signature=False) + context["refresh_token_parsed"] = refresh_token_parsed except PyJWTError as err: context["refresh_token_parsed"] = {"Cannot parse": err.__class__.__name__} + context["resources"] = registry.resources + context["resource_providers"] = provider.resource_providers return templates.TemplateResponse(name="home.html", request=request, context=context) @@ -245,7 +262,7 @@ async def logout( if ( provider_logout_uri := provider.authlib_client.server_metadata.get("end_session_endpoint") ) is None: - logger.warn(f"Cannot find end_session_endpoint for provider {provider.id}") + logger.warning(f"Cannot find end_session_endpoint for provider {provider.id}") return RedirectResponse(request.url_for("non_compliant_logout")) post_logout_uri = request.url_for("home") # Clear session @@ -254,19 +271,15 @@ async def logout( try: token = await db.get_token(provider, request.session.pop("sid", None)) except TokenNotInDb: - logger.warn("No session in db for the token or no token") + logger.warning("No session in db for the token or no token") return RedirectResponse(request.url_for("home")) - logout_url = ( - provider_logout_uri - + "?" - + urlencode( - { - "post_logout_redirect_uri": post_logout_uri, - "id_token_hint": token["id_token"], - "cliend_id": "oidc_local_test", - } - ) - ) + url_query = { + "post_logout_redirect_uri": post_logout_uri, + "client_id": provider.client_id, + } + if provider.logout_with_id_token_hint: + url_query["id_token_hint"] = token["id_token"] + logout_url = f"{provider_logout_uri}?{urlencode(url_query)}" return RedirectResponse(logout_url) @@ -290,14 +303,20 @@ async def non_compliant_logout( async def refresh( request: Request, provider: Annotated[Provider, Depends(get_auth_provider)], - token: Annotated[OAuth2Token, Depends(get_token)], + token: Annotated[OAuth2Token, Depends(get_token_from_session)], ) -> RedirectResponse: """Manually refresh token""" new_token = await provider.authlib_client.fetch_access_token( refresh_token=token["refresh_token"], grant_type="refresh_token", ) - await update_token(provider.id, new_token) + try: + await update_token(provider.id, new_token) + except PyJWTError as err: + logger.info(f"Cannot refresh token: {err.__class__.__name__}") + raise HTTPException( + status.HTTP_510_NOT_EXTENDED, f"Token refresh error: {err.__class__.__name__}" + ) return RedirectResponse(url=request.url_for("home")) diff --git a/src/oidc_test/registry.py b/src/oidc_test/registry.py index 6db0a47..3b91ad4 100644 --- a/src/oidc_test/registry.py +++ b/src/oidc_test/registry.py @@ -1,8 +1,7 @@ from importlib.metadata import entry_points import logging -from typing import Any -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from oidc_test.models import User @@ -10,34 +9,39 @@ logger = logging.getLogger("registry") class ProcessResult(BaseModel): - result: dict[str, Any] = {} + model_config = ConfigDict( + extra="allow", + ) class ProcessError(Exception): pass -class ResourceProvider: +class Resource(BaseModel): name: str scope_required: str | None = None + role_required: str | None = None + is_public: bool = False default_resource_id: str | None = None def __init__(self, name: str): - self.name = name + super().__init__() + self.__id__ = name - async def process(self, user: User, resource_id: str | None = None) -> ProcessResult: - logger.warning(f"{self.name} should define a process method") + async def process(self, user: User | None, resource_id: str | None = None) -> ProcessResult: + logger.warning(f"{self.__id__} should define a process method") return ProcessResult() -class ResourceRegistry: - resource_providers: dict[str, ResourceProvider] = {} +class ResourceRegistry(BaseModel): + resources: dict[str, Resource] = {} def make_registry(self): for ep in entry_points().select(group="oidc_test.resource_provider"): - ResourceProviderClass = ep.load() - if issubclass(ResourceProviderClass, ResourceProvider): - self.resource_providers[ep.name] = ResourceProviderClass(ep.name) + ResourceClass = ep.load() + if issubclass(ResourceClass, Resource): + self.resources[ep.name] = ResourceClass(ep.name) registry = ResourceRegistry() diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index 3b89240..ddc5762 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -1,10 +1,10 @@ -from datetime import datetime -from typing import Annotated +from typing import Annotated, Any import logging +from json import JSONDecodeError -from authlib.oauth2.auth import OAuth2Token -from httpx import AsyncClient -from jwt.exceptions import ExpiredSignatureError, InvalidTokenError +from authlib.oauth2.rfc6749 import OAuth2Token +from httpx import AsyncClient, HTTPError +from jwt.exceptions import DecodeError, ExpiredSignatureError, InvalidTokenError from fastapi import FastAPI, HTTPException, Depends, Request, status from fastapi.middleware.cors import CORSMiddleware @@ -14,12 +14,11 @@ from fastapi.middleware.cors import CORSMiddleware from oidc_test.auth.provider import Provider from oidc_test.auth.utils import ( - get_token_or_none, - get_user_from_token, - UserWithRole, + get_user_from_token_or_none, + oauth2_scheme_optional, ) from oidc_test.auth_providers import providers -from oidc_test.settings import settings +from oidc_test.settings import ResourceProvider, settings from oidc_test.models import User from oidc_test.registry import ProcessError, ProcessResult, registry @@ -48,99 +47,119 @@ resource_server.add_middleware( # Routes for RBAC based tests -@resource_server.get("/public") -async def public() -> dict: - return {"msg": "Not protected"} - - -@resource_server.get("/protected") -async def get_protected(user: Annotated[User, Depends(get_user_from_token)]): - assert user is not None # Just to keep QA checks happy - return {"msg": "Only authenticated users can see this"} - - -@resource_server.get("/protected-by-foorole") -async def get_protected_by_foorole( - user: Annotated[User, Depends(UserWithRole("foorole"))], -): - assert user is not None - return {"msg": "Only users with foorole can see this"} - - -@resource_server.get("/protected-by-barrole") -async def get_protected_by_barrole( - user: Annotated[User, Depends(UserWithRole("barrole"))], -): - assert user is not None - return {"msg": "Protected by barrole"} - - -@resource_server.get("/protected-by-foorole-and-barrole") -async def get_protected_by_foorole_and_barrole( - user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))], -): - assert user is not None # Just to keep QA checks happy - return {"msg": "Only users with foorole and barrole can see this"} - - -@resource_server.get("/protected-by-foorole-or-barrole") -async def get_protected_by_foorole_or_barrole( - user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))], -): - assert user is not None # Just to keep QA checks happy - return {"msg": "Only users with foorole or barrole can see this"} +@resource_server.get("/") +async def resources() -> dict[str, dict[str, Any]]: + return {"internal": {}, "plugins": registry.resources} @resource_server.get("/{resource_name}") @resource_server.get("/{resource_name}/{resource_id}") async def get_resource( resource_name: str, - user: Annotated[User, Depends(get_user_from_token)], - token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], - request: Request, + user: Annotated[User | None, Depends(get_user_from_token_or_none)], + token: Annotated[OAuth2Token | None, Depends(oauth2_scheme_optional)], resource_id: str | None = None, -) -> ProcessResult: - """Generic path for testing a resource provided by a provider""" - provider = providers[user.auth_provider_id] - # Third party resource (provided through the auth provider) - # The token is just passed on - if resource_name in [r.resource_name for r in provider.resources]: - return await get_auth_provider_resource( - provider=provider, - resource_name=resource_name, - access_token=token["access_token"] if token else None, - user=user, - ) +): + """Generic path for testing a resource provided by a provider. + There's no field validation (response type of ProcessResult) on purpose, + leaving the responsibility of the response validation to resource providers""" + # Get the resource if it's defined in user auth provider's resources (external) + if user is not None: + provider = providers[user.auth_provider_id] + if ":" in resource_name: + # Third-party resource provider: send the request with the request token + resource_provider_id, resource_name = resource_name.split(":", 1) + provider = providers[user.auth_provider_id] + resource_provider: ResourceProvider = provider.get_resource_provider( + resource_provider_id + ) + resource_url = resource_provider.get_resource_url(resource_name) + async with AsyncClient(verify=resource_provider.verify_ssl) as client: + try: + logger.debug(f"GET request to {resource_url}") + resp = await client.get( + resource_url, + headers={ + "Content-type": "application/json", + "Authorization": f"Bearer {token}", + "auth_provider": user.auth_provider_id, + }, + ) + except HTTPError as err: + raise HTTPException( + status.HTTP_503_SERVICE_UNAVAILABLE, err.__class__.__name__ + ) + except Exception as err: + raise HTTPException( + status.HTTP_500_INTERNAL_SERVER_ERROR, err.__class__.__name__ + ) + else: + if resp.is_success: + return resp.json() + else: + reason_str: str + try: + reason_str = resp.json().get("detail", str(resp)) + except Exception: + reason_str = str(resp.text) + raise HTTPException(resp.status_code, reason_str) + # Third party resource (provided through the auth provider) + # The token is just passed on + # XXX: is this branch valid anymore? + if resource_name in [r.resource_name for r in provider.resources]: + return await get_auth_provider_resource( + provider=provider, + resource_name=resource_name, + token=token, + user=user, + ) # Internal resource (provided here) - if resource_name in registry.resource_providers: - resource_provider = registry.resource_providers[resource_name] - if resource_provider.scope_required is not None and user.has_scope( - resource_provider.scope_required - ): + if resource_name in registry.resources: + resource = registry.resources[resource_name] + reason: dict[str, str] = {} + if not resource.is_public: + if user is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Resource is not public") + else: + if resource.scope_required is not None and not user.has_scope( + resource.scope_required + ): + reason["scope"] = ( + f"No scope {resource.scope_required} in the access token " + "but it is required for accessing this resource" + ) + if ( + resource.role_required is not None + and resource.role_required not in user.roles_as_set + ): + reason["role"] = ( + f"You don't have the role {resource.role_required} " + "but it is required for accessing this resource" + ) + if len(reason) == 0: try: - return await resource_provider.process(user=user, resource_id=resource_id) + resp = await resource.process(user=user, resource_id=resource_id) + return resp except ProcessError as err: raise HTTPException( status.HTTP_401_UNAUTHORIZED, f"Cannot process resource: {err}" ) else: - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, - f"No scope {resource_provider.scope_required} in the access token " - + "but it is required for accessing this resource", - ) + raise HTTPException(status.HTTP_401_UNAUTHORIZED, ", ".join(reason.values())) else: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown resource {resource_name}") + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Unknown resource") # return await get_resource_(resource_name, user, **request.query_params) async def get_auth_provider_resource( - provider: Provider, resource_name: str, access_token: str | None, user: User + provider: Provider, resource_name: str, token: OAuth2Token | None, user: User ) -> ProcessResult: - resource = [r for r in provider.resources if r.resource_name == resource_name][0] + if token is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No auth token") + access_token = token async with AsyncClient() as client: resp = await client.get( - url=provider.url + resource.url, + url=provider.get_resource_url(resource_name), headers={ "Content-type": "application/json", "Authorization": f"Bearer {access_token}", @@ -152,12 +171,62 @@ async def get_auth_provider_resource( resp_length = len(resp.text) if resp_length > 1024: return ProcessResult( - result={"msg": f"The resource is too long ({resp_length} bytes) to show here"} + msg=f"The resource is too long ({resp_length} bytes) to show in this demo, here is just the begining in raw format", + start=resp.text[:100] + "...", ) else: - return ProcessResult(result=resp.json()) + try: + resp_json = resp.json() + except JSONDecodeError: + return ProcessResult(msg="The resource is not formatted in JSON", text=resp.text) + if isinstance(resp_json, dict): + return ProcessResult(**resp.json()) + elif isinstance(resp_json, list): + return ProcessResult(**{str(i): line for i, line in enumerate(resp_json)}) +# @resource_server.get("/public") +# async def public() -> dict: +# return {"msg": "Not protected"} +# +# +# @resource_server.get("/protected") +# async def get_protected(user: Annotated[User, Depends(get_user_from_token)]): +# assert user is not None # Just to keep QA checks happy +# return {"msg": "Only authenticated users can see this"} +# +# +# @resource_server.get("/protected-by-foorole") +# async def get_protected_by_foorole( +# user: Annotated[User, Depends(UserWithRole("foorole"))], +# ): +# assert user is not None +# return {"msg": "Only users with foorole can see this"} +# +# +# @resource_server.get("/protected-by-barrole") +# async def get_protected_by_barrole( +# user: Annotated[User, Depends(UserWithRole("barrole"))], +# ): +# assert user is not None +# return {"msg": "Protected by barrole"} +# +# +# @resource_server.get("/protected-by-foorole-and-barrole") +# async def get_protected_by_foorole_and_barrole( +# user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))], +# ): +# assert user is not None # Just to keep QA checks happy +# return {"msg": "Only users with foorole and barrole can see this"} +# +# +# @resource_server.get("/protected-by-foorole-or-barrole") +# async def get_protected_by_foorole_or_barrole( +# user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))], +# ): +# assert user is not None # Just to keep QA checks happy +# return {"msg": "Only users with foorole or barrole can see this"} + # async def get_resource_(resource_id: str, user: User, **kwargs) -> dict: # """ # Resource processing: build an informative rely as a simple showcase diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index 2acbc3f..ad80c06 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -4,7 +4,7 @@ import random from typing import Type, Tuple from pathlib import Path -from pydantic import BaseModel, computed_field, AnyUrl +from pydantic import AnyHttpUrl, BaseModel, computed_field, AnyUrl from pydantic_settings import ( BaseSettings, SettingsConfigDict, @@ -22,6 +22,22 @@ class Resource(BaseModel): url: str +class ResourceProvider(BaseModel): + id: str + name: str + base_url: AnyUrl + resources: list[Resource] = [] + verify_ssl: bool = True + + def get_resource(self, resource_name: str) -> Resource: + return [ + resource for resource in self.resources if resource.resource_name == resource_name + ][0] + + def get_resource_url(self, resource_name: str) -> str: + return f"{self.base_url}{self.get_resource(resource_name).url}" + + class AuthProviderSettings(BaseModel): """Auth provider, can also be a resource server""" @@ -44,6 +60,8 @@ class AuthProviderSettings(BaseModel): resource_provider_scopes: list[str] = [] session_key: str = "sid" skip_verify_signature: bool = True + disabled: bool = False + resource_providers: list[ResourceProvider] = [] @computed_field @property @@ -66,13 +84,6 @@ class AuthProviderSettings(BaseModel): return None -class ResourceProvider(BaseModel): - id: str - name: str - base_url: AnyUrl - resources: list[Resource] = [] - - class AuthSettings(BaseModel): show_session_details: bool = False providers: list[AuthProviderSettings] = [] @@ -91,13 +102,14 @@ class Settings(BaseSettings): model_config = SettingsConfigDict(env_nested_delimiter="__") auth: AuthSettings = AuthSettings() - resource_providers: list[ResourceProvider] = [] secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16)) log: bool = False + log_config_file: str = "log_conf.yaml" insecure: Insecure = Insecure() cors_origins: list[str] = [] debug_token: bool = False show_token: bool = False + show_external_resource_providers_links: bool = False @classmethod def settings_customise_sources( diff --git a/src/oidc_test/static/styles.css b/src/oidc_test/static/styles.css index e163a68..1e8dc03 100644 --- a/src/oidc_test/static/styles.css +++ b/src/oidc_test/static/styles.css @@ -21,6 +21,12 @@ hr { .hidden { display: none; } +.version { + position: absolute; + font-size: 75%; + top: 0.3em; + right: 0.3em; +} .center { text-align: center; } @@ -142,19 +148,27 @@ hr { .providers .provider { min-height: 2em; } -.providers .provider a.link { +.providers .provider .link { text-decoration: none; max-height: 2em; } -.providers .provider .link div { +.providers .provider .link { background-color: #f7c7867d; border-radius: 8px; padding: 6px; text-align: center; color: black; - font-weight: bold; + font-weight: 400; cursor: pointer; + border: 0; + width: 100%; } + +.providers .provider .link.disabled { + color: gray; + cursor: not-allowed; +} + .providers .provider .hint { font-size: 80%; max-width: 13em; diff --git a/src/oidc_test/static/utils.js b/src/oidc_test/static/utils.js index 978b61c..e988dfe 100644 --- a/src/oidc_test/static/utils.js +++ b/src/oidc_test/static/utils.js @@ -2,7 +2,9 @@ async function checkHref(elem, token, authProvider) { const msg = document.getElementById("msg") const resourceName = elem.getAttribute("resource-name") const resourceId = elem.getAttribute("resource-id") - const url = resourceId ? `resource/${resourceName}/${resourceId}` : `resource/${resourceName}` + const resourceProviderId = elem.getAttribute("resource-provider-id") ? elem.getAttribute("resource-provider-id") : "" + const fqResourceName = resourceProviderId ? `${resourceProviderId}:${resourceName}` : resourceName + const url = resourceId ? `resource/${fqResourceName}/${resourceId}` : `resource/${fqResourceName}` const resp = await fetch(url, { method: "GET", headers: new Headers({ @@ -30,11 +32,13 @@ function checkPerms(className, token, authProvider) { ) } -async function get_resource(resource_name, token, authProvider, resource_id) { +async function get_resource(resourceName, token, authProvider, resourceId, resourceProviderId) { + // BaseUrl for an external resource provider //if (!keycloak.keycloak) { return } const msg = document.getElementById("msg") const resourceElem = document.getElementById('resource') - const url = resource_id ? `resource/${resource_name}/${resource_id}` : `resource/${resource_name}` + const fqResourceName = resourceProviderId ? `${resourceProviderId}:${resourceName}` : resourceName + const url = resourceId ? `resource/${fqResourceName}/${resourceId}` : `resource/${fqResourceName}` const resp = await fetch(url, { method: "GET", headers: new Headers({ diff --git a/src/oidc_test/templates/base.html b/src/oidc_test/templates/base.html index 4cb56f5..157e26f 100644 --- a/src/oidc_test/templates/base.html +++ b/src/oidc_test/templates/base.html @@ -5,6 +5,7 @@ +
v. {{ __version__}}

OIDC-test - FastAPI client

{% block content %} {% endblock %} diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 790da81..167616f 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -11,7 +11,11 @@ {% for provider in providers.values() %} -
{{ provider.name }}
+ {{ provider.hint }} @@ -56,51 +60,73 @@ Account management {% endif %} - + {% endif %}
-

- Resources validated by role: -

- - - {% if resource_providers %} -

- Resource providers (validated by scope): -

+ {% if auth_provider.resources %} +

{{ auth_provider.name }} is also defined as a provider for these resources:

{% endif %} + {% if resource_providers %} +

{{ auth_provider.name }} allows this application to request resources from third party resource providers:

+ {% for resource_provider in resource_providers %} + + {% endfor %} + {% endif %}
diff --git a/uv.lock b/uv.lock index 01b64de..0566bb5 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.13" [[package]] @@ -206,6 +207,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/68/1b/e0a87d256e40e8c888847551b20a017a6b98139178505dc7ffb96f04e954/dnspython-2.7.0-py3-none-any.whl", hash = "sha256:b4c34b7d10b51bcc3a5071e7b8dee77939f1e878477eeecc965e9835f63c6c86", size = 313632 }, ] +[[package]] +name = "dunamai" +version = "1.23.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/06/4e/a5c8c337a1d9ac0384298ade02d322741fb5998041a5ea74d1cd2a4a1d47/dunamai-1.23.0.tar.gz", hash = "sha256:a163746de7ea5acb6dacdab3a6ad621ebc612ed1e528aaa8beedb8887fccd2c4", size = 44681 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/21/4c/963169386309fec4f96fd61210ac0a0666887d0fb0a50205395674d20b71/dunamai-1.23.0-py3-none-any.whl", hash = "sha256:a0906d876e92441793c6a423e16a4802752e723e9c9a5aabdc5535df02dbe041", size = 26342 }, +] + [[package]] name = "ecdsa" version = "0.19.0" @@ -482,7 +495,6 @@ wheels = [ [[package]] name = "oidc-fastapi-test" -version = "0.0.0" source = { editable = "." } dependencies = [ { name = "authlib" }, @@ -501,6 +513,7 @@ dependencies = [ [package.dev-dependencies] dev = [ + { name = "dunamai" }, { name = "ipdb" }, { name = "pytest" }, ] @@ -523,6 +536,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ + { name = "dunamai", specifier = ">=1.23.0" }, { name = "ipdb", specifier = ">=0.13.13" }, { name = "pytest", specifier = ">=8.3.4" }, ]