Initial commit

This commit is contained in:
phil 2025-01-02 02:14:30 +01:00
commit 4be2036f3b
10 changed files with 1303 additions and 0 deletions

View file

224
src/oidc-test/main.py Normal file
View file

@ -0,0 +1,224 @@
from typing import Annotated, Type, Tuple
import string
import random
from json import dumps
from httpx import HTTPError
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from fastapi.security import (
OAuth2PasswordBearer,
OpenIdConnect,
OAuth2AuthorizationCodeBearer,
)
from starlette.middleware.sessions import SessionMiddleware
from authlib.integrations.starlette_client import OAuth, OAuthError
from pydantic import BaseModel
from pydantic_settings import (
BaseSettings,
YamlConfigSettingsSource,
PydanticBaseSettingsSource,
)
from jose import jwt
from .models import User
templates = Jinja2Templates("src/templates")
class OIDCProvider(BaseModel):
name: str = ""
url: str = ""
client_id: str = ""
client_secret: str = ""
is_swagger: bool = False
@property
def provider_url(self):
return self.url + "/.well-known/openid-configuration"
@property
def token_url(self):
return "auth/" + self.name
class OIDCSettings(BaseModel):
show_session_details: bool = False
providers: list[OIDCProvider] = []
swagger_provider: str = ""
def get_swagger_provider(self) -> OIDCProvider:
for provider in self.providers:
if provider.is_swagger:
return provider
else:
raise UserWarning("Please define a provider for Swagger with id_swagger")
class Settings(BaseSettings):
"""Settings wil be read from an .env file"""
oidc: OIDCSettings = OIDCSettings()
secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16))
class Config:
env_nested_delimiter = "__"
@classmethod
def settings_customise_sources(
cls,
settings_cls: Type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> Tuple[PydanticBaseSettingsSource, ...]:
return (
init_settings,
env_settings,
file_secret_settings,
YamlConfigSettingsSource(settings_cls, "settings.yaml"),
dotenv_settings,
)
settings = Settings()
swagger_provider = settings.oidc.get_swagger_provider()
if swagger_provider is not None:
swagger_ui_init_oauth = {
"clientId": settings.oidc.get_swagger_provider().client_id,
"scopes": ["openid"], # fill in additional scopes when necessary
"appName": "Test Application",
# "usePkceWithAuthorizationCodeGrant": True,
}
app = FastAPI(
title="OIDC auth test",
swagger_ui_init_oauth=swagger_ui_init_oauth,
)
authlib_oauth = OAuth()
for provider in settings.oidc.providers:
authlib_oauth.register(
name=provider.name,
server_metadata_url=provider.provider_url,
client_kwargs={"scope": "openid email offline_access profile"},
client_id=provider.client_id,
client_secret=provider.client_secret,
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
)
oidc_providers = dict(
(
provider.name,
OpenIdConnect(
openIdConnectUrl=provider.url,
scheme_name="openid",
auto_error=True,
),
)
for provider in settings.oidc.providers
)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=swagger_provider.token_url)
oidc_scheme = oidc_providers[swagger_provider.name]
def get_current_user(request: Request) -> User:
auth_data = request.session.get("user")
if auth_data is None:
raise HTTPException(401, "Not authorized")
return User(**auth_data)
def get_current_user_or_none(request: Request) -> User | None:
try:
return get_current_user(request)
except HTTPException:
return None
def fastapi_oauth2():
breakpoint()
...
app.add_middleware(
SessionMiddleware,
secret_key=settings.secret_key,
)
async def current_user(request: Request, token: str | None = Depends(fastapi_oauth2)):
# we could query the identity provider to give us some information about the user
# userinfo = await self.authlib_oauth.provider.userinfo(token={"access_token": token})
# in my case, the JWT already contains all the information so I only need to decode and verify it
try:
# note that this also validates the JWT by validating all the claims
user = await authlib_oauth.provider.parse_id_token(
request, token={"id_token": token}
)
except Exception as exp:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Supplied authentication could not be validated ({exp})",
)
return user
@app.get("/login")
async def login(request: Request, provider: str) -> RedirectResponse:
redirect_uri = request.url_for("auth", provider=provider)
try:
return await getattr(authlib_oauth, provider).authorize_redirect(
request, redirect_uri
)
except HTTPError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
@app.get("/auth/{provider}")
async def auth(request: Request, provider: str) -> RedirectResponse:
try:
token = await getattr(authlib_oauth, provider).authorize_access_token(request)
except OAuthError as error:
return HTMLResponse(f"<h1>{error.error}</h1>")
user = token.get("userinfo")
if user:
request.session["user"] = dict(user)
return RedirectResponse(url=request.session.pop("next", "/"))
return RedirectResponse(url="/login")
@app.get("/logout")
async def logout(request: Request) -> RedirectResponse:
request.session.pop("user", None)
return RedirectResponse(url="/")
@app.get("/", response_class=HTMLResponse)
async def home(
request: Request, user: Annotated[User, Depends(get_current_user_or_none)]
) -> HTMLResponse:
return templates.TemplateResponse(
request=request,
context={
"settings": settings.dict(),
"user": user,
"auth_data": request.session.get("user"),
},
name="index.html",
)
@app.get("/private")
async def read_items(token: Annotated[str, Depends(get_current_user)]):
return {"token": token}

