Refactor; add services in settings

This commit is contained in:
phil 2025-01-19 01:48:00 +01:00
parent 17fabd21c9
commit f14d8d3114
7 changed files with 272 additions and 224 deletions

View file

@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
OIDC_PROVIDERS = set([provider.id for provider in settings.oidc.providers]) OIDC_PROVIDERS = set([provider.id for provider in settings.oidc.providers])
def get_provider(request: Request) -> StarletteOAuth2App: def get_oidc_provider(request: Request) -> StarletteOAuth2App:
"""Return the oidc_provider from a request object, from the session. """Return the oidc_provider from a request object, from the session.
It can be used in Depends()""" It can be used in Depends()"""
if (oidc_provider_id := request.session.get("oidc_provider_id")) is None: if (oidc_provider_id := request.session.get("oidc_provider_id")) is None:
@ -45,7 +45,7 @@ async def get_current_user(request: Request) -> User:
user = await db.get_user(user_sub) user = await db.get_user(user_sub)
## Check if the token is expired ## Check if the token is expired
if token.is_expired(): if token.is_expired():
oidc_provider = get_provider(request=request) oidc_provider = get_oidc_provider(request=request)
## Ask a new refresh token from the provider ## Ask a new refresh token from the provider
logger.info(f"Token expired for user {user.name}") logger.info(f"Token expired for user {user.name}")
try: try:
@ -61,7 +61,17 @@ async def get_current_user(request: Request) -> User:
return user return user
async def get_token(request: Request) -> OAuth2Token:
"""Return the token from a request object, from the session.
It can be used in Depends()"""
if (token := await db.get_token(request.session.get("token"))) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token")
return token
async def get_current_user_or_none(request: Request) -> User | None: async def get_current_user_or_none(request: Request) -> User | None:
"""Return the user from a request object, from the session.
It can be used in Depends()"""
try: try:
return await get_current_user(request) return await get_current_user(request)
except HTTPException: except HTTPException:
@ -69,6 +79,7 @@ async def get_current_user_or_none(request: Request) -> User | None:
def hasrole(required_roles: Union[str, list[str]] = []): def hasrole(required_roles: Union[str, list[str]] = []):
"""Decorator for RBAC permissions"""
required_roles_set: set[str] required_roles_set: set[str]
if isinstance(required_roles, str): if isinstance(required_roles, str):
required_roles_set = set([required_roles]) required_roles_set = set([required_roles])
@ -118,10 +129,4 @@ def update_token(*args, **kwargs):
... ...
async def get_token(request: Request) -> OAuth2Token:
if (token := await db.get_token(request.session.get("token"))) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token")
return token
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)

View file

