Create a sub-app for resource server move all resources to resource server;
All checks were successful
/ build (push) Successful in 5s
/ test (push) Successful in 5s

use token bearer instead of session cookie for resources and use fetch instead of XMLHttpRequest for checking resource status;
add UserWithRole class for fastapi depends (instead of has_role decorator);
add asserts for typing QC; code formatting;
comment out introspect endpoint processing
This commit is contained in:
phil 2025-02-07 13:57:17 +01:00
parent ee8ba3d2df
commit d39adf41ef
8 changed files with 188 additions and 153 deletions

View file

@ -13,7 +13,7 @@ from authlib.oauth2.auth import OAuth2Token
from .models import User from .models import User
from .database import TokenNotInDb, db, UserNotInDB from .database import TokenNotInDb, db, UserNotInDB
from .settings import settings, OIDCProvider, oidc_providers_settings from .settings import oidc_providers_settings
logger = logging.getLogger("oidc-test") logger = logging.getLogger("oidc-test")
@ -21,6 +21,8 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
async def fetch_token(name, request): async def fetch_token(name, request):
assert name is not None
assert request is not None
logger.warn("TODO: fetch_token") logger.warn("TODO: fetch_token")
... ...
# if name in oidc_providers: # if name in oidc_providers:
@ -37,8 +39,10 @@ async def update_token(name, token, refresh_token=None, access_token=None):
sid: str = oidc_provider_settings.decode(token["id_token"])["sid"] sid: str = oidc_provider_settings.decode(token["id_token"])["sid"]
item = await db.get_token(oidc_provider_settings, sid) item = await db.get_token(oidc_provider_settings, sid)
# update old token # update old token
if access_token is not None:
item["access_token"] = token.get("access_token") item["access_token"] = token.get("access_token")
item["refresh_token"] = token.get("refresh_token") if refresh_token is not None:
item["refresh_token"] = refresh_token
item["expires_at"] = token["expires_at"] item["expires_at"] = token["expires_at"]
logger.info(f"Token {sid} refreshed") logger.info(f"Token {sid} refreshed")
# It's a fake db and only in memory, so there's nothing to save # It's a fake db and only in memory, so there's nothing to save
@ -119,6 +123,7 @@ async def get_current_user(request: Request) -> User:
userinfo = await oidc_provider.fetch_access_token( userinfo = await oidc_provider.fetch_access_token(
refresh_token=token.get("refresh_token") refresh_token=token.get("refresh_token")
) )
assert userinfo is not None
except OAuthError as err: except OAuthError as err:
logger.exception(err) logger.exception(err)
# raise HTTPException( # raise HTTPException(
@ -242,3 +247,20 @@ async def get_user_from_token(
access_token=token, access_token=token,
) )
return user return user
class UserWithRole:
roles: set[str]
def __init__(self, roles: str | list[str] | tuple[str] | set[str]):
if isinstance(roles, str):
self.roles = set([roles])
elif isinstance(roles, (list, tuple, set)):
self.roles = set(roles)
def __call__(self, user: User = Depends(get_user_from_token)) -> User:
if not any(self.roles.intersection(user.roles_as_set)):
raise HTTPException(
status.HTTP_401_UNAUTHORIZED, f"Not of any required role {', '.join(self.roles)}"
)
return user

View file

@ -69,6 +69,7 @@ class Database:
async def add_token(self, oidc_provider_settings: OIDCProvider, token: OAuth2Token) -> None: async def add_token(self, oidc_provider_settings: OIDCProvider, token: OAuth2Token) -> None:
"""Store a token using as key the sid (auth provider's session id) """Store a token using as key the sid (auth provider's session id)
in the id_token""" in the id_token"""
assert isinstance(oidc_provider_settings, OIDCProvider)
sid = token["userinfo"]["sid"] sid = token["userinfo"]["sid"]
self.tokens[sid] = token self.tokens[sid] = token
@ -77,6 +78,7 @@ class Database:
oidc_provider_settings: OIDCProvider, oidc_provider_settings: OIDCProvider,
sid: str | None, sid: str | None,
) -> OAuth2Token: ) -> OAuth2Token:
assert isinstance(oidc_provider_settings, OIDCProvider)
if sid is None: if sid is None:
raise TokenNotInDb raise TokenNotInDb
try: try:

