diff --git a/.forgejo/workflows/build.yaml b/.forgejo/workflows/build.yaml index df52e0b..379aaa8 100644 --- a/.forgejo/workflows/build.yaml +++ b/.forgejo/workflows/build.yaml @@ -19,7 +19,7 @@ jobs: - name: Install the latest version of uv uses: astral-sh/setup-uv@v4 with: - version: "0.5.16" + version: "0.6.9" - name: Install run: uv sync @@ -27,32 +27,26 @@ jobs: - name: Run tests (API call) run: .venv/bin/pytest -s tests/basic.py - - name: Get version with git describe - id: version - run: | - echo "version=$(git describe)" >> $GITHUB_OUTPUT - echo "$VERSION" + - name: Get version + run: echo "VERSION=$(.venv/bin/dunamai from any --style semver)" >> $GITHUB_ENV - - name: Check if the container should be built - id: builder - env: - RUN: ${{ toJSON(inputs.build || !contains(steps.version.outputs.version, '-')) }} - run: | - echo "run=$RUN" >> $GITHUB_OUTPUT - echo "Run build: $RUN" + - name: Version + run: echo $VERSION - - name: Set the version in pyproject.toml (workaround for uv not supporting dynamic version) - if: fromJSON(steps.builder.outputs.run) - env: - VERSION: ${{ steps.version.outputs.version }} - run: sed "s/0.0.0/$VERSION/" -i pyproject.toml + - name: Get distance from tag + run: echo "DISTANCE=$(.venv/bin/dunamai from any --format '{distance}')" >> $GITHUB_ENV + + - name: Distance + run: echo $DISTANCE - name: Workaround for bug of podman-login + if: env.DISTANCE == '0' run: | mkdir -p $HOME/.docker echo "{ \"auths\": {} }" > $HOME/.docker/config.json - name: Log in to the container registry (with another workaround) + if: env.DISTANCE == '0' uses: actions/podman-login@v1 with: registry: ${{ vars.REGISTRY }} @@ -61,26 +55,31 @@ jobs: auth_file_path: /tmp/auth.json - name: Build the container image + if: env.DISTANCE == '0' uses: actions/buildah-build@v1 with: image: oidc-fastapi-test oci: true labels: oidc-fastapi-test - tags: latest ${{ steps.version.outputs.version }} + tags: "latest ${{ env.VERSION }}" containerfiles: | ./Containerfile - name: Push the image to the registry + if: env.DISTANCE == '0' uses: actions/push-to-registry@v2 with: registry: "docker://${{ vars.REGISTRY }}/${{ vars.ORGANISATION }}" image: oidc-fastapi-test - tags: latest ${{ steps.version.outputs.version }} + tags: "latest ${{ env.VERSION }}" - name: Build wheel + if: env.DISTANCE == '0' run: uv build --wheel - name: Publish Python package (home) + if: env.DISTANCE == '0' env: LOCAL_PYPI_TOKEN: ${{ secrets.LOCAL_PYPI_TOKEN }} run: uv publish --publish-url https://code.philo.ydns.eu/api/packages/philorg/pypi --token $LOCAL_PYPI_TOKEN + continue-on-error: true diff --git a/.forgejo/workflows/test.yaml b/.forgejo/workflows/test.yaml index a56a9ce..f4d994e 100644 --- a/.forgejo/workflows/test.yaml +++ b/.forgejo/workflows/test.yaml @@ -19,7 +19,7 @@ jobs: - name: Install the latest version of uv uses: astral-sh/setup-uv@v4 with: - version: "0.5.16" + version: "0.6.3" - name: Install run: uv sync diff --git a/Containerfile b/Containerfile index aef57f8..0ec45d1 100644 --- a/Containerfile +++ b/Containerfile @@ -1,4 +1,4 @@ -FROM docker.io/library/python:alpine +FROM docker.io/library/python:latest COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/ @@ -9,6 +9,9 @@ WORKDIR /app RUN uv pip install --system . +# Add demo plugin +RUN PIP_EXTRA_INDEX_URL=https://pypi.org/simple/ uv pip install --system --index-url https://code.philo.ydns.eu/api/packages/philorg/pypi/simple/ oidc-fastapi-test-resource-provider-demo + # Possible to run with: #CMD ["oidc-test", "--port", "80"] #CMD ["fastapi", "run", "src/oidc_test/main.py", "--port", "8873", "--root-path", "/oidc-test"] diff --git a/README.md b/README.md index 9e00474..68f335d 100644 --- a/README.md +++ b/README.md @@ -52,31 +52,59 @@ given by the OIDC providers. For example: ```yaml -oidc: - secret_key: "ASecretNoOneKnows" - show_session_details: yes +secret_key: AVeryWellKeptSecret +debug_token: no +show_token: yes +log: yes + +auth: providers: - id: auth0 name: Okta / Auth0 - url: "https://" - client_id: "" - client_secret: "client_secret_generated_by_auth0" - hint: "A hint for test credentials" + url: https:// + public_key_url: https:///pem + client_id: + client_secret: client_secret_generated_by_auth0 + hint: A hint for test credentials - id: keycloak name: Keycloak at somewhere - url: "https://" - account_url_template: "/account" - client_id: "" - client_secret: "client_secret_generated_by_keycloak" - hint: "User: foo, password: foofoo" + url: https:// + info_url: https://philo.ydns.eu/auth/realms/test + account_url_template: /account + client_id: + client_secret: + hint: A hint for test credentials + code_challenge_method: S256 + resource_provider_scopes: + - get:time + - get:bs + resource_providers: + - id: + name: A third party resource provider + base_url: https://some.example.com/ + verify_ssl: yes + resources: + - name: Public RS2 + resource_name: public + url: resource/public + - name: BS RS2 + resource_name: bs + url: resource/bs + - name: Time RS2 + resource_name: time + url: resource/time - id: codeberg + disabled: no name: Codeberg - url: "https://codeberg.org" - account_url_template: "/user/settings" - client_id: "" - client_secret: "client_secret_generated_by_codeberg" + url: https://codeberg.org + account_url_template: /user/settings + client_id: + client_secret: client_secret_generated_by_codeberg + info_url: https://codeberg.org/login/oauth/keys + session_key: sub + skip_verify_signature: no resources: - name: List of repos id: repos diff --git a/pyproject.toml b/pyproject.toml index 980bcfc..c44e9f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "oidc-fastapi-test" -version = "0.0.0" -# dynamic = ["version"] +#version = "0.0.0" +dynamic = ["version"] description = "Add your description here" readme = "README.md" requires-python = ">=3.13" @@ -24,14 +24,24 @@ dependencies = [ oidc-test = "oidc_test.main:main" [dependency-groups] -dev = ["ipdb>=0.13.13", "pytest>=8.3.4"] +dev = ["dunamai>=1.23.0", "ipdb>=0.13.13", "pytest>=8.3.4"] [build-system] -requires = ["hatchling"] +requires = ["hatchling", "uv-dynamic-versioning"] build-backend = "hatchling.build" +[tool.hatch.version] +source = "uv-dynamic-versioning" + [tool.hatch.build.targets.wheel] packages = ["src/oidc_test"] +package = true + +[tool.uv-dynamic-versioning] +style = "semver" [tool.uv] package = true + +[tool.black] +line-length = 98 diff --git a/src/oidc_test/__init__.py b/src/oidc_test/__init__.py index e69de29..f449e2b 100644 --- a/src/oidc_test/__init__.py +++ b/src/oidc_test/__init__.py @@ -0,0 +1,11 @@ +import importlib.metadata + +try: + from dunamai import Version, Style + + __version__ = Version.from_git().serialize(style=Style.SemVer, dirty=True) +except ImportError: + # __name__ could be used if the package name is the same + # as the directory - not the case here + # __version__ = importlib.metadata.version(__name__) + __version__ = importlib.metadata.version("oidc-fastapi-test") diff --git a/src/oidc_test/auth/provider.py b/src/oidc_test/auth/provider.py new file mode 100644 index 0000000..ce288a6 --- /dev/null +++ b/src/oidc_test/auth/provider.py @@ -0,0 +1,113 @@ +from json import JSONDecodeError +from typing import Any +from jwt import decode +import logging + +from pydantic import ConfigDict +from authlib.integrations.starlette_client.apps import StarletteOAuth2App +from httpx import AsyncClient + +from oidc_test.settings import AuthProviderSettings, ResourceProvider, Resource, settings +from oidc_test.models import User + +logger = logging.getLogger("oidc-test") + + +class NoPublicKey(Exception): + pass + + +class Provider(AuthProviderSettings): + # To allow authlib_client as StarletteOAuth2App + model_config = ConfigDict(arbitrary_types_allowed=True) # type:ignore + + authlib_client: StarletteOAuth2App = StarletteOAuth2App(None) + info: dict[str, Any] = {} + unknown_auth_user: User + logout_with_id_token_hint: bool = True + + def decode(self, token: str, verify_signature: bool | None = None) -> dict[str, Any]: + """Decode the token with signature check""" + if self.public_key is None: + raise NoPublicKey + if verify_signature is None: + verify_signature = self.skip_verify_signature + if settings.debug_token: + decoded = decode( + token, + self.public_key, + algorithms=[self.signature_alg], + audience=["account", "oidc-test", "oidc-test-web"], + options={ + "verify_signature": False, + "verify_aud": False, + }, # not settings.insecure.skip_verify_signature}, + ) + logger.debug(str(decoded)) + return decode( + token, + self.public_key, + algorithms=[self.signature_alg], + audience=["account", "oidc-test", "oidc-test-web"], + options={ + "verify_signature": verify_signature, + }, # not settings.insecure.skip_verify_signature}, + ) + + async def get_info(self): + # Get the public key: + async with AsyncClient() as client: + public_key: str | None = None + if self.info_url is not None: + try: + provider_info = await client.get(self.info_url) + except Exception as err: + logger.debug("Provider_info: cannot connect") + logger.exception(err) + raise NoPublicKey + try: + self.info = provider_info.json() + except JSONDecodeError: + logger.debug("Provider_info: cannot decode json response") + raise NoPublicKey + if "public_key" in self.info: + # For Keycloak + try: + public_key = str(self.info["public_key"]) + except KeyError: + logger.debug("Provider_info: cannot get public_key") + raise NoPublicKey + elif "keys" in self.info: + # For Forgejo/Gitea + try: + public_key = str(self.info["keys"][0]["n"]) + except KeyError: + logger.debug("Provider_info: cannot get key 0.n") + raise NoPublicKey + if self.public_key_url is not None: + resp = await client.get(self.public_key_url) + public_key = resp.text + if public_key is None: + logger.debug("Provider_info: cannot determine public key") + raise NoPublicKey + self.public_key = "\n".join( + ["-----BEGIN PUBLIC KEY-----", public_key, "-----END PUBLIC KEY-----"] + ) + + def get_session_key(self, userinfo): + return userinfo[self.session_key] + + def get_resource(self, resource_name: str) -> Resource: + return [ + resource for resource in self.resources if resource.resource_name == resource_name + ][0] + + def get_resource_url(self, resource_name: str) -> str: + return self.url + self.get_resource(resource_name).url + + def get_resource_provider(self, resource_provider_id: str) -> ResourceProvider: + return [ + provider + for provider in self.resource_providers + if provider.id == resource_provider_id + ][0] diff --git a/src/oidc_test/auth/utils.py b/src/oidc_test/auth/utils.py new file mode 100644 index 0000000..c51b039 --- /dev/null +++ b/src/oidc_test/auth/utils.py @@ -0,0 +1,305 @@ +from typing import Union, Annotated +from functools import wraps +import logging + +from fastapi import HTTPException, Request, Depends, status +from fastapi.security import OAuth2PasswordBearer +from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App +from authlib.oauth2.rfc6749 import OAuth2Token +from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError + +from oidc_test.auth.provider import Provider +from oidc_test.models import User +from oidc_test.database import db, TokenNotInDb, UserNotInDB +from oidc_test.settings import settings +from oidc_test.auth_providers import providers + +logger = logging.getLogger("oidc-test") + + +async def fetch_token(name, request): + assert name is not None + assert request is not None + logger.warning("TODO: fetch_token") + ... + # if name in oidc_providers: + # model = OAuth2Token + # else: + # model = OAuthToken + + # token = model.find(name=name, user=request.user) + # return token.to_token() + + +async def update_token( + provider_id, + token, + refresh_token: str | None = None, + access_token: str | None = None, +): + """Update the token in the database""" + provider = providers[provider_id] + sid: str = provider.get_session_key(provider.decode(token["id_token"])) + item = await db.get_token(provider, sid) + # update old token + item["access_token"] = token["access_token"] + item["refresh_token"] = token["refresh_token"] + item["id_token"] = token["id_token"] + item["expires_at"] = token["expires_at"] + logger.info(f"Token {sid} refreshed") + # It's a fake db and only in memory, so there's nothing to save + # await item.save() + + +def init_providers(): + """Add oidc providers to authlib from the settings + and build the providers dict""" + for provider_settings in settings.auth.providers: + provider_settings_dict = provider_settings.model_dump() + # Add an anonymous user, that cannot be identified but has provided a valid access token + provider_settings_dict["unknown_auth_user"] = User( + sub="", auth_provider_id=provider_settings.id + ) + provider = Provider(**provider_settings_dict) + if provider.disabled: + logger.info(f"{provider_settings.name} is disabled, skipping") + else: + authlib_oauth.register( + name=provider.id, + server_metadata_url=provider.openid_configuration, + client_kwargs={ + "scope": " ".join( + ["openid", "email", "offline_access", "profile"] + + provider.resource_provider_scopes + ), + }, + client_id=provider.client_id, + client_secret=provider.client_secret, + api_base_url=provider.url, + # For PKCE (not implemented yet): + # code_challenge_method="S256", + fetch_token=fetch_token, + update_token=update_token, + # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) + ) + provider.authlib_client = getattr(authlib_oauth, provider.id) + providers[provider.id] = provider + + +authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") +oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) + + +def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None: + """Return the oidc_provider from a request object, from the session. + It can be used in Depends()""" + if (auth_provider_id := request.session.get("auth_provider_id")) is None: + return + return getattr(authlib_oauth, str(auth_provider_id), None) + + +def get_auth_provider_client(request: Request) -> StarletteOAuth2App: + if (oidc_provider := get_auth_provider_client_or_none(request)) is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") + else: + return oidc_provider + + +def get_auth_provider_or_none(request: Request) -> Provider | None: + """Return the oidc_provider settings from a request object, from the session. + It can be used in Depends()""" + if (auth_provider_id := request.session.get("auth_provider_id")) is None: + return + return providers.get(auth_provider_id) + + +def get_auth_provider(request: Request) -> Provider: + if (provider := get_auth_provider_or_none(request)) is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") + return provider + + +async def get_current_user(request: Request) -> User: + """Get the current user from a request object. + Also validates the token expiration time. + ... TODO: complete about refresh token + """ + if (user_sub := request.session.get("user_sub")) is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED) + token = await get_token_from_session(request) + user = await db.get_user(user_sub) + ## Check if the token is expired + if token.is_expired(): + provider = get_auth_provider(request=request) + ## Ask a new refresh token from the provider + logger.info(f"Token expired for user {user.name}") + try: + userinfo = await provider.authlib_client.fetch_access_token( + refresh_token=token.get("refresh_token") + ) + assert userinfo is not None + except OAuthError as err: + logger.exception(err) + # raise HTTPException( + # status.HTTP_401_UNAUTHORIZED, "Token expired, cannot refresh" + # ) + + return user + + +async def get_token_from_session_or_none(request: Request) -> OAuth2Token | None: + """Return the auth token from the session or None. + Can be used in Depends()""" + try: + return await get_token_from_session(request) + except HTTPException: + return None + + +async def get_token_from_session(request: Request) -> OAuth2Token: + """Return the token from the session. + Can be used in Depends()""" + try: + provider = providers[request.session.get("auth_provider_id", "")] + except KeyError: + request.session.pop("auth_provider_id", None) + request.session.pop("user_sub", None) + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid provider") + try: + return await db.get_token( + provider, + request.session.get("sid"), + ) + except (TokenNotInDb, InvalidKeyError, DecodeError) as err: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, err.__class__.__name__) + + +async def get_current_user_or_none(request: Request) -> User | None: + """Return the user from a request object, from the session. + It can be used in Depends()""" + try: + return await get_current_user(request) + except HTTPException: + return None + + +def hasrole(required_roles: Union[str, list[str]] = []): + """Decorator for RBAC permissions""" + required_roles_set: set[str] + if isinstance(required_roles, str): + required_roles_set = set([required_roles]) + else: + required_roles_set = set(required_roles) + + def decorator(func): + @wraps(func) + async def wrapper(request=None, *args, **kwargs): + if request is None: + raise HTTPException( + 500, + "Functions decorated with hasrole must have a request:Request argument", + ) + user: User = await get_current_user(request) + if not any(required_roles_set.intersection(user.roles_as_set)): + raise HTTPException(status.HTTP_401_UNAUTHORIZED) + return await func(request, *args, **kwargs) + + return wrapper + + return decorator + + +def get_token_info(token: dict) -> dict: + token_info = dict() + for key in token: + if key != "userinfo": + token_info[key] = token[key] + return token_info + + +async def get_user_from_token( + token: Annotated[str, Depends(oauth2_scheme)], + request: Request, +) -> User: + try: + auth_provider_id = request.headers["auth_provider"] + except KeyError: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, + "Request headers must have a 'auth_provider' field", + ) + try: + provider = providers[auth_provider_id] + except KeyError: + if auth_provider_id == "": + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No auth provider") + else: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'" + ) + if token == "None": + request.session.pop("auth_provider_id", None) + request.session.pop("user_sub", None) + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token") + try: + payload = provider.decode(token) + except ExpiredSignatureError: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, + "Expired signature (token refresh not implemented yet)", + ) + except InvalidKeyError: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key") + except PyJWTError as err: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, f"Cannot decode token: {err.__class__.__name__}" + ) + try: + user_id = payload["sub"] + except KeyError: + return provider.unknown_auth_user + try: + user = await db.get_user(user_id) + if user.access_token != token: + user.access_token = token + except UserNotInDB: + logger.info( + f"User {user_id} not found in DB, creating it (real apps can behave differently)" + ) + user = await db.add_user( + sub=payload["sub"], + user_info=payload, + auth_provider=providers[auth_provider_id], + access_token=token, + ) + return user + + +async def get_user_from_token_or_none( + token: Annotated[str | None, Depends(oauth2_scheme_optional)], + request: Request, +) -> User | None: + if token is None: + return None + try: + return await get_user_from_token(token, request) + except HTTPException: + return None + + +class UserWithRole: + roles: set[str] + + def __init__(self, roles: str | list[str] | tuple[str] | set[str]): + if isinstance(roles, str): + self.roles = set([roles]) + elif isinstance(roles, (list, tuple, set)): + self.roles = set(roles) + + def __call__(self, user: User = Depends(get_user_from_token)) -> User: + if not any(self.roles.intersection(user.roles_as_set)): + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, f"Not of any required role {', '.join(self.roles)}" + ) + return user diff --git a/src/oidc_test/auth_misc.py b/src/oidc_test/auth_misc.py deleted file mode 100644 index a4e9ea3..0000000 --- a/src/oidc_test/auth_misc.py +++ /dev/null @@ -1,29 +0,0 @@ -from datetime import datetime, timedelta -from collections import OrderedDict - -from .models import User - -time_keys = set(("iat", "exp", "auth_time", "updated_at")) - - -def pretty_details(user: User, now: datetime) -> OrderedDict: - details = OrderedDict() - # breakpoint() - for key in sorted(time_keys): - try: - dt = datetime.fromtimestamp(user.userinfo[key]) - except (KeyError, TypeError): - pass - else: - td = now - dt - td = timedelta(days=td.days, seconds=td.seconds) - if td.days < 0: - ptd = f"in {-td} h:m:s" - else: - ptd = f"{td} h:m:s ago" - details[key] = f"{user.userinfo[key]} - {dt} ({ptd})" - for key in sorted(user.userinfo): - if key in time_keys: - continue - details[key] = user.userinfo[key] - return details diff --git a/src/oidc_test/auth_providers.py b/src/oidc_test/auth_providers.py new file mode 100644 index 0000000..1c33ae8 --- /dev/null +++ b/src/oidc_test/auth_providers.py @@ -0,0 +1,5 @@ +from collections import OrderedDict + +from oidc_test.auth.provider import Provider + +providers: OrderedDict[str, Provider] = OrderedDict() diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py deleted file mode 100644 index 4c3b98c..0000000 --- a/src/oidc_test/auth_utils.py +++ /dev/null @@ -1,227 +0,0 @@ -from typing import Union, Annotated -from functools import wraps -import logging - -from fastapi import HTTPException, Request, Depends, status -from fastapi.security import OAuth2PasswordBearer -from authlib.oauth2.rfc6749 import OAuth2Token -from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App -from jwt import ExpiredSignatureError, InvalidKeyError -from httpx import AsyncClient - -# from authlib.oauth1.auth import OAuthToken -# from authlib.oauth2.auth import OAuth2Token - -from .models import User -from .database import TokenNotInDb, db, UserNotInDB -from .settings import settings, OIDCProvider - -logger = logging.getLogger(__name__) - -oidc_providers_settings: dict[str, OIDCProvider] = dict( - [(provider.id, provider) for provider in settings.oidc.providers] -) - -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") - - -async def fetch_token(name, request): - logger.warn("TODO: fetch_token") - ... - # if name in oidc_providers: - # model = OAuth2Token - # else: - # model = OAuthToken - - # token = model.find(name=name, user=request.user) - # return token.to_token() - - -async def update_token(*args, **kwargs): - logger.warn("TODO: update_token") - ... - - -authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) - - -def init_providers(): - # Add oidc providers to authlib from the settings - for id, provider in oidc_providers_settings.items(): - authlib_oauth.register( - name=id, - server_metadata_url=provider.openid_configuration, - client_kwargs={ - "scope": "openid email offline_access profile", - }, - client_id=provider.client_id, - client_secret=provider.client_secret, - api_base_url=provider.url, - # For PKCE (not implemented yet): - # code_challenge_method="S256", - # fetch_token=fetch_token, - # update_token=update_token, - # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) - ) - - -init_providers() - - -async def get_providers_info(): - # Get the public key: - async with AsyncClient() as client: - for provider_settings in oidc_providers_settings.values(): - if provider_settings.info_url: - provider_info = await client.get(provider_settings.url) - provider_settings.info = provider_info.json() - - -def get_oidc_provider_or_none(request: Request) -> StarletteOAuth2App | None: - """Return the oidc_provider from a request object, from the session. - It can be used in Depends()""" - if (oidc_provider_id := request.session.get("oidc_provider_id")) is None: - return - try: - return getattr(authlib_oauth, str(oidc_provider_id)) - except AttributeError: - return - - -def get_oidc_provider(request: Request) -> StarletteOAuth2App: - if (oidc_provider := get_oidc_provider_or_none(request)) is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") - else: - return oidc_provider - - -async def get_current_user(request: Request) -> User: - """Get the current user from a request object. - Also validates the token expiration time. - ... TODO: complete about refresh token - """ - if (user_sub := request.session.get("user_sub")) is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED) - if (token := await db.get_token(request.session["token"])) is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token unknown") - user = await db.get_user(user_sub) - ## Check if the token is expired - if token.is_expired(): - oidc_provider = get_oidc_provider(request=request) - ## Ask a new refresh token from the provider - logger.info(f"Token expired for user {user.name}") - try: - userinfo = await oidc_provider.fetch_access_token( - refresh_token=token.get("refresh_token") - ) - except OAuthError as err: - logger.exception(err) - # raise HTTPException( - # status.HTTP_401_UNAUTHORIZED, "Token expired, cannot refresh" - # ) - - return user - - -async def get_token(request: Request) -> OAuth2Token: - """Return the token from a request object, from the session. - It can be used in Depends()""" - try: - return await db.get_token(request.session["token"]) - except (KeyError, TokenNotInDb): - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token") - - -async def get_current_user_or_none(request: Request) -> User | None: - """Return the user from a request object, from the session. - It can be used in Depends()""" - try: - return await get_current_user(request) - except HTTPException: - return None - - -def hasrole(required_roles: Union[str, list[str]] = []): - """Decorator for RBAC permissions""" - required_roles_set: set[str] - if isinstance(required_roles, str): - required_roles_set = set([required_roles]) - else: - required_roles_set = set(required_roles) - - def decorator(func): - @wraps(func) - async def wrapper(request=None, *args, **kwargs): - if request is None: - raise HTTPException( - 500, - "Functions decorated with hasrole must have a request:Request argument", - ) - user: User = await get_current_user(request) - if not any(required_roles_set.intersection(user.roles_as_set)): - raise HTTPException(status.HTTP_401_UNAUTHORIZED) - return await func(request, *args, **kwargs) - - return wrapper - - return decorator - - -def get_token_info(token: dict) -> dict: - token_info = dict() - for key in token: - if key != "userinfo": - token_info[key] = token[key] - return token_info - - -async def get_user_from_token( - token: Annotated[str, Depends(oauth2_scheme)], - request: Request, -) -> User: - if (auth_provider_id := request.headers.get("auth_provider")) is None: - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, - "Request headers must have a 'auth_provider' field", - ) - if ( - auth_provider_settings := oidc_providers_settings.get(auth_provider_id) - ) is None: - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'" - ) - try: - payload = auth_provider_settings.decode(token) - except ExpiredSignatureError as err: - logger.info(f"Expired signature: {err}") - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, - "Expired signature (refresh not implemented yet)", - ) - except InvalidKeyError as err: - logger.info(f"Invalid key: {err}") - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key") - except Exception as err: - logger.info("Cannot decode token, see below") - logger.exception(err) - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot decode token") - if (user_id := payload.get("sub")) is None: - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found" - ) - try: - user = await db.get_user(user_id) - if user.access_token != token: - user.access_token = token - except UserNotInDB: - logger.info( - f"User {user_id} not found in DB, creating it (real apps can behave differently" - ) - user = await db.add_user( - sub=payload["sub"], - user_info=payload, - oidc_provider=getattr(authlib_oauth, auth_provider_id), - user_info_from_endpoint={}, - access_token=token, - ) - return user diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index 1b682ef..8d87a48 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -2,12 +2,15 @@ import logging -from authlib.integrations.starlette_client.apps import StarletteOAuth2App - -from .models import User, Role from authlib.oauth2.rfc6749 import OAuth2Token +from jwt import PyJWTError -logger = logging.getLogger(__name__) +from oidc_test.auth.provider import Provider + +from oidc_test.models import User, Role +from oidc_test.auth_providers import providers + +logger = logging.getLogger("oidc-test") class UserNotInDB(Exception): @@ -20,6 +23,7 @@ class TokenNotInDb(Exception): class Database: users: dict[str, User] = {} + # TODO: key of the token table should be provider: sid tokens: dict[str, OAuth2Token] = {} # Last sessions for the user (key: users's subject id (sub)) @@ -28,21 +32,38 @@ class Database: self, sub: str, user_info: dict, - oidc_provider: StarletteOAuth2App, - user_info_from_endpoint: dict, + auth_provider: Provider, access_token: str, + access_token_decoded: dict | None = None, ) -> User: - user = User.from_auth(userinfo=user_info, oidc_provider=oidc_provider) - user.access_token = access_token + if access_token_decoded is None: + assert auth_provider.name is not None + provider = providers[auth_provider.id] + try: + access_token_decoded = provider.decode(access_token) + except PyJWTError: + access_token_decoded = {} + user_info["auth_provider_id"] = auth_provider.id + user = User(**user_info) + user.userinfo = user_info + # user.access_token = access_token + # user.access_token_decoded = access_token_decoded + # Add roles provided in the access token + roles = set() try: - raw_roles = user_info_from_endpoint["resource_access"][ - oidc_provider.client_id - ]["roles"] - except Exception as err: - logger.debug(f"Cannot read additional roles: {err}") - raw_roles = [] - for raw_role in raw_roles: - user.roles.append(Role(name=raw_role)) + r = access_token_decoded["resource_access"][auth_provider.client_id]["roles"] + roles.update(r) + except KeyError: + pass + try: + r = access_token_decoded["realm_access"]["roles"] + if isinstance(r, str): + roles.add(r) + else: + roles.update(r) + except KeyError: + pass + user.roles = [Role(name=role_name) for role_name in roles] self.users[sub] = user return user @@ -51,12 +72,23 @@ class Database: raise UserNotInDB return self.users[sub] - async def add_token(self, token: OAuth2Token, user: User) -> None: - self.tokens[token["id_token"]] = token + async def add_token(self, provider: Provider, token: OAuth2Token) -> None: + """Store a token using as key the sid (auth provider's session id) + in the id_token""" + sid = provider.get_session_key(token["userinfo"]) + self.tokens[sid] = token - async def get_token(self, id_token: str) -> OAuth2Token: + async def get_token( + self, + provider: Provider, + sid: str | None, + ) -> OAuth2Token: + # TODO: key of the token table should be provider: sid + assert isinstance(provider, Provider) + if sid is None: + raise TokenNotInDb try: - return self.tokens[id_token] + return self.tokens[sid] except KeyError: raise TokenNotInDb diff --git a/src/oidc_test/log_conf.yaml b/src/oidc_test/log_conf.yaml new file mode 100644 index 0000000..a6bb0b4 --- /dev/null +++ b/src/oidc_test/log_conf.yaml @@ -0,0 +1,34 @@ +version: 1 +disable_existing_loggers: False +formatters: + default: + "()": uvicorn.logging.DefaultFormatter + format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + access: + "()": uvicorn.logging.AccessFormatter + format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" +handlers: + default: + formatter: default + class: logging.StreamHandler + stream: ext://sys.stderr + access: + formatter: access + class: logging.StreamHandler + stream: ext://sys.stdout +loggers: + uvicorn.error: + level: INFO + handlers: + - default + propagate: no + uvicorn.access: + level: INFO + handlers: + - access + propagate: no + "oidc-test": + level: DEBUG + handlers: + - default + propagate: yes diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 3d95009..e882cda 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -6,15 +6,19 @@ from typing import Annotated from pathlib import Path from datetime import datetime import logging +import logging.config +import importlib.resources +from yaml import safe_load from urllib.parse import urlencode from contextlib import asynccontextmanager from httpx import HTTPError from fastapi import Depends, FastAPI, HTTPException, Request, status from fastapi.staticfiles import StaticFiles -from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse +from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.templating import Jinja2Templates from fastapi.middleware.cors import CORSMiddleware +from jwt import PyJWTError from starlette.middleware.sessions import SessionMiddleware from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.base_client import OAuthError @@ -25,32 +29,52 @@ from authlib.oauth2.rfc6749 import OAuth2Token # from fastapi.security import OpenIdConnect # from pkce import generate_code_verifier, generate_pkce_pair -from .settings import settings -from .models import User -from .auth_utils import ( - get_oidc_provider, - get_oidc_provider_or_none, - hasrole, +from oidc_test import __version__ +from oidc_test.registry import registry +from oidc_test.auth.provider import NoPublicKey, Provider +from oidc_test.auth.utils import ( + get_auth_provider, + get_auth_provider_or_none, get_current_user_or_none, - get_current_user, - get_user_from_token, authlib_oauth, - get_token, - oidc_providers_settings, - get_providers_info, + get_token_from_session_or_none, + get_token_from_session, + update_token, ) -from .auth_misc import pretty_details -from .database import db -from .resource_server import get_resource +from oidc_test.auth.utils import init_providers +from oidc_test.settings import settings +from oidc_test.auth_providers import providers +from oidc_test.models import User +from oidc_test.database import TokenNotInDb, db +from oidc_test.resource_server import resource_server -logger = logging.getLogger("uvicorn.error") +logger = logging.getLogger("oidc-test") + +if settings.log: + assert __package__ is not None + with ( + importlib.resources.path(__package__) as package_path, + open(package_path / settings.log_config_file) as f, + ): + logging_config = safe_load(f) + logging.config.dictConfig(logging_config) templates = Jinja2Templates(Path(__file__).parent / "templates") @asynccontextmanager async def lifespan(app: FastAPI): - await get_providers_info() + assert app is not None + init_providers() + registry.make_registry() + for provider in list(providers.values()): + if provider.disabled: + continue + try: + await provider.get_info() + except NoPublicKey: + logger.warning(f"Disable {provider.id}: public key not found") + del providers[provider.id] yield @@ -64,74 +88,79 @@ app.add_middleware( allow_headers=["*"], ) -app.mount( - "/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static" -) - # SessionMiddleware is required by authlib app.add_middleware( SessionMiddleware, secret_key=settings.secret_key, ) +app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static") +app.mount("/resource", resource_server, name="resource_server") + @app.get("/") async def home( request: Request, - user: Annotated[User, Depends(get_current_user_or_none)], - oidc_provider: Annotated[ - StarletteOAuth2App | None, Depends(get_oidc_provider_or_none) - ], + user: Annotated[User | None, Depends(get_current_user_or_none)], + provider: Annotated[Provider | None, Depends(get_auth_provider_or_none)], + token: Annotated[OAuth2Token | None, Depends(get_token_from_session_or_none)], ) -> HTMLResponse: - now = datetime.now() - if oidc_provider and ( - ( - oidc_provider_settings := oidc_providers_settings.get( - request.session.get("oidc_provider_id", "") - ) - ) - is not None - ): - resources = oidc_provider_settings.resources + context = { + "show_token": settings.show_token, + "user": user, + "now": datetime.now(), + "__version__": __version__, + } + if provider is None or token is None: + context["providers"] = providers + context["access_token"] = None + context["id_token_parsed"] = None + context["access_token_parsed"] = None + context["refresh_token_parsed"] = None + context["resources"] = None + context["auth_provider"] = None else: - resources = [] - oidc_provider_settings = None - return templates.TemplateResponse( - name="home.html", - request=request, - context={ - "settings": settings.model_dump(), - "user": user, - "now": now, - "oidc_provider": oidc_provider, - "oidc_provider_settings": oidc_provider_settings, - "resources": resources, - "user_info_details": ( - pretty_details(user, now) - if user and settings.oidc.show_session_details - else None - ), - }, - ) + context["auth_provider"] = provider + context["access_token"] = token["access_token"] + try: + access_token_parsed = provider.decode(token["access_token"], verify_signature=False) + context["access_token_parsed"] = access_token_parsed + context["access_token_scope"] = access_token_parsed.get("scope") + except PyJWTError as err: + context["access_token_parsed"] = {"Cannot parse": err.__class__.__name__} + context["access_token_scope"] = None + try: + id_token_parsed = provider.decode(token["id_token"], verify_signature=False) + context["id_token_parsed"] = id_token_parsed + except PyJWTError as err: + context["id_token_parsed"] = {"Cannot parse": err.__class__.__name__} + try: + refresh_token_parsed = provider.decode(token["refresh_token"], verify_signature=False) + context["refresh_token_parsed"] = refresh_token_parsed + except PyJWTError as err: + context["refresh_token_parsed"] = {"Cannot parse": err.__class__.__name__} + context["resources"] = registry.resources + context["resource_providers"] = provider.resource_providers + return templates.TemplateResponse(name="home.html", request=request, context=context) # Endpoints for the login / authorization process -@app.get("/login/{oidc_provider_id}") -async def login(request: Request, oidc_provider_id: str) -> RedirectResponse: +@app.get("/login/{auth_provider_id}") +async def login(request: Request, auth_provider_id: str) -> RedirectResponse: """Login with the provider id, giving the browser a redirect to its authorize page. - The provider is expected to send the browser back to our own /auth/{oidc_provider_id} url + The provider is expected to send the browser back to our own /auth/{auth_provider_id} url with the token. """ - redirect_uri = request.url_for("auth", oidc_provider_id=oidc_provider_id) + redirect_uri = request.url_for("auth", auth_provider_id=auth_provider_id) try: - provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id) + provider: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id) except AttributeError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") # if ( - # code_challenge_method := oidc_providers_settings[ - # oidc_provider_id + # code_challenge_method := providers[ + # auth_provider_id # ].code_challenge_method # ) is not None: # #client = AsyncOAuth2Client(..., code_challenge_method=code_challenge_method) @@ -151,217 +180,144 @@ async def login(request: Request, oidc_provider_id: str) -> RedirectResponse: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider") -@app.get("/auth/{oidc_provider_id}") -async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: +@app.get("/auth/{auth_provider_id}") +async def auth( + request: Request, + auth_provider_id: str, +) -> RedirectResponse: """Decrypt the auth token, store it to the session (cookie based) and response to the browser with a redirect to a "welcome user" page. """ try: - oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id) - except AttributeError: + provider = providers[auth_provider_id] + except KeyError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") try: - token: OAuth2Token = await oidc_provider.authorize_access_token(request) + token: OAuth2Token = await provider.authlib_client.authorize_access_token(request) except OAuthError as error: raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error) - # Remember the oidc_provider in the session + # Remember the authlib_client in the session # logger.info(f"Scope: {token['scope']}") - request.session["oidc_provider_id"] = oidc_provider_id + request.session["auth_provider_id"] = auth_provider_id # # One could process the full decoded token which contains extra information # eg for updates. Here we are only interested in roles # if userinfo := token.get("userinfo"): - # Remember the oidc_provider in the session - request.session["oidc_provider_id"] = oidc_provider_id - # User id (sub) given by oidc provider + # Remember the authlib_client in the session + request.session["auth_provider_id"] = auth_provider_id + # User id (sub) given by auth provider sub = userinfo["sub"] - # Get additional data from userinfo endpoint - try: - user_info_from_endpoint = await oidc_provider.userinfo( - token=token, follow_redirects=True - ) - except Exception as err: - logger.warn(f"Cannot get userinfo from endpoint: {err}") - user_info_from_endpoint = {} + ## Get additional data from userinfo endpoint + # try: + # user_info_from_endpoint = await authlib_client.userinfo( + # token=token, follow_redirects=True + # ) + # except Exception as err: + # logger.warn(f"Cannot get userinfo from endpoint: {err}") + # user_info_from_endpoint = {} # Build and remember the user in the session request.session["user_sub"] = sub - # Store the user in the database + # Store the user in the database, which also verifies the token validity and signature try: - oidc_provider_settings = oidc_providers_settings[oidc_provider_id] - access_token = oidc_provider_settings.decode(token["access_token"]) - except Exception: + user = await db.add_user( + sub, + user_info=userinfo, + auth_provider=providers[auth_provider_id], + access_token=token["access_token"], + ) + except PyJWTError as err: raise HTTPException( status.HTTP_401_UNAUTHORIZED, - detail="Cannot decode token or verify its signature", + detail=f"Token invalid: {err.__class__.__name__}", ) - user = await db.add_user( - sub, - user_info=userinfo, - oidc_provider=oidc_provider, - user_info_from_endpoint=user_info_from_endpoint, - access_token=token["access_token"], - ) - # Add the id_token to the session - request.session["token"] = token["id_token"] + assert isinstance(user, User) + # Add the provider session id to the session + request.session["sid"] = provider.get_session_key(userinfo) # Add the token to the db because it is used for logout - await db.add_token(token, user) + await db.add_token(provider, token) # Send the user to the home: (s)he is authenticated return RedirectResponse(url=request.url_for("home")) else: # Not sure if it's correct to redirect to plain login # if no userinfo is provided - return RedirectResponse( - url=request.url_for("login", oidc_provider_id=oidc_provider_id) - ) + return RedirectResponse(url=request.url_for("login", auth_provider_id=auth_provider_id)) @app.get("/account") async def account( - request: Request, - oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], + provider: Annotated[Provider, Depends(get_auth_provider)], ) -> RedirectResponse: - if ( - oidc_provider_settings := oidc_providers_settings.get( - request.session.get("oidc_provider_id", "") - ) - ) is None: - raise HTTPException( - status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider settings" - ) - return RedirectResponse(f"{oidc_provider_settings.account_url_template}") + """Redirect to the auth provider account management, + if account_url_template is in the provider's settings""" + return RedirectResponse(f"{provider.account_url_template}") @app.get("/logout") async def logout( request: Request, - oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], + provider: Annotated[Provider, Depends(get_auth_provider)], ) -> RedirectResponse: - # Clear session - request.session.pop("user_sub", None) # Get provider's endpoint if ( - provider_logout_uri := oidc_provider.server_metadata.get("end_session_endpoint") + provider_logout_uri := provider.authlib_client.server_metadata.get("end_session_endpoint") ) is None: - logger.warn( - f"Cannot find end_session_endpoint for provider {oidc_provider.name}" - ) + logger.warning(f"Cannot find end_session_endpoint for provider {provider.id}") return RedirectResponse(request.url_for("non_compliant_logout")) post_logout_uri = request.url_for("home") - if (token := await db.get_token(request.session.pop("token", None))) is None: - logger.warn("No session in db for the token") + # Clear session + request.session.pop("user_sub", None) + request.session.pop("auth_provider_id", None) + try: + token = await db.get_token(provider, request.session.pop("sid", None)) + except TokenNotInDb: + logger.warning("No session in db for the token or no token") return RedirectResponse(request.url_for("home")) - logout_url = ( - provider_logout_uri - + "?" - + urlencode( - { - "post_logout_redirect_uri": post_logout_uri, - "id_token_hint": token["id_token"], - "cliend_id": "oidc_local_test", - } - ) - ) + url_query = { + "post_logout_redirect_uri": post_logout_uri, + "client_id": provider.client_id, + } + if provider.logout_with_id_token_hint: + url_query["id_token_hint"] = token["id_token"] + logout_url = f"{provider_logout_uri}?{urlencode(url_query)}" return RedirectResponse(logout_url) @app.get("/non-compliant-logout") async def non_compliant_logout( request: Request, - oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], + provider: Annotated[Provider, Depends(get_auth_provider)], ): """A page for non-compliant OAuth2 servers that we cannot log out.""" - # Clear the remain of the session - request.session.pop("oidc_provider_id", None) + # Clear session + request.session.pop("user_sub", None) + request.session.pop("auth_provider_id", None) return templates.TemplateResponse( name="non_compliant_logout.html", request=request, - context={"oidc_provider": oidc_provider, "home_url": request.url_for("home")}, + context={"auth_provider": provider, "home_url": request.url_for("home")}, ) -# Route for OAuth resource server - - -@app.get("/resource/{id}") -async def get_resource_( - id: str, - # user: Annotated[User, Depends(get_current_user)], - # oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], - # token: Annotated[OAuth2Token, Depends(get_token)], - user: Annotated[User, Depends(get_user_from_token)], -) -> JSONResponse: - """Generic path for testing a resource provided by a provider""" - return JSONResponse(await get_resource(id, user)) - - -# Routes for RBAC based tests - - -@app.get("/public") -async def public() -> HTMLResponse: - return HTMLResponse("

