2025-01-10 17:33:10 +01:00
|
|
|
from os import environ
|
2025-01-02 11:23:53 +01:00
|
|
|
import string
|
|
|
|
import random
|
2025-02-06 13:30:35 +01:00
|
|
|
from typing import Type, Tuple, Any
|
2025-01-10 17:33:10 +01:00
|
|
|
from pathlib import Path
|
2025-02-06 13:30:35 +01:00
|
|
|
import logging
|
2025-01-02 11:23:53 +01:00
|
|
|
|
2025-02-02 15:54:44 +01:00
|
|
|
from jwt import decode
|
2025-01-20 04:35:33 +01:00
|
|
|
from pydantic import BaseModel, computed_field, AnyUrl
|
2025-01-02 11:23:53 +01:00
|
|
|
from pydantic_settings import (
|
|
|
|
BaseSettings,
|
2025-01-10 19:18:57 +01:00
|
|
|
SettingsConfigDict,
|
2025-01-02 11:23:53 +01:00
|
|
|
PydanticBaseSettingsSource,
|
|
|
|
YamlConfigSettingsSource,
|
|
|
|
)
|
2025-01-26 23:37:56 +01:00
|
|
|
from starlette.requests import Request
|
|
|
|
|
|
|
|
from .models import User
|
2025-01-02 11:23:53 +01:00
|
|
|
|
2025-02-06 13:30:35 +01:00
|
|
|
logger = logging.getLogger("oidc-test")
|
|
|
|
|
2025-01-02 11:23:53 +01:00
|
|
|
|
2025-01-19 01:48:00 +01:00
|
|
|
class Resource(BaseModel):
|
|
|
|
"""A resource with an URL that can be accessed with an OAuth2 access token"""
|
|
|
|
|
2025-01-19 14:26:54 +01:00
|
|
|
id: str
|
2025-01-19 01:48:00 +01:00
|
|
|
name: str
|
|
|
|
|
|
|
|
|
2025-01-02 11:23:53 +01:00
|
|
|
class OIDCProvider(BaseModel):
|
2025-01-19 01:48:00 +01:00
|
|
|
"""OIDC provider, can also be a resource server"""
|
|
|
|
|
2025-01-10 00:09:12 +01:00
|
|
|
id: str
|
|
|
|
name: str
|
|
|
|
url: str
|
|
|
|
client_id: str
|
2025-01-02 11:23:53 +01:00
|
|
|
client_secret: str = ""
|
2025-01-16 05:43:26 +01:00
|
|
|
# For PKCE (not implemented yet)
|
2025-01-18 06:20:44 +01:00
|
|
|
code_challenge_method: str | None = None
|
2025-01-13 05:37:55 +01:00
|
|
|
hint: str = "No hint"
|
2025-01-19 01:48:00 +01:00
|
|
|
resources: list[Resource] = []
|
2025-01-26 23:37:56 +01:00
|
|
|
account_url_template: str | None = None
|
2025-01-30 20:40:04 +01:00
|
|
|
info_url: str | None = (
|
|
|
|
None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key)
|
|
|
|
)
|
2025-02-07 16:09:49 +01:00
|
|
|
info: dict[str, str | int] | None = None # Info fetched from info_url, eg. public key
|
2025-01-29 14:03:33 +01:00
|
|
|
public_key: str | None = None
|
2025-02-02 15:54:44 +01:00
|
|
|
signature_alg: str = "RS256"
|
2025-02-04 02:27:32 +01:00
|
|
|
resource_provider_scopes: list[str] = []
|
2025-01-02 11:23:53 +01:00
|
|
|
|
|
|
|
@computed_field
|
|
|
|
@property
|
2025-01-09 23:41:32 +01:00
|
|
|
def openid_configuration(self) -> str:
|
2025-01-02 11:23:53 +01:00
|
|
|
return self.url + "/.well-known/openid-configuration"
|
|
|
|
|
|
|
|
@computed_field
|
|
|
|
@property
|
|
|
|
def token_url(self) -> str:
|
2025-01-13 05:37:55 +01:00
|
|
|
return "auth/" + self.id
|
2025-01-02 11:23:53 +01:00
|
|
|
|
2025-01-26 23:37:56 +01:00
|
|
|
def get_account_url(self, request: Request, user: User) -> str | None:
|
|
|
|
if self.account_url_template:
|
2025-02-07 16:09:49 +01:00
|
|
|
if not (self.url.endswith("/") or self.account_url_template.startswith("/")):
|
2025-01-26 23:37:56 +01:00
|
|
|
sep = "/"
|
|
|
|
else:
|
|
|
|
sep = ""
|
2025-02-07 16:09:49 +01:00
|
|
|
return self.url + sep + self.account_url_template.format(request=request, user=user)
|
2025-01-26 19:08:49 +01:00
|
|
|
else:
|
|
|
|
return None
|
|
|
|
|
2025-02-02 15:54:44 +01:00
|
|
|
def get_public_key(self) -> str:
|
2025-01-29 14:03:33 +01:00
|
|
|
"""Return the public key formatted for decoding token"""
|
2025-02-07 16:09:49 +01:00
|
|
|
public_key = self.public_key or (self.info is not None and self.info["public_key"])
|
2025-01-29 14:03:33 +01:00
|
|
|
if public_key is None:
|
2025-02-02 15:54:44 +01:00
|
|
|
raise AttributeError(f"Cannot get public key for {self.name}")
|
2025-01-28 19:48:35 +01:00
|
|
|
return f"""
|
|
|
|
-----BEGIN PUBLIC KEY-----
|
2025-01-29 14:03:33 +01:00
|
|
|
{public_key}
|
2025-01-28 19:48:35 +01:00
|
|
|
-----END PUBLIC KEY-----
|
|
|
|
"""
|
2025-01-02 11:23:53 +01:00
|
|
|
|
2025-02-06 13:30:35 +01:00
|
|
|
def decode(self, token: str, verify_signature: bool = True) -> dict[str, Any]:
|
2025-02-02 15:54:44 +01:00
|
|
|
"""Decode the token with signature check"""
|
2025-02-07 16:09:49 +01:00
|
|
|
if settings.debug_token:
|
|
|
|
decoded = decode(
|
|
|
|
token,
|
|
|
|
self.get_public_key(),
|
|
|
|
algorithms=[self.signature_alg],
|
|
|
|
audience=["account", "oidc-test", "oidc-test-web"],
|
|
|
|
options={
|
|
|
|
"verify_signature": False,
|
|
|
|
"verify_aud": False,
|
|
|
|
}, # not settings.insecure.skip_verify_signature},
|
|
|
|
)
|
|
|
|
logger.debug(str(decoded))
|
2025-02-02 15:54:44 +01:00
|
|
|
return decode(
|
|
|
|
token,
|
|
|
|
self.get_public_key(),
|
|
|
|
algorithms=[self.signature_alg],
|
2025-02-06 13:30:35 +01:00
|
|
|
audience=["account", "oidc-test", "oidc-test-web"],
|
|
|
|
options={
|
|
|
|
"verify_signature": verify_signature,
|
|
|
|
}, # not settings.insecure.skip_verify_signature},
|
2025-02-02 15:54:44 +01:00
|
|
|
)
|
|
|
|
|
2025-01-30 20:40:04 +01:00
|
|
|
|
2025-01-20 01:16:17 +01:00
|
|
|
class ResourceProvider(BaseModel):
|
|
|
|
id: str
|
|
|
|
name: str
|
2025-01-20 04:35:33 +01:00
|
|
|
base_url: AnyUrl
|
2025-01-20 01:16:17 +01:00
|
|
|
resources: list[Resource] = []
|
|
|
|
|
|
|
|
|
2025-01-02 11:23:53 +01:00
|
|
|
class OIDCSettings(BaseModel):
|
|
|
|
show_session_details: bool = False
|
|
|
|
providers: list[OIDCProvider] = []
|
|
|
|
swagger_provider: str = ""
|
|
|
|
|
|
|
|
|
2025-01-30 20:40:04 +01:00
|
|
|
class Insecure(BaseModel):
|
|
|
|
"""Warning: changing these defaults are only suitable for debugging"""
|
|
|
|
|
|
|
|
skip_verify_signature: bool = False
|
|
|
|
|
|
|
|
|
2025-01-02 11:23:53 +01:00
|
|
|
class Settings(BaseSettings):
|
|
|
|
"""Settings wil be read from an .env file"""
|
|
|
|
|
2025-01-30 20:40:04 +01:00
|
|
|
model_config = SettingsConfigDict(env_nested_delimiter="__")
|
|
|
|
|
2025-01-02 11:23:53 +01:00
|
|
|
oidc: OIDCSettings = OIDCSettings()
|
2025-01-20 04:35:33 +01:00
|
|
|
resource_providers: list[ResourceProvider] = []
|
2025-01-02 11:23:53 +01:00
|
|
|
secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16))
|
2025-01-13 05:37:55 +01:00
|
|
|
log: bool = False
|
2025-01-30 20:40:04 +01:00
|
|
|
insecure: Insecure = Insecure()
|
2025-01-31 00:12:50 +01:00
|
|
|
cors_origins: list[str] = []
|
2025-02-07 16:09:49 +01:00
|
|
|
debug_token: bool = False
|
2025-02-08 01:55:36 +01:00
|
|
|
show_token: bool = False
|
2025-01-02 11:23:53 +01:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def settings_customise_sources(
|
|
|
|
cls,
|
|
|
|
settings_cls: Type[BaseSettings],
|
|
|
|
init_settings: PydanticBaseSettingsSource,
|
|
|
|
env_settings: PydanticBaseSettingsSource,
|
|
|
|
dotenv_settings: PydanticBaseSettingsSource,
|
|
|
|
file_secret_settings: PydanticBaseSettingsSource,
|
|
|
|
) -> Tuple[PydanticBaseSettingsSource, ...]:
|
|
|
|
return (
|
|
|
|
init_settings,
|
|
|
|
env_settings,
|
|
|
|
file_secret_settings,
|
2025-01-10 17:33:10 +01:00
|
|
|
YamlConfigSettingsSource(
|
|
|
|
settings_cls,
|
|
|
|
Path(
|
|
|
|
Path(
|
2025-02-07 16:09:49 +01:00
|
|
|
environ.get("OIDC_TEST_SETTINGS_FILE", Path.cwd() / "settings.yaml"),
|
2025-01-10 17:33:10 +01:00
|
|
|
)
|
|
|
|
),
|
|
|
|
),
|
2025-01-02 11:23:53 +01:00
|
|
|
dotenv_settings,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
settings = Settings()
|
2025-02-06 13:30:35 +01:00
|
|
|
|
|
|
|
|
|
|
|
oidc_providers_settings: dict[str, OIDCProvider] = dict(
|
|
|
|
[(provider.id, provider) for provider in settings.oidc.providers]
|
|
|
|
)
|