Compare commits
55 commits
Author | SHA1 | Date | |
---|---|---|---|
8b3a339196 | |||
b01f233208 | |||
4355e6dc42 | |||
c3ebad42d5 | |||
c5b1bdeda9 | |||
821df02758 | |||
9f7b090273 | |||
22d0a9852c | |||
6f060dc2bf | |||
f4b38e1c69 | |||
b465394766 | |||
9c46237905 | |||
3da485c945 | |||
9c1f843283 | |||
ef7c265d8e | |||
395ec1c7f7 | |||
9249885c80 | |||
5f429797ff | |||
850db9f590 | |||
f6a84fd3aa | |||
4c2b197850 | |||
347c395394 | |||
3f945310a4 | |||
ecdd3702f8 | |||
d924c56b17 | |||
0764b1c003 | |||
703985f311 | |||
e925f21762 | |||
435c11b6ca | |||
1c57944a90 | |||
4008036bca | |||
c89ca4098b | |||
40ddb61636 | |||
5bd4b82804 | |||
9d3146dc1c | |||
381ce1ebc1 | |||
0464047f8a | |||
64f6a90f22 | |||
e56be3c378 | |||
496ce016e3 | |||
c5bb4f4319 | |||
38b983c2a5 | |||
923a63f5d5 | |||
ff72f0cae5 | |||
3eb6dc3dcf | |||
d39adf41ef | |||
ee8ba3d2df | |||
5c9ed9724e | |||
76da695b66 | |||
b86ae4eb11 | |||
3dc14ae57b | |||
31a783cbf1 | |||
aa86f81358 | |||
fefe44acfe | |||
af49242192 |
24 changed files with 1461 additions and 761 deletions
|
@ -19,7 +19,7 @@ jobs:
|
|||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
with:
|
||||
version: "0.5.16"
|
||||
version: "0.6.9"
|
||||
|
||||
- name: Install
|
||||
run: uv sync
|
||||
|
@ -27,32 +27,26 @@ jobs:
|
|||
- name: Run tests (API call)
|
||||
run: .venv/bin/pytest -s tests/basic.py
|
||||
|
||||
- name: Get version with git describe
|
||||
id: version
|
||||
run: |
|
||||
echo "version=$(git describe)" >> $GITHUB_OUTPUT
|
||||
echo "$VERSION"
|
||||
- name: Get version
|
||||
run: echo "VERSION=$(.venv/bin/dunamai from any --style semver)" >> $GITHUB_ENV
|
||||
|
||||
- name: Check if the container should be built
|
||||
id: builder
|
||||
env:
|
||||
RUN: ${{ toJSON(inputs.build || !contains(steps.version.outputs.version, '-')) }}
|
||||
run: |
|
||||
echo "run=$RUN" >> $GITHUB_OUTPUT
|
||||
echo "Run build: $RUN"
|
||||
- name: Version
|
||||
run: echo $VERSION
|
||||
|
||||
- name: Set the version in pyproject.toml (workaround for uv not supporting dynamic version)
|
||||
if: fromJSON(steps.builder.outputs.run)
|
||||
env:
|
||||
VERSION: ${{ steps.version.outputs.version }}
|
||||
run: sed "s/0.0.0/$VERSION/" -i pyproject.toml
|
||||
- name: Get distance from tag
|
||||
run: echo "DISTANCE=$(.venv/bin/dunamai from any --format '{distance}')" >> $GITHUB_ENV
|
||||
|
||||
- name: Distance
|
||||
run: echo $DISTANCE
|
||||
|
||||
- name: Workaround for bug of podman-login
|
||||
if: env.DISTANCE == '0'
|
||||
run: |
|
||||
mkdir -p $HOME/.docker
|
||||
echo "{ \"auths\": {} }" > $HOME/.docker/config.json
|
||||
|
||||
- name: Log in to the container registry (with another workaround)
|
||||
if: env.DISTANCE == '0'
|
||||
uses: actions/podman-login@v1
|
||||
with:
|
||||
registry: ${{ vars.REGISTRY }}
|
||||
|
@ -61,26 +55,31 @@ jobs:
|
|||
auth_file_path: /tmp/auth.json
|
||||
|
||||
- name: Build the container image
|
||||
if: env.DISTANCE == '0'
|
||||
uses: actions/buildah-build@v1
|
||||
with:
|
||||
image: oidc-fastapi-test
|
||||
oci: true
|
||||
labels: oidc-fastapi-test
|
||||
tags: latest ${{ steps.version.outputs.version }}
|
||||
tags: "latest ${{ env.VERSION }}"
|
||||
containerfiles: |
|
||||
./Containerfile
|
||||
|
||||
- name: Push the image to the registry
|
||||
if: env.DISTANCE == '0'
|
||||
uses: actions/push-to-registry@v2
|
||||
with:
|
||||
registry: "docker://${{ vars.REGISTRY }}/${{ vars.ORGANISATION }}"
|
||||
image: oidc-fastapi-test
|
||||
tags: latest ${{ steps.version.outputs.version }}
|
||||
tags: "latest ${{ env.VERSION }}"
|
||||
|
||||
- name: Build wheel
|
||||
if: env.DISTANCE == '0'
|
||||
run: uv build --wheel
|
||||
|
||||
- name: Publish Python package (home)
|
||||
if: env.DISTANCE == '0'
|
||||
env:
|
||||
LOCAL_PYPI_TOKEN: ${{ secrets.LOCAL_PYPI_TOKEN }}
|
||||
run: uv publish --publish-url https://code.philo.ydns.eu/api/packages/philorg/pypi --token $LOCAL_PYPI_TOKEN
|
||||
continue-on-error: true
|
||||
|
|
|
@ -19,7 +19,7 @@ jobs:
|
|||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@v4
|
||||
with:
|
||||
version: "0.5.16"
|
||||
version: "0.6.3"
|
||||
|
||||
- name: Install
|
||||
run: uv sync
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
FROM docker.io/library/python:alpine
|
||||
FROM docker.io/library/python:latest
|
||||
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /usr/local/bin/
|
||||
|
||||
|
@ -9,6 +9,9 @@ WORKDIR /app
|
|||
|
||||
RUN uv pip install --system .
|
||||
|
||||
# Add demo plugin
|
||||
RUN PIP_EXTRA_INDEX_URL=https://pypi.org/simple/ uv pip install --system --index-url https://code.philo.ydns.eu/api/packages/philorg/pypi/simple/ oidc-fastapi-test-resource-provider-demo
|
||||
|
||||
# Possible to run with:
|
||||
#CMD ["oidc-test", "--port", "80"]
|
||||
#CMD ["fastapi", "run", "src/oidc_test/main.py", "--port", "8873", "--root-path", "/oidc-test"]
|
||||
|
|
60
README.md
60
README.md
|
@ -52,31 +52,59 @@ given by the OIDC providers.
|
|||
For example:
|
||||
|
||||
```yaml
|
||||
oidc:
|
||||
secret_key: "ASecretNoOneKnows"
|
||||
show_session_details: yes
|
||||
secret_key: AVeryWellKeptSecret
|
||||
debug_token: no
|
||||
show_token: yes
|
||||
log: yes
|
||||
|
||||
auth:
|
||||
providers:
|
||||
- id: auth0
|
||||
name: Okta / Auth0
|
||||
url: "https://<your_auth0_app_URL>"
|
||||
client_id: "<your_auth0_client_id>"
|
||||
client_secret: "client_secret_generated_by_auth0"
|
||||
hint: "A hint for test credentials"
|
||||
url: https://<your_auth0_app_URL>
|
||||
public_key_url: https://<your_auth0_app_URL>/pem
|
||||
client_id: <your_auth0_client_id>
|
||||
client_secret: client_secret_generated_by_auth0
|
||||
hint: A hint for test credentials
|
||||
|
||||
- id: keycloak
|
||||
name: Keycloak at somewhere
|
||||
url: "https://<the_keycloak_realm_url>"
|
||||
account_url_template: "/account"
|
||||
client_id: "<your_keycloak_client_id>"
|
||||
client_secret: "client_secret_generated_by_keycloak"
|
||||
hint: "User: foo, password: foofoo"
|
||||
url: https://<the_keycloak_realm_url>
|
||||
info_url: https://philo.ydns.eu/auth/realms/test
|
||||
account_url_template: /account
|
||||
client_id: <your_keycloak_client_id>
|
||||
client_secret: <client_secret_generated_by_keycloak>
|
||||
hint: A hint for test credentials
|
||||
code_challenge_method: S256
|
||||
resource_provider_scopes:
|
||||
- get:time
|
||||
- get:bs
|
||||
resource_providers:
|
||||
- id: <third_party_resource_provider_id>
|
||||
name: A third party resource provider
|
||||
base_url: https://some.example.com/
|
||||
verify_ssl: yes
|
||||
resources:
|
||||
- name: Public RS2
|
||||
resource_name: public
|
||||
url: resource/public
|
||||
- name: BS RS2
|
||||
resource_name: bs
|
||||
url: resource/bs
|
||||
- name: Time RS2
|
||||
resource_name: time
|
||||
url: resource/time
|
||||
|
||||
- id: codeberg
|
||||
disabled: no
|
||||
name: Codeberg
|
||||
url: "https://codeberg.org"
|
||||
account_url_template: "/user/settings"
|
||||
client_id: "<your_codeberg_client_id>"
|
||||
client_secret: "client_secret_generated_by_codeberg"
|
||||
url: https://codeberg.org
|
||||
account_url_template: /user/settings
|
||||
client_id: <your_codeberg_client_id>
|
||||
client_secret: client_secret_generated_by_codeberg
|
||||
info_url: https://codeberg.org/login/oauth/keys
|
||||
session_key: sub
|
||||
skip_verify_signature: no
|
||||
resources:
|
||||
- name: List of repos
|
||||
id: repos
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
[project]
|
||||
name = "oidc-fastapi-test"
|
||||
version = "0.0.0"
|
||||
# dynamic = ["version"]
|
||||
#version = "0.0.0"
|
||||
dynamic = ["version"]
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
|
@ -24,14 +24,24 @@ dependencies = [
|
|||
oidc-test = "oidc_test.main:main"
|
||||
|
||||
[dependency-groups]
|
||||
dev = ["ipdb>=0.13.13", "pytest>=8.3.4"]
|
||||
dev = ["dunamai>=1.23.0", "ipdb>=0.13.13", "pytest>=8.3.4"]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
requires = ["hatchling", "uv-dynamic-versioning"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.version]
|
||||
source = "uv-dynamic-versioning"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/oidc_test"]
|
||||
package = true
|
||||
|
||||
[tool.uv-dynamic-versioning]
|
||||
style = "semver"
|
||||
|
||||
[tool.uv]
|
||||
package = true
|
||||
|
||||
[tool.black]
|
||||
line-length = 98
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
import importlib.metadata
|
||||
|
||||
try:
|
||||
from dunamai import Version, Style
|
||||
|
||||
__version__ = Version.from_git().serialize(style=Style.SemVer, dirty=True)
|
||||
except ImportError:
|
||||
# __name__ could be used if the package name is the same
|
||||
# as the directory - not the case here
|
||||
# __version__ = importlib.metadata.version(__name__)
|
||||
__version__ = importlib.metadata.version("oidc-fastapi-test")
|
113
src/oidc_test/auth/provider.py
Normal file
113
src/oidc_test/auth/provider.py
Normal file
|
@ -0,0 +1,113 @@
|
|||
from json import JSONDecodeError
|
||||
from typing import Any
|
||||
from jwt import decode
|
||||
import logging
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
from httpx import AsyncClient
|
||||
|
||||
from oidc_test.settings import AuthProviderSettings, ResourceProvider, Resource, settings
|
||||
from oidc_test.models import User
|
||||
|
||||
logger = logging.getLogger("oidc-test")
|
||||
|
||||
|
||||
class NoPublicKey(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Provider(AuthProviderSettings):
|
||||
# To allow authlib_client as StarletteOAuth2App
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True) # type:ignore
|
||||
|
||||
authlib_client: StarletteOAuth2App = StarletteOAuth2App(None)
|
||||
info: dict[str, Any] = {}
|
||||
unknown_auth_user: User
|
||||
logout_with_id_token_hint: bool = True
|
||||
|
||||
def decode(self, token: str, verify_signature: bool | None = None) -> dict[str, Any]:
|
||||
"""Decode the token with signature check"""
|
||||
if self.public_key is None:
|
||||
raise NoPublicKey
|
||||
if verify_signature is None:
|
||||
verify_signature = self.skip_verify_signature
|
||||
if settings.debug_token:
|
||||
decoded = decode(
|
||||
token,
|
||||
self.public_key,
|
||||
algorithms=[self.signature_alg],
|
||||
audience=["account", "oidc-test", "oidc-test-web"],
|
||||
options={
|
||||
"verify_signature": False,
|
||||
"verify_aud": False,
|
||||
}, # not settings.insecure.skip_verify_signature},
|
||||
)
|
||||
logger.debug(str(decoded))
|
||||
return decode(
|
||||
token,
|
||||
self.public_key,
|
||||
algorithms=[self.signature_alg],
|
||||
audience=["account", "oidc-test", "oidc-test-web"],
|
||||
options={
|
||||
"verify_signature": verify_signature,
|
||||
}, # not settings.insecure.skip_verify_signature},
|
||||
)
|
||||
|
||||
async def get_info(self):
|
||||
# Get the public key:
|
||||
async with AsyncClient() as client:
|
||||
public_key: str | None = None
|
||||
if self.info_url is not None:
|
||||
try:
|
||||
provider_info = await client.get(self.info_url)
|
||||
except Exception as err:
|
||||
logger.debug("Provider_info: cannot connect")
|
||||
logger.exception(err)
|
||||
raise NoPublicKey
|
||||
try:
|
||||
self.info = provider_info.json()
|
||||
except JSONDecodeError:
|
||||
logger.debug("Provider_info: cannot decode json response")
|
||||
raise NoPublicKey
|
||||
if "public_key" in self.info:
|
||||
# For Keycloak
|
||||
try:
|
||||
public_key = str(self.info["public_key"])
|
||||
except KeyError:
|
||||
logger.debug("Provider_info: cannot get public_key")
|
||||
raise NoPublicKey
|
||||
elif "keys" in self.info:
|
||||
# For Forgejo/Gitea
|
||||
try:
|
||||
public_key = str(self.info["keys"][0]["n"])
|
||||
except KeyError:
|
||||
logger.debug("Provider_info: cannot get key 0.n")
|
||||
raise NoPublicKey
|
||||
if self.public_key_url is not None:
|
||||
resp = await client.get(self.public_key_url)
|
||||
public_key = resp.text
|
||||
if public_key is None:
|
||||
logger.debug("Provider_info: cannot determine public key")
|
||||
raise NoPublicKey
|
||||
self.public_key = "\n".join(
|
||||
["-----BEGIN PUBLIC KEY-----", public_key, "-----END PUBLIC KEY-----"]
|
||||
)
|
||||
|
||||
def get_session_key(self, userinfo):
|
||||
return userinfo[self.session_key]
|
||||
|
||||
def get_resource(self, resource_name: str) -> Resource:
|
||||
return [
|
||||
resource for resource in self.resources if resource.resource_name == resource_name
|
||||
][0]
|
||||
|
||||
def get_resource_url(self, resource_name: str) -> str:
|
||||
return self.url + self.get_resource(resource_name).url
|
||||
|
||||
def get_resource_provider(self, resource_provider_id: str) -> ResourceProvider:
|
||||
return [
|
||||
provider
|
||||
for provider in self.resource_providers
|
||||
if provider.id == resource_provider_id
|
||||
][0]
|
305
src/oidc_test/auth/utils.py
Normal file
305
src/oidc_test/auth/utils.py
Normal file
|
@ -0,0 +1,305 @@
|
|||
from typing import Union, Annotated
|
||||
from functools import wraps
|
||||
import logging
|
||||
|
||||
from fastapi import HTTPException, Request, Depends, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App
|
||||
from authlib.oauth2.rfc6749 import OAuth2Token
|
||||
from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError
|
||||
|
||||
from oidc_test.auth.provider import Provider
|
||||
from oidc_test.models import User
|
||||
from oidc_test.database import db, TokenNotInDb, UserNotInDB
|
||||
from oidc_test.settings import settings
|
||||
from oidc_test.auth_providers import providers
|
||||
|
||||
logger = logging.getLogger("oidc-test")
|
||||
|
||||
|
||||
async def fetch_token(name, request):
|
||||
assert name is not None
|
||||
assert request is not None
|
||||
logger.warning("TODO: fetch_token")
|
||||
...
|
||||
# if name in oidc_providers:
|
||||
# model = OAuth2Token
|
||||
# else:
|
||||
# model = OAuthToken
|
||||
|
||||
# token = model.find(name=name, user=request.user)
|
||||
# return token.to_token()
|
||||
|
||||
|
||||
async def update_token(
|
||||
provider_id,
|
||||
token,
|
||||
refresh_token: str | None = None,
|
||||
access_token: str | None = None,
|
||||
):
|
||||
"""Update the token in the database"""
|
||||
provider = providers[provider_id]
|
||||
sid: str = provider.get_session_key(provider.decode(token["id_token"]))
|
||||
item = await db.get_token(provider, sid)
|
||||
# update old token
|
||||
item["access_token"] = token["access_token"]
|
||||
item["refresh_token"] = token["refresh_token"]
|
||||
item["id_token"] = token["id_token"]
|
||||
item["expires_at"] = token["expires_at"]
|
||||
logger.info(f"Token {sid} refreshed")
|
||||
# It's a fake db and only in memory, so there's nothing to save
|
||||
# await item.save()
|
||||
|
||||
|
||||
def init_providers():
|
||||
"""Add oidc providers to authlib from the settings
|
||||
and build the providers dict"""
|
||||
for provider_settings in settings.auth.providers:
|
||||
provider_settings_dict = provider_settings.model_dump()
|
||||
# Add an anonymous user, that cannot be identified but has provided a valid access token
|
||||
provider_settings_dict["unknown_auth_user"] = User(
|
||||
sub="", auth_provider_id=provider_settings.id
|
||||
)
|
||||
provider = Provider(**provider_settings_dict)
|
||||
if provider.disabled:
|
||||
logger.info(f"{provider_settings.name} is disabled, skipping")
|
||||
else:
|
||||
authlib_oauth.register(
|
||||
name=provider.id,
|
||||
server_metadata_url=provider.openid_configuration,
|
||||
client_kwargs={
|
||||
"scope": " ".join(
|
||||
["openid", "email", "offline_access", "profile"]
|
||||
+ provider.resource_provider_scopes
|
||||
),
|
||||
},
|
||||
client_id=provider.client_id,
|
||||
client_secret=provider.client_secret,
|
||||
api_base_url=provider.url,
|
||||
# For PKCE (not implemented yet):
|
||||
# code_challenge_method="S256",
|
||||
fetch_token=fetch_token,
|
||||
update_token=update_token,
|
||||
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
|
||||
)
|
||||
provider.authlib_client = getattr(authlib_oauth, provider.id)
|
||||
providers[provider.id] = provider
|
||||
|
||||
|
||||
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
|
||||
|
||||
|
||||
def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None:
|
||||
"""Return the oidc_provider from a request object, from the session.
|
||||
It can be used in Depends()"""
|
||||
if (auth_provider_id := request.session.get("auth_provider_id")) is None:
|
||||
return
|
||||
return getattr(authlib_oauth, str(auth_provider_id), None)
|
||||
|
||||
|
||||
def get_auth_provider_client(request: Request) -> StarletteOAuth2App:
|
||||
if (oidc_provider := get_auth_provider_client_or_none(request)) is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
||||
else:
|
||||
return oidc_provider
|
||||
|
||||
|
||||
def get_auth_provider_or_none(request: Request) -> Provider | None:
|
||||
"""Return the oidc_provider settings from a request object, from the session.
|
||||
It can be used in Depends()"""
|
||||
if (auth_provider_id := request.session.get("auth_provider_id")) is None:
|
||||
return
|
||||
return providers.get(auth_provider_id)
|
||||
|
||||
|
||||
def get_auth_provider(request: Request) -> Provider:
|
||||
if (provider := get_auth_provider_or_none(request)) is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
||||
return provider
|
||||
|
||||
|
||||
async def get_current_user(request: Request) -> User:
|
||||
"""Get the current user from a request object.
|
||||
Also validates the token expiration time.
|
||||
... TODO: complete about refresh token
|
||||
"""
|
||||
if (user_sub := request.session.get("user_sub")) is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
token = await get_token_from_session(request)
|
||||
user = await db.get_user(user_sub)
|
||||
## Check if the token is expired
|
||||
if token.is_expired():
|
||||
provider = get_auth_provider(request=request)
|
||||
## Ask a new refresh token from the provider
|
||||
logger.info(f"Token expired for user {user.name}")
|
||||
try:
|
||||
userinfo = await provider.authlib_client.fetch_access_token(
|
||||
refresh_token=token.get("refresh_token")
|
||||
)
|
||||
assert userinfo is not None
|
||||
except OAuthError as err:
|
||||
logger.exception(err)
|
||||
# raise HTTPException(
|
||||
# status.HTTP_401_UNAUTHORIZED, "Token expired, cannot refresh"
|
||||
# )
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_token_from_session_or_none(request: Request) -> OAuth2Token | None:
|
||||
"""Return the auth token from the session or None.
|
||||
Can be used in Depends()"""
|
||||
try:
|
||||
return await get_token_from_session(request)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
async def get_token_from_session(request: Request) -> OAuth2Token:
|
||||
"""Return the token from the session.
|
||||
Can be used in Depends()"""
|
||||
try:
|
||||
provider = providers[request.session.get("auth_provider_id", "")]
|
||||
except KeyError:
|
||||
request.session.pop("auth_provider_id", None)
|
||||
request.session.pop("user_sub", None)
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid provider")
|
||||
try:
|
||||
return await db.get_token(
|
||||
provider,
|
||||
request.session.get("sid"),
|
||||
)
|
||||
except (TokenNotInDb, InvalidKeyError, DecodeError) as err:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, err.__class__.__name__)
|
||||
|
||||
|
||||
async def get_current_user_or_none(request: Request) -> User | None:
|
||||
"""Return the user from a request object, from the session.
|
||||
It can be used in Depends()"""
|
||||
try:
|
||||
return await get_current_user(request)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
def hasrole(required_roles: Union[str, list[str]] = []):
|
||||
"""Decorator for RBAC permissions"""
|
||||
required_roles_set: set[str]
|
||||
if isinstance(required_roles, str):
|
||||
required_roles_set = set([required_roles])
|
||||
else:
|
||||
required_roles_set = set(required_roles)
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(request=None, *args, **kwargs):
|
||||
if request is None:
|
||||
raise HTTPException(
|
||||
500,
|
||||
"Functions decorated with hasrole must have a request:Request argument",
|
||||
)
|
||||
user: User = await get_current_user(request)
|
||||
if not any(required_roles_set.intersection(user.roles_as_set)):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
return await func(request, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_token_info(token: dict) -> dict:
|
||||
token_info = dict()
|
||||
for key in token:
|
||||
if key != "userinfo":
|
||||
token_info[key] = token[key]
|
||||
return token_info
|
||||
|
||||
|
||||
async def get_user_from_token(
|
||||
token: Annotated[str, Depends(oauth2_scheme)],
|
||||
request: Request,
|
||||
) -> User:
|
||||
try:
|
||||
auth_provider_id = request.headers["auth_provider"]
|
||||
except KeyError:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED,
|
||||
"Request headers must have a 'auth_provider' field",
|
||||
)
|
||||
try:
|
||||
provider = providers[auth_provider_id]
|
||||
except KeyError:
|
||||
if auth_provider_id == "":
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No auth provider")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
|
||||
)
|
||||
if token == "None":
|
||||
request.session.pop("auth_provider_id", None)
|
||||
request.session.pop("user_sub", None)
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token")
|
||||
try:
|
||||
payload = provider.decode(token)
|
||||
except ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED,
|
||||
"Expired signature (token refresh not implemented yet)",
|
||||
)
|
||||
except InvalidKeyError:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key")
|
||||
except PyJWTError as err:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED, f"Cannot decode token: {err.__class__.__name__}"
|
||||
)
|
||||
try:
|
||||
user_id = payload["sub"]
|
||||
except KeyError:
|
||||
return provider.unknown_auth_user
|
||||
try:
|
||||
user = await db.get_user(user_id)
|
||||
if user.access_token != token:
|
||||
user.access_token = token
|
||||
except UserNotInDB:
|
||||
logger.info(
|
||||
f"User {user_id} not found in DB, creating it (real apps can behave differently)"
|
||||
)
|
||||
user = await db.add_user(
|
||||
sub=payload["sub"],
|
||||
user_info=payload,
|
||||
auth_provider=providers[auth_provider_id],
|
||||
access_token=token,
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
async def get_user_from_token_or_none(
|
||||
token: Annotated[str | None, Depends(oauth2_scheme_optional)],
|
||||
request: Request,
|
||||
) -> User | None:
|
||||
if token is None:
|
||||
return None
|
||||
try:
|
||||
return await get_user_from_token(token, request)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
class UserWithRole:
|
||||
roles: set[str]
|
||||
|
||||
def __init__(self, roles: str | list[str] | tuple[str] | set[str]):
|
||||
if isinstance(roles, str):
|
||||
self.roles = set([roles])
|
||||
elif isinstance(roles, (list, tuple, set)):
|
||||
self.roles = set(roles)
|
||||
|
||||
def __call__(self, user: User = Depends(get_user_from_token)) -> User:
|
||||
if not any(self.roles.intersection(user.roles_as_set)):
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED, f"Not of any required role {', '.join(self.roles)}"
|
||||
)
|
||||
return user
|
|
@ -1,29 +0,0 @@
|
|||
from datetime import datetime, timedelta
|
||||
from collections import OrderedDict
|
||||
|
||||
from .models import User
|
||||
|
||||
time_keys = set(("iat", "exp", "auth_time", "updated_at"))
|
||||
|
||||
|
||||
def pretty_details(user: User, now: datetime) -> OrderedDict:
|
||||
details = OrderedDict()
|
||||
# breakpoint()
|
||||
for key in sorted(time_keys):
|
||||
try:
|
||||
dt = datetime.fromtimestamp(user.userinfo[key])
|
||||
except (KeyError, TypeError):
|
||||
pass
|
||||
else:
|
||||
td = now - dt
|
||||
td = timedelta(days=td.days, seconds=td.seconds)
|
||||
if td.days < 0:
|
||||
ptd = f"in {-td} h:m:s"
|
||||
else:
|
||||
ptd = f"{td} h:m:s ago"
|
||||
details[key] = f"{user.userinfo[key]} - {dt} ({ptd})"
|
||||
for key in sorted(user.userinfo):
|
||||
if key in time_keys:
|
||||
continue
|
||||
details[key] = user.userinfo[key]
|
||||
return details
|
5
src/oidc_test/auth_providers.py
Normal file
5
src/oidc_test/auth_providers.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
from collections import OrderedDict
|
||||
|
||||
from oidc_test.auth.provider import Provider
|
||||
|
||||
providers: OrderedDict[str, Provider] = OrderedDict()
|
|
@ -1,227 +0,0 @@
|
|||
from typing import Union, Annotated
|
||||
from functools import wraps
|
||||
import logging
|
||||
|
||||
from fastapi import HTTPException, Request, Depends, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from authlib.oauth2.rfc6749 import OAuth2Token
|
||||
from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App
|
||||
from jwt import ExpiredSignatureError, InvalidKeyError
|
||||
from httpx import AsyncClient
|
||||
|
||||
# from authlib.oauth1.auth import OAuthToken
|
||||
# from authlib.oauth2.auth import OAuth2Token
|
||||
|
||||
from .models import User
|
||||
from .database import TokenNotInDb, db, UserNotInDB
|
||||
from .settings import settings, OIDCProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
oidc_providers_settings: dict[str, OIDCProvider] = dict(
|
||||
[(provider.id, provider) for provider in settings.oidc.providers]
|
||||
)
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
|
||||
async def fetch_token(name, request):
|
||||
logger.warn("TODO: fetch_token")
|
||||
...
|
||||
# if name in oidc_providers:
|
||||
# model = OAuth2Token
|
||||
# else:
|
||||
# model = OAuthToken
|
||||
|
||||
# token = model.find(name=name, user=request.user)
|
||||
# return token.to_token()
|
||||
|
||||
|
||||
async def update_token(*args, **kwargs):
|
||||
logger.warn("TODO: update_token")
|
||||
...
|
||||
|
||||
|
||||
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
|
||||
|
||||
|
||||
def init_providers():
|
||||
# Add oidc providers to authlib from the settings
|
||||
for id, provider in oidc_providers_settings.items():
|
||||
authlib_oauth.register(
|
||||
name=id,
|
||||
server_metadata_url=provider.openid_configuration,
|
||||
client_kwargs={
|
||||
"scope": "openid email offline_access profile",
|
||||
},
|
||||
client_id=provider.client_id,
|
||||
client_secret=provider.client_secret,
|
||||
api_base_url=provider.url,
|
||||
# For PKCE (not implemented yet):
|
||||
# code_challenge_method="S256",
|
||||
# fetch_token=fetch_token,
|
||||
# update_token=update_token,
|
||||
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
|
||||
)
|
||||
|
||||
|
||||
init_providers()
|
||||
|
||||
|
||||
async def get_providers_info():
|
||||
# Get the public key:
|
||||
async with AsyncClient() as client:
|
||||
for provider_settings in oidc_providers_settings.values():
|
||||
if provider_settings.info_url:
|
||||
provider_info = await client.get(provider_settings.url)
|
||||
provider_settings.info = provider_info.json()
|
||||
|
||||
|
||||
def get_oidc_provider_or_none(request: Request) -> StarletteOAuth2App | None:
|
||||
"""Return the oidc_provider from a request object, from the session.
|
||||
It can be used in Depends()"""
|
||||
if (oidc_provider_id := request.session.get("oidc_provider_id")) is None:
|
||||
return
|
||||
try:
|
||||
return getattr(authlib_oauth, str(oidc_provider_id))
|
||||
except AttributeError:
|
||||
return
|
||||
|
||||
|
||||
def get_oidc_provider(request: Request) -> StarletteOAuth2App:
|
||||
if (oidc_provider := get_oidc_provider_or_none(request)) is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
||||
else:
|
||||
return oidc_provider
|
||||
|
||||
|
||||
async def get_current_user(request: Request) -> User:
|
||||
"""Get the current user from a request object.
|
||||
Also validates the token expiration time.
|
||||
... TODO: complete about refresh token
|
||||
"""
|
||||
if (user_sub := request.session.get("user_sub")) is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
if (token := await db.get_token(request.session["token"])) is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token unknown")
|
||||
user = await db.get_user(user_sub)
|
||||
## Check if the token is expired
|
||||
if token.is_expired():
|
||||
oidc_provider = get_oidc_provider(request=request)
|
||||
## Ask a new refresh token from the provider
|
||||
logger.info(f"Token expired for user {user.name}")
|
||||
try:
|
||||
userinfo = await oidc_provider.fetch_access_token(
|
||||
refresh_token=token.get("refresh_token")
|
||||
)
|
||||
except OAuthError as err:
|
||||
logger.exception(err)
|
||||
# raise HTTPException(
|
||||
# status.HTTP_401_UNAUTHORIZED, "Token expired, cannot refresh"
|
||||
# )
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_token(request: Request) -> OAuth2Token:
|
||||
"""Return the token from a request object, from the session.
|
||||
It can be used in Depends()"""
|
||||
try:
|
||||
return await db.get_token(request.session["token"])
|
||||
except (KeyError, TokenNotInDb):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token")
|
||||
|
||||
|
||||
async def get_current_user_or_none(request: Request) -> User | None:
|
||||
"""Return the user from a request object, from the session.
|
||||
It can be used in Depends()"""
|
||||
try:
|
||||
return await get_current_user(request)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
def hasrole(required_roles: Union[str, list[str]] = []):
|
||||
"""Decorator for RBAC permissions"""
|
||||
required_roles_set: set[str]
|
||||
if isinstance(required_roles, str):
|
||||
required_roles_set = set([required_roles])
|
||||
else:
|
||||
required_roles_set = set(required_roles)
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(request=None, *args, **kwargs):
|
||||
if request is None:
|
||||
raise HTTPException(
|
||||
500,
|
||||
"Functions decorated with hasrole must have a request:Request argument",
|
||||
)
|
||||
user: User = await get_current_user(request)
|
||||
if not any(required_roles_set.intersection(user.roles_as_set)):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
return await func(request, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_token_info(token: dict) -> dict:
|
||||
token_info = dict()
|
||||
for key in token:
|
||||
if key != "userinfo":
|
||||
token_info[key] = token[key]
|
||||
return token_info
|
||||
|
||||
|
||||
async def get_user_from_token(
|
||||
token: Annotated[str, Depends(oauth2_scheme)],
|
||||
request: Request,
|
||||
) -> User:
|
||||
if (auth_provider_id := request.headers.get("auth_provider")) is None:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED,
|
||||
"Request headers must have a 'auth_provider' field",
|
||||
)
|
||||
if (
|
||||
auth_provider_settings := oidc_providers_settings.get(auth_provider_id)
|
||||
) is None:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
|
||||
)
|
||||
try:
|
||||
payload = auth_provider_settings.decode(token)
|
||||
except ExpiredSignatureError as err:
|
||||
logger.info(f"Expired signature: {err}")
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED,
|
||||
"Expired signature (refresh not implemented yet)",
|
||||
)
|
||||
except InvalidKeyError as err:
|
||||
logger.info(f"Invalid key: {err}")
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key")
|
||||
except Exception as err:
|
||||
logger.info("Cannot decode token, see below")
|
||||
logger.exception(err)
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot decode token")
|
||||
if (user_id := payload.get("sub")) is None:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found"
|
||||
)
|
||||
try:
|
||||
user = await db.get_user(user_id)
|
||||
if user.access_token != token:
|
||||
user.access_token = token
|
||||
except UserNotInDB:
|
||||
logger.info(
|
||||
f"User {user_id} not found in DB, creating it (real apps can behave differently"
|
||||
)
|
||||
user = await db.add_user(
|
||||
sub=payload["sub"],
|
||||
user_info=payload,
|
||||
oidc_provider=getattr(authlib_oauth, auth_provider_id),
|
||||
user_info_from_endpoint={},
|
||||
access_token=token,
|
||||
)
|
||||
return user
|
|
@ -2,12 +2,15 @@
|
|||
|
||||
import logging
|
||||
|
||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
|
||||
from .models import User, Role
|
||||
from authlib.oauth2.rfc6749 import OAuth2Token
|
||||
from jwt import PyJWTError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from oidc_test.auth.provider import Provider
|
||||
|
||||
from oidc_test.models import User, Role
|
||||
from oidc_test.auth_providers import providers
|
||||
|
||||
logger = logging.getLogger("oidc-test")
|
||||
|
||||
|
||||
class UserNotInDB(Exception):
|
||||
|
@ -20,6 +23,7 @@ class TokenNotInDb(Exception):
|
|||
|
||||
class Database:
|
||||
users: dict[str, User] = {}
|
||||
# TODO: key of the token table should be provider: sid
|
||||
tokens: dict[str, OAuth2Token] = {}
|
||||
|
||||
# Last sessions for the user (key: users's subject id (sub))
|
||||
|
@ -28,21 +32,38 @@ class Database:
|
|||
self,
|
||||
sub: str,
|
||||
user_info: dict,
|
||||
oidc_provider: StarletteOAuth2App,
|
||||
user_info_from_endpoint: dict,
|
||||
auth_provider: Provider,
|
||||
access_token: str,
|
||||
access_token_decoded: dict | None = None,
|
||||
) -> User:
|
||||
user = User.from_auth(userinfo=user_info, oidc_provider=oidc_provider)
|
||||
user.access_token = access_token
|
||||
if access_token_decoded is None:
|
||||
assert auth_provider.name is not None
|
||||
provider = providers[auth_provider.id]
|
||||
try:
|
||||
raw_roles = user_info_from_endpoint["resource_access"][
|
||||
oidc_provider.client_id
|
||||
]["roles"]
|
||||
except Exception as err:
|
||||
logger.debug(f"Cannot read additional roles: {err}")
|
||||
raw_roles = []
|
||||
for raw_role in raw_roles:
|
||||
user.roles.append(Role(name=raw_role))
|
||||
access_token_decoded = provider.decode(access_token)
|
||||
except PyJWTError:
|
||||
access_token_decoded = {}
|
||||
user_info["auth_provider_id"] = auth_provider.id
|
||||
user = User(**user_info)
|
||||
user.userinfo = user_info
|
||||
# user.access_token = access_token
|
||||
# user.access_token_decoded = access_token_decoded
|
||||
# Add roles provided in the access token
|
||||
roles = set()
|
||||
try:
|
||||
r = access_token_decoded["resource_access"][auth_provider.client_id]["roles"]
|
||||
roles.update(r)
|
||||
except KeyError:
|
||||
pass
|
||||
try:
|
||||
r = access_token_decoded["realm_access"]["roles"]
|
||||
if isinstance(r, str):
|
||||
roles.add(r)
|
||||
else:
|
||||
roles.update(r)
|
||||
except KeyError:
|
||||
pass
|
||||
user.roles = [Role(name=role_name) for role_name in roles]
|
||||
self.users[sub] = user
|
||||
return user
|
||||
|
||||
|
@ -51,12 +72,23 @@ class Database:
|
|||
raise UserNotInDB
|
||||
return self.users[sub]
|
||||
|
||||
async def add_token(self, token: OAuth2Token, user: User) -> None:
|
||||
self.tokens[token["id_token"]] = token
|
||||
async def add_token(self, provider: Provider, token: OAuth2Token) -> None:
|
||||
"""Store a token using as key the sid (auth provider's session id)
|
||||
in the id_token"""
|
||||
sid = provider.get_session_key(token["userinfo"])
|
||||
self.tokens[sid] = token
|
||||
|
||||
async def get_token(self, id_token: str) -> OAuth2Token:
|
||||
async def get_token(
|
||||
self,
|
||||
provider: Provider,
|
||||
sid: str | None,
|
||||
) -> OAuth2Token:
|
||||
# TODO: key of the token table should be provider: sid
|
||||
assert isinstance(provider, Provider)
|
||||
if sid is None:
|
||||
raise TokenNotInDb
|
||||
try:
|
||||
return self.tokens[id_token]
|
||||
return self.tokens[sid]
|
||||
except KeyError:
|
||||
raise TokenNotInDb
|
||||
|
||||
|
|
34
src/oidc_test/log_conf.yaml
Normal file
34
src/oidc_test/log_conf.yaml
Normal file
|
@ -0,0 +1,34 @@
|
|||
version: 1
|
||||
disable_existing_loggers: False
|
||||
formatters:
|
||||
default:
|
||||
"()": uvicorn.logging.DefaultFormatter
|
||||
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
access:
|
||||
"()": uvicorn.logging.AccessFormatter
|
||||
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
handlers:
|
||||
default:
|
||||
formatter: default
|
||||
class: logging.StreamHandler
|
||||
stream: ext://sys.stderr
|
||||
access:
|
||||
formatter: access
|
||||
class: logging.StreamHandler
|
||||
stream: ext://sys.stdout
|
||||
loggers:
|
||||
uvicorn.error:
|
||||
level: INFO
|
||||
handlers:
|
||||
- default
|
||||
propagate: no
|
||||
uvicorn.access:
|
||||
level: INFO
|
||||
handlers:
|
||||
- access
|
||||
propagate: no
|
||||
"oidc-test":
|
||||
level: DEBUG
|
||||
handlers:
|
||||
- default
|
||||
propagate: yes
|
|
@ -6,15 +6,19 @@ from typing import Annotated
|
|||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import logging.config
|
||||
import importlib.resources
|
||||
from yaml import safe_load
|
||||
from urllib.parse import urlencode
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from httpx import HTTPError
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from jwt import PyJWTError
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
from authlib.integrations.base_client import OAuthError
|
||||
|
@ -25,32 +29,52 @@ from authlib.oauth2.rfc6749 import OAuth2Token
|
|||
# from fastapi.security import OpenIdConnect
|
||||
# from pkce import generate_code_verifier, generate_pkce_pair
|
||||
|
||||
from .settings import settings
|
||||
from .models import User
|
||||
from .auth_utils import (
|
||||
get_oidc_provider,
|
||||
get_oidc_provider_or_none,
|
||||
hasrole,
|
||||
from oidc_test import __version__
|
||||
from oidc_test.registry import registry
|
||||
from oidc_test.auth.provider import NoPublicKey, Provider
|
||||
from oidc_test.auth.utils import (
|
||||
get_auth_provider,
|
||||
get_auth_provider_or_none,
|
||||
get_current_user_or_none,
|
||||
get_current_user,
|
||||
get_user_from_token,
|
||||
authlib_oauth,
|
||||
get_token,
|
||||
oidc_providers_settings,
|
||||
get_providers_info,
|
||||
get_token_from_session_or_none,
|
||||
get_token_from_session,
|
||||
update_token,
|
||||
)
|
||||
from .auth_misc import pretty_details
|
||||
from .database import db
|
||||
from .resource_server import get_resource
|
||||
from oidc_test.auth.utils import init_providers
|
||||
from oidc_test.settings import settings
|
||||
from oidc_test.auth_providers import providers
|
||||
from oidc_test.models import User
|
||||
from oidc_test.database import TokenNotInDb, db
|
||||
from oidc_test.resource_server import resource_server
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
logger = logging.getLogger("oidc-test")
|
||||
|
||||
if settings.log:
|
||||
assert __package__ is not None
|
||||
with (
|
||||
importlib.resources.path(__package__) as package_path,
|
||||
open(package_path / settings.log_config_file) as f,
|
||||
):
|
||||
logging_config = safe_load(f)
|
||||
logging.config.dictConfig(logging_config)
|
||||
|
||||
templates = Jinja2Templates(Path(__file__).parent / "templates")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await get_providers_info()
|
||||
assert app is not None
|
||||
init_providers()
|
||||
registry.make_registry()
|
||||
for provider in list(providers.values()):
|
||||
if provider.disabled:
|
||||
continue
|
||||
try:
|
||||
await provider.get_info()
|
||||
except NoPublicKey:
|
||||
logger.warning(f"Disable {provider.id}: public key not found")
|
||||
del providers[provider.id]
|
||||
yield
|
||||
|
||||
|
||||
|
@ -64,74 +88,79 @@ app.add_middleware(
|
|||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.mount(
|
||||
"/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static"
|
||||
)
|
||||
|
||||
# SessionMiddleware is required by authlib
|
||||
app.add_middleware(
|
||||
SessionMiddleware,
|
||||
secret_key=settings.secret_key,
|
||||
)
|
||||
|
||||
app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static")
|
||||
app.mount("/resource", resource_server, name="resource_server")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def home(
|
||||
request: Request,
|
||||
user: Annotated[User, Depends(get_current_user_or_none)],
|
||||
oidc_provider: Annotated[
|
||||
StarletteOAuth2App | None, Depends(get_oidc_provider_or_none)
|
||||
],
|
||||
user: Annotated[User | None, Depends(get_current_user_or_none)],
|
||||
provider: Annotated[Provider | None, Depends(get_auth_provider_or_none)],
|
||||
token: Annotated[OAuth2Token | None, Depends(get_token_from_session_or_none)],
|
||||
) -> HTMLResponse:
|
||||
now = datetime.now()
|
||||
if oidc_provider and (
|
||||
(
|
||||
oidc_provider_settings := oidc_providers_settings.get(
|
||||
request.session.get("oidc_provider_id", "")
|
||||
)
|
||||
)
|
||||
is not None
|
||||
):
|
||||
resources = oidc_provider_settings.resources
|
||||
else:
|
||||
resources = []
|
||||
oidc_provider_settings = None
|
||||
return templates.TemplateResponse(
|
||||
name="home.html",
|
||||
request=request,
|
||||
context={
|
||||
"settings": settings.model_dump(),
|
||||
context = {
|
||||
"show_token": settings.show_token,
|
||||
"user": user,
|
||||
"now": now,
|
||||
"oidc_provider": oidc_provider,
|
||||
"oidc_provider_settings": oidc_provider_settings,
|
||||
"resources": resources,
|
||||
"user_info_details": (
|
||||
pretty_details(user, now)
|
||||
if user and settings.oidc.show_session_details
|
||||
else None
|
||||
),
|
||||
},
|
||||
)
|
||||
"now": datetime.now(),
|
||||
"__version__": __version__,
|
||||
}
|
||||
if provider is None or token is None:
|
||||
context["providers"] = providers
|
||||
context["access_token"] = None
|
||||
context["id_token_parsed"] = None
|
||||
context["access_token_parsed"] = None
|
||||
context["refresh_token_parsed"] = None
|
||||
context["resources"] = None
|
||||
context["auth_provider"] = None
|
||||
else:
|
||||
context["auth_provider"] = provider
|
||||
context["access_token"] = token["access_token"]
|
||||
try:
|
||||
access_token_parsed = provider.decode(token["access_token"], verify_signature=False)
|
||||
context["access_token_parsed"] = access_token_parsed
|
||||
context["access_token_scope"] = access_token_parsed.get("scope")
|
||||
except PyJWTError as err:
|
||||
context["access_token_parsed"] = {"Cannot parse": err.__class__.__name__}
|
||||
context["access_token_scope"] = None
|
||||
try:
|
||||
id_token_parsed = provider.decode(token["id_token"], verify_signature=False)
|
||||
context["id_token_parsed"] = id_token_parsed
|
||||
except PyJWTError as err:
|
||||
context["id_token_parsed"] = {"Cannot parse": err.__class__.__name__}
|
||||
try:
|
||||
refresh_token_parsed = provider.decode(token["refresh_token"], verify_signature=False)
|
||||
context["refresh_token_parsed"] = refresh_token_parsed
|
||||
except PyJWTError as err:
|
||||
context["refresh_token_parsed"] = {"Cannot parse": err.__class__.__name__}
|
||||
context["resources"] = registry.resources
|
||||
context["resource_providers"] = provider.resource_providers
|
||||
return templates.TemplateResponse(name="home.html", request=request, context=context)
|
||||
|
||||
|
||||
# Endpoints for the login / authorization process
|
||||
|
||||
|
||||
@app.get("/login/{oidc_provider_id}")
|
||||
async def login(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
||||
@app.get("/login/{auth_provider_id}")
|
||||
async def login(request: Request, auth_provider_id: str) -> RedirectResponse:
|
||||
"""Login with the provider id, giving the browser a redirect to its authorize page.
|
||||
The provider is expected to send the browser back to our own /auth/{oidc_provider_id} url
|
||||
The provider is expected to send the browser back to our own /auth/{auth_provider_id} url
|
||||
with the token.
|
||||
"""
|
||||
redirect_uri = request.url_for("auth", oidc_provider_id=oidc_provider_id)
|
||||
redirect_uri = request.url_for("auth", auth_provider_id=auth_provider_id)
|
||||
try:
|
||||
provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
|
||||
provider: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id)
|
||||
except AttributeError:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
||||
# if (
|
||||
# code_challenge_method := oidc_providers_settings[
|
||||
# oidc_provider_id
|
||||
# code_challenge_method := providers[
|
||||
# auth_provider_id
|
||||
# ].code_challenge_method
|
||||
# ) is not None:
|
||||
# #client = AsyncOAuth2Client(..., code_challenge_method=code_challenge_method)
|
||||
|
@ -151,217 +180,144 @@ async def login(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
|||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
|
||||
|
||||
|
||||
@app.get("/auth/{oidc_provider_id}")
|
||||
async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
||||
@app.get("/auth/{auth_provider_id}")
|
||||
async def auth(
|
||||
request: Request,
|
||||
auth_provider_id: str,
|
||||
) -> RedirectResponse:
|
||||
"""Decrypt the auth token, store it to the session (cookie based)
|
||||
and response to the browser with a redirect to a "welcome user" page.
|
||||
"""
|
||||
try:
|
||||
oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id)
|
||||
except AttributeError:
|
||||
provider = providers[auth_provider_id]
|
||||
except KeyError:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
|
||||
try:
|
||||
token: OAuth2Token = await oidc_provider.authorize_access_token(request)
|
||||
token: OAuth2Token = await provider.authlib_client.authorize_access_token(request)
|
||||
except OAuthError as error:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error)
|
||||
# Remember the oidc_provider in the session
|
||||
# Remember the authlib_client in the session
|
||||
# logger.info(f"Scope: {token['scope']}")
|
||||
request.session["oidc_provider_id"] = oidc_provider_id
|
||||
request.session["auth_provider_id"] = auth_provider_id
|
||||
#
|
||||
# One could process the full decoded token which contains extra information
|
||||
# eg for updates. Here we are only interested in roles
|
||||
#
|
||||
if userinfo := token.get("userinfo"):
|
||||
# Remember the oidc_provider in the session
|
||||
request.session["oidc_provider_id"] = oidc_provider_id
|
||||
# User id (sub) given by oidc provider
|
||||
# Remember the authlib_client in the session
|
||||
request.session["auth_provider_id"] = auth_provider_id
|
||||
# User id (sub) given by auth provider
|
||||
sub = userinfo["sub"]
|
||||
# Get additional data from userinfo endpoint
|
||||
try:
|
||||
user_info_from_endpoint = await oidc_provider.userinfo(
|
||||
token=token, follow_redirects=True
|
||||
)
|
||||
except Exception as err:
|
||||
logger.warn(f"Cannot get userinfo from endpoint: {err}")
|
||||
user_info_from_endpoint = {}
|
||||
## Get additional data from userinfo endpoint
|
||||
# try:
|
||||
# user_info_from_endpoint = await authlib_client.userinfo(
|
||||
# token=token, follow_redirects=True
|
||||
# )
|
||||
# except Exception as err:
|
||||
# logger.warn(f"Cannot get userinfo from endpoint: {err}")
|
||||
# user_info_from_endpoint = {}
|
||||
# Build and remember the user in the session
|
||||
request.session["user_sub"] = sub
|
||||
# Store the user in the database
|
||||
# Store the user in the database, which also verifies the token validity and signature
|
||||
try:
|
||||
oidc_provider_settings = oidc_providers_settings[oidc_provider_id]
|
||||
access_token = oidc_provider_settings.decode(token["access_token"])
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Cannot decode token or verify its signature",
|
||||
)
|
||||
user = await db.add_user(
|
||||
sub,
|
||||
user_info=userinfo,
|
||||
oidc_provider=oidc_provider,
|
||||
user_info_from_endpoint=user_info_from_endpoint,
|
||||
auth_provider=providers[auth_provider_id],
|
||||
access_token=token["access_token"],
|
||||
)
|
||||
# Add the id_token to the session
|
||||
request.session["token"] = token["id_token"]
|
||||
except PyJWTError as err:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Token invalid: {err.__class__.__name__}",
|
||||
)
|
||||
assert isinstance(user, User)
|
||||
# Add the provider session id to the session
|
||||
request.session["sid"] = provider.get_session_key(userinfo)
|
||||
# Add the token to the db because it is used for logout
|
||||
await db.add_token(token, user)
|
||||
await db.add_token(provider, token)
|
||||
# Send the user to the home: (s)he is authenticated
|
||||
return RedirectResponse(url=request.url_for("home"))
|
||||
else:
|
||||
# Not sure if it's correct to redirect to plain login
|
||||
# if no userinfo is provided
|
||||
return RedirectResponse(
|
||||
url=request.url_for("login", oidc_provider_id=oidc_provider_id)
|
||||
)
|
||||
return RedirectResponse(url=request.url_for("login", auth_provider_id=auth_provider_id))
|
||||
|
||||
|
||||
@app.get("/account")
|
||||
async def account(
|
||||
request: Request,
|
||||
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
||||
provider: Annotated[Provider, Depends(get_auth_provider)],
|
||||
) -> RedirectResponse:
|
||||
if (
|
||||
oidc_provider_settings := oidc_providers_settings.get(
|
||||
request.session.get("oidc_provider_id", "")
|
||||
)
|
||||
) is None:
|
||||
raise HTTPException(
|
||||
status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider settings"
|
||||
)
|
||||
return RedirectResponse(f"{oidc_provider_settings.account_url_template}")
|
||||
"""Redirect to the auth provider account management,
|
||||
if account_url_template is in the provider's settings"""
|
||||
return RedirectResponse(f"{provider.account_url_template}")
|
||||
|
||||
|
||||
@app.get("/logout")
|
||||
async def logout(
|
||||
request: Request,
|
||||
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
||||
provider: Annotated[Provider, Depends(get_auth_provider)],
|
||||
) -> RedirectResponse:
|
||||
# Clear session
|
||||
request.session.pop("user_sub", None)
|
||||
# Get provider's endpoint
|
||||
if (
|
||||
provider_logout_uri := oidc_provider.server_metadata.get("end_session_endpoint")
|
||||
provider_logout_uri := provider.authlib_client.server_metadata.get("end_session_endpoint")
|
||||
) is None:
|
||||
logger.warn(
|
||||
f"Cannot find end_session_endpoint for provider {oidc_provider.name}"
|
||||
)
|
||||
logger.warning(f"Cannot find end_session_endpoint for provider {provider.id}")
|
||||
return RedirectResponse(request.url_for("non_compliant_logout"))
|
||||
post_logout_uri = request.url_for("home")
|
||||
if (token := await db.get_token(request.session.pop("token", None))) is None:
|
||||
logger.warn("No session in db for the token")
|
||||
# Clear session
|
||||
request.session.pop("user_sub", None)
|
||||
request.session.pop("auth_provider_id", None)
|
||||
try:
|
||||
token = await db.get_token(provider, request.session.pop("sid", None))
|
||||
except TokenNotInDb:
|
||||
logger.warning("No session in db for the token or no token")
|
||||
return RedirectResponse(request.url_for("home"))
|
||||
logout_url = (
|
||||
provider_logout_uri
|
||||
+ "?"
|
||||
+ urlencode(
|
||||
{
|
||||
url_query = {
|
||||
"post_logout_redirect_uri": post_logout_uri,
|
||||
"id_token_hint": token["id_token"],
|
||||
"cliend_id": "oidc_local_test",
|
||||
"client_id": provider.client_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
if provider.logout_with_id_token_hint:
|
||||
url_query["id_token_hint"] = token["id_token"]
|
||||
logout_url = f"{provider_logout_uri}?{urlencode(url_query)}"
|
||||
return RedirectResponse(logout_url)
|
||||
|
||||
|
||||
@app.get("/non-compliant-logout")
|
||||
async def non_compliant_logout(
|
||||
request: Request,
|
||||
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
||||
provider: Annotated[Provider, Depends(get_auth_provider)],
|
||||
):
|
||||
"""A page for non-compliant OAuth2 servers that we cannot log out."""
|
||||
# Clear the remain of the session
|
||||
request.session.pop("oidc_provider_id", None)
|
||||
# Clear session
|
||||
request.session.pop("user_sub", None)
|
||||
request.session.pop("auth_provider_id", None)
|
||||
return templates.TemplateResponse(
|
||||
name="non_compliant_logout.html",
|
||||
request=request,
|
||||
context={"oidc_provider": oidc_provider, "home_url": request.url_for("home")},
|
||||
context={"auth_provider": provider, "home_url": request.url_for("home")},
|
||||
)
|
||||
|
||||
|
||||
# Route for OAuth resource server
|
||||
|
||||
|
||||
@app.get("/resource/{id}")
|
||||
async def get_resource_(
|
||||
id: str,
|
||||
# user: Annotated[User, Depends(get_current_user)],
|
||||
# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
||||
# token: Annotated[OAuth2Token, Depends(get_token)],
|
||||
user: Annotated[User, Depends(get_user_from_token)],
|
||||
) -> JSONResponse:
|
||||
"""Generic path for testing a resource provided by a provider"""
|
||||
return JSONResponse(await get_resource(id, user))
|
||||
|
||||
|
||||
# Routes for RBAC based tests
|
||||
|
||||
|
||||
@app.get("/public")
|
||||
async def public() -> HTMLResponse:
|
||||
return HTMLResponse("<h1>Not protected</h1>")
|
||||
|
||||
|
||||
@app.get("/protected")
|
||||
async def get_protected(
|
||||
user: Annotated[User, Depends(get_current_user)]
|
||||
) -> HTMLResponse:
|
||||
assert user is not None # Just to keep QA checks happy
|
||||
return HTMLResponse("<h1>Only authenticated users can see this</h1>")
|
||||
|
||||
|
||||
@app.get("/protected-by-foorole")
|
||||
@hasrole("foorole")
|
||||
async def get_protected_by_foorole(request: Request) -> HTMLResponse:
|
||||
assert request is not None # Just to keep QA checks happy
|
||||
return HTMLResponse("<h1>Only users with foorole can see this</h1>")
|
||||
|
||||
|
||||
@app.get("/protected-by-barrole")
|
||||
@hasrole("barrole")
|
||||
async def get_protected_by_barrole(request: Request) -> HTMLResponse:
|
||||
assert request is not None # Just to keep QA checks happy
|
||||
return HTMLResponse("<h1>Protected by barrole</h1>")
|
||||
|
||||
|
||||
@app.get("/protected-by-foorole-and-barrole")
|
||||
@hasrole("barrole")
|
||||
@hasrole("foorole")
|
||||
async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse:
|
||||
assert request is not None # Just to keep QA checks happy
|
||||
return HTMLResponse("<h1>Only users with foorole and barrole can see this</h1>")
|
||||
|
||||
|
||||
@app.get("/protected-by-foorole-or-barrole")
|
||||
@hasrole(["foorole", "barrole"])
|
||||
async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse:
|
||||
assert request is not None # Just to keep QA checks happy
|
||||
return HTMLResponse("<h1>Only users with foorole or barrole can see this</h1>")
|
||||
|
||||
|
||||
@app.get("/introspect")
|
||||
async def get_introspect(
|
||||
@app.get("/refresh")
|
||||
async def refresh(
|
||||
request: Request,
|
||||
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
||||
token: Annotated[OAuth2Token, Depends(get_token)],
|
||||
) -> JSONResponse:
|
||||
assert request is not None # Just to keep QA checks happy
|
||||
if (url := oidc_provider.server_metadata.get("introspection_endpoint")) is None:
|
||||
provider: Annotated[Provider, Depends(get_auth_provider)],
|
||||
token: Annotated[OAuth2Token, Depends(get_token_from_session)],
|
||||
) -> RedirectResponse:
|
||||
"""Manually refresh token"""
|
||||
new_token = await provider.authlib_client.fetch_access_token(
|
||||
refresh_token=token["refresh_token"],
|
||||
grant_type="refresh_token",
|
||||
)
|
||||
try:
|
||||
await update_token(provider.id, new_token)
|
||||
except PyJWTError as err:
|
||||
logger.info(f"Cannot refresh token: {err.__class__.__name__}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="No intrispection endpoint found for the OIDC provider",
|
||||
status.HTTP_510_NOT_EXTENDED, f"Token refresh error: {err.__class__.__name__}"
|
||||
)
|
||||
if (
|
||||
response := await oidc_provider.post(
|
||||
url,
|
||||
token=token,
|
||||
data={"token": token["access_token"]},
|
||||
)
|
||||
).is_success:
|
||||
return response.json()
|
||||
else:
|
||||
raise HTTPException(status_code=response.status_code, detail=response.text)
|
||||
return RedirectResponse(url=request.url_for("home"))
|
||||
|
||||
|
||||
# Snippet for running standalone
|
||||
|
@ -385,9 +341,7 @@ def main():
|
|||
parser.add_argument(
|
||||
"-p", "--port", type=int, default=80, help="Port to listen to (default: 80)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v", "--version", action="store_true", help="Print version and exit"
|
||||
)
|
||||
parser.add_argument("-v", "--version", action="store_true", help="Print version and exit")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.version:
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
from functools import cached_property
|
||||
from typing import Self
|
||||
from typing import Any
|
||||
|
||||
from pydantic import (
|
||||
computed_field,
|
||||
|
@ -7,9 +8,10 @@ from pydantic import (
|
|||
EmailStr,
|
||||
ConfigDict,
|
||||
)
|
||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
logger = logging.getLogger("oidc-test")
|
||||
|
||||
|
||||
class Role(SQLModel, extra="ignore"):
|
||||
name: str
|
||||
|
@ -33,18 +35,8 @@ class User(UserBase):
|
|||
)
|
||||
userinfo: dict = {}
|
||||
access_token: str | None = None
|
||||
oidc_provider: StarletteOAuth2App | None = None
|
||||
|
||||
@classmethod
|
||||
def from_auth(cls, userinfo: dict, oidc_provider: StarletteOAuth2App) -> Self:
|
||||
user = cls(**userinfo)
|
||||
user.userinfo = userinfo
|
||||
user.oidc_provider = oidc_provider
|
||||
# Add roles if they are provided in the token
|
||||
if raw_ra := userinfo.get("realm_access"):
|
||||
if raw_roles := raw_ra.get("roles"):
|
||||
user.roles = [Role(name=raw_role) for raw_role in raw_roles]
|
||||
return user
|
||||
access_token_decoded: dict[str, Any] | None = None
|
||||
auth_provider_id: str
|
||||
|
||||
@computed_field
|
||||
@cached_property
|
||||
|
@ -54,15 +46,21 @@ class User(UserBase):
|
|||
def has_scope(self, scope: str) -> bool:
|
||||
"""Check if the scope is present in user info or access token"""
|
||||
info_scopes = self.userinfo.get("scope", "").split(" ")
|
||||
access_token_scopes = self.access_token_parsed().get("scope", "").split(" ")
|
||||
try:
|
||||
access_token_scopes = self.decode_access_token().get("scope", "").split(" ")
|
||||
except Exception as err:
|
||||
logger.debug(f"Cannot find scope because the access token cannot be decoded: {err}")
|
||||
access_token_scopes = []
|
||||
return scope in set(info_scopes + access_token_scopes)
|
||||
|
||||
def access_token_parsed(self):
|
||||
assert self.access_token is not None
|
||||
assert self.oidc_provider is not None
|
||||
assert self.oidc_provider.name is not None
|
||||
from .auth_utils import oidc_providers_settings
|
||||
def decode_access_token(self, verify_signature: bool = True):
|
||||
assert self.access_token is not None, "no access_token"
|
||||
assert self.auth_provider_id is not None, "no auth_provider_id"
|
||||
from .auth_providers import providers
|
||||
|
||||
return oidc_providers_settings[self.oidc_provider.name].decode(
|
||||
self.access_token
|
||||
return providers[self.auth_provider_id].decode(
|
||||
self.access_token, verify_signature=verify_signature
|
||||
)
|
||||
|
||||
def get_scope(self, verify_signature: bool = True):
|
||||
return self.decode_access_token(verify_signature=verify_signature)["scope"]
|
||||
|
|
47
src/oidc_test/registry.py
Normal file
47
src/oidc_test/registry.py
Normal file
|
@ -0,0 +1,47 @@
|
|||
from importlib.metadata import entry_points
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from oidc_test.models import User
|
||||
|
||||
logger = logging.getLogger("registry")
|
||||
|
||||
|
||||
class ProcessResult(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
extra="allow",
|
||||
)
|
||||
|
||||
|
||||
class ProcessError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Resource(BaseModel):
|
||||
name: str
|
||||
scope_required: str | None = None
|
||||
role_required: str | None = None
|
||||
is_public: bool = False
|
||||
default_resource_id: str | None = None
|
||||
|
||||
def __init__(self, name: str):
|
||||
super().__init__()
|
||||
self.__id__ = name
|
||||
|
||||
async def process(self, user: User | None, resource_id: str | None = None) -> ProcessResult:
|
||||
logger.warning(f"{self.__id__} should define a process method")
|
||||
return ProcessResult()
|
||||
|
||||
|
||||
class ResourceRegistry(BaseModel):
|
||||
resources: dict[str, Resource] = {}
|
||||
|
||||
def make_registry(self):
|
||||
for ep in entry_points().select(group="oidc_test.resource_provider"):
|
||||
ResourceClass = ep.load()
|
||||
if issubclass(ResourceClass, Resource):
|
||||
self.resources[ep.name] = ResourceClass(ep.name)
|
||||
|
||||
|
||||
registry = ResourceRegistry()
|
|
@ -1,55 +1,310 @@
|
|||
from datetime import datetime
|
||||
from typing import Annotated, Any
|
||||
import logging
|
||||
from json import JSONDecodeError
|
||||
|
||||
from httpx import AsyncClient
|
||||
from authlib.oauth2.rfc6749 import OAuth2Token
|
||||
from httpx import AsyncClient, HTTPError
|
||||
from jwt.exceptions import DecodeError, ExpiredSignatureError, InvalidTokenError
|
||||
from fastapi import FastAPI, HTTPException, Depends, Request, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from .models import User
|
||||
# from starlette.middleware.sessions import SessionMiddleware
|
||||
# from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
# from authlib.oauth2.rfc6749 import OAuth2Token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from oidc_test.auth.provider import Provider
|
||||
from oidc_test.auth.utils import (
|
||||
get_user_from_token_or_none,
|
||||
oauth2_scheme_optional,
|
||||
)
|
||||
from oidc_test.auth_providers import providers
|
||||
from oidc_test.settings import ResourceProvider, settings
|
||||
from oidc_test.models import User
|
||||
from oidc_test.registry import ProcessError, ProcessResult, registry
|
||||
|
||||
logger = logging.getLogger("oidc-test")
|
||||
|
||||
resource_server = FastAPI()
|
||||
|
||||
|
||||
async def get_resource(resource_id: str, user: User) -> dict:
|
||||
"""
|
||||
Resource processing: build an informative rely as a simple showcase
|
||||
"""
|
||||
pname = getattr(user.oidc_provider, "name", "?")
|
||||
resp = {
|
||||
"hello": f"Hi {user.name} from an OAuth resource provider",
|
||||
"comment": f"I received a request for '{resource_id}' "
|
||||
+ f"with an access token signed by {pname}",
|
||||
}
|
||||
# For the demo, resource resource_id matches a scope get:resource_id,
|
||||
# but this has to be refined for production
|
||||
required_scope = f"get:{resource_id}"
|
||||
# Check if the required scope is in the scopes allowed in userinfo
|
||||
if user.has_scope(required_scope):
|
||||
await process(user, resource_id, resp)
|
||||
else:
|
||||
## For the showcase, giving a explanation.
|
||||
## Alternatively, raise HTTP_401_UNAUTHORIZED
|
||||
resp["sorry"] = (
|
||||
f"No scope {required_scope} in the access token "
|
||||
+ "but it is required for accessing this resource."
|
||||
resource_server.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# SessionMiddleware is required by authlib
|
||||
# resource_server.add_middleware(
|
||||
# SessionMiddleware,
|
||||
# secret_key=settings.secret_key,
|
||||
# )
|
||||
|
||||
# Route for OAuth resource server
|
||||
|
||||
|
||||
# Routes for RBAC based tests
|
||||
|
||||
|
||||
@resource_server.get("/")
|
||||
async def resources() -> dict[str, dict[str, Any]]:
|
||||
return {"internal": {}, "plugins": registry.resources}
|
||||
|
||||
|
||||
@resource_server.get("/{resource_name}")
|
||||
@resource_server.get("/{resource_name}/{resource_id}")
|
||||
async def get_resource(
|
||||
resource_name: str,
|
||||
user: Annotated[User | None, Depends(get_user_from_token_or_none)],
|
||||
token: Annotated[OAuth2Token | None, Depends(oauth2_scheme_optional)],
|
||||
resource_id: str | None = None,
|
||||
):
|
||||
"""Generic path for testing a resource provided by a provider.
|
||||
There's no field validation (response type of ProcessResult) on purpose,
|
||||
leaving the responsibility of the response validation to resource providers"""
|
||||
# Get the resource if it's defined in user auth provider's resources (external)
|
||||
if user is not None:
|
||||
provider = providers[user.auth_provider_id]
|
||||
if ":" in resource_name:
|
||||
# Third-party resource provider: send the request with the request token
|
||||
resource_provider_id, resource_name = resource_name.split(":", 1)
|
||||
provider = providers[user.auth_provider_id]
|
||||
resource_provider: ResourceProvider = provider.get_resource_provider(
|
||||
resource_provider_id
|
||||
)
|
||||
resource_url = resource_provider.get_resource_url(resource_name)
|
||||
async with AsyncClient(verify=resource_provider.verify_ssl) as client:
|
||||
try:
|
||||
logger.debug(f"GET request to {resource_url}")
|
||||
resp = await client.get(
|
||||
resource_url,
|
||||
headers={
|
||||
"Content-type": "application/json",
|
||||
"Authorization": f"Bearer {token}",
|
||||
"auth_provider": user.auth_provider_id,
|
||||
},
|
||||
)
|
||||
except HTTPError as err:
|
||||
raise HTTPException(
|
||||
status.HTTP_503_SERVICE_UNAVAILABLE, err.__class__.__name__
|
||||
)
|
||||
except Exception as err:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR, err.__class__.__name__
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
async def process(user, resource_id, resp):
|
||||
"""
|
||||
Too simple to be serious.
|
||||
It's a good fit for a plugin architecture for production
|
||||
"""
|
||||
assert user is not None
|
||||
if resource_id == "time":
|
||||
resp["time"] = datetime.now().strftime("%c")
|
||||
elif resource_id == "bs":
|
||||
async with AsyncClient() as client:
|
||||
bs = await client.get("https://corporatebs-generator.sameerkumar.website/")
|
||||
resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.")
|
||||
else:
|
||||
resp["sorry"] = f"I don't known how to give '{resource_id}'."
|
||||
if resp.is_success:
|
||||
return resp.json()
|
||||
else:
|
||||
reason_str: str
|
||||
try:
|
||||
reason_str = resp.json().get("detail", str(resp))
|
||||
except Exception:
|
||||
reason_str = str(resp.text)
|
||||
raise HTTPException(resp.status_code, reason_str)
|
||||
# Third party resource (provided through the auth provider)
|
||||
# The token is just passed on
|
||||
# XXX: is this branch valid anymore?
|
||||
if resource_name in [r.resource_name for r in provider.resources]:
|
||||
return await get_auth_provider_resource(
|
||||
provider=provider,
|
||||
resource_name=resource_name,
|
||||
token=token,
|
||||
user=user,
|
||||
)
|
||||
# Internal resource (provided here)
|
||||
if resource_name in registry.resources:
|
||||
resource = registry.resources[resource_name]
|
||||
reason: dict[str, str] = {}
|
||||
if not resource.is_public:
|
||||
if user is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Resource is not public")
|
||||
else:
|
||||
if resource.scope_required is not None and not user.has_scope(
|
||||
resource.scope_required
|
||||
):
|
||||
reason["scope"] = (
|
||||
f"No scope {resource.scope_required} in the access token "
|
||||
"but it is required for accessing this resource"
|
||||
)
|
||||
if (
|
||||
resource.role_required is not None
|
||||
and resource.role_required not in user.roles_as_set
|
||||
):
|
||||
reason["role"] = (
|
||||
f"You don't have the role {resource.role_required} "
|
||||
"but it is required for accessing this resource"
|
||||
)
|
||||
if len(reason) == 0:
|
||||
try:
|
||||
resp = await resource.process(user=user, resource_id=resource_id)
|
||||
return resp
|
||||
except ProcessError as err:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED, f"Cannot process resource: {err}"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, ", ".join(reason.values()))
|
||||
else:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Unknown resource")
|
||||
# return await get_resource_(resource_name, user, **request.query_params)
|
||||
|
||||
|
||||
async def get_auth_provider_resource(
|
||||
provider: Provider, resource_name: str, token: OAuth2Token | None, user: User
|
||||
) -> ProcessResult:
|
||||
if token is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No auth token")
|
||||
access_token = token
|
||||
async with AsyncClient() as client:
|
||||
resp = await client.get(
|
||||
url=provider.get_resource_url(resource_name),
|
||||
headers={
|
||||
"Content-type": "application/json",
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
},
|
||||
)
|
||||
if resp.is_error:
|
||||
raise HTTPException(resp.status_code, f"Cannot fetch resource: {resp.reason_phrase}")
|
||||
# Only a demo, real application would really process the response
|
||||
resp_length = len(resp.text)
|
||||
if resp_length > 1024:
|
||||
return ProcessResult(
|
||||
msg=f"The resource is too long ({resp_length} bytes) to show in this demo, here is just the begining in raw format",
|
||||
start=resp.text[:100] + "...",
|
||||
)
|
||||
else:
|
||||
try:
|
||||
resp_json = resp.json()
|
||||
except JSONDecodeError:
|
||||
return ProcessResult(msg="The resource is not formatted in JSON", text=resp.text)
|
||||
if isinstance(resp_json, dict):
|
||||
return ProcessResult(**resp.json())
|
||||
elif isinstance(resp_json, list):
|
||||
return ProcessResult(**{str(i): line for i, line in enumerate(resp_json)})
|
||||
|
||||
|
||||
# @resource_server.get("/public")
|
||||
# async def public() -> dict:
|
||||
# return {"msg": "Not protected"}
|
||||
#
|
||||
#
|
||||
# @resource_server.get("/protected")
|
||||
# async def get_protected(user: Annotated[User, Depends(get_user_from_token)]):
|
||||
# assert user is not None # Just to keep QA checks happy
|
||||
# return {"msg": "Only authenticated users can see this"}
|
||||
#
|
||||
#
|
||||
# @resource_server.get("/protected-by-foorole")
|
||||
# async def get_protected_by_foorole(
|
||||
# user: Annotated[User, Depends(UserWithRole("foorole"))],
|
||||
# ):
|
||||
# assert user is not None
|
||||
# return {"msg": "Only users with foorole can see this"}
|
||||
#
|
||||
#
|
||||
# @resource_server.get("/protected-by-barrole")
|
||||
# async def get_protected_by_barrole(
|
||||
# user: Annotated[User, Depends(UserWithRole("barrole"))],
|
||||
# ):
|
||||
# assert user is not None
|
||||
# return {"msg": "Protected by barrole"}
|
||||
#
|
||||
#
|
||||
# @resource_server.get("/protected-by-foorole-and-barrole")
|
||||
# async def get_protected_by_foorole_and_barrole(
|
||||
# user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))],
|
||||
# ):
|
||||
# assert user is not None # Just to keep QA checks happy
|
||||
# return {"msg": "Only users with foorole and barrole can see this"}
|
||||
#
|
||||
#
|
||||
# @resource_server.get("/protected-by-foorole-or-barrole")
|
||||
# async def get_protected_by_foorole_or_barrole(
|
||||
# user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))],
|
||||
# ):
|
||||
# assert user is not None # Just to keep QA checks happy
|
||||
# return {"msg": "Only users with foorole or barrole can see this"}
|
||||
|
||||
# async def get_resource_(resource_id: str, user: User, **kwargs) -> dict:
|
||||
# """
|
||||
# Resource processing: build an informative rely as a simple showcase
|
||||
# """
|
||||
# if resource_id == "petition":
|
||||
# return await sign(user, kwargs["petition_id"])
|
||||
# provider = providers[user.auth_provider_id]
|
||||
# try:
|
||||
# pname = provider.name
|
||||
# except KeyError:
|
||||
# pname = "?"
|
||||
# resp = {
|
||||
# "hello": f"Hi {user.name} from an OAuth resource provider",
|
||||
# "comment": f"I received a request for '{resource_id}' "
|
||||
# + f"with an access token signed by {pname}",
|
||||
# }
|
||||
# # For the demo, resource resource_id matches a scope get:resource_id,
|
||||
# # but this has to be refined for production
|
||||
# required_scope = f"get:{resource_id}"
|
||||
# # Check if the required scope is in the scopes allowed in userinfo
|
||||
# try:
|
||||
# if user.has_scope(required_scope):
|
||||
# await process(user, resource_id, resp)
|
||||
# else:
|
||||
# ## For the showcase, giving a explanation.
|
||||
# ## Alternatively, raise HTTP_401_UNAUTHORIZED
|
||||
# raise HTTPException(
|
||||
# status.HTTP_401_UNAUTHORIZED,
|
||||
# f"No scope {required_scope} in the access token "
|
||||
# + "but it is required for accessing this resource",
|
||||
# )
|
||||
# except ExpiredSignatureError:
|
||||
# raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token's signature has expired")
|
||||
# except InvalidTokenError:
|
||||
# raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token is invalid")
|
||||
# return resp
|
||||
|
||||
|
||||
# async def process(user, resource_id, resp):
|
||||
# """
|
||||
# Too simple to be serious.
|
||||
# It's a good fit for a plugin architecture for production
|
||||
# """
|
||||
# if resource_id == "time":
|
||||
# resp["time"] = datetime.now().strftime("%c")
|
||||
# elif resource_id == "bs":
|
||||
# async with AsyncClient() as client:
|
||||
# bs = await client.get("https://corporatebs-generator.sameerkumar.website/")
|
||||
# resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.")
|
||||
# else:
|
||||
# raise HTTPException(
|
||||
# status.HTTP_401_UNAUTHORIZED, f"I don't known how to give '{resource_id}'."
|
||||
# )
|
||||
|
||||
|
||||
# @resource_server.get("/introspect")
|
||||
# async def get_introspect(
|
||||
# request: Request,
|
||||
# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
||||
# token: Annotated[OAuth2Token, Depends(get_token)],
|
||||
# ) -> JSONResponse:
|
||||
# assert request is not None # Just to keep QA checks happy
|
||||
# if (url := oidc_provider.server_metadata.get("introspection_endpoint")) is None:
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
# detail="No introspection endpoint found for the OIDC provider",
|
||||
# )
|
||||
# if (
|
||||
# response := await oidc_provider.post(
|
||||
# url,
|
||||
# token=token,
|
||||
# data={"token": token["access_token"]},
|
||||
# )
|
||||
# ).is_success:
|
||||
# return response.json()
|
||||
# else:
|
||||
# raise HTTPException(status_code=response.status_code, detail=response.text)
|
||||
|
||||
# assert user.oidc_provider is not None
|
||||
### Get some info (TODO: refactor)
|
||||
# if (auth_provider_id := user.oidc_provider.name) is None:
|
||||
|
|
|
@ -4,8 +4,7 @@ import random
|
|||
from typing import Type, Tuple
|
||||
from pathlib import Path
|
||||
|
||||
from jwt import decode
|
||||
from pydantic import BaseModel, computed_field, AnyUrl
|
||||
from pydantic import AnyHttpUrl, BaseModel, computed_field, AnyUrl
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
SettingsConfigDict,
|
||||
|
@ -14,18 +13,33 @@ from pydantic_settings import (
|
|||
)
|
||||
from starlette.requests import Request
|
||||
|
||||
from .models import User
|
||||
|
||||
|
||||
class Resource(BaseModel):
|
||||
"""A resource with an URL that can be accessed with an OAuth2 access token"""
|
||||
|
||||
resource_name: str
|
||||
name: str
|
||||
url: str
|
||||
|
||||
|
||||
class ResourceProvider(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
base_url: AnyUrl
|
||||
resources: list[Resource] = []
|
||||
verify_ssl: bool = True
|
||||
|
||||
def get_resource(self, resource_name: str) -> Resource:
|
||||
return [
|
||||
resource for resource in self.resources if resource.resource_name == resource_name
|
||||
][0]
|
||||
|
||||
def get_resource_url(self, resource_name: str) -> str:
|
||||
return f"{self.base_url}{self.get_resource(resource_name).url}"
|
||||
|
||||
|
||||
class OIDCProvider(BaseModel):
|
||||
"""OIDC provider, can also be a resource server"""
|
||||
class AuthProviderSettings(BaseModel):
|
||||
"""Auth provider, can also be a resource server"""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
|
@ -40,11 +54,14 @@ class OIDCProvider(BaseModel):
|
|||
info_url: str | None = (
|
||||
None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key)
|
||||
)
|
||||
info: dict[str, str | int] | None = (
|
||||
None # Info fetched from info_url, eg. public key
|
||||
)
|
||||
public_key: str | None = None
|
||||
public_key_url: str | None = None
|
||||
signature_alg: str = "RS256"
|
||||
resource_provider_scopes: list[str] = []
|
||||
session_key: str = "sid"
|
||||
skip_verify_signature: bool = True
|
||||
disabled: bool = False
|
||||
resource_providers: list[ResourceProvider] = []
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
|
@ -56,56 +73,20 @@ class OIDCProvider(BaseModel):
|
|||
def token_url(self) -> str:
|
||||
return "auth/" + self.id
|
||||
|
||||
def get_account_url(self, request: Request, user: User) -> str | None:
|
||||
def get_account_url(self, request: Request, user: dict) -> str | None:
|
||||
if self.account_url_template:
|
||||
if not (
|
||||
self.url.endswith("/") or self.account_url_template.startswith("/")
|
||||
):
|
||||
if not (self.url.endswith("/") or self.account_url_template.startswith("/")):
|
||||
sep = "/"
|
||||
else:
|
||||
sep = ""
|
||||
return (
|
||||
self.url
|
||||
+ sep
|
||||
+ self.account_url_template.format(request=request, user=user)
|
||||
)
|
||||
return self.url + sep + self.account_url_template.format(request=request, user=user)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_public_key(self) -> str:
|
||||
"""Return the public key formatted for decoding token"""
|
||||
public_key = self.public_key or (
|
||||
self.info is not None and self.info["public_key"]
|
||||
)
|
||||
if public_key is None:
|
||||
raise AttributeError(f"Cannot get public key for {self.name}")
|
||||
return f"""
|
||||
-----BEGIN PUBLIC KEY-----
|
||||
{public_key}
|
||||
-----END PUBLIC KEY-----
|
||||
"""
|
||||
|
||||
def decode(self, token: str) -> dict:
|
||||
"""Decode the token with signature check"""
|
||||
return decode(
|
||||
token,
|
||||
self.get_public_key(),
|
||||
algorithms=[self.signature_alg],
|
||||
audience=["oidc-test", "oidc-test-web"],
|
||||
options={"verify_signature": not settings.insecure.skip_verify_signature},
|
||||
)
|
||||
|
||||
|
||||
class ResourceProvider(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
base_url: AnyUrl
|
||||
resources: list[Resource] = []
|
||||
|
||||
|
||||
class OIDCSettings(BaseModel):
|
||||
class AuthSettings(BaseModel):
|
||||
show_session_details: bool = False
|
||||
providers: list[OIDCProvider] = []
|
||||
providers: list[AuthProviderSettings] = []
|
||||
swagger_provider: str = ""
|
||||
|
||||
|
||||
|
@ -120,12 +101,15 @@ class Settings(BaseSettings):
|
|||
|
||||
model_config = SettingsConfigDict(env_nested_delimiter="__")
|
||||
|
||||
oidc: OIDCSettings = OIDCSettings()
|
||||
resource_providers: list[ResourceProvider] = []
|
||||
auth: AuthSettings = AuthSettings()
|
||||
secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16))
|
||||
log: bool = False
|
||||
log_config_file: str = "log_conf.yaml"
|
||||
insecure: Insecure = Insecure()
|
||||
cors_origins: list[str] = []
|
||||
debug_token: bool = False
|
||||
show_token: bool = False
|
||||
show_external_resource_providers_links: bool = False
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
|
@ -144,9 +128,7 @@ class Settings(BaseSettings):
|
|||
settings_cls,
|
||||
Path(
|
||||
Path(
|
||||
environ.get(
|
||||
"OIDC_TEST_SETTINGS_FILE", Path.cwd() / "settings.yaml"
|
||||
),
|
||||
environ.get("OIDC_TEST_SETTINGS_FILE", Path.cwd() / "settings.yaml"),
|
||||
)
|
||||
),
|
||||
),
|
||||
|
|
|
@ -21,6 +21,18 @@ hr {
|
|||
.hidden {
|
||||
display: none;
|
||||
}
|
||||
.version {
|
||||
position: absolute;
|
||||
font-size: 75%;
|
||||
top: 0.3em;
|
||||
right: 0.3em;
|
||||
}
|
||||
.center {
|
||||
text-align: center;
|
||||
}
|
||||
.error {
|
||||
color: darkred;
|
||||
}
|
||||
.content {
|
||||
width: 100%;
|
||||
display: flex;
|
||||
|
@ -111,6 +123,7 @@ hr {
|
|||
background-color: #8888FF80;
|
||||
}
|
||||
|
||||
|
||||
/* For home */
|
||||
|
||||
.login-box {
|
||||
|
@ -135,19 +148,27 @@ hr {
|
|||
.providers .provider {
|
||||
min-height: 2em;
|
||||
}
|
||||
.providers .provider a.link {
|
||||
.providers .provider .link {
|
||||
text-decoration: none;
|
||||
max-height: 2em;
|
||||
}
|
||||
.providers .provider .link div {
|
||||
.providers .provider .link {
|
||||
background-color: #f7c7867d;
|
||||
border-radius: 8px;
|
||||
padding: 6px;
|
||||
text-align: center;
|
||||
color: black;
|
||||
font-weight: bold;
|
||||
font-weight: 400;
|
||||
cursor: pointer;
|
||||
border: 0;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.providers .provider .link.disabled {
|
||||
color: gray;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.providers .provider .hint {
|
||||
font-size: 80%;
|
||||
max-width: 13em;
|
||||
|
@ -163,10 +184,56 @@ hr {
|
|||
gap: 0.5em;
|
||||
flex-flow: wrap;
|
||||
}
|
||||
.content .links-to-check a {
|
||||
.content .links-to-check button {
|
||||
color: black;
|
||||
padding: 5px 10px;
|
||||
text-decoration: none;
|
||||
border-radius: 8px;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.token {
|
||||
overflow-wrap: anywhere;
|
||||
font-family: monospace;
|
||||
}
|
||||
|
||||
.resourceResult {
|
||||
padding: 0.5em;
|
||||
display: flex;
|
||||
gap: 0.5em;
|
||||
width: fit-content;
|
||||
align-items: center;
|
||||
margin: 5px auto;
|
||||
box-shadow: 0px 0px 10px #90c3eeA0;
|
||||
background-color: #90c3eeA0;
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
.resources {
|
||||
display: flex;
|
||||
}
|
||||
|
||||
.resource {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.token-info {
|
||||
margin: 0 1em;
|
||||
}
|
||||
|
||||
.key {
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.token .key, .token .value {
|
||||
display: inline;
|
||||
}
|
||||
.token .value {
|
||||
padding-left: 1em;
|
||||
}
|
||||
|
||||
.msg {
|
||||
text-align: center;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
|
|
@ -1,40 +1,90 @@
|
|||
function checkHref(elem) {
|
||||
var xmlHttp = new XMLHttpRequest()
|
||||
xmlHttp.onreadystatechange = function () {
|
||||
if (xmlHttp.readyState == 4) {
|
||||
elem.classList.add("hasResponseStatus")
|
||||
elem.classList.add("status-" + xmlHttp.status)
|
||||
elem.title = "Response code: " + xmlHttp.status + " - " + xmlHttp.statusText
|
||||
}
|
||||
}
|
||||
xmlHttp.open("GET", elem.href, true) // true for asynchronous
|
||||
xmlHttp.send(null)
|
||||
}
|
||||
|
||||
function checkPerms(className) {
|
||||
var rootElems = document.getElementsByClassName(className)
|
||||
Array.from(rootElems).forEach(elem =>
|
||||
Array.from(elem.children).forEach(elem => checkHref(elem))
|
||||
)
|
||||
}
|
||||
|
||||
async function get_resource(id, token, authProvider) {
|
||||
//if (!keycloak.keycloak) { return }
|
||||
const resp = await fetch("resource/" + id, {
|
||||
async function checkHref(elem, token, authProvider) {
|
||||
const msg = document.getElementById("msg")
|
||||
const resourceName = elem.getAttribute("resource-name")
|
||||
const resourceId = elem.getAttribute("resource-id")
|
||||
const resourceProviderId = elem.getAttribute("resource-provider-id") ? elem.getAttribute("resource-provider-id") : ""
|
||||
const fqResourceName = resourceProviderId ? `${resourceProviderId}:${resourceName}` : resourceName
|
||||
const url = resourceId ? `resource/${fqResourceName}/${resourceId}` : `resource/${fqResourceName}`
|
||||
const resp = await fetch(url, {
|
||||
method: "GET",
|
||||
headers: new Headers({
|
||||
"Content-type": "application/json",
|
||||
"Authorization": `Bearer ${token}`,
|
||||
"auth_provider": authProvider,
|
||||
}),
|
||||
}).catch(err => {
|
||||
msg.innerHTML = "Cannot fetch resource: " + err.message
|
||||
resourceElem.innerHTML = ""
|
||||
})
|
||||
/*
|
||||
resource.value = resp['data']
|
||||
msg.value = ""
|
||||
if (resp === undefined) {
|
||||
return
|
||||
} else {
|
||||
elem.classList.add("hasResponseStatus")
|
||||
elem.classList.add("status-" + resp.status)
|
||||
elem.title = "Response code: " + resp.status + " - " + resp.statusText
|
||||
}
|
||||
).catch (
|
||||
err => msg.value = err
|
||||
)
|
||||
*/
|
||||
console.log(await resp.json())
|
||||
}
|
||||
|
||||
function checkPerms(className, token, authProvider) {
|
||||
var rootElems = document.getElementsByClassName(className)
|
||||
Array.from(rootElems).forEach(elem =>
|
||||
Array.from(elem.children).forEach(elem => checkHref(elem, token, authProvider))
|
||||
)
|
||||
}
|
||||
|
||||
async function get_resource(resourceName, token, authProvider, resourceId, resourceProviderId) {
|
||||
// BaseUrl for an external resource provider
|
||||
//if (!keycloak.keycloak) { return }
|
||||
const msg = document.getElementById("msg")
|
||||
const resourceElem = document.getElementById('resource')
|
||||
const fqResourceName = resourceProviderId ? `${resourceProviderId}:${resourceName}` : resourceName
|
||||
const url = resourceId ? `resource/${fqResourceName}/${resourceId}` : `resource/${fqResourceName}`
|
||||
const resp = await fetch(url, {
|
||||
method: "GET",
|
||||
headers: new Headers({
|
||||
"Content-type": "application/json",
|
||||
"Authorization": `Bearer ${token}`,
|
||||
"auth_provider": authProvider,
|
||||
}),
|
||||
}).catch(err => {
|
||||
msg.innerHTML = "Cannot fetch resource: " + err.message
|
||||
resourceElem.innerHTML = ""
|
||||
})
|
||||
if (resp === undefined) {
|
||||
return
|
||||
}
|
||||
const resource = await resp.json()
|
||||
if (!resp.ok) {
|
||||
msg.innerHTML = resource["detail"]
|
||||
resourceElem.innerHTML = ""
|
||||
return
|
||||
}
|
||||
msg.innerHTML = ""
|
||||
resourceElem.innerHTML = ""
|
||||
Object.entries(resource).forEach(
|
||||
([key, value]) => {
|
||||
let r = document.createElement('div')
|
||||
let kElem = document.createElement('div')
|
||||
kElem.innerText = key
|
||||
kElem.className = "key"
|
||||
let vElem = document.createElement('div')
|
||||
if (typeof value == "object") {
|
||||
Object.entries(value).forEach(v => {
|
||||
const ne = document.createElement('div')
|
||||
ne.innerHTML = `<span class="key">${v[0]}</span>: <span class="value">${v[1]}</span>`
|
||||
vElem.appendChild(ne)
|
||||
})
|
||||
}
|
||||
else {
|
||||
vElem.innerText = value
|
||||
}
|
||||
vElem.className = "value"
|
||||
if (key == "sorry") {
|
||||
vElem.classList.add("error")
|
||||
}
|
||||
r.appendChild(kElem)
|
||||
r.appendChild(vElem)
|
||||
resourceElem.appendChild(r)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
|
|
@ -4,7 +4,8 @@
|
|||
<link href="{{ url_for('static', path='/styles.css') }}" rel="stylesheet">
|
||||
<script src="{{ url_for('static', path='/utils.js') }}"></script>
|
||||
</head>
|
||||
<body onload="checkPerms('links-to-check')">
|
||||
<body onload="checkPerms('links-to-check', '{{ access_token }}', '{{ auth_provider.id }}')">
|
||||
<div class="version">v. {{ __version__}}</div>
|
||||
<h1>OIDC-test - FastAPI client</h1>
|
||||
{% block content %}
|
||||
{% endblock %}
|
||||
|
|
|
@ -8,10 +8,14 @@
|
|||
<div class="login-box">
|
||||
<p class="description">Log in with:</p>
|
||||
<table class="providers">
|
||||
{% for provider in settings.oidc.providers %}
|
||||
{% for provider in providers.values() %}
|
||||
<tr class="provider">
|
||||
<td>
|
||||
<a class="link" href="login/{{ provider.id }}"><div>{{ provider.name }}</div></a>
|
||||
<button class="link{% if provider.disabled %} disabled{% endif %}"
|
||||
{% if provider.disabled %}disabled{% endif %}
|
||||
onclick="location.href='login/{{ provider.id }}'">
|
||||
{{ provider.name }}
|
||||
</button>
|
||||
</td>
|
||||
<td class="hint">{{ provider.hint }}</div>
|
||||
</td>
|
||||
|
@ -22,8 +26,7 @@
|
|||
{% endfor %}
|
||||
</table>
|
||||
</div>
|
||||
{% endif %}
|
||||
{% if user %}
|
||||
{% else %}
|
||||
<div class="user-info">
|
||||
<p>Hey, {{ user.name }}</p>
|
||||
{% if user.picture %}
|
||||
|
@ -32,7 +35,7 @@
|
|||
<div>{{ user.email }}</div>
|
||||
<div>
|
||||
<span>Provider:</span>
|
||||
{{ oidc_provider_settings.name }}
|
||||
{{ auth_provider.name }}
|
||||
</div>
|
||||
{% if user.roles %}
|
||||
<div>
|
||||
|
@ -42,80 +45,125 @@
|
|||
{% endfor %}
|
||||
</div>
|
||||
{% endif %}
|
||||
{% if user.access_token.scope %}
|
||||
{% if access_token_scope %}
|
||||
<div>
|
||||
<span>Scopes</span>:
|
||||
{% for scope in user.access_token.scope.split(' ') %}
|
||||
{% for scope in access_token_scope.split(' ') %}
|
||||
<span class="scope">{{ scope }}</span>
|
||||
{% endfor %}
|
||||
</div>
|
||||
{% endif %}
|
||||
{% if oidc_provider_settings.account_url_template %}
|
||||
{% if auth_provider.account_url_template %}
|
||||
<button
|
||||
onclick="location.href='{{ oidc_provider_settings.get_account_url(request, user) }}'"
|
||||
onclick="location.href='{{ auth_provider.get_account_url(request, user.model_dump()) }}'"
|
||||
class="account">
|
||||
Account management
|
||||
</button>
|
||||
{% endif %}
|
||||
<button onclick="location.href='{{ request.url_for("refresh") }}'" class="refresh">Refresh access token</button>
|
||||
<button onclick="location.href='{{ request.url_for("logout") }}'" class="logout">Logout</button>
|
||||
</div>
|
||||
{% endif %}
|
||||
<hr>
|
||||
<p class="center">
|
||||
Fetch resources from the resource server with your authentication token:
|
||||
</p>
|
||||
<div class="actions">
|
||||
<button onclick="get_resource('time', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">Time</button>
|
||||
<button onclick="get_resource('bs', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">BS</button>
|
||||
</div>
|
||||
<div class="resources">
|
||||
<div v-if="Object.entries(resource).length > 0" class="resource">
|
||||
<div v-for="(value, key) in resource">
|
||||
<div class="key">{{ key }}</div>
|
||||
<div v-if="key == 'sorry' || key == 'error'" class="error">{{ value }}</div>
|
||||
<div v-else class="value">{{ value }}</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div v-if="msg" class="msg resource">{{ msg }}</div>
|
||||
<hr>
|
||||
<div class="content">
|
||||
<p>
|
||||
These links should get different response codes depending on the authorization:
|
||||
</p>
|
||||
<div class="links-to-check">
|
||||
<a href="public">Public</a>
|
||||
<a href="protected">Auth protected content</a>
|
||||
<a href="protected-by-foorole">Auth + foorole protected content</a>
|
||||
<a href="protected-by-foorole-or-barrole">Auth + foorole or barrole protected content</a>
|
||||
<a href="protected-by-barrole">Auth + barrole protected content</a>
|
||||
<a href="protected-by-foorole-and-barrole">Auth + foorole and barrole protected content</a>
|
||||
<a href="fast_api_depends" class="hidden">Using FastAPI Depends</a>
|
||||
<a href="introspect">Introspect token (401 expected)</a>
|
||||
</div>
|
||||
{% if resources %}
|
||||
<p>
|
||||
Resources for this provider:
|
||||
</p>
|
||||
<p>This application provides all these resources, eventually protected with scope or roles:</p>
|
||||
<div class="links-to-check">
|
||||
{% for resource in resources %}
|
||||
<a href="{{ request.url_for('get_resource', id=resource.id) }}">{{ resource.name }}</a>
|
||||
{% for name, resource in resources.items() %}
|
||||
{% if resource.default_resource_id %}
|
||||
<button resource-name="{{ name }}"
|
||||
resource-id="{{ resource.default_resource_id }}"
|
||||
onclick="get_resource('{{ name }}', '{{ access_token }}', '{{ auth_provider.id }}', '{{ resource.default_resource_id }}')"
|
||||
>
|
||||
{{ resource.name }}
|
||||
</button>
|
||||
{% else %}
|
||||
<button resource-name="{{ name }}"
|
||||
onclick="get_resource('{{ name }}', '{{ access_token }}', '{{ auth_provider.id }}')"
|
||||
>
|
||||
{{ resource.name }}
|
||||
</button>
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
</div>
|
||||
{% endif %}
|
||||
{% if user_info_details %}
|
||||
<hr>
|
||||
<div class="debug-auth">
|
||||
<p>User info</p>
|
||||
<ul>
|
||||
{% for key, value in user_info_details.items() %}
|
||||
<li>
|
||||
<span class="key">{{ key }}</span>: {{ value }}
|
||||
</li>
|
||||
{% if auth_provider.resources %}
|
||||
<p>{{ auth_provider.name }} is also defined as a provider for these resources:</p>
|
||||
<div class="links-to-check">
|
||||
{% for resource in auth_provider.resources %}
|
||||
{% if resource.default_resource_id %}
|
||||
<button resource-name="{{ resource.resource_name }}"
|
||||
resource-id="{{ resource.default_resource_id }}"
|
||||
onclick="get_resource('{{ resource.resource_name }}', '{{ access_token }}', '{{ auth_provider.id }}', '{{ resource.default_resource_id }}')"
|
||||
>
|
||||
{{ resource.name }}
|
||||
</button>
|
||||
{% else %}
|
||||
<button resource-name="{{ resource.resource_name }}"
|
||||
onclick="get_resource('{{ resource.resource_name }}', '{{ access_token }}', '{{ auth_provider.id }}')"
|
||||
>
|
||||
{{ resource.name }}
|
||||
</button>
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
</ul>
|
||||
</div>
|
||||
<div>Now is: {{ now.strftime("%T, %D") }} </div>
|
||||
{% endif %}
|
||||
{% if resource_providers %}
|
||||
<p>{{ auth_provider.name }} allows this application to request resources from third party resource providers:</p>
|
||||
{% for resource_provider in resource_providers %}
|
||||
<div class="links-to-check">
|
||||
{{ resource_provider.name }}
|
||||
{% for resource in resource_provider.resources %}
|
||||
<button resource-name="{{ resource.resource_name }}"
|
||||
resource-id="{{ resource.default_resource_id }}"
|
||||
resource-provider-id="{{ resource_provider.id }}"
|
||||
onclick="get_resource('{{ resource.resource_name }}', '{{ access_token }}',
|
||||
'{{ auth_provider.id }}', '{{ resource.default_resource_id }}',
|
||||
'{{ resource_provider.id }}')"
|
||||
>
|
||||
{{ resource.name }}
|
||||
</button>
|
||||
{% endfor %}
|
||||
</div>
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
<div class="resourceResult">
|
||||
<div id="resource" class="resource"></div>
|
||||
<div id="msg" class="msg error"></div>
|
||||
</div>
|
||||
</div>
|
||||
{% if show_token and id_token_parsed %}
|
||||
<div class="token-info">
|
||||
<hr>
|
||||
<div>
|
||||
<h2>id token</h2>
|
||||
<div class="token">
|
||||
{% for key, value in id_token_parsed.items() %}
|
||||
<div>
|
||||
<div class="key">{{ key }}</div>
|
||||
<div class="value">{{ value }}</div>
|
||||
</div>
|
||||
{% endfor %}
|
||||
</div>
|
||||
<h2>access token</h2>
|
||||
<div class="token">
|
||||
{% for key, value in access_token_parsed.items() %}
|
||||
<div>
|
||||
<div class="key">{{ key }}</div>
|
||||
<div class="value">{{ value }}</div>
|
||||
</div>
|
||||
{% endfor %}
|
||||
</div>
|
||||
<h2>refresh token</h2>
|
||||
<div class="token">
|
||||
{% for key, value in refresh_token_parsed.items() %}
|
||||
<div>
|
||||
<div class="key">{{ key }}</div>
|
||||
<div class="value">{{ value }}</div>
|
||||
</div>
|
||||
{% endfor %}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{% endif %}
|
||||
{% endblock %}
|
||||
|
|
|
@ -6,12 +6,12 @@
|
|||
authorisation to log in again without asking for credentials.
|
||||
</p>
|
||||
<p>
|
||||
This is because {{ oidc_provider.name }} does not provide "end_session_endpoint" in its metadata
|
||||
(see: <a href="{{ oidc_provider._server_metadata_url }}">{{ oidc_provider._server_metadata_url }}</a>).
|
||||
This is because {{ auth_provider.name }} does not provide "end_session_endpoint" in its metadata
|
||||
(see: <a href="{{ auth_provider.authlib_client._server_metadata_url }}">{{ auth_provider.authlib_client._server_metadata_url }}</a>).
|
||||
</p>
|
||||
<p>
|
||||
You can just also go back to the <a href="{{ home_url }}">application home page</a>, but
|
||||
it recommended to go to the <a href="{{ oidc_provider.server_metadata['issuer'] }}">OIDC provider's site</a>
|
||||
it recommended to go to the <a href="{{ auth_provider.authlib_client.server_metadata['issuer'] }}">OIDC provider's site</a>
|
||||
and log out explicitely from there.
|
||||
</p>
|
||||
{% endblock %}
|
||||
|
|
16
uv.lock
generated
16
uv.lock
generated
|
@ -1,4 +1,5 @@
|
|||
version = 1
|
||||
revision = 1
|
||||
requires-python = ">=3.13"
|
||||
|
||||
[[package]]
|
||||
|
@ -206,6 +207,18 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/68/1b/e0a87d256e40e8c888847551b20a017a6b98139178505dc7ffb96f04e954/dnspython-2.7.0-py3-none-any.whl", hash = "sha256:b4c34b7d10b51bcc3a5071e7b8dee77939f1e878477eeecc965e9835f63c6c86", size = 313632 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dunamai"
|
||||
version = "1.23.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "packaging" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/06/4e/a5c8c337a1d9ac0384298ade02d322741fb5998041a5ea74d1cd2a4a1d47/dunamai-1.23.0.tar.gz", hash = "sha256:a163746de7ea5acb6dacdab3a6ad621ebc612ed1e528aaa8beedb8887fccd2c4", size = 44681 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/21/4c/963169386309fec4f96fd61210ac0a0666887d0fb0a50205395674d20b71/dunamai-1.23.0-py3-none-any.whl", hash = "sha256:a0906d876e92441793c6a423e16a4802752e723e9c9a5aabdc5535df02dbe041", size = 26342 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ecdsa"
|
||||
version = "0.19.0"
|
||||
|
@ -482,7 +495,6 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "oidc-fastapi-test"
|
||||
version = "0.0.0"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "authlib" },
|
||||
|
@ -501,6 +513,7 @@ dependencies = [
|
|||
|
||||
[package.dev-dependencies]
|
||||
dev = [
|
||||
{ name = "dunamai" },
|
||||
{ name = "ipdb" },
|
||||
{ name = "pytest" },
|
||||
]
|
||||
|
@ -523,6 +536,7 @@ requires-dist = [
|
|||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
{ name = "dunamai", specifier = ">=1.23.0" },
|
||||
{ name = "ipdb", specifier = ">=0.13.13" },
|
||||
{ name = "pytest", specifier = ">=8.3.4" },
|
||||
]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue