diff --git a/.forgejo/workflows/build.yaml b/.forgejo/workflows/build.yaml index 379aaa8..df52e0b 100644 --- a/.forgejo/workflows/build.yaml +++ b/.forgejo/workflows/build.yaml @@ -19,7 +19,7 @@ jobs: - name: Install the latest version of uv uses: astral-sh/setup-uv@v4 with: - version: "0.6.9" + version: "0.5.16" - name: Install run: uv sync @@ -27,26 +27,32 @@ jobs: - name: Run tests (API call) run: .venv/bin/pytest -s tests/basic.py - - name: Get version - run: echo "VERSION=$(.venv/bin/dunamai from any --style semver)" >> $GITHUB_ENV + - name: Get version with git describe + id: version + run: | + echo "version=$(git describe)" >> $GITHUB_OUTPUT + echo "$VERSION" - - name: Version - run: echo $VERSION + - name: Check if the container should be built + id: builder + env: + RUN: ${{ toJSON(inputs.build || !contains(steps.version.outputs.version, '-')) }} + run: | + echo "run=$RUN" >> $GITHUB_OUTPUT + echo "Run build: $RUN" - - name: Get distance from tag - run: echo "DISTANCE=$(.venv/bin/dunamai from any --format '{distance}')" >> $GITHUB_ENV - - - name: Distance - run: echo $DISTANCE + - name: Set the version in pyproject.toml (workaround for uv not supporting dynamic version) + if: fromJSON(steps.builder.outputs.run) + env: + VERSION: ${{ steps.version.outputs.version }} + run: sed "s/0.0.0/$VERSION/" -i pyproject.toml - name: Workaround for bug of podman-login - if: env.DISTANCE == '0' run: | mkdir -p $HOME/.docker echo "{ \"auths\": {} }" > $HOME/.docker/config.json - name: Log in to the container registry (with another workaround) - if: env.DISTANCE == '0' uses: actions/podman-login@v1 with: registry: ${{ vars.REGISTRY }} @@ -55,31 +61,26 @@ jobs: auth_file_path: /tmp/auth.json - name: Build the container image - if: env.DISTANCE == '0' uses: actions/buildah-build@v1 with: image: oidc-fastapi-test oci: true labels: oidc-fastapi-test - tags: "latest ${{ env.VERSION }}" + tags: latest ${{ steps.version.outputs.version }} containerfiles: | ./Containerfile - name: Push the image to the registry - if: env.DISTANCE == '0' uses: actions/push-to-registry@v2 with: registry: "docker://${{ vars.REGISTRY }}/${{ vars.ORGANISATION }}" image: oidc-fastapi-test - tags: "latest ${{ env.VERSION }}" + tags: latest ${{ steps.version.outputs.version }} - name: Build wheel - if: env.DISTANCE == '0' run: uv build --wheel - name: Publish Python package (home) - if: env.DISTANCE == '0' env: LOCAL_PYPI_TOKEN: ${{ secrets.LOCAL_PYPI_TOKEN }} run: uv publish --publish-url https://code.philo.ydns.eu/api/packages/philorg/pypi --token $LOCAL_PYPI_TOKEN - continue-on-error: true diff --git a/.forgejo/workflows/test.yaml b/.forgejo/workflows/test.yaml index f4d994e..a56a9ce 100644 --- a/.forgejo/workflows/test.yaml +++ b/.forgejo/workflows/test.yaml @@ -19,7 +19,7 @@ jobs: - name: Install the latest version of uv uses: astral-sh/setup-uv@v4 with: - version: "0.6.3" + version: "0.5.16" - name: Install run: uv sync diff --git a/Containerfile b/Containerfile index 0ec45d1..aef57f8 100644 --- a/Containerfile +++ b/Containerfile @@ -1,4 +1,4 @@ -FROM docker.io/library/python:latest +FROM docker.io/library/python:alpine COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/ @@ -9,9 +9,6 @@ WORKDIR /app RUN uv pip install --system . -# Add demo plugin -RUN PIP_EXTRA_INDEX_URL=https://pypi.org/simple/ uv pip install --system --index-url https://code.philo.ydns.eu/api/packages/philorg/pypi/simple/ oidc-fastapi-test-resource-provider-demo - # Possible to run with: #CMD ["oidc-test", "--port", "80"] #CMD ["fastapi", "run", "src/oidc_test/main.py", "--port", "8873", "--root-path", "/oidc-test"] diff --git a/README.md b/README.md index 68f335d..0dc11be 100644 --- a/README.md +++ b/README.md @@ -16,12 +16,6 @@ as a template for integration in other FastAPI/SQLModel applications. Feedback welcome. -## Resource server - -It also functions as a resource server in a OAuth architecture. -See a sibling test project, a web based OIDC/OAuth: -[oidc-vue-test](https://code.philo.ydns.eu/philorg/oidc-vue-test). - ## RBAC The application is also a playground for RBAC (Role Based Access control) @@ -51,78 +45,36 @@ given by the OIDC providers. For example: -```yaml -secret_key: AVeryWellKeptSecret -debug_token: no -show_token: yes -log: yes - -auth: +```text +oidc: + secret_key: "ASecretNoOneKnows" + show_session_details: yes providers: - id: auth0 name: Okta / Auth0 - url: https:// - public_key_url: https:///pem - client_id: - client_secret: client_secret_generated_by_auth0 - hint: A hint for test credentials + url: "https://" + client_id: "" + client_secret: "client_secret_generated_by_auth0" + hint: "A hint for test credentials" - id: keycloak name: Keycloak at somewhere - url: https:// - info_url: https://philo.ydns.eu/auth/realms/test - account_url_template: /account - client_id: - client_secret: - hint: A hint for test credentials - code_challenge_method: S256 - resource_provider_scopes: - - get:time - - get:bs - resource_providers: - - id: - name: A third party resource provider - base_url: https://some.example.com/ - verify_ssl: yes - resources: - - name: Public RS2 - resource_name: public - url: resource/public - - name: BS RS2 - resource_name: bs - url: resource/bs - - name: Time RS2 - resource_name: time - url: resource/time + url: "https://" + client_id: "" + client_secret: "client_secret_generated_by_keycloak" + hint: "User: foo, password: foofoo" - id: codeberg - disabled: no name: Codeberg - url: https://codeberg.org - account_url_template: /user/settings - client_id: - client_secret: client_secret_generated_by_codeberg - info_url: https://codeberg.org/login/oauth/keys - session_key: sub - skip_verify_signature: no - resources: - - name: List of repos - id: repos - url: /api/v1/user/repos - - name: List of OAuth2 applications - id: oauth2_applications - url: /api/v1/user/applications/oauth2 - -cors_origins: - - https://some.client - - https://localhost:8000 + url: "https://codeberg.org" + client_id: "" + client_secret: "client_secret_generated_by_codeberg" ``` The application reads the `OIDC_TEST_SETTINGS_FILE` environment variable to determine the location of this file at startup. -For example, to run on port 8000 in a container, -with the setting file in the current working directory: +For example, to run on port 8000 in a container, with the setting file in the current working directory: ```sh podman run -p 8000:80 --env OIDC_TEST_CONFIG_FILE=/app/settings.yaml --mount type=bind,source=settings.yaml,destination=/app/settings.yaml code.philo.ydns.eu/philorg/oidc-fastapi-test:latest diff --git a/TODO b/TODO index 93e80a9..5d7e575 100644 --- a/TODO +++ b/TODO @@ -1,5 +1,3 @@ https://docs.authlib.org/en/latest/oauth/2/intro.html#intro-oauth2 https://www.keycloak.org/docs/latest/authorization_services/index.html - -https://thinhdanggroup.github.io/oauth2-python/ diff --git a/pyproject.toml b/pyproject.toml index c44e9f3..4509e5b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "oidc-fastapi-test" -#version = "0.0.0" -dynamic = ["version"] +version = "0.0.0" +# dynamic = ["version"] description = "Add your description here" readme = "README.md" requires-python = ">=3.13" @@ -9,12 +9,10 @@ dependencies = [ "authlib>=1.4.0", "cachetools>=5.5.0", "fastapi[standard]>=0.115.6", - "httpx>=0.28.1", "itsdangerous>=2.2.0", "passlib[bcrypt]>=1.7.4", "pkce>=1.0.3", "pydantic-settings>=2.7.1", - "pyjwt>=2.10.1", "python-jose[cryptography]>=3.3.0", "requests>=2.32.3", "sqlmodel>=0.0.22", @@ -24,24 +22,14 @@ dependencies = [ oidc-test = "oidc_test.main:main" [dependency-groups] -dev = ["dunamai>=1.23.0", "ipdb>=0.13.13", "pytest>=8.3.4"] +dev = ["ipdb>=0.13.13", "pytest>=8.3.4"] [build-system] -requires = ["hatchling", "uv-dynamic-versioning"] +requires = ["hatchling"] build-backend = "hatchling.build" -[tool.hatch.version] -source = "uv-dynamic-versioning" - [tool.hatch.build.targets.wheel] packages = ["src/oidc_test"] -package = true - -[tool.uv-dynamic-versioning] -style = "semver" [tool.uv] package = true - -[tool.black] -line-length = 98 diff --git a/src/oidc_test/__init__.py b/src/oidc_test/__init__.py index f449e2b..e69de29 100644 --- a/src/oidc_test/__init__.py +++ b/src/oidc_test/__init__.py @@ -1,11 +0,0 @@ -import importlib.metadata - -try: - from dunamai import Version, Style - - __version__ = Version.from_git().serialize(style=Style.SemVer, dirty=True) -except ImportError: - # __name__ could be used if the package name is the same - # as the directory - not the case here - # __version__ = importlib.metadata.version(__name__) - __version__ = importlib.metadata.version("oidc-fastapi-test") diff --git a/src/oidc_test/auth/provider.py b/src/oidc_test/auth/provider.py deleted file mode 100644 index ce288a6..0000000 --- a/src/oidc_test/auth/provider.py +++ /dev/null @@ -1,113 +0,0 @@ -from json import JSONDecodeError -from typing import Any -from jwt import decode -import logging - -from pydantic import ConfigDict -from authlib.integrations.starlette_client.apps import StarletteOAuth2App -from httpx import AsyncClient - -from oidc_test.settings import AuthProviderSettings, ResourceProvider, Resource, settings -from oidc_test.models import User - -logger = logging.getLogger("oidc-test") - - -class NoPublicKey(Exception): - pass - - -class Provider(AuthProviderSettings): - # To allow authlib_client as StarletteOAuth2App - model_config = ConfigDict(arbitrary_types_allowed=True) # type:ignore - - authlib_client: StarletteOAuth2App = StarletteOAuth2App(None) - info: dict[str, Any] = {} - unknown_auth_user: User - logout_with_id_token_hint: bool = True - - def decode(self, token: str, verify_signature: bool | None = None) -> dict[str, Any]: - """Decode the token with signature check""" - if self.public_key is None: - raise NoPublicKey - if verify_signature is None: - verify_signature = self.skip_verify_signature - if settings.debug_token: - decoded = decode( - token, - self.public_key, - algorithms=[self.signature_alg], - audience=["account", "oidc-test", "oidc-test-web"], - options={ - "verify_signature": False, - "verify_aud": False, - }, # not settings.insecure.skip_verify_signature}, - ) - logger.debug(str(decoded)) - return decode( - token, - self.public_key, - algorithms=[self.signature_alg], - audience=["account", "oidc-test", "oidc-test-web"], - options={ - "verify_signature": verify_signature, - }, # not settings.insecure.skip_verify_signature}, - ) - - async def get_info(self): - # Get the public key: - async with AsyncClient() as client: - public_key: str | None = None - if self.info_url is not None: - try: - provider_info = await client.get(self.info_url) - except Exception as err: - logger.debug("Provider_info: cannot connect") - logger.exception(err) - raise NoPublicKey - try: - self.info = provider_info.json() - except JSONDecodeError: - logger.debug("Provider_info: cannot decode json response") - raise NoPublicKey - if "public_key" in self.info: - # For Keycloak - try: - public_key = str(self.info["public_key"]) - except KeyError: - logger.debug("Provider_info: cannot get public_key") - raise NoPublicKey - elif "keys" in self.info: - # For Forgejo/Gitea - try: - public_key = str(self.info["keys"][0]["n"]) - except KeyError: - logger.debug("Provider_info: cannot get key 0.n") - raise NoPublicKey - if self.public_key_url is not None: - resp = await client.get(self.public_key_url) - public_key = resp.text - if public_key is None: - logger.debug("Provider_info: cannot determine public key") - raise NoPublicKey - self.public_key = "\n".join( - ["-----BEGIN PUBLIC KEY-----", public_key, "-----END PUBLIC KEY-----"] - ) - - def get_session_key(self, userinfo): - return userinfo[self.session_key] - - def get_resource(self, resource_name: str) -> Resource: - return [ - resource for resource in self.resources if resource.resource_name == resource_name - ][0] - - def get_resource_url(self, resource_name: str) -> str: - return self.url + self.get_resource(resource_name).url - - def get_resource_provider(self, resource_provider_id: str) -> ResourceProvider: - return [ - provider - for provider in self.resource_providers - if provider.id == resource_provider_id - ][0] diff --git a/src/oidc_test/auth/utils.py b/src/oidc_test/auth/utils.py deleted file mode 100644 index c51b039..0000000 --- a/src/oidc_test/auth/utils.py +++ /dev/null @@ -1,305 +0,0 @@ -from typing import Union, Annotated -from functools import wraps -import logging - -from fastapi import HTTPException, Request, Depends, status -from fastapi.security import OAuth2PasswordBearer -from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App -from authlib.oauth2.rfc6749 import OAuth2Token -from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError - -from oidc_test.auth.provider import Provider -from oidc_test.models import User -from oidc_test.database import db, TokenNotInDb, UserNotInDB -from oidc_test.settings import settings -from oidc_test.auth_providers import providers - -logger = logging.getLogger("oidc-test") - - -async def fetch_token(name, request): - assert name is not None - assert request is not None - logger.warning("TODO: fetch_token") - ... - # if name in oidc_providers: - # model = OAuth2Token - # else: - # model = OAuthToken - - # token = model.find(name=name, user=request.user) - # return token.to_token() - - -async def update_token( - provider_id, - token, - refresh_token: str | None = None, - access_token: str | None = None, -): - """Update the token in the database""" - provider = providers[provider_id] - sid: str = provider.get_session_key(provider.decode(token["id_token"])) - item = await db.get_token(provider, sid) - # update old token - item["access_token"] = token["access_token"] - item["refresh_token"] = token["refresh_token"] - item["id_token"] = token["id_token"] - item["expires_at"] = token["expires_at"] - logger.info(f"Token {sid} refreshed") - # It's a fake db and only in memory, so there's nothing to save - # await item.save() - - -def init_providers(): - """Add oidc providers to authlib from the settings - and build the providers dict""" - for provider_settings in settings.auth.providers: - provider_settings_dict = provider_settings.model_dump() - # Add an anonymous user, that cannot be identified but has provided a valid access token - provider_settings_dict["unknown_auth_user"] = User( - sub="", auth_provider_id=provider_settings.id - ) - provider = Provider(**provider_settings_dict) - if provider.disabled: - logger.info(f"{provider_settings.name} is disabled, skipping") - else: - authlib_oauth.register( - name=provider.id, - server_metadata_url=provider.openid_configuration, - client_kwargs={ - "scope": " ".join( - ["openid", "email", "offline_access", "profile"] - + provider.resource_provider_scopes - ), - }, - client_id=provider.client_id, - client_secret=provider.client_secret, - api_base_url=provider.url, - # For PKCE (not implemented yet): - # code_challenge_method="S256", - fetch_token=fetch_token, - update_token=update_token, - # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) - ) - provider.authlib_client = getattr(authlib_oauth, provider.id) - providers[provider.id] = provider - - -authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") -oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) - - -def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None: - """Return the oidc_provider from a request object, from the session. - It can be used in Depends()""" - if (auth_provider_id := request.session.get("auth_provider_id")) is None: - return - return getattr(authlib_oauth, str(auth_provider_id), None) - - -def get_auth_provider_client(request: Request) -> StarletteOAuth2App: - if (oidc_provider := get_auth_provider_client_or_none(request)) is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") - else: - return oidc_provider - - -def get_auth_provider_or_none(request: Request) -> Provider | None: - """Return the oidc_provider settings from a request object, from the session. - It can be used in Depends()""" - if (auth_provider_id := request.session.get("auth_provider_id")) is None: - return - return providers.get(auth_provider_id) - - -def get_auth_provider(request: Request) -> Provider: - if (provider := get_auth_provider_or_none(request)) is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") - return provider - - -async def get_current_user(request: Request) -> User: - """Get the current user from a request object. - Also validates the token expiration time. - ... TODO: complete about refresh token - """ - if (user_sub := request.session.get("user_sub")) is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED) - token = await get_token_from_session(request) - user = await db.get_user(user_sub) - ## Check if the token is expired - if token.is_expired(): - provider = get_auth_provider(request=request) - ## Ask a new refresh token from the provider - logger.info(f"Token expired for user {user.name}") - try: - userinfo = await provider.authlib_client.fetch_access_token( - refresh_token=token.get("refresh_token") - ) - assert userinfo is not None - except OAuthError as err: - logger.exception(err) - # raise HTTPException( - # status.HTTP_401_UNAUTHORIZED, "Token expired, cannot refresh" - # ) - - return user - - -async def get_token_from_session_or_none(request: Request) -> OAuth2Token | None: - """Return the auth token from the session or None. - Can be used in Depends()""" - try: - return await get_token_from_session(request) - except HTTPException: - return None - - -async def get_token_from_session(request: Request) -> OAuth2Token: - """Return the token from the session. - Can be used in Depends()""" - try: - provider = providers[request.session.get("auth_provider_id", "")] - except KeyError: - request.session.pop("auth_provider_id", None) - request.session.pop("user_sub", None) - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid provider") - try: - return await db.get_token( - provider, - request.session.get("sid"), - ) - except (TokenNotInDb, InvalidKeyError, DecodeError) as err: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, err.__class__.__name__) - - -async def get_current_user_or_none(request: Request) -> User | None: - """Return the user from a request object, from the session. - It can be used in Depends()""" - try: - return await get_current_user(request) - except HTTPException: - return None - - -def hasrole(required_roles: Union[str, list[str]] = []): - """Decorator for RBAC permissions""" - required_roles_set: set[str] - if isinstance(required_roles, str): - required_roles_set = set([required_roles]) - else: - required_roles_set = set(required_roles) - - def decorator(func): - @wraps(func) - async def wrapper(request=None, *args, **kwargs): - if request is None: - raise HTTPException( - 500, - "Functions decorated with hasrole must have a request:Request argument", - ) - user: User = await get_current_user(request) - if not any(required_roles_set.intersection(user.roles_as_set)): - raise HTTPException(status.HTTP_401_UNAUTHORIZED) - return await func(request, *args, **kwargs) - - return wrapper - - return decorator - - -def get_token_info(token: dict) -> dict: - token_info = dict() - for key in token: - if key != "userinfo": - token_info[key] = token[key] - return token_info - - -async def get_user_from_token( - token: Annotated[str, Depends(oauth2_scheme)], - request: Request, -) -> User: - try: - auth_provider_id = request.headers["auth_provider"] - except KeyError: - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, - "Request headers must have a 'auth_provider' field", - ) - try: - provider = providers[auth_provider_id] - except KeyError: - if auth_provider_id == "": - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No auth provider") - else: - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'" - ) - if token == "None": - request.session.pop("auth_provider_id", None) - request.session.pop("user_sub", None) - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token") - try: - payload = provider.decode(token) - except ExpiredSignatureError: - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, - "Expired signature (token refresh not implemented yet)", - ) - except InvalidKeyError: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key") - except PyJWTError as err: - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, f"Cannot decode token: {err.__class__.__name__}" - ) - try: - user_id = payload["sub"] - except KeyError: - return provider.unknown_auth_user - try: - user = await db.get_user(user_id) - if user.access_token != token: - user.access_token = token - except UserNotInDB: - logger.info( - f"User {user_id} not found in DB, creating it (real apps can behave differently)" - ) - user = await db.add_user( - sub=payload["sub"], - user_info=payload, - auth_provider=providers[auth_provider_id], - access_token=token, - ) - return user - - -async def get_user_from_token_or_none( - token: Annotated[str | None, Depends(oauth2_scheme_optional)], - request: Request, -) -> User | None: - if token is None: - return None - try: - return await get_user_from_token(token, request) - except HTTPException: - return None - - -class UserWithRole: - roles: set[str] - - def __init__(self, roles: str | list[str] | tuple[str] | set[str]): - if isinstance(roles, str): - self.roles = set([roles]) - elif isinstance(roles, (list, tuple, set)): - self.roles = set(roles) - - def __call__(self, user: User = Depends(get_user_from_token)) -> User: - if not any(self.roles.intersection(user.roles_as_set)): - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, f"Not of any required role {', '.join(self.roles)}" - ) - return user diff --git a/src/oidc_test/auth_misc.py b/src/oidc_test/auth_misc.py new file mode 100644 index 0000000..a4e9ea3 --- /dev/null +++ b/src/oidc_test/auth_misc.py @@ -0,0 +1,29 @@ +from datetime import datetime, timedelta +from collections import OrderedDict + +from .models import User + +time_keys = set(("iat", "exp", "auth_time", "updated_at")) + + +def pretty_details(user: User, now: datetime) -> OrderedDict: + details = OrderedDict() + # breakpoint() + for key in sorted(time_keys): + try: + dt = datetime.fromtimestamp(user.userinfo[key]) + except (KeyError, TypeError): + pass + else: + td = now - dt + td = timedelta(days=td.days, seconds=td.seconds) + if td.days < 0: + ptd = f"in {-td} h:m:s" + else: + ptd = f"{td} h:m:s ago" + details[key] = f"{user.userinfo[key]} - {dt} ({ptd})" + for key in sorted(user.userinfo): + if key in time_keys: + continue + details[key] = user.userinfo[key] + return details diff --git a/src/oidc_test/auth_providers.py b/src/oidc_test/auth_providers.py deleted file mode 100644 index 1c33ae8..0000000 --- a/src/oidc_test/auth_providers.py +++ /dev/null @@ -1,5 +0,0 @@ -from collections import OrderedDict - -from oidc_test.auth.provider import Provider - -providers: OrderedDict[str, Provider] = OrderedDict() diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py new file mode 100644 index 0000000..1f026b3 --- /dev/null +++ b/src/oidc_test/auth_utils.py @@ -0,0 +1,127 @@ +from typing import Union +from functools import wraps +from datetime import datetime +import logging + +from fastapi import HTTPException, Request, status +from authlib.oauth2.rfc6749 import OAuth2Token +from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App + +# from authlib.oauth1.auth import OAuthToken +# from authlib.oauth2.auth import OAuth2Token + +from .models import User +from .database import db +from .settings import settings + +logger = logging.getLogger(__name__) + +OIDC_PROVIDERS = set([provider.id for provider in settings.oidc.providers]) + + +def get_provider(request: Request) -> StarletteOAuth2App: + """Return the oidc_provider from a request object, from the session. + It can be used in Depends()""" + if (oidc_provider_id := request.session.get("oidc_provider_id")) is None: + raise HTTPException( + status.HTTP_503_SERVICE_UNAVAILABLE, + "Not logged in (no provider in session)", + ) + try: + return getattr(authlib_oauth, str(oidc_provider_id)) + except AttributeError: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") + + +async def get_current_user(request: Request) -> User: + """Get the current user from a request object. + Also validates the token expiration time. + ... TODO: complete about refresh token + """ + if (user_sub := request.session.get("user_sub")) is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED) + if (token := await db.get_token(request.session["token"])) is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token unknown") + user = await db.get_user(user_sub) + ## Check if the token is expired + if token.is_expired(): + oidc_provider = get_provider(request=request) + ## Ask a new refresh token from the provider + logger.info(f"Token expired for user {user.name}") + try: + userinfo = await oidc_provider.fetch_access_token( + refresh_token=token.refresh_token + ) + except OAuthError as err: + logger.exception(err) + # raise HTTPException( + # status.HTTP_401_UNAUTHORIZED, "Token expired, cannot refresh" + # ) + + return user + + +async def get_current_user_or_none(request: Request) -> User | None: + try: + return await get_current_user(request) + except HTTPException: + return None + + +def hasrole(required_roles: Union[str, list[str]] = []): + required_roles_set: set[str] + if isinstance(required_roles, str): + required_roles_set = set([required_roles]) + else: + required_roles_set = set(required_roles) + + def decorator(func): + @wraps(func) + async def wrapper(request=None, *args, **kwargs): + if request is None: + raise HTTPException( + 500, + "Functions decorated with hasrole must have a request:Request argument", + ) + user: User = await get_current_user(request) + if not any(required_roles_set.intersection(user.roles_as_set)): + raise HTTPException(status.HTTP_401_UNAUTHORIZED) + return await func(request, *args, **kwargs) + + return wrapper + + return decorator + + +def get_token_info(token: dict) -> dict: + token_info = dict() + for key in token: + if key != "userinfo": + token_info[key] = token[key] + return token_info + + +def fetch_token(name, request): + breakpoint() + ... + # if name in OIDC_PROVIDERS: + # model = OAuth2Token + # else: + # model = OAuthToken + + # token = model.find(name=name, user=request.user) + # return token.to_token() + + +def update_token(*args, **kwargs): + breakpoint() + ... + + +async def get_token(request: Request) -> OAuth2Token: + if (token := await db.get_token(request.session.get("token"))) is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token") + return token + + +authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index 8d87a48..4b3f529 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -2,28 +2,16 @@ import logging +from authlib.integrations.starlette_client.apps import StarletteOAuth2App + +from .models import User, Role from authlib.oauth2.rfc6749 import OAuth2Token -from jwt import PyJWTError -from oidc_test.auth.provider import Provider - -from oidc_test.models import User, Role -from oidc_test.auth_providers import providers - -logger = logging.getLogger("oidc-test") - - -class UserNotInDB(Exception): - pass - - -class TokenNotInDb(Exception): - pass +logger = logging.getLogger(__name__) class Database: users: dict[str, User] = {} - # TODO: key of the token table should be provider: sid tokens: dict[str, OAuth2Token] = {} # Last sessions for the user (key: users's subject id (sub)) @@ -32,65 +20,30 @@ class Database: self, sub: str, user_info: dict, - auth_provider: Provider, - access_token: str, - access_token_decoded: dict | None = None, + oidc_provider: StarletteOAuth2App, + user_info_from_endpoint: dict, ) -> User: - if access_token_decoded is None: - assert auth_provider.name is not None - provider = providers[auth_provider.id] - try: - access_token_decoded = provider.decode(access_token) - except PyJWTError: - access_token_decoded = {} - user_info["auth_provider_id"] = auth_provider.id - user = User(**user_info) - user.userinfo = user_info - # user.access_token = access_token - # user.access_token_decoded = access_token_decoded - # Add roles provided in the access token - roles = set() + user = User.from_auth(userinfo=user_info, oidc_provider=oidc_provider) try: - r = access_token_decoded["resource_access"][auth_provider.client_id]["roles"] - roles.update(r) - except KeyError: - pass - try: - r = access_token_decoded["realm_access"]["roles"] - if isinstance(r, str): - roles.add(r) - else: - roles.update(r) - except KeyError: - pass - user.roles = [Role(name=role_name) for role_name in roles] + raw_roles = user_info_from_endpoint["resource_access"][ + oidc_provider.client_id + ]["roles"] + except Exception as err: + logger.debug(f"Cannot read additional roles: {err}") + raw_roles = [] + for raw_role in raw_roles: + user.roles.append(Role(name=raw_role)) self.users[sub] = user return user async def get_user(self, sub: str) -> User: - if sub not in self.users: - raise UserNotInDB return self.users[sub] - async def add_token(self, provider: Provider, token: OAuth2Token) -> None: - """Store a token using as key the sid (auth provider's session id) - in the id_token""" - sid = provider.get_session_key(token["userinfo"]) - self.tokens[sid] = token + async def add_token(self, token: OAuth2Token, user: User) -> None: + self.tokens[token["id_token"]] = token - async def get_token( - self, - provider: Provider, - sid: str | None, - ) -> OAuth2Token: - # TODO: key of the token table should be provider: sid - assert isinstance(provider, Provider) - if sid is None: - raise TokenNotInDb - try: - return self.tokens[sid] - except KeyError: - raise TokenNotInDb + async def get_token(self, id_token: str) -> OAuth2Token | None: + return self.tokens.get(id_token) db = Database() diff --git a/src/oidc_test/log_conf.yaml b/src/oidc_test/log_conf.yaml deleted file mode 100644 index a6bb0b4..0000000 --- a/src/oidc_test/log_conf.yaml +++ /dev/null @@ -1,34 +0,0 @@ -version: 1 -disable_existing_loggers: False -formatters: - default: - "()": uvicorn.logging.DefaultFormatter - format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - access: - "()": uvicorn.logging.AccessFormatter - format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" -handlers: - default: - formatter: default - class: logging.StreamHandler - stream: ext://sys.stderr - access: - formatter: access - class: logging.StreamHandler - stream: ext://sys.stdout -loggers: - uvicorn.error: - level: INFO - handlers: - - default - propagate: no - uvicorn.access: - level: INFO - handlers: - - access - propagate: no - "oidc-test": - level: DEBUG - handlers: - - default - propagate: yes diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index e882cda..90ab910 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -6,324 +6,318 @@ from typing import Annotated from pathlib import Path from datetime import datetime import logging -import logging.config -import importlib.resources -from yaml import safe_load from urllib.parse import urlencode -from contextlib import asynccontextmanager from httpx import HTTPError from fastapi import Depends, FastAPI, HTTPException, Request, status -from fastapi.staticfiles import StaticFiles -from fastapi.responses import HTMLResponse, RedirectResponse +from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse from fastapi.templating import Jinja2Templates -from fastapi.middleware.cors import CORSMiddleware -from jwt import PyJWTError +from fastapi.security import OpenIdConnect from starlette.middleware.sessions import SessionMiddleware from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.base_client import OAuthError +from authlib.integrations.httpx_client import AsyncOAuth2Client from authlib.oauth2.rfc6749 import OAuth2Token +from pkce import generate_code_verifier, generate_pkce_pair -# TODO: PKCE -# from authlib.integrations.httpx_client import AsyncOAuth2Client -# from fastapi.security import OpenIdConnect -# from pkce import generate_code_verifier, generate_pkce_pair - -from oidc_test import __version__ -from oidc_test.registry import registry -from oidc_test.auth.provider import NoPublicKey, Provider -from oidc_test.auth.utils import ( - get_auth_provider, - get_auth_provider_or_none, +from .settings import settings +from .models import User +from .auth_utils import ( + get_provider, + hasrole, get_current_user_or_none, + get_current_user, authlib_oauth, - get_token_from_session_or_none, - get_token_from_session, - update_token, + get_token, ) -from oidc_test.auth.utils import init_providers -from oidc_test.settings import settings -from oidc_test.auth_providers import providers -from oidc_test.models import User -from oidc_test.database import TokenNotInDb, db -from oidc_test.resource_server import resource_server +from .auth_misc import pretty_details +from .database import db -logger = logging.getLogger("oidc-test") - -if settings.log: - assert __package__ is not None - with ( - importlib.resources.path(__package__) as package_path, - open(package_path / settings.log_config_file) as f, - ): - logging_config = safe_load(f) - logging.config.dictConfig(logging_config) +logger = logging.getLogger("uvicorn.error") templates = Jinja2Templates(Path(__file__).parent / "templates") -@asynccontextmanager -async def lifespan(app: FastAPI): - assert app is not None - init_providers() - registry.make_registry() - for provider in list(providers.values()): - if provider.disabled: - continue - try: - await provider.get_info() - except NoPublicKey: - logger.warning(f"Disable {provider.id}: public key not found") - del providers[provider.id] - yield - - -app = FastAPI(title="OIDC auth test", lifespan=lifespan) - -app.add_middleware( - CORSMiddleware, - allow_origins=settings.cors_origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], +app = FastAPI( + title="OIDC auth test", ) + # SessionMiddleware is required by authlib app.add_middleware( SessionMiddleware, secret_key=settings.secret_key, ) -app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static") -app.mount("/resource", resource_server, name="resource_server") +# Add oidc providers to authlib from the settings +fastapi_providers = {} +_providers = {} -@app.get("/") -async def home( - request: Request, - user: Annotated[User | None, Depends(get_current_user_or_none)], - provider: Annotated[Provider | None, Depends(get_auth_provider_or_none)], - token: Annotated[OAuth2Token | None, Depends(get_token_from_session_or_none)], -) -> HTMLResponse: - context = { - "show_token": settings.show_token, - "user": user, - "now": datetime.now(), - "__version__": __version__, - } - if provider is None or token is None: - context["providers"] = providers - context["access_token"] = None - context["id_token_parsed"] = None - context["access_token_parsed"] = None - context["refresh_token_parsed"] = None - context["resources"] = None - context["auth_provider"] = None - else: - context["auth_provider"] = provider - context["access_token"] = token["access_token"] - try: - access_token_parsed = provider.decode(token["access_token"], verify_signature=False) - context["access_token_parsed"] = access_token_parsed - context["access_token_scope"] = access_token_parsed.get("scope") - except PyJWTError as err: - context["access_token_parsed"] = {"Cannot parse": err.__class__.__name__} - context["access_token_scope"] = None - try: - id_token_parsed = provider.decode(token["id_token"], verify_signature=False) - context["id_token_parsed"] = id_token_parsed - except PyJWTError as err: - context["id_token_parsed"] = {"Cannot parse": err.__class__.__name__} - try: - refresh_token_parsed = provider.decode(token["refresh_token"], verify_signature=False) - context["refresh_token_parsed"] = refresh_token_parsed - except PyJWTError as err: - context["refresh_token_parsed"] = {"Cannot parse": err.__class__.__name__} - context["resources"] = registry.resources - context["resource_providers"] = provider.resource_providers - return templates.TemplateResponse(name="home.html", request=request, context=context) +for provider in settings.oidc.providers: + authlib_oauth.register( + name=provider.id, + server_metadata_url=provider.openid_configuration, + client_kwargs={ + "scope": "openid email", # offline_access profile", + }, + client_id=provider.client_id, + client_secret=provider.client_secret, + api_base_url=provider.url, + # For PKCE (not implemented yet): + # code_challenge_method="S256", + # fetch_token=fetch_token, + # update_token=update_token, + # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) + ) + fastapi_providers[provider.id] = OpenIdConnect( + openIdConnectUrl=provider.openid_configuration + ) + _providers[provider.id] = provider # Endpoints for the login / authorization process -@app.get("/login/{auth_provider_id}") -async def login(request: Request, auth_provider_id: str) -> RedirectResponse: +@app.get("/login/{oidc_provider_id}") +async def login(request: Request, oidc_provider_id: str) -> RedirectResponse: """Login with the provider id, giving the browser a redirect to its authorize page. - The provider is expected to send the browser back to our own /auth/{auth_provider_id} url + The provider is expected to send the browser back to our own /auth/{oidc_provider_id} url with the token. """ - redirect_uri = request.url_for("auth", auth_provider_id=auth_provider_id) + redirect_uri = request.url_for("auth", oidc_provider_id=oidc_provider_id) try: - provider: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id) + provider_: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id) except AttributeError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") - # if ( - # code_challenge_method := providers[ - # auth_provider_id - # ].code_challenge_method - # ) is not None: - # #client = AsyncOAuth2Client(..., code_challenge_method=code_challenge_method) - # code_verifier = generate_code_verifier() - # logger.debug("TODO: PKCE") - # else: - # code_verifier = None + if ( + code_challenge_method := _providers[oidc_provider_id].code_challenge_method + ) is not None: + client = AsyncOAuth2Client(..., code_challenge_method=code_challenge_method) + code_verifier = generate_code_verifier() + logger.debug("TODO: PKCE") + else: + code_verifier = None try: - response = await provider.authorize_redirect( + response = await provider_.authorize_redirect( request, redirect_uri, access_type="offline", - code_verifier=None, + code_verifier=code_verifier, ) return response except HTTPError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider") -@app.get("/auth/{auth_provider_id}") -async def auth( - request: Request, - auth_provider_id: str, -) -> RedirectResponse: +@app.get("/auth/{oidc_provider_id}") +async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: """Decrypt the auth token, store it to the session (cookie based) and response to the browser with a redirect to a "welcome user" page. """ try: - provider = providers[auth_provider_id] - except KeyError: + oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id) + except AttributeError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") try: - token: OAuth2Token = await provider.authlib_client.authorize_access_token(request) + token: OAuth2Token = await oidc_provider.authorize_access_token(request) except OAuthError as error: raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error) - # Remember the authlib_client in the session - # logger.info(f"Scope: {token['scope']}") - request.session["auth_provider_id"] = auth_provider_id + # Remember the oidc_provider in the session + # logger.debug(f"Scope: {token['scope']}") + request.session["oidc_provider_id"] = oidc_provider_id # # One could process the full decoded token which contains extra information # eg for updates. Here we are only interested in roles # if userinfo := token.get("userinfo"): - # Remember the authlib_client in the session - request.session["auth_provider_id"] = auth_provider_id - # User id (sub) given by auth provider + # Remember the oidc_provider in the session + request.session["oidc_provider_id"] = oidc_provider_id + # User id (sub) given by oidc provider sub = userinfo["sub"] - ## Get additional data from userinfo endpoint - # try: - # user_info_from_endpoint = await authlib_client.userinfo( - # token=token, follow_redirects=True - # ) - # except Exception as err: - # logger.warn(f"Cannot get userinfo from endpoint: {err}") - # user_info_from_endpoint = {} + # Get additional data from userinfo endpoint + try: + user_info_from_endpoint = await oidc_provider.userinfo( + token=token, follow_redirects=True + ) + except Exception as err: + logger.warn(f"Cannot get userinfo from endpoint: {err}") + user_info_from_endpoint = {} # Build and remember the user in the session request.session["user_sub"] = sub - # Store the user in the database, which also verifies the token validity and signature - try: - user = await db.add_user( - sub, - user_info=userinfo, - auth_provider=providers[auth_provider_id], - access_token=token["access_token"], - ) - except PyJWTError as err: - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, - detail=f"Token invalid: {err.__class__.__name__}", - ) - assert isinstance(user, User) - # Add the provider session id to the session - request.session["sid"] = provider.get_session_key(userinfo) + # Store the user in the database + user = await db.add_user( + sub, + user_info=userinfo, + oidc_provider=oidc_provider, + user_info_from_endpoint=user_info_from_endpoint, + ) + # Add the id_token to the session + request.session["token"] = token["id_token"] # Add the token to the db because it is used for logout - await db.add_token(provider, token) + await db.add_token(token, user) # Send the user to the home: (s)he is authenticated return RedirectResponse(url=request.url_for("home")) else: # Not sure if it's correct to redirect to plain login # if no userinfo is provided - return RedirectResponse(url=request.url_for("login", auth_provider_id=auth_provider_id)) - - -@app.get("/account") -async def account( - provider: Annotated[Provider, Depends(get_auth_provider)], -) -> RedirectResponse: - """Redirect to the auth provider account management, - if account_url_template is in the provider's settings""" - return RedirectResponse(f"{provider.account_url_template}") + return RedirectResponse( + url=request.url_for("login", oidc_provider_id=oidc_provider_id) + ) @app.get("/logout") async def logout( request: Request, - provider: Annotated[Provider, Depends(get_auth_provider)], + provider: Annotated[StarletteOAuth2App, Depends(get_provider)], ) -> RedirectResponse: - # Get provider's endpoint - if ( - provider_logout_uri := provider.authlib_client.server_metadata.get("end_session_endpoint") - ) is None: - logger.warning(f"Cannot find end_session_endpoint for provider {provider.id}") - return RedirectResponse(request.url_for("non_compliant_logout")) - post_logout_uri = request.url_for("home") # Clear session request.session.pop("user_sub", None) - request.session.pop("auth_provider_id", None) - try: - token = await db.get_token(provider, request.session.pop("sid", None)) - except TokenNotInDb: - logger.warning("No session in db for the token or no token") + # Get provider's endpoint + if ( + provider_logout_uri := provider.server_metadata.get("end_session_endpoint") + ) is None: + logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}") + return RedirectResponse(request.url_for("non_compliant_logout")) + post_logout_uri = request.url_for("home") + if (token := await db.get_token(request.session.pop("token", None))) is None: + logger.warn("No session in db for the token") return RedirectResponse(request.url_for("home")) - url_query = { - "post_logout_redirect_uri": post_logout_uri, - "client_id": provider.client_id, - } - if provider.logout_with_id_token_hint: - url_query["id_token_hint"] = token["id_token"] - logout_url = f"{provider_logout_uri}?{urlencode(url_query)}" + logout_url = ( + provider_logout_uri + + "?" + + urlencode( + { + "post_logout_redirect_uri": post_logout_uri, + "id_token_hint": token["id_token"], + "cliend_id": "oidc_local_test", + } + ) + ) return RedirectResponse(logout_url) @app.get("/non-compliant-logout") async def non_compliant_logout( request: Request, - provider: Annotated[Provider, Depends(get_auth_provider)], + provider: Annotated[StarletteOAuth2App, Depends(get_provider)], ): """A page for non-compliant OAuth2 servers that we cannot log out.""" - # Clear session - request.session.pop("user_sub", None) - request.session.pop("auth_provider_id", None) return templates.TemplateResponse( name="non_compliant_logout.html", request=request, - context={"auth_provider": provider, "home_url": request.url_for("home")}, + context={"provider": provider, "home_url": request.url_for("home")}, ) -@app.get("/refresh") -async def refresh( +# Home URL + + +@app.get("/") +async def home( + request: Request, user: Annotated[User, Depends(get_current_user_or_none)] +) -> HTMLResponse: + now = datetime.now() + return templates.TemplateResponse( + name="home.html", + request=request, + context={ + "settings": settings.model_dump(), + "user": user, + "now": now, + "user_info_details": ( + pretty_details(user, now) + if user and settings.oidc.show_session_details + else None + ), + }, + ) + + +@app.get("/public") +async def public() -> HTMLResponse: + return HTMLResponse("