11
src/oidc-test/models.py Normal file
View file

@ -0,0 +1,11 @@
from pydantic import BaseModel, EmailStr, AnyHttpUrl
# from app.models import User
class User(BaseModel, extra="ignore"):
id: str | None = None
name: str
email: EmailStr | None = None
picture: AnyHttpUrl | None = None

119
src/templates/index.html Normal file
View file

@ -0,0 +1,119 @@
<html>
<head>
<title>FastAPI OIDC test</title>
<style>
body {
font-family: Arial, Helvetica, sans-serif;
background-color: antiquewhite;
}
h1 {
text-align: center;
}
.content {
width: 100%;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
}
.user-info {
padding: 1em;
margin: 1em 0;
display: flex;
gap: 0.5em;
flex-direction: column;
width: fit-content;
align-items: center;
justify-content: center;
box-shadow: 0px 0px 10px lightgreen;
background-color: lightgreen;
}
.user-info * {
flex: 2 1 auto;
margin: 0;
}
.user-info .picture {
max-width: 3em;
max-height: 3em
}
.login-box {
text-align: center;
}
.login-toolbox {
max-width: 20em;
margin: auto;
display: flex;
flex-direction: column;
padding: 0 1em;
gap: 5px;
}
.login-toolbox a {
background-color: lightblue;
padding: 3px 6px;
text-decoration: none;
text-align: center;
color: black;
flex: 1 1 auto;
}
.login-toolbox .error {
color: darkred;
padding: 3px 6px;
text-align: center;
font-weight: bold;
flex: 1 1 auto;
}
.login-toolbox a.logout:hover {
background-color: orange;
}
.login-toolbox a:hover {
background-color: lightgreen;
}
.debug {
font-size: 90%;
}
.debug p, .debug .key {
font-weight: bold;
}
</style>
</head>
<body>
<h1>Test app for OIDC</h1>
<div class="login-box">
<div class="login-toolbox">
{% if user %}
<a href="logout" class="logout">Logout</a>
{% else %}
{% for provider in settings.oidc.providers %}
<a href="login?provider={{ provider.name }}">Login with: {{ provider.name }}</a>
{% else %}
<span class="error">Cannot login: no oidc prodiver in settings.yaml</span>
{% endfor %}
{% endif %}
</div>
<a href="private">Private area, accessible only when Authorization is set (logged in)</a>
</div>
<div class="content">
{% if user %}
<div class="user-info">
<p>Hey, {{ user.name }}</p>
{% if user.picture %}
<img src="{{ user.picture }}" class="picture"></img>
{% endif %}
<p>{{ user.email }}</p>
</div>
{% endif %}
{% if user and settings.oidc.show_session_details %}
<div class="debug">
<p>Session details:</p>
<ul>
{% for key, value in auth_data.items() %}
<li>
<span class="key">{{ key }}</span>: {{ value }}
</li>
{% endfor %}
</ul>
</div>
</div>
{% endif %}
</body>
</html>