Initial commit
This commit is contained in:
commit
4be2036f3b
10 changed files with 1303 additions and 0 deletions
0
src/oidc-test/__init__.py
Normal file
0
src/oidc-test/__init__.py
Normal file
224
src/oidc-test/main.py
Normal file
224
src/oidc-test/main.py
Normal file
|
@ -0,0 +1,224 @@
|
|||
from typing import Annotated, Type, Tuple
|
||||
import string
|
||||
import random
|
||||
from json import dumps
|
||||
|
||||
from httpx import HTTPError
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.security import (
|
||||
OAuth2PasswordBearer,
|
||||
OpenIdConnect,
|
||||
OAuth2AuthorizationCodeBearer,
|
||||
)
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from authlib.integrations.starlette_client import OAuth, OAuthError
|
||||
from pydantic import BaseModel
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
YamlConfigSettingsSource,
|
||||
PydanticBaseSettingsSource,
|
||||
)
|
||||
from jose import jwt
|
||||
|
||||
from .models import User
|
||||
|
||||
templates = Jinja2Templates("src/templates")
|
||||
|
||||
|
||||
class OIDCProvider(BaseModel):
|
||||
name: str = ""
|
||||
url: str = ""
|
||||
client_id: str = ""
|
||||
client_secret: str = ""
|
||||
is_swagger: bool = False
|
||||
|
||||
@property
|
||||
def provider_url(self):
|
||||
return self.url + "/.well-known/openid-configuration"
|
||||
|
||||
@property
|
||||
def token_url(self):
|
||||
return "auth/" + self.name
|
||||
|
||||
|
||||
class OIDCSettings(BaseModel):
|
||||
show_session_details: bool = False
|
||||
providers: list[OIDCProvider] = []
|
||||
swagger_provider: str = ""
|
||||
|
||||
def get_swagger_provider(self) -> OIDCProvider:
|
||||
for provider in self.providers:
|
||||
if provider.is_swagger:
|
||||
return provider
|
||||
else:
|
||||
raise UserWarning("Please define a provider for Swagger with id_swagger")
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Settings wil be read from an .env file"""
|
||||
|
||||
oidc: OIDCSettings = OIDCSettings()
|
||||
secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16))
|
||||
|
||||
class Config:
|
||||
env_nested_delimiter = "__"
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: Type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> Tuple[PydanticBaseSettingsSource, ...]:
|
||||
return (
|
||||
init_settings,
|
||||
env_settings,
|
||||
file_secret_settings,
|
||||
YamlConfigSettingsSource(settings_cls, "settings.yaml"),
|
||||
dotenv_settings,
|
||||
)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
swagger_provider = settings.oidc.get_swagger_provider()
|
||||
if swagger_provider is not None:
|
||||
swagger_ui_init_oauth = {
|
||||
"clientId": settings.oidc.get_swagger_provider().client_id,
|
||||
"scopes": ["openid"], # fill in additional scopes when necessary
|
||||
"appName": "Test Application",
|
||||
# "usePkceWithAuthorizationCodeGrant": True,
|
||||
}
|
||||
|
||||
app = FastAPI(
|
||||
title="OIDC auth test",
|
||||
swagger_ui_init_oauth=swagger_ui_init_oauth,
|
||||
)
|
||||
|
||||
authlib_oauth = OAuth()
|
||||
for provider in settings.oidc.providers:
|
||||
authlib_oauth.register(
|
||||
name=provider.name,
|
||||
server_metadata_url=provider.provider_url,
|
||||
client_kwargs={"scope": "openid email offline_access profile"},
|
||||
client_id=provider.client_id,
|
||||
client_secret=provider.client_secret,
|
||||
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
|
||||
)
|
||||
|
||||
oidc_providers = dict(
|
||||
(
|
||||
provider.name,
|
||||
OpenIdConnect(
|
||||
openIdConnectUrl=provider.url,
|
||||
scheme_name="openid",
|
||||
auto_error=True,
|
||||
),
|
||||
)
|
||||
for provider in settings.oidc.providers
|
||||
)
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=swagger_provider.token_url)
|
||||
oidc_scheme = oidc_providers[swagger_provider.name]
|
||||
|
||||
|
||||
def get_current_user(request: Request) -> User:
|
||||
auth_data = request.session.get("user")
|
||||
if auth_data is None:
|
||||
raise HTTPException(401, "Not authorized")
|
||||
return User(**auth_data)
|
||||
|
||||
|
||||
def get_current_user_or_none(request: Request) -> User | None:
|
||||
try:
|
||||
return get_current_user(request)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
def fastapi_oauth2():
|
||||
breakpoint()
|
||||
...
|
||||
|
||||
|
||||
app.add_middleware(
|
||||
SessionMiddleware,
|
||||
secret_key=settings.secret_key,
|
||||
)
|
||||
|
||||
|
||||
async def current_user(request: Request, token: str | None = Depends(fastapi_oauth2)):
|
||||
# we could query the identity provider to give us some information about the user
|
||||
# userinfo = await self.authlib_oauth.provider.userinfo(token={"access_token": token})
|
||||
|
||||
# in my case, the JWT already contains all the information so I only need to decode and verify it
|
||||
try:
|
||||
# note that this also validates the JWT by validating all the claims
|
||||
user = await authlib_oauth.provider.parse_id_token(
|
||||
request, token={"id_token": token}
|
||||
)
|
||||
except Exception as exp:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Supplied authentication could not be validated ({exp})",
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
@app.get("/login")
|
||||
async def login(request: Request, provider: str) -> RedirectResponse:
|
||||
redirect_uri = request.url_for("auth", provider=provider)
|
||||
try:
|
||||
return await getattr(authlib_oauth, provider).authorize_redirect(
|
||||
request, redirect_uri
|
||||
)
|
||||
except HTTPError:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider")
|
||||
|
||||
|
||||
@app.get("/auth/{provider}")
|
||||
async def auth(request: Request, provider: str) -> RedirectResponse:
|
||||
try:
|
||||
token = await getattr(authlib_oauth, provider).authorize_access_token(request)
|
||||
|
||||
except OAuthError as error:
|
||||
return HTMLResponse(f"<h1>{error.error}</h1>")
|
||||
|
||||
user = token.get("userinfo")
|
||||
|
||||
if user:
|
||||
request.session["user"] = dict(user)
|
||||
return RedirectResponse(url=request.session.pop("next", "/"))
|
||||
|
||||
return RedirectResponse(url="/login")
|
||||
|
||||
|
||||
@app.get("/logout")
|
||||
async def logout(request: Request) -> RedirectResponse:
|
||||
request.session.pop("user", None)
|
||||
return RedirectResponse(url="/")
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def home(
|
||||
request: Request, user: Annotated[User, Depends(get_current_user_or_none)]
|
||||
) -> HTMLResponse:
|
||||
return templates.TemplateResponse(
|
||||
request=request,
|
||||
context={
|
||||
"settings": settings.dict(),
|
||||
"user": user,
|
||||
"auth_data": request.session.get("user"),
|
||||
},
|
||||
name="index.html",
|
||||
)
|
||||
|
||||
|
||||
@app.get("/private")
|
||||
async def read_items(token: Annotated[str, Depends(get_current_user)]):
|
||||
return {"token": token}
|
11
src/oidc-test/models.py
Normal file
11
src/oidc-test/models.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
from pydantic import BaseModel, EmailStr, AnyHttpUrl
|
||||
|
||||
# from app.models import User
|
||||
|
||||
|
||||
class User(BaseModel, extra="ignore"):
|
||||
|
||||
id: str | None = None
|
||||
name: str
|
||||
email: EmailStr | None = None
|
||||
picture: AnyHttpUrl | None = None
|
119
src/templates/index.html
Normal file
119
src/templates/index.html
Normal file
|
@ -0,0 +1,119 @@
|
|||
<html>
|
||||
<head>
|
||||
<title>FastAPI OIDC test</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: Arial, Helvetica, sans-serif;
|
||||
background-color: antiquewhite;
|
||||
}
|
||||
h1 {
|
||||
text-align: center;
|
||||
}
|
||||
.content {
|
||||
width: 100%;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
.user-info {
|
||||
padding: 1em;
|
||||
margin: 1em 0;
|
||||
display: flex;
|
||||
gap: 0.5em;
|
||||
flex-direction: column;
|
||||
width: fit-content;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
box-shadow: 0px 0px 10px lightgreen;
|
||||
background-color: lightgreen;
|
||||
}
|
||||
.user-info * {
|
||||
flex: 2 1 auto;
|
||||
margin: 0;
|
||||
}
|
||||
.user-info .picture {
|
||||
max-width: 3em;
|
||||
max-height: 3em
|
||||
}
|
||||
.login-box {
|
||||
text-align: center;
|
||||
}
|
||||
.login-toolbox {
|
||||
max-width: 20em;
|
||||
margin: auto;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
padding: 0 1em;
|
||||
gap: 5px;
|
||||
}
|
||||
.login-toolbox a {
|
||||
background-color: lightblue;
|
||||
padding: 3px 6px;
|
||||
text-decoration: none;
|
||||
text-align: center;
|
||||
color: black;
|
||||
flex: 1 1 auto;
|
||||
}
|
||||
.login-toolbox .error {
|
||||
color: darkred;
|
||||
padding: 3px 6px;
|
||||
text-align: center;
|
||||
font-weight: bold;
|
||||
flex: 1 1 auto;
|
||||
}
|
||||
.login-toolbox a.logout:hover {
|
||||
background-color: orange;
|
||||
}
|
||||
.login-toolbox a:hover {
|
||||
background-color: lightgreen;
|
||||
}
|
||||
.debug {
|
||||
font-size: 90%;
|
||||
}
|
||||
.debug p, .debug .key {
|
||||
font-weight: bold;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>Test app for OIDC</h1>
|
||||
<div class="login-box">
|
||||
<div class="login-toolbox">
|
||||
{% if user %}
|
||||
<a href="logout" class="logout">Logout</a>
|
||||
{% else %}
|
||||
{% for provider in settings.oidc.providers %}
|
||||
<a href="login?provider={{ provider.name }}">Login with: {{ provider.name }}</a>
|
||||
{% else %}
|
||||
<span class="error">Cannot login: no oidc prodiver in settings.yaml</span>
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
</div>
|
||||
<a href="private">Private area, accessible only when Authorization is set (logged in)</a>
|
||||
</div>
|
||||
<div class="content">
|
||||
{% if user %}
|
||||
<div class="user-info">
|
||||
<p>Hey, {{ user.name }}</p>
|
||||
{% if user.picture %}
|
||||
<img src="{{ user.picture }}" class="picture"></img>
|
||||
{% endif %}
|
||||
<p>{{ user.email }}</p>
|
||||
</div>
|
||||
{% endif %}
|
||||
{% if user and settings.oidc.show_session_details %}
|
||||
<div class="debug">
|
||||
<p>Session details:</p>
|
||||
<ul>
|
||||
{% for key, value in auth_data.items() %}
|
||||
<li>
|
||||
<span class="key">{{ key }}</span>: {{ value }}
|
||||
</li>
|
||||
{% endfor %}
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
{% endif %}
|
||||
</body>
|
||||
</html>
|
Loading…
Add table
Add a link
Reference in a new issue