Not protected

") + + +# Some URIs for testing the permissions + + +@app.get("/protected") +async def get_protected( + user: Annotated[User, Depends(get_current_user)] +) -> HTMLResponse: + return HTMLResponse("

Only authenticated users can see this

") + + +@app.get("/protected-by-foorole") +@hasrole("foorole") +async def get_protected_by_foorole(request: Request) -> HTMLResponse: + return HTMLResponse("

Only users with foorole can see this

") + + +@app.get("/protected-by-barrole") +@hasrole("barrole") +async def get_protected_by_barrole(request: Request) -> HTMLResponse: + return HTMLResponse("

Protected by barrole

") + + +@app.get("/protected-by-foorole-and-barrole") +@hasrole("barrole") +@hasrole("foorole") +async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse: + return HTMLResponse("

Only users with foorole and barrole can see this

") + + +@app.get("/protected-by-foorole-or-barrole") +@hasrole(["foorole", "barrole"]) +async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse: + return HTMLResponse("

Only users with foorole or barrole can see this

") + + +@app.get("/introspect") +async def get_introspect( request: Request, - provider: Annotated[Provider, Depends(get_auth_provider)], - token: Annotated[OAuth2Token, Depends(get_token_from_session)], -) -> RedirectResponse: - """Manually refresh token""" - new_token = await provider.authlib_client.fetch_access_token( - refresh_token=token["refresh_token"], - grant_type="refresh_token", - ) - try: - await update_token(provider.id, new_token) - except PyJWTError as err: - logger.info(f"Cannot refresh token: {err.__class__.__name__}") - raise HTTPException( - status.HTTP_510_NOT_EXTENDED, f"Token refresh error: {err.__class__.__name__}" + provider: Annotated[StarletteOAuth2App, Depends(get_provider)], + token: Annotated[OAuth2Token, Depends(get_token)], +) -> JSONResponse: + if ( + response := await provider.post( + provider.server_metadata["introspection_endpoint"], + token=token, + data={"token": token["access_token"]}, ) - return RedirectResponse(url=request.url_for("home")) + ).is_success: + return response.json() + else: + raise HTTPException(status_code=response.status_code, detail=response.text) -# Snippet for running standalone -# Mostly useful for the --version option, -# as running with uvicorn is easy and provides better flexibility, eg. -# uvicorn --host foo oidc_test.main:app --reload +@app.get("/oauth2-forgejo-test") +async def get_forgejo_user_info( + request: Request, + user: Annotated[User, Depends(get_current_user)], + provider: Annotated[StarletteOAuth2App, Depends(get_provider)], + token: Annotated[OAuth2Token, Depends(get_token)], +) -> HTMLResponse: + if ( + response := await provider.get( + "/api/v1/user/repos", + # headers={"Authorization": f"token {token['access_token']}"}, + token=token, + ) + ).is_success: + repos = response.json() + names = [repo["name"] for repo in repos] + return HTMLResponse(f"{user.name} has {len(repos)} repos: {', '.join(names)}") + else: + raise HTTPException(status_code=response.status_code, detail=response.text) + + +# @app.get("/fast_api_depends") +# def fast_api_depends( +# token: Annotated[str, Depends(fastapi_providers["Keycloak"])] +# ) -> HTMLResponse: +# return HTMLResponse("You're Authenticated") def main(): @@ -341,7 +335,9 @@ def main(): parser.add_argument( "-p", "--port", type=int, default=80, help="Port to listen to (default: 80)" ) - parser.add_argument("-v", "--version", action="store_true", help="Print version and exit") + parser.add_argument( + "-v", "--version", action="store_true", help="Print version and exit" + ) args = parser.parse_args() if args.version: diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py index 7b6fd0e..3d484aa 100644 --- a/src/oidc_test/models.py +++ b/src/oidc_test/models.py @@ -1,66 +1,55 @@ -import logging from functools import cached_property -from typing import Any +from typing import Self from pydantic import ( computed_field, AnyHttpUrl, EmailStr, ConfigDict, + GetCoreSchemaHandler, ) +from pydantic_core import CoreSchema, core_schema +from authlib.integrations.starlette_client.apps import StarletteOAuth2App +from authlib.oauth2.rfc6749 import OAuth2Token as OAuth2Token_authlib from sqlmodel import SQLModel, Field -logger = logging.getLogger("oidc-test") - class Role(SQLModel, extra="ignore"): name: str class UserBase(SQLModel, extra="ignore"): + id: str | None = None sid: str | None = None - name: str | None = None + name: str email: EmailStr | None = None picture: AnyHttpUrl | None = None roles: list[Role] = [] class User(UserBase): - model_config = ConfigDict(arbitrary_types_allowed=True) # type:ignore + model_config = ConfigDict(arbitrary_types_allowed=True) sub: str = Field( description="""subject id of the user given by the oidc provider, also the key for the database 'table'""", ) userinfo: dict = {} - access_token: str | None = None - access_token_decoded: dict[str, Any] | None = None - auth_provider_id: str + oidc_provider: StarletteOAuth2App | None = None + + @classmethod + def from_auth(cls, userinfo: dict, oidc_provider: StarletteOAuth2App) -> Self: + user = cls(**userinfo) + user.userinfo = userinfo + user.oidc_provider = oidc_provider + # Add roles if they are provided in the token + if raw_ra := userinfo.get("realm_access"): + if raw_roles := raw_ra.get("roles"): + user.roles = [Role(name=raw_role) for raw_role in raw_roles] + return user @computed_field @cached_property def roles_as_set(self) -> set[str]: return set([role.name for role in self.roles]) - - def has_scope(self, scope: str) -> bool: - """Check if the scope is present in user info or access token""" - info_scopes = self.userinfo.get("scope", "").split(" ") - try: - access_token_scopes = self.decode_access_token().get("scope", "").split(" ") - except Exception as err: - logger.debug(f"Cannot find scope because the access token cannot be decoded: {err}") - access_token_scopes = [] - return scope in set(info_scopes + access_token_scopes) - - def decode_access_token(self, verify_signature: bool = True): - assert self.access_token is not None, "no access_token" - assert self.auth_provider_id is not None, "no auth_provider_id" - from .auth_providers import providers - - return providers[self.auth_provider_id].decode( - self.access_token, verify_signature=verify_signature - ) - - def get_scope(self, verify_signature: bool = True): - return self.decode_access_token(verify_signature=verify_signature)["scope"] diff --git a/src/oidc_test/registry.py b/src/oidc_test/registry.py deleted file mode 100644 index 3b91ad4..0000000 --- a/src/oidc_test/registry.py +++ /dev/null @@ -1,47 +0,0 @@ -from importlib.metadata import entry_points -import logging - -from pydantic import BaseModel, ConfigDict - -from oidc_test.models import User - -logger = logging.getLogger("registry") - - -class ProcessResult(BaseModel): - model_config = ConfigDict( - extra="allow", - ) - - -class ProcessError(Exception): - pass - - -class Resource(BaseModel): - name: str - scope_required: str | None = None - role_required: str | None = None - is_public: bool = False - default_resource_id: str | None = None - - def __init__(self, name: str): - super().__init__() - self.__id__ = name - - async def process(self, user: User | None, resource_id: str | None = None) -> ProcessResult: - logger.warning(f"{self.__id__} should define a process method") - return ProcessResult() - - -class ResourceRegistry(BaseModel): - resources: dict[str, Resource] = {} - - def make_registry(self): - for ep in entry_points().select(group="oidc_test.resource_provider"): - ResourceClass = ep.load() - if issubclass(ResourceClass, Resource): - self.resources[ep.name] = ResourceClass(ep.name) - - -registry = ResourceRegistry() diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py deleted file mode 100644 index ddc5762..0000000 --- a/src/oidc_test/resource_server.py +++ /dev/null @@ -1,343 +0,0 @@ -from typing import Annotated, Any -import logging -from json import JSONDecodeError - -from authlib.oauth2.rfc6749 import OAuth2Token -from httpx import AsyncClient, HTTPError -from jwt.exceptions import DecodeError, ExpiredSignatureError, InvalidTokenError -from fastapi import FastAPI, HTTPException, Depends, Request, status -from fastapi.middleware.cors import CORSMiddleware - -# from starlette.middleware.sessions import SessionMiddleware -# from authlib.integrations.starlette_client.apps import StarletteOAuth2App -# from authlib.oauth2.rfc6749 import OAuth2Token - -from oidc_test.auth.provider import Provider -from oidc_test.auth.utils import ( - get_user_from_token_or_none, - oauth2_scheme_optional, -) -from oidc_test.auth_providers import providers -from oidc_test.settings import ResourceProvider, settings -from oidc_test.models import User -from oidc_test.registry import ProcessError, ProcessResult, registry - -logger = logging.getLogger("oidc-test") - -resource_server = FastAPI() - - -resource_server.add_middleware( - CORSMiddleware, - allow_origins=settings.cors_origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -# SessionMiddleware is required by authlib -# resource_server.add_middleware( -# SessionMiddleware, -# secret_key=settings.secret_key, -# ) - -# Route for OAuth resource server - - -# Routes for RBAC based tests - - -@resource_server.get("/") -async def resources() -> dict[str, dict[str, Any]]: - return {"internal": {}, "plugins": registry.resources} - - -@resource_server.get("/{resource_name}") -@resource_server.get("/{resource_name}/{resource_id}") -async def get_resource( - resource_name: str, - user: Annotated[User | None, Depends(get_user_from_token_or_none)], - token: Annotated[OAuth2Token | None, Depends(oauth2_scheme_optional)], - resource_id: str | None = None, -): - """Generic path for testing a resource provided by a provider. - There's no field validation (response type of ProcessResult) on purpose, - leaving the responsibility of the response validation to resource providers""" - # Get the resource if it's defined in user auth provider's resources (external) - if user is not None: - provider = providers[user.auth_provider_id] - if ":" in resource_name: - # Third-party resource provider: send the request with the request token - resource_provider_id, resource_name = resource_name.split(":", 1) - provider = providers[user.auth_provider_id] - resource_provider: ResourceProvider = provider.get_resource_provider( - resource_provider_id - ) - resource_url = resource_provider.get_resource_url(resource_name) - async with AsyncClient(verify=resource_provider.verify_ssl) as client: - try: - logger.debug(f"GET request to {resource_url}") - resp = await client.get( - resource_url, - headers={ - "Content-type": "application/json", - "Authorization": f"Bearer {token}", - "auth_provider": user.auth_provider_id, - }, - ) - except HTTPError as err: - raise HTTPException( - status.HTTP_503_SERVICE_UNAVAILABLE, err.__class__.__name__ - ) - except Exception as err: - raise HTTPException( - status.HTTP_500_INTERNAL_SERVER_ERROR, err.__class__.__name__ - ) - else: - if resp.is_success: - return resp.json() - else: - reason_str: str - try: - reason_str = resp.json().get("detail", str(resp)) - except Exception: - reason_str = str(resp.text) - raise HTTPException(resp.status_code, reason_str) - # Third party resource (provided through the auth provider) - # The token is just passed on - # XXX: is this branch valid anymore? - if resource_name in [r.resource_name for r in provider.resources]: - return await get_auth_provider_resource( - provider=provider, - resource_name=resource_name, - token=token, - user=user, - ) - # Internal resource (provided here) - if resource_name in registry.resources: - resource = registry.resources[resource_name] - reason: dict[str, str] = {} - if not resource.is_public: - if user is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Resource is not public") - else: - if resource.scope_required is not None and not user.has_scope( - resource.scope_required - ): - reason["scope"] = ( - f"No scope {resource.scope_required} in the access token " - "but it is required for accessing this resource" - ) - if ( - resource.role_required is not None - and resource.role_required not in user.roles_as_set - ): - reason["role"] = ( - f"You don't have the role {resource.role_required} " - "but it is required for accessing this resource" - ) - if len(reason) == 0: - try: - resp = await resource.process(user=user, resource_id=resource_id) - return resp - except ProcessError as err: - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, f"Cannot process resource: {err}" - ) - else: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, ", ".join(reason.values())) - else: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Unknown resource") - # return await get_resource_(resource_name, user, **request.query_params) - - -async def get_auth_provider_resource( - provider: Provider, resource_name: str, token: OAuth2Token | None, user: User -) -> ProcessResult: - if token is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No auth token") - access_token = token - async with AsyncClient() as client: - resp = await client.get( - url=provider.get_resource_url(resource_name), - headers={ - "Content-type": "application/json", - "Authorization": f"Bearer {access_token}", - }, - ) - if resp.is_error: - raise HTTPException(resp.status_code, f"Cannot fetch resource: {resp.reason_phrase}") - # Only a demo, real application would really process the response - resp_length = len(resp.text) - if resp_length > 1024: - return ProcessResult( - msg=f"The resource is too long ({resp_length} bytes) to show in this demo, here is just the begining in raw format", - start=resp.text[:100] + "...", - ) - else: - try: - resp_json = resp.json() - except JSONDecodeError: - return ProcessResult(msg="The resource is not formatted in JSON", text=resp.text) - if isinstance(resp_json, dict): - return ProcessResult(**resp.json()) - elif isinstance(resp_json, list): - return ProcessResult(**{str(i): line for i, line in enumerate(resp_json)}) - - -# @resource_server.get("/public") -# async def public() -> dict: -# return {"msg": "Not protected"} -# -# -# @resource_server.get("/protected") -# async def get_protected(user: Annotated[User, Depends(get_user_from_token)]): -# assert user is not None # Just to keep QA checks happy -# return {"msg": "Only authenticated users can see this"} -# -# -# @resource_server.get("/protected-by-foorole") -# async def get_protected_by_foorole( -# user: Annotated[User, Depends(UserWithRole("foorole"))], -# ): -# assert user is not None -# return {"msg": "Only users with foorole can see this"} -# -# -# @resource_server.get("/protected-by-barrole") -# async def get_protected_by_barrole( -# user: Annotated[User, Depends(UserWithRole("barrole"))], -# ): -# assert user is not None -# return {"msg": "Protected by barrole"} -# -# -# @resource_server.get("/protected-by-foorole-and-barrole") -# async def get_protected_by_foorole_and_barrole( -# user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))], -# ): -# assert user is not None # Just to keep QA checks happy -# return {"msg": "Only users with foorole and barrole can see this"} -# -# -# @resource_server.get("/protected-by-foorole-or-barrole") -# async def get_protected_by_foorole_or_barrole( -# user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))], -# ): -# assert user is not None # Just to keep QA checks happy -# return {"msg": "Only users with foorole or barrole can see this"} - -# async def get_resource_(resource_id: str, user: User, **kwargs) -> dict: -# """ -# Resource processing: build an informative rely as a simple showcase -# """ -# if resource_id == "petition": -# return await sign(user, kwargs["petition_id"]) -# provider = providers[user.auth_provider_id] -# try: -# pname = provider.name -# except KeyError: -# pname = "?" -# resp = { -# "hello": f"Hi {user.name} from an OAuth resource provider", -# "comment": f"I received a request for '{resource_id}' " -# + f"with an access token signed by {pname}", -# } -# # For the demo, resource resource_id matches a scope get:resource_id, -# # but this has to be refined for production -# required_scope = f"get:{resource_id}" -# # Check if the required scope is in the scopes allowed in userinfo -# try: -# if user.has_scope(required_scope): -# await process(user, resource_id, resp) -# else: -# ## For the showcase, giving a explanation. -# ## Alternatively, raise HTTP_401_UNAUTHORIZED -# raise HTTPException( -# status.HTTP_401_UNAUTHORIZED, -# f"No scope {required_scope} in the access token " -# + "but it is required for accessing this resource", -# ) -# except ExpiredSignatureError: -# raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token's signature has expired") -# except InvalidTokenError: -# raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token is invalid") -# return resp - - -# async def process(user, resource_id, resp): -# """ -# Too simple to be serious. -# It's a good fit for a plugin architecture for production -# """ -# if resource_id == "time": -# resp["time"] = datetime.now().strftime("%c") -# elif resource_id == "bs": -# async with AsyncClient() as client: -# bs = await client.get("https://corporatebs-generator.sameerkumar.website/") -# resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.") -# else: -# raise HTTPException( -# status.HTTP_401_UNAUTHORIZED, f"I don't known how to give '{resource_id}'." -# ) - - -# @resource_server.get("/introspect") -# async def get_introspect( -# request: Request, -# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], -# token: Annotated[OAuth2Token, Depends(get_token)], -# ) -> JSONResponse: -# assert request is not None # Just to keep QA checks happy -# if (url := oidc_provider.server_metadata.get("introspection_endpoint")) is None: -# raise HTTPException( -# status_code=status.HTTP_401_UNAUTHORIZED, -# detail="No introspection endpoint found for the OIDC provider", -# ) -# if ( -# response := await oidc_provider.post( -# url, -# token=token, -# data={"token": token["access_token"]}, -# ) -# ).is_success: -# return response.json() -# else: -# raise HTTPException(status_code=response.status_code, detail=response.text) - -# assert user.oidc_provider is not None -### Get some info (TODO: refactor) -# if (auth_provider_id := user.oidc_provider.name) is None: -# raise HTTPException( -# status.HTTP_401_UNAUTHORIZED, -# "Request headers must have a 'auth_provider' field", -# ) -# if ( -# auth_provider_settings := oidc_providers_settings.get(auth_provider_id) -# ) is None: -# raise HTTPException( -# status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'" -# ) -# if (key := auth_provider_settings.get_public_key()) is None: -# raise HTTPException( -# status.HTTP_401_UNAUTHORIZED, -# f"Key for provider '{auth_provider_id}' unknown", -# ) -# logger.warn(f"refresh with scope {scope}") -# breakpoint() -# refreshed_auth_info = await user.oidc_provider.fetch_access_token(scope=scope) -### Decode the new token -# try: -# payload = decode( -# refreshed_auth_info["access_token"], -# key=key, -# algorithms=["RS256"], -# audience="account", -# options={"verify_signature": not settings.insecure.skip_verify_signature}, -# ) -# except ExpiredSignatureError as err: -# logger.info(f"Expired signature: {err}") -# raise HTTPException( -# status.HTTP_401_UNAUTHORIZED, -# "Expired signature (refresh not implemented yet)", -# ) diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index ad80c06..0f41a2e 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -4,43 +4,16 @@ import random from typing import Type, Tuple from pathlib import Path -from pydantic import AnyHttpUrl, BaseModel, computed_field, AnyUrl +from pydantic import BaseModel, computed_field from pydantic_settings import ( BaseSettings, SettingsConfigDict, PydanticBaseSettingsSource, YamlConfigSettingsSource, ) -from starlette.requests import Request -class Resource(BaseModel): - """A resource with an URL that can be accessed with an OAuth2 access token""" - - resource_name: str - name: str - url: str - - -class ResourceProvider(BaseModel): - id: str - name: str - base_url: AnyUrl - resources: list[Resource] = [] - verify_ssl: bool = True - - def get_resource(self, resource_name: str) -> Resource: - return [ - resource for resource in self.resources if resource.resource_name == resource_name - ][0] - - def get_resource_url(self, resource_name: str) -> str: - return f"{self.base_url}{self.get_resource(resource_name).url}" - - -class AuthProviderSettings(BaseModel): - """Auth provider, can also be a resource server""" - +class OIDCProvider(BaseModel): id: str name: str url: str @@ -49,19 +22,6 @@ class AuthProviderSettings(BaseModel): # For PKCE (not implemented yet) code_challenge_method: str | None = None hint: str = "No hint" - resources: list[Resource] = [] - account_url_template: str | None = None - info_url: str | None = ( - None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key) - ) - public_key: str | None = None - public_key_url: str | None = None - signature_alg: str = "RS256" - resource_provider_scopes: list[str] = [] - session_key: str = "sid" - skip_verify_signature: bool = True - disabled: bool = False - resource_providers: list[ResourceProvider] = [] @computed_field @property @@ -73,43 +33,21 @@ class AuthProviderSettings(BaseModel): def token_url(self) -> str: return "auth/" + self.id - def get_account_url(self, request: Request, user: dict) -> str | None: - if self.account_url_template: - if not (self.url.endswith("/") or self.account_url_template.startswith("/")): - sep = "/" - else: - sep = "" - return self.url + sep + self.account_url_template.format(request=request, user=user) - else: - return None - -class AuthSettings(BaseModel): +class OIDCSettings(BaseModel): show_session_details: bool = False - providers: list[AuthProviderSettings] = [] + providers: list[OIDCProvider] = [] swagger_provider: str = "" -class Insecure(BaseModel): - """Warning: changing these defaults are only suitable for debugging""" - - skip_verify_signature: bool = False - - class Settings(BaseSettings): """Settings wil be read from an .env file""" - model_config = SettingsConfigDict(env_nested_delimiter="__") - - auth: AuthSettings = AuthSettings() + oidc: OIDCSettings = OIDCSettings() secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16)) log: bool = False - log_config_file: str = "log_conf.yaml" - insecure: Insecure = Insecure() - cors_origins: list[str] = [] - debug_token: bool = False - show_token: bool = False - show_external_resource_providers_links: bool = False + + model_config = SettingsConfigDict(env_nested_delimiter="__") @classmethod def settings_customise_sources( @@ -128,7 +66,9 @@ class Settings(BaseSettings): settings_cls, Path( Path( - environ.get("OIDC_TEST_SETTINGS_FILE", Path.cwd() / "settings.yaml"), + environ.get( + "OIDC_TEST_SETTINGS_FILE", Path.cwd() / "settings.yaml" + ), ) ), ), diff --git a/src/oidc_test/static/styles.css b/src/oidc_test/static/styles.css deleted file mode 100644 index 1e8dc03..0000000 --- a/src/oidc_test/static/styles.css +++ /dev/null @@ -1,239 +0,0 @@ -body { - font-family: Arial, Helvetica, sans-serif; - background-color: floralwhite; - margin: 0; - font-family: system-ui; - text-align: center; -} -h1 { - background-color: #f7c7867d; - margin: 0 0 0.2em 0; - box-shadow: 0px 0.2em 0.2em #f7c7867d; - text-shadow: 0 0 2px #00000080; - font-weight: 200; -} -p { - margin: 0.2em; -} -hr { - margin: 0.2em; -} -.hidden { - display: none; -} -.version { - position: absolute; - font-size: 75%; - top: 0.3em; - right: 0.3em; -} -.center { - text-align: center; -} -.error { - color: darkred; -} -.content { - width: 100%; - display: flex; - flex-direction: column; - align-items: center; - justify-content: center; -} -.user-info { - padding: 0.5em; - display: flex; - gap: 0.5em; - flex-direction: column; - width: fit-content; - align-items: center; - margin: 5px auto; - box-shadow: 0px 0px 10px lightgreen; - background-color: lightgreen; - border-radius: 8px; -} -.user-info * { - flex: 2 1 auto; - margin: 0; -} -.user-info .picture { - max-width: 3em; - max-height: 3em -} -.user-info a.logout { - border: 2px solid darkkhaki; - padding: 3px 6px; - text-decoration: none; - color: black; -} -.user-info a.logout:hover { - background-color: orange; -} -.debug-auth { - font-size: 90%; - background-color: #d8bebc75; - padding: 6px; -} -.debug-auth * { - margin: 0; -} -.debug-auth p { - border-bottom: 1px solid black; -} -.debug-auth ul { - padding: 0; - list-style: none; -} -.debug-auth p, .debug-auth .key { - font-weight: bold; -} -.content { - text-align: left; -} -.hasResponseStatus { - background-color: #88888840; -} -.hasResponseStatus.status-200 { - background-color: #00ff0040; -} -.hasResponseStatus.status-401 { - background-color: #ff000040; -} -.hasResponseStatus.status-403 { - background-color: #ff990040; -} -.hasResponseStatus.status-404 { - background-color: #ffCC0040; -} -.hasResponseStatus.status-503 { - background-color: #ffA88050; -} - -.role, .scope { - padding: 3px 6px; - margin: 3px; - border-radius: 6px; -} - -.role { - background-color: #44228840; -} - -.scope { - background-color: #8888FF80; -} - - -/* For home */ - -.login-box { - background-color: antiquewhite; - margin: 0.5em auto; - width: fit-content; - box-shadow: 0 0 10px #49759b88; - border-radius: 8px; -} -.login-box .description { - font-style: italic; - font-weight: bold; - background-color: #f7c7867d; - padding: 6px; - margin: 0; - border-radius: 8px 8px 0 0; -} -.providers { - justify-content: center; - padding: 0.8em; -} -.providers .provider { - min-height: 2em; -} -.providers .provider .link { - text-decoration: none; - max-height: 2em; -} -.providers .provider .link { - background-color: #f7c7867d; - border-radius: 8px; - padding: 6px; - text-align: center; - color: black; - font-weight: 400; - cursor: pointer; - border: 0; - width: 100%; -} - -.providers .provider .link.disabled { - color: gray; - cursor: not-allowed; -} - -.providers .provider .hint { - font-size: 80%; - max-width: 13em; -} -.providers .error { - padding: 3px 6px; - font-weight: bold; - flex: 1 1 auto; -} -.content .links-to-check { - display: flex; - justify-content: center; - gap: 0.5em; - flex-flow: wrap; -} -.content .links-to-check button { - color: black; - padding: 5px 10px; - text-decoration: none; - border-radius: 8px; - border: none; - cursor: pointer; -} - -.token { - overflow-wrap: anywhere; - font-family: monospace; -} - -.resourceResult { - padding: 0.5em; - display: flex; - gap: 0.5em; - width: fit-content; - align-items: center; - margin: 5px auto; - box-shadow: 0px 0px 10px #90c3eeA0; - background-color: #90c3eeA0; - border-radius: 8px; -} - -.resources { - display: flex; -} - -.resource { - text-align: center; -} - -.token-info { - margin: 0 1em; -} - -.key { - font-weight: bold; -} - -.token .key, .token .value { - display: inline; -} -.token .value { - padding-left: 1em; -} - -.msg { - text-align: center; - font-weight: bold; -} diff --git a/src/oidc_test/static/utils.js b/src/oidc_test/static/utils.js deleted file mode 100644 index e988dfe..0000000 --- a/src/oidc_test/static/utils.js +++ /dev/null @@ -1,90 +0,0 @@ -async function checkHref(elem, token, authProvider) { - const msg = document.getElementById("msg") - const resourceName = elem.getAttribute("resource-name") - const resourceId = elem.getAttribute("resource-id") - const resourceProviderId = elem.getAttribute("resource-provider-id") ? elem.getAttribute("resource-provider-id") : "" - const fqResourceName = resourceProviderId ? `${resourceProviderId}:${resourceName}` : resourceName - const url = resourceId ? `resource/${fqResourceName}/${resourceId}` : `resource/${fqResourceName}` - const resp = await fetch(url, { - method: "GET", - headers: new Headers({ - "Content-type": "application/json", - "Authorization": `Bearer ${token}`, - "auth_provider": authProvider, - }), - }).catch(err => { - msg.innerHTML = "Cannot fetch resource: " + err.message - resourceElem.innerHTML = "" - }) - if (resp === undefined) { - return - } else { - elem.classList.add("hasResponseStatus") - elem.classList.add("status-" + resp.status) - elem.title = "Response code: " + resp.status + " - " + resp.statusText - } -} - -function checkPerms(className, token, authProvider) { - var rootElems = document.getElementsByClassName(className) - Array.from(rootElems).forEach(elem => - Array.from(elem.children).forEach(elem => checkHref(elem, token, authProvider)) - ) -} - -async function get_resource(resourceName, token, authProvider, resourceId, resourceProviderId) { - // BaseUrl for an external resource provider - //if (!keycloak.keycloak) { return } - const msg = document.getElementById("msg") - const resourceElem = document.getElementById('resource') - const fqResourceName = resourceProviderId ? `${resourceProviderId}:${resourceName}` : resourceName - const url = resourceId ? `resource/${fqResourceName}/${resourceId}` : `resource/${fqResourceName}` - const resp = await fetch(url, { - method: "GET", - headers: new Headers({ - "Content-type": "application/json", - "Authorization": `Bearer ${token}`, - "auth_provider": authProvider, - }), - }).catch(err => { - msg.innerHTML = "Cannot fetch resource: " + err.message - resourceElem.innerHTML = "" - }) - if (resp === undefined) { - return - } - const resource = await resp.json() - if (!resp.ok) { - msg.innerHTML = resource["detail"] - resourceElem.innerHTML = "" - return - } - msg.innerHTML = "" - resourceElem.innerHTML = "" - Object.entries(resource).forEach( - ([key, value]) => { - let r = document.createElement('div') - let kElem = document.createElement('div') - kElem.innerText = key - kElem.className = "key" - let vElem = document.createElement('div') - if (typeof value == "object") { - Object.entries(value).forEach(v => { - const ne = document.createElement('div') - ne.innerHTML = `${v[0]}: ${v[1]}` - vElem.appendChild(ne) - }) - } - else { - vElem.innerText = value - } - vElem.className = "value" - if (key == "sorry") { - vElem.classList.add("error") - } - r.appendChild(kElem) - r.appendChild(vElem) - resourceElem.appendChild(r) - } - ) -} diff --git a/src/oidc_test/templates/base.html b/src/oidc_test/templates/base.html index 157e26f..30f6194 100644 --- a/src/oidc_test/templates/base.html +++ b/src/oidc_test/templates/base.html @@ -1,12 +1,139 @@ - OIDC (FastAPI) test - - + FastAPI OIDC test + + - -
v. {{ __version__}}
-