Not protected

") - - -@app.get("/protected") -async def get_protected( - user: Annotated[User, Depends(get_current_user)] -) -> HTMLResponse: - assert user is not None # Just to keep QA checks happy - return HTMLResponse("

Only authenticated users can see this

") - - -@app.get("/protected-by-foorole") -@hasrole("foorole") -async def get_protected_by_foorole(request: Request) -> HTMLResponse: - assert request is not None # Just to keep QA checks happy - return HTMLResponse("

Only users with foorole can see this

") - - -@app.get("/protected-by-barrole") -@hasrole("barrole") -async def get_protected_by_barrole(request: Request) -> HTMLResponse: - assert request is not None # Just to keep QA checks happy - return HTMLResponse("

Protected by barrole

") - - -@app.get("/protected-by-foorole-and-barrole") -@hasrole("barrole") -@hasrole("foorole") -async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse: - assert request is not None # Just to keep QA checks happy - return HTMLResponse("

Only users with foorole and barrole can see this

") - - -@app.get("/protected-by-foorole-or-barrole") -@hasrole(["foorole", "barrole"]) -async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse: - assert request is not None # Just to keep QA checks happy - return HTMLResponse("

Only users with foorole or barrole can see this

") - - -@app.get("/introspect") -async def get_introspect( +@app.get("/refresh") +async def refresh( request: Request, - oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], - token: Annotated[OAuth2Token, Depends(get_token)], -) -> JSONResponse: - assert request is not None # Just to keep QA checks happy - if (url := oidc_provider.server_metadata.get("introspection_endpoint")) is None: + provider: Annotated[Provider, Depends(get_auth_provider)], + token: Annotated[OAuth2Token, Depends(get_token_from_session)], +) -> RedirectResponse: + """Manually refresh token""" + new_token = await provider.authlib_client.fetch_access_token( + refresh_token=token["refresh_token"], + grant_type="refresh_token", + ) + try: + await update_token(provider.id, new_token) + except PyJWTError as err: + logger.info(f"Cannot refresh token: {err.__class__.__name__}") raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="No intrispection endpoint found for the OIDC provider", + status.HTTP_510_NOT_EXTENDED, f"Token refresh error: {err.__class__.__name__}" ) - if ( - response := await oidc_provider.post( - url, - token=token, - data={"token": token["access_token"]}, - ) - ).is_success: - return response.json() - else: - raise HTTPException(status_code=response.status_code, detail=response.text) + return RedirectResponse(url=request.url_for("home")) # Snippet for running standalone @@ -385,9 +341,7 @@ def main(): parser.add_argument( "-p", "--port", type=int, default=80, help="Port to listen to (default: 80)" ) - parser.add_argument( - "-v", "--version", action="store_true", help="Print version and exit" - ) + parser.add_argument("-v", "--version", action="store_true", help="Print version and exit") args = parser.parse_args() if args.version: diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py index db5d6ad..7b6fd0e 100644 --- a/src/oidc_test/models.py +++ b/src/oidc_test/models.py @@ -1,5 +1,6 @@ +import logging from functools import cached_property -from typing import Self +from typing import Any from pydantic import ( computed_field, @@ -7,9 +8,10 @@ from pydantic import ( EmailStr, ConfigDict, ) -from authlib.integrations.starlette_client.apps import StarletteOAuth2App from sqlmodel import SQLModel, Field +logger = logging.getLogger("oidc-test") + class Role(SQLModel, extra="ignore"): name: str @@ -33,18 +35,8 @@ class User(UserBase): ) userinfo: dict = {} access_token: str | None = None - oidc_provider: StarletteOAuth2App | None = None - - @classmethod - def from_auth(cls, userinfo: dict, oidc_provider: StarletteOAuth2App) -> Self: - user = cls(**userinfo) - user.userinfo = userinfo - user.oidc_provider = oidc_provider - # Add roles if they are provided in the token - if raw_ra := userinfo.get("realm_access"): - if raw_roles := raw_ra.get("roles"): - user.roles = [Role(name=raw_role) for raw_role in raw_roles] - return user + access_token_decoded: dict[str, Any] | None = None + auth_provider_id: str @computed_field @cached_property @@ -54,15 +46,21 @@ class User(UserBase): def has_scope(self, scope: str) -> bool: """Check if the scope is present in user info or access token""" info_scopes = self.userinfo.get("scope", "").split(" ") - access_token_scopes = self.access_token_parsed().get("scope", "").split(" ") + try: + access_token_scopes = self.decode_access_token().get("scope", "").split(" ") + except Exception as err: + logger.debug(f"Cannot find scope because the access token cannot be decoded: {err}") + access_token_scopes = [] return scope in set(info_scopes + access_token_scopes) - def access_token_parsed(self): - assert self.access_token is not None - assert self.oidc_provider is not None - assert self.oidc_provider.name is not None - from .auth_utils import oidc_providers_settings + def decode_access_token(self, verify_signature: bool = True): + assert self.access_token is not None, "no access_token" + assert self.auth_provider_id is not None, "no auth_provider_id" + from .auth_providers import providers - return oidc_providers_settings[self.oidc_provider.name].decode( - self.access_token + return providers[self.auth_provider_id].decode( + self.access_token, verify_signature=verify_signature ) + + def get_scope(self, verify_signature: bool = True): + return self.decode_access_token(verify_signature=verify_signature)["scope"] diff --git a/src/oidc_test/registry.py b/src/oidc_test/registry.py new file mode 100644 index 0000000..3b91ad4 --- /dev/null +++ b/src/oidc_test/registry.py @@ -0,0 +1,47 @@ +from importlib.metadata import entry_points +import logging + +from pydantic import BaseModel, ConfigDict + +from oidc_test.models import User + +logger = logging.getLogger("registry") + + +class ProcessResult(BaseModel): + model_config = ConfigDict( + extra="allow", + ) + + +class ProcessError(Exception): + pass + + +class Resource(BaseModel): + name: str + scope_required: str | None = None + role_required: str | None = None + is_public: bool = False + default_resource_id: str | None = None + + def __init__(self, name: str): + super().__init__() + self.__id__ = name + + async def process(self, user: User | None, resource_id: str | None = None) -> ProcessResult: + logger.warning(f"{self.__id__} should define a process method") + return ProcessResult() + + +class ResourceRegistry(BaseModel): + resources: dict[str, Resource] = {} + + def make_registry(self): + for ep in entry_points().select(group="oidc_test.resource_provider"): + ResourceClass = ep.load() + if issubclass(ResourceClass, Resource): + self.resources[ep.name] = ResourceClass(ep.name) + + +registry = ResourceRegistry() diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index ecaa597..ddc5762 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -1,55 +1,310 @@ -from datetime import datetime +from typing import Annotated, Any import logging +from json import JSONDecodeError -from httpx import AsyncClient +from authlib.oauth2.rfc6749 import OAuth2Token +from httpx import AsyncClient, HTTPError +from jwt.exceptions import DecodeError, ExpiredSignatureError, InvalidTokenError +from fastapi import FastAPI, HTTPException, Depends, Request, status +from fastapi.middleware.cors import CORSMiddleware -from .models import User +# from starlette.middleware.sessions import SessionMiddleware +# from authlib.integrations.starlette_client.apps import StarletteOAuth2App +# from authlib.oauth2.rfc6749 import OAuth2Token -logger = logging.getLogger(__name__) +from oidc_test.auth.provider import Provider +from oidc_test.auth.utils import ( + get_user_from_token_or_none, + oauth2_scheme_optional, +) +from oidc_test.auth_providers import providers +from oidc_test.settings import ResourceProvider, settings +from oidc_test.models import User +from oidc_test.registry import ProcessError, ProcessResult, registry + +logger = logging.getLogger("oidc-test") + +resource_server = FastAPI() -async def get_resource(resource_id: str, user: User) -> dict: - """ - Resource processing: build an informative rely as a simple showcase - """ - pname = getattr(user.oidc_provider, "name", "?") - resp = { - "hello": f"Hi {user.name} from an OAuth resource provider", - "comment": f"I received a request for '{resource_id}' " - + f"with an access token signed by {pname}", - } - # For the demo, resource resource_id matches a scope get:resource_id, - # but this has to be refined for production - required_scope = f"get:{resource_id}" - # Check if the required scope is in the scopes allowed in userinfo - if user.has_scope(required_scope): - await process(user, resource_id, resp) +resource_server.add_middleware( + CORSMiddleware, + allow_origins=settings.cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# SessionMiddleware is required by authlib +# resource_server.add_middleware( +# SessionMiddleware, +# secret_key=settings.secret_key, +# ) + +# Route for OAuth resource server + + +# Routes for RBAC based tests + + +@resource_server.get("/") +async def resources() -> dict[str, dict[str, Any]]: + return {"internal": {}, "plugins": registry.resources} + + +@resource_server.get("/{resource_name}") +@resource_server.get("/{resource_name}/{resource_id}") +async def get_resource( + resource_name: str, + user: Annotated[User | None, Depends(get_user_from_token_or_none)], + token: Annotated[OAuth2Token | None, Depends(oauth2_scheme_optional)], + resource_id: str | None = None, +): + """Generic path for testing a resource provided by a provider. + There's no field validation (response type of ProcessResult) on purpose, + leaving the responsibility of the response validation to resource providers""" + # Get the resource if it's defined in user auth provider's resources (external) + if user is not None: + provider = providers[user.auth_provider_id] + if ":" in resource_name: + # Third-party resource provider: send the request with the request token + resource_provider_id, resource_name = resource_name.split(":", 1) + provider = providers[user.auth_provider_id] + resource_provider: ResourceProvider = provider.get_resource_provider( + resource_provider_id + ) + resource_url = resource_provider.get_resource_url(resource_name) + async with AsyncClient(verify=resource_provider.verify_ssl) as client: + try: + logger.debug(f"GET request to {resource_url}") + resp = await client.get( + resource_url, + headers={ + "Content-type": "application/json", + "Authorization": f"Bearer {token}", + "auth_provider": user.auth_provider_id, + }, + ) + except HTTPError as err: + raise HTTPException( + status.HTTP_503_SERVICE_UNAVAILABLE, err.__class__.__name__ + ) + except Exception as err: + raise HTTPException( + status.HTTP_500_INTERNAL_SERVER_ERROR, err.__class__.__name__ + ) + else: + if resp.is_success: + return resp.json() + else: + reason_str: str + try: + reason_str = resp.json().get("detail", str(resp)) + except Exception: + reason_str = str(resp.text) + raise HTTPException(resp.status_code, reason_str) + # Third party resource (provided through the auth provider) + # The token is just passed on + # XXX: is this branch valid anymore? + if resource_name in [r.resource_name for r in provider.resources]: + return await get_auth_provider_resource( + provider=provider, + resource_name=resource_name, + token=token, + user=user, + ) + # Internal resource (provided here) + if resource_name in registry.resources: + resource = registry.resources[resource_name] + reason: dict[str, str] = {} + if not resource.is_public: + if user is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Resource is not public") + else: + if resource.scope_required is not None and not user.has_scope( + resource.scope_required + ): + reason["scope"] = ( + f"No scope {resource.scope_required} in the access token " + "but it is required for accessing this resource" + ) + if ( + resource.role_required is not None + and resource.role_required not in user.roles_as_set + ): + reason["role"] = ( + f"You don't have the role {resource.role_required} " + "but it is required for accessing this resource" + ) + if len(reason) == 0: + try: + resp = await resource.process(user=user, resource_id=resource_id) + return resp + except ProcessError as err: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, f"Cannot process resource: {err}" + ) + else: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, ", ".join(reason.values())) else: - ## For the showcase, giving a explanation. - ## Alternatively, raise HTTP_401_UNAUTHORIZED - resp["sorry"] = ( - f"No scope {required_scope} in the access token " - + "but it is required for accessing this resource." + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Unknown resource") + # return await get_resource_(resource_name, user, **request.query_params) + + +async def get_auth_provider_resource( + provider: Provider, resource_name: str, token: OAuth2Token | None, user: User +) -> ProcessResult: + if token is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No auth token") + access_token = token + async with AsyncClient() as client: + resp = await client.get( + url=provider.get_resource_url(resource_name), + headers={ + "Content-type": "application/json", + "Authorization": f"Bearer {access_token}", + }, + ) + if resp.is_error: + raise HTTPException(resp.status_code, f"Cannot fetch resource: {resp.reason_phrase}") + # Only a demo, real application would really process the response + resp_length = len(resp.text) + if resp_length > 1024: + return ProcessResult( + msg=f"The resource is too long ({resp_length} bytes) to show in this demo, here is just the begining in raw format", + start=resp.text[:100] + "...", ) - return resp - - -async def process(user, resource_id, resp): - """ - Too simple to be serious. - It's a good fit for a plugin architecture for production - """ - assert user is not None - if resource_id == "time": - resp["time"] = datetime.now().strftime("%c") - elif resource_id == "bs": - async with AsyncClient() as client: - bs = await client.get("https://corporatebs-generator.sameerkumar.website/") - resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.") else: - resp["sorry"] = f"I don't known how to give '{resource_id}'." + try: + resp_json = resp.json() + except JSONDecodeError: + return ProcessResult(msg="The resource is not formatted in JSON", text=resp.text) + if isinstance(resp_json, dict): + return ProcessResult(**resp.json()) + elif isinstance(resp_json, list): + return ProcessResult(**{str(i): line for i, line in enumerate(resp_json)}) +# @resource_server.get("/public") +# async def public() -> dict: +# return {"msg": "Not protected"} +# +# +# @resource_server.get("/protected") +# async def get_protected(user: Annotated[User, Depends(get_user_from_token)]): +# assert user is not None # Just to keep QA checks happy +# return {"msg": "Only authenticated users can see this"} +# +# +# @resource_server.get("/protected-by-foorole") +# async def get_protected_by_foorole( +# user: Annotated[User, Depends(UserWithRole("foorole"))], +# ): +# assert user is not None +# return {"msg": "Only users with foorole can see this"} +# +# +# @resource_server.get("/protected-by-barrole") +# async def get_protected_by_barrole( +# user: Annotated[User, Depends(UserWithRole("barrole"))], +# ): +# assert user is not None +# return {"msg": "Protected by barrole"} +# +# +# @resource_server.get("/protected-by-foorole-and-barrole") +# async def get_protected_by_foorole_and_barrole( +# user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))], +# ): +# assert user is not None # Just to keep QA checks happy +# return {"msg": "Only users with foorole and barrole can see this"} +# +# +# @resource_server.get("/protected-by-foorole-or-barrole") +# async def get_protected_by_foorole_or_barrole( +# user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))], +# ): +# assert user is not None # Just to keep QA checks happy +# return {"msg": "Only users with foorole or barrole can see this"} + +# async def get_resource_(resource_id: str, user: User, **kwargs) -> dict: +# """ +# Resource processing: build an informative rely as a simple showcase +# """ +# if resource_id == "petition": +# return await sign(user, kwargs["petition_id"]) +# provider = providers[user.auth_provider_id] +# try: +# pname = provider.name +# except KeyError: +# pname = "?" +# resp = { +# "hello": f"Hi {user.name} from an OAuth resource provider", +# "comment": f"I received a request for '{resource_id}' " +# + f"with an access token signed by {pname}", +# } +# # For the demo, resource resource_id matches a scope get:resource_id, +# # but this has to be refined for production +# required_scope = f"get:{resource_id}" +# # Check if the required scope is in the scopes allowed in userinfo +# try: +# if user.has_scope(required_scope): +# await process(user, resource_id, resp) +# else: +# ## For the showcase, giving a explanation. +# ## Alternatively, raise HTTP_401_UNAUTHORIZED +# raise HTTPException( +# status.HTTP_401_UNAUTHORIZED, +# f"No scope {required_scope} in the access token " +# + "but it is required for accessing this resource", +# ) +# except ExpiredSignatureError: +# raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token's signature has expired") +# except InvalidTokenError: +# raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token is invalid") +# return resp + + +# async def process(user, resource_id, resp): +# """ +# Too simple to be serious. +# It's a good fit for a plugin architecture for production +# """ +# if resource_id == "time": +# resp["time"] = datetime.now().strftime("%c") +# elif resource_id == "bs": +# async with AsyncClient() as client: +# bs = await client.get("https://corporatebs-generator.sameerkumar.website/") +# resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.") +# else: +# raise HTTPException( +# status.HTTP_401_UNAUTHORIZED, f"I don't known how to give '{resource_id}'." +# ) + + +# @resource_server.get("/introspect") +# async def get_introspect( +# request: Request, +# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], +# token: Annotated[OAuth2Token, Depends(get_token)], +# ) -> JSONResponse: +# assert request is not None # Just to keep QA checks happy +# if (url := oidc_provider.server_metadata.get("introspection_endpoint")) is None: +# raise HTTPException( +# status_code=status.HTTP_401_UNAUTHORIZED, +# detail="No introspection endpoint found for the OIDC provider", +# ) +# if ( +# response := await oidc_provider.post( +# url, +# token=token, +# data={"token": token["access_token"]}, +# ) +# ).is_success: +# return response.json() +# else: +# raise HTTPException(status_code=response.status_code, detail=response.text) + # assert user.oidc_provider is not None ### Get some info (TODO: refactor) # if (auth_provider_id := user.oidc_provider.name) is None: diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index 46d857d..ad80c06 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -4,8 +4,7 @@ import random from typing import Type, Tuple from pathlib import Path -from jwt import decode -from pydantic import BaseModel, computed_field, AnyUrl +from pydantic import AnyHttpUrl, BaseModel, computed_field, AnyUrl from pydantic_settings import ( BaseSettings, SettingsConfigDict, @@ -14,18 +13,33 @@ from pydantic_settings import ( ) from starlette.requests import Request -from .models import User - class Resource(BaseModel): """A resource with an URL that can be accessed with an OAuth2 access token""" + resource_name: str + name: str + url: str + + +class ResourceProvider(BaseModel): id: str name: str + base_url: AnyUrl + resources: list[Resource] = [] + verify_ssl: bool = True + + def get_resource(self, resource_name: str) -> Resource: + return [ + resource for resource in self.resources if resource.resource_name == resource_name + ][0] + + def get_resource_url(self, resource_name: str) -> str: + return f"{self.base_url}{self.get_resource(resource_name).url}" -class OIDCProvider(BaseModel): - """OIDC provider, can also be a resource server""" +class AuthProviderSettings(BaseModel): + """Auth provider, can also be a resource server""" id: str name: str @@ -40,11 +54,14 @@ class OIDCProvider(BaseModel): info_url: str | None = ( None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key) ) - info: dict[str, str | int] | None = ( - None # Info fetched from info_url, eg. public key - ) public_key: str | None = None + public_key_url: str | None = None signature_alg: str = "RS256" + resource_provider_scopes: list[str] = [] + session_key: str = "sid" + skip_verify_signature: bool = True + disabled: bool = False + resource_providers: list[ResourceProvider] = [] @computed_field @property @@ -56,56 +73,20 @@ class OIDCProvider(BaseModel): def token_url(self) -> str: return "auth/" + self.id - def get_account_url(self, request: Request, user: User) -> str | None: + def get_account_url(self, request: Request, user: dict) -> str | None: if self.account_url_template: - if not ( - self.url.endswith("/") or self.account_url_template.startswith("/") - ): + if not (self.url.endswith("/") or self.account_url_template.startswith("/")): sep = "/" else: sep = "" - return ( - self.url - + sep - + self.account_url_template.format(request=request, user=user) - ) + return self.url + sep + self.account_url_template.format(request=request, user=user) else: return None - def get_public_key(self) -> str: - """Return the public key formatted for decoding token""" - public_key = self.public_key or ( - self.info is not None and self.info["public_key"] - ) - if public_key is None: - raise AttributeError(f"Cannot get public key for {self.name}") - return f""" - -----BEGIN PUBLIC KEY----- - {public_key} - -----END PUBLIC KEY----- - """ - def decode(self, token: str) -> dict: - """Decode the token with signature check""" - return decode( - token, - self.get_public_key(), - algorithms=[self.signature_alg], - audience=["oidc-test", "oidc-test-web"], - options={"verify_signature": not settings.insecure.skip_verify_signature}, - ) - - -class ResourceProvider(BaseModel): - id: str - name: str - base_url: AnyUrl - resources: list[Resource] = [] - - -class OIDCSettings(BaseModel): +class AuthSettings(BaseModel): show_session_details: bool = False - providers: list[OIDCProvider] = [] + providers: list[AuthProviderSettings] = [] swagger_provider: str = "" @@ -120,12 +101,15 @@ class Settings(BaseSettings): model_config = SettingsConfigDict(env_nested_delimiter="__") - oidc: OIDCSettings = OIDCSettings() - resource_providers: list[ResourceProvider] = [] + auth: AuthSettings = AuthSettings() secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16)) log: bool = False + log_config_file: str = "log_conf.yaml" insecure: Insecure = Insecure() cors_origins: list[str] = [] + debug_token: bool = False + show_token: bool = False + show_external_resource_providers_links: bool = False @classmethod def settings_customise_sources( @@ -144,9 +128,7 @@ class Settings(BaseSettings): settings_cls, Path( Path( - environ.get( - "OIDC_TEST_SETTINGS_FILE", Path.cwd() / "settings.yaml" - ), + environ.get("OIDC_TEST_SETTINGS_FILE", Path.cwd() / "settings.yaml"), ) ), ), diff --git a/src/oidc_test/static/styles.css b/src/oidc_test/static/styles.css index 4552ca0..1e8dc03 100644 --- a/src/oidc_test/static/styles.css +++ b/src/oidc_test/static/styles.css @@ -21,6 +21,18 @@ hr { .hidden { display: none; } +.version { + position: absolute; + font-size: 75%; + top: 0.3em; + right: 0.3em; +} +.center { + text-align: center; +} +.error { + color: darkred; +} .content { width: 100%; display: flex; @@ -111,6 +123,7 @@ hr { background-color: #8888FF80; } + /* For home */ .login-box { @@ -135,19 +148,27 @@ hr { .providers .provider { min-height: 2em; } -.providers .provider a.link { +.providers .provider .link { text-decoration: none; max-height: 2em; } -.providers .provider .link div { +.providers .provider .link { background-color: #f7c7867d; border-radius: 8px; padding: 6px; text-align: center; color: black; - font-weight: bold; + font-weight: 400; cursor: pointer; + border: 0; + width: 100%; } + +.providers .provider .link.disabled { + color: gray; + cursor: not-allowed; +} + .providers .provider .hint { font-size: 80%; max-width: 13em; @@ -163,10 +184,56 @@ hr { gap: 0.5em; flex-flow: wrap; } -.content .links-to-check a { +.content .links-to-check button { color: black; padding: 5px 10px; text-decoration: none; border-radius: 8px; + border: none; + cursor: pointer; } +.token { + overflow-wrap: anywhere; + font-family: monospace; +} + +.resourceResult { + padding: 0.5em; + display: flex; + gap: 0.5em; + width: fit-content; + align-items: center; + margin: 5px auto; + box-shadow: 0px 0px 10px #90c3eeA0; + background-color: #90c3eeA0; + border-radius: 8px; +} + +.resources { + display: flex; +} + +.resource { + text-align: center; +} + +.token-info { + margin: 0 1em; +} + +.key { + font-weight: bold; +} + +.token .key, .token .value { + display: inline; +} +.token .value { + padding-left: 1em; +} + +.msg { + text-align: center; + font-weight: bold; +} diff --git a/src/oidc_test/static/utils.js b/src/oidc_test/static/utils.js index 142fa6e..e988dfe 100644 --- a/src/oidc_test/static/utils.js +++ b/src/oidc_test/static/utils.js @@ -1,40 +1,90 @@ -function checkHref(elem) { - var xmlHttp = new XMLHttpRequest() - xmlHttp.onreadystatechange = function () { - if (xmlHttp.readyState == 4) { - elem.classList.add("hasResponseStatus") - elem.classList.add("status-" + xmlHttp.status) - elem.title = "Response code: " + xmlHttp.status + " - " + xmlHttp.statusText - } - } - xmlHttp.open("GET", elem.href, true) // true for asynchronous - xmlHttp.send(null) -} - -function checkPerms(className) { - var rootElems = document.getElementsByClassName(className) - Array.from(rootElems).forEach(elem => - Array.from(elem.children).forEach(elem => checkHref(elem)) - ) -} - -async function get_resource(id, token, authProvider) { - //if (!keycloak.keycloak) { return } - const resp = await fetch("resource/" + id, { +async function checkHref(elem, token, authProvider) { + const msg = document.getElementById("msg") + const resourceName = elem.getAttribute("resource-name") + const resourceId = elem.getAttribute("resource-id") + const resourceProviderId = elem.getAttribute("resource-provider-id") ? elem.getAttribute("resource-provider-id") : "" + const fqResourceName = resourceProviderId ? `${resourceProviderId}:${resourceName}` : resourceName + const url = resourceId ? `resource/${fqResourceName}/${resourceId}` : `resource/${fqResourceName}` + const resp = await fetch(url, { method: "GET", headers: new Headers({ "Content-type": "application/json", "Authorization": `Bearer ${token}`, "auth_provider": authProvider, }), + }).catch(err => { + msg.innerHTML = "Cannot fetch resource: " + err.message + resourceElem.innerHTML = "" }) - /* - resource.value = resp['data'] - msg.value = "" + if (resp === undefined) { + return + } else { + elem.classList.add("hasResponseStatus") + elem.classList.add("status-" + resp.status) + elem.title = "Response code: " + resp.status + " - " + resp.statusText } - ).catch ( - err => msg.value = err - ) -*/ - console.log(await resp.json()) +} + +function checkPerms(className, token, authProvider) { + var rootElems = document.getElementsByClassName(className) + Array.from(rootElems).forEach(elem => + Array.from(elem.children).forEach(elem => checkHref(elem, token, authProvider)) + ) +} + +async function get_resource(resourceName, token, authProvider, resourceId, resourceProviderId) { + // BaseUrl for an external resource provider + //if (!keycloak.keycloak) { return } + const msg = document.getElementById("msg") + const resourceElem = document.getElementById('resource') + const fqResourceName = resourceProviderId ? `${resourceProviderId}:${resourceName}` : resourceName + const url = resourceId ? `resource/${fqResourceName}/${resourceId}` : `resource/${fqResourceName}` + const resp = await fetch(url, { + method: "GET", + headers: new Headers({ + "Content-type": "application/json", + "Authorization": `Bearer ${token}`, + "auth_provider": authProvider, + }), + }).catch(err => { + msg.innerHTML = "Cannot fetch resource: " + err.message + resourceElem.innerHTML = "" + }) + if (resp === undefined) { + return + } + const resource = await resp.json() + if (!resp.ok) { + msg.innerHTML = resource["detail"] + resourceElem.innerHTML = "" + return + } + msg.innerHTML = "" + resourceElem.innerHTML = "" + Object.entries(resource).forEach( + ([key, value]) => { + let r = document.createElement('div') + let kElem = document.createElement('div') + kElem.innerText = key + kElem.className = "key" + let vElem = document.createElement('div') + if (typeof value == "object") { + Object.entries(value).forEach(v => { + const ne = document.createElement('div') + ne.innerHTML = `${v[0]}: ${v[1]}` + vElem.appendChild(ne) + }) + } + else { + vElem.innerText = value + } + vElem.className = "value" + if (key == "sorry") { + vElem.classList.add("error") + } + r.appendChild(kElem) + r.appendChild(vElem) + resourceElem.appendChild(r) + } + ) } diff --git a/src/oidc_test/templates/base.html b/src/oidc_test/templates/base.html index 3bdb3f3..157e26f 100644 --- a/src/oidc_test/templates/base.html +++ b/src/oidc_test/templates/base.html @@ -4,7 +4,8 @@ - + +
v. {{ __version__}}

