diff --git a/src/oidc-test/main.py b/src/oidc-test/main.py index 6c5c15c..df82b60 100644 --- a/src/oidc-test/main.py +++ b/src/oidc-test/main.py @@ -1,6 +1,4 @@ -from typing import Annotated, Type, Tuple -import string -import random +from typing import Annotated from httpx import HTTPError from fastapi import Depends, FastAPI, HTTPException, Request, status @@ -10,93 +8,32 @@ from fastapi.security import OpenIdConnect from starlette.middleware.sessions import SessionMiddleware from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.starlette_client import OAuth, OAuthError -from pydantic import BaseModel, computed_field -from pydantic_settings import ( - BaseSettings, - YamlConfigSettingsSource, - PydanticBaseSettingsSource, -) +from .settings import settings from .models import User templates = Jinja2Templates("src/templates") -class OIDCProvider(BaseModel): - name: str = "" - url: str = "" - client_id: str = "" - client_secret: str = "" - is_swagger: bool = False - - @computed_field - @property - def provider_url(self) -> str: - return self.url + "/.well-known/openid-configuration" - - @computed_field - @property - def token_url(self) -> str: - return "auth/" + self.name - - -class OIDCSettings(BaseModel): - show_session_details: bool = False - providers: list[OIDCProvider] = [] - swagger_provider: str = "" - - def get_swagger_provider(self) -> OIDCProvider: - for provider in self.providers: - if provider.is_swagger: - return provider - else: - raise UserWarning("Please define a provider for Swagger with id_swagger") - - -class Settings(BaseSettings): - """Settings wil be read from an .env file""" - - oidc: OIDCSettings = OIDCSettings() - secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16)) - - class Config: - env_nested_delimiter = "__" - - @classmethod - def settings_customise_sources( - cls, - settings_cls: Type[BaseSettings], - init_settings: PydanticBaseSettingsSource, - env_settings: PydanticBaseSettingsSource, - dotenv_settings: PydanticBaseSettingsSource, - file_secret_settings: PydanticBaseSettingsSource, - ) -> Tuple[PydanticBaseSettingsSource, ...]: - return ( - init_settings, - env_settings, - file_secret_settings, - YamlConfigSettingsSource(settings_cls, "settings.yaml"), - dotenv_settings, - ) - - -settings = Settings() - - -swagger_provider = settings.oidc.get_swagger_provider() -if swagger_provider is not None: - swagger_ui_init_oauth = { - "clientId": settings.oidc.get_swagger_provider().client_id, - "scopes": ["openid"], # fill in additional scopes when necessary - "appName": "Test Application", - # "usePkceWithAuthorizationCodeGrant": True, - } -else: - swagger_ui_init_oauth = None +# swagger_provider = settings.oidc.get_swagger_provider() +# if swagger_provider is not None: +# swagger_ui_init_oauth = { +# "clientId": settings.oidc.get_swagger_provider().client_id, +# "scopes": ["openid"], # fill in additional scopes when necessary +# "appName": "Test Application", +# # "usePkceWithAuthorizationCodeGrant": True, +# } +# else: +# swagger_ui_init_oauth = None app = FastAPI( title="OIDC auth test", - swagger_ui_init_oauth=swagger_ui_init_oauth, + # swagger_ui_init_oauth=swagger_ui_init_oauth, +) + +app.add_middleware( + SessionMiddleware, + secret_key=settings.secret_key, ) authlib_oauth = OAuth() @@ -110,19 +47,18 @@ for provider in settings.oidc.providers: # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) ) -oidc_providers = dict( - ( - provider.name, - OpenIdConnect( - openIdConnectUrl=provider.url, - scheme_name="openid", - auto_error=True, - ), - ) - for provider in settings.oidc.providers -) - -oidc_scheme = oidc_providers[swagger_provider.name] +# oidc_providers = dict( +# ( +# provider.name, +# OpenIdConnect( +# openIdConnectUrl=provider.url, +# scheme_name="openid", +# auto_error=True, +# ), +# ) +# for provider in settings.oidc.providers +# ) +# oidc_scheme = oidc_providers[swagger_provider.name] def get_current_user(request: Request) -> User: @@ -139,33 +75,27 @@ def get_current_user_or_none(request: Request) -> User | None: return None -def fastapi_oauth2(): - breakpoint() - ... +# def fastapi_oauth2(): +# breakpoint() +# ... -app.add_middleware( - SessionMiddleware, - secret_key=settings.secret_key, -) - - -async def current_user(request: Request, token: str | None = Depends(fastapi_oauth2)): - # we could query the identity provider to give us some information about the user - # userinfo = await self.authlib_oauth.provider.userinfo(token={"access_token": token}) - - # in my case, the JWT already contains all the information so I only need to decode and verify it - try: - # note that this also validates the JWT by validating all the claims - user = await authlib_oauth.provider.parse_id_token( - request, token={"id_token": token} - ) - except Exception as exp: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=f"Supplied authentication could not be validated ({exp})", - ) - return user +# async def current_user(request: Request, token: str | None = Depends(fastapi_oauth2)): +# # we could query the identity provider to give us some information about the user +# # userinfo = await self.authlib_oauth.provider.userinfo(token={"access_token": token}) +# +# # in my case, the JWT already contains all the information so I only need to decode and verify it +# try: +# # note that this also validates the JWT by validating all the claims +# user = await authlib_oauth.provider.parse_id_token( +# request, token={"id_token": token} +# ) +# except Exception as exp: +# raise HTTPException( +# status_code=status.HTTP_401_UNAUTHORIZED, +# detail=f"Supplied authentication could not be validated ({exp})", +# ) +# return user @app.get("/login") diff --git a/src/oidc-test/settings.py b/src/oidc-test/settings.py new file mode 100644 index 0000000..e188b82 --- /dev/null +++ b/src/oidc-test/settings.py @@ -0,0 +1,71 @@ +import string +import random +from typing import Type, Tuple + +from pydantic import BaseModel, computed_field +from pydantic_settings import ( + BaseSettings, + PydanticBaseSettingsSource, + YamlConfigSettingsSource, +) + + +class OIDCProvider(BaseModel): + name: str = "" + url: str = "" + client_id: str = "" + client_secret: str = "" + is_swagger: bool = False + + @computed_field + @property + def provider_url(self) -> str: + return self.url + "/.well-known/openid-configuration" + + @computed_field + @property + def token_url(self) -> str: + return "auth/" + self.name + + +class OIDCSettings(BaseModel): + show_session_details: bool = False + providers: list[OIDCProvider] = [] + swagger_provider: str = "" + + def get_swagger_provider(self) -> OIDCProvider: + for provider in self.providers: + if provider.is_swagger: + return provider + else: + raise UserWarning("Please define a provider for Swagger with id_swagger") + + +class Settings(BaseSettings): + """Settings wil be read from an .env file""" + + oidc: OIDCSettings = OIDCSettings() + secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16)) + + class Config: + env_nested_delimiter = "__" + + @classmethod + def settings_customise_sources( + cls, + settings_cls: Type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> Tuple[PydanticBaseSettingsSource, ...]: + return ( + init_settings, + env_settings, + file_secret_settings, + YamlConfigSettingsSource(settings_cls, "settings.yaml"), + dotenv_settings, + ) + + +settings = Settings()