@ -10,6 +10,7 @@ from urllib.parse import urlencode
from httpx import HTTPError from httpx import HTTPError
from fastapi import Depends, FastAPI, HTTPException, Request, status from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from fastapi.security import OpenIdConnect from fastapi.security import OpenIdConnect
@ -23,7 +24,7 @@ from pkce import generate_code_verifier, generate_pkce_pair
from .settings import settings from .settings import settings
from .models import User from .models import User
from .auth_utils import ( from .auth_utils import (
get_provider, get_oidc_provider,
hasrole, hasrole,
get_current_user_or_none, get_current_user_or_none,
get_current_user, get_current_user,
@ -42,6 +43,9 @@ app = FastAPI(
title="OIDC auth test", title="OIDC auth test",
) )
app.mount(
"/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static"
)
# SessionMiddleware is required by authlib # SessionMiddleware is required by authlib
app.add_middleware( app.add_middleware(
@ -76,6 +80,27 @@ for provider in settings.oidc.providers:
_providers[provider.id] = provider _providers[provider.id] = provider
@app.get("/")
async def home(
request: Request, user: Annotated[User, Depends(get_current_user_or_none)]
) -> HTMLResponse:
now = datetime.now()
return templates.TemplateResponse(
name="home.html",
request=request,
context={
"settings": settings.model_dump(),
"user": user,
"now": now,
"user_info_details": (
pretty_details(user, now)
if user and settings.oidc.show_session_details
else None
),
},
)
# Endpoints for the login / authorization process # Endpoints for the login / authorization process
@ -169,13 +194,13 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
@app.get("/logout") @app.get("/logout")
async def logout( async def logout(
request: Request, request: Request,
provider: Annotated[StarletteOAuth2App, Depends(get_provider)], oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
) -> RedirectResponse: ) -> RedirectResponse:
# Clear session # Clear session
request.session.pop("user_sub", None) request.session.pop("user_sub", None)
# Get provider's endpoint # Get provider's endpoint
if ( if (
provider_logout_uri := provider.server_metadata.get("end_session_endpoint") provider_logout_uri := oidc_provider.server_metadata.get("end_session_endpoint")
) is None: ) is None:
logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}") logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}")
return RedirectResponse(request.url_for("non_compliant_logout")) return RedirectResponse(request.url_for("non_compliant_logout"))
@ -200,7 +225,7 @@ async def logout(
@app.get("/non-compliant-logout") @app.get("/non-compliant-logout")
async def non_compliant_logout( async def non_compliant_logout(
request: Request, request: Request,
provider: Annotated[StarletteOAuth2App, Depends(get_provider)], oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
): ):
"""A page for non-compliant OAuth2 servers that we cannot log out.""" """A page for non-compliant OAuth2 servers that we cannot log out."""
return templates.TemplateResponse( return templates.TemplateResponse(
@ -210,28 +235,34 @@ async def non_compliant_logout(
) )
# Home URL # Route for OAuth resource server
@app.get("/") @app.get("/resource/{name}")
async def home( async def get_resource(
request: Request, user: Annotated[User, Depends(get_current_user_or_none)] name: str,
request: Request,
user: Annotated[User, Depends(get_current_user)],
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
token: Annotated[OAuth2Token, Depends(get_token)],
) -> HTMLResponse: ) -> HTMLResponse:
now = datetime.now() """Generic path for testing a resource provided by a provider"""
return templates.TemplateResponse( provider = _providers[oidc_provider.name]
name="home.html", if (
request=request, response := await oidc_provider.get(
context={ "/api/v1/user/repos",
"settings": settings.model_dump(), # headers={"Authorization": f"token {token['access_token']}"},
"user": user, token=token,
"now": now, )
"user_info_details": ( ).is_success:
pretty_details(user, now) repos = response.json()
if user and settings.oidc.show_session_details names = [repo["name"] for repo in repos]
else None return HTMLResponse(f"{user.name} has {len(repos)} repos: {', '.join(names)}")
), else:
}, raise HTTPException(status_code=response.status_code, detail=response.text)
)
# Routes for test
@app.get("/public") @app.get("/public")
@ -239,9 +270,6 @@ async def public() -> HTMLResponse:
return HTMLResponse("<h1>Not protected</h1>") return HTMLResponse("<h1>Not protected</h1>")
# Some URIs for testing the permissions
@app.get("/protected") @app.get("/protected")
async def get_protected( async def get_protected(
user: Annotated[User, Depends(get_current_user)] user: Annotated[User, Depends(get_current_user)]
@ -277,12 +305,12 @@ async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse:
@app.get("/introspect") @app.get("/introspect")
async def get_introspect( async def get_introspect(
request: Request, request: Request,
provider: Annotated[StarletteOAuth2App, Depends(get_provider)], oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
token: Annotated[OAuth2Token, Depends(get_token)], token: Annotated[OAuth2Token, Depends(get_token)],
) -> JSONResponse: ) -> JSONResponse:
if ( if (
response := await provider.post( response := await oidc_provider.post(
provider.server_metadata["introspection_endpoint"], oidc_provider.server_metadata["introspection_endpoint"],
token=token, token=token,
data={"token": token["access_token"]}, data={"token": token["access_token"]},
) )
@ -296,11 +324,11 @@ async def get_introspect(
async def get_forgejo_user_info( async def get_forgejo_user_info(
request: Request, request: Request,
user: Annotated[User, Depends(get_current_user)], user: Annotated[User, Depends(get_current_user)],
provider: Annotated[StarletteOAuth2App, Depends(get_provider)], oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
token: Annotated[OAuth2Token, Depends(get_token)], token: Annotated[OAuth2Token, Depends(get_token)],
) -> HTMLResponse: ) -> HTMLResponse:
if ( if (
response := await provider.get( response := await oidc_provider.get(
"/api/v1/user/repos", "/api/v1/user/repos",
# headers={"Authorization": f"token {token['access_token']}"}, # headers={"Authorization": f"token {token['access_token']}"},
token=token, token=token,
@ -313,11 +341,9 @@ async def get_forgejo_user_info(
raise HTTPException(status_code=response.status_code, detail=response.text) raise HTTPException(status_code=response.status_code, detail=response.text)
# @app.get("/fast_api_depends") # Snippet for running standalone
# def fast_api_depends( # Mostly useful for the --version option,
# token: Annotated[str, Depends(fastapi_providers["Keycloak"])] # as running with uvicorn is easy and provides flaxibility
# ) -> HTMLResponse:
# return HTMLResponse("You're Authenticated")
def main(): def main():

View file

@ -13,7 +13,16 @@ from pydantic_settings import (
) )
class Resource(BaseModel):
"""A resource with an URL that can be accessed with an OAuth2 access token"""
name: str
url: str
class OIDCProvider(BaseModel): class OIDCProvider(BaseModel):
"""OIDC provider, can also be a resource server"""
id: str id: str
name: str name: str
url: str url: str
@ -22,6 +31,7 @@ class OIDCProvider(BaseModel):
# For PKCE (not implemented yet) # For PKCE (not implemented yet)
code_challenge_method: str | None = None code_challenge_method: str | None = None
hint: str = "No hint" hint: str = "No hint"
resources: list[Resource] = []
@computed_field @computed_field
@property @property

View file

@ -0,0 +1,168 @@
body {
font-family: Arial, Helvetica, sans-serif;
background-color: floralwhite;
margin: 0;
}
h1 {
text-align: center;
background-color: #f7c7867d;
margin: 0 0 0.2em 0;
}
p {
margin: 0.2em;
}
hr {
margin: 0.2em;
}
.hidden {
display: none;
}
.center {
text-align: center;
}
.content {
width: 100%;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
}
.user-info {
padding: 0.5em;
display: flex;
gap: 0.5em;
flex-direction: column;
width: fit-content;
align-items: center;
margin: 5px auto;
box-shadow: 0px 0px 10px lightgreen;
background-color: lightgreen;
border-radius: 8px;
}
.user-info * {
flex: 2 1 auto;
margin: 0;
}
.user-info .picture {
max-width: 3em;
max-height: 3em
}
.user-info a.logout {
border: 2px solid darkkhaki;
padding: 3px 6px;
text-decoration: none;
text-align: center;
color: black;
}
.user-info a.logout:hover {
background-color: orange;
}
debug-auth {
font-size: 90%;
background-color: #d8bebc75;
padding: 6px;
}
.debug-auth * {
margin: 0;
}
.debug-auth p {
text-align: center;
border-bottom: 1px solid black;
}
.debug-auth ul {
padding: 0;
list-style: none;
}
.debug-auth p, .debug-auth .key {
font-weight: bold;
}
.content {
text-align: left;
}
.hasResponseStatus {
background-color: #88888840;
}
.hasResponseStatus.status-200 {
background-color: #00ff0040;
}
.hasResponseStatus.status-401 {
background-color: #ff000040;
}
.hasResponseStatus.status-403 {
background-color: #ff990040;
}
.hasResponseStatus.status-404 {
background-color: #ffCC0040;
}
.hasResponseStatus.status-503 {
background-color: #ffA88050;
}
.role {
padding: 3px 6px;
background-color: #44228840;
}
/* For home */
.login-box {
text-align: center;
background-color: antiquewhite;
margin: 0.5em auto;
width: fit-content;
box-shadow: 0 0 10px #49759b88;
border-radius: 8px;
}
.login-box .description {
font-style: italic;
font-weight: bold;
background-color: #f7c7867d;
padding: 6px;
margin: 0;
border-radius: 8px 8px 0 0;
}
.providers {
justify-content: center;
padding: 0.8em;
}
.providers .provider {
min-height: 2em;
}
.providers .provider a.link {
text-decoration: none;
max-height: 2em;
}
.providers .provider .link div {
text-align: center;
background-color: #f7c7867d;
border-radius: 8px;
padding: 6px;
text-align: center;
color: black;
font-weight: bold;
cursor: pointer;
}
.providers .provider .hint {
font-size: 80%;
max-width: 13em;
}
.providers .error {
color: darkred;
padding: 3px 6px;
text-align: center;
font-weight: bold;
flex: 1 1 auto;
}
.content #links-to-check {
display: flex;
text-align: center;
justify-content: center;
gap: 0.5em;
flex-flow: wrap;
}
.content #links-to-check a {
color: black;
padding: 5px 10px;
text-decoration: none;
border-radius: 8px;
}

View file

@ -0,0 +1,17 @@
function checkHref(elem) {
var xmlHttp = new XMLHttpRequest()
xmlHttp.onreadystatechange = function () {
if (xmlHttp.readyState == 4) {
elem.classList.add("hasResponseStatus")
elem.classList.add("status-" + xmlHttp.status)
elem.title = "Response code: " + xmlHttp.status + " - " + xmlHttp.statusText
}
}
xmlHttp.open("GET", elem.href, true) // true for asynchronous
xmlHttp.send(null)
}
function checkPerms(rootId) {
var rootElem = document.getElementById(rootId)
Array.from(rootElem.children).forEach(elem => checkHref(elem))
}

View file

@ -1,136 +1,8 @@
<html> <html>
<head> <head>
<title>FastAPI OIDC test</title> <title>FastAPI OIDC test</title>
<style> <link href="{{ url_for('static', path='/styles.css') }}" rel="stylesheet">
body { <script src="{{ url_for('static', path='/utils.js') }}"></script>
font-family: Arial, Helvetica, sans-serif;
background-color: floralwhite;
margin: 0;
}
h1 {
text-align: center;
background-color: #f7c7867d;
margin: 0 0 0.2em 0;
}
p {
margin: 0.2em;
}
hr {
margin: 0.2em;
}
.hidden {
display: none;
}
.center {
text-align: center;
}
.content {
width: 100%;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
}
.user-info {
padding: 0.5em;
display: flex;
gap: 0.5em;
flex-direction: column;
width: fit-content;
align-items: center;
margin: 5px auto;
box-shadow: 0px 0px 10px lightgreen;
background-color: lightgreen;
border-radius: 8px;
}
.user-info * {
flex: 2 1 auto;
margin: 0;
}
.user-info .picture {
max-width: 3em;
max-height: 3em
}
.user-info a.logout {
border: 2px solid darkkhaki;
padding: 3px 6px;
text-decoration: none;
text-align: center;
color: black;
}
.user-info a.logout:hover {
background-color: orange;
}
.debug-auth {
font-size: 90%;
background-color: #d8bebc75;
padding: 6px;
}
.debug-auth * {
margin: 0;
}
.debug-auth p {
text-align: center;
border-bottom: 1px solid black;
}
.debug-auth ul {
padding: 0;
list-style: none;
}
.debug-auth p, .debug-auth .key {
font-weight: bold;
}
.content {
text-align: left;
}
.content #links-to-check {
display: flex;
text-align: center;
justify-content: center;
gap: 0.5em;
flex-flow: wrap;
}
.content #links-to-check a {
color: black;
padding: 5px 10px;
text-decoration: none;
border-radius: 8px;
}
.hasResponseStatus {
background-color: #88888840;
}
.hasResponseStatus.status-200 {
background-color: #00ff0040;
}
.hasResponseStatus.status-401 {
background-color: #ff000040;
}
.hasResponseStatus.status-403 {
background-color: #ff990040;
}
.role {
padding: 3px 6px;
background-color: #44228840;
}
</style>
<script>
function checkHref(elem) {
var xmlHttp = new XMLHttpRequest()
xmlHttp.onreadystatechange = function() {
if (xmlHttp.readyState == 4) {
elem.classList.add("hasResponseStatus")
elem.classList.add("status-" + xmlHttp.status)
elem.title = "Response code: " + xmlHttp.status + " - " + xmlHttp.statusText
}
}
xmlHttp.open("GET", elem.href, true) // true for asynchronous
xmlHttp.send(null)
}
function checkPerms(rootId) {
var rootElem = document.getElementById(rootId)
Array.from(rootElem.children).forEach(elem => checkHref(elem))
}
</script>
</head> </head>
<body onload="checkPerms('links-to-check')"> <body onload="checkPerms('links-to-check')">
<h1>OIDC-test</h1> <h1>OIDC-test</h1>

View file

@ -1,55 +1,5 @@
{% extends "base.html" %} {% extends "base.html" %}
{% block content %} {% block content %}
<style>
.login-box {
text-align: center;
background-color: antiquewhite;
margin: 0.5em auto;
width: fit-content;
box-shadow: 0 0 10px #49759b88;
border-radius: 8px;
}
.login-box .description {
font-style: italic;
font-weight: bold;
background-color: #f7c7867d;
padding: 6px;
margin: 0;
border-radius: 8px 8px 0 0;
}
.providers {
justify-content: center;
padding: 0.8em;
}
.providers .provider {
min-height: 2em;
}
.providers .provider a.link {
text-decoration: none;
max-height: 2em;
}
.providers .provider .link div {
text-align: center;
background-color: #f7c7867d;
border-radius: 8px;
padding: 6px;
text-align: center;
color: black;
font-weight: bold;
cursor: pointer;
}
.providers .provider .hint {
font-size: 80%;
max-width: 13em;
}
.providers .error {
color: darkred;
padding: 3px 6px;
text-align: center;
font-weight: bold;
flex: 1 1 auto;
}
</style>
<p class="center"> <p class="center">
Test the authentication and authorization, Test the authentication and authorization,
with OpenID Connect and OAuth2 with different providers. with OpenID Connect and OAuth2 with different providers.