OIDC-test - FastAPI client

{% block content %} {% endblock %} diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index bba2f2a..167616f 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -5,25 +5,28 @@ with OpenID Connect and OAuth2 with different providers.

{% if not user %} - - {% endif %} - {% if user %} + + {% else %} {% endif %}
-

- Fetch resources from the resource server with your authentication token: -

-
- - -
-
-
-
-
{{ key }}
-
{{ value }}
-
{{ value }}
-
-
-
-
{{ msg }}
-
-

- These links should get different response codes depending on the authorization: -

- {% if resources %} -

- Resources for this provider: -

+

This application provides all these resources, eventually protected with scope or roles:

{% endif %} - {% if user_info_details %} -
-
-

User info

-
    - {% for key, value in user_info_details.items() %} -
  • - {{ key }}: {{ value }} -
  • + {% if auth_provider.resources %} +

    {{ auth_provider.name }} is also defined as a provider for these resources:

    +
-
Now is: {{ now.strftime("%T, %D") }}
+ {% endif %} + {% if resource_providers %} +

{{ auth_provider.name }} allows this application to request resources from third party resource providers:

+ {% for resource_provider in resource_providers %} + + {% endfor %} + {% endif %} +
+
+
+
+
+ {% if show_token and id_token_parsed %} +
+
+
+

