This commit is contained in:
phil 2025-01-02 11:23:53 +01:00
parent f67d1f4a1d
commit 17662dd5bc
2 changed files with 120 additions and 119 deletions

View file

@ -1,6 +1,4 @@
from typing import Annotated, Type, Tuple from typing import Annotated
import string
import random
from httpx import HTTPError from httpx import HTTPError
from fastapi import Depends, FastAPI, HTTPException, Request, status from fastapi import Depends, FastAPI, HTTPException, Request, status
@ -10,93 +8,32 @@ from fastapi.security import OpenIdConnect
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from authlib.integrations.starlette_client import OAuth, OAuthError from authlib.integrations.starlette_client import OAuth, OAuthError
from pydantic import BaseModel, computed_field
from pydantic_settings import (
BaseSettings,
YamlConfigSettingsSource,
PydanticBaseSettingsSource,
)
from .settings import settings
from .models import User from .models import User
templates = Jinja2Templates("src/templates") templates = Jinja2Templates("src/templates")
class OIDCProvider(BaseModel): # swagger_provider = settings.oidc.get_swagger_provider()
name: str = "" # if swagger_provider is not None:
url: str = "" # swagger_ui_init_oauth = {
client_id: str = "" # "clientId": settings.oidc.get_swagger_provider().client_id,
client_secret: str = "" # "scopes": ["openid"], # fill in additional scopes when necessary
is_swagger: bool = False # "appName": "Test Application",
# # "usePkceWithAuthorizationCodeGrant": True,
@computed_field # }
@property # else:
def provider_url(self) -> str: # swagger_ui_init_oauth = None
return self.url + "/.well-known/openid-configuration"
@computed_field
@property
def token_url(self) -> str:
return "auth/" + self.name
class OIDCSettings(BaseModel):
show_session_details: bool = False
providers: list[OIDCProvider] = []
swagger_provider: str = ""
def get_swagger_provider(self) -> OIDCProvider:
for provider in self.providers:
if provider.is_swagger:
return provider
else:
raise UserWarning("Please define a provider for Swagger with id_swagger")
class Settings(BaseSettings):
"""Settings wil be read from an .env file"""
oidc: OIDCSettings = OIDCSettings()
secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16))
class Config:
env_nested_delimiter = "__"
@classmethod
def settings_customise_sources(
cls,
settings_cls: Type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> Tuple[PydanticBaseSettingsSource, ...]:
return (
init_settings,
env_settings,
file_secret_settings,
YamlConfigSettingsSource(settings_cls, "settings.yaml"),
dotenv_settings,
)
settings = Settings()
swagger_provider = settings.oidc.get_swagger_provider()
if swagger_provider is not None:
swagger_ui_init_oauth = {
"clientId": settings.oidc.get_swagger_provider().client_id,
"scopes": ["openid"], # fill in additional scopes when necessary
"appName": "Test Application",
# "usePkceWithAuthorizationCodeGrant": True,
}
else:
swagger_ui_init_oauth = None
app = FastAPI( app = FastAPI(
title="OIDC auth test", title="OIDC auth test",
swagger_ui_init_oauth=swagger_ui_init_oauth, # swagger_ui_init_oauth=swagger_ui_init_oauth,
)
app.add_middleware(
SessionMiddleware,
secret_key=settings.secret_key,
) )
authlib_oauth = OAuth() authlib_oauth = OAuth()
@ -110,19 +47,18 @@ for provider in settings.oidc.providers:
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
) )
oidc_providers = dict( # oidc_providers = dict(
( # (
provider.name, # provider.name,
OpenIdConnect( # OpenIdConnect(
openIdConnectUrl=provider.url, # openIdConnectUrl=provider.url,
scheme_name="openid", # scheme_name="openid",
auto_error=True, # auto_error=True,
), # ),
) # )
for provider in settings.oidc.providers # for provider in settings.oidc.providers
) # )
# oidc_scheme = oidc_providers[swagger_provider.name]
oidc_scheme = oidc_providers[swagger_provider.name]
def get_current_user(request: Request) -> User: def get_current_user(request: Request) -> User:
@ -139,33 +75,27 @@ def get_current_user_or_none(request: Request) -> User | None:
return None return None
def fastapi_oauth2(): # def fastapi_oauth2():
breakpoint() # breakpoint()
... # ...
app.add_middleware( # async def current_user(request: Request, token: str | None = Depends(fastapi_oauth2)):
SessionMiddleware, # # we could query the identity provider to give us some information about the user
secret_key=settings.secret_key, # # userinfo = await self.authlib_oauth.provider.userinfo(token={"access_token": token})
) #
# # in my case, the JWT already contains all the information so I only need to decode and verify it
# try:
async def current_user(request: Request, token: str | None = Depends(fastapi_oauth2)): # # note that this also validates the JWT by validating all the claims
# we could query the identity provider to give us some information about the user # user = await authlib_oauth.provider.parse_id_token(
# userinfo = await self.authlib_oauth.provider.userinfo(token={"access_token": token}) # request, token={"id_token": token}
# )
# in my case, the JWT already contains all the information so I only need to decode and verify it # except Exception as exp:
try: # raise HTTPException(
# note that this also validates the JWT by validating all the claims # status_code=status.HTTP_401_UNAUTHORIZED,
user = await authlib_oauth.provider.parse_id_token( # detail=f"Supplied authentication could not be validated ({exp})",
request, token={"id_token": token} # )
) # return user
except Exception as exp:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Supplied authentication could not be validated ({exp})",
)
return user
@app.get("/login") @app.get("/login")

71
src/oidc-test/settings.py Normal file
View file

@ -0,0 +1,71 @@
import string
import random
from typing import Type, Tuple
from pydantic import BaseModel, computed_field
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
YamlConfigSettingsSource,
)
class OIDCProvider(BaseModel):
name: str = ""
url: str = ""
client_id: str = ""
client_secret: str = ""
is_swagger: bool = False
@computed_field
@property
def provider_url(self) -> str:
return self.url + "/.well-known/openid-configuration"
@computed_field
@property
def token_url(self) -> str:
return "auth/" + self.name
class OIDCSettings(BaseModel):
show_session_details: bool = False
providers: list[OIDCProvider] = []
swagger_provider: str = ""
def get_swagger_provider(self) -> OIDCProvider:
for provider in self.providers:
if provider.is_swagger:
return provider
else:
raise UserWarning("Please define a provider for Swagger with id_swagger")
class Settings(BaseSettings):
"""Settings wil be read from an .env file"""
oidc: OIDCSettings = OIDCSettings()
secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16))
class Config:
env_nested_delimiter = "__"
@classmethod
def settings_customise_sources(
cls,
settings_cls: Type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> Tuple[PydanticBaseSettingsSource, ...]:
return (
init_settings,
env_settings,
file_secret_settings,
YamlConfigSettingsSource(settings_cls, "settings.yaml"),
dotenv_settings,
)
settings = Settings()