diff --git a/.forgejo/workflows/build.yaml b/.forgejo/workflows/build.yaml
index df52e0b..379aaa8 100644
--- a/.forgejo/workflows/build.yaml
+++ b/.forgejo/workflows/build.yaml
@@ -19,7 +19,7 @@ jobs:
- name: Install the latest version of uv
uses: astral-sh/setup-uv@v4
with:
- version: "0.5.16"
+ version: "0.6.9"
- name: Install
run: uv sync
@@ -27,32 +27,26 @@ jobs:
- name: Run tests (API call)
run: .venv/bin/pytest -s tests/basic.py
- - name: Get version with git describe
- id: version
- run: |
- echo "version=$(git describe)" >> $GITHUB_OUTPUT
- echo "$VERSION"
+ - name: Get version
+ run: echo "VERSION=$(.venv/bin/dunamai from any --style semver)" >> $GITHUB_ENV
- - name: Check if the container should be built
- id: builder
- env:
- RUN: ${{ toJSON(inputs.build || !contains(steps.version.outputs.version, '-')) }}
- run: |
- echo "run=$RUN" >> $GITHUB_OUTPUT
- echo "Run build: $RUN"
+ - name: Version
+ run: echo $VERSION
- - name: Set the version in pyproject.toml (workaround for uv not supporting dynamic version)
- if: fromJSON(steps.builder.outputs.run)
- env:
- VERSION: ${{ steps.version.outputs.version }}
- run: sed "s/0.0.0/$VERSION/" -i pyproject.toml
+ - name: Get distance from tag
+ run: echo "DISTANCE=$(.venv/bin/dunamai from any --format '{distance}')" >> $GITHUB_ENV
+
+ - name: Distance
+ run: echo $DISTANCE
- name: Workaround for bug of podman-login
+ if: env.DISTANCE == '0'
run: |
mkdir -p $HOME/.docker
echo "{ \"auths\": {} }" > $HOME/.docker/config.json
- name: Log in to the container registry (with another workaround)
+ if: env.DISTANCE == '0'
uses: actions/podman-login@v1
with:
registry: ${{ vars.REGISTRY }}
@@ -61,26 +55,31 @@ jobs:
auth_file_path: /tmp/auth.json
- name: Build the container image
+ if: env.DISTANCE == '0'
uses: actions/buildah-build@v1
with:
image: oidc-fastapi-test
oci: true
labels: oidc-fastapi-test
- tags: latest ${{ steps.version.outputs.version }}
+ tags: "latest ${{ env.VERSION }}"
containerfiles: |
./Containerfile
- name: Push the image to the registry
+ if: env.DISTANCE == '0'
uses: actions/push-to-registry@v2
with:
registry: "docker://${{ vars.REGISTRY }}/${{ vars.ORGANISATION }}"
image: oidc-fastapi-test
- tags: latest ${{ steps.version.outputs.version }}
+ tags: "latest ${{ env.VERSION }}"
- name: Build wheel
+ if: env.DISTANCE == '0'
run: uv build --wheel
- name: Publish Python package (home)
+ if: env.DISTANCE == '0'
env:
LOCAL_PYPI_TOKEN: ${{ secrets.LOCAL_PYPI_TOKEN }}
run: uv publish --publish-url https://code.philo.ydns.eu/api/packages/philorg/pypi --token $LOCAL_PYPI_TOKEN
+ continue-on-error: true
diff --git a/.forgejo/workflows/test.yaml b/.forgejo/workflows/test.yaml
index a56a9ce..f4d994e 100644
--- a/.forgejo/workflows/test.yaml
+++ b/.forgejo/workflows/test.yaml
@@ -19,7 +19,7 @@ jobs:
- name: Install the latest version of uv
uses: astral-sh/setup-uv@v4
with:
- version: "0.5.16"
+ version: "0.6.3"
- name: Install
run: uv sync
diff --git a/Containerfile b/Containerfile
index aef57f8..0ec45d1 100644
--- a/Containerfile
+++ b/Containerfile
@@ -1,4 +1,4 @@
-FROM docker.io/library/python:alpine
+FROM docker.io/library/python:latest
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/
@@ -9,6 +9,9 @@ WORKDIR /app
RUN uv pip install --system .
+# Add demo plugin
+RUN PIP_EXTRA_INDEX_URL=https://pypi.org/simple/ uv pip install --system --index-url https://code.philo.ydns.eu/api/packages/philorg/pypi/simple/ oidc-fastapi-test-resource-provider-demo
+
# Possible to run with:
#CMD ["oidc-test", "--port", "80"]
#CMD ["fastapi", "run", "src/oidc_test/main.py", "--port", "8873", "--root-path", "/oidc-test"]
diff --git a/README.md b/README.md
index 0dc11be..68f335d 100644
--- a/README.md
+++ b/README.md
@@ -16,6 +16,12 @@ as a template for integration in other FastAPI/SQLModel applications.
Feedback welcome.
+## Resource server
+
+It also functions as a resource server in a OAuth architecture.
+See a sibling test project, a web based OIDC/OAuth:
+[oidc-vue-test](https://code.philo.ydns.eu/philorg/oidc-vue-test).
+
## RBAC
The application is also a playground for RBAC (Role Based Access control)
@@ -45,36 +51,78 @@ given by the OIDC providers.
For example:
-```text
-oidc:
- secret_key: "ASecretNoOneKnows"
- show_session_details: yes
+```yaml
+secret_key: AVeryWellKeptSecret
+debug_token: no
+show_token: yes
+log: yes
+
+auth:
providers:
- id: auth0
name: Okta / Auth0
- url: "https://"
- client_id: ""
- client_secret: "client_secret_generated_by_auth0"
- hint: "A hint for test credentials"
+ url: https://
+ public_key_url: https:///pem
+ client_id:
+ client_secret: client_secret_generated_by_auth0
+ hint: A hint for test credentials
- id: keycloak
name: Keycloak at somewhere
- url: "https://"
- client_id: ""
- client_secret: "client_secret_generated_by_keycloak"
- hint: "User: foo, password: foofoo"
+ url: https://
+ info_url: https://philo.ydns.eu/auth/realms/test
+ account_url_template: /account
+ client_id:
+ client_secret:
+ hint: A hint for test credentials
+ code_challenge_method: S256
+ resource_provider_scopes:
+ - get:time
+ - get:bs
+ resource_providers:
+ - id:
+ name: A third party resource provider
+ base_url: https://some.example.com/
+ verify_ssl: yes
+ resources:
+ - name: Public RS2
+ resource_name: public
+ url: resource/public
+ - name: BS RS2
+ resource_name: bs
+ url: resource/bs
+ - name: Time RS2
+ resource_name: time
+ url: resource/time
- id: codeberg
+ disabled: no
name: Codeberg
- url: "https://codeberg.org"
- client_id: ""
- client_secret: "client_secret_generated_by_codeberg"
+ url: https://codeberg.org
+ account_url_template: /user/settings
+ client_id:
+ client_secret: client_secret_generated_by_codeberg
+ info_url: https://codeberg.org/login/oauth/keys
+ session_key: sub
+ skip_verify_signature: no
+ resources:
+ - name: List of repos
+ id: repos
+ url: /api/v1/user/repos
+ - name: List of OAuth2 applications
+ id: oauth2_applications
+ url: /api/v1/user/applications/oauth2
+
+cors_origins:
+ - https://some.client
+ - https://localhost:8000
```
The application reads the `OIDC_TEST_SETTINGS_FILE` environment variable
to determine the location of this file at startup.
-For example, to run on port 8000 in a container, with the setting file in the current working directory:
+For example, to run on port 8000 in a container,
+with the setting file in the current working directory:
```sh
podman run -p 8000:80 --env OIDC_TEST_CONFIG_FILE=/app/settings.yaml --mount type=bind,source=settings.yaml,destination=/app/settings.yaml code.philo.ydns.eu/philorg/oidc-fastapi-test:latest
diff --git a/TODO b/TODO
index 5d7e575..93e80a9 100644
--- a/TODO
+++ b/TODO
@@ -1,3 +1,5 @@
https://docs.authlib.org/en/latest/oauth/2/intro.html#intro-oauth2
https://www.keycloak.org/docs/latest/authorization_services/index.html
+
+https://thinhdanggroup.github.io/oauth2-python/
diff --git a/pyproject.toml b/pyproject.toml
index 4509e5b..c44e9f3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,7 +1,7 @@
[project]
name = "oidc-fastapi-test"
-version = "0.0.0"
-# dynamic = ["version"]
+#version = "0.0.0"
+dynamic = ["version"]
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.13"
@@ -9,10 +9,12 @@ dependencies = [
"authlib>=1.4.0",
"cachetools>=5.5.0",
"fastapi[standard]>=0.115.6",
+ "httpx>=0.28.1",
"itsdangerous>=2.2.0",
"passlib[bcrypt]>=1.7.4",
"pkce>=1.0.3",
"pydantic-settings>=2.7.1",
+ "pyjwt>=2.10.1",
"python-jose[cryptography]>=3.3.0",
"requests>=2.32.3",
"sqlmodel>=0.0.22",
@@ -22,14 +24,24 @@ dependencies = [
oidc-test = "oidc_test.main:main"
[dependency-groups]
-dev = ["ipdb>=0.13.13", "pytest>=8.3.4"]
+dev = ["dunamai>=1.23.0", "ipdb>=0.13.13", "pytest>=8.3.4"]
[build-system]
-requires = ["hatchling"]
+requires = ["hatchling", "uv-dynamic-versioning"]
build-backend = "hatchling.build"
+[tool.hatch.version]
+source = "uv-dynamic-versioning"
+
[tool.hatch.build.targets.wheel]
packages = ["src/oidc_test"]
+package = true
+
+[tool.uv-dynamic-versioning]
+style = "semver"
[tool.uv]
package = true
+
+[tool.black]
+line-length = 98
diff --git a/src/oidc_test/__init__.py b/src/oidc_test/__init__.py
index e69de29..f449e2b 100644
--- a/src/oidc_test/__init__.py
+++ b/src/oidc_test/__init__.py
@@ -0,0 +1,11 @@
+import importlib.metadata
+
+try:
+ from dunamai import Version, Style
+
+ __version__ = Version.from_git().serialize(style=Style.SemVer, dirty=True)
+except ImportError:
+ # __name__ could be used if the package name is the same
+ # as the directory - not the case here
+ # __version__ = importlib.metadata.version(__name__)
+ __version__ = importlib.metadata.version("oidc-fastapi-test")
diff --git a/src/oidc_test/auth/provider.py b/src/oidc_test/auth/provider.py
new file mode 100644
index 0000000..ce288a6
--- /dev/null
+++ b/src/oidc_test/auth/provider.py
@@ -0,0 +1,113 @@
+from json import JSONDecodeError
+from typing import Any
+from jwt import decode
+import logging
+
+from pydantic import ConfigDict
+from authlib.integrations.starlette_client.apps import StarletteOAuth2App
+from httpx import AsyncClient
+
+from oidc_test.settings import AuthProviderSettings, ResourceProvider, Resource, settings
+from oidc_test.models import User
+
+logger = logging.getLogger("oidc-test")
+
+
+class NoPublicKey(Exception):
+ pass
+
+
+class Provider(AuthProviderSettings):
+ # To allow authlib_client as StarletteOAuth2App
+ model_config = ConfigDict(arbitrary_types_allowed=True) # type:ignore
+
+ authlib_client: StarletteOAuth2App = StarletteOAuth2App(None)
+ info: dict[str, Any] = {}
+ unknown_auth_user: User
+ logout_with_id_token_hint: bool = True
+
+ def decode(self, token: str, verify_signature: bool | None = None) -> dict[str, Any]:
+ """Decode the token with signature check"""
+ if self.public_key is None:
+ raise NoPublicKey
+ if verify_signature is None:
+ verify_signature = self.skip_verify_signature
+ if settings.debug_token:
+ decoded = decode(
+ token,
+ self.public_key,
+ algorithms=[self.signature_alg],
+ audience=["account", "oidc-test", "oidc-test-web"],
+ options={
+ "verify_signature": False,
+ "verify_aud": False,
+ }, # not settings.insecure.skip_verify_signature},
+ )
+ logger.debug(str(decoded))
+ return decode(
+ token,
+ self.public_key,
+ algorithms=[self.signature_alg],
+ audience=["account", "oidc-test", "oidc-test-web"],
+ options={
+ "verify_signature": verify_signature,
+ }, # not settings.insecure.skip_verify_signature},
+ )
+
+ async def get_info(self):
+ # Get the public key:
+ async with AsyncClient() as client:
+ public_key: str | None = None
+ if self.info_url is not None:
+ try:
+ provider_info = await client.get(self.info_url)
+ except Exception as err:
+ logger.debug("Provider_info: cannot connect")
+ logger.exception(err)
+ raise NoPublicKey
+ try:
+ self.info = provider_info.json()
+ except JSONDecodeError:
+ logger.debug("Provider_info: cannot decode json response")
+ raise NoPublicKey
+ if "public_key" in self.info:
+ # For Keycloak
+ try:
+ public_key = str(self.info["public_key"])
+ except KeyError:
+ logger.debug("Provider_info: cannot get public_key")
+ raise NoPublicKey
+ elif "keys" in self.info:
+ # For Forgejo/Gitea
+ try:
+ public_key = str(self.info["keys"][0]["n"])
+ except KeyError:
+ logger.debug("Provider_info: cannot get key 0.n")
+ raise NoPublicKey
+ if self.public_key_url is not None:
+ resp = await client.get(self.public_key_url)
+ public_key = resp.text
+ if public_key is None:
+ logger.debug("Provider_info: cannot determine public key")
+ raise NoPublicKey
+ self.public_key = "\n".join(
+ ["-----BEGIN PUBLIC KEY-----", public_key, "-----END PUBLIC KEY-----"]
+ )
+
+ def get_session_key(self, userinfo):
+ return userinfo[self.session_key]
+
+ def get_resource(self, resource_name: str) -> Resource:
+ return [
+ resource for resource in self.resources if resource.resource_name == resource_name
+ ][0]
+
+ def get_resource_url(self, resource_name: str) -> str:
+ return self.url + self.get_resource(resource_name).url
+
+ def get_resource_provider(self, resource_provider_id: str) -> ResourceProvider:
+ return [
+ provider
+ for provider in self.resource_providers
+ if provider.id == resource_provider_id
+ ][0]
diff --git a/src/oidc_test/auth/utils.py b/src/oidc_test/auth/utils.py
new file mode 100644
index 0000000..c51b039
--- /dev/null
+++ b/src/oidc_test/auth/utils.py
@@ -0,0 +1,305 @@
+from typing import Union, Annotated
+from functools import wraps
+import logging
+
+from fastapi import HTTPException, Request, Depends, status
+from fastapi.security import OAuth2PasswordBearer
+from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App
+from authlib.oauth2.rfc6749 import OAuth2Token
+from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError
+
+from oidc_test.auth.provider import Provider
+from oidc_test.models import User
+from oidc_test.database import db, TokenNotInDb, UserNotInDB
+from oidc_test.settings import settings
+from oidc_test.auth_providers import providers
+
+logger = logging.getLogger("oidc-test")
+
+
+async def fetch_token(name, request):
+ assert name is not None
+ assert request is not None
+ logger.warning("TODO: fetch_token")
+ ...
+ # if name in oidc_providers:
+ # model = OAuth2Token
+ # else:
+ # model = OAuthToken
+
+ # token = model.find(name=name, user=request.user)
+ # return token.to_token()
+
+
+async def update_token(
+ provider_id,
+ token,
+ refresh_token: str | None = None,
+ access_token: str | None = None,
+):
+ """Update the token in the database"""
+ provider = providers[provider_id]
+ sid: str = provider.get_session_key(provider.decode(token["id_token"]))
+ item = await db.get_token(provider, sid)
+ # update old token
+ item["access_token"] = token["access_token"]
+ item["refresh_token"] = token["refresh_token"]
+ item["id_token"] = token["id_token"]
+ item["expires_at"] = token["expires_at"]
+ logger.info(f"Token {sid} refreshed")
+ # It's a fake db and only in memory, so there's nothing to save
+ # await item.save()
+
+
+def init_providers():
+ """Add oidc providers to authlib from the settings
+ and build the providers dict"""
+ for provider_settings in settings.auth.providers:
+ provider_settings_dict = provider_settings.model_dump()
+ # Add an anonymous user, that cannot be identified but has provided a valid access token
+ provider_settings_dict["unknown_auth_user"] = User(
+ sub="", auth_provider_id=provider_settings.id
+ )
+ provider = Provider(**provider_settings_dict)
+ if provider.disabled:
+ logger.info(f"{provider_settings.name} is disabled, skipping")
+ else:
+ authlib_oauth.register(
+ name=provider.id,
+ server_metadata_url=provider.openid_configuration,
+ client_kwargs={
+ "scope": " ".join(
+ ["openid", "email", "offline_access", "profile"]
+ + provider.resource_provider_scopes
+ ),
+ },
+ client_id=provider.client_id,
+ client_secret=provider.client_secret,
+ api_base_url=provider.url,
+ # For PKCE (not implemented yet):
+ # code_challenge_method="S256",
+ fetch_token=fetch_token,
+ update_token=update_token,
+ # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
+ )
+ provider.authlib_client = getattr(authlib_oauth, provider.id)
+ providers[provider.id] = provider
+
+
+authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
+
+
+def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None:
+ """Return the oidc_provider from a request object, from the session.
+ It can be used in Depends()"""
+ if (auth_provider_id := request.session.get("auth_provider_id")) is None:
+ return
+ return getattr(authlib_oauth, str(auth_provider_id), None)
+
+
+def get_auth_provider_client(request: Request) -> StarletteOAuth2App:
+ if (oidc_provider := get_auth_provider_client_or_none(request)) is None:
+ raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
+ else:
+ return oidc_provider
+
+
+def get_auth_provider_or_none(request: Request) -> Provider | None:
+ """Return the oidc_provider settings from a request object, from the session.
+ It can be used in Depends()"""
+ if (auth_provider_id := request.session.get("auth_provider_id")) is None:
+ return
+ return providers.get(auth_provider_id)
+
+
+def get_auth_provider(request: Request) -> Provider:
+ if (provider := get_auth_provider_or_none(request)) is None:
+ raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
+ return provider
+
+
+async def get_current_user(request: Request) -> User:
+ """Get the current user from a request object.
+ Also validates the token expiration time.
+ ... TODO: complete about refresh token
+ """
+ if (user_sub := request.session.get("user_sub")) is None:
+ raise HTTPException(status.HTTP_401_UNAUTHORIZED)
+ token = await get_token_from_session(request)
+ user = await db.get_user(user_sub)
+ ## Check if the token is expired
+ if token.is_expired():
+ provider = get_auth_provider(request=request)
+ ## Ask a new refresh token from the provider
+ logger.info(f"Token expired for user {user.name}")
+ try:
+ userinfo = await provider.authlib_client.fetch_access_token(
+ refresh_token=token.get("refresh_token")
+ )
+ assert userinfo is not None
+ except OAuthError as err:
+ logger.exception(err)
+ # raise HTTPException(
+ # status.HTTP_401_UNAUTHORIZED, "Token expired, cannot refresh"
+ # )
+
+ return user
+
+
+async def get_token_from_session_or_none(request: Request) -> OAuth2Token | None:
+ """Return the auth token from the session or None.
+ Can be used in Depends()"""
+ try:
+ return await get_token_from_session(request)
+ except HTTPException:
+ return None
+
+
+async def get_token_from_session(request: Request) -> OAuth2Token:
+ """Return the token from the session.
+ Can be used in Depends()"""
+ try:
+ provider = providers[request.session.get("auth_provider_id", "")]
+ except KeyError:
+ request.session.pop("auth_provider_id", None)
+ request.session.pop("user_sub", None)
+ raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid provider")
+ try:
+ return await db.get_token(
+ provider,
+ request.session.get("sid"),
+ )
+ except (TokenNotInDb, InvalidKeyError, DecodeError) as err:
+ raise HTTPException(status.HTTP_401_UNAUTHORIZED, err.__class__.__name__)
+
+
+async def get_current_user_or_none(request: Request) -> User | None:
+ """Return the user from a request object, from the session.
+ It can be used in Depends()"""
+ try:
+ return await get_current_user(request)
+ except HTTPException:
+ return None
+
+
+def hasrole(required_roles: Union[str, list[str]] = []):
+ """Decorator for RBAC permissions"""
+ required_roles_set: set[str]
+ if isinstance(required_roles, str):
+ required_roles_set = set([required_roles])
+ else:
+ required_roles_set = set(required_roles)
+
+ def decorator(func):
+ @wraps(func)
+ async def wrapper(request=None, *args, **kwargs):
+ if request is None:
+ raise HTTPException(
+ 500,
+ "Functions decorated with hasrole must have a request:Request argument",
+ )
+ user: User = await get_current_user(request)
+ if not any(required_roles_set.intersection(user.roles_as_set)):
+ raise HTTPException(status.HTTP_401_UNAUTHORIZED)
+ return await func(request, *args, **kwargs)
+
+ return wrapper
+
+ return decorator
+
+
+def get_token_info(token: dict) -> dict:
+ token_info = dict()
+ for key in token:
+ if key != "userinfo":
+ token_info[key] = token[key]
+ return token_info
+
+
+async def get_user_from_token(
+ token: Annotated[str, Depends(oauth2_scheme)],
+ request: Request,
+) -> User:
+ try:
+ auth_provider_id = request.headers["auth_provider"]
+ except KeyError:
+ raise HTTPException(
+ status.HTTP_401_UNAUTHORIZED,
+ "Request headers must have a 'auth_provider' field",
+ )
+ try:
+ provider = providers[auth_provider_id]
+ except KeyError:
+ if auth_provider_id == "":
+ raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No auth provider")
+ else:
+ raise HTTPException(
+ status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
+ )
+ if token == "None":
+ request.session.pop("auth_provider_id", None)
+ request.session.pop("user_sub", None)
+ raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token")
+ try:
+ payload = provider.decode(token)
+ except ExpiredSignatureError:
+ raise HTTPException(
+ status.HTTP_401_UNAUTHORIZED,
+ "Expired signature (token refresh not implemented yet)",
+ )
+ except InvalidKeyError:
+ raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key")
+ except PyJWTError as err:
+ raise HTTPException(
+ status.HTTP_401_UNAUTHORIZED, f"Cannot decode token: {err.__class__.__name__}"
+ )
+ try:
+ user_id = payload["sub"]
+ except KeyError:
+ return provider.unknown_auth_user
+ try:
+ user = await db.get_user(user_id)
+ if user.access_token != token:
+ user.access_token = token
+ except UserNotInDB:
+ logger.info(
+ f"User {user_id} not found in DB, creating it (real apps can behave differently)"
+ )
+ user = await db.add_user(
+ sub=payload["sub"],
+ user_info=payload,
+ auth_provider=providers[auth_provider_id],
+ access_token=token,
+ )
+ return user
+
+
+async def get_user_from_token_or_none(
+ token: Annotated[str | None, Depends(oauth2_scheme_optional)],
+ request: Request,
+) -> User | None:
+ if token is None:
+ return None
+ try:
+ return await get_user_from_token(token, request)
+ except HTTPException:
+ return None
+
+
+class UserWithRole:
+ roles: set[str]
+
+ def __init__(self, roles: str | list[str] | tuple[str] | set[str]):
+ if isinstance(roles, str):
+ self.roles = set([roles])
+ elif isinstance(roles, (list, tuple, set)):
+ self.roles = set(roles)
+
+ def __call__(self, user: User = Depends(get_user_from_token)) -> User:
+ if not any(self.roles.intersection(user.roles_as_set)):
+ raise HTTPException(
+ status.HTTP_401_UNAUTHORIZED, f"Not of any required role {', '.join(self.roles)}"
+ )
+ return user
diff --git a/src/oidc_test/auth_misc.py b/src/oidc_test/auth_misc.py
deleted file mode 100644
index a4e9ea3..0000000
--- a/src/oidc_test/auth_misc.py
+++ /dev/null
@@ -1,29 +0,0 @@
-from datetime import datetime, timedelta
-from collections import OrderedDict
-
-from .models import User
-
-time_keys = set(("iat", "exp", "auth_time", "updated_at"))
-
-
-def pretty_details(user: User, now: datetime) -> OrderedDict:
- details = OrderedDict()
- # breakpoint()
- for key in sorted(time_keys):
- try:
- dt = datetime.fromtimestamp(user.userinfo[key])
- except (KeyError, TypeError):
- pass
- else:
- td = now - dt
- td = timedelta(days=td.days, seconds=td.seconds)
- if td.days < 0:
- ptd = f"in {-td} h:m:s"
- else:
- ptd = f"{td} h:m:s ago"
- details[key] = f"{user.userinfo[key]} - {dt} ({ptd})"
- for key in sorted(user.userinfo):
- if key in time_keys:
- continue
- details[key] = user.userinfo[key]
- return details
diff --git a/src/oidc_test/auth_providers.py b/src/oidc_test/auth_providers.py
new file mode 100644
index 0000000..1c33ae8
--- /dev/null
+++ b/src/oidc_test/auth_providers.py
@@ -0,0 +1,5 @@
+from collections import OrderedDict
+
+from oidc_test.auth.provider import Provider
+
+providers: OrderedDict[str, Provider] = OrderedDict()
diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py
deleted file mode 100644
index 880c111..0000000
--- a/src/oidc_test/auth_utils.py
+++ /dev/null
@@ -1,136 +0,0 @@
-from typing import Union
-from functools import wraps
-from datetime import datetime
-import logging
-
-from fastapi import HTTPException, Request, status
-from authlib.oauth2.rfc6749 import OAuth2Token
-from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App
-
-# from authlib.oauth1.auth import OAuthToken
-# from authlib.oauth2.auth import OAuth2Token
-
-from .models import User
-from .database import db
-from .settings import settings
-
-logger = logging.getLogger(__name__)
-
-OIDC_PROVIDERS = set([provider.id for provider in settings.oidc.providers])
-
-
-def get_oidc_provider_or_none(request: Request) -> StarletteOAuth2App | None:
- """Return the oidc_provider from a request object, from the session.
- It can be used in Depends()"""
- if (oidc_provider_id := request.session.get("oidc_provider_id")) is None:
- return
- try:
- return getattr(authlib_oauth, str(oidc_provider_id))
- except AttributeError:
- return
-
-
-def get_oidc_provider(request: Request) -> StarletteOAuth2App:
- if (oidc_provider := get_oidc_provider_or_none(request)) is None:
- raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
- else:
- return oidc_provider
-
-
-async def get_current_user(request: Request) -> User:
- """Get the current user from a request object.
- Also validates the token expiration time.
- ... TODO: complete about refresh token
- """
- if (user_sub := request.session.get("user_sub")) is None:
- raise HTTPException(status.HTTP_401_UNAUTHORIZED)
- if (token := await db.get_token(request.session["token"])) is None:
- raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token unknown")
- user = await db.get_user(user_sub)
- ## Check if the token is expired
- if token.is_expired():
- oidc_provider = get_oidc_provider(request=request)
- ## Ask a new refresh token from the provider
- logger.info(f"Token expired for user {user.name}")
- try:
- userinfo = await oidc_provider.fetch_access_token(
- refresh_token=token.refresh_token
- )
- except OAuthError as err:
- logger.exception(err)
- # raise HTTPException(
- # status.HTTP_401_UNAUTHORIZED, "Token expired, cannot refresh"
- # )
-
- return user
-
-
-async def get_token(request: Request) -> OAuth2Token:
- """Return the token from a request object, from the session.
- It can be used in Depends()"""
- if (token := await db.get_token(request.session.get("token"))) is None:
- raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token")
- return token
-
-
-async def get_current_user_or_none(request: Request) -> User | None:
- """Return the user from a request object, from the session.
- It can be used in Depends()"""
- try:
- return await get_current_user(request)
- except HTTPException:
- return None
-
-
-def hasrole(required_roles: Union[str, list[str]] = []):
- """Decorator for RBAC permissions"""
- required_roles_set: set[str]
- if isinstance(required_roles, str):
- required_roles_set = set([required_roles])
- else:
- required_roles_set = set(required_roles)
-
- def decorator(func):
- @wraps(func)
- async def wrapper(request=None, *args, **kwargs):
- if request is None:
- raise HTTPException(
- 500,
- "Functions decorated with hasrole must have a request:Request argument",
- )
- user: User = await get_current_user(request)
- if not any(required_roles_set.intersection(user.roles_as_set)):
- raise HTTPException(status.HTTP_401_UNAUTHORIZED)
- return await func(request, *args, **kwargs)
-
- return wrapper
-
- return decorator
-
-
-def get_token_info(token: dict) -> dict:
- token_info = dict()
- for key in token:
- if key != "userinfo":
- token_info[key] = token[key]
- return token_info
-
-
-def fetch_token(name, request):
- breakpoint()
- ...
- # if name in OIDC_PROVIDERS:
- # model = OAuth2Token
- # else:
- # model = OAuthToken
-
- # token = model.find(name=name, user=request.user)
- # return token.to_token()
-
-
-def update_token(*args, **kwargs):
- breakpoint()
- ...
-
-
-authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py
index 4b3f529..8d87a48 100644
--- a/src/oidc_test/database.py
+++ b/src/oidc_test/database.py
@@ -2,16 +2,28 @@
import logging
-from authlib.integrations.starlette_client.apps import StarletteOAuth2App
-
-from .models import User, Role
from authlib.oauth2.rfc6749 import OAuth2Token
+from jwt import PyJWTError
-logger = logging.getLogger(__name__)
+from oidc_test.auth.provider import Provider
+
+from oidc_test.models import User, Role
+from oidc_test.auth_providers import providers
+
+logger = logging.getLogger("oidc-test")
+
+
+class UserNotInDB(Exception):
+ pass
+
+
+class TokenNotInDb(Exception):
+ pass
class Database:
users: dict[str, User] = {}
+ # TODO: key of the token table should be provider: sid
tokens: dict[str, OAuth2Token] = {}
# Last sessions for the user (key: users's subject id (sub))
@@ -20,30 +32,65 @@ class Database:
self,
sub: str,
user_info: dict,
- oidc_provider: StarletteOAuth2App,
- user_info_from_endpoint: dict,
+ auth_provider: Provider,
+ access_token: str,
+ access_token_decoded: dict | None = None,
) -> User:
- user = User.from_auth(userinfo=user_info, oidc_provider=oidc_provider)
+ if access_token_decoded is None:
+ assert auth_provider.name is not None
+ provider = providers[auth_provider.id]
+ try:
+ access_token_decoded = provider.decode(access_token)
+ except PyJWTError:
+ access_token_decoded = {}
+ user_info["auth_provider_id"] = auth_provider.id
+ user = User(**user_info)
+ user.userinfo = user_info
+ # user.access_token = access_token
+ # user.access_token_decoded = access_token_decoded
+ # Add roles provided in the access token
+ roles = set()
try:
- raw_roles = user_info_from_endpoint["resource_access"][
- oidc_provider.client_id
- ]["roles"]
- except Exception as err:
- logger.debug(f"Cannot read additional roles: {err}")
- raw_roles = []
- for raw_role in raw_roles:
- user.roles.append(Role(name=raw_role))
+ r = access_token_decoded["resource_access"][auth_provider.client_id]["roles"]
+ roles.update(r)
+ except KeyError:
+ pass
+ try:
+ r = access_token_decoded["realm_access"]["roles"]
+ if isinstance(r, str):
+ roles.add(r)
+ else:
+ roles.update(r)
+ except KeyError:
+ pass
+ user.roles = [Role(name=role_name) for role_name in roles]
self.users[sub] = user
return user
async def get_user(self, sub: str) -> User:
+ if sub not in self.users:
+ raise UserNotInDB
return self.users[sub]
- async def add_token(self, token: OAuth2Token, user: User) -> None:
- self.tokens[token["id_token"]] = token
+ async def add_token(self, provider: Provider, token: OAuth2Token) -> None:
+ """Store a token using as key the sid (auth provider's session id)
+ in the id_token"""
+ sid = provider.get_session_key(token["userinfo"])
+ self.tokens[sid] = token
- async def get_token(self, id_token: str) -> OAuth2Token | None:
- return self.tokens.get(id_token)
+ async def get_token(
+ self,
+ provider: Provider,
+ sid: str | None,
+ ) -> OAuth2Token:
+ # TODO: key of the token table should be provider: sid
+ assert isinstance(provider, Provider)
+ if sid is None:
+ raise TokenNotInDb
+ try:
+ return self.tokens[sid]
+ except KeyError:
+ raise TokenNotInDb
db = Database()
diff --git a/src/oidc_test/log_conf.yaml b/src/oidc_test/log_conf.yaml
new file mode 100644
index 0000000..a6bb0b4
--- /dev/null
+++ b/src/oidc_test/log_conf.yaml
@@ -0,0 +1,34 @@
+version: 1
+disable_existing_loggers: False
+formatters:
+ default:
+ "()": uvicorn.logging.DefaultFormatter
+ format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+ access:
+ "()": uvicorn.logging.AccessFormatter
+ format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+handlers:
+ default:
+ formatter: default
+ class: logging.StreamHandler
+ stream: ext://sys.stderr
+ access:
+ formatter: access
+ class: logging.StreamHandler
+ stream: ext://sys.stdout
+loggers:
+ uvicorn.error:
+ level: INFO
+ handlers:
+ - default
+ propagate: no
+ uvicorn.access:
+ level: INFO
+ handlers:
+ - access
+ propagate: no
+ "oidc-test":
+ level: DEBUG
+ handlers:
+ - default
+ propagate: yes
diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py
index ee50025..e882cda 100644
--- a/src/oidc_test/main.py
+++ b/src/oidc_test/main.py
@@ -6,46 +6,86 @@ from typing import Annotated
from pathlib import Path
from datetime import datetime
import logging
+import logging.config
+import importlib.resources
+from yaml import safe_load
from urllib.parse import urlencode
+from contextlib import asynccontextmanager
from httpx import HTTPError
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.staticfiles import StaticFiles
-from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
+from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
-from fastapi.security import OpenIdConnect
+from fastapi.middleware.cors import CORSMiddleware
+from jwt import PyJWTError
from starlette.middleware.sessions import SessionMiddleware
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from authlib.integrations.base_client import OAuthError
-from authlib.integrations.httpx_client import AsyncOAuth2Client
from authlib.oauth2.rfc6749 import OAuth2Token
-from pkce import generate_code_verifier, generate_pkce_pair
-from .settings import settings, OIDCProvider
-from .models import User
-from .auth_utils import (
- get_oidc_provider,
- get_oidc_provider_or_none,
- hasrole,
+# TODO: PKCE
+# from authlib.integrations.httpx_client import AsyncOAuth2Client
+# from fastapi.security import OpenIdConnect
+# from pkce import generate_code_verifier, generate_pkce_pair
+
+from oidc_test import __version__
+from oidc_test.registry import registry
+from oidc_test.auth.provider import NoPublicKey, Provider
+from oidc_test.auth.utils import (
+ get_auth_provider,
+ get_auth_provider_or_none,
get_current_user_or_none,
- get_current_user,
authlib_oauth,
- get_token,
+ get_token_from_session_or_none,
+ get_token_from_session,
+ update_token,
)
-from .auth_misc import pretty_details
-from .database import db
+from oidc_test.auth.utils import init_providers
+from oidc_test.settings import settings
+from oidc_test.auth_providers import providers
+from oidc_test.models import User
+from oidc_test.database import TokenNotInDb, db
+from oidc_test.resource_server import resource_server
-logger = logging.getLogger("uvicorn.error")
+logger = logging.getLogger("oidc-test")
+
+if settings.log:
+ assert __package__ is not None
+ with (
+ importlib.resources.path(__package__) as package_path,
+ open(package_path / settings.log_config_file) as f,
+ ):
+ logging_config = safe_load(f)
+ logging.config.dictConfig(logging_config)
templates = Jinja2Templates(Path(__file__).parent / "templates")
-app = FastAPI(
- title="OIDC auth test",
-)
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ assert app is not None
+ init_providers()
+ registry.make_registry()
+ for provider in list(providers.values()):
+ if provider.disabled:
+ continue
+ try:
+ await provider.get_info()
+ except NoPublicKey:
+ logger.warning(f"Disable {provider.id}: public key not found")
+ del providers[provider.id]
+ yield
-app.mount(
- "/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static"
+
+app = FastAPI(title="OIDC auth test", lifespan=lifespan)
+
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=settings.cors_origins,
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
)
# SessionMiddleware is required by authlib
@@ -54,321 +94,236 @@ app.add_middleware(
secret_key=settings.secret_key,
)
-# Add oidc providers to authlib from the settings
-
-# fastapi_providers: dict[str, OpenIdConnect] = {}
-providers_settings: dict[str, OIDCProvider] = {}
-
-for provider in settings.oidc.providers:
- authlib_oauth.register(
- name=provider.id,
- server_metadata_url=provider.openid_configuration,
- client_kwargs={
- "scope": "openid email", # offline_access profile",
- },
- client_id=provider.client_id,
- client_secret=provider.client_secret,
- api_base_url=provider.url,
- # For PKCE (not implemented yet):
- # code_challenge_method="S256",
- # fetch_token=fetch_token,
- # update_token=update_token,
- # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
- )
- # fastapi_providers[provider.id] = OpenIdConnect(
- # openIdConnectUrl=provider.openid_configuration
- # )
- providers_settings[provider.id] = provider
+app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static")
+app.mount("/resource", resource_server, name="resource_server")
@app.get("/")
async def home(
request: Request,
- user: Annotated[User, Depends(get_current_user_or_none)],
- oidc_provider: Annotated[
- StarletteOAuth2App | None, Depends(get_oidc_provider_or_none)
- ],
+ user: Annotated[User | None, Depends(get_current_user_or_none)],
+ provider: Annotated[Provider | None, Depends(get_auth_provider_or_none)],
+ token: Annotated[OAuth2Token | None, Depends(get_token_from_session_or_none)],
) -> HTMLResponse:
- now = datetime.now()
- if oidc_provider and (
- (provider := providers_settings.get(oidc_provider.name)) is not None
- ):
- resources = provider.resources
+ context = {
+ "show_token": settings.show_token,
+ "user": user,
+ "now": datetime.now(),
+ "__version__": __version__,
+ }
+ if provider is None or token is None:
+ context["providers"] = providers
+ context["access_token"] = None
+ context["id_token_parsed"] = None
+ context["access_token_parsed"] = None
+ context["refresh_token_parsed"] = None
+ context["resources"] = None
+ context["auth_provider"] = None
else:
- resources = []
- return templates.TemplateResponse(
- name="home.html",
- request=request,
- context={
- "settings": settings.model_dump(),
- "user": user,
- "now": now,
- "resources": resources,
- "user_info_details": (
- pretty_details(user, now)
- if user and settings.oidc.show_session_details
- else None
- ),
- },
- )
+ context["auth_provider"] = provider
+ context["access_token"] = token["access_token"]
+ try:
+ access_token_parsed = provider.decode(token["access_token"], verify_signature=False)
+ context["access_token_parsed"] = access_token_parsed
+ context["access_token_scope"] = access_token_parsed.get("scope")
+ except PyJWTError as err:
+ context["access_token_parsed"] = {"Cannot parse": err.__class__.__name__}
+ context["access_token_scope"] = None
+ try:
+ id_token_parsed = provider.decode(token["id_token"], verify_signature=False)
+ context["id_token_parsed"] = id_token_parsed
+ except PyJWTError as err:
+ context["id_token_parsed"] = {"Cannot parse": err.__class__.__name__}
+ try:
+ refresh_token_parsed = provider.decode(token["refresh_token"], verify_signature=False)
+ context["refresh_token_parsed"] = refresh_token_parsed
+ except PyJWTError as err:
+ context["refresh_token_parsed"] = {"Cannot parse": err.__class__.__name__}
+ context["resources"] = registry.resources
+ context["resource_providers"] = provider.resource_providers
+ return templates.TemplateResponse(name="home.html", request=request, context=context)
# Endpoints for the login / authorization process
-@app.get("/login/{oidc_provider_id}")
-async def login(request: Request, oidc_provider_id: str) -> RedirectResponse:
+@app.get("/login/{auth_provider_id}")
+async def login(request: Request, auth_provider_id: str) -> RedirectResponse:
"""Login with the provider id, giving the browser a redirect to its authorize page.
- The provider is expected to send the browser back to our own /auth/{oidc_provider_id} url
+ The provider is expected to send the browser back to our own /auth/{auth_provider_id} url
with the token.
"""
- redirect_uri = request.url_for("auth", oidc_provider_id=oidc_provider_id)
+ redirect_uri = request.url_for("auth", auth_provider_id=auth_provider_id)
try:
- provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
+ provider: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id)
except AttributeError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
- if (
- code_challenge_method := providers_settings[
- oidc_provider_id
- ].code_challenge_method
- ) is not None:
- client = AsyncOAuth2Client(..., code_challenge_method=code_challenge_method)
- code_verifier = generate_code_verifier()
- logger.debug("TODO: PKCE")
- else:
- code_verifier = None
+ # if (
+ # code_challenge_method := providers[
+ # auth_provider_id
+ # ].code_challenge_method
+ # ) is not None:
+ # #client = AsyncOAuth2Client(..., code_challenge_method=code_challenge_method)
+ # code_verifier = generate_code_verifier()
+ # logger.debug("TODO: PKCE")
+ # else:
+ # code_verifier = None
try:
response = await provider.authorize_redirect(
request,
redirect_uri,
access_type="offline",
- code_verifier=code_verifier,
+ code_verifier=None,
)
return response
except HTTPError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
-@app.get("/auth/{oidc_provider_id}")
-async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
+@app.get("/auth/{auth_provider_id}")
+async def auth(
+ request: Request,
+ auth_provider_id: str,
+) -> RedirectResponse:
"""Decrypt the auth token, store it to the session (cookie based)
and response to the browser with a redirect to a "welcome user" page.
"""
try:
- oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
- except AttributeError:
+ provider = providers[auth_provider_id]
+ except KeyError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
try:
- token: OAuth2Token = await oidc_provider.authorize_access_token(request)
+ token: OAuth2Token = await provider.authlib_client.authorize_access_token(request)
except OAuthError as error:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error)
- # Remember the oidc_provider in the session
- # logger.debug(f"Scope: {token['scope']}")
- request.session["oidc_provider_id"] = oidc_provider_id
+ # Remember the authlib_client in the session
+ # logger.info(f"Scope: {token['scope']}")
+ request.session["auth_provider_id"] = auth_provider_id
#
# One could process the full decoded token which contains extra information
# eg for updates. Here we are only interested in roles
#
if userinfo := token.get("userinfo"):
- # Remember the oidc_provider in the session
- request.session["oidc_provider_id"] = oidc_provider_id
- # User id (sub) given by oidc provider
+ # Remember the authlib_client in the session
+ request.session["auth_provider_id"] = auth_provider_id
+ # User id (sub) given by auth provider
sub = userinfo["sub"]
- # Get additional data from userinfo endpoint
- try:
- user_info_from_endpoint = await oidc_provider.userinfo(
- token=token, follow_redirects=True
- )
- except Exception as err:
- logger.warn(f"Cannot get userinfo from endpoint: {err}")
- user_info_from_endpoint = {}
+ ## Get additional data from userinfo endpoint
+ # try:
+ # user_info_from_endpoint = await authlib_client.userinfo(
+ # token=token, follow_redirects=True
+ # )
+ # except Exception as err:
+ # logger.warn(f"Cannot get userinfo from endpoint: {err}")
+ # user_info_from_endpoint = {}
# Build and remember the user in the session
request.session["user_sub"] = sub
- # Store the user in the database
- user = await db.add_user(
- sub,
- user_info=userinfo,
- oidc_provider=oidc_provider,
- user_info_from_endpoint=user_info_from_endpoint,
- )
- # Add the id_token to the session
- request.session["token"] = token["id_token"]
+ # Store the user in the database, which also verifies the token validity and signature
+ try:
+ user = await db.add_user(
+ sub,
+ user_info=userinfo,
+ auth_provider=providers[auth_provider_id],
+ access_token=token["access_token"],
+ )
+ except PyJWTError as err:
+ raise HTTPException(
+ status.HTTP_401_UNAUTHORIZED,
+ detail=f"Token invalid: {err.__class__.__name__}",
+ )
+ assert isinstance(user, User)
+ # Add the provider session id to the session
+ request.session["sid"] = provider.get_session_key(userinfo)
# Add the token to the db because it is used for logout
- await db.add_token(token, user)
+ await db.add_token(provider, token)
# Send the user to the home: (s)he is authenticated
return RedirectResponse(url=request.url_for("home"))
else:
# Not sure if it's correct to redirect to plain login
# if no userinfo is provided
- return RedirectResponse(
- url=request.url_for("login", oidc_provider_id=oidc_provider_id)
- )
+ return RedirectResponse(url=request.url_for("login", auth_provider_id=auth_provider_id))
+
+
+@app.get("/account")
+async def account(
+ provider: Annotated[Provider, Depends(get_auth_provider)],
+) -> RedirectResponse:
+ """Redirect to the auth provider account management,
+ if account_url_template is in the provider's settings"""
+ return RedirectResponse(f"{provider.account_url_template}")
@app.get("/logout")
async def logout(
request: Request,
- oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
+ provider: Annotated[Provider, Depends(get_auth_provider)],
) -> RedirectResponse:
- # Clear session
- request.session.pop("user_sub", None)
# Get provider's endpoint
if (
- provider_logout_uri := oidc_provider.server_metadata.get("end_session_endpoint")
+ provider_logout_uri := provider.authlib_client.server_metadata.get("end_session_endpoint")
) is None:
- logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}")
+ logger.warning(f"Cannot find end_session_endpoint for provider {provider.id}")
return RedirectResponse(request.url_for("non_compliant_logout"))
post_logout_uri = request.url_for("home")
- if (token := await db.get_token(request.session.pop("token", None))) is None:
- logger.warn("No session in db for the token")
+ # Clear session
+ request.session.pop("user_sub", None)
+ request.session.pop("auth_provider_id", None)
+ try:
+ token = await db.get_token(provider, request.session.pop("sid", None))
+ except TokenNotInDb:
+ logger.warning("No session in db for the token or no token")
return RedirectResponse(request.url_for("home"))
- logout_url = (
- provider_logout_uri
- + "?"
- + urlencode(
- {
- "post_logout_redirect_uri": post_logout_uri,
- "id_token_hint": token["id_token"],
- "cliend_id": "oidc_local_test",
- }
- )
- )
+ url_query = {
+ "post_logout_redirect_uri": post_logout_uri,
+ "client_id": provider.client_id,
+ }
+ if provider.logout_with_id_token_hint:
+ url_query["id_token_hint"] = token["id_token"]
+ logout_url = f"{provider_logout_uri}?{urlencode(url_query)}"
return RedirectResponse(logout_url)
@app.get("/non-compliant-logout")
async def non_compliant_logout(
request: Request,
- oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
+ provider: Annotated[Provider, Depends(get_auth_provider)],
):
"""A page for non-compliant OAuth2 servers that we cannot log out."""
+ # Clear session
+ request.session.pop("user_sub", None)
+ request.session.pop("auth_provider_id", None)
return templates.TemplateResponse(
name="non_compliant_logout.html",
request=request,
- context={"oidc_provider": oidc_provider, "home_url": request.url_for("home")},
+ context={"auth_provider": provider, "home_url": request.url_for("home")},
)
-# Route for OAuth resource server
-
-
-@app.get("/resource/{id}")
-async def get_resource(
- id: str,
+@app.get("/refresh")
+async def refresh(
request: Request,
- user: Annotated[User, Depends(get_current_user)],
- oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
- token: Annotated[OAuth2Token, Depends(get_token)],
-) -> JSONResponse:
- """Generic path for testing a resource provided by a provider"""
- if oidc_provider is None:
- raise HTTPException(
- status.HTTP_406_NOT_ACCEPTABLE, detail="No such oidc provider"
- )
- if (provider := providers_settings.get(oidc_provider.name)) is None:
- raise HTTPException(
- status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider setting"
- )
+ provider: Annotated[Provider, Depends(get_auth_provider)],
+ token: Annotated[OAuth2Token, Depends(get_token_from_session)],
+) -> RedirectResponse:
+ """Manually refresh token"""
+ new_token = await provider.authlib_client.fetch_access_token(
+ refresh_token=token["refresh_token"],
+ grant_type="refresh_token",
+ )
try:
- resource = next(x for x in provider.resources if x.id == id)
- except StopIteration:
+ await update_token(provider.id, new_token)
+ except PyJWTError as err:
+ logger.info(f"Cannot refresh token: {err.__class__.__name__}")
raise HTTPException(
- status.HTTP_406_NOT_ACCEPTABLE, detail="No such resource for this provider"
+ status.HTTP_510_NOT_EXTENDED, f"Token refresh error: {err.__class__.__name__}"
)
- if (
- response := await oidc_provider.get(
- resource.url,
- # headers={"Authorization": f"token {token['access_token']}"},
- token=token,
- )
- ).is_success:
- return JSONResponse(response.json())
- else:
- raise HTTPException(status_code=response.status_code, detail=response.text)
-
-
-# Routes for test
-
-
-@app.get("/public")
-async def public() -> HTMLResponse:
- return HTMLResponse("Not protected ")
-
-
-@app.get("/protected")
-async def get_protected(
- user: Annotated[User, Depends(get_current_user)]
-) -> HTMLResponse:
- return HTMLResponse("Only authenticated users can see this ")
-
-
-@app.get("/protected-by-foorole")
-@hasrole("foorole")
-async def get_protected_by_foorole(request: Request) -> HTMLResponse:
- return HTMLResponse("Only users with foorole can see this ")
-
-
-@app.get("/protected-by-barrole")
-@hasrole("barrole")
-async def get_protected_by_barrole(request: Request) -> HTMLResponse:
- return HTMLResponse("Protected by barrole ")
-
-
-@app.get("/protected-by-foorole-and-barrole")
-@hasrole("barrole")
-@hasrole("foorole")
-async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse:
- return HTMLResponse("Only users with foorole and barrole can see this ")
-
-
-@app.get("/protected-by-foorole-or-barrole")
-@hasrole(["foorole", "barrole"])
-async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse:
- return HTMLResponse("Only users with foorole or barrole can see this ")
-
-
-@app.get("/introspect")
-async def get_introspect(
- request: Request,
- oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
- token: Annotated[OAuth2Token, Depends(get_token)],
-) -> JSONResponse:
- if (
- response := await oidc_provider.post(
- oidc_provider.server_metadata["introspection_endpoint"],
- token=token,
- data={"token": token["access_token"]},
- )
- ).is_success:
- return response.json()
- else:
- raise HTTPException(status_code=response.status_code, detail=response.text)
-
-
-@app.get("/oauth2-forgejo-test")
-async def get_forgejo_user_info(
- request: Request,
- user: Annotated[User, Depends(get_current_user)],
- oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
- token: Annotated[OAuth2Token, Depends(get_token)],
-) -> HTMLResponse:
- if (
- response := await oidc_provider.get(
- "/api/v1/user/repos",
- # headers={"Authorization": f"token {token['access_token']}"},
- token=token,
- )
- ).is_success:
- repos = response.json()
- names = [repo["name"] for repo in repos]
- return HTMLResponse(f"{user.name} has {len(repos)} repos: {', '.join(names)}")
- else:
- raise HTTPException(status_code=response.status_code, detail=response.text)
+ return RedirectResponse(url=request.url_for("home"))
# Snippet for running standalone
# Mostly useful for the --version option,
-# as running with uvicorn is easy and provides flaxibility
+# as running with uvicorn is easy and provides better flexibility, eg.
+# uvicorn --host foo oidc_test.main:app --reload
def main():
@@ -386,9 +341,7 @@ def main():
parser.add_argument(
"-p", "--port", type=int, default=80, help="Port to listen to (default: 80)"
)
- parser.add_argument(
- "-v", "--version", action="store_true", help="Print version and exit"
- )
+ parser.add_argument("-v", "--version", action="store_true", help="Print version and exit")
args = parser.parse_args()
if args.version:
diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py
index 3d484aa..7b6fd0e 100644
--- a/src/oidc_test/models.py
+++ b/src/oidc_test/models.py
@@ -1,55 +1,66 @@
+import logging
from functools import cached_property
-from typing import Self
+from typing import Any
from pydantic import (
computed_field,
AnyHttpUrl,
EmailStr,
ConfigDict,
- GetCoreSchemaHandler,
)
-from pydantic_core import CoreSchema, core_schema
-from authlib.integrations.starlette_client.apps import StarletteOAuth2App
-from authlib.oauth2.rfc6749 import OAuth2Token as OAuth2Token_authlib
from sqlmodel import SQLModel, Field
+logger = logging.getLogger("oidc-test")
+
class Role(SQLModel, extra="ignore"):
name: str
class UserBase(SQLModel, extra="ignore"):
-
id: str | None = None
sid: str | None = None
- name: str
+ name: str | None = None
email: EmailStr | None = None
picture: AnyHttpUrl | None = None
roles: list[Role] = []
class User(UserBase):
- model_config = ConfigDict(arbitrary_types_allowed=True)
+ model_config = ConfigDict(arbitrary_types_allowed=True) # type:ignore
sub: str = Field(
description="""subject id of the user given by the oidc provider,
also the key for the database 'table'""",
)
userinfo: dict = {}
- oidc_provider: StarletteOAuth2App | None = None
-
- @classmethod
- def from_auth(cls, userinfo: dict, oidc_provider: StarletteOAuth2App) -> Self:
- user = cls(**userinfo)
- user.userinfo = userinfo
- user.oidc_provider = oidc_provider
- # Add roles if they are provided in the token
- if raw_ra := userinfo.get("realm_access"):
- if raw_roles := raw_ra.get("roles"):
- user.roles = [Role(name=raw_role) for raw_role in raw_roles]
- return user
+ access_token: str | None = None
+ access_token_decoded: dict[str, Any] | None = None
+ auth_provider_id: str
@computed_field
@cached_property
def roles_as_set(self) -> set[str]:
return set([role.name for role in self.roles])
+
+ def has_scope(self, scope: str) -> bool:
+ """Check if the scope is present in user info or access token"""
+ info_scopes = self.userinfo.get("scope", "").split(" ")
+ try:
+ access_token_scopes = self.decode_access_token().get("scope", "").split(" ")
+ except Exception as err:
+ logger.debug(f"Cannot find scope because the access token cannot be decoded: {err}")
+ access_token_scopes = []
+ return scope in set(info_scopes + access_token_scopes)
+
+ def decode_access_token(self, verify_signature: bool = True):
+ assert self.access_token is not None, "no access_token"
+ assert self.auth_provider_id is not None, "no auth_provider_id"
+ from .auth_providers import providers
+
+ return providers[self.auth_provider_id].decode(
+ self.access_token, verify_signature=verify_signature
+ )
+
+ def get_scope(self, verify_signature: bool = True):
+ return self.decode_access_token(verify_signature=verify_signature)["scope"]
diff --git a/src/oidc_test/registry.py b/src/oidc_test/registry.py
new file mode 100644
index 0000000..3b91ad4
--- /dev/null
+++ b/src/oidc_test/registry.py
@@ -0,0 +1,47 @@
+from importlib.metadata import entry_points
+import logging
+
+from pydantic import BaseModel, ConfigDict
+
+from oidc_test.models import User
+
+logger = logging.getLogger("registry")
+
+
+class ProcessResult(BaseModel):
+ model_config = ConfigDict(
+ extra="allow",
+ )
+
+
+class ProcessError(Exception):
+ pass
+
+
+class Resource(BaseModel):
+ name: str
+ scope_required: str | None = None
+ role_required: str | None = None
+ is_public: bool = False
+ default_resource_id: str | None = None
+
+ def __init__(self, name: str):
+ super().__init__()
+ self.__id__ = name
+
+ async def process(self, user: User | None, resource_id: str | None = None) -> ProcessResult:
+ logger.warning(f"{self.__id__} should define a process method")
+ return ProcessResult()
+
+
+class ResourceRegistry(BaseModel):
+ resources: dict[str, Resource] = {}
+
+ def make_registry(self):
+ for ep in entry_points().select(group="oidc_test.resource_provider"):
+ ResourceClass = ep.load()
+ if issubclass(ResourceClass, Resource):
+ self.resources[ep.name] = ResourceClass(ep.name)
+
+
+registry = ResourceRegistry()
diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py
new file mode 100644
index 0000000..ddc5762
--- /dev/null
+++ b/src/oidc_test/resource_server.py
@@ -0,0 +1,343 @@
+from typing import Annotated, Any
+import logging
+from json import JSONDecodeError
+
+from authlib.oauth2.rfc6749 import OAuth2Token
+from httpx import AsyncClient, HTTPError
+from jwt.exceptions import DecodeError, ExpiredSignatureError, InvalidTokenError
+from fastapi import FastAPI, HTTPException, Depends, Request, status
+from fastapi.middleware.cors import CORSMiddleware
+
+# from starlette.middleware.sessions import SessionMiddleware
+# from authlib.integrations.starlette_client.apps import StarletteOAuth2App
+# from authlib.oauth2.rfc6749 import OAuth2Token
+
+from oidc_test.auth.provider import Provider
+from oidc_test.auth.utils import (
+ get_user_from_token_or_none,
+ oauth2_scheme_optional,
+)
+from oidc_test.auth_providers import providers
+from oidc_test.settings import ResourceProvider, settings
+from oidc_test.models import User
+from oidc_test.registry import ProcessError, ProcessResult, registry
+
+logger = logging.getLogger("oidc-test")
+
+resource_server = FastAPI()
+
+
+resource_server.add_middleware(
+ CORSMiddleware,
+ allow_origins=settings.cors_origins,
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+)
+
+# SessionMiddleware is required by authlib
+# resource_server.add_middleware(
+# SessionMiddleware,
+# secret_key=settings.secret_key,
+# )
+
+# Route for OAuth resource server
+
+
+# Routes for RBAC based tests
+
+
+@resource_server.get("/")
+async def resources() -> dict[str, dict[str, Any]]:
+ return {"internal": {}, "plugins": registry.resources}
+
+
+@resource_server.get("/{resource_name}")
+@resource_server.get("/{resource_name}/{resource_id}")
+async def get_resource(
+ resource_name: str,
+ user: Annotated[User | None, Depends(get_user_from_token_or_none)],
+ token: Annotated[OAuth2Token | None, Depends(oauth2_scheme_optional)],
+ resource_id: str | None = None,
+):
+ """Generic path for testing a resource provided by a provider.
+ There's no field validation (response type of ProcessResult) on purpose,
+ leaving the responsibility of the response validation to resource providers"""
+ # Get the resource if it's defined in user auth provider's resources (external)
+ if user is not None:
+ provider = providers[user.auth_provider_id]
+ if ":" in resource_name:
+ # Third-party resource provider: send the request with the request token
+ resource_provider_id, resource_name = resource_name.split(":", 1)
+ provider = providers[user.auth_provider_id]
+ resource_provider: ResourceProvider = provider.get_resource_provider(
+ resource_provider_id
+ )
+ resource_url = resource_provider.get_resource_url(resource_name)
+ async with AsyncClient(verify=resource_provider.verify_ssl) as client:
+ try:
+ logger.debug(f"GET request to {resource_url}")
+ resp = await client.get(
+ resource_url,
+ headers={
+ "Content-type": "application/json",
+ "Authorization": f"Bearer {token}",
+ "auth_provider": user.auth_provider_id,
+ },
+ )
+ except HTTPError as err:
+ raise HTTPException(
+ status.HTTP_503_SERVICE_UNAVAILABLE, err.__class__.__name__
+ )
+ except Exception as err:
+ raise HTTPException(
+ status.HTTP_500_INTERNAL_SERVER_ERROR, err.__class__.__name__
+ )
+ else:
+ if resp.is_success:
+ return resp.json()
+ else:
+ reason_str: str
+ try:
+ reason_str = resp.json().get("detail", str(resp))
+ except Exception:
+ reason_str = str(resp.text)
+ raise HTTPException(resp.status_code, reason_str)
+ # Third party resource (provided through the auth provider)
+ # The token is just passed on
+ # XXX: is this branch valid anymore?
+ if resource_name in [r.resource_name for r in provider.resources]:
+ return await get_auth_provider_resource(
+ provider=provider,
+ resource_name=resource_name,
+ token=token,
+ user=user,
+ )
+ # Internal resource (provided here)
+ if resource_name in registry.resources:
+ resource = registry.resources[resource_name]
+ reason: dict[str, str] = {}
+ if not resource.is_public:
+ if user is None:
+ raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Resource is not public")
+ else:
+ if resource.scope_required is not None and not user.has_scope(
+ resource.scope_required
+ ):
+ reason["scope"] = (
+ f"No scope {resource.scope_required} in the access token "
+ "but it is required for accessing this resource"
+ )
+ if (
+ resource.role_required is not None
+ and resource.role_required not in user.roles_as_set
+ ):
+ reason["role"] = (
+ f"You don't have the role {resource.role_required} "
+ "but it is required for accessing this resource"
+ )
+ if len(reason) == 0:
+ try:
+ resp = await resource.process(user=user, resource_id=resource_id)
+ return resp
+ except ProcessError as err:
+ raise HTTPException(
+ status.HTTP_401_UNAUTHORIZED, f"Cannot process resource: {err}"
+ )
+ else:
+ raise HTTPException(status.HTTP_401_UNAUTHORIZED, ", ".join(reason.values()))
+ else:
+ raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Unknown resource")
+ # return await get_resource_(resource_name, user, **request.query_params)
+
+
+async def get_auth_provider_resource(
+ provider: Provider, resource_name: str, token: OAuth2Token | None, user: User
+) -> ProcessResult:
+ if token is None:
+ raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No auth token")
+ access_token = token
+ async with AsyncClient() as client:
+ resp = await client.get(
+ url=provider.get_resource_url(resource_name),
+ headers={
+ "Content-type": "application/json",
+ "Authorization": f"Bearer {access_token}",
+ },
+ )
+ if resp.is_error:
+ raise HTTPException(resp.status_code, f"Cannot fetch resource: {resp.reason_phrase}")
+ # Only a demo, real application would really process the response
+ resp_length = len(resp.text)
+ if resp_length > 1024:
+ return ProcessResult(
+ msg=f"The resource is too long ({resp_length} bytes) to show in this demo, here is just the begining in raw format",
+ start=resp.text[:100] + "...",
+ )
+ else:
+ try:
+ resp_json = resp.json()
+ except JSONDecodeError:
+ return ProcessResult(msg="The resource is not formatted in JSON", text=resp.text)
+ if isinstance(resp_json, dict):
+ return ProcessResult(**resp.json())
+ elif isinstance(resp_json, list):
+ return ProcessResult(**{str(i): line for i, line in enumerate(resp_json)})
+
+
+# @resource_server.get("/public")
+# async def public() -> dict:
+# return {"msg": "Not protected"}
+#
+#
+# @resource_server.get("/protected")
+# async def get_protected(user: Annotated[User, Depends(get_user_from_token)]):
+# assert user is not None # Just to keep QA checks happy
+# return {"msg": "Only authenticated users can see this"}
+#
+#
+# @resource_server.get("/protected-by-foorole")
+# async def get_protected_by_foorole(
+# user: Annotated[User, Depends(UserWithRole("foorole"))],
+# ):
+# assert user is not None
+# return {"msg": "Only users with foorole can see this"}
+#
+#
+# @resource_server.get("/protected-by-barrole")
+# async def get_protected_by_barrole(
+# user: Annotated[User, Depends(UserWithRole("barrole"))],
+# ):
+# assert user is not None
+# return {"msg": "Protected by barrole"}
+#
+#
+# @resource_server.get("/protected-by-foorole-and-barrole")
+# async def get_protected_by_foorole_and_barrole(
+# user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))],
+# ):
+# assert user is not None # Just to keep QA checks happy
+# return {"msg": "Only users with foorole and barrole can see this"}
+#
+#
+# @resource_server.get("/protected-by-foorole-or-barrole")
+# async def get_protected_by_foorole_or_barrole(
+# user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))],
+# ):
+# assert user is not None # Just to keep QA checks happy
+# return {"msg": "Only users with foorole or barrole can see this"}
+
+# async def get_resource_(resource_id: str, user: User, **kwargs) -> dict:
+# """
+# Resource processing: build an informative rely as a simple showcase
+# """
+# if resource_id == "petition":
+# return await sign(user, kwargs["petition_id"])
+# provider = providers[user.auth_provider_id]
+# try:
+# pname = provider.name
+# except KeyError:
+# pname = "?"
+# resp = {
+# "hello": f"Hi {user.name} from an OAuth resource provider",
+# "comment": f"I received a request for '{resource_id}' "
+# + f"with an access token signed by {pname}",
+# }
+# # For the demo, resource resource_id matches a scope get:resource_id,
+# # but this has to be refined for production
+# required_scope = f"get:{resource_id}"
+# # Check if the required scope is in the scopes allowed in userinfo
+# try:
+# if user.has_scope(required_scope):
+# await process(user, resource_id, resp)
+# else:
+# ## For the showcase, giving a explanation.
+# ## Alternatively, raise HTTP_401_UNAUTHORIZED
+# raise HTTPException(
+# status.HTTP_401_UNAUTHORIZED,
+# f"No scope {required_scope} in the access token "
+# + "but it is required for accessing this resource",
+# )
+# except ExpiredSignatureError:
+# raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token's signature has expired")
+# except InvalidTokenError:
+# raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token is invalid")
+# return resp
+
+
+# async def process(user, resource_id, resp):
+# """
+# Too simple to be serious.
+# It's a good fit for a plugin architecture for production
+# """
+# if resource_id == "time":
+# resp["time"] = datetime.now().strftime("%c")
+# elif resource_id == "bs":
+# async with AsyncClient() as client:
+# bs = await client.get("https://corporatebs-generator.sameerkumar.website/")
+# resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.")
+# else:
+# raise HTTPException(
+# status.HTTP_401_UNAUTHORIZED, f"I don't known how to give '{resource_id}'."
+# )
+
+
+# @resource_server.get("/introspect")
+# async def get_introspect(
+# request: Request,
+# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
+# token: Annotated[OAuth2Token, Depends(get_token)],
+# ) -> JSONResponse:
+# assert request is not None # Just to keep QA checks happy
+# if (url := oidc_provider.server_metadata.get("introspection_endpoint")) is None:
+# raise HTTPException(
+# status_code=status.HTTP_401_UNAUTHORIZED,
+# detail="No introspection endpoint found for the OIDC provider",
+# )
+# if (
+# response := await oidc_provider.post(
+# url,
+# token=token,
+# data={"token": token["access_token"]},
+# )
+# ).is_success:
+# return response.json()
+# else:
+# raise HTTPException(status_code=response.status_code, detail=response.text)
+
+# assert user.oidc_provider is not None
+### Get some info (TODO: refactor)
+# if (auth_provider_id := user.oidc_provider.name) is None:
+# raise HTTPException(
+# status.HTTP_401_UNAUTHORIZED,
+# "Request headers must have a 'auth_provider' field",
+# )
+# if (
+# auth_provider_settings := oidc_providers_settings.get(auth_provider_id)
+# ) is None:
+# raise HTTPException(
+# status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
+# )
+# if (key := auth_provider_settings.get_public_key()) is None:
+# raise HTTPException(
+# status.HTTP_401_UNAUTHORIZED,
+# f"Key for provider '{auth_provider_id}' unknown",
+# )
+# logger.warn(f"refresh with scope {scope}")
+# breakpoint()
+# refreshed_auth_info = await user.oidc_provider.fetch_access_token(scope=scope)
+### Decode the new token
+# try:
+# payload = decode(
+# refreshed_auth_info["access_token"],
+# key=key,
+# algorithms=["RS256"],
+# audience="account",
+# options={"verify_signature": not settings.insecure.skip_verify_signature},
+# )
+# except ExpiredSignatureError as err:
+# logger.info(f"Expired signature: {err}")
+# raise HTTPException(
+# status.HTTP_401_UNAUTHORIZED,
+# "Expired signature (refresh not implemented yet)",
+# )
diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py
index 3a9447c..ad80c06 100644
--- a/src/oidc_test/settings.py
+++ b/src/oidc_test/settings.py
@@ -4,25 +4,42 @@ import random
from typing import Type, Tuple
from pathlib import Path
-from pydantic import BaseModel, computed_field
+from pydantic import AnyHttpUrl, BaseModel, computed_field, AnyUrl
from pydantic_settings import (
BaseSettings,
SettingsConfigDict,
PydanticBaseSettingsSource,
YamlConfigSettingsSource,
)
+from starlette.requests import Request
class Resource(BaseModel):
"""A resource with an URL that can be accessed with an OAuth2 access token"""
- id: str
+ resource_name: str
name: str
url: str
-class OIDCProvider(BaseModel):
- """OIDC provider, can also be a resource server"""
+class ResourceProvider(BaseModel):
+ id: str
+ name: str
+ base_url: AnyUrl
+ resources: list[Resource] = []
+ verify_ssl: bool = True
+
+ def get_resource(self, resource_name: str) -> Resource:
+ return [
+ resource for resource in self.resources if resource.resource_name == resource_name
+ ][0]
+
+ def get_resource_url(self, resource_name: str) -> str:
+ return f"{self.base_url}{self.get_resource(resource_name).url}"
+
+
+class AuthProviderSettings(BaseModel):
+ """Auth provider, can also be a resource server"""
id: str
name: str
@@ -33,6 +50,18 @@ class OIDCProvider(BaseModel):
code_challenge_method: str | None = None
hint: str = "No hint"
resources: list[Resource] = []
+ account_url_template: str | None = None
+ info_url: str | None = (
+ None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key)
+ )
+ public_key: str | None = None
+ public_key_url: str | None = None
+ signature_alg: str = "RS256"
+ resource_provider_scopes: list[str] = []
+ session_key: str = "sid"
+ skip_verify_signature: bool = True
+ disabled: bool = False
+ resource_providers: list[ResourceProvider] = []
@computed_field
@property
@@ -44,21 +73,43 @@ class OIDCProvider(BaseModel):
def token_url(self) -> str:
return "auth/" + self.id
+ def get_account_url(self, request: Request, user: dict) -> str | None:
+ if self.account_url_template:
+ if not (self.url.endswith("/") or self.account_url_template.startswith("/")):
+ sep = "/"
+ else:
+ sep = ""
+ return self.url + sep + self.account_url_template.format(request=request, user=user)
+ else:
+ return None
-class OIDCSettings(BaseModel):
+
+class AuthSettings(BaseModel):
show_session_details: bool = False
- providers: list[OIDCProvider] = []
+ providers: list[AuthProviderSettings] = []
swagger_provider: str = ""
+class Insecure(BaseModel):
+ """Warning: changing these defaults are only suitable for debugging"""
+
+ skip_verify_signature: bool = False
+
+
class Settings(BaseSettings):
"""Settings wil be read from an .env file"""
- oidc: OIDCSettings = OIDCSettings()
+ model_config = SettingsConfigDict(env_nested_delimiter="__")
+
+ auth: AuthSettings = AuthSettings()
secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16))
log: bool = False
-
- model_config = SettingsConfigDict(env_nested_delimiter="__")
+ log_config_file: str = "log_conf.yaml"
+ insecure: Insecure = Insecure()
+ cors_origins: list[str] = []
+ debug_token: bool = False
+ show_token: bool = False
+ show_external_resource_providers_links: bool = False
@classmethod
def settings_customise_sources(
@@ -77,9 +128,7 @@ class Settings(BaseSettings):
settings_cls,
Path(
Path(
- environ.get(
- "OIDC_TEST_SETTINGS_FILE", Path.cwd() / "settings.yaml"
- ),
+ environ.get("OIDC_TEST_SETTINGS_FILE", Path.cwd() / "settings.yaml"),
)
),
),
diff --git a/src/oidc_test/static/styles.css b/src/oidc_test/static/styles.css
index 6065a91..1e8dc03 100644
--- a/src/oidc_test/static/styles.css
+++ b/src/oidc_test/static/styles.css
@@ -2,11 +2,15 @@ body {
font-family: Arial, Helvetica, sans-serif;
background-color: floralwhite;
margin: 0;
+ font-family: system-ui;
+ text-align: center;
}
h1 {
- text-align: center;
background-color: #f7c7867d;
margin: 0 0 0.2em 0;
+ box-shadow: 0px 0.2em 0.2em #f7c7867d;
+ text-shadow: 0 0 2px #00000080;
+ font-weight: 200;
}
p {
margin: 0.2em;
@@ -17,9 +21,18 @@ hr {
.hidden {
display: none;
}
+.version {
+ position: absolute;
+ font-size: 75%;
+ top: 0.3em;
+ right: 0.3em;
+}
.center {
text-align: center;
}
+.error {
+ color: darkred;
+}
.content {
width: 100%;
display: flex;
@@ -51,7 +64,6 @@ hr {
border: 2px solid darkkhaki;
padding: 3px 6px;
text-decoration: none;
- text-align: center;
color: black;
}
.user-info a.logout:hover {
@@ -66,7 +78,6 @@ hr {
margin: 0;
}
.debug-auth p {
- text-align: center;
border-bottom: 1px solid black;
}
.debug-auth ul {
@@ -97,15 +108,25 @@ hr {
.hasResponseStatus.status-503 {
background-color: #ffA88050;
}
-.role {
+
+.role, .scope {
padding: 3px 6px;
+ margin: 3px;
+ border-radius: 6px;
+}
+
+.role {
background-color: #44228840;
}
+.scope {
+ background-color: #8888FF80;
+}
+
+
/* For home */
.login-box {
- text-align: center;
background-color: antiquewhite;
margin: 0.5em auto;
width: fit-content;
@@ -127,42 +148,92 @@ hr {
.providers .provider {
min-height: 2em;
}
-.providers .provider a.link {
+.providers .provider .link {
text-decoration: none;
max-height: 2em;
}
-.providers .provider .link div {
- text-align: center;
+.providers .provider .link {
background-color: #f7c7867d;
border-radius: 8px;
padding: 6px;
text-align: center;
color: black;
- font-weight: bold;
+ font-weight: 400;
cursor: pointer;
+ border: 0;
+ width: 100%;
}
+
+.providers .provider .link.disabled {
+ color: gray;
+ cursor: not-allowed;
+}
+
.providers .provider .hint {
font-size: 80%;
max-width: 13em;
}
.providers .error {
- color: darkred;
padding: 3px 6px;
- text-align: center;
font-weight: bold;
flex: 1 1 auto;
}
.content .links-to-check {
display: flex;
- text-align: center;
justify-content: center;
gap: 0.5em;
flex-flow: wrap;
}
-.content .links-to-check a {
+.content .links-to-check button {
color: black;
padding: 5px 10px;
text-decoration: none;
border-radius: 8px;
+ border: none;
+ cursor: pointer;
}
+.token {
+ overflow-wrap: anywhere;
+ font-family: monospace;
+}
+
+.resourceResult {
+ padding: 0.5em;
+ display: flex;
+ gap: 0.5em;
+ width: fit-content;
+ align-items: center;
+ margin: 5px auto;
+ box-shadow: 0px 0px 10px #90c3eeA0;
+ background-color: #90c3eeA0;
+ border-radius: 8px;
+}
+
+.resources {
+ display: flex;
+}
+
+.resource {
+ text-align: center;
+}
+
+.token-info {
+ margin: 0 1em;
+}
+
+.key {
+ font-weight: bold;
+}
+
+.token .key, .token .value {
+ display: inline;
+}
+.token .value {
+ padding-left: 1em;
+}
+
+.msg {
+ text-align: center;
+ font-weight: bold;
+}
diff --git a/src/oidc_test/static/utils.js b/src/oidc_test/static/utils.js
index 6b40d3d..e988dfe 100644
--- a/src/oidc_test/static/utils.js
+++ b/src/oidc_test/static/utils.js
@@ -1,19 +1,90 @@
-function checkHref(elem) {
- var xmlHttp = new XMLHttpRequest()
- xmlHttp.onreadystatechange = function () {
- if (xmlHttp.readyState == 4) {
- elem.classList.add("hasResponseStatus")
- elem.classList.add("status-" + xmlHttp.status)
- elem.title = "Response code: " + xmlHttp.status + " - " + xmlHttp.statusText
- }
+async function checkHref(elem, token, authProvider) {
+ const msg = document.getElementById("msg")
+ const resourceName = elem.getAttribute("resource-name")
+ const resourceId = elem.getAttribute("resource-id")
+ const resourceProviderId = elem.getAttribute("resource-provider-id") ? elem.getAttribute("resource-provider-id") : ""
+ const fqResourceName = resourceProviderId ? `${resourceProviderId}:${resourceName}` : resourceName
+ const url = resourceId ? `resource/${fqResourceName}/${resourceId}` : `resource/${fqResourceName}`
+ const resp = await fetch(url, {
+ method: "GET",
+ headers: new Headers({
+ "Content-type": "application/json",
+ "Authorization": `Bearer ${token}`,
+ "auth_provider": authProvider,
+ }),
+ }).catch(err => {
+ msg.innerHTML = "Cannot fetch resource: " + err.message
+ resourceElem.innerHTML = ""
+ })
+ if (resp === undefined) {
+ return
+ } else {
+ elem.classList.add("hasResponseStatus")
+ elem.classList.add("status-" + resp.status)
+ elem.title = "Response code: " + resp.status + " - " + resp.statusText
}
- xmlHttp.open("GET", elem.href, true) // true for asynchronous
- xmlHttp.send(null)
}
-function checkPerms(className) {
+function checkPerms(className, token, authProvider) {
var rootElems = document.getElementsByClassName(className)
Array.from(rootElems).forEach(elem =>
- Array.from(elem.children).forEach(elem => checkHref(elem))
+ Array.from(elem.children).forEach(elem => checkHref(elem, token, authProvider))
+ )
+}
+
+async function get_resource(resourceName, token, authProvider, resourceId, resourceProviderId) {
+ // BaseUrl for an external resource provider
+ //if (!keycloak.keycloak) { return }
+ const msg = document.getElementById("msg")
+ const resourceElem = document.getElementById('resource')
+ const fqResourceName = resourceProviderId ? `${resourceProviderId}:${resourceName}` : resourceName
+ const url = resourceId ? `resource/${fqResourceName}/${resourceId}` : `resource/${fqResourceName}`
+ const resp = await fetch(url, {
+ method: "GET",
+ headers: new Headers({
+ "Content-type": "application/json",
+ "Authorization": `Bearer ${token}`,
+ "auth_provider": authProvider,
+ }),
+ }).catch(err => {
+ msg.innerHTML = "Cannot fetch resource: " + err.message
+ resourceElem.innerHTML = ""
+ })
+ if (resp === undefined) {
+ return
+ }
+ const resource = await resp.json()
+ if (!resp.ok) {
+ msg.innerHTML = resource["detail"]
+ resourceElem.innerHTML = ""
+ return
+ }
+ msg.innerHTML = ""
+ resourceElem.innerHTML = ""
+ Object.entries(resource).forEach(
+ ([key, value]) => {
+ let r = document.createElement('div')
+ let kElem = document.createElement('div')
+ kElem.innerText = key
+ kElem.className = "key"
+ let vElem = document.createElement('div')
+ if (typeof value == "object") {
+ Object.entries(value).forEach(v => {
+ const ne = document.createElement('div')
+ ne.innerHTML = `${v[0]} : ${v[1]} `
+ vElem.appendChild(ne)
+ })
+ }
+ else {
+ vElem.innerText = value
+ }
+ vElem.className = "value"
+ if (key == "sorry") {
+ vElem.classList.add("error")
+ }
+ r.appendChild(kElem)
+ r.appendChild(vElem)
+ resourceElem.appendChild(r)
+ }
)
}
diff --git a/src/oidc_test/templates/base.html b/src/oidc_test/templates/base.html
index d2aa44b..157e26f 100644
--- a/src/oidc_test/templates/base.html
+++ b/src/oidc_test/templates/base.html
@@ -1,11 +1,12 @@
- FastAPI OIDC test
+ OIDC (FastAPI) test
-
- OIDC-test
+
+ v. {{ __version__}}
+ OIDC-test - FastAPI client
{% block content %}
{% endblock %}
diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html
index ab4dc77..167616f 100644
--- a/src/oidc_test/templates/home.html
+++ b/src/oidc_test/templates/home.html
@@ -5,31 +5,38 @@
with OpenID Connect and OAuth2 with different providers.
{% if not user %}
-
-
Log in with:
-
- {% for provider in settings.oidc.providers %}
-
-
- {{ provider.name }}
-
- {{ provider.hint }}
-
-
- {% else %}
- There is no authentication provider defined.
- Hint: check the settings.yaml file.
- {% endfor %}
-
-
- {% endif %}
- {% if user %}
+
+
Log in with:
+
+ {% for provider in providers.values() %}
+
+
+
+ {{ provider.name }}
+
+
+ {{ provider.hint }}
+
+
+ {% else %}
+ There is no authentication provider defined.
+ Hint: check the settings.yaml file.
+ {% endfor %}
+
+
+ {% else %}
Hey, {{ user.name }}
{% if user.picture %}
{% endif %}
{{ user.email }}
+
+ Provider:
+ {{ auth_provider.name }}
+
{% if user.roles %}
Roles:
@@ -38,53 +45,125 @@
{% endfor %}
{% endif %}
-
- Provider:
- {{ user.oidc_provider.name }}
-
-
Logout
+ {% if access_token_scope %}
+
+ Scopes :
+ {% for scope in access_token_scope.split(' ') %}
+ {{ scope }}
+ {% endfor %}
+
+ {% endif %}
+ {% if auth_provider.account_url_template %}
+
+ Account management
+
+ {% endif %}
+
Refresh access token
+
Logout
{% endif %}
-
- These links should get different response codes depending on the authorization:
-
-
{% if resources %}
-
- Resources for this provider:
-
+
This application provides all these resources, eventually protected with scope or roles:
- {% for resource in resources %}
-
{{ resource.name }}
+ {% for name, resource in resources.items() %}
+ {% if resource.default_resource_id %}
+
+ {{ resource.name }}
+
+ {% else %}
+
+ {{ resource.name }}
+
+ {% endif %}
{% endfor %}
{% endif %}
- {% if user_info_details %}
-
-
-
User info
-
- {% for key, value in user_info_details.items() %}
-
- {{ key }} : {{ value }}
-
+ {% if auth_provider.resources %}
+ {{ auth_provider.name }} is also defined as a provider for these resources:
+
+ {% for resource in auth_provider.resources %}
+ {% if resource.default_resource_id %}
+
+ {{ resource.name }}
+
+ {% else %}
+
+ {{ resource.name }}
+
+ {% endif %}
{% endfor %}
-
- Now is: {{ now.strftime("%T, %D") }}
+ {% endif %}
+ {% if resource_providers %}
+ {{ auth_provider.name }} allows this application to request resources from third party resource providers:
+ {% for resource_provider in resource_providers %}
+
+ {{ resource_provider.name }}
+ {% for resource in resource_provider.resources %}
+
+ {{ resource.name }}
+
+ {% endfor %}
+
+ {% endfor %}
+ {% endif %}
+
+
+ {% if show_token and id_token_parsed %}
+
+
+
+
id token
+
+ {% for key, value in id_token_parsed.items() %}
+
+
{{ key }}
+
{{ value }}
+
+ {% endfor %}
+
+
access token
+
+ {% for key, value in access_token_parsed.items() %}
+
+
{{ key }}
+
{{ value }}
+
+ {% endfor %}
+
+
refresh token
+
+ {% for key, value in refresh_token_parsed.items() %}
+
+
{{ key }}
+
{{ value }}
+
+ {% endfor %}
+
+
{% endif %}
{% endblock %}
diff --git a/src/oidc_test/templates/non_compliant_logout.html b/src/oidc_test/templates/non_compliant_logout.html
index 24a96ae..56758de 100644
--- a/src/oidc_test/templates/non_compliant_logout.html
+++ b/src/oidc_test/templates/non_compliant_logout.html
@@ -6,12 +6,12 @@
authorisation to log in again without asking for credentials.
- This is because {{ oidc_provider.name }} does not provide "end_session_endpoint" in its metadata
- (see: {{ oidc_provider._server_metadata_url }} ).
+ This is because {{ auth_provider.name }} does not provide "end_session_endpoint" in its metadata
+ (see: {{ auth_provider.authlib_client._server_metadata_url }} ).
You can just also go back to the application home page , but
- it recommended to go to the OIDC provider's site
+ it recommended to go to the OIDC provider's site
and log out explicitely from there.
{% endblock %}
diff --git a/uv.lock b/uv.lock
index 6ceb4ca..0566bb5 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1,4 +1,5 @@
version = 1
+revision = 1
requires-python = ">=3.13"
[[package]]
@@ -206,6 +207,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/68/1b/e0a87d256e40e8c888847551b20a017a6b98139178505dc7ffb96f04e954/dnspython-2.7.0-py3-none-any.whl", hash = "sha256:b4c34b7d10b51bcc3a5071e7b8dee77939f1e878477eeecc965e9835f63c6c86", size = 313632 },
]
+[[package]]
+name = "dunamai"
+version = "1.23.0"
+source = { registry = "https://pypi.org/simple" }
+dependencies = [
+ { name = "packaging" },
+]
+sdist = { url = "https://files.pythonhosted.org/packages/06/4e/a5c8c337a1d9ac0384298ade02d322741fb5998041a5ea74d1cd2a4a1d47/dunamai-1.23.0.tar.gz", hash = "sha256:a163746de7ea5acb6dacdab3a6ad621ebc612ed1e528aaa8beedb8887fccd2c4", size = 44681 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/21/4c/963169386309fec4f96fd61210ac0a0666887d0fb0a50205395674d20b71/dunamai-1.23.0-py3-none-any.whl", hash = "sha256:a0906d876e92441793c6a423e16a4802752e723e9c9a5aabdc5535df02dbe041", size = 26342 },
+]
+
[[package]]
name = "ecdsa"
version = "0.19.0"
@@ -482,16 +495,17 @@ wheels = [
[[package]]
name = "oidc-fastapi-test"
-version = "0.0.0"
source = { editable = "." }
dependencies = [
{ name = "authlib" },
{ name = "cachetools" },
{ name = "fastapi", extra = ["standard"] },
+ { name = "httpx" },
{ name = "itsdangerous" },
{ name = "passlib", extra = ["bcrypt"] },
{ name = "pkce" },
{ name = "pydantic-settings" },
+ { name = "pyjwt" },
{ name = "python-jose", extra = ["cryptography"] },
{ name = "requests" },
{ name = "sqlmodel" },
@@ -499,6 +513,7 @@ dependencies = [
[package.dev-dependencies]
dev = [
+ { name = "dunamai" },
{ name = "ipdb" },
{ name = "pytest" },
]
@@ -508,10 +523,12 @@ requires-dist = [
{ name = "authlib", specifier = ">=1.4.0" },
{ name = "cachetools", specifier = ">=5.5.0" },
{ name = "fastapi", extras = ["standard"], specifier = ">=0.115.6" },
+ { name = "httpx", specifier = ">=0.28.1" },
{ name = "itsdangerous", specifier = ">=2.2.0" },
{ name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" },
{ name = "pkce", specifier = ">=1.0.3" },
{ name = "pydantic-settings", specifier = ">=2.7.1" },
+ { name = "pyjwt", specifier = ">=2.10.1" },
{ name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" },
{ name = "requests", specifier = ">=2.32.3" },
{ name = "sqlmodel", specifier = ">=0.0.22" },
@@ -519,6 +536,7 @@ requires-dist = [
[package.metadata.requires-dev]
dev = [
+ { name = "dunamai", specifier = ">=1.23.0" },
{ name = "ipdb", specifier = ">=0.13.13" },
{ name = "pytest", specifier = ">=8.3.4" },
]
@@ -694,6 +712,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/f7/3f/01c8b82017c199075f8f788d0d906b9ffbbc5a47dc9918a945e13d5a2bda/pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a", size = 1205513 },
]
+[[package]]
+name = "pyjwt"
+version = "2.10.1"
+source = { registry = "https://pypi.org/simple" }
+sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785 }
+wheels = [
+ { url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997 },
+]
+
[[package]]
name = "pytest"
version = "8.3.4"