View file

@ -12,7 +12,7 @@ from contextlib import asynccontextmanager
from httpx import HTTPError from httpx import HTTPError
from fastapi import Depends, FastAPI, HTTPException, Request, status from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from jwt import InvalidTokenError, PyJWTError from jwt import InvalidTokenError, PyJWTError
@ -31,17 +31,13 @@ from .models import User
from .auth_utils import ( from .auth_utils import (
get_oidc_provider, get_oidc_provider,
get_oidc_provider_or_none, get_oidc_provider_or_none,
hasrole,
get_current_user_or_none, get_current_user_or_none,
get_current_user,
get_user_from_token,
authlib_oauth, authlib_oauth,
get_token,
get_providers_info, get_providers_info,
) )
from .auth_misc import pretty_details from .auth_misc import pretty_details
from .database import TokenNotInDb, db from .database import TokenNotInDb, db
from .resource_server import get_resource from .resource_server import resource_server
logger = logging.getLogger("oidc-test") logger = logging.getLogger("oidc-test")
@ -50,6 +46,7 @@ templates = Jinja2Templates(Path(__file__).parent / "templates")
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
assert app is not None
await get_providers_info() await get_providers_info()
yield yield
@ -64,24 +61,21 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.mount(
"/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static"
)
# SessionMiddleware is required by authlib # SessionMiddleware is required by authlib
app.add_middleware( app.add_middleware(
SessionMiddleware, SessionMiddleware,
secret_key=settings.secret_key, secret_key=settings.secret_key,
) )
app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static")
app.mount("/resource", resource_server, name="resource_server")
@app.get("/") @app.get("/")
async def home( async def home(
request: Request, request: Request,
user: Annotated[User, Depends(get_current_user_or_none)], user: Annotated[User, Depends(get_current_user_or_none)],
oidc_provider: Annotated[ oidc_provider: Annotated[StarletteOAuth2App | None, Depends(get_oidc_provider_or_none)],
StarletteOAuth2App | None, Depends(get_oidc_provider_or_none)
],
) -> HTMLResponse: ) -> HTMLResponse:
now = datetime.now() now = datetime.now()
if oidc_provider and ( if oidc_provider and (
@ -119,9 +113,7 @@ async def home(
"oidc_provider_settings": oidc_provider_settings, "oidc_provider_settings": oidc_provider_settings,
"resources": resources, "resources": resources,
"user_info_details": ( "user_info_details": (
pretty_details(user, now) pretty_details(user, now) if user and settings.oidc.show_session_details else None
if user and settings.oidc.show_session_details
else None
), ),
}, },
) )
@ -215,24 +207,19 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
else: else:
# Not sure if it's correct to redirect to plain login # Not sure if it's correct to redirect to plain login
# if no userinfo is provided # if no userinfo is provided
return RedirectResponse( return RedirectResponse(url=request.url_for("login", oidc_provider_id=oidc_provider_id))
url=request.url_for("login", oidc_provider_id=oidc_provider_id)
)
@app.get("/account") @app.get("/account")
async def account( async def account(
request: Request, request: Request,
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
) -> RedirectResponse: ) -> RedirectResponse:
if ( if (
oidc_provider_settings := oidc_providers_settings.get( oidc_provider_settings := oidc_providers_settings.get(
request.session.get("oidc_provider_id", "") request.session.get("oidc_provider_id", "")
) )
) is None: ) is None:
raise HTTPException( raise HTTPException(status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider settings")
status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider settings"
)
return RedirectResponse(f"{oidc_provider_settings.account_url_template}") return RedirectResponse(f"{oidc_provider_settings.account_url_template}")
@ -244,12 +231,8 @@ async def logout(
# Clear session # Clear session
request.session.pop("user_sub", None) request.session.pop("user_sub", None)
# Get provider's endpoint # Get provider's endpoint
if ( if (provider_logout_uri := oidc_provider.server_metadata.get("end_session_endpoint")) is None:
provider_logout_uri := oidc_provider.server_metadata.get("end_session_endpoint") logger.warn(f"Cannot find end_session_endpoint for provider {oidc_provider.name}")
) is None:
logger.warn(
f"Cannot find end_session_endpoint for provider {oidc_provider.name}"
)
return RedirectResponse(request.url_for("non_compliant_logout")) return RedirectResponse(request.url_for("non_compliant_logout"))
post_logout_uri = request.url_for("home") post_logout_uri = request.url_for("home")
oidc_provider_settings = oidc_providers_settings.get( oidc_provider_settings = oidc_providers_settings.get(
@ -257,9 +240,7 @@ async def logout(
) )
assert oidc_provider_settings is not None assert oidc_provider_settings is not None
try: try:
token = await db.get_token( token = await db.get_token(oidc_provider_settings, request.session.pop("sid", None))
oidc_provider_settings, request.session.pop("sid", None)
)
except TokenNotInDb: except TokenNotInDb:
logger.warn("No session in db for the token or no token") logger.warn("No session in db for the token or no token")
return RedirectResponse(request.url_for("home")) return RedirectResponse(request.url_for("home"))
@ -292,90 +273,6 @@ async def non_compliant_logout(
) )
# Route for OAuth resource server
@app.get("/resource/{id}")
async def get_resource_(
id: str,
# user: Annotated[User, Depends(get_current_user)],
# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
# token: Annotated[OAuth2Token, Depends(get_token)],
user: Annotated[User, Depends(get_user_from_token)],
) -> JSONResponse:
"""Generic path for testing a resource provided by a provider"""
return JSONResponse(await get_resource(id, user))
# Routes for RBAC based tests
@app.get("/public")
async def public() -> HTMLResponse:
return HTMLResponse("<h1>Not protected</h1>")
@app.get("/protected")
async def get_protected(
user: Annotated[User, Depends(get_current_user)]
) -> HTMLResponse:
assert user is not None # Just to keep QA checks happy
return HTMLResponse("<h1>Only authenticated users can see this</h1>")
@app.get("/protected-by-foorole")
@hasrole("foorole")
async def get_protected_by_foorole(request: Request) -> HTMLResponse:
assert request is not None # Just to keep QA checks happy
return HTMLResponse("<h1>Only users with foorole can see this</h1>")
@app.get("/protected-by-barrole")
@hasrole("barrole")
async def get_protected_by_barrole(request: Request) -> HTMLResponse:
assert request is not None # Just to keep QA checks happy
return HTMLResponse("<h1>Protected by barrole</h1>")
@app.get("/protected-by-foorole-and-barrole")
@hasrole("barrole")
@hasrole("foorole")
async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse:
assert request is not None # Just to keep QA checks happy
return HTMLResponse("<h1>Only users with foorole and barrole can see this</h1>")
@app.get("/protected-by-foorole-or-barrole")
@hasrole(["foorole", "barrole"])
async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse:
assert request is not None # Just to keep QA checks happy
return HTMLResponse("<h1>Only users with foorole or barrole can see this</h1>")
@app.get("/introspect")
async def get_introspect(
request: Request,
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
token: Annotated[OAuth2Token, Depends(get_token)],
) -> JSONResponse:
assert request is not None # Just to keep QA checks happy
if (url := oidc_provider.server_metadata.get("introspection_endpoint")) is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="No intrispection endpoint found for the OIDC provider",
)
if (
response := await oidc_provider.post(
url,
token=token,
data={"token": token["access_token"]},
)
).is_success:
return response.json()
else:
raise HTTPException(status_code=response.status_code, detail=response.text)
# Snippet for running standalone # Snippet for running standalone
# Mostly useful for the --version option, # Mostly useful for the --version option,
# as running with uvicorn is easy and provides better flexibility, eg. # as running with uvicorn is easy and provides better flexibility, eg.
@ -397,9 +294,7 @@ def main():
parser.add_argument( parser.add_argument(
"-p", "--port", type=int, default=80, help="Port to listen to (default: 80)" "-p", "--port", type=int, default=80, help="Port to listen to (default: 80)"
) )
parser.add_argument( parser.add_argument("-v", "--version", action="store_true", help="Print version and exit")
"-v", "--version", action="store_true", help="Print version and exit"
)
args = parser.parse_args() args = parser.parse_args()
if args.version: if args.version:

View file

@ -1,6 +1,6 @@
import logging import logging
from functools import cached_property from functools import cached_property
from typing import Self, Any from typing import Any
from pydantic import ( from pydantic import (
computed_field, computed_field,
@ -60,6 +60,4 @@ class User(UserBase):
assert self.oidc_provider.name is not None assert self.oidc_provider.name is not None
from .settings import oidc_providers_settings from .settings import oidc_providers_settings
return oidc_providers_settings[self.oidc_provider.name].decode( return oidc_providers_settings[self.oidc_provider.name].decode(self.access_token)
self.access_token
)

View file

@ -1,15 +1,127 @@
from datetime import datetime from datetime import datetime
from typing import Annotated
import logging import logging
from httpx import AsyncClient from httpx import AsyncClient
from jwt.exceptions import ExpiredSignatureError, InvalidTokenError from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
from fastapi import HTTPException, status from fastapi import FastAPI, HTTPException, Depends, Request, status
from starlette.status import HTTP_401_UNAUTHORIZED from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
# from starlette.middleware.sessions import SessionMiddleware
# from authlib.integrations.starlette_client.apps import StarletteOAuth2App
# from authlib.oauth2.rfc6749 import OAuth2Token
from .models import User from .models import User
from .auth_utils import (
get_user_from_token,
UserWithRole,
get_oidc_provider,
get_token,
)
from .settings import settings
logger = logging.getLogger("oidc-test") logger = logging.getLogger("oidc-test")
resource_server = FastAPI()
resource_server.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# SessionMiddleware is required by authlib
# resource_server.add_middleware(
# SessionMiddleware,
# secret_key=settings.secret_key,
# )
# Route for OAuth resource server
# Routes for RBAC based tests
@resource_server.get("/public")
async def public() -> HTMLResponse:
return HTMLResponse("<h1>Not protected</h1>")
@resource_server.get("/protected")
async def get_protected(user: Annotated[User, Depends(get_user_from_token)]) -> HTMLResponse:
assert user is not None # Just to keep QA checks happy
return HTMLResponse("<h1>Only authenticated users can see this</h1>")
@resource_server.get("/protected-by-foorole")
async def get_protected_by_foorole(
user: Annotated[User, Depends(UserWithRole("foorole"))]
) -> HTMLResponse:
return HTMLResponse("<h1>Only users with foorole can see this</h1>")
@resource_server.get("/protected-by-barrole")
async def get_protected_by_barrole(
user: Annotated[User, Depends(UserWithRole("barrole"))]
) -> HTMLResponse:
return HTMLResponse("<h1>Protected by barrole</h1>")
@resource_server.get("/protected-by-foorole-and-barrole")
async def get_protected_by_foorole_and_barrole(
user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))],
) -> HTMLResponse:
assert user is not None # Just to keep QA checks happy
return HTMLResponse("<h1>Only users with foorole and barrole can see this</h1>")
@resource_server.get("/protected-by-foorole-or-barrole")
async def get_protected_by_foorole_or_barrole(
user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))]
) -> HTMLResponse:
assert user is not None # Just to keep QA checks happy
return HTMLResponse("<h1>Only users with foorole or barrole can see this</h1>")
# @resource_server.get("/introspect")
# async def get_introspect(
# request: Request,
# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
# token: Annotated[OAuth2Token, Depends(get_token)],
# ) -> JSONResponse:
# assert request is not None # Just to keep QA checks happy
# if (url := oidc_provider.server_metadata.get("introspection_endpoint")) is None:
# raise HTTPException(
# status_code=status.HTTP_401_UNAUTHORIZED,
# detail="No introspection endpoint found for the OIDC provider",
# )
# if (
# response := await oidc_provider.post(
# url,
# token=token,
# data={"token": token["access_token"]},
# )
# ).is_success:
# return response.json()
# else:
# raise HTTPException(status_code=response.status_code, detail=response.text)
@resource_server.get("/{id}")
async def get_resource_(
id: str,
# user: Annotated[User, Depends(get_current_user)],
# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
# token: Annotated[OAuth2Token, Depends(get_token)],
user: Annotated[User, Depends(get_user_from_token)],
) -> JSONResponse:
"""Generic path for testing a resource provided by a provider"""
return JSONResponse(await get_resource(id, user))
async def get_resource(resource_id: str, user: User) -> dict: async def get_resource(resource_id: str, user: User) -> dict:
""" """
@ -34,12 +146,10 @@ async def get_resource(resource_id: str, user: User) -> dict:
raise HTTPException( raise HTTPException(
status.HTTP_401_UNAUTHORIZED, status.HTTP_401_UNAUTHORIZED,
f"No scope {required_scope} in the access token " f"No scope {required_scope} in the access token "
+ "but it is required for accessing this resource.", + "but it is required for accessing this resource",
) )
except ExpiredSignatureError: except ExpiredSignatureError:
raise HTTPException( raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token's signature has expired")
status.HTTP_401_UNAUTHORIZED, "The token's signature has expired"
)
except InvalidTokenError: except InvalidTokenError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token is invalid") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token is invalid")
return resp return resp

