Compare commits

..

3 commits
main ... db

Author SHA1 Message Date
2051addfc2 Merge branch 'db'
Some checks failed
/ build (push) Failing after 5s
/ test (push) Failing after 5s
2025-02-17 18:48:28 +01:00
2534c1cbb4 Add postgres db (messy) 2025-02-17 17:27:12 +01:00
fb433e27be Add postgres db (messy)
Some checks failed
/ build (push) Failing after 6s
/ test (push) Failing after 6s
2025-02-17 02:42:38 +01:00
20 changed files with 384 additions and 440 deletions

View file

@ -19,7 +19,7 @@ jobs:
- name: Install the latest version of uv - name: Install the latest version of uv
uses: astral-sh/setup-uv@v4 uses: astral-sh/setup-uv@v4
with: with:
version: "0.6.9" version: "0.5.16"
- name: Install - name: Install
run: uv sync run: uv sync
@ -27,26 +27,34 @@ jobs:
- name: Run tests (API call) - name: Run tests (API call)
run: .venv/bin/pytest -s tests/basic.py run: .venv/bin/pytest -s tests/basic.py
- name: Get version - name: Get version with git describe
run: echo "VERSION=$(.venv/bin/dunamai from any --style semver)" >> $GITHUB_ENV id: version
run: |
echo "version=$(git describe)" >> $GITHUB_OUTPUT
echo "$VERSION"
- name: Version - name: Check if the container should be built
run: echo $VERSION id: builder
env:
RUN: ${{ toJSON(inputs.build || !contains(steps.version.outputs.version, '-')) }}
run: |
echo "run=$RUN" >> $GITHUB_OUTPUT
echo "Run build: $RUN"
- name: Get distance from tag - name: Set the version in pyproject.toml (workaround for uv not supporting dynamic version)
run: echo "DISTANCE=$(.venv/bin/dunamai from any --format '{distance}')" >> $GITHUB_ENV if: fromJSON(steps.builder.outputs.run)
env:
- name: Distance VERSION: ${{ steps.version.outputs.version }}
run: echo $DISTANCE run: sed "s/0.0.0/$VERSION/" -i pyproject.toml
- name: Workaround for bug of podman-login - name: Workaround for bug of podman-login
if: env.DISTANCE == '0' if: fromJSON(steps.builder.outputs.run)
run: | run: |
mkdir -p $HOME/.docker mkdir -p $HOME/.docker
echo "{ \"auths\": {} }" > $HOME/.docker/config.json echo "{ \"auths\": {} }" > $HOME/.docker/config.json
- name: Log in to the container registry (with another workaround) - name: Log in to the container registry (with another workaround)
if: env.DISTANCE == '0' if: fromJSON(steps.builder.outputs.run)
uses: actions/podman-login@v1 uses: actions/podman-login@v1
with: with:
registry: ${{ vars.REGISTRY }} registry: ${{ vars.REGISTRY }}
@ -55,30 +63,30 @@ jobs:
auth_file_path: /tmp/auth.json auth_file_path: /tmp/auth.json
- name: Build the container image - name: Build the container image
if: env.DISTANCE == '0' if: fromJSON(steps.builder.outputs.run)
uses: actions/buildah-build@v1 uses: actions/buildah-build@v1
with: with:
image: oidc-fastapi-test image: oidc-fastapi-test
oci: true oci: true
labels: oidc-fastapi-test labels: oidc-fastapi-test
tags: "latest ${{ env.VERSION }}" tags: latest ${{ steps.version.outputs.version }}
containerfiles: | containerfiles: |
./Containerfile ./Containerfile
- name: Push the image to the registry - name: Push the image to the registry
if: env.DISTANCE == '0' if: fromJSON(steps.builder.outputs.run)
uses: actions/push-to-registry@v2 uses: actions/push-to-registry@v2
with: with:
registry: "docker://${{ vars.REGISTRY }}/${{ vars.ORGANISATION }}" registry: "docker://${{ vars.REGISTRY }}/${{ vars.ORGANISATION }}"
image: oidc-fastapi-test image: oidc-fastapi-test
tags: "latest ${{ env.VERSION }}" tags: latest ${{ steps.version.outputs.version }}
- name: Build wheel - name: Build wheel
if: env.DISTANCE == '0' if: fromJSON(steps.builder.outputs.run)
run: uv build --wheel run: uv build --wheel
- name: Publish Python package (home) - name: Publish Python package (home)
if: env.DISTANCE == '0' if: fromJSON(steps.builder.outputs.run)
env: env:
LOCAL_PYPI_TOKEN: ${{ secrets.LOCAL_PYPI_TOKEN }} LOCAL_PYPI_TOKEN: ${{ secrets.LOCAL_PYPI_TOKEN }}
run: uv publish --publish-url https://code.philo.ydns.eu/api/packages/philorg/pypi --token $LOCAL_PYPI_TOKEN run: uv publish --publish-url https://code.philo.ydns.eu/api/packages/philorg/pypi --token $LOCAL_PYPI_TOKEN

View file

@ -19,7 +19,7 @@ jobs:
- name: Install the latest version of uv - name: Install the latest version of uv
uses: astral-sh/setup-uv@v4 uses: astral-sh/setup-uv@v4
with: with:
version: "0.6.3" version: "0.5.16"
- name: Install - name: Install
run: uv sync run: uv sync

View file

@ -1,4 +1,4 @@
FROM docker.io/library/python:latest FROM docker.io/library/python:alpine
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/

View file

@ -52,59 +52,31 @@ given by the OIDC providers.
For example: For example:
```yaml ```yaml
secret_key: AVeryWellKeptSecret oidc:
debug_token: no secret_key: "ASecretNoOneKnows"
show_token: yes show_session_details: yes
log: yes
auth:
providers: providers:
- id: auth0 - id: auth0
name: Okta / Auth0 name: Okta / Auth0
url: https://<your_auth0_app_URL> url: "https://<your_auth0_app_URL>"
public_key_url: https://<your_auth0_app_URL>/pem client_id: "<your_auth0_client_id>"
client_id: <your_auth0_client_id> client_secret: "client_secret_generated_by_auth0"
client_secret: client_secret_generated_by_auth0 hint: "A hint for test credentials"
hint: A hint for test credentials
- id: keycloak - id: keycloak
name: Keycloak at somewhere name: Keycloak at somewhere
url: https://<the_keycloak_realm_url> url: "https://<the_keycloak_realm_url>"
info_url: https://philo.ydns.eu/auth/realms/test account_url_template: "/account"
account_url_template: /account client_id: "<your_keycloak_client_id>"
client_id: <your_keycloak_client_id> client_secret: "client_secret_generated_by_keycloak"
client_secret: <client_secret_generated_by_keycloak> hint: "User: foo, password: foofoo"
hint: A hint for test credentials
code_challenge_method: S256
resource_provider_scopes:
- get:time
- get:bs
resource_providers:
- id: <third_party_resource_provider_id>
name: A third party resource provider
base_url: https://some.example.com/
verify_ssl: yes
resources:
- name: Public RS2
resource_name: public
url: resource/public
- name: BS RS2
resource_name: bs
url: resource/bs
- name: Time RS2
resource_name: time
url: resource/time
- id: codeberg - id: codeberg
disabled: no
name: Codeberg name: Codeberg
url: https://codeberg.org url: "https://codeberg.org"
account_url_template: /user/settings account_url_template: "/user/settings"
client_id: <your_codeberg_client_id> client_id: "<your_codeberg_client_id>"
client_secret: client_secret_generated_by_codeberg client_secret: "client_secret_generated_by_codeberg"
info_url: https://codeberg.org/login/oauth/keys
session_key: sub
skip_verify_signature: no
resources: resources:
- name: List of repos - name: List of repos
id: repos id: repos

View file