OIDC-test - FastAPI client

+ +

OIDC-test

{% block content %} {% endblock %} diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 167616f..2d858a3 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -1,42 +1,85 @@ {% extends "base.html" %} {% block content %} +

Test the authentication and authorization, with OpenID Connect and OAuth2 with different providers.

{% if not user %} - - {% else %} + + {% endif %} + {% if user %} {% endif %}
- {% if resources %} -

This application provides all these resources, eventually protected with scope or roles:

- - {% endif %} - {% if auth_provider.resources %} -

{{ auth_provider.name }} is also defined as a provider for these resources:

- - {% endif %} - {% if resource_providers %} -

{{ auth_provider.name }} allows this application to request resources from third party resource providers:

- {% for resource_provider in resource_providers %} - - {% endfor %} - {% endif %} - - {% if show_token and id_token_parsed %} -
-
-
-

id token

-
- {% for key, value in id_token_parsed.items() %} -
-
{{ key }}
-
{{ value }}
-
+ {% if user_info_details %} +
+

User info

+
    + {% for key, value in user_info_details.items() %} +
  • + {{ key }}: {{ value }} +
  • {% endfor %} -
-

access token

-
- {% for key, value in access_token_parsed.items() %} -
-
{{ key }}
-
{{ value }}
-
- {% endfor %} -
-