View file

@ -1,20 +1,28 @@
function checkHref(elem) { async function checkHref(elem, token, authProvider) {
var xmlHttp = new XMLHttpRequest() const msg = document.getElementById("msg")
xmlHttp.onreadystatechange = function () { const resp = await fetch(elem.href, {
if (xmlHttp.readyState == 4) { headers: new Headers({
"Content-type": "application/json",
"Authorization": `Bearer ${token}`,
"auth_provider": authProvider,
}),
}).catch(err => {
msg.innerHTML = "Cannot fetch resource: " + err.message
resourceElem.innerHTML = ""
})
if (resp === undefined) {
return
} else {
elem.classList.add("hasResponseStatus") elem.classList.add("hasResponseStatus")
elem.classList.add("status-" + xmlHttp.status) elem.classList.add("status-" + resp.status)
elem.title = "Response code: " + xmlHttp.status + " - " + xmlHttp.statusText elem.title = "Response code: " + resp.status + " - " + resp.statusText
} }
}
xmlHttp.open("GET", elem.href, true) // true for asynchronous
xmlHttp.send(null)
} }
function checkPerms(className) { function checkPerms(className, token, authProvider) {
var rootElems = document.getElementsByClassName(className) var rootElems = document.getElementsByClassName(className)
Array.from(rootElems).forEach(elem => Array.from(rootElems).forEach(elem =>
Array.from(elem.children).forEach(elem => checkHref(elem)) Array.from(elem.children).forEach(elem => checkHref(elem, token, authProvider))
) )
} }

