diff --git a/.forgejo/workflows/build.yaml b/.forgejo/workflows/build.yaml
index 379aaa8..df52e0b 100644
--- a/.forgejo/workflows/build.yaml
+++ b/.forgejo/workflows/build.yaml
@@ -19,7 +19,7 @@ jobs:
- name: Install the latest version of uv
uses: astral-sh/setup-uv@v4
with:
- version: "0.6.9"
+ version: "0.5.16"
- name: Install
run: uv sync
@@ -27,26 +27,32 @@ jobs:
- name: Run tests (API call)
run: .venv/bin/pytest -s tests/basic.py
- - name: Get version
- run: echo "VERSION=$(.venv/bin/dunamai from any --style semver)" >> $GITHUB_ENV
+ - name: Get version with git describe
+ id: version
+ run: |
+ echo "version=$(git describe)" >> $GITHUB_OUTPUT
+ echo "$VERSION"
- - name: Version
- run: echo $VERSION
+ - name: Check if the container should be built
+ id: builder
+ env:
+ RUN: ${{ toJSON(inputs.build || !contains(steps.version.outputs.version, '-')) }}
+ run: |
+ echo "run=$RUN" >> $GITHUB_OUTPUT
+ echo "Run build: $RUN"
- - name: Get distance from tag
- run: echo "DISTANCE=$(.venv/bin/dunamai from any --format '{distance}')" >> $GITHUB_ENV
-
- - name: Distance
- run: echo $DISTANCE
+ - name: Set the version in pyproject.toml (workaround for uv not supporting dynamic version)
+ if: fromJSON(steps.builder.outputs.run)
+ env:
+ VERSION: ${{ steps.version.outputs.version }}
+ run: sed "s/0.0.0/$VERSION/" -i pyproject.toml
- name: Workaround for bug of podman-login
- if: env.DISTANCE == '0'
run: |
mkdir -p $HOME/.docker
echo "{ \"auths\": {} }" > $HOME/.docker/config.json
- name: Log in to the container registry (with another workaround)
- if: env.DISTANCE == '0'
uses: actions/podman-login@v1
with:
registry: ${{ vars.REGISTRY }}
@@ -55,31 +61,26 @@ jobs:
auth_file_path: /tmp/auth.json
- name: Build the container image
- if: env.DISTANCE == '0'
uses: actions/buildah-build@v1
with:
image: oidc-fastapi-test
oci: true
labels: oidc-fastapi-test
- tags: "latest ${{ env.VERSION }}"
+ tags: latest ${{ steps.version.outputs.version }}
containerfiles: |
./Containerfile
- name: Push the image to the registry
- if: env.DISTANCE == '0'
uses: actions/push-to-registry@v2
with:
registry: "docker://${{ vars.REGISTRY }}/${{ vars.ORGANISATION }}"
image: oidc-fastapi-test
- tags: "latest ${{ env.VERSION }}"
+ tags: latest ${{ steps.version.outputs.version }}
- name: Build wheel
- if: env.DISTANCE == '0'
run: uv build --wheel
- name: Publish Python package (home)
- if: env.DISTANCE == '0'
env:
LOCAL_PYPI_TOKEN: ${{ secrets.LOCAL_PYPI_TOKEN }}
run: uv publish --publish-url https://code.philo.ydns.eu/api/packages/philorg/pypi --token $LOCAL_PYPI_TOKEN
- continue-on-error: true
diff --git a/.forgejo/workflows/test.yaml b/.forgejo/workflows/test.yaml
index f4d994e..a56a9ce 100644
--- a/.forgejo/workflows/test.yaml
+++ b/.forgejo/workflows/test.yaml
@@ -19,7 +19,7 @@ jobs:
- name: Install the latest version of uv
uses: astral-sh/setup-uv@v4
with:
- version: "0.6.3"
+ version: "0.5.16"
- name: Install
run: uv sync
diff --git a/Containerfile b/Containerfile
index 0ec45d1..aef57f8 100644
--- a/Containerfile
+++ b/Containerfile
@@ -1,4 +1,4 @@
-FROM docker.io/library/python:latest
+FROM docker.io/library/python:alpine
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/
@@ -9,9 +9,6 @@ WORKDIR /app
RUN uv pip install --system .
-# Add demo plugin
-RUN PIP_EXTRA_INDEX_URL=https://pypi.org/simple/ uv pip install --system --index-url https://code.philo.ydns.eu/api/packages/philorg/pypi/simple/ oidc-fastapi-test-resource-provider-demo
-
# Possible to run with:
#CMD ["oidc-test", "--port", "80"]
#CMD ["fastapi", "run", "src/oidc_test/main.py", "--port", "8873", "--root-path", "/oidc-test"]
diff --git a/README.md b/README.md
index 68f335d..9e00474 100644
--- a/README.md
+++ b/README.md
@@ -52,59 +52,31 @@ given by the OIDC providers.
For example:
```yaml
-secret_key: AVeryWellKeptSecret
-debug_token: no
-show_token: yes
-log: yes
-
-auth:
+oidc:
+ secret_key: "ASecretNoOneKnows"
+ show_session_details: yes
providers:
- id: auth0
name: Okta / Auth0
- url: https://
- public_key_url: https:///pem
- client_id:
- client_secret: client_secret_generated_by_auth0
- hint: A hint for test credentials
+ url: "https://"
+ client_id: ""
+ client_secret: "client_secret_generated_by_auth0"
+ hint: "A hint for test credentials"
- id: keycloak
name: Keycloak at somewhere
- url: https://
- info_url: https://philo.ydns.eu/auth/realms/test
- account_url_template: /account
- client_id:
- client_secret:
- hint: A hint for test credentials
- code_challenge_method: S256
- resource_provider_scopes:
- - get:time
- - get:bs
- resource_providers:
- - id:
- name: A third party resource provider
- base_url: https://some.example.com/
- verify_ssl: yes
- resources:
- - name: Public RS2
- resource_name: public
- url: resource/public
- - name: BS RS2
- resource_name: bs
- url: resource/bs
- - name: Time RS2
- resource_name: time
- url: resource/time
+ url: "https://"
+ account_url_template: "/account"
+ client_id: ""
+ client_secret: "client_secret_generated_by_keycloak"
+ hint: "User: foo, password: foofoo"
- id: codeberg
- disabled: no
name: Codeberg
- url: https://codeberg.org
- account_url_template: /user/settings
- client_id:
- client_secret: client_secret_generated_by_codeberg
- info_url: https://codeberg.org/login/oauth/keys
- session_key: sub
- skip_verify_signature: no
+ url: "https://codeberg.org"
+ account_url_template: "/user/settings"
+ client_id: ""
+ client_secret: "client_secret_generated_by_codeberg"
resources:
- name: List of repos
id: repos
diff --git a/pyproject.toml b/pyproject.toml
index c44e9f3..980bcfc 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,7 +1,7 @@
[project]
name = "oidc-fastapi-test"
-#version = "0.0.0"
-dynamic = ["version"]
+version = "0.0.0"
+# dynamic = ["version"]
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.13"
@@ -24,24 +24,14 @@ dependencies = [
oidc-test = "oidc_test.main:main"
[dependency-groups]
-dev = ["dunamai>=1.23.0", "ipdb>=0.13.13", "pytest>=8.3.4"]
+dev = ["ipdb>=0.13.13", "pytest>=8.3.4"]
[build-system]
-requires = ["hatchling", "uv-dynamic-versioning"]
+requires = ["hatchling"]
build-backend = "hatchling.build"
-[tool.hatch.version]
-source = "uv-dynamic-versioning"
-
[tool.hatch.build.targets.wheel]
packages = ["src/oidc_test"]
-package = true
-
-[tool.uv-dynamic-versioning]
-style = "semver"
[tool.uv]
package = true
-
-[tool.black]
-line-length = 98
diff --git a/src/oidc_test/__init__.py b/src/oidc_test/__init__.py
index f449e2b..e69de29 100644
--- a/src/oidc_test/__init__.py
+++ b/src/oidc_test/__init__.py
@@ -1,11 +0,0 @@
-import importlib.metadata
-
-try:
- from dunamai import Version, Style
-
- __version__ = Version.from_git().serialize(style=Style.SemVer, dirty=True)
-except ImportError:
- # __name__ could be used if the package name is the same
- # as the directory - not the case here
- # __version__ = importlib.metadata.version(__name__)
- __version__ = importlib.metadata.version("oidc-fastapi-test")
diff --git a/src/oidc_test/auth/provider.py b/src/oidc_test/auth/provider.py
deleted file mode 100644
index ce288a6..0000000
--- a/src/oidc_test/auth/provider.py
+++ /dev/null
@@ -1,113 +0,0 @@
-from json import JSONDecodeError
-from typing import Any
-from jwt import decode
-import logging
-
-from pydantic import ConfigDict
-from authlib.integrations.starlette_client.apps import StarletteOAuth2App
-from httpx import AsyncClient
-
-from oidc_test.settings import AuthProviderSettings, ResourceProvider, Resource, settings
-from oidc_test.models import User
-
-logger = logging.getLogger("oidc-test")
-
-
-class NoPublicKey(Exception):
- pass
-
-
-class Provider(AuthProviderSettings):
- # To allow authlib_client as StarletteOAuth2App
- model_config = ConfigDict(arbitrary_types_allowed=True) # type:ignore
-
- authlib_client: StarletteOAuth2App = StarletteOAuth2App(None)
- info: dict[str, Any] = {}
- unknown_auth_user: User
- logout_with_id_token_hint: bool = True
-
- def decode(self, token: str, verify_signature: bool | None = None) -> dict[str, Any]:
- """Decode the token with signature check"""
- if self.public_key is None:
- raise NoPublicKey
- if verify_signature is None:
- verify_signature = self.skip_verify_signature
- if settings.debug_token:
- decoded = decode(
- token,
- self.public_key,
- algorithms=[self.signature_alg],
- audience=["account", "oidc-test", "oidc-test-web"],
- options={
- "verify_signature": False,
- "verify_aud": False,
- }, # not settings.insecure.skip_verify_signature},
- )
- logger.debug(str(decoded))
- return decode(
- token,
- self.public_key,
- algorithms=[self.signature_alg],
- audience=["account", "oidc-test", "oidc-test-web"],
- options={
- "verify_signature": verify_signature,
- }, # not settings.insecure.skip_verify_signature},
- )
-
- async def get_info(self):
- # Get the public key:
- async with AsyncClient() as client:
- public_key: str | None = None
- if self.info_url is not None:
- try:
- provider_info = await client.get(self.info_url)
- except Exception as err:
- logger.debug("Provider_info: cannot connect")
- logger.exception(err)
- raise NoPublicKey
- try:
- self.info = provider_info.json()
- except JSONDecodeError:
- logger.debug("Provider_info: cannot decode json response")
- raise NoPublicKey
- if "public_key" in self.info:
- # For Keycloak
- try:
- public_key = str(self.info["public_key"])
- except KeyError:
- logger.debug("Provider_info: cannot get public_key")
- raise NoPublicKey
- elif "keys" in self.info:
- # For Forgejo/Gitea
- try:
- public_key = str(self.info["keys"][0]["n"])
- except KeyError:
- logger.debug("Provider_info: cannot get key 0.n")
- raise NoPublicKey
- if self.public_key_url is not None:
- resp = await client.get(self.public_key_url)
- public_key = resp.text
- if public_key is None:
- logger.debug("Provider_info: cannot determine public key")
- raise NoPublicKey
- self.public_key = "\n".join(
- ["-----BEGIN PUBLIC KEY-----", public_key, "-----END PUBLIC KEY-----"]
- )
-
- def get_session_key(self, userinfo):
- return userinfo[self.session_key]
-
- def get_resource(self, resource_name: str) -> Resource:
- return [
- resource for resource in self.resources if resource.resource_name == resource_name
- ][0]
-
- def get_resource_url(self, resource_name: str) -> str:
- return self.url + self.get_resource(resource_name).url
-
- def get_resource_provider(self, resource_provider_id: str) -> ResourceProvider:
- return [
- provider
- for provider in self.resource_providers
- if provider.id == resource_provider_id
- ][0]
diff --git a/src/oidc_test/auth/utils.py b/src/oidc_test/auth/utils.py
deleted file mode 100644
index c51b039..0000000
--- a/src/oidc_test/auth/utils.py
+++ /dev/null
@@ -1,305 +0,0 @@
-from typing import Union, Annotated
-from functools import wraps
-import logging
-
-from fastapi import HTTPException, Request, Depends, status
-from fastapi.security import OAuth2PasswordBearer
-from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App
-from authlib.oauth2.rfc6749 import OAuth2Token
-from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError
-
-from oidc_test.auth.provider import Provider
-from oidc_test.models import User
-from oidc_test.database import db, TokenNotInDb, UserNotInDB
-from oidc_test.settings import settings
-from oidc_test.auth_providers import providers
-
-logger = logging.getLogger("oidc-test")
-
-
-async def fetch_token(name, request):
- assert name is not None
- assert request is not None
- logger.warning("TODO: fetch_token")
- ...
- # if name in oidc_providers:
- # model = OAuth2Token
- # else:
- # model = OAuthToken
-
- # token = model.find(name=name, user=request.user)
- # return token.to_token()
-
-
-async def update_token(
- provider_id,
- token,
- refresh_token: str | None = None,
- access_token: str | None = None,
-):
- """Update the token in the database"""
- provider = providers[provider_id]
- sid: str = provider.get_session_key(provider.decode(token["id_token"]))
- item = await db.get_token(provider, sid)
- # update old token
- item["access_token"] = token["access_token"]
- item["refresh_token"] = token["refresh_token"]
- item["id_token"] = token["id_token"]
- item["expires_at"] = token["expires_at"]
- logger.info(f"Token {sid} refreshed")
- # It's a fake db and only in memory, so there's nothing to save
- # await item.save()
-
-
-def init_providers():
- """Add oidc providers to authlib from the settings
- and build the providers dict"""
- for provider_settings in settings.auth.providers:
- provider_settings_dict = provider_settings.model_dump()
- # Add an anonymous user, that cannot be identified but has provided a valid access token
- provider_settings_dict["unknown_auth_user"] = User(
- sub="", auth_provider_id=provider_settings.id
- )
- provider = Provider(**provider_settings_dict)
- if provider.disabled:
- logger.info(f"{provider_settings.name} is disabled, skipping")
- else:
- authlib_oauth.register(
- name=provider.id,
- server_metadata_url=provider.openid_configuration,
- client_kwargs={
- "scope": " ".join(
- ["openid", "email", "offline_access", "profile"]
- + provider.resource_provider_scopes
- ),
- },
- client_id=provider.client_id,
- client_secret=provider.client_secret,
- api_base_url=provider.url,
- # For PKCE (not implemented yet):
- # code_challenge_method="S256",
- fetch_token=fetch_token,
- update_token=update_token,
- # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
- )
- provider.authlib_client = getattr(authlib_oauth, provider.id)
- providers[provider.id] = provider
-
-
-authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
-oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
-oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
-
-
-def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None:
- """Return the oidc_provider from a request object, from the session.
- It can be used in Depends()"""
- if (auth_provider_id := request.session.get("auth_provider_id")) is None:
- return
- return getattr(authlib_oauth, str(auth_provider_id), None)
-
-
-def get_auth_provider_client(request: Request) -> StarletteOAuth2App:
- if (oidc_provider := get_auth_provider_client_or_none(request)) is None:
- raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
- else:
- return oidc_provider
-
-
-def get_auth_provider_or_none(request: Request) -> Provider | None:
- """Return the oidc_provider settings from a request object, from the session.
- It can be used in Depends()"""
- if (auth_provider_id := request.session.get("auth_provider_id")) is None:
- return
- return providers.get(auth_provider_id)
-
-
-def get_auth_provider(request: Request) -> Provider:
- if (provider := get_auth_provider_or_none(request)) is None:
- raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
- return provider
-
-
-async def get_current_user(request: Request) -> User:
- """Get the current user from a request object.
- Also validates the token expiration time.
- ... TODO: complete about refresh token
- """
- if (user_sub := request.session.get("user_sub")) is None:
- raise HTTPException(status.HTTP_401_UNAUTHORIZED)
- token = await get_token_from_session(request)
- user = await db.get_user(user_sub)
- ## Check if the token is expired
- if token.is_expired():
- provider = get_auth_provider(request=request)
- ## Ask a new refresh token from the provider
- logger.info(f"Token expired for user {user.name}")
- try:
- userinfo = await provider.authlib_client.fetch_access_token(
- refresh_token=token.get("refresh_token")
- )
- assert userinfo is not None
- except OAuthError as err:
- logger.exception(err)
- # raise HTTPException(
- # status.HTTP_401_UNAUTHORIZED, "Token expired, cannot refresh"
- # )
-
- return user
-
-
-async def get_token_from_session_or_none(request: Request) -> OAuth2Token | None:
- """Return the auth token from the session or None.
- Can be used in Depends()"""
- try:
- return await get_token_from_session(request)
- except HTTPException:
- return None
-
-
-async def get_token_from_session(request: Request) -> OAuth2Token:
- """Return the token from the session.
- Can be used in Depends()"""
- try:
- provider = providers[request.session.get("auth_provider_id", "")]
- except KeyError:
- request.session.pop("auth_provider_id", None)
- request.session.pop("user_sub", None)
- raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid provider")
- try:
- return await db.get_token(
- provider,
- request.session.get("sid"),
- )
- except (TokenNotInDb, InvalidKeyError, DecodeError) as err:
- raise HTTPException(status.HTTP_401_UNAUTHORIZED, err.__class__.__name__)
-
-
-async def get_current_user_or_none(request: Request) -> User | None:
- """Return the user from a request object, from the session.
- It can be used in Depends()"""
- try:
- return await get_current_user(request)
- except HTTPException:
- return None
-
-
-def hasrole(required_roles: Union[str, list[str]] = []):
- """Decorator for RBAC permissions"""
- required_roles_set: set[str]
- if isinstance(required_roles, str):
- required_roles_set = set([required_roles])
- else:
- required_roles_set = set(required_roles)
-
- def decorator(func):
- @wraps(func)
- async def wrapper(request=None, *args, **kwargs):
- if request is None:
- raise HTTPException(
- 500,
- "Functions decorated with hasrole must have a request:Request argument",
- )
- user: User = await get_current_user(request)
- if not any(required_roles_set.intersection(user.roles_as_set)):
- raise HTTPException(status.HTTP_401_UNAUTHORIZED)
- return await func(request, *args, **kwargs)
-
- return wrapper
-
- return decorator
-
-
-def get_token_info(token: dict) -> dict:
- token_info = dict()
- for key in token:
- if key != "userinfo":
- token_info[key] = token[key]
- return token_info
-
-
-async def get_user_from_token(
- token: Annotated[str, Depends(oauth2_scheme)],
- request: Request,
-) -> User:
- try:
- auth_provider_id = request.headers["auth_provider"]
- except KeyError:
- raise HTTPException(
- status.HTTP_401_UNAUTHORIZED,
- "Request headers must have a 'auth_provider' field",
- )
- try:
- provider = providers[auth_provider_id]
- except KeyError:
- if auth_provider_id == "":
- raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No auth provider")
- else:
- raise HTTPException(
- status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
- )
- if token == "None":
- request.session.pop("auth_provider_id", None)
- request.session.pop("user_sub", None)
- raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token")
- try:
- payload = provider.decode(token)
- except ExpiredSignatureError:
- raise HTTPException(
- status.HTTP_401_UNAUTHORIZED,
- "Expired signature (token refresh not implemented yet)",
- )
- except InvalidKeyError:
- raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key")
- except PyJWTError as err:
- raise HTTPException(
- status.HTTP_401_UNAUTHORIZED, f"Cannot decode token: {err.__class__.__name__}"
- )
- try:
- user_id = payload["sub"]
- except KeyError:
- return provider.unknown_auth_user
- try:
- user = await db.get_user(user_id)
- if user.access_token != token:
- user.access_token = token
- except UserNotInDB:
- logger.info(
- f"User {user_id} not found in DB, creating it (real apps can behave differently)"
- )
- user = await db.add_user(
- sub=payload["sub"],
- user_info=payload,
- auth_provider=providers[auth_provider_id],
- access_token=token,
- )
- return user
-
-
-async def get_user_from_token_or_none(
- token: Annotated[str | None, Depends(oauth2_scheme_optional)],
- request: Request,
-) -> User | None:
- if token is None:
- return None
- try:
- return await get_user_from_token(token, request)
- except HTTPException:
- return None
-
-
-class UserWithRole:
- roles: set[str]
-
- def __init__(self, roles: str | list[str] | tuple[str] | set[str]):
- if isinstance(roles, str):
- self.roles = set([roles])
- elif isinstance(roles, (list, tuple, set)):
- self.roles = set(roles)
-
- def __call__(self, user: User = Depends(get_user_from_token)) -> User:
- if not any(self.roles.intersection(user.roles_as_set)):
- raise HTTPException(
- status.HTTP_401_UNAUTHORIZED, f"Not of any required role {', '.join(self.roles)}"
- )
- return user
diff --git a/src/oidc_test/auth_misc.py b/src/oidc_test/auth_misc.py
new file mode 100644
index 0000000..a4e9ea3
--- /dev/null
+++ b/src/oidc_test/auth_misc.py
@@ -0,0 +1,29 @@
+from datetime import datetime, timedelta
+from collections import OrderedDict
+
+from .models import User
+
+time_keys = set(("iat", "exp", "auth_time", "updated_at"))
+
+
+def pretty_details(user: User, now: datetime) -> OrderedDict:
+ details = OrderedDict()
+ # breakpoint()
+ for key in sorted(time_keys):
+ try:
+ dt = datetime.fromtimestamp(user.userinfo[key])
+ except (KeyError, TypeError):
+ pass
+ else:
+ td = now - dt
+ td = timedelta(days=td.days, seconds=td.seconds)
+ if td.days < 0:
+ ptd = f"in {-td} h:m:s"
+ else:
+ ptd = f"{td} h:m:s ago"
+ details[key] = f"{user.userinfo[key]} - {dt} ({ptd})"
+ for key in sorted(user.userinfo):
+ if key in time_keys:
+ continue
+ details[key] = user.userinfo[key]
+ return details
diff --git a/src/oidc_test/auth_providers.py b/src/oidc_test/auth_providers.py
deleted file mode 100644
index 1c33ae8..0000000
--- a/src/oidc_test/auth_providers.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from collections import OrderedDict
-
-from oidc_test.auth.provider import Provider
-
-providers: OrderedDict[str, Provider] = OrderedDict()
diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py
new file mode 100644
index 0000000..ed3350c
--- /dev/null
+++ b/src/oidc_test/auth_utils.py
@@ -0,0 +1,236 @@
+from typing import Union, Annotated
+from functools import wraps
+import logging
+
+from fastapi import HTTPException, Request, Depends, status
+from fastapi.security import OAuth2PasswordBearer
+from authlib.oauth2.rfc6749 import OAuth2Token
+from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App
+from jwt import ExpiredSignatureError, InvalidKeyError, decode
+from httpx import AsyncClient
+
+# from authlib.oauth1.auth import OAuthToken
+# from authlib.oauth2.auth import OAuth2Token
+
+from .models import User
+from .database import db, UserNotInDB
+from .settings import settings, OIDCProvider
+
+logger = logging.getLogger(__name__)
+
+oidc_providers_settings: dict[str, OIDCProvider] = dict(
+ [(provider.id, provider) for provider in settings.oidc.providers]
+)
+
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+
+
+def fetch_token(name, request):
+ breakpoint()
+ ...
+ # if name in oidc_providers:
+ # model = OAuth2Token
+ # else:
+ # model = OAuthToken
+
+ # token = model.find(name=name, user=request.user)
+ # return token.to_token()
+
+
+def update_token(*args, **kwargs):
+ breakpoint()
+ ...
+
+
+authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
+
+
+def init_providers():
+ # Add oidc providers to authlib from the settings
+ for id, provider in oidc_providers_settings.items():
+ authlib_oauth.register(
+ name=id,
+ server_metadata_url=provider.openid_configuration,
+ client_kwargs={
+ "scope": "openid email offline_access profile",
+ },
+ client_id=provider.client_id,
+ client_secret=provider.client_secret,
+ api_base_url=provider.url,
+ # For PKCE (not implemented yet):
+ # code_challenge_method="S256",
+ # fetch_token=fetch_token,
+ # update_token=update_token,
+ # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
+ )
+
+
+init_providers()
+
+
+async def get_providers_info():
+ # Get the public key:
+ async with AsyncClient() as client:
+ for provider_settings in oidc_providers_settings.values():
+ if provider_settings.info_url:
+ provider_info = await client.get(provider_settings.url)
+ provider_settings.info = provider_info.json()
+
+
+def get_oidc_provider_or_none(request: Request) -> StarletteOAuth2App | None:
+ """Return the oidc_provider from a request object, from the session.
+ It can be used in Depends()"""
+ if (oidc_provider_id := request.session.get("oidc_provider_id")) is None:
+ return
+ try:
+ return getattr(authlib_oauth, str(oidc_provider_id))
+ except AttributeError:
+ return
+
+
+def get_oidc_provider(request: Request) -> StarletteOAuth2App:
+ if (oidc_provider := get_oidc_provider_or_none(request)) is None:
+ raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
+ else:
+ return oidc_provider
+
+
+async def get_current_user(request: Request) -> User:
+ """Get the current user from a request object.
+ Also validates the token expiration time.
+ ... TODO: complete about refresh token
+ """
+ if (user_sub := request.session.get("user_sub")) is None:
+ raise HTTPException(status.HTTP_401_UNAUTHORIZED)
+ if (token := await db.get_token(request.session["token"])) is None:
+ raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token unknown")
+ user = await db.get_user(user_sub)
+ ## Check if the token is expired
+ if token.is_expired():
+ oidc_provider = get_oidc_provider(request=request)
+ ## Ask a new refresh token from the provider
+ logger.info(f"Token expired for user {user.name}")
+ try:
+ userinfo = await oidc_provider.fetch_access_token(
+ refresh_token=token.get("refresh_token")
+ )
+ except OAuthError as err:
+ logger.exception(err)
+ # raise HTTPException(
+ # status.HTTP_401_UNAUTHORIZED, "Token expired, cannot refresh"
+ # )
+
+ return user
+
+
+async def get_token(request: Request) -> OAuth2Token:
+ """Return the token from a request object, from the session.
+ It can be used in Depends()"""
+ if (token := await db.get_token(request.session.get("token"))) is None:
+ raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token")
+ return token
+
+
+async def get_current_user_or_none(request: Request) -> User | None:
+ """Return the user from a request object, from the session.
+ It can be used in Depends()"""
+ try:
+ return await get_current_user(request)
+ except HTTPException:
+ return None
+
+
+def hasrole(required_roles: Union[str, list[str]] = []):
+ """Decorator for RBAC permissions"""
+ required_roles_set: set[str]
+ if isinstance(required_roles, str):
+ required_roles_set = set([required_roles])
+ else:
+ required_roles_set = set(required_roles)
+
+ def decorator(func):
+ @wraps(func)
+ async def wrapper(request=None, *args, **kwargs):
+ if request is None:
+ raise HTTPException(
+ 500,
+ "Functions decorated with hasrole must have a request:Request argument",
+ )
+ user: User = await get_current_user(request)
+ if not any(required_roles_set.intersection(user.roles_as_set)):
+ raise HTTPException(status.HTTP_401_UNAUTHORIZED)
+ return await func(request, *args, **kwargs)
+
+ return wrapper
+
+ return decorator
+
+
+def get_token_info(token: dict) -> dict:
+ token_info = dict()
+ for key in token:
+ if key != "userinfo":
+ token_info[key] = token[key]
+ return token_info
+
+
+async def get_user_from_token(
+ token: Annotated[str, Depends(oauth2_scheme)],
+ request: Request,
+) -> User:
+ if (auth_provider_id := request.headers.get("auth_provider")) is None:
+ raise HTTPException(
+ status.HTTP_401_UNAUTHORIZED,
+ "Request headers must have a 'auth_provider' field",
+ )
+ if (
+ auth_provider_settings := oidc_providers_settings.get(auth_provider_id)
+ ) is None:
+ raise HTTPException(
+ status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
+ )
+ if (key := auth_provider_settings.get_public_key()) is None:
+ raise HTTPException(
+ status.HTTP_401_UNAUTHORIZED,
+ f"Key for provider '{auth_provider_id}' unknown",
+ )
+ try:
+ payload = decode(
+ token,
+ key=key,
+ algorithms=["RS256"],
+ audience="oidc-test",
+ options={"verify_signature": not settings.insecure.skip_verify_signature},
+ )
+ except ExpiredSignatureError as err:
+ logger.info(f"Expired signature: {err}")
+ raise HTTPException(
+ status.HTTP_401_UNAUTHORIZED,
+ "Expired signature (refresh not implemented yet)",
+ )
+ except InvalidKeyError as err:
+ logger.info(f"Invalid key: {err}")
+ raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key")
+ except Exception as err:
+ logger.info("Cannot decode token, see below")
+ logger.exception(err)
+ raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot decode token")
+ if (user_id := payload.get("sub")) is None:
+ raise HTTPException(
+ status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found"
+ )
+ try:
+ user = await db.get_user(user_id)
+ user.access_token = payload
+ except UserNotInDB:
+ logger.info(
+ f"User {user_id} not found in DB, creating it (real apps can behave differently"
+ )
+ user = await db.add_user(
+ sub=payload["sub"],
+ user_info=payload,
+ oidc_provider=getattr(authlib_oauth, auth_provider_id),
+ user_info_from_endpoint={},
+ access_token=payload,
+ )
+ return user
diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py
index 8d87a48..5dec7fc 100644
--- a/src/oidc_test/database.py
+++ b/src/oidc_test/database.py
@@ -2,28 +2,20 @@
import logging
+from authlib.integrations.starlette_client.apps import StarletteOAuth2App
+
+from .models import User, Role
from authlib.oauth2.rfc6749 import OAuth2Token
-from jwt import PyJWTError
-from oidc_test.auth.provider import Provider
-
-from oidc_test.models import User, Role
-from oidc_test.auth_providers import providers
-
-logger = logging.getLogger("oidc-test")
+logger = logging.getLogger(__name__)
class UserNotInDB(Exception):
pass
-class TokenNotInDb(Exception):
- pass
-
-
class Database:
users: dict[str, User] = {}
- # TODO: key of the token table should be provider: sid
tokens: dict[str, OAuth2Token] = {}
# Last sessions for the user (key: users's subject id (sub))
@@ -32,38 +24,21 @@ class Database:
self,
sub: str,
user_info: dict,
- auth_provider: Provider,
- access_token: str,
- access_token_decoded: dict | None = None,
+ oidc_provider: StarletteOAuth2App,
+ user_info_from_endpoint: dict,
+ access_token: dict,
) -> User:
- if access_token_decoded is None:
- assert auth_provider.name is not None
- provider = providers[auth_provider.id]
- try:
- access_token_decoded = provider.decode(access_token)
- except PyJWTError:
- access_token_decoded = {}
- user_info["auth_provider_id"] = auth_provider.id
- user = User(**user_info)
- user.userinfo = user_info
- # user.access_token = access_token
- # user.access_token_decoded = access_token_decoded
- # Add roles provided in the access token
- roles = set()
+ user = User.from_auth(userinfo=user_info, oidc_provider=oidc_provider)
+ user.access_token = access_token
try:
- r = access_token_decoded["resource_access"][auth_provider.client_id]["roles"]
- roles.update(r)
- except KeyError:
- pass
- try:
- r = access_token_decoded["realm_access"]["roles"]
- if isinstance(r, str):
- roles.add(r)
- else:
- roles.update(r)
- except KeyError:
- pass
- user.roles = [Role(name=role_name) for role_name in roles]
+ raw_roles = user_info_from_endpoint["resource_access"][
+ oidc_provider.client_id
+ ]["roles"]
+ except Exception as err:
+ logger.debug(f"Cannot read additional roles: {err}")
+ raw_roles = []
+ for raw_role in raw_roles:
+ user.roles.append(Role(name=raw_role))
self.users[sub] = user
return user
@@ -72,25 +47,11 @@ class Database:
raise UserNotInDB
return self.users[sub]
- async def add_token(self, provider: Provider, token: OAuth2Token) -> None:
- """Store a token using as key the sid (auth provider's session id)
- in the id_token"""
- sid = provider.get_session_key(token["userinfo"])
- self.tokens[sid] = token
+ async def add_token(self, token: OAuth2Token, user: User) -> None:
+ self.tokens[token["id_token"]] = token
- async def get_token(
- self,
- provider: Provider,
- sid: str | None,
- ) -> OAuth2Token:
- # TODO: key of the token table should be provider: sid
- assert isinstance(provider, Provider)
- if sid is None:
- raise TokenNotInDb
- try:
- return self.tokens[sid]
- except KeyError:
- raise TokenNotInDb
+ async def get_token(self, id_token: str) -> OAuth2Token | None:
+ return self.tokens.get(id_token)
db = Database()
diff --git a/src/oidc_test/log_conf.yaml b/src/oidc_test/log_conf.yaml
deleted file mode 100644
index a6bb0b4..0000000
--- a/src/oidc_test/log_conf.yaml
+++ /dev/null
@@ -1,34 +0,0 @@
-version: 1
-disable_existing_loggers: False
-formatters:
- default:
- "()": uvicorn.logging.DefaultFormatter
- format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
- access:
- "()": uvicorn.logging.AccessFormatter
- format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
-handlers:
- default:
- formatter: default
- class: logging.StreamHandler
- stream: ext://sys.stderr
- access:
- formatter: access
- class: logging.StreamHandler
- stream: ext://sys.stdout
-loggers:
- uvicorn.error:
- level: INFO
- handlers:
- - default
- propagate: no
- uvicorn.access:
- level: INFO
- handlers:
- - access
- propagate: no
- "oidc-test":
- level: DEBUG
- handlers:
- - default
- propagate: yes
diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py
index e882cda..739dd1b 100644
--- a/src/oidc_test/main.py
+++ b/src/oidc_test/main.py
@@ -6,19 +6,15 @@ from typing import Annotated
from pathlib import Path
from datetime import datetime
import logging
-import logging.config
-import importlib.resources
-from yaml import safe_load
from urllib.parse import urlencode
from contextlib import asynccontextmanager
from httpx import HTTPError
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.staticfiles import StaticFiles
-from fastapi.responses import HTMLResponse, RedirectResponse
+from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
-from jwt import PyJWTError
from starlette.middleware.sessions import SessionMiddleware
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from authlib.integrations.base_client import OAuthError
@@ -29,52 +25,32 @@ from authlib.oauth2.rfc6749 import OAuth2Token
# from fastapi.security import OpenIdConnect
# from pkce import generate_code_verifier, generate_pkce_pair
-from oidc_test import __version__
-from oidc_test.registry import registry
-from oidc_test.auth.provider import NoPublicKey, Provider
-from oidc_test.auth.utils import (
- get_auth_provider,
- get_auth_provider_or_none,
+from .settings import settings
+from .models import User
+from .auth_utils import (
+ get_oidc_provider,
+ get_oidc_provider_or_none,
+ hasrole,
get_current_user_or_none,
+ get_current_user,
+ get_user_from_token,
authlib_oauth,
- get_token_from_session_or_none,
- get_token_from_session,
- update_token,
+ get_token,
+ oidc_providers_settings,
+ get_providers_info,
)
-from oidc_test.auth.utils import init_providers
-from oidc_test.settings import settings
-from oidc_test.auth_providers import providers
-from oidc_test.models import User
-from oidc_test.database import TokenNotInDb, db
-from oidc_test.resource_server import resource_server
+from .auth_misc import pretty_details
+from .database import db
+from .resource_server import get_resource
-logger = logging.getLogger("oidc-test")
-
-if settings.log:
- assert __package__ is not None
- with (
- importlib.resources.path(__package__) as package_path,
- open(package_path / settings.log_config_file) as f,
- ):
- logging_config = safe_load(f)
- logging.config.dictConfig(logging_config)
+logger = logging.getLogger("uvicorn.error")
templates = Jinja2Templates(Path(__file__).parent / "templates")
@asynccontextmanager
async def lifespan(app: FastAPI):
- assert app is not None
- init_providers()
- registry.make_registry()
- for provider in list(providers.values()):
- if provider.disabled:
- continue
- try:
- await provider.get_info()
- except NoPublicKey:
- logger.warning(f"Disable {provider.id}: public key not found")
- del providers[provider.id]
+ await get_providers_info()
yield
@@ -88,79 +64,74 @@ app.add_middleware(
allow_headers=["*"],
)
+app.mount(
+ "/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static"
+)
+
# SessionMiddleware is required by authlib
app.add_middleware(
SessionMiddleware,
secret_key=settings.secret_key,
)
-app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static")
-app.mount("/resource", resource_server, name="resource_server")
-
@app.get("/")
async def home(
request: Request,
- user: Annotated[User | None, Depends(get_current_user_or_none)],
- provider: Annotated[Provider | None, Depends(get_auth_provider_or_none)],
- token: Annotated[OAuth2Token | None, Depends(get_token_from_session_or_none)],
+ user: Annotated[User, Depends(get_current_user_or_none)],
+ oidc_provider: Annotated[
+ StarletteOAuth2App | None, Depends(get_oidc_provider_or_none)
+ ],
) -> HTMLResponse:
- context = {
- "show_token": settings.show_token,
- "user": user,
- "now": datetime.now(),
- "__version__": __version__,
- }
- if provider is None or token is None:
- context["providers"] = providers
- context["access_token"] = None
- context["id_token_parsed"] = None
- context["access_token_parsed"] = None
- context["refresh_token_parsed"] = None
- context["resources"] = None
- context["auth_provider"] = None
+ now = datetime.now()
+ if oidc_provider and (
+ (
+ oidc_provider_settings := oidc_providers_settings.get(
+ request.session.get("oidc_provider_id", "")
+ )
+ )
+ is not None
+ ):
+ resources = oidc_provider_settings.resources
else:
- context["auth_provider"] = provider
- context["access_token"] = token["access_token"]
- try:
- access_token_parsed = provider.decode(token["access_token"], verify_signature=False)
- context["access_token_parsed"] = access_token_parsed
- context["access_token_scope"] = access_token_parsed.get("scope")
- except PyJWTError as err:
- context["access_token_parsed"] = {"Cannot parse": err.__class__.__name__}
- context["access_token_scope"] = None
- try:
- id_token_parsed = provider.decode(token["id_token"], verify_signature=False)
- context["id_token_parsed"] = id_token_parsed
- except PyJWTError as err:
- context["id_token_parsed"] = {"Cannot parse": err.__class__.__name__}
- try:
- refresh_token_parsed = provider.decode(token["refresh_token"], verify_signature=False)
- context["refresh_token_parsed"] = refresh_token_parsed
- except PyJWTError as err:
- context["refresh_token_parsed"] = {"Cannot parse": err.__class__.__name__}
- context["resources"] = registry.resources
- context["resource_providers"] = provider.resource_providers
- return templates.TemplateResponse(name="home.html", request=request, context=context)
+ resources = []
+ oidc_provider_settings = None
+ return templates.TemplateResponse(
+ name="home.html",
+ request=request,
+ context={
+ "settings": settings.model_dump(),
+ "user": user,
+ "now": now,
+ "oidc_provider": oidc_provider,
+ "oidc_provider_settings": oidc_provider_settings,
+ "resources": resources,
+ "user_info_details": (
+ pretty_details(user, now)
+ if user and settings.oidc.show_session_details
+ else None
+ ),
+ },
+ )
# Endpoints for the login / authorization process
-@app.get("/login/{auth_provider_id}")
-async def login(request: Request, auth_provider_id: str) -> RedirectResponse:
+@app.get("/login/{oidc_provider_id}")
+async def login(request: Request, oidc_provider_id: str) -> RedirectResponse:
"""Login with the provider id, giving the browser a redirect to its authorize page.
- The provider is expected to send the browser back to our own /auth/{auth_provider_id} url
+ The provider is expected to send the browser back to our own /auth/{oidc_provider_id} url
with the token.
"""
- redirect_uri = request.url_for("auth", auth_provider_id=auth_provider_id)
+ redirect_uri = request.url_for("auth", oidc_provider_id=oidc_provider_id)
try:
- provider: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id)
+ provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
except AttributeError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
# if (
- # code_challenge_method := providers[
- # auth_provider_id
+ # code_challenge_method := oidc_providers_settings[
+ # oidc_provider_id
# ].code_challenge_method
# ) is not None:
# #client = AsyncOAuth2Client(..., code_challenge_method=code_challenge_method)
@@ -180,144 +151,206 @@ async def login(request: Request, auth_provider_id: str) -> RedirectResponse:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
-@app.get("/auth/{auth_provider_id}")
-async def auth(
- request: Request,
- auth_provider_id: str,
-) -> RedirectResponse:
+@app.get("/auth/{oidc_provider_id}")
+async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
"""Decrypt the auth token, store it to the session (cookie based)
and response to the browser with a redirect to a "welcome user" page.
"""
try:
- provider = providers[auth_provider_id]
- except KeyError:
+ oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
+ except AttributeError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
try:
- token: OAuth2Token = await provider.authlib_client.authorize_access_token(request)
+ token: OAuth2Token = await oidc_provider.authorize_access_token(request)
except OAuthError as error:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error)
- # Remember the authlib_client in the session
+ # Remember the oidc_provider in the session
# logger.info(f"Scope: {token['scope']}")
- request.session["auth_provider_id"] = auth_provider_id
+ request.session["oidc_provider_id"] = oidc_provider_id
#
# One could process the full decoded token which contains extra information
# eg for updates. Here we are only interested in roles
#
if userinfo := token.get("userinfo"):
- # Remember the authlib_client in the session
- request.session["auth_provider_id"] = auth_provider_id
- # User id (sub) given by auth provider
+ # Remember the oidc_provider in the session
+ request.session["oidc_provider_id"] = oidc_provider_id
+ # User id (sub) given by oidc provider
sub = userinfo["sub"]
- ## Get additional data from userinfo endpoint
- # try:
- # user_info_from_endpoint = await authlib_client.userinfo(
- # token=token, follow_redirects=True
- # )
- # except Exception as err:
- # logger.warn(f"Cannot get userinfo from endpoint: {err}")
- # user_info_from_endpoint = {}
+ # Get additional data from userinfo endpoint
+ try:
+ user_info_from_endpoint = await oidc_provider.userinfo(
+ token=token, follow_redirects=True
+ )
+ except Exception as err:
+ logger.warn(f"Cannot get userinfo from endpoint: {err}")
+ user_info_from_endpoint = {}
# Build and remember the user in the session
request.session["user_sub"] = sub
- # Store the user in the database, which also verifies the token validity and signature
- try:
- user = await db.add_user(
- sub,
- user_info=userinfo,
- auth_provider=providers[auth_provider_id],
- access_token=token["access_token"],
- )
- except PyJWTError as err:
- raise HTTPException(
- status.HTTP_401_UNAUTHORIZED,
- detail=f"Token invalid: {err.__class__.__name__}",
- )
- assert isinstance(user, User)
- # Add the provider session id to the session
- request.session["sid"] = provider.get_session_key(userinfo)
+ # Store the user in the database
+ user = await db.add_user(
+ sub,
+ user_info=userinfo,
+ oidc_provider=oidc_provider,
+ user_info_from_endpoint=user_info_from_endpoint,
+ )
+ # Add the id_token to the session
+ request.session["token"] = token["id_token"]
# Add the token to the db because it is used for logout
- await db.add_token(provider, token)
+ await db.add_token(token, user)
# Send the user to the home: (s)he is authenticated
return RedirectResponse(url=request.url_for("home"))
else:
# Not sure if it's correct to redirect to plain login
# if no userinfo is provided
- return RedirectResponse(url=request.url_for("login", auth_provider_id=auth_provider_id))
+ return RedirectResponse(
+ url=request.url_for("login", oidc_provider_id=oidc_provider_id)
+ )
@app.get("/account")
async def account(
- provider: Annotated[Provider, Depends(get_auth_provider)],
+ request: Request,
+ oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
) -> RedirectResponse:
- """Redirect to the auth provider account management,
- if account_url_template is in the provider's settings"""
- return RedirectResponse(f"{provider.account_url_template}")
+ if (
+ provider := oidc_providers_settings.get(
+ request.session.get("oidc_provider_id", "")
+ )
+ ) is None:
+ raise HTTPException(
+ status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider setting"
+ )
+ return RedirectResponse(f"{provider.account_url}")
@app.get("/logout")
async def logout(
request: Request,
- provider: Annotated[Provider, Depends(get_auth_provider)],
+ oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
) -> RedirectResponse:
- # Get provider's endpoint
- if (
- provider_logout_uri := provider.authlib_client.server_metadata.get("end_session_endpoint")
- ) is None:
- logger.warning(f"Cannot find end_session_endpoint for provider {provider.id}")
- return RedirectResponse(request.url_for("non_compliant_logout"))
- post_logout_uri = request.url_for("home")
# Clear session
request.session.pop("user_sub", None)
- request.session.pop("auth_provider_id", None)
- try:
- token = await db.get_token(provider, request.session.pop("sid", None))
- except TokenNotInDb:
- logger.warning("No session in db for the token or no token")
+ # Get provider's endpoint
+ if (
+ provider_logout_uri := oidc_provider.server_metadata.get("end_session_endpoint")
+ ) is None:
+ logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}")
+ return RedirectResponse(request.url_for("non_compliant_logout"))
+ post_logout_uri = request.url_for("home")
+ if (token := await db.get_token(request.session.pop("token", None))) is None:
+ logger.warn("No session in db for the token")
return RedirectResponse(request.url_for("home"))
- url_query = {
- "post_logout_redirect_uri": post_logout_uri,
- "client_id": provider.client_id,
- }
- if provider.logout_with_id_token_hint:
- url_query["id_token_hint"] = token["id_token"]
- logout_url = f"{provider_logout_uri}?{urlencode(url_query)}"
+ logout_url = (
+ provider_logout_uri
+ + "?"
+ + urlencode(
+ {
+ "post_logout_redirect_uri": post_logout_uri,
+ "id_token_hint": token["id_token"],
+ "cliend_id": "oidc_local_test",
+ }
+ )
+ )
return RedirectResponse(logout_url)
@app.get("/non-compliant-logout")
async def non_compliant_logout(
request: Request,
- provider: Annotated[Provider, Depends(get_auth_provider)],
+ oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
):
"""A page for non-compliant OAuth2 servers that we cannot log out."""
- # Clear session
- request.session.pop("user_sub", None)
- request.session.pop("auth_provider_id", None)
+ # Clear the remain of the session
+ request.session.pop("oidc_provider_id", None)
return templates.TemplateResponse(
name="non_compliant_logout.html",
request=request,
- context={"auth_provider": provider, "home_url": request.url_for("home")},
+ context={"oidc_provider": oidc_provider, "home_url": request.url_for("home")},
)
-@app.get("/refresh")
-async def refresh(
+# Route for OAuth resource server
+
+
+@app.get("/resource/{id}")
+async def get_resource_(
+ id: str,
+ # user: Annotated[User, Depends(get_current_user)],
+ # oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
+ # token: Annotated[OAuth2Token, Depends(get_token)],
+ user: Annotated[User, Depends(get_user_from_token)],
+) -> JSONResponse:
+ """Generic path for testing a resource provided by a provider"""
+ return JSONResponse(await get_resource(id, user))
+
+
+# Routes for RBAC based tests
+
+
+@app.get("/public")
+async def public() -> HTMLResponse:
+ return HTMLResponse("Not protected ")
+
+
+@app.get("/protected")
+async def get_protected(
+ user: Annotated[User, Depends(get_current_user)]
+) -> HTMLResponse:
+ assert user is not None # Just to keep QA checks happy
+ return HTMLResponse("Only authenticated users can see this ")
+
+
+@app.get("/protected-by-foorole")
+@hasrole("foorole")
+async def get_protected_by_foorole(request: Request) -> HTMLResponse:
+ assert request is not None # Just to keep QA checks happy
+ return HTMLResponse("Only users with foorole can see this ")
+
+
+@app.get("/protected-by-barrole")
+@hasrole("barrole")
+async def get_protected_by_barrole(request: Request) -> HTMLResponse:
+ assert request is not None # Just to keep QA checks happy
+ return HTMLResponse("Protected by barrole ")
+
+
+@app.get("/protected-by-foorole-and-barrole")
+@hasrole("barrole")
+@hasrole("foorole")
+async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse:
+ assert request is not None # Just to keep QA checks happy
+ return HTMLResponse("Only users with foorole and barrole can see this ")
+
+
+@app.get("/protected-by-foorole-or-barrole")
+@hasrole(["foorole", "barrole"])
+async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse:
+ assert request is not None # Just to keep QA checks happy
+ return HTMLResponse("Only users with foorole or barrole can see this ")
+
+
+@app.get("/introspect")
+async def get_introspect(
request: Request,
- provider: Annotated[Provider, Depends(get_auth_provider)],
- token: Annotated[OAuth2Token, Depends(get_token_from_session)],
-) -> RedirectResponse:
- """Manually refresh token"""
- new_token = await provider.authlib_client.fetch_access_token(
- refresh_token=token["refresh_token"],
- grant_type="refresh_token",
- )
- try:
- await update_token(provider.id, new_token)
- except PyJWTError as err:
- logger.info(f"Cannot refresh token: {err.__class__.__name__}")
+ oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
+ token: Annotated[OAuth2Token, Depends(get_token)],
+) -> JSONResponse:
+ assert request is not None # Just to keep QA checks happy
+ if (url := oidc_provider.server_metadata.get("introspection_endpoint")) is None:
raise HTTPException(
- status.HTTP_510_NOT_EXTENDED, f"Token refresh error: {err.__class__.__name__}"
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="No intrispection endpoint found for the OIDC provider",
)
- return RedirectResponse(url=request.url_for("home"))
+ if (
+ response := await oidc_provider.post(
+ url,
+ token=token,
+ data={"token": token["access_token"]},
+ )
+ ).is_success:
+ return response.json()
+ else:
+ raise HTTPException(status_code=response.status_code, detail=response.text)
# Snippet for running standalone
@@ -341,7 +374,9 @@ def main():
parser.add_argument(
"-p", "--port", type=int, default=80, help="Port to listen to (default: 80)"
)
- parser.add_argument("-v", "--version", action="store_true", help="Print version and exit")
+ parser.add_argument(
+ "-v", "--version", action="store_true", help="Print version and exit"
+ )
args = parser.parse_args()
if args.version:
diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py
index 7b6fd0e..542a9c4 100644
--- a/src/oidc_test/models.py
+++ b/src/oidc_test/models.py
@@ -1,6 +1,5 @@
-import logging
from functools import cached_property
-from typing import Any
+from typing import Self
from pydantic import (
computed_field,
@@ -8,10 +7,9 @@ from pydantic import (
EmailStr,
ConfigDict,
)
+from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from sqlmodel import SQLModel, Field
-logger = logging.getLogger("oidc-test")
-
class Role(SQLModel, extra="ignore"):
name: str
@@ -27,16 +25,26 @@ class UserBase(SQLModel, extra="ignore"):
class User(UserBase):
- model_config = ConfigDict(arbitrary_types_allowed=True) # type:ignore
+ model_config = ConfigDict(arbitrary_types_allowed=True)
sub: str = Field(
description="""subject id of the user given by the oidc provider,
also the key for the database 'table'""",
)
userinfo: dict = {}
- access_token: str | None = None
- access_token_decoded: dict[str, Any] | None = None
- auth_provider_id: str
+ access_token: dict = {}
+ oidc_provider: StarletteOAuth2App | None = None
+
+ @classmethod
+ def from_auth(cls, userinfo: dict, oidc_provider: StarletteOAuth2App) -> Self:
+ user = cls(**userinfo)
+ user.userinfo = userinfo
+ user.oidc_provider = oidc_provider
+ # Add roles if they are provided in the token
+ if raw_ra := userinfo.get("realm_access"):
+ if raw_roles := raw_ra.get("roles"):
+ user.roles = [Role(name=raw_role) for raw_role in raw_roles]
+ return user
@computed_field
@cached_property
@@ -46,21 +54,5 @@ class User(UserBase):
def has_scope(self, scope: str) -> bool:
"""Check if the scope is present in user info or access token"""
info_scopes = self.userinfo.get("scope", "").split(" ")
- try:
- access_token_scopes = self.decode_access_token().get("scope", "").split(" ")
- except Exception as err:
- logger.debug(f"Cannot find scope because the access token cannot be decoded: {err}")
- access_token_scopes = []
+ access_token_scopes = self.access_token.get("scope", "").split(" ")
return scope in set(info_scopes + access_token_scopes)
-
- def decode_access_token(self, verify_signature: bool = True):
- assert self.access_token is not None, "no access_token"
- assert self.auth_provider_id is not None, "no auth_provider_id"
- from .auth_providers import providers
-
- return providers[self.auth_provider_id].decode(
- self.access_token, verify_signature=verify_signature
- )
-
- def get_scope(self, verify_signature: bool = True):
- return self.decode_access_token(verify_signature=verify_signature)["scope"]
diff --git a/src/oidc_test/registry.py b/src/oidc_test/registry.py
deleted file mode 100644
index 3b91ad4..0000000
--- a/src/oidc_test/registry.py
+++ /dev/null
@@ -1,47 +0,0 @@
-from importlib.metadata import entry_points
-import logging
-
-from pydantic import BaseModel, ConfigDict
-
-from oidc_test.models import User
-
-logger = logging.getLogger("registry")
-
-
-class ProcessResult(BaseModel):
- model_config = ConfigDict(
- extra="allow",
- )
-
-
-class ProcessError(Exception):
- pass
-
-
-class Resource(BaseModel):
- name: str
- scope_required: str | None = None
- role_required: str | None = None
- is_public: bool = False
- default_resource_id: str | None = None
-
- def __init__(self, name: str):
- super().__init__()
- self.__id__ = name
-
- async def process(self, user: User | None, resource_id: str | None = None) -> ProcessResult:
- logger.warning(f"{self.__id__} should define a process method")
- return ProcessResult()
-
-
-class ResourceRegistry(BaseModel):
- resources: dict[str, Resource] = {}
-
- def make_registry(self):
- for ep in entry_points().select(group="oidc_test.resource_provider"):
- ResourceClass = ep.load()
- if issubclass(ResourceClass, Resource):
- self.resources[ep.name] = ResourceClass(ep.name)
-
-
-registry = ResourceRegistry()
diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py
index ddc5762..ecaa597 100644
--- a/src/oidc_test/resource_server.py
+++ b/src/oidc_test/resource_server.py
@@ -1,310 +1,55 @@
-from typing import Annotated, Any
+from datetime import datetime
import logging
-from json import JSONDecodeError
-from authlib.oauth2.rfc6749 import OAuth2Token
-from httpx import AsyncClient, HTTPError
-from jwt.exceptions import DecodeError, ExpiredSignatureError, InvalidTokenError
-from fastapi import FastAPI, HTTPException, Depends, Request, status
-from fastapi.middleware.cors import CORSMiddleware
+from httpx import AsyncClient
-# from starlette.middleware.sessions import SessionMiddleware
-# from authlib.integrations.starlette_client.apps import StarletteOAuth2App
-# from authlib.oauth2.rfc6749 import OAuth2Token
+from .models import User
-from oidc_test.auth.provider import Provider
-from oidc_test.auth.utils import (
- get_user_from_token_or_none,
- oauth2_scheme_optional,
-)
-from oidc_test.auth_providers import providers
-from oidc_test.settings import ResourceProvider, settings
-from oidc_test.models import User
-from oidc_test.registry import ProcessError, ProcessResult, registry
-
-logger = logging.getLogger("oidc-test")
-
-resource_server = FastAPI()
+logger = logging.getLogger(__name__)
-resource_server.add_middleware(
- CORSMiddleware,
- allow_origins=settings.cors_origins,
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
-)
-
-# SessionMiddleware is required by authlib
-# resource_server.add_middleware(
-# SessionMiddleware,
-# secret_key=settings.secret_key,
-# )
-
-# Route for OAuth resource server
-
-
-# Routes for RBAC based tests
-
-
-@resource_server.get("/")
-async def resources() -> dict[str, dict[str, Any]]:
- return {"internal": {}, "plugins": registry.resources}
-
-
-@resource_server.get("/{resource_name}")
-@resource_server.get("/{resource_name}/{resource_id}")
-async def get_resource(
- resource_name: str,
- user: Annotated[User | None, Depends(get_user_from_token_or_none)],
- token: Annotated[OAuth2Token | None, Depends(oauth2_scheme_optional)],
- resource_id: str | None = None,
-):
- """Generic path for testing a resource provided by a provider.
- There's no field validation (response type of ProcessResult) on purpose,
- leaving the responsibility of the response validation to resource providers"""
- # Get the resource if it's defined in user auth provider's resources (external)
- if user is not None:
- provider = providers[user.auth_provider_id]
- if ":" in resource_name:
- # Third-party resource provider: send the request with the request token
- resource_provider_id, resource_name = resource_name.split(":", 1)
- provider = providers[user.auth_provider_id]
- resource_provider: ResourceProvider = provider.get_resource_provider(
- resource_provider_id
- )
- resource_url = resource_provider.get_resource_url(resource_name)
- async with AsyncClient(verify=resource_provider.verify_ssl) as client:
- try:
- logger.debug(f"GET request to {resource_url}")
- resp = await client.get(
- resource_url,
- headers={
- "Content-type": "application/json",
- "Authorization": f"Bearer {token}",
- "auth_provider": user.auth_provider_id,
- },
- )
- except HTTPError as err:
- raise HTTPException(
- status.HTTP_503_SERVICE_UNAVAILABLE, err.__class__.__name__
- )
- except Exception as err:
- raise HTTPException(
- status.HTTP_500_INTERNAL_SERVER_ERROR, err.__class__.__name__
- )
- else:
- if resp.is_success:
- return resp.json()
- else:
- reason_str: str
- try:
- reason_str = resp.json().get("detail", str(resp))
- except Exception:
- reason_str = str(resp.text)
- raise HTTPException(resp.status_code, reason_str)
- # Third party resource (provided through the auth provider)
- # The token is just passed on
- # XXX: is this branch valid anymore?
- if resource_name in [r.resource_name for r in provider.resources]:
- return await get_auth_provider_resource(
- provider=provider,
- resource_name=resource_name,
- token=token,
- user=user,
- )
- # Internal resource (provided here)
- if resource_name in registry.resources:
- resource = registry.resources[resource_name]
- reason: dict[str, str] = {}
- if not resource.is_public:
- if user is None:
- raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Resource is not public")
- else:
- if resource.scope_required is not None and not user.has_scope(
- resource.scope_required
- ):
- reason["scope"] = (
- f"No scope {resource.scope_required} in the access token "
- "but it is required for accessing this resource"
- )
- if (
- resource.role_required is not None
- and resource.role_required not in user.roles_as_set
- ):
- reason["role"] = (
- f"You don't have the role {resource.role_required} "
- "but it is required for accessing this resource"
- )
- if len(reason) == 0:
- try:
- resp = await resource.process(user=user, resource_id=resource_id)
- return resp
- except ProcessError as err:
- raise HTTPException(
- status.HTTP_401_UNAUTHORIZED, f"Cannot process resource: {err}"
- )
- else:
- raise HTTPException(status.HTTP_401_UNAUTHORIZED, ", ".join(reason.values()))
+async def get_resource(resource_id: str, user: User) -> dict:
+ """
+ Resource processing: build an informative rely as a simple showcase
+ """
+ pname = getattr(user.oidc_provider, "name", "?")
+ resp = {
+ "hello": f"Hi {user.name} from an OAuth resource provider",
+ "comment": f"I received a request for '{resource_id}' "
+ + f"with an access token signed by {pname}",
+ }
+ # For the demo, resource resource_id matches a scope get:resource_id,
+ # but this has to be refined for production
+ required_scope = f"get:{resource_id}"
+ # Check if the required scope is in the scopes allowed in userinfo
+ if user.has_scope(required_scope):
+ await process(user, resource_id, resp)
else:
- raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Unknown resource")
- # return await get_resource_(resource_name, user, **request.query_params)
+ ## For the showcase, giving a explanation.
+ ## Alternatively, raise HTTP_401_UNAUTHORIZED
+ resp["sorry"] = (
+ f"No scope {required_scope} in the access token "
+ + "but it is required for accessing this resource."
+ )
+ return resp
-async def get_auth_provider_resource(
- provider: Provider, resource_name: str, token: OAuth2Token | None, user: User
-) -> ProcessResult:
- if token is None:
- raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No auth token")
- access_token = token
- async with AsyncClient() as client:
- resp = await client.get(
- url=provider.get_resource_url(resource_name),
- headers={
- "Content-type": "application/json",
- "Authorization": f"Bearer {access_token}",
- },
- )
- if resp.is_error:
- raise HTTPException(resp.status_code, f"Cannot fetch resource: {resp.reason_phrase}")
- # Only a demo, real application would really process the response
- resp_length = len(resp.text)
- if resp_length > 1024:
- return ProcessResult(
- msg=f"The resource is too long ({resp_length} bytes) to show in this demo, here is just the begining in raw format",
- start=resp.text[:100] + "...",
- )
+async def process(user, resource_id, resp):
+ """
+ Too simple to be serious.
+ It's a good fit for a plugin architecture for production
+ """
+ assert user is not None
+ if resource_id == "time":
+ resp["time"] = datetime.now().strftime("%c")
+ elif resource_id == "bs":
+ async with AsyncClient() as client:
+ bs = await client.get("https://corporatebs-generator.sameerkumar.website/")
+ resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.")
else:
- try:
- resp_json = resp.json()
- except JSONDecodeError:
- return ProcessResult(msg="The resource is not formatted in JSON", text=resp.text)
- if isinstance(resp_json, dict):
- return ProcessResult(**resp.json())
- elif isinstance(resp_json, list):
- return ProcessResult(**{str(i): line for i, line in enumerate(resp_json)})
+ resp["sorry"] = f"I don't known how to give '{resource_id}'."
-# @resource_server.get("/public")
-# async def public() -> dict:
-# return {"msg": "Not protected"}
-#
-#
-# @resource_server.get("/protected")
-# async def get_protected(user: Annotated[User, Depends(get_user_from_token)]):
-# assert user is not None # Just to keep QA checks happy
-# return {"msg": "Only authenticated users can see this"}
-#
-#
-# @resource_server.get("/protected-by-foorole")
-# async def get_protected_by_foorole(
-# user: Annotated[User, Depends(UserWithRole("foorole"))],
-# ):
-# assert user is not None
-# return {"msg": "Only users with foorole can see this"}
-#
-#
-# @resource_server.get("/protected-by-barrole")
-# async def get_protected_by_barrole(
-# user: Annotated[User, Depends(UserWithRole("barrole"))],
-# ):
-# assert user is not None
-# return {"msg": "Protected by barrole"}
-#
-#
-# @resource_server.get("/protected-by-foorole-and-barrole")
-# async def get_protected_by_foorole_and_barrole(
-# user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))],
-# ):
-# assert user is not None # Just to keep QA checks happy
-# return {"msg": "Only users with foorole and barrole can see this"}
-#
-#
-# @resource_server.get("/protected-by-foorole-or-barrole")
-# async def get_protected_by_foorole_or_barrole(
-# user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))],
-# ):
-# assert user is not None # Just to keep QA checks happy
-# return {"msg": "Only users with foorole or barrole can see this"}
-
-# async def get_resource_(resource_id: str, user: User, **kwargs) -> dict:
-# """
-# Resource processing: build an informative rely as a simple showcase
-# """
-# if resource_id == "petition":
-# return await sign(user, kwargs["petition_id"])
-# provider = providers[user.auth_provider_id]
-# try:
-# pname = provider.name
-# except KeyError:
-# pname = "?"
-# resp = {
-# "hello": f"Hi {user.name} from an OAuth resource provider",
-# "comment": f"I received a request for '{resource_id}' "
-# + f"with an access token signed by {pname}",
-# }
-# # For the demo, resource resource_id matches a scope get:resource_id,
-# # but this has to be refined for production
-# required_scope = f"get:{resource_id}"
-# # Check if the required scope is in the scopes allowed in userinfo
-# try:
-# if user.has_scope(required_scope):
-# await process(user, resource_id, resp)
-# else:
-# ## For the showcase, giving a explanation.
-# ## Alternatively, raise HTTP_401_UNAUTHORIZED
-# raise HTTPException(
-# status.HTTP_401_UNAUTHORIZED,
-# f"No scope {required_scope} in the access token "
-# + "but it is required for accessing this resource",
-# )
-# except ExpiredSignatureError:
-# raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token's signature has expired")
-# except InvalidTokenError:
-# raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token is invalid")
-# return resp
-
-
-# async def process(user, resource_id, resp):
-# """
-# Too simple to be serious.
-# It's a good fit for a plugin architecture for production
-# """
-# if resource_id == "time":
-# resp["time"] = datetime.now().strftime("%c")
-# elif resource_id == "bs":
-# async with AsyncClient() as client:
-# bs = await client.get("https://corporatebs-generator.sameerkumar.website/")
-# resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.")
-# else:
-# raise HTTPException(
-# status.HTTP_401_UNAUTHORIZED, f"I don't known how to give '{resource_id}'."
-# )
-
-
-# @resource_server.get("/introspect")
-# async def get_introspect(
-# request: Request,
-# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
-# token: Annotated[OAuth2Token, Depends(get_token)],
-# ) -> JSONResponse:
-# assert request is not None # Just to keep QA checks happy
-# if (url := oidc_provider.server_metadata.get("introspection_endpoint")) is None:
-# raise HTTPException(
-# status_code=status.HTTP_401_UNAUTHORIZED,
-# detail="No introspection endpoint found for the OIDC provider",
-# )
-# if (
-# response := await oidc_provider.post(
-# url,
-# token=token,
-# data={"token": token["access_token"]},
-# )
-# ).is_success:
-# return response.json()
-# else:
-# raise HTTPException(status_code=response.status_code, detail=response.text)
-
# assert user.oidc_provider is not None
### Get some info (TODO: refactor)
# if (auth_provider_id := user.oidc_provider.name) is None:
diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py
index ad80c06..329b9c0 100644
--- a/src/oidc_test/settings.py
+++ b/src/oidc_test/settings.py
@@ -4,7 +4,7 @@ import random
from typing import Type, Tuple
from pathlib import Path
-from pydantic import AnyHttpUrl, BaseModel, computed_field, AnyUrl
+from pydantic import BaseModel, computed_field, AnyUrl
from pydantic_settings import (
BaseSettings,
SettingsConfigDict,
@@ -13,33 +13,18 @@ from pydantic_settings import (
)
from starlette.requests import Request
+from .models import User
+
class Resource(BaseModel):
"""A resource with an URL that can be accessed with an OAuth2 access token"""
- resource_name: str
- name: str
- url: str
-
-
-class ResourceProvider(BaseModel):
id: str
name: str
- base_url: AnyUrl
- resources: list[Resource] = []
- verify_ssl: bool = True
-
- def get_resource(self, resource_name: str) -> Resource:
- return [
- resource for resource in self.resources if resource.resource_name == resource_name
- ][0]
-
- def get_resource_url(self, resource_name: str) -> str:
- return f"{self.base_url}{self.get_resource(resource_name).url}"
-class AuthProviderSettings(BaseModel):
- """Auth provider, can also be a resource server"""
+class OIDCProvider(BaseModel):
+ """OIDC provider, can also be a resource server"""
id: str
name: str
@@ -54,14 +39,10 @@ class AuthProviderSettings(BaseModel):
info_url: str | None = (
None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key)
)
+ info: dict[str, str | int] | None = (
+ None # Info fetched from info_url, eg. public key
+ )
public_key: str | None = None
- public_key_url: str | None = None
- signature_alg: str = "RS256"
- resource_provider_scopes: list[str] = []
- session_key: str = "sid"
- skip_verify_signature: bool = True
- disabled: bool = False
- resource_providers: list[ResourceProvider] = []
@computed_field
@property
@@ -73,20 +54,46 @@ class AuthProviderSettings(BaseModel):
def token_url(self) -> str:
return "auth/" + self.id
- def get_account_url(self, request: Request, user: dict) -> str | None:
+ def get_account_url(self, request: Request, user: User) -> str | None:
if self.account_url_template:
- if not (self.url.endswith("/") or self.account_url_template.startswith("/")):
+ if not (
+ self.url.endswith("/") or self.account_url_template.startswith("/")
+ ):
sep = "/"
else:
sep = ""
- return self.url + sep + self.account_url_template.format(request=request, user=user)
+ return (
+ self.url
+ + sep
+ + self.account_url_template.format(request=request, user=user)
+ )
else:
return None
+ def get_public_key(self) -> str | None:
+ """Return the public key formatted for decoding token"""
+ public_key = self.public_key or (
+ self.info is not None and self.info["public_key"]
+ )
+ if public_key is None:
+ return None
+ return f"""
+ -----BEGIN PUBLIC KEY-----
+ {public_key}
+ -----END PUBLIC KEY-----
+ """
-class AuthSettings(BaseModel):
+
+class ResourceProvider(BaseModel):
+ id: str
+ name: str
+ base_url: AnyUrl
+ resources: list[Resource] = []
+
+
+class OIDCSettings(BaseModel):
show_session_details: bool = False
- providers: list[AuthProviderSettings] = []
+ providers: list[OIDCProvider] = []
swagger_provider: str = ""
@@ -101,15 +108,12 @@ class Settings(BaseSettings):
model_config = SettingsConfigDict(env_nested_delimiter="__")
- auth: AuthSettings = AuthSettings()
+ oidc: OIDCSettings = OIDCSettings()
+ resource_providers: list[ResourceProvider] = []
secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16))
log: bool = False
- log_config_file: str = "log_conf.yaml"
insecure: Insecure = Insecure()
cors_origins: list[str] = []
- debug_token: bool = False
- show_token: bool = False
- show_external_resource_providers_links: bool = False
@classmethod
def settings_customise_sources(
@@ -128,7 +132,9 @@ class Settings(BaseSettings):
settings_cls,
Path(
Path(
- environ.get("OIDC_TEST_SETTINGS_FILE", Path.cwd() / "settings.yaml"),
+ environ.get(
+ "OIDC_TEST_SETTINGS_FILE", Path.cwd() / "settings.yaml"
+ ),
)
),
),
diff --git a/src/oidc_test/static/styles.css b/src/oidc_test/static/styles.css
index 1e8dc03..a4a0178 100644
--- a/src/oidc_test/static/styles.css
+++ b/src/oidc_test/static/styles.css
@@ -3,9 +3,9 @@ body {
background-color: floralwhite;
margin: 0;
font-family: system-ui;
- text-align: center;
}
h1 {
+ text-align: center;
background-color: #f7c7867d;
margin: 0 0 0.2em 0;
box-shadow: 0px 0.2em 0.2em #f7c7867d;
@@ -21,18 +21,9 @@ hr {
.hidden {
display: none;
}
-.version {
- position: absolute;
- font-size: 75%;
- top: 0.3em;
- right: 0.3em;
-}
.center {
text-align: center;
}
-.error {
- color: darkred;
-}
.content {
width: 100%;
display: flex;
@@ -64,6 +55,7 @@ hr {
border: 2px solid darkkhaki;
padding: 3px 6px;
text-decoration: none;
+ text-align: center;
color: black;
}
.user-info a.logout:hover {
@@ -78,6 +70,7 @@ hr {
margin: 0;
}
.debug-auth p {
+ text-align: center;
border-bottom: 1px solid black;
}
.debug-auth ul {
@@ -108,25 +101,15 @@ hr {
.hasResponseStatus.status-503 {
background-color: #ffA88050;
}
-
-.role, .scope {
- padding: 3px 6px;
- margin: 3px;
- border-radius: 6px;
-}
-
.role {
+ padding: 3px 6px;
background-color: #44228840;
}
-.scope {
- background-color: #8888FF80;
-}
-
-
/* For home */
.login-box {
+ text-align: center;
background-color: antiquewhite;
margin: 0.5em auto;
width: fit-content;
@@ -148,92 +131,41 @@ hr {
.providers .provider {
min-height: 2em;
}
-.providers .provider .link {
+.providers .provider a.link {
text-decoration: none;
max-height: 2em;
}
-.providers .provider .link {
+.providers .provider .link div {
+ text-align: center;
background-color: #f7c7867d;
border-radius: 8px;
padding: 6px;
text-align: center;
color: black;
- font-weight: 400;
+ font-weight: bold;
cursor: pointer;
- border: 0;
- width: 100%;
}
-
-.providers .provider .link.disabled {
- color: gray;
- cursor: not-allowed;
-}
-
.providers .provider .hint {
font-size: 80%;
max-width: 13em;
}
.providers .error {
padding: 3px 6px;
+ text-align: center;
font-weight: bold;
flex: 1 1 auto;
}
.content .links-to-check {
display: flex;
+ text-align: center;
justify-content: center;
gap: 0.5em;
flex-flow: wrap;
}
-.content .links-to-check button {
+.content .links-to-check a {
color: black;
padding: 5px 10px;
text-decoration: none;
border-radius: 8px;
- border: none;
- cursor: pointer;
}
-.token {
- overflow-wrap: anywhere;
- font-family: monospace;
-}
-
-.resourceResult {
- padding: 0.5em;
- display: flex;
- gap: 0.5em;
- width: fit-content;
- align-items: center;
- margin: 5px auto;
- box-shadow: 0px 0px 10px #90c3eeA0;
- background-color: #90c3eeA0;
- border-radius: 8px;
-}
-
-.resources {
- display: flex;
-}
-
-.resource {
- text-align: center;
-}
-
-.token-info {
- margin: 0 1em;
-}
-
-.key {
- font-weight: bold;
-}
-
-.token .key, .token .value {
- display: inline;
-}
-.token .value {
- padding-left: 1em;
-}
-
-.msg {
- text-align: center;
- font-weight: bold;
-}
diff --git a/src/oidc_test/static/utils.js b/src/oidc_test/static/utils.js
index e988dfe..6b40d3d 100644
--- a/src/oidc_test/static/utils.js
+++ b/src/oidc_test/static/utils.js
@@ -1,90 +1,19 @@
-async function checkHref(elem, token, authProvider) {
- const msg = document.getElementById("msg")
- const resourceName = elem.getAttribute("resource-name")
- const resourceId = elem.getAttribute("resource-id")
- const resourceProviderId = elem.getAttribute("resource-provider-id") ? elem.getAttribute("resource-provider-id") : ""
- const fqResourceName = resourceProviderId ? `${resourceProviderId}:${resourceName}` : resourceName
- const url = resourceId ? `resource/${fqResourceName}/${resourceId}` : `resource/${fqResourceName}`
- const resp = await fetch(url, {
- method: "GET",
- headers: new Headers({
- "Content-type": "application/json",
- "Authorization": `Bearer ${token}`,
- "auth_provider": authProvider,
- }),
- }).catch(err => {
- msg.innerHTML = "Cannot fetch resource: " + err.message
- resourceElem.innerHTML = ""
- })
- if (resp === undefined) {
- return
- } else {
- elem.classList.add("hasResponseStatus")
- elem.classList.add("status-" + resp.status)
- elem.title = "Response code: " + resp.status + " - " + resp.statusText
+function checkHref(elem) {
+ var xmlHttp = new XMLHttpRequest()
+ xmlHttp.onreadystatechange = function () {
+ if (xmlHttp.readyState == 4) {
+ elem.classList.add("hasResponseStatus")
+ elem.classList.add("status-" + xmlHttp.status)
+ elem.title = "Response code: " + xmlHttp.status + " - " + xmlHttp.statusText
+ }
}
+ xmlHttp.open("GET", elem.href, true) // true for asynchronous
+ xmlHttp.send(null)
}
-function checkPerms(className, token, authProvider) {
+function checkPerms(className) {
var rootElems = document.getElementsByClassName(className)
Array.from(rootElems).forEach(elem =>
- Array.from(elem.children).forEach(elem => checkHref(elem, token, authProvider))
- )
-}
-
-async function get_resource(resourceName, token, authProvider, resourceId, resourceProviderId) {
- // BaseUrl for an external resource provider
- //if (!keycloak.keycloak) { return }
- const msg = document.getElementById("msg")
- const resourceElem = document.getElementById('resource')
- const fqResourceName = resourceProviderId ? `${resourceProviderId}:${resourceName}` : resourceName
- const url = resourceId ? `resource/${fqResourceName}/${resourceId}` : `resource/${fqResourceName}`
- const resp = await fetch(url, {
- method: "GET",
- headers: new Headers({
- "Content-type": "application/json",
- "Authorization": `Bearer ${token}`,
- "auth_provider": authProvider,
- }),
- }).catch(err => {
- msg.innerHTML = "Cannot fetch resource: " + err.message
- resourceElem.innerHTML = ""
- })
- if (resp === undefined) {
- return
- }
- const resource = await resp.json()
- if (!resp.ok) {
- msg.innerHTML = resource["detail"]
- resourceElem.innerHTML = ""
- return
- }
- msg.innerHTML = ""
- resourceElem.innerHTML = ""
- Object.entries(resource).forEach(
- ([key, value]) => {
- let r = document.createElement('div')
- let kElem = document.createElement('div')
- kElem.innerText = key
- kElem.className = "key"
- let vElem = document.createElement('div')
- if (typeof value == "object") {
- Object.entries(value).forEach(v => {
- const ne = document.createElement('div')
- ne.innerHTML = `${v[0]} : ${v[1]} `
- vElem.appendChild(ne)
- })
- }
- else {
- vElem.innerText = value
- }
- vElem.className = "value"
- if (key == "sorry") {
- vElem.classList.add("error")
- }
- r.appendChild(kElem)
- r.appendChild(vElem)
- resourceElem.appendChild(r)
- }
+ Array.from(elem.children).forEach(elem => checkHref(elem))
)
}
diff --git a/src/oidc_test/templates/base.html b/src/oidc_test/templates/base.html
index 157e26f..3bdb3f3 100644
--- a/src/oidc_test/templates/base.html
+++ b/src/oidc_test/templates/base.html
@@ -4,8 +4,7 @@
-
- v. {{ __version__}}
+
OIDC-test - FastAPI client
{% block content %}
{% endblock %}
diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html
index 167616f..c062101 100644
--- a/src/oidc_test/templates/home.html
+++ b/src/oidc_test/templates/home.html
@@ -5,38 +5,31 @@
with OpenID Connect and OAuth2 with different providers.
{% if not user %}
-
-
Log in with:
-
- {% for provider in providers.values() %}
-
-
-
- {{ provider.name }}
-
-
- {{ provider.hint }}
-
-
- {% else %}
- There is no authentication provider defined.
- Hint: check the settings.yaml file.
- {% endfor %}
-
-
- {% else %}
+
+
Log in with:
+
+ {% for provider in settings.oidc.providers %}
+
+
+ {{ provider.name }}
+
+ {{ provider.hint }}
+
+
+ {% else %}
+ There is no authentication provider defined.
+ Hint: check the settings.yaml file.
+ {% endfor %}
+
+
+ {% endif %}
+ {% if user %}
Hey, {{ user.name }}
{% if user.picture %}
{% endif %}
{{ user.email }}
-
- Provider:
- {{ auth_provider.name }}
-
{% if user.roles %}
Roles:
@@ -45,125 +38,54 @@
{% endfor %}
{% endif %}
- {% if access_token_scope %}
-
- Scopes :
- {% for scope in access_token_scope.split(' ') %}
- {{ scope }}
- {% endfor %}
-
+
+ Provider:
+ {{ oidc_provider_settings.name }}
+
+ {% if oidc_provider_settings.account_url_template %}
+
Account management
{% endif %}
- {% if auth_provider.account_url_template %}
-
- Account management
-
- {% endif %}
-
Refresh access token
Logout
{% endif %}
- {% if resources %}
-
This application provides all these resources, eventually protected with scope or roles:
-
- {% for name, resource in resources.items() %}
- {% if resource.default_resource_id %}
-
- {{ resource.name }}
-
- {% else %}
-
- {{ resource.name }}
-
- {% endif %}
- {% endfor %}
-
- {% endif %}
- {% if auth_provider.resources %}
-
{{ auth_provider.name }} is also defined as a provider for these resources:
-
- {% for resource in auth_provider.resources %}
- {% if resource.default_resource_id %}
-
- {{ resource.name }}
-
- {% else %}
-
- {{ resource.name }}
-
- {% endif %}
- {% endfor %}
-
- {% endif %}
- {% if resource_providers %}
-
{{ auth_provider.name }} allows this application to request resources from third party resource providers:
- {% for resource_provider in resource_providers %}
-
- {{ resource_provider.name }}
- {% for resource in resource_provider.resources %}
-
- {{ resource.name }}
-
- {% endfor %}
-
- {% endfor %}
- {% endif %}
-
-
-
+
+ These links should get different response codes depending on the authorization:
+
+
-
- {% if show_token and id_token_parsed %}
-
-
-
-
id token
-
- {% for key, value in id_token_parsed.items() %}
-
-
{{ key }}
-
{{ value }}
-
+ {% if resources %}
+
+ Resources for this provider:
+
+
-
access token
-
- {% for key, value in access_token_parsed.items() %}
-
-
{{ key }}
-
{{ value }}
-
- {% endfor %}
-
-
refresh token
-
- {% for key, value in refresh_token_parsed.items() %}
-
-
{{ key }}
-
{{ value }}
-
- {% endfor %}
-
+ {% endif %}
+ {% if user_info_details %}
+
+
+
User info
+
+ {% for key, value in user_info_details.items() %}
+
+ {{ key }} : {{ value }}
+
+ {% endfor %}
+
+
+
Now is: {{ now.strftime("%T, %D") }}
{% endif %}
{% endblock %}
diff --git a/src/oidc_test/templates/non_compliant_logout.html b/src/oidc_test/templates/non_compliant_logout.html
index 56758de..24a96ae 100644
--- a/src/oidc_test/templates/non_compliant_logout.html
+++ b/src/oidc_test/templates/non_compliant_logout.html
@@ -6,12 +6,12 @@
authorisation to log in again without asking for credentials.
- This is because {{ auth_provider.name }} does not provide "end_session_endpoint" in its metadata
- (see: {{ auth_provider.authlib_client._server_metadata_url }} ).
+ This is because {{ oidc_provider.name }} does not provide "end_session_endpoint" in its metadata
+ (see: {{ oidc_provider._server_metadata_url }} ).
You can just also go back to the application home page , but
- it recommended to go to the OIDC provider's site
+ it recommended to go to the OIDC provider's site
and log out explicitely from there.
{% endblock %}
diff --git a/uv.lock b/uv.lock
index 0566bb5..01b64de 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1,5 +1,4 @@
version = 1
-revision = 1
requires-python = ">=3.13"
[[package]]
@@ -207,18 +206,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/68/1b/e0a87d256e40e8c888847551b20a017a6b98139178505dc7ffb96f04e954/dnspython-2.7.0-py3-none-any.whl", hash = "sha256:b4c34b7d10b51bcc3a5071e7b8dee77939f1e878477eeecc965e9835f63c6c86", size = 313632 },
]
-[[package]]
-name = "dunamai"
-version = "1.23.0"
-source = { registry = "https://pypi.org/simple" }
-dependencies = [
- { name = "packaging" },
-]
-sdist = { url = "https://files.pythonhosted.org/packages/06/4e/a5c8c337a1d9ac0384298ade02d322741fb5998041a5ea74d1cd2a4a1d47/dunamai-1.23.0.tar.gz", hash = "sha256:a163746de7ea5acb6dacdab3a6ad621ebc612ed1e528aaa8beedb8887fccd2c4", size = 44681 }
-wheels = [
- { url = "https://files.pythonhosted.org/packages/21/4c/963169386309fec4f96fd61210ac0a0666887d0fb0a50205395674d20b71/dunamai-1.23.0-py3-none-any.whl", hash = "sha256:a0906d876e92441793c6a423e16a4802752e723e9c9a5aabdc5535df02dbe041", size = 26342 },
-]
-
[[package]]
name = "ecdsa"
version = "0.19.0"
@@ -495,6 +482,7 @@ wheels = [
[[package]]
name = "oidc-fastapi-test"
+version = "0.0.0"
source = { editable = "." }
dependencies = [
{ name = "authlib" },
@@ -513,7 +501,6 @@ dependencies = [
[package.dev-dependencies]
dev = [
- { name = "dunamai" },
{ name = "ipdb" },
{ name = "pytest" },
]
@@ -536,7 +523,6 @@ requires-dist = [
[package.metadata.requires-dev]
dev = [
- { name = "dunamai", specifier = ">=1.23.0" },
{ name = "ipdb", specifier = ">=0.13.13" },
{ name = "pytest", specifier = ">=8.3.4" },
]