This commit is contained in:
phil 2025-01-02 11:23:53 +01:00
parent f67d1f4a1d
commit 17662dd5bc
2 changed files with 120 additions and 119 deletions

View file

@ -1,6 +1,4 @@
from typing import Annotated, Type, Tuple
import string
import random
from typing import Annotated
from httpx import HTTPError
from fastapi import Depends, FastAPI, HTTPException, Request, status
@ -10,93 +8,32 @@ from fastapi.security import OpenIdConnect
from starlette.middleware.sessions import SessionMiddleware
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from authlib.integrations.starlette_client import OAuth, OAuthError
from pydantic import BaseModel, computed_field
from pydantic_settings import (
BaseSettings,
YamlConfigSettingsSource,
PydanticBaseSettingsSource,
)
from .settings import settings
from .models import User
templates = Jinja2Templates("src/templates")
class OIDCProvider(BaseModel):
name: str = ""
url: str = ""
client_id: str = ""
client_secret: str = ""
is_swagger: bool = False
@computed_field
@property
def provider_url(self) -> str:
return self.url + "/.well-known/openid-configuration"
@computed_field
@property
def token_url(self) -> str:
return "auth/" + self.name
class OIDCSettings(BaseModel):
show_session_details: bool = False
providers: list[OIDCProvider] = []
swagger_provider: str = ""
def get_swagger_provider(self) -> OIDCProvider:
for provider in self.providers:
if provider.is_swagger:
return provider
else:
raise UserWarning("Please define a provider for Swagger with id_swagger")
class Settings(BaseSettings):
"""Settings wil be read from an .env file"""
oidc: OIDCSettings = OIDCSettings()
secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16))
class Config:
env_nested_delimiter = "__"
@classmethod
def settings_customise_sources(
cls,
settings_cls: Type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> Tuple[PydanticBaseSettingsSource, ...]:
return (
init_settings,
env_settings,
file_secret_settings,
YamlConfigSettingsSource(settings_cls, "settings.yaml"),
dotenv_settings,
)
settings = Settings()
swagger_provider = settings.oidc.get_swagger_provider()
if swagger_provider is not None:
swagger_ui_init_oauth = {
"clientId": settings.oidc.get_swagger_provider().client_id,
"scopes": ["openid"], # fill in additional scopes when necessary
"appName": "Test Application",
# "usePkceWithAuthorizationCodeGrant": True,
}
else:
swagger_ui_init_oauth = None
# swagger_provider = settings.oidc.get_swagger_provider()
# if swagger_provider is not None:
# swagger_ui_init_oauth = {
# "clientId": settings.oidc.get_swagger_provider().client_id,
# "scopes": ["openid"], # fill in additional scopes when necessary
# "appName": "Test Application",
# # "usePkceWithAuthorizationCodeGrant": True,
# }
# else:
# swagger_ui_init_oauth = None
app = FastAPI(
title="OIDC auth test",
swagger_ui_init_oauth=swagger_ui_init_oauth,
# swagger_ui_init_oauth=swagger_ui_init_oauth,
)
app.add_middleware(
SessionMiddleware,
secret_key=settings.secret_key,
)
authlib_oauth = OAuth()
@ -110,19 +47,18 @@ for provider in settings.oidc.providers:
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
)
oidc_providers = dict(
(
provider.name,
OpenIdConnect(
openIdConnectUrl=provider.url,
scheme_name="openid",
auto_error=True,
),
)
for provider in settings.oidc.providers
)
oidc_scheme = oidc_providers[swagger_provider.name]
# oidc_providers = dict(
# (
# provider.name,
# OpenIdConnect(
# openIdConnectUrl=provider.url,
# scheme_name="openid",
# auto_error=True,
# ),
# )
# for provider in settings.oidc.providers
# )
# oidc_scheme = oidc_providers[swagger_provider.name]
def get_current_user(request: Request) -> User:
@ -139,33 +75,27 @@ def get_current_user_or_none(request: Request) -> User | None:
return None
def fastapi_oauth2():
breakpoint()
...
# def fastapi_oauth2():
# breakpoint()
# ...
app.add_middleware(
SessionMiddleware,
secret_key=settings.secret_key,
)
async def current_user(request: Request, token: str | None = Depends(fastapi_oauth2)):
# we could query the identity provider to give us some information about the user
# userinfo = await self.authlib_oauth.provider.userinfo(token={"access_token": token})
# in my case, the JWT already contains all the information so I only need to decode and verify it
try:
# note that this also validates the JWT by validating all the claims
user = await authlib_oauth.provider.parse_id_token(
request, token={"id_token": token}
)
except Exception as exp:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Supplied authentication could not be validated ({exp})",
)
return user
# async def current_user(request: Request, token: str | None = Depends(fastapi_oauth2)):
# # we could query the identity provider to give us some information about the user
# # userinfo = await self.authlib_oauth.provider.userinfo(token={"access_token": token})
#
# # in my case, the JWT already contains all the information so I only need to decode and verify it
# try:
# # note that this also validates the JWT by validating all the claims
# user = await authlib_oauth.provider.parse_id_token(
# request, token={"id_token": token}
# )
# except Exception as exp:
# raise HTTPException(
# status_code=status.HTTP_401_UNAUTHORIZED,
# detail=f"Supplied authentication could not be validated ({exp})",
# )
# return user
@app.get("/login")

71
src/oidc-test/settings.py Normal file
View file

@ -0,0 +1,71 @@
import string
import random
from typing import Type, Tuple
from pydantic import BaseModel, computed_field
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
YamlConfigSettingsSource,
)
class OIDCProvider(BaseModel):
name: str = ""
url: str = ""
client_id: str = ""
client_secret: str = ""
is_swagger: bool = False
@computed_field
@property
def provider_url(self) -> str:
return self.url + "/.well-known/openid-configuration"
@computed_field
@property
def token_url(self) -> str:
return "auth/" + self.name
class OIDCSettings(BaseModel):
show_session_details: bool = False
providers: list[OIDCProvider] = []
swagger_provider: str = ""
def get_swagger_provider(self) -> OIDCProvider:
for provider in self.providers:
if provider.is_swagger:
return provider
else:
raise UserWarning("Please define a provider for Swagger with id_swagger")
class Settings(BaseSettings):
"""Settings wil be read from an .env file"""
oidc: OIDCSettings = OIDCSettings()
secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16))
class Config:
env_nested_delimiter = "__"
@classmethod
def settings_customise_sources(
cls,
settings_cls: Type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> Tuple[PydanticBaseSettingsSource, ...]:
return (
init_settings,
env_settings,
file_secret_settings,
YamlConfigSettingsSource(settings_cls, "settings.yaml"),
dotenv_settings,
)
settings = Settings()