@ -1,14 +1,16 @@
[project] [project]
name = "oidc-fastapi-test" name = "oidc-fastapi-test"
#version = "0.0.0" version = "0.0.0"
dynamic = ["version"] # dynamic = ["version"]
description = "Add your description here" description = "Add your description here"
readme = "README.md" readme = "README.md"
requires-python = ">=3.13" requires-python = ">=3.13"
dependencies = [ dependencies = [
"asyncpg>=0.30.0",
"authlib>=1.4.0", "authlib>=1.4.0",
"cachetools>=5.5.0", "cachetools>=5.5.0",
"fastapi[standard]>=0.115.6", "fastapi[standard]>=0.115.6",
"greenlet>=3.1.1",
"httpx>=0.28.1", "httpx>=0.28.1",
"itsdangerous>=2.2.0", "itsdangerous>=2.2.0",
"passlib[bcrypt]>=1.7.4", "passlib[bcrypt]>=1.7.4",
@ -24,21 +26,14 @@ dependencies = [
oidc-test = "oidc_test.main:main" oidc-test = "oidc_test.main:main"
[dependency-groups] [dependency-groups]
dev = ["dunamai>=1.23.0", "ipdb>=0.13.13", "pytest>=8.3.4"] dev = ["ipdb>=0.13.13", "pytest>=8.3.4"]
[build-system] [build-system]
requires = ["hatchling", "uv-dynamic-versioning"] requires = ["hatchling"]
build-backend = "hatchling.build" build-backend = "hatchling.build"
[tool.hatch.version]
source = "uv-dynamic-versioning"
[tool.hatch.build.targets.wheel] [tool.hatch.build.targets.wheel]
packages = ["src/oidc_test"] packages = ["src/oidc_test"]
package = true
[tool.uv-dynamic-versioning]
style = "semver"
[tool.uv] [tool.uv]
package = true package = true

View file

@ -1,11 +0,0 @@
import importlib.metadata
try:
from dunamai import Version, Style
__version__ = Version.from_git().serialize(style=Style.SemVer, dirty=True)
except ImportError:
# __name__ could be used if the package name is the same
# as the directory - not the case here
# __version__ = importlib.metadata.version(__name__)
__version__ = importlib.metadata.version("oidc-fastapi-test")

View file

@ -7,7 +7,7 @@ from pydantic import ConfigDict
from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from httpx import AsyncClient from httpx import AsyncClient
from oidc_test.settings import AuthProviderSettings, ResourceProvider, Resource, settings from oidc_test.settings import AuthProviderSettings, settings
from oidc_test.models import User from oidc_test.models import User
logger = logging.getLogger("oidc-test") logger = logging.getLogger("oidc-test")
@ -24,7 +24,6 @@ class Provider(AuthProviderSettings):
authlib_client: StarletteOAuth2App = StarletteOAuth2App(None) authlib_client: StarletteOAuth2App = StarletteOAuth2App(None)
info: dict[str, Any] = {} info: dict[str, Any] = {}
unknown_auth_user: User unknown_auth_user: User
logout_with_id_token_hint: bool = True
def decode(self, token: str, verify_signature: bool | None = None) -> dict[str, Any]: def decode(self, token: str, verify_signature: bool | None = None) -> dict[str, Any]:
"""Decode the token with signature check""" """Decode the token with signature check"""
@ -61,34 +60,28 @@ class Provider(AuthProviderSettings):
if self.info_url is not None: if self.info_url is not None:
try: try:
provider_info = await client.get(self.info_url) provider_info = await client.get(self.info_url)
except Exception as err: except Exception:
logger.debug("Provider_info: cannot connect")
logger.exception(err)
raise NoPublicKey raise NoPublicKey
try: try:
self.info = provider_info.json() self.info = provider_info.json()
except JSONDecodeError: except JSONDecodeError:
logger.debug("Provider_info: cannot decode json response")
raise NoPublicKey raise NoPublicKey
if "public_key" in self.info: if "public_key" in self.info:
# For Keycloak # For Keycloak
try: try:
public_key = str(self.info["public_key"]) public_key = str(self.info["public_key"])
except KeyError: except KeyError:
logger.debug("Provider_info: cannot get public_key")
raise NoPublicKey raise NoPublicKey
elif "keys" in self.info: elif "keys" in self.info:
# For Forgejo/Gitea # For Forgejo/Gitea
try: try:
public_key = str(self.info["keys"][0]["n"]) public_key = str(self.info["keys"][0]["n"])
except KeyError: except KeyError:
logger.debug("Provider_info: cannot get key 0.n")
raise NoPublicKey raise NoPublicKey
if self.public_key_url is not None: if self.public_key_url is not None:
resp = await client.get(self.public_key_url) resp = await client.get(self.public_key_url)
public_key = resp.text public_key = resp.text
if public_key is None: if public_key is None:
logger.debug("Provider_info: cannot determine public key")
raise NoPublicKey raise NoPublicKey
self.public_key = "\n".join( self.public_key = "\n".join(
["-----BEGIN PUBLIC KEY-----", public_key, "-----END PUBLIC KEY-----"] ["-----BEGIN PUBLIC KEY-----", public_key, "-----END PUBLIC KEY-----"]
@ -96,18 +89,3 @@ class Provider(AuthProviderSettings):
def get_session_key(self, userinfo): def get_session_key(self, userinfo):
return userinfo[self.session_key] return userinfo[self.session_key]
def get_resource(self, resource_name: str) -> Resource:
return [
resource for resource in self.resources if resource.resource_name == resource_name
][0]
def get_resource_url(self, resource_name: str) -> str:
return self.url + self.get_resource(resource_name).url
def get_resource_provider(self, resource_provider_id: str) -> ResourceProvider:
return [
provider
for provider in self.resource_providers
if provider.id == resource_provider_id
][0]

View file

@ -4,13 +4,16 @@ import logging
from fastapi import HTTPException, Request, Depends, status from fastapi import HTTPException, Request, Depends, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from sqlmodel.ext.asyncio.session import AsyncSession
from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App
from authlib.oauth2.rfc6749 import OAuth2Token
from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError
# from authlib.oauth1.auth import OAuthToken
from authlib.oauth2.rfc6749 import OAuth2Token
from oidc_test.auth.provider import Provider from oidc_test.auth.provider import Provider
from oidc_test.models import User from oidc_test.models import User
from oidc_test.database import db, TokenNotInDb, UserNotInDB from oidc_test.database import db, get_db_session, TokenNotInDb, UserNotInDB
from oidc_test.settings import settings from oidc_test.settings import settings
from oidc_test.auth_providers import providers from oidc_test.auth_providers import providers
@ -20,7 +23,7 @@ logger = logging.getLogger("oidc-test")
async def fetch_token(name, request): async def fetch_token(name, request):
assert name is not None assert name is not None
assert request is not None assert request is not None
logger.warning("TODO: fetch_token") logger.warn("TODO: fetch_token")
... ...
# if name in oidc_providers: # if name in oidc_providers:
# model = OAuth2Token # model = OAuth2Token
@ -32,10 +35,7 @@ async def fetch_token(name, request):
async def update_token( async def update_token(
provider_id, provider_id, token, refresh_token: str | None = None, access_token: str | None = None
token,
refresh_token: str | None = None,
access_token: str | None = None,
): ):
"""Update the token in the database""" """Update the token in the database"""
provider = providers[provider_id] provider = providers[provider_id]
@ -58,7 +58,11 @@ def init_providers():
provider_settings_dict = provider_settings.model_dump() provider_settings_dict = provider_settings.model_dump()
# Add an anonymous user, that cannot be identified but has provided a valid access token # Add an anonymous user, that cannot be identified but has provided a valid access token
provider_settings_dict["unknown_auth_user"] = User( provider_settings_dict["unknown_auth_user"] = User(
sub="", auth_provider_id=provider_settings.id sub="",
auth_provider_id=provider_settings.id,
roles=[],
userinfo={},
access_token_decoded={},
) )
provider = Provider(**provider_settings_dict) provider = Provider(**provider_settings_dict)
if provider.disabled: if provider.disabled:
@ -120,17 +124,32 @@ def get_auth_provider(request: Request) -> Provider:
return provider return provider
async def get_current_user(request: Request) -> User: async def get_current_user(
request: Request,
db_session: Annotated[AsyncSession, Depends(get_db_session)],
) -> User:
"""Return the user from the request's session.
It can be used in Depends()"""
if user := await get_current_user_or_none(request, db_session):
return user
else:
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
async def get_current_user_or_none(
request: Request,
db_session: Annotated[AsyncSession, Depends(get_db_session)],
) -> User | None:
"""Get the current user from a request object. """Get the current user from a request object.
Also validates the token expiration time. Also validates the token expiration time.
... TODO: complete about refresh token ... TODO: complete about refresh token
""" """
if (user_sub := request.session.get("user_sub")) is None: if (user_sub := request.session.get("user_sub")) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED) return None
token = await get_token_from_session(request) token = await get_token_from_session_or_none(request, db_session)
user = await db.get_user(user_sub)
## Check if the token is expired ## Check if the token is expired
if token.is_expired(): breakpoint()
if token is not None and token.is_expired():
provider = get_auth_provider(request=request) provider = get_auth_provider(request=request)
## Ask a new refresh token from the provider ## Ask a new refresh token from the provider
logger.info(f"Token expired for user {user.name}") logger.info(f"Token expired for user {user.name}")
@ -144,20 +163,28 @@ async def get_current_user(request: Request) -> User:
# raise HTTPException( # raise HTTPException(
# status.HTTP_401_UNAUTHORIZED, "Token expired, cannot refresh" # status.HTTP_401_UNAUTHORIZED, "Token expired, cannot refresh"
# ) # )
user = await db.get_or_add_user(
user_sub, db_session=db_session, auth_provider=provider, token=token
)
return user return user
async def get_token_from_session_or_none(request: Request) -> OAuth2Token | None: async def get_token_from_session_or_none(
request: Request,
db_session: Annotated[AsyncSession, Depends(get_db_session)],
) -> OAuth2Token | None:
"""Return the auth token from the session or None. """Return the auth token from the session or None.
Can be used in Depends()""" Can be used in Depends()"""
try: try:
return await get_token_from_session(request) return await get_token_from_session(request, db_session)
except HTTPException: except HTTPException:
return None return None
async def get_token_from_session(request: Request) -> OAuth2Token: async def get_token_from_session(
request: Request,
db_session: Annotated[AsyncSession, Depends(get_db_session)],
) -> OAuth2Token:
"""Return the token from the session. """Return the token from the session.
Can be used in Depends()""" Can be used in Depends()"""
try: try:
@ -167,60 +194,15 @@ async def get_token_from_session(request: Request) -> OAuth2Token:
request.session.pop("user_sub", None) request.session.pop("user_sub", None)
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid provider") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid provider")
try: try:
return await db.get_token( return await db.get_token(provider, request.session.get("sid"), db_session)
provider,
request.session.get("sid"),
)
except (TokenNotInDb, InvalidKeyError, DecodeError) as err: except (TokenNotInDb, InvalidKeyError, DecodeError) as err:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, err.__class__.__name__) raise HTTPException(status.HTTP_401_UNAUTHORIZED, err.__class__.__name__)
async def get_current_user_or_none(request: Request) -> User | None:
"""Return the user from a request object, from the session.
It can be used in Depends()"""
try:
return await get_current_user(request)
except HTTPException:
return None
def hasrole(required_roles: Union[str, list[str]] = []):
"""Decorator for RBAC permissions"""
required_roles_set: set[str]
if isinstance(required_roles, str):
required_roles_set = set([required_roles])
else:
required_roles_set = set(required_roles)
def decorator(func):
@wraps(func)
async def wrapper(request=None, *args, **kwargs):
if request is None:
raise HTTPException(
500,
"Functions decorated with hasrole must have a request:Request argument",
)
user: User = await get_current_user(request)
if not any(required_roles_set.intersection(user.roles_as_set)):
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
return await func(request, *args, **kwargs)
return wrapper
return decorator
def get_token_info(token: dict) -> dict:
token_info = dict()
for key in token:
if key != "userinfo":
token_info[key] = token[key]
return token_info
async def get_user_from_token( async def get_user_from_token(
token: Annotated[str, Depends(oauth2_scheme)], token: Annotated[str, Depends(oauth2_scheme)],
request: Request, request: Request,
db_session: Annotated[AsyncSession, Depends(get_db_session)],
) -> User: ) -> User:
try: try:
auth_provider_id = request.headers["auth_provider"] auth_provider_id = request.headers["auth_provider"]
@ -258,9 +240,10 @@ async def get_user_from_token(
try: try:
user_id = payload["sub"] user_id = payload["sub"]
except KeyError: except KeyError:
logger.info(f"'sub' not found in the token, using {auth_provider_id}'s default user")
return provider.unknown_auth_user return provider.unknown_auth_user
try: try:
user = await db.get_user(user_id) user = await db.get_user(user_id, db_session)
if user.access_token != token: if user.access_token != token:
user.access_token = token user.access_token = token
except UserNotInDB: except UserNotInDB:
@ -279,11 +262,12 @@ async def get_user_from_token(
async def get_user_from_token_or_none( async def get_user_from_token_or_none(
token: Annotated[str | None, Depends(oauth2_scheme_optional)], token: Annotated[str | None, Depends(oauth2_scheme_optional)],
request: Request, request: Request,
db_session: Annotated[AsyncSession, Depends(get_db_session)],
) -> User | None: ) -> User | None:
if token is None: if token is None:
return None return None
try: try:
return await get_user_from_token(token, request) return await get_user_from_token(token, request, db_session)
except HTTPException: except HTTPException:
return None return None
@ -303,3 +287,29 @@ class UserWithRole:
status.HTTP_401_UNAUTHORIZED, f"Not of any required role {', '.join(self.roles)}" status.HTTP_401_UNAUTHORIZED, f"Not of any required role {', '.join(self.roles)}"
) )
return user return user
# def hasrole(required_roles: Union[str, list[str]] = []):
# """Decorator for RBAC permissions"""
# required_roles_set: set[str]
# if isinstance(required_roles, str):
# required_roles_set = set([required_roles])
# else:
# required_roles_set = set(required_roles)
#
# def decorator(func):
# @wraps(func)
# async def wrapper(request=None, *args, **kwargs):
# if request is None:
# raise HTTPException(
# 500,
# "Functions decorated with hasrole must have a request:Request argument",
# )
# user: User = await get_current_user(request)
# if not any(required_roles_set.intersection(user.roles_as_set)):
# raise HTTPException(status.HTTP_401_UNAUTHORIZED)
# return await func(request, *args, **kwargs)
#
# return wrapper
#
# return decorator

View file

@ -1,17 +1,26 @@
"""Fake in-memory database interface for demo purpose""" """Fake in-memory database interface for demo purpose"""
import logging import logging
from collections.abc import AsyncGenerator
from authlib.oauth2.rfc6749 import OAuth2Token from authlib.oauth2.rfc6749 import OAuth2Token
from jwt import PyJWTError from jwt import PyJWTError
from oidc_test.auth.provider import Provider from sqlmodel import SQLModel, create_engine, select
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from oidc_test.models import User, Role from oidc_test.auth import provider
from oidc_test.settings import settings
from oidc_test.auth.provider import Provider
from oidc_test.models import User, Role, Token
from oidc_test.auth_providers import providers from oidc_test.auth_providers import providers
logger = logging.getLogger("oidc-test") logger = logging.getLogger("oidc-test")
engine = create_async_engine(settings.db.sqla_url)
sync_engine = create_engine(settings.db.sqla_url)
class UserNotInDB(Exception): class UserNotInDB(Exception):
pass pass
@ -21,6 +30,11 @@ class TokenNotInDb(Exception):
pass pass
async def get_db_session() -> AsyncGenerator[AsyncSession]:
async with AsyncSession(engine) as db_session:
yield db_session
class Database: class Database:
users: dict[str, User] = {} users: dict[str, User] = {}
# TODO: key of the token table should be provider: sid # TODO: key of the token table should be provider: sid
@ -31,20 +45,21 @@ class Database:
async def add_user( async def add_user(
self, self,
sub: str, sub: str,
user_info: dict,
auth_provider: Provider, auth_provider: Provider,
access_token: str, token: OAuth2Token,
# access_token: str,
access_token_decoded: dict | None = None, access_token_decoded: dict | None = None,
) -> User: ) -> User:
if access_token_decoded is None: if access_token_decoded is None:
assert auth_provider.name is not None assert auth_provider.name is not None
provider = providers[auth_provider.id] provider = providers[auth_provider.id]
try: try:
access_token_decoded = provider.decode(access_token) access_token_decoded = provider.decode(token["access_token"])
except PyJWTError: except PyJWTError:
access_token_decoded = {} access_token_decoded = {}
user_info["auth_provider_id"] = auth_provider.id user_info: dict = token["user_info"]
user = User(**user_info) sub = user_info["sub"]
user = User(auth_provider_id=auth_provider.id, **user_info)
user.userinfo = user_info user.userinfo = user_info
# user.access_token = access_token # user.access_token = access_token
# user.access_token_decoded = access_token_decoded # user.access_token_decoded = access_token_decoded
@ -63,34 +78,62 @@ class Database:
roles.update(r) roles.update(r)
except KeyError: except KeyError:
pass pass
user.roles = [Role(name=role_name) for role_name in roles] # user.roles = [Role(name=role_name) for role_name in roles]
user.roles = []
self.users[sub] = user self.users[sub] = user
return user return user
async def get_user(self, sub: str) -> User: async def get_user(self, sub: str, db_session: AsyncSession) -> User:
if sub not in self.users: query = select(User).where(User.sub == sub)
user = (await db_session.exec(query)).first()
if user is None:
raise UserNotInDB raise UserNotInDB
return self.users[sub] return user
async def add_token(self, provider: Provider, token: OAuth2Token) -> None: async def get_or_add_user(
self, sub: str, db_session: AsyncSession, auth_provider: Provider, token: OAuth2Token
):
if user := self.get_user(sub, db_session):
return user
else:
return await self.add_user(sub=sub, auth_provider=auth_provider, token=token)
async def add_token(
self, provider: Provider, token: OAuth2Token, db_session: AsyncSession
) -> None:
"""Store a token using as key the sid (auth provider's session id) """Store a token using as key the sid (auth provider's session id)
in the id_token""" in the id_token"""
sid = provider.get_session_key(token["userinfo"]) sid = provider.get_session_key(token["userinfo"])
self.tokens[sid] = token if existing_token := await db_session.get(Token, sid):
# The token already exists: update it
# XXX: check is token is different?
existing_token.token = token
db_session.add(existing_token)
await db_session.commit()
else:
token = Token(sid=sid, token=token)
db_session.add(token)
await db_session.commit()
async def get_token( async def get_token(
self, self, provider: Provider, sid: str | None, db_session: AsyncSession
provider: Provider,
sid: str | None,
) -> OAuth2Token: ) -> OAuth2Token:
# TODO: key of the token table should be provider: sid # TODO: key of the token table should be provider: sid
assert isinstance(provider, Provider) assert isinstance(provider, Provider)
if sid is None: if sid is None:
raise TokenNotInDb raise TokenNotInDb
try: if token := await db_session.get(Token, sid):
return self.tokens[sid] return OAuth2Token.from_dict(token.token)
except KeyError: else:
raise TokenNotInDb raise TokenNotInDb
async def create_db(drop=False):
logger.debug(f"Connect to database with config: {settings.db}")
async with engine.begin() as conn:
if drop:
await conn.run_sync(SQLModel.metadata.drop_all)
await conn.run_sync(SQLModel.metadata.create_all)
db = Database() db = Database()

View file

@ -6,9 +6,6 @@ from typing import Annotated
from pathlib import Path from pathlib import Path
from datetime import datetime from datetime import datetime
import logging import logging
import logging.config
import importlib.resources
from yaml import safe_load
from urllib.parse import urlencode from urllib.parse import urlencode
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
@ -19,6 +16,8 @@ from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from jwt import PyJWTError from jwt import PyJWTError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import select
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from authlib.integrations.base_client import OAuthError from authlib.integrations.base_client import OAuthError
@ -29,7 +28,6 @@ from authlib.oauth2.rfc6749 import OAuth2Token
# from fastapi.security import OpenIdConnect # from fastapi.security import OpenIdConnect
# from pkce import generate_code_verifier, generate_pkce_pair # from pkce import generate_code_verifier, generate_pkce_pair
from oidc_test import __version__
from oidc_test.registry import registry from oidc_test.registry import registry
from oidc_test.auth.provider import NoPublicKey, Provider from oidc_test.auth.provider import NoPublicKey, Provider
from oidc_test.auth.utils import ( from oidc_test.auth.utils import (
@ -45,26 +43,18 @@ from oidc_test.auth.utils import init_providers
from oidc_test.settings import settings from oidc_test.settings import settings
from oidc_test.auth_providers import providers from oidc_test.auth_providers import providers
from oidc_test.models import User from oidc_test.models import User
from oidc_test.database import TokenNotInDb, db from oidc_test.database import TokenNotInDb, db, create_db, get_db_session
from oidc_test.resource_server import resource_server from oidc_test.resource_server import resource_server
logger = logging.getLogger("oidc-test") logger = logging.getLogger("oidc-test")
if settings.log:
assert __package__ is not None
with (
importlib.resources.path(__package__) as package_path,
open(package_path / settings.log_config_file) as f,
):
logging_config = safe_load(f)
logging.config.dictConfig(logging_config)
templates = Jinja2Templates(Path(__file__).parent / "templates") templates = Jinja2Templates(Path(__file__).parent / "templates")
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
assert app is not None assert app is not None
await create_db()
init_providers() init_providers()
registry.make_registry() registry.make_registry()
for provider in list(providers.values()): for provider in list(providers.values()):
@ -109,38 +99,35 @@ async def home(
"show_token": settings.show_token, "show_token": settings.show_token,
"user": user, "user": user,
"now": datetime.now(), "now": datetime.now(),
"__version__": __version__, "auth_provider": provider,
} }
if provider is None or token is None: if provider is None or token is None or user is None:
context["providers"] = providers context["providers"] = providers
context["access_token"] = None context["access_token"] = None
context["id_token_parsed"] = None context["id_token_parsed"] = None
context["access_token_parsed"] = None context["access_token_parsed"] = None
context["refresh_token_parsed"] = None context["refresh_token_parsed"] = None
context["resources"] = None context["resources"] = None
context["auth_provider"] = None
else: else:
context["auth_provider"] = provider
context["access_token"] = token["access_token"] context["access_token"] = token["access_token"]
# XXX: resources defined externally? I am confused...
try: try:
access_token_parsed = provider.decode(token["access_token"], verify_signature=False) access_token_parsed = provider.decode(token["access_token"], verify_signature=False)
context["access_token_parsed"] = access_token_parsed
context["access_token_scope"] = access_token_parsed.get("scope")
except PyJWTError as err: except PyJWTError as err:
context["access_token_parsed"] = {"Cannot parse": err.__class__.__name__} access_token_parsed = {"Cannot parse": err.__class__.__name__}
try:
context["access_token_scope"] = access_token_parsed["scope"]
except KeyError:
context["access_token_scope"] = None context["access_token_scope"] = None
context["id_token_parsed"] = provider.decode(token["id_token"], verify_signature=False)
context["access_token_parsed"] = access_token_parsed
context["resource_providers"] = registry.resource_providers
try: try:
id_token_parsed = provider.decode(token["id_token"], verify_signature=False) context["refresh_token_parsed"] = provider.decode(
context["id_token_parsed"] = id_token_parsed token["refresh_token"], verify_signature=False
except PyJWTError as err: )
context["id_token_parsed"] = {"Cannot parse": err.__class__.__name__}
try:
refresh_token_parsed = provider.decode(token["refresh_token"], verify_signature=False)
context["refresh_token_parsed"] = refresh_token_parsed
except PyJWTError as err: except PyJWTError as err:
context["refresh_token_parsed"] = {"Cannot parse": err.__class__.__name__} context["refresh_token_parsed"] = {"Cannot parse": err.__class__.__name__}
context["resources"] = registry.resources
context["resource_providers"] = provider.resource_providers
return templates.TemplateResponse(name="home.html", request=request, context=context) return templates.TemplateResponse(name="home.html", request=request, context=context)
@ -184,6 +171,7 @@ async def login(request: Request, auth_provider_id: str) -> RedirectResponse:
async def auth( async def auth(
request: Request, request: Request,
auth_provider_id: str, auth_provider_id: str,
db_session: Annotated[AsyncSession, Depends(get_db_session)],
) -> RedirectResponse: ) -> RedirectResponse:
"""Decrypt the auth token, store it to the session (cookie based) """Decrypt the auth token, store it to the session (cookie based)
and response to the browser with a redirect to a "welcome user" page. and response to the browser with a redirect to a "welcome user" page.
@ -218,24 +206,30 @@ async def auth(
# user_info_from_endpoint = {} # user_info_from_endpoint = {}
# Build and remember the user in the session # Build and remember the user in the session
request.session["user_sub"] = sub request.session["user_sub"] = sub
# Store the user in the database, which also verifies the token validity and signature
try: try:
user = await db.add_user( user = db.get_or_add_user(sub, db_session, auth_provider=provider, token=token)
sub, query = select(User).where(User.sub == sub)
user_info=userinfo, user = (await db_session.exec(query)).first()
auth_provider=providers[auth_provider_id], assert user is not None
access_token=token["access_token"], except Exception as err:
) # Store the user in the database, which also verifies the token validity and signature
except PyJWTError as err: logger.info(f"New user {userinfo}")
raise HTTPException( try:
status.HTTP_401_UNAUTHORIZED, user = await db.add_user(
detail=f"Token invalid: {err.__class__.__name__}", sub,
) user_info=userinfo,
assert isinstance(user, User) auth_provider=providers[auth_provider_id],
access_token=token["access_token"],
)
except PyJWTError as err:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED,
detail=f"Token invalid: {err.__class__.__name__}",
)
# Add the provider session id to the session # Add the provider session id to the session
request.session["sid"] = provider.get_session_key(userinfo) request.session["sid"] = provider.get_session_key(userinfo)
# Add the token to the db because it is used for logout # Add the token to the db because it is used for logout
await db.add_token(provider, token) await db.add_token(provider, token, db_session)
# Send the user to the home: (s)he is authenticated # Send the user to the home: (s)he is authenticated
return RedirectResponse(url=request.url_for("home")) return RedirectResponse(url=request.url_for("home"))
else: else:
@ -262,7 +256,7 @@ async def logout(
if ( if (
provider_logout_uri := provider.authlib_client.server_metadata.get("end_session_endpoint") provider_logout_uri := provider.authlib_client.server_metadata.get("end_session_endpoint")
) is None: ) is None:
logger.warning(f"Cannot find end_session_endpoint for provider {provider.id}") logger.warn(f"Cannot find end_session_endpoint for provider {provider.id}")
return RedirectResponse(request.url_for("non_compliant_logout")) return RedirectResponse(request.url_for("non_compliant_logout"))
post_logout_uri = request.url_for("home") post_logout_uri = request.url_for("home")
# Clear session # Clear session
@ -273,13 +267,17 @@ async def logout(
except TokenNotInDb: except TokenNotInDb:
logger.warning("No session in db for the token or no token") logger.warning("No session in db for the token or no token")
return RedirectResponse(request.url_for("home")) return RedirectResponse(request.url_for("home"))
url_query = { logout_url = (
"post_logout_redirect_uri": post_logout_uri, provider_logout_uri
"client_id": provider.client_id, + "?"
} + urlencode(
if provider.logout_with_id_token_hint: {
url_query["id_token_hint"] = token["id_token"] "post_logout_redirect_uri": post_logout_uri,
logout_url = f"{provider_logout_uri}?{urlencode(url_query)}" "id_token_hint": token["id_token"],
"client_id": "oidc_local_test",
}
)
)
return RedirectResponse(logout_url) return RedirectResponse(logout_url)
@ -310,13 +308,7 @@ async def refresh(
refresh_token=token["refresh_token"], refresh_token=token["refresh_token"],
grant_type="refresh_token", grant_type="refresh_token",
) )
try: await update_token(provider.id, new_token)
await update_token(provider.id, new_token)
except PyJWTError as err:
logger.info(f"Cannot refresh token: {err.__class__.__name__}")
raise HTTPException(
status.HTTP_510_NOT_EXTENDED, f"Token refresh error: {err.__class__.__name__}"
)
return RedirectResponse(url=request.url_for("home")) return RedirectResponse(url=request.url_for("home"))

View file

@ -2,41 +2,49 @@ import logging
from functools import cached_property from functools import cached_property
from typing import Any from typing import Any
from sqlalchemy.types import JSON
from pydantic import ( from pydantic import (
BaseModel,
computed_field, computed_field,
field_validator,
AnyHttpUrl, AnyHttpUrl,
EmailStr, EmailStr,
ConfigDict, ConfigDict,
) )
from sqlmodel import SQLModel, Field from sqlmodel import Relationship, SQLModel, Field
logger = logging.getLogger("oidc-test") logger = logging.getLogger("oidc-test")
class Role(SQLModel, extra="ignore"): class Role(SQLModel, table=True):
id: str = Field(primary_key=True)
name: str name: str
class UserBase(SQLModel, extra="ignore"): class UserBase(SQLModel):
id: str | None = None
sid: str | None = None sid: str | None = None
name: str | None = None name: str | None = None
email: EmailStr | None = None email: EmailStr | None = None
picture: AnyHttpUrl | None = None picture: str | None = None
roles: list[Role] = []
@classmethod
@field_validator("picture")
def _valid_url(cls, v):
return AnyHttpUrl(v)
class User(UserBase): class User(UserBase, table=True):
model_config = ConfigDict(arbitrary_types_allowed=True) # type:ignore # model_config = ConfigDict(arbitrary_types_allowed=True) # type:ignore
id: int | None = Field(primary_key=True, default=None)
roles: list[str] = Field(sa_type=JSON, default=[]) # Relationship(link_model=Role)
sub: str = Field( sub: str = Field(
description="""subject id of the user given by the oidc provider, description="""subject id of the user given by the oidc provider,
also the key for the database 'table'""", also the key for the database 'table'""",
) )
userinfo: dict = {}
access_token: str | None = None
access_token_decoded: dict[str, Any] | None = None
auth_provider_id: str auth_provider_id: str
access_token: str | None = None
userinfo: dict[str, Any] = Field(sa_type=JSON)
access_token_decoded: dict[str, Any] | None = Field(sa_type=JSON)
@computed_field @computed_field
@cached_property @cached_property
@ -64,3 +72,8 @@ class User(UserBase):
def get_scope(self, verify_signature: bool = True): def get_scope(self, verify_signature: bool = True):
return self.decode_access_token(verify_signature=verify_signature)["scope"] return self.decode_access_token(verify_signature=verify_signature)["scope"]
class Token(SQLModel, table=True):
sid: str | None = Field(primary_key=True, default=None)
token: dict[str, Any] = Field(sa_type=JSON)

View file

@ -12,13 +12,14 @@ class ProcessResult(BaseModel):
model_config = ConfigDict( model_config = ConfigDict(
extra="allow", extra="allow",
) )
msg: str | None = None
class ProcessError(Exception): class ProcessError(Exception):
pass pass
class Resource(BaseModel): class ResourceProvider(BaseModel):
name: str name: str
scope_required: str | None = None scope_required: str | None = None
role_required: str | None = None role_required: str | None = None
@ -35,13 +36,13 @@ class Resource(BaseModel):
class ResourceRegistry(BaseModel): class ResourceRegistry(BaseModel):
resources: dict[str, Resource] = {} resource_providers: dict[str, ResourceProvider] = {}
def make_registry(self): def make_registry(self):
for ep in entry_points().select(group="oidc_test.resource_provider"): for ep in entry_points().select(group="oidc_test.resource_provider"):
ResourceClass = ep.load() ResourceProviderClass = ep.load()
if issubclass(ResourceClass, Resource): if issubclass(ResourceProviderClass, ResourceProvider):
self.resources[ep.name] = ResourceClass(ep.name) self.resource_providers[ep.name] = ResourceProviderClass(ep.name)
registry = ResourceRegistry() registry = ResourceRegistry()

View file

@ -1,11 +1,9 @@
from typing import Annotated, Any from typing import Annotated, Any
import logging import logging
from json import JSONDecodeError
from authlib.oauth2.rfc6749 import OAuth2Token from authlib.oauth2.rfc6749 import OAuth2Token
from httpx import AsyncClient, HTTPError from httpx import AsyncClient
from jwt.exceptions import DecodeError, ExpiredSignatureError, InvalidTokenError from fastapi import FastAPI, HTTPException, Depends, status
from fastapi import FastAPI, HTTPException, Depends, Request, status
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
# from starlette.middleware.sessions import SessionMiddleware # from starlette.middleware.sessions import SessionMiddleware
@ -18,7 +16,7 @@ from oidc_test.auth.utils import (
oauth2_scheme_optional, oauth2_scheme_optional,
) )
from oidc_test.auth_providers import providers from oidc_test.auth_providers import providers
from oidc_test.settings import ResourceProvider, settings from oidc_test.settings import settings
from oidc_test.models import User from oidc_test.models import User
from oidc_test.registry import ProcessError, ProcessResult, registry from oidc_test.registry import ProcessError, ProcessResult, registry
@ -49,7 +47,7 @@ resource_server.add_middleware(
@resource_server.get("/") @resource_server.get("/")
async def resources() -> dict[str, dict[str, Any]]: async def resources() -> dict[str, dict[str, Any]]:
return {"internal": {}, "plugins": registry.resources} return {"internal": {}, "plugins": registry.resource_providers}
@resource_server.get("/{resource_name}") @resource_server.get("/{resource_name}")
@ -66,46 +64,8 @@ async def get_resource(
# Get the resource if it's defined in user auth provider's resources (external) # Get the resource if it's defined in user auth provider's resources (external)
if user is not None: if user is not None:
provider = providers[user.auth_provider_id] provider = providers[user.auth_provider_id]
if ":" in resource_name:
# Third-party resource provider: send the request with the request token
resource_provider_id, resource_name = resource_name.split(":", 1)
provider = providers[user.auth_provider_id]
resource_provider: ResourceProvider = provider.get_resource_provider(
resource_provider_id
)
resource_url = resource_provider.get_resource_url(resource_name)
async with AsyncClient(verify=resource_provider.verify_ssl) as client:
try:
logger.debug(f"GET request to {resource_url}")
resp = await client.get(
resource_url,
headers={
"Content-type": "application/json",
"Authorization": f"Bearer {token}",
"auth_provider": user.auth_provider_id,
},
)
except HTTPError as err:
raise HTTPException(
status.HTTP_503_SERVICE_UNAVAILABLE, err.__class__.__name__
)
except Exception as err:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR, err.__class__.__name__
)
else:
if resp.is_success:
return resp.json()
else:
reason_str: str
try:
reason_str = resp.json().get("detail", str(resp))
except Exception:
reason_str = str(resp.text)
raise HTTPException(resp.status_code, reason_str)
# Third party resource (provided through the auth provider) # Third party resource (provided through the auth provider)
# The token is just passed on # The token is just passed on
# XXX: is this branch valid anymore?
if resource_name in [r.resource_name for r in provider.resources]: if resource_name in [r.resource_name for r in provider.resources]:
return await get_auth_provider_resource( return await get_auth_provider_resource(
provider=provider, provider=provider,
@ -114,31 +74,31 @@ async def get_resource(
user=user, user=user,
) )
# Internal resource (provided here) # Internal resource (provided here)
if resource_name in registry.resources: if resource_name in registry.resource_providers:
resource = registry.resources[resource_name] resource_provider = registry.resource_providers[resource_name]
reason: dict[str, str] = {} reason: dict[str, str] = {}
if not resource.is_public: if not resource_provider.is_public:
if user is None: if user is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Resource is not public") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Resource is not public")
else: else:
if resource.scope_required is not None and not user.has_scope( if resource_provider.scope_required is not None and not user.has_scope(
resource.scope_required resource_provider.scope_required
): ):
reason["scope"] = ( reason["scope"] = (
f"No scope {resource.scope_required} in the access token " f"No scope {resource_provider.scope_required} in the access token "
"but it is required for accessing this resource" "but it is required for accessing this resource"
) )
if ( if (
resource.role_required is not None resource_provider.role_required is not None
and resource.role_required not in user.roles_as_set and resource_provider.role_required not in user.roles_as_set
): ):
reason["role"] = ( reason["role"] = (
f"You don't have the role {resource.role_required} " f"You don't have the role {resource_provider.role_required} "
"but it is required for accessing this resource" "but it is required for accessing this resource"
) )
if len(reason) == 0: if len(reason) == 0:
try: try:
resp = await resource.process(user=user, resource_id=resource_id) resp = await resource_provider.process(user=user, resource_id=resource_id)
return resp return resp
except ProcessError as err: except ProcessError as err:
raise HTTPException( raise HTTPException(
@ -156,10 +116,11 @@ async def get_auth_provider_resource(
) -> ProcessResult: ) -> ProcessResult:
if token is None: if token is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No auth token") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No auth token")
access_token = token access_token = token["access_token"]
resource = [r for r in provider.resources if r.resource_name == resource_name][0]
async with AsyncClient() as client: async with AsyncClient() as client:
resp = await client.get( resp = await client.get(
url=provider.get_resource_url(resource_name), url=provider.url + resource.url,
headers={ headers={
"Content-type": "application/json", "Content-type": "application/json",
"Authorization": f"Bearer {access_token}", "Authorization": f"Bearer {access_token}",
@ -170,19 +131,9 @@ async def get_auth_provider_resource(
# Only a demo, real application would really process the response # Only a demo, real application would really process the response
resp_length = len(resp.text) resp_length = len(resp.text)
if resp_length > 1024: if resp_length > 1024:
return ProcessResult( return ProcessResult(msg=f"The resource is too long ({resp_length} bytes) to show here")
msg=f"The resource is too long ({resp_length} bytes) to show in this demo, here is just the begining in raw format",
start=resp.text[:100] + "...",
)
else: else:
try: return ProcessResult(**resp.json())
resp_json = resp.json()
except JSONDecodeError:
return ProcessResult(msg="The resource is not formatted in JSON", text=resp.text)
if isinstance(resp_json, dict):
return ProcessResult(**resp.json())
elif isinstance(resp_json, list):
return ProcessResult(**{str(i): line for i, line in enumerate(resp_json)})
# @resource_server.get("/public") # @resource_server.get("/public")

View file

@ -4,7 +4,7 @@ import random
from typing import Type, Tuple from typing import Type, Tuple
from pathlib import Path from pathlib import Path
from pydantic import AnyHttpUrl, BaseModel, computed_field, AnyUrl from pydantic import BaseModel, computed_field, AnyUrl
from pydantic_settings import ( from pydantic_settings import (
BaseSettings, BaseSettings,
SettingsConfigDict, SettingsConfigDict,
@ -22,22 +22,6 @@ class Resource(BaseModel):
url: str url: str
class ResourceProvider(BaseModel):
id: str
name: str
base_url: AnyUrl
resources: list[Resource] = []
verify_ssl: bool = True
def get_resource(self, resource_name: str) -> Resource:
return [
resource for resource in self.resources if resource.resource_name == resource_name
][0]
def get_resource_url(self, resource_name: str) -> str:
return f"{self.base_url}{self.get_resource(resource_name).url}"
class AuthProviderSettings(BaseModel): class AuthProviderSettings(BaseModel):
"""Auth provider, can also be a resource server""" """Auth provider, can also be a resource server"""
@ -61,7 +45,6 @@ class AuthProviderSettings(BaseModel):
session_key: str = "sid" session_key: str = "sid"
skip_verify_signature: bool = True skip_verify_signature: bool = True
disabled: bool = False disabled: bool = False
resource_providers: list[ResourceProvider] = []
@computed_field @computed_field
@property @property
@ -84,6 +67,13 @@ class AuthProviderSettings(BaseModel):
return None return None
class ResourceProvider(BaseModel):
id: str
name: str
base_url: AnyUrl
resources: list[Resource] = []
class AuthSettings(BaseModel): class AuthSettings(BaseModel):
show_session_details: bool = False show_session_details: bool = False
providers: list[AuthProviderSettings] = [] providers: list[AuthProviderSettings] = []
@ -96,20 +86,41 @@ class Insecure(BaseModel):
skip_verify_signature: bool = False skip_verify_signature: bool = False
class DB(BaseModel):
host: str = "localhost"
port: int = 5432
db: str = "oidc-test"
user: str = "oidc-test"
password: str = "oidc-test"
debug: bool = False
pool_size: int = 10
max_overflow: int = 10
echo: bool = False
@property
def sqla_url(self):
return (
f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}"
)
def get_pg_url(self):
return f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}"
class Settings(BaseSettings): class Settings(BaseSettings):
"""Settings wil be read from an .env file""" """Settings wil be read from an .env file"""
model_config = SettingsConfigDict(env_nested_delimiter="__") model_config = SettingsConfigDict(env_nested_delimiter="__")
auth: AuthSettings = AuthSettings() auth: AuthSettings = AuthSettings()
resource_providers: list[ResourceProvider] = []
secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16)) secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16))
log: bool = False log: bool = False
log_config_file: str = "log_conf.yaml"
insecure: Insecure = Insecure() insecure: Insecure = Insecure()
db: DB = DB()
cors_origins: list[str] = [] cors_origins: list[str] = []
debug_token: bool = False debug_token: bool = False
show_token: bool = False show_token: bool = False
show_external_resource_providers_links: bool = False
@classmethod @classmethod
def settings_customise_sources( def settings_customise_sources(

View file

@ -21,12 +21,6 @@ hr {
.hidden { .hidden {
display: none; display: none;
} }
.version {
position: absolute;
font-size: 75%;
top: 0.3em;
right: 0.3em;
}
.center { .center {
text-align: center; text-align: center;
} }

View file

@ -2,9 +2,7 @@ async function checkHref(elem, token, authProvider) {
const msg = document.getElementById("msg") const msg = document.getElementById("msg")
const resourceName = elem.getAttribute("resource-name") const resourceName = elem.getAttribute("resource-name")
const resourceId = elem.getAttribute("resource-id") const resourceId = elem.getAttribute("resource-id")
const resourceProviderId = elem.getAttribute("resource-provider-id") ? elem.getAttribute("resource-provider-id") : "" const url = resourceId ? `resource/${resourceName}/${resourceId}` : `resource/${resourceName}`
const fqResourceName = resourceProviderId ? `${resourceProviderId}:${resourceName}` : resourceName
const url = resourceId ? `resource/${fqResourceName}/${resourceId}` : `resource/${fqResourceName}`
const resp = await fetch(url, { const resp = await fetch(url, {
method: "GET", method: "GET",
headers: new Headers({ headers: new Headers({
@ -32,13 +30,11 @@ function checkPerms(className, token, authProvider) {
) )
} }
async function get_resource(resourceName, token, authProvider, resourceId, resourceProviderId) { async function get_resource(resource_name, token, authProvider, resource_id) {
// BaseUrl for an external resource provider
//if (!keycloak.keycloak) { return } //if (!keycloak.keycloak) { return }
const msg = document.getElementById("msg") const msg = document.getElementById("msg")
const resourceElem = document.getElementById('resource') const resourceElem = document.getElementById('resource')
const fqResourceName = resourceProviderId ? `${resourceProviderId}:${resourceName}` : resourceName const url = resource_id ? `resource/${resource_name}/${resource_id}` : `resource/${resource_name}`
const url = resourceId ? `resource/${fqResourceName}/${resourceId}` : `resource/${fqResourceName}`
const resp = await fetch(url, { const resp = await fetch(url, {
method: "GET", method: "GET",
headers: new Headers({ headers: new Headers({

View file

@ -5,7 +5,6 @@
<script src="{{ url_for('static', path='/utils.js') }}"></script> <script src="{{ url_for('static', path='/utils.js') }}"></script>
</head> </head>
<body onload="checkPerms('links-to-check', '{{ access_token }}', '{{ auth_provider.id }}')"> <body onload="checkPerms('links-to-check', '{{ access_token }}', '{{ auth_provider.id }}')">
<div class="version">v. {{ __version__}}</div>
<h1>OIDC-test - FastAPI client</h1> <h1>OIDC-test - FastAPI client</h1>
{% block content %} {% block content %}
{% endblock %} {% endblock %}

View file

@ -66,67 +66,29 @@
{% endif %} {% endif %}
<hr> <hr>
<div class="content"> <div class="content">
{% if resources %} {% if resource_providers %}
<p>This application provides all these resources, eventually protected with scope or roles:</p> <p>
{{ auth_provider.name }} provides these resources:
</p>
<div class="links-to-check"> <div class="links-to-check">
{% for name, resource in resources.items() %} {% for name, resource_provider in resource_providers.items() %}
{% if resource.default_resource_id %} {% if resource_provider.default_resource_id %}
<button resource-name="{{ name }}" <button resource-name="{{ name }}"
resource-id="{{ resource.default_resource_id }}" resource-id="{{ resource_provider.default_resource_id }}"
onclick="get_resource('{{ name }}', '{{ access_token }}', '{{ auth_provider.id }}', '{{ resource.default_resource_id }}')" onclick="get_resource('{{ name }}', '{{ access_token }}', '{{ auth_provider.id }}', '{{ resource_provider.default_resource_id }}')"
> >
{{ resource.name }} {{ resource_provider.name }}
</button> </button>
{% else %} {% else %}
<button resource-name="{{ name }}" <button resource-name="{{ name }}"
onclick="get_resource('{{ name }}', '{{ access_token }}', '{{ auth_provider.id }}')" onclick="get_resource('{{ name }}', '{{ access_token }}', '{{ auth_provider.id }}')"
> >
{{ resource.name }} {{ resource_provider.name }}
</button> </buttona>
{% endif %} {% endif %}
{% endfor %} {% endfor %}
</div> </div>
{% endif %} {% endif %}
{% if auth_provider.resources %}
<p>{{ auth_provider.name }} is also defined as a provider for these resources:</p>
<div class="links-to-check">
{% for resource in auth_provider.resources %}
{% if resource.default_resource_id %}
<button resource-name="{{ resource.resource_name }}"
resource-id="{{ resource.default_resource_id }}"
onclick="get_resource('{{ resource.resource_name }}', '{{ access_token }}', '{{ auth_provider.id }}', '{{ resource.default_resource_id }}')"
>
{{ resource.name }}
</button>
{% else %}
<button resource-name="{{ resource.resource_name }}"
onclick="get_resource('{{ resource.resource_name }}', '{{ access_token }}', '{{ auth_provider.id }}')"
>
{{ resource.name }}
</button>
{% endif %}
{% endfor %}
</div>
{% endif %}
{% if resource_providers %}
<p>{{ auth_provider.name }} allows this application to request resources from third party resource providers:</p>
{% for resource_provider in resource_providers %}
<div class="links-to-check">
{{ resource_provider.name }}
{% for resource in resource_provider.resources %}
<button resource-name="{{ resource.resource_name }}"
resource-id="{{ resource.default_resource_id }}"
resource-provider-id="{{ resource_provider.id }}"
onclick="get_resource('{{ resource.resource_name }}', '{{ access_token }}',
'{{ auth_provider.id }}', '{{ resource.default_resource_id }}',
'{{ resource_provider.id }}')"
>
{{ resource.name }}
</button>
{% endfor %}
</div>
{% endfor %}
{% endif %}
<div class="resourceResult"> <div class="resourceResult">
<div id="resource" class="resource"></div> <div id="resource" class="resource"></div>
<div id="msg" class="msg error"></div> <div id="msg" class="msg error"></div>

60
uv.lock generated
View file

@ -1,5 +1,4 @@
version = 1 version = 1
revision = 1
requires-python = ">=3.13" requires-python = ">=3.13"
[[package]] [[package]]
@ -33,6 +32,22 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/25/8a/c46dcc25341b5bce5472c718902eb3d38600a903b14fa6aeecef3f21a46f/asttokens-3.0.0-py3-none-any.whl", hash = "sha256:e3078351a059199dd5138cb1c706e6430c05eff2ff136af5eb4790f9d28932e2", size = 26918 }, { url = "https://files.pythonhosted.org/packages/25/8a/c46dcc25341b5bce5472c718902eb3d38600a903b14fa6aeecef3f21a46f/asttokens-3.0.0-py3-none-any.whl", hash = "sha256:e3078351a059199dd5138cb1c706e6430c05eff2ff136af5eb4790f9d28932e2", size = 26918 },
] ]
[[package]]
name = "asyncpg"
version = "0.30.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/2f/4c/7c991e080e106d854809030d8584e15b2e996e26f16aee6d757e387bc17d/asyncpg-0.30.0.tar.gz", hash = "sha256:c551e9928ab6707602f44811817f82ba3c446e018bfe1d3abecc8ba5f3eac851", size = 957746 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/3a/22/e20602e1218dc07692acf70d5b902be820168d6282e69ef0d3cb920dc36f/asyncpg-0.30.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:05b185ebb8083c8568ea8a40e896d5f7af4b8554b64d7719c0eaa1eb5a5c3a70", size = 670373 },
{ url = "https://files.pythonhosted.org/packages/3d/b3/0cf269a9d647852a95c06eb00b815d0b95a4eb4b55aa2d6ba680971733b9/asyncpg-0.30.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c47806b1a8cbb0a0db896f4cd34d89942effe353a5035c62734ab13b9f938da3", size = 634745 },
{ url = "https://files.pythonhosted.org/packages/8e/6d/a4f31bf358ce8491d2a31bfe0d7bcf25269e80481e49de4d8616c4295a34/asyncpg-0.30.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b6fde867a74e8c76c71e2f64f80c64c0f3163e687f1763cfaf21633ec24ec33", size = 3512103 },
{ url = "https://files.pythonhosted.org/packages/96/19/139227a6e67f407b9c386cb594d9628c6c78c9024f26df87c912fabd4368/asyncpg-0.30.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46973045b567972128a27d40001124fbc821c87a6cade040cfcd4fa8a30bcdc4", size = 3592471 },
{ url = "https://files.pythonhosted.org/packages/67/e4/ab3ca38f628f53f0fd28d3ff20edff1c975dd1cb22482e0061916b4b9a74/asyncpg-0.30.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9110df111cabc2ed81aad2f35394a00cadf4f2e0635603db6ebbd0fc896f46a4", size = 3496253 },
{ url = "https://files.pythonhosted.org/packages/ef/5f/0bf65511d4eeac3a1f41c54034a492515a707c6edbc642174ae79034d3ba/asyncpg-0.30.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:04ff0785ae7eed6cc138e73fc67b8e51d54ee7a3ce9b63666ce55a0bf095f7ba", size = 3662720 },
{ url = "https://files.pythonhosted.org/packages/e7/31/1513d5a6412b98052c3ed9158d783b1e09d0910f51fbe0e05f56cc370bc4/asyncpg-0.30.0-cp313-cp313-win32.whl", hash = "sha256:ae374585f51c2b444510cdf3595b97ece4f233fde739aa14b50e0d64e8a7a590", size = 560404 },
{ url = "https://files.pythonhosted.org/packages/c8/a4/cec76b3389c4c5ff66301cd100fe88c318563ec8a520e0b2e792b5b84972/asyncpg-0.30.0-cp313-cp313-win_amd64.whl", hash = "sha256:f59b430b8e27557c3fb9869222559f7417ced18688375825f8f12302c34e915e", size = 621623 },
]
[[package]] [[package]]
name = "authlib" name = "authlib"
version = "1.4.0" version = "1.4.0"
@ -207,18 +222,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/68/1b/e0a87d256e40e8c888847551b20a017a6b98139178505dc7ffb96f04e954/dnspython-2.7.0-py3-none-any.whl", hash = "sha256:b4c34b7d10b51bcc3a5071e7b8dee77939f1e878477eeecc965e9835f63c6c86", size = 313632 }, { url = "https://files.pythonhosted.org/packages/68/1b/e0a87d256e40e8c888847551b20a017a6b98139178505dc7ffb96f04e954/dnspython-2.7.0-py3-none-any.whl", hash = "sha256:b4c34b7d10b51bcc3a5071e7b8dee77939f1e878477eeecc965e9835f63c6c86", size = 313632 },
] ]
[[package]]
name = "dunamai"
version = "1.23.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "packaging" },
]
sdist = { url = "https://files.pythonhosted.org/packages/06/4e/a5c8c337a1d9ac0384298ade02d322741fb5998041a5ea74d1cd2a4a1d47/dunamai-1.23.0.tar.gz", hash = "sha256:a163746de7ea5acb6dacdab3a6ad621ebc612ed1e528aaa8beedb8887fccd2c4", size = 44681 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/21/4c/963169386309fec4f96fd61210ac0a0666887d0fb0a50205395674d20b71/dunamai-1.23.0-py3-none-any.whl", hash = "sha256:a0906d876e92441793c6a423e16a4802752e723e9c9a5aabdc5535df02dbe041", size = 26342 },
]
[[package]] [[package]]
name = "ecdsa" name = "ecdsa"
version = "0.19.0" version = "0.19.0"
@ -296,6 +299,30 @@ standard = [
{ name = "uvicorn", extra = ["standard"] }, { name = "uvicorn", extra = ["standard"] },
] ]
[[package]]
name = "greenlet"
version = "3.1.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/2f/ff/df5fede753cc10f6a5be0931204ea30c35fa2f2ea7a35b25bdaf4fe40e46/greenlet-3.1.1.tar.gz", hash = "sha256:4ce3ac6cdb6adf7946475d7ef31777c26d94bccc377e070a7986bd2d5c515467", size = 186022 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f3/57/0db4940cd7bb461365ca8d6fd53e68254c9dbbcc2b452e69d0d41f10a85e/greenlet-3.1.1-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:05175c27cb459dcfc05d026c4232f9de8913ed006d42713cb8a5137bd49375f1", size = 272990 },
{ url = "https://files.pythonhosted.org/packages/1c/ec/423d113c9f74e5e402e175b157203e9102feeb7088cee844d735b28ef963/greenlet-3.1.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:935e943ec47c4afab8965954bf49bfa639c05d4ccf9ef6e924188f762145c0ff", size = 649175 },
{ url = "https://files.pythonhosted.org/packages/a9/46/ddbd2db9ff209186b7b7c621d1432e2f21714adc988703dbdd0e65155c77/greenlet-3.1.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:667a9706c970cb552ede35aee17339a18e8f2a87a51fba2ed39ceeeb1004798a", size = 663425 },
{ url = "https://files.pythonhosted.org/packages/bc/f9/9c82d6b2b04aa37e38e74f0c429aece5eeb02bab6e3b98e7db89b23d94c6/greenlet-3.1.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b8a678974d1f3aa55f6cc34dc480169d58f2e6d8958895d68845fa4ab566509e", size = 657736 },
{ url = "https://files.pythonhosted.org/packages/d9/42/b87bc2a81e3a62c3de2b0d550bf91a86939442b7ff85abb94eec3fc0e6aa/greenlet-3.1.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:efc0f674aa41b92da8c49e0346318c6075d734994c3c4e4430b1c3f853e498e4", size = 660347 },
{ url = "https://files.pythonhosted.org/packages/37/fa/71599c3fd06336cdc3eac52e6871cfebab4d9d70674a9a9e7a482c318e99/greenlet-3.1.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0153404a4bb921f0ff1abeb5ce8a5131da56b953eda6e14b88dc6bbc04d2049e", size = 615583 },
{ url = "https://files.pythonhosted.org/packages/4e/96/e9ef85de031703ee7a4483489b40cf307f93c1824a02e903106f2ea315fe/greenlet-3.1.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:275f72decf9932639c1c6dd1013a1bc266438eb32710016a1c742df5da6e60a1", size = 1133039 },
{ url = "https://files.pythonhosted.org/packages/87/76/b2b6362accd69f2d1889db61a18c94bc743e961e3cab344c2effaa4b4a25/greenlet-3.1.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:c4aab7f6381f38a4b42f269057aee279ab0fc7bf2e929e3d4abfae97b682a12c", size = 1160716 },
{ url = "https://files.pythonhosted.org/packages/1f/1b/54336d876186920e185066d8c3024ad55f21d7cc3683c856127ddb7b13ce/greenlet-3.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:b42703b1cf69f2aa1df7d1030b9d77d3e584a70755674d60e710f0af570f3761", size = 299490 },
{ url = "https://files.pythonhosted.org/packages/5f/17/bea55bf36990e1638a2af5ba10c1640273ef20f627962cf97107f1e5d637/greenlet-3.1.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1695e76146579f8c06c1509c7ce4dfe0706f49c6831a817ac04eebb2fd02011", size = 643731 },
{ url = "https://files.pythonhosted.org/packages/78/d2/aa3d2157f9ab742a08e0fd8f77d4699f37c22adfbfeb0c610a186b5f75e0/greenlet-3.1.1-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7876452af029456b3f3549b696bb36a06db7c90747740c5302f74a9e9fa14b13", size = 649304 },
{ url = "https://files.pythonhosted.org/packages/f1/8e/d0aeffe69e53ccff5a28fa86f07ad1d2d2d6537a9506229431a2a02e2f15/greenlet-3.1.1-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4ead44c85f8ab905852d3de8d86f6f8baf77109f9da589cb4fa142bd3b57b475", size = 646537 },
{ url = "https://files.pythonhosted.org/packages/05/79/e15408220bbb989469c8871062c97c6c9136770657ba779711b90870d867/greenlet-3.1.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8320f64b777d00dd7ccdade271eaf0cad6636343293a25074cc5566160e4de7b", size = 642506 },
{ url = "https://files.pythonhosted.org/packages/18/87/470e01a940307796f1d25f8167b551a968540fbe0551c0ebb853cb527dd6/greenlet-3.1.1-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6510bf84a6b643dabba74d3049ead221257603a253d0a9873f55f6a59a65f822", size = 602753 },
{ url = "https://files.pythonhosted.org/packages/e2/72/576815ba674eddc3c25028238f74d7b8068902b3968cbe456771b166455e/greenlet-3.1.1-cp313-cp313t-musllinux_1_1_aarch64.whl", hash = "sha256:04b013dc07c96f83134b1e99888e7a79979f1a247e2a9f59697fa14b5862ed01", size = 1122731 },
{ url = "https://files.pythonhosted.org/packages/ac/38/08cc303ddddc4b3d7c628c3039a61a3aae36c241ed01393d00c2fd663473/greenlet-3.1.1-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:411f015496fec93c1c8cd4e5238da364e1da7a124bcb293f085bf2860c32c6f6", size = 1142112 },
]
[[package]] [[package]]
name = "h11" name = "h11"
version = "0.14.0" version = "0.14.0"
@ -495,11 +522,14 @@ wheels = [
[[package]] [[package]]
name = "oidc-fastapi-test" name = "oidc-fastapi-test"
version = "0.0.0"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "asyncpg" },
{ name = "authlib" }, { name = "authlib" },
{ name = "cachetools" }, { name = "cachetools" },
{ name = "fastapi", extra = ["standard"] }, { name = "fastapi", extra = ["standard"] },
{ name = "greenlet" },
{ name = "httpx" }, { name = "httpx" },
{ name = "itsdangerous" }, { name = "itsdangerous" },
{ name = "passlib", extra = ["bcrypt"] }, { name = "passlib", extra = ["bcrypt"] },
@ -513,16 +543,17 @@ dependencies = [
[package.dev-dependencies] [package.dev-dependencies]
dev = [ dev = [
{ name = "dunamai" },
{ name = "ipdb" }, { name = "ipdb" },
{ name = "pytest" }, { name = "pytest" },
] ]
[package.metadata] [package.metadata]
requires-dist = [ requires-dist = [
{ name = "asyncpg", specifier = ">=0.30.0" },
{ name = "authlib", specifier = ">=1.4.0" }, { name = "authlib", specifier = ">=1.4.0" },
{ name = "cachetools", specifier = ">=5.5.0" }, { name = "cachetools", specifier = ">=5.5.0" },
{ name = "fastapi", extras = ["standard"], specifier = ">=0.115.6" }, { name = "fastapi", extras = ["standard"], specifier = ">=0.115.6" },
{ name = "greenlet", specifier = ">=3.1.1" },
{ name = "httpx", specifier = ">=0.28.1" }, { name = "httpx", specifier = ">=0.28.1" },
{ name = "itsdangerous", specifier = ">=2.2.0" }, { name = "itsdangerous", specifier = ">=2.2.0" },
{ name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" }, { name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" },
@ -536,7 +567,6 @@ requires-dist = [
[package.metadata.requires-dev] [package.metadata.requires-dev]
dev = [ dev = [
{ name = "dunamai", specifier = ">=1.23.0" },
{ name = "ipdb", specifier = ">=0.13.13" }, { name = "ipdb", specifier = ">=0.13.13" },
{ name = "pytest", specifier = ">=8.3.4" }, { name = "pytest", specifier = ">=8.3.4" },
] ]