refresh token

-
- {% for key, value in refresh_token_parsed.items() %} -
-
{{ key }}
-
{{ value }}
-
- {% endfor %} -
+
+
Now is: {{ now }}
{% endif %} {% endblock %} diff --git a/src/oidc_test/templates/non_compliant_logout.html b/src/oidc_test/templates/non_compliant_logout.html index 56758de..2f5b247 100644 --- a/src/oidc_test/templates/non_compliant_logout.html +++ b/src/oidc_test/templates/non_compliant_logout.html @@ -6,12 +6,12 @@ authorisation to log in again without asking for credentials.

- This is because {{ auth_provider.name }} does not provide "end_session_endpoint" in its metadata - (see: {{ auth_provider.authlib_client._server_metadata_url }}). + This is because {{ provider.name }} does not provide "end_session_endpoint" in its metadata + (see: {{ provider._server_metadata_url }}).

You can just also go back to the application home page, but - it recommended to go to the OIDC provider's site + it recommended to go to the provider's site and log out explicitely from there.

{% endblock %} diff --git a/uv.lock b/uv.lock index 0566bb5..6ceb4ca 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,4 @@ version = 1 -revision = 1 requires-python = ">=3.13" [[package]] @@ -207,18 +206,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/68/1b/e0a87d256e40e8c888847551b20a017a6b98139178505dc7ffb96f04e954/dnspython-2.7.0-py3-none-any.whl", hash = "sha256:b4c34b7d10b51bcc3a5071e7b8dee77939f1e878477eeecc965e9835f63c6c86", size = 313632 }, ] -[[package]] -name = "dunamai" -version = "1.23.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "packaging" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/06/4e/a5c8c337a1d9ac0384298ade02d322741fb5998041a5ea74d1cd2a4a1d47/dunamai-1.23.0.tar.gz", hash = "sha256:a163746de7ea5acb6dacdab3a6ad621ebc612ed1e528aaa8beedb8887fccd2c4", size = 44681 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/21/4c/963169386309fec4f96fd61210ac0a0666887d0fb0a50205395674d20b71/dunamai-1.23.0-py3-none-any.whl", hash = "sha256:a0906d876e92441793c6a423e16a4802752e723e9c9a5aabdc5535df02dbe041", size = 26342 }, -] - [[package]] name = "ecdsa" version = "0.19.0" @@ -495,17 +482,16 @@ wheels = [ [[package]] name = "oidc-fastapi-test" +version = "0.0.0" source = { editable = "." } dependencies = [ { name = "authlib" }, { name = "cachetools" }, { name = "fastapi", extra = ["standard"] }, - { name = "httpx" }, { name = "itsdangerous" }, { name = "passlib", extra = ["bcrypt"] }, { name = "pkce" }, { name = "pydantic-settings" }, - { name = "pyjwt" }, { name = "python-jose", extra = ["cryptography"] }, { name = "requests" }, { name = "sqlmodel" }, @@ -513,7 +499,6 @@ dependencies = [ [package.dev-dependencies] dev = [ - { name = "dunamai" }, { name = "ipdb" }, { name = "pytest" }, ] @@ -523,12 +508,10 @@ requires-dist = [ { name = "authlib", specifier = ">=1.4.0" }, { name = "cachetools", specifier = ">=5.5.0" }, { name = "fastapi", extras = ["standard"], specifier = ">=0.115.6" }, - { name = "httpx", specifier = ">=0.28.1" }, { name = "itsdangerous", specifier = ">=2.2.0" }, { name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" }, { name = "pkce", specifier = ">=1.0.3" }, { name = "pydantic-settings", specifier = ">=2.7.1" }, - { name = "pyjwt", specifier = ">=2.10.1" }, { name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" }, { name = "requests", specifier = ">=2.32.3" }, { name = "sqlmodel", specifier = ">=0.0.22" }, @@ -536,7 +519,6 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ - { name = "dunamai", specifier = ">=1.23.0" }, { name = "ipdb", specifier = ">=0.13.13" }, { name = "pytest", specifier = ">=8.3.4" }, ] @@ -712,15 +694,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f7/3f/01c8b82017c199075f8f788d0d906b9ffbbc5a47dc9918a945e13d5a2bda/pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a", size = 1205513 }, ] -[[package]] -name = "pyjwt" -version = "2.10.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997 }, -] - [[package]] name = "pytest" version = "8.3.4"