id token

+
+ {% for key, value in id_token_parsed.items() %} +
+
{{ key }}
+
{{ value }}
+
+ {% endfor %} +
+

access token

+
+ {% for key, value in access_token_parsed.items() %} +
+
{{ key }}
+
{{ value }}
+
+ {% endfor %} +
+

refresh token

+
+ {% for key, value in refresh_token_parsed.items() %} +
+
{{ key }}
+
{{ value }}
+
+ {% endfor %} +
+
{% endif %} {% endblock %} diff --git a/src/oidc_test/templates/non_compliant_logout.html b/src/oidc_test/templates/non_compliant_logout.html index 24a96ae..56758de 100644 --- a/src/oidc_test/templates/non_compliant_logout.html +++ b/src/oidc_test/templates/non_compliant_logout.html @@ -6,12 +6,12 @@ authorisation to log in again without asking for credentials.

- This is because {{ oidc_provider.name }} does not provide "end_session_endpoint" in its metadata - (see: {{ oidc_provider._server_metadata_url }}). + This is because {{ auth_provider.name }} does not provide "end_session_endpoint" in its metadata + (see: {{ auth_provider.authlib_client._server_metadata_url }}).

You can just also go back to the application home page, but - it recommended to go to the OIDC provider's site + it recommended to go to the OIDC provider's site and log out explicitely from there.

