Refactor
This commit is contained in:
parent
f67d1f4a1d
commit
17662dd5bc
2 changed files with 120 additions and 119 deletions
|
@ -1,6 +1,4 @@
|
||||||
from typing import Annotated, Type, Tuple
|
from typing import Annotated
|
||||||
import string
|
|
||||||
import random
|
|
||||||
|
|
||||||
from httpx import HTTPError
|
from httpx import HTTPError
|
||||||
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
||||||
|
@ -10,93 +8,32 @@ from fastapi.security import OpenIdConnect
|
||||||
from starlette.middleware.sessions import SessionMiddleware
|
from starlette.middleware.sessions import SessionMiddleware
|
||||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||||
from authlib.integrations.starlette_client import OAuth, OAuthError
|
from authlib.integrations.starlette_client import OAuth, OAuthError
|
||||||
from pydantic import BaseModel, computed_field
|
|
||||||
from pydantic_settings import (
|
|
||||||
BaseSettings,
|
|
||||||
YamlConfigSettingsSource,
|
|
||||||
PydanticBaseSettingsSource,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
from .settings import settings
|
||||||
from .models import User
|
from .models import User
|
||||||
|
|
||||||
templates = Jinja2Templates("src/templates")
|
templates = Jinja2Templates("src/templates")
|
||||||
|
|
||||||
|
|
||||||
class OIDCProvider(BaseModel):
|
# swagger_provider = settings.oidc.get_swagger_provider()
|
||||||
name: str = ""
|
# if swagger_provider is not None:
|
||||||
url: str = ""
|
# swagger_ui_init_oauth = {
|
||||||
client_id: str = ""
|
# "clientId": settings.oidc.get_swagger_provider().client_id,
|
||||||
client_secret: str = ""
|
# "scopes": ["openid"], # fill in additional scopes when necessary
|
||||||
is_swagger: bool = False
|
# "appName": "Test Application",
|
||||||
|
# # "usePkceWithAuthorizationCodeGrant": True,
|
||||||
@computed_field
|
# }
|
||||||
@property
|
# else:
|
||||||
def provider_url(self) -> str:
|
# swagger_ui_init_oauth = None
|
||||||
return self.url + "/.well-known/openid-configuration"
|
|
||||||
|
|
||||||
@computed_field
|
|
||||||
@property
|
|
||||||
def token_url(self) -> str:
|
|
||||||
return "auth/" + self.name
|
|
||||||
|
|
||||||
|
|
||||||
class OIDCSettings(BaseModel):
|
|
||||||
show_session_details: bool = False
|
|
||||||
providers: list[OIDCProvider] = []
|
|
||||||
swagger_provider: str = ""
|
|
||||||
|
|
||||||
def get_swagger_provider(self) -> OIDCProvider:
|
|
||||||
for provider in self.providers:
|
|
||||||
if provider.is_swagger:
|
|
||||||
return provider
|
|
||||||
else:
|
|
||||||
raise UserWarning("Please define a provider for Swagger with id_swagger")
|
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
|
||||||
"""Settings wil be read from an .env file"""
|
|
||||||
|
|
||||||
oidc: OIDCSettings = OIDCSettings()
|
|
||||||
secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16))
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
env_nested_delimiter = "__"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def settings_customise_sources(
|
|
||||||
cls,
|
|
||||||
settings_cls: Type[BaseSettings],
|
|
||||||
init_settings: PydanticBaseSettingsSource,
|
|
||||||
env_settings: PydanticBaseSettingsSource,
|
|
||||||
dotenv_settings: PydanticBaseSettingsSource,
|
|
||||||
file_secret_settings: PydanticBaseSettingsSource,
|
|
||||||
) -> Tuple[PydanticBaseSettingsSource, ...]:
|
|
||||||
return (
|
|
||||||
init_settings,
|
|
||||||
env_settings,
|
|
||||||
file_secret_settings,
|
|
||||||
YamlConfigSettingsSource(settings_cls, "settings.yaml"),
|
|
||||||
dotenv_settings,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
|
|
||||||
|
|
||||||
swagger_provider = settings.oidc.get_swagger_provider()
|
|
||||||
if swagger_provider is not None:
|
|
||||||
swagger_ui_init_oauth = {
|
|
||||||
"clientId": settings.oidc.get_swagger_provider().client_id,
|
|
||||||
"scopes": ["openid"], # fill in additional scopes when necessary
|
|
||||||
"appName": "Test Application",
|
|
||||||
# "usePkceWithAuthorizationCodeGrant": True,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
swagger_ui_init_oauth = None
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="OIDC auth test",
|
title="OIDC auth test",
|
||||||
swagger_ui_init_oauth=swagger_ui_init_oauth,
|
# swagger_ui_init_oauth=swagger_ui_init_oauth,
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
SessionMiddleware,
|
||||||
|
secret_key=settings.secret_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
authlib_oauth = OAuth()
|
authlib_oauth = OAuth()
|
||||||
|
@ -110,19 +47,18 @@ for provider in settings.oidc.providers:
|
||||||
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
|
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
|
||||||
)
|
)
|
||||||
|
|
||||||
oidc_providers = dict(
|
# oidc_providers = dict(
|
||||||
(
|
# (
|
||||||
provider.name,
|
# provider.name,
|
||||||
OpenIdConnect(
|
# OpenIdConnect(
|
||||||
openIdConnectUrl=provider.url,
|
# openIdConnectUrl=provider.url,
|
||||||
scheme_name="openid",
|
# scheme_name="openid",
|
||||||
auto_error=True,
|
# auto_error=True,
|
||||||
),
|
# ),
|
||||||
)
|
# )
|
||||||
for provider in settings.oidc.providers
|
# for provider in settings.oidc.providers
|
||||||
)
|
# )
|
||||||
|
# oidc_scheme = oidc_providers[swagger_provider.name]
|
||||||
oidc_scheme = oidc_providers[swagger_provider.name]
|
|
||||||
|
|
||||||
|
|
||||||
def get_current_user(request: Request) -> User:
|
def get_current_user(request: Request) -> User:
|
||||||
|
@ -139,33 +75,27 @@ def get_current_user_or_none(request: Request) -> User | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def fastapi_oauth2():
|
# def fastapi_oauth2():
|
||||||
breakpoint()
|
# breakpoint()
|
||||||
...
|
# ...
|
||||||
|
|
||||||
|
|
||||||
app.add_middleware(
|
# async def current_user(request: Request, token: str | None = Depends(fastapi_oauth2)):
|
||||||
SessionMiddleware,
|
# # we could query the identity provider to give us some information about the user
|
||||||
secret_key=settings.secret_key,
|
# # userinfo = await self.authlib_oauth.provider.userinfo(token={"access_token": token})
|
||||||
)
|
#
|
||||||
|
# # in my case, the JWT already contains all the information so I only need to decode and verify it
|
||||||
|
# try:
|
||||||
async def current_user(request: Request, token: str | None = Depends(fastapi_oauth2)):
|
# # note that this also validates the JWT by validating all the claims
|
||||||
# we could query the identity provider to give us some information about the user
|
# user = await authlib_oauth.provider.parse_id_token(
|
||||||
# userinfo = await self.authlib_oauth.provider.userinfo(token={"access_token": token})
|
# request, token={"id_token": token}
|
||||||
|
# )
|
||||||
# in my case, the JWT already contains all the information so I only need to decode and verify it
|
# except Exception as exp:
|
||||||
try:
|
# raise HTTPException(
|
||||||
# note that this also validates the JWT by validating all the claims
|
# status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
user = await authlib_oauth.provider.parse_id_token(
|
# detail=f"Supplied authentication could not be validated ({exp})",
|
||||||
request, token={"id_token": token}
|
# )
|
||||||
)
|
# return user
|
||||||
except Exception as exp:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail=f"Supplied authentication could not be validated ({exp})",
|
|
||||||
)
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/login")
|
@app.get("/login")
|
||||||
|
|
71
src/oidc-test/settings.py
Normal file
71
src/oidc-test/settings.py
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
import string
|
||||||
|
import random
|
||||||
|
from typing import Type, Tuple
|
||||||
|
|
||||||
|
from pydantic import BaseModel, computed_field
|
||||||
|
from pydantic_settings import (
|
||||||
|
BaseSettings,
|
||||||
|
PydanticBaseSettingsSource,
|
||||||
|
YamlConfigSettingsSource,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OIDCProvider(BaseModel):
|
||||||
|
name: str = ""
|
||||||
|
url: str = ""
|
||||||
|
client_id: str = ""
|
||||||
|
client_secret: str = ""
|
||||||
|
is_swagger: bool = False
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def provider_url(self) -> str:
|
||||||
|
return self.url + "/.well-known/openid-configuration"
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def token_url(self) -> str:
|
||||||
|
return "auth/" + self.name
|
||||||
|
|
||||||
|
|
||||||
|
class OIDCSettings(BaseModel):
|
||||||
|
show_session_details: bool = False
|
||||||
|
providers: list[OIDCProvider] = []
|
||||||
|
swagger_provider: str = ""
|
||||||
|
|
||||||
|
def get_swagger_provider(self) -> OIDCProvider:
|
||||||
|
for provider in self.providers:
|
||||||
|
if provider.is_swagger:
|
||||||
|
return provider
|
||||||
|
else:
|
||||||
|
raise UserWarning("Please define a provider for Swagger with id_swagger")
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""Settings wil be read from an .env file"""
|
||||||
|
|
||||||
|
oidc: OIDCSettings = OIDCSettings()
|
||||||
|
secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16))
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_nested_delimiter = "__"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def settings_customise_sources(
|
||||||
|
cls,
|
||||||
|
settings_cls: Type[BaseSettings],
|
||||||
|
init_settings: PydanticBaseSettingsSource,
|
||||||
|
env_settings: PydanticBaseSettingsSource,
|
||||||
|
dotenv_settings: PydanticBaseSettingsSource,
|
||||||
|
file_secret_settings: PydanticBaseSettingsSource,
|
||||||
|
) -> Tuple[PydanticBaseSettingsSource, ...]:
|
||||||
|
return (
|
||||||
|
init_settings,
|
||||||
|
env_settings,
|
||||||
|
file_secret_settings,
|
||||||
|
YamlConfigSettingsSource(settings_cls, "settings.yaml"),
|
||||||
|
dotenv_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
Loading…
Add table
Add a link
Reference in a new issue