View file

@ -4,7 +4,7 @@
<link href="{{ url_for('static', path='/styles.css') }}" rel="stylesheet"> <link href="{{ url_for('static', path='/styles.css') }}" rel="stylesheet">
<script src="{{ url_for('static', path='/utils.js') }}"></script> <script src="{{ url_for('static', path='/utils.js') }}"></script>
</head> </head>
<body onload="checkPerms('links-to-check')"> <body onload="checkPerms('links-to-check', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">
<h1>OIDC-test - FastAPI client</h1> <h1>OIDC-test - FastAPI client</h1>
{% block content %} {% block content %}
{% endblock %} {% endblock %}

View file

@ -80,14 +80,14 @@
These links should get different response codes depending on the authorization: These links should get different response codes depending on the authorization:
</p> </p>
<div class="links-to-check"> <div class="links-to-check">
<a href="public">Public</a> <a href="resource/public">Public</a>
<a href="protected">Auth protected content</a> <a href="resource/protected">Auth protected content</a>
<a href="protected-by-foorole">Auth + foorole protected content</a> <a href="resource/protected-by-foorole">Auth + foorole protected content</a>
<a href="protected-by-foorole-or-barrole">Auth + foorole or barrole protected content</a> <a href="resource/protected-by-foorole-or-barrole">Auth + foorole or barrole protected content</a>
<a href="protected-by-barrole">Auth + barrole protected content</a> <a href="resource/protected-by-barrole">Auth + barrole protected content</a>
<a href="protected-by-foorole-and-barrole">Auth + foorole and barrole protected content</a> <a href="resource/protected-by-foorole-and-barrole">Auth + foorole and barrole protected content</a>
<a href="fast_api_depends" class="hidden">Using FastAPI Depends</a> <a href="resource/fast_api_depends" class="hidden">Using FastAPI Depends</a>
<a href="introspect">Introspect token (401 expected)</a> <!--<a href="resource/introspect">Introspect token (401 expected)</a>-->
</div> </div>
{% if resources %} {% if resources %}
<p> <p>