{% endblock %} diff --git a/uv.lock b/uv.lock index 01b64de..0566bb5 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.13" [[package]] @@ -206,6 +207,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/68/1b/e0a87d256e40e8c888847551b20a017a6b98139178505dc7ffb96f04e954/dnspython-2.7.0-py3-none-any.whl", hash = "sha256:b4c34b7d10b51bcc3a5071e7b8dee77939f1e878477eeecc965e9835f63c6c86", size = 313632 }, ] +[[package]] +name = "dunamai" +version = "1.23.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/06/4e/a5c8c337a1d9ac0384298ade02d322741fb5998041a5ea74d1cd2a4a1d47/dunamai-1.23.0.tar.gz", hash = "sha256:a163746de7ea5acb6dacdab3a6ad621ebc612ed1e528aaa8beedb8887fccd2c4", size = 44681 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/21/4c/963169386309fec4f96fd61210ac0a0666887d0fb0a50205395674d20b71/dunamai-1.23.0-py3-none-any.whl", hash = "sha256:a0906d876e92441793c6a423e16a4802752e723e9c9a5aabdc5535df02dbe041", size = 26342 }, +] + [[package]] name = "ecdsa" version = "0.19.0" @@ -482,7 +495,6 @@ wheels = [ [[package]] name = "oidc-fastapi-test" -version = "0.0.0" source = { editable = "." } dependencies = [ { name = "authlib" }, @@ -501,6 +513,7 @@ dependencies = [ [package.dev-dependencies] dev = [ + { name = "dunamai" }, { name = "ipdb" }, { name = "pytest" }, ] @@ -523,6 +536,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ + { name = "dunamai", specifier = ">=1.23.0" }, { name = "ipdb", specifier = ">=0.13.13" }, { name = "pytest", specifier = ">=8.3.4" }, ]