from os import environ import string import random from typing import Type, Tuple, Any from pathlib import Path import logging from jwt import decode from pydantic import BaseModel, computed_field, AnyUrl from pydantic_settings import ( BaseSettings, SettingsConfigDict, PydanticBaseSettingsSource, YamlConfigSettingsSource, ) from starlette.requests import Request from .models import User logger = logging.getLogger("oidc-test") class Resource(BaseModel): """A resource with an URL that can be accessed with an OAuth2 access token""" id: str name: str class OIDCProvider(BaseModel): """OIDC provider, can also be a resource server""" id: str name: str url: str client_id: str client_secret: str = "" # For PKCE (not implemented yet) code_challenge_method: str | None = None hint: str = "No hint" resources: list[Resource] = [] account_url_template: str | None = None info_url: str | None = ( None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key) ) info: dict[str, str | int] | None = None # Info fetched from info_url, eg. public key public_key: str | None = None signature_alg: str = "RS256" resource_provider_scopes: list[str] = [] @computed_field @property def openid_configuration(self) -> str: return self.url + "/.well-known/openid-configuration" @computed_field @property def token_url(self) -> str: return "auth/" + self.id def get_account_url(self, request: Request, user: User) -> str | None: if self.account_url_template: if not (self.url.endswith("/") or self.account_url_template.startswith("/")): sep = "/" else: sep = "" return self.url + sep + self.account_url_template.format(request=request, user=user) else: return None def get_public_key(self) -> str: """Return the public key formatted for decoding token""" public_key = self.public_key or (self.info is not None and self.info["public_key"]) if public_key is None: raise AttributeError(f"Cannot get public key for {self.name}") return f""" -----BEGIN PUBLIC KEY----- {public_key} -----END PUBLIC KEY----- """ def decode(self, token: str, verify_signature: bool = True) -> dict[str, Any]: """Decode the token with signature check""" if settings.debug_token: decoded = decode( token, self.get_public_key(), algorithms=[self.signature_alg], audience=["account", "oidc-test", "oidc-test-web"], options={ "verify_signature": False, "verify_aud": False, }, # not settings.insecure.skip_verify_signature}, ) logger.debug(str(decoded)) return decode( token, self.get_public_key(), algorithms=[self.signature_alg], audience=["account", "oidc-test", "oidc-test-web"], options={ "verify_signature": verify_signature, }, # not settings.insecure.skip_verify_signature}, ) class ResourceProvider(BaseModel): id: str name: str base_url: AnyUrl resources: list[Resource] = [] class OIDCSettings(BaseModel): show_session_details: bool = False providers: list[OIDCProvider] = [] swagger_provider: str = "" class Insecure(BaseModel): """Warning: changing these defaults are only suitable for debugging""" skip_verify_signature: bool = False class Settings(BaseSettings): """Settings wil be read from an .env file""" model_config = SettingsConfigDict(env_nested_delimiter="__") oidc: OIDCSettings = OIDCSettings() resource_providers: list[ResourceProvider] = [] secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16)) log: bool = False insecure: Insecure = Insecure() cors_origins: list[str] = [] debug_token: bool = False show_token: bool = False @classmethod def settings_customise_sources( cls, settings_cls: Type[BaseSettings], init_settings: PydanticBaseSettingsSource, env_settings: PydanticBaseSettingsSource, dotenv_settings: PydanticBaseSettingsSource, file_secret_settings: PydanticBaseSettingsSource, ) -> Tuple[PydanticBaseSettingsSource, ...]: return ( init_settings, env_settings, file_secret_settings, YamlConfigSettingsSource( settings_cls, Path( Path( environ.get("OIDC_TEST_SETTINGS_FILE", Path.cwd() / "settings.yaml"), ) ), ), dotenv_settings, ) settings = Settings() oidc_providers_settings: dict[str, OIDCProvider] = dict( [(provider.id, provider) for provider in settings.oidc.providers] )