This commit is contained in:
phil 2025-01-02 10:46:02 +01:00
parent e13b7e1e29
commit 24f1761632

View file

@ -1,27 +1,21 @@
from typing import Annotated, Type, Tuple from typing import Annotated, Type, Tuple
import string import string
import random import random
from json import dumps
from httpx import HTTPError from httpx import HTTPError
from fastapi import Depends, FastAPI, HTTPException, Request, status from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.responses import HTMLResponse, RedirectResponse, PlainTextResponse from fastapi.responses import HTMLResponse, RedirectResponse, PlainTextResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from fastapi.security import ( from fastapi.security import OpenIdConnect
OAuth2PasswordBearer,
OpenIdConnect,
OAuth2AuthorizationCodeBearer,
)
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from authlib.integrations.starlette_client import OAuth, OAuthError from authlib.integrations.starlette_client import OAuth, OAuthError
from pydantic import BaseModel from pydantic import BaseModel, computed_field
from pydantic_settings import ( from pydantic_settings import (
BaseSettings, BaseSettings,
YamlConfigSettingsSource, YamlConfigSettingsSource,
PydanticBaseSettingsSource, PydanticBaseSettingsSource,
) )
from jose import jwt
from .models import User from .models import User
@ -35,10 +29,12 @@ class OIDCProvider(BaseModel):
client_secret: str = "" client_secret: str = ""
is_swagger: bool = False is_swagger: bool = False
@computed_field
@property @property
def provider_url(self): def provider_url(self):
return self.url + "/.well-known/openid-configuration" return self.url + "/.well-known/openid-configuration"
@computed_field
@property @property
def token_url(self): def token_url(self):
return "auth/" + self.name return "auth/" + self.name
@ -95,6 +91,8 @@ if swagger_provider is not None:
"appName": "Test Application", "appName": "Test Application",
# "usePkceWithAuthorizationCodeGrant": True, # "usePkceWithAuthorizationCodeGrant": True,
} }
else:
swagger_ui_init_oauth = None
app = FastAPI( app = FastAPI(
title="OIDC auth test", title="OIDC auth test",
@ -124,7 +122,6 @@ oidc_providers = dict(
for provider in settings.oidc.providers for provider in settings.oidc.providers
) )
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=swagger_provider.token_url)
oidc_scheme = oidc_providers[swagger_provider.name] oidc_scheme = oidc_providers[swagger_provider.name]
@ -179,6 +176,7 @@ async def login(request: Request, provider: str) -> RedirectResponse:
except AttributeError: except AttributeError:
raise HTTPException(500, "") raise HTTPException(500, "")
try: try:
breakpoint()
return await provider_.authorize_redirect(request, redirect_uri) return await provider_.authorize_redirect(request, redirect_uri)
except HTTPError: except HTTPError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
@ -191,6 +189,7 @@ async def auth(request: Request, provider: str) -> RedirectResponse:
except AttributeError: except AttributeError:
raise HTTPException(500, "") raise HTTPException(500, "")
try: try:
breakpoint()
token = await provider_.authorize_access_token(request) token = await provider_.authorize_access_token(request)
except OAuthError as error: except OAuthError as error:
raise HTTPException(status_code=401, detail=error.error) raise HTTPException(status_code=401, detail=error.error)
@ -215,7 +214,7 @@ async def home(
return templates.TemplateResponse( return templates.TemplateResponse(
request=request, request=request,
context={ context={
"settings": settings.dict(), "settings": settings.model_dump(),
"user": user, "user": user,
"auth_data": request.session.get("user"), "auth_data": request.session.get("user"),
}, },