Refactor
This commit is contained in:
parent
f67d1f4a1d
commit
17662dd5bc
2 changed files with 120 additions and 119 deletions
|
@ -1,6 +1,4 @@
|
|||
from typing import Annotated, Type, Tuple
|
||||
import string
|
||||
import random
|
||||
from typing import Annotated
|
||||
|
||||
from httpx import HTTPError
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
||||
|
@ -10,93 +8,32 @@ from fastapi.security import OpenIdConnect
|
|||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
from authlib.integrations.starlette_client import OAuth, OAuthError
|
||||
from pydantic import BaseModel, computed_field
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
YamlConfigSettingsSource,
|
||||
PydanticBaseSettingsSource,
|
||||
)
|
||||
|
||||
from .settings import settings
|
||||
from .models import User
|
||||
|
||||
templates = Jinja2Templates("src/templates")
|
||||
|
||||
|
||||
class OIDCProvider(BaseModel):
|
||||
name: str = ""
|
||||
url: str = ""
|
||||
client_id: str = ""
|
||||
client_secret: str = ""
|
||||
is_swagger: bool = False
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def provider_url(self) -> str:
|
||||
return self.url + "/.well-known/openid-configuration"
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def token_url(self) -> str:
|
||||
return "auth/" + self.name
|
||||
|
||||
|
||||
class OIDCSettings(BaseModel):
|
||||
show_session_details: bool = False
|
||||
providers: list[OIDCProvider] = []
|
||||
swagger_provider: str = ""
|
||||
|
||||
def get_swagger_provider(self) -> OIDCProvider:
|
||||
for provider in self.providers:
|
||||
if provider.is_swagger:
|
||||
return provider
|
||||
else:
|
||||
raise UserWarning("Please define a provider for Swagger with id_swagger")
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Settings wil be read from an .env file"""
|
||||
|
||||
oidc: OIDCSettings = OIDCSettings()
|
||||
secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16))
|
||||
|
||||
class Config:
|
||||
env_nested_delimiter = "__"
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: Type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> Tuple[PydanticBaseSettingsSource, ...]:
|
||||
return (
|
||||
init_settings,
|
||||
env_settings,
|
||||
file_secret_settings,
|
||||
YamlConfigSettingsSource(settings_cls, "settings.yaml"),
|
||||
dotenv_settings,
|
||||
)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
swagger_provider = settings.oidc.get_swagger_provider()
|
||||
if swagger_provider is not None:
|
||||
swagger_ui_init_oauth = {
|
||||
"clientId": settings.oidc.get_swagger_provider().client_id,
|
||||
"scopes": ["openid"], # fill in additional scopes when necessary
|
||||
"appName": "Test Application",
|
||||
# "usePkceWithAuthorizationCodeGrant": True,
|
||||
}
|
||||
else:
|
||||
swagger_ui_init_oauth = None
|
||||
# swagger_provider = settings.oidc.get_swagger_provider()
|
||||
# if swagger_provider is not None:
|
||||
# swagger_ui_init_oauth = {
|
||||
# "clientId": settings.oidc.get_swagger_provider().client_id,
|
||||
# "scopes": ["openid"], # fill in additional scopes when necessary
|
||||
# "appName": "Test Application",
|
||||
# # "usePkceWithAuthorizationCodeGrant": True,
|
||||
# }
|
||||
# else:
|
||||
# swagger_ui_init_oauth = None
|
||||
|
||||
app = FastAPI(
|
||||
title="OIDC auth test",
|
||||
swagger_ui_init_oauth=swagger_ui_init_oauth,
|
||||
# swagger_ui_init_oauth=swagger_ui_init_oauth,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
SessionMiddleware,
|
||||
secret_key=settings.secret_key,
|
||||
)
|
||||
|
||||
authlib_oauth = OAuth()
|
||||
|
@ -110,19 +47,18 @@ for provider in settings.oidc.providers:
|
|||
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
|
||||
)
|
||||
|
||||
oidc_providers = dict(
|
||||
(
|
||||
provider.name,
|
||||
OpenIdConnect(
|
||||
openIdConnectUrl=provider.url,
|
||||
scheme_name="openid",
|
||||
auto_error=True,
|
||||
),
|
||||
)
|
||||
for provider in settings.oidc.providers
|
||||
)
|
||||
|
||||
oidc_scheme = oidc_providers[swagger_provider.name]
|
||||
# oidc_providers = dict(
|
||||
# (
|
||||
# provider.name,
|
||||
# OpenIdConnect(
|
||||
# openIdConnectUrl=provider.url,
|
||||
# scheme_name="openid",
|
||||
# auto_error=True,
|
||||
# ),
|
||||
# )
|
||||
# for provider in settings.oidc.providers
|
||||
# )
|
||||
# oidc_scheme = oidc_providers[swagger_provider.name]
|
||||
|
||||
|
||||
def get_current_user(request: Request) -> User:
|
||||
|
@ -139,33 +75,27 @@ def get_current_user_or_none(request: Request) -> User | None:
|
|||
return None
|
||||
|
||||
|
||||
def fastapi_oauth2():
|
||||
breakpoint()
|
||||
...
|
||||
# def fastapi_oauth2():
|
||||
# breakpoint()
|
||||
# ...
|
||||
|
||||
|
||||
app.add_middleware(
|
||||
SessionMiddleware,
|
||||
secret_key=settings.secret_key,
|
||||
)
|
||||
|
||||
|
||||
async def current_user(request: Request, token: str | None = Depends(fastapi_oauth2)):
|
||||
# we could query the identity provider to give us some information about the user
|
||||
# userinfo = await self.authlib_oauth.provider.userinfo(token={"access_token": token})
|
||||
|
||||
# in my case, the JWT already contains all the information so I only need to decode and verify it
|
||||
try:
|
||||
# note that this also validates the JWT by validating all the claims
|
||||
user = await authlib_oauth.provider.parse_id_token(
|
||||
request, token={"id_token": token}
|
||||
)
|
||||
except Exception as exp:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Supplied authentication could not be validated ({exp})",
|
||||
)
|
||||
return user
|
||||
# async def current_user(request: Request, token: str | None = Depends(fastapi_oauth2)):
|
||||
# # we could query the identity provider to give us some information about the user
|
||||
# # userinfo = await self.authlib_oauth.provider.userinfo(token={"access_token": token})
|
||||
#
|
||||
# # in my case, the JWT already contains all the information so I only need to decode and verify it
|
||||
# try:
|
||||
# # note that this also validates the JWT by validating all the claims
|
||||
# user = await authlib_oauth.provider.parse_id_token(
|
||||
# request, token={"id_token": token}
|
||||
# )
|
||||
# except Exception as exp:
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
# detail=f"Supplied authentication could not be validated ({exp})",
|
||||
# )
|
||||
# return user
|
||||
|
||||
|
||||
@app.get("/login")
|
||||
|
|
71
src/oidc-test/settings.py
Normal file
71
src/oidc-test/settings.py
Normal file
|
@ -0,0 +1,71 @@
|
|||
import string
|
||||
import random
|
||||
from typing import Type, Tuple
|
||||
|
||||
from pydantic import BaseModel, computed_field
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
PydanticBaseSettingsSource,
|
||||
YamlConfigSettingsSource,
|
||||
)
|
||||
|
||||
|
||||
class OIDCProvider(BaseModel):
|
||||
name: str = ""
|
||||
url: str = ""
|
||||
client_id: str = ""
|
||||
client_secret: str = ""
|
||||
is_swagger: bool = False
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def provider_url(self) -> str:
|
||||
return self.url + "/.well-known/openid-configuration"
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def token_url(self) -> str:
|
||||
return "auth/" + self.name
|
||||
|
||||
|
||||
class OIDCSettings(BaseModel):
|
||||
show_session_details: bool = False
|
||||
providers: list[OIDCProvider] = []
|
||||
swagger_provider: str = ""
|
||||
|
||||
def get_swagger_provider(self) -> OIDCProvider:
|
||||
for provider in self.providers:
|
||||
if provider.is_swagger:
|
||||
return provider
|
||||
else:
|
||||
raise UserWarning("Please define a provider for Swagger with id_swagger")
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Settings wil be read from an .env file"""
|
||||
|
||||
oidc: OIDCSettings = OIDCSettings()
|
||||
secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16))
|
||||
|
||||
class Config:
|
||||
env_nested_delimiter = "__"
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: Type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> Tuple[PydanticBaseSettingsSource, ...]:
|
||||
return (
|
||||
init_settings,
|
||||
env_settings,
|
||||
file_secret_settings,
|
||||
YamlConfigSettingsSource(settings_cls, "settings.yaml"),
|
||||
dotenv_settings,
|
||||
)
|
||||
|
||||
|
||||
settings = Settings()
|
Loading…
Add table
Add a link
Reference in a new issue