Create a sub-app for resource server move all resources to resource server;
All checks were successful
/ build (push) Successful in 5s
/ test (push) Successful in 5s

use token bearer instead of session cookie for resources and use fetch instead of XMLHttpRequest for checking resource status;
add UserWithRole class for fastapi depends (instead of has_role decorator);
add asserts for typing QC; code formatting;
comment out introspect endpoint processing
This commit is contained in:
phil 2025-02-07 13:57:17 +01:00
parent ee8ba3d2df
commit d39adf41ef
8 changed files with 188 additions and 153 deletions

View file

@ -13,7 +13,7 @@ from authlib.oauth2.auth import OAuth2Token
from .models import User
from .database import TokenNotInDb, db, UserNotInDB
from .settings import settings, OIDCProvider, oidc_providers_settings
from .settings import oidc_providers_settings
logger = logging.getLogger("oidc-test")
@ -21,6 +21,8 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
async def fetch_token(name, request):
assert name is not None
assert request is not None
logger.warn("TODO: fetch_token")
...
# if name in oidc_providers:
@ -37,8 +39,10 @@ async def update_token(name, token, refresh_token=None, access_token=None):
sid: str = oidc_provider_settings.decode(token["id_token"])["sid"]
item = await db.get_token(oidc_provider_settings, sid)
# update old token
item["access_token"] = token.get("access_token")
item["refresh_token"] = token.get("refresh_token")
if access_token is not None:
item["access_token"] = token.get("access_token")
if refresh_token is not None:
item["refresh_token"] = refresh_token
item["expires_at"] = token["expires_at"]
logger.info(f"Token {sid} refreshed")
# It's a fake db and only in memory, so there's nothing to save
@ -119,6 +123,7 @@ async def get_current_user(request: Request) -> User:
userinfo = await oidc_provider.fetch_access_token(
refresh_token=token.get("refresh_token")
)
assert userinfo is not None
except OAuthError as err:
logger.exception(err)
# raise HTTPException(
@ -242,3 +247,20 @@ async def get_user_from_token(
access_token=token,
)
return user
class UserWithRole:
roles: set[str]
def __init__(self, roles: str | list[str] | tuple[str] | set[str]):
if isinstance(roles, str):
self.roles = set([roles])
elif isinstance(roles, (list, tuple, set)):
self.roles = set(roles)
def __call__(self, user: User = Depends(get_user_from_token)) -> User:
if not any(self.roles.intersection(user.roles_as_set)):
raise HTTPException(
status.HTTP_401_UNAUTHORIZED, f"Not of any required role {', '.join(self.roles)}"
)
return user

View file

@ -69,6 +69,7 @@ class Database:
async def add_token(self, oidc_provider_settings: OIDCProvider, token: OAuth2Token) -> None:
"""Store a token using as key the sid (auth provider's session id)
in the id_token"""
assert isinstance(oidc_provider_settings, OIDCProvider)
sid = token["userinfo"]["sid"]
self.tokens[sid] = token
@ -77,6 +78,7 @@ class Database:
oidc_provider_settings: OIDCProvider,
sid: str | None,
) -> OAuth2Token:
assert isinstance(oidc_provider_settings, OIDCProvider)
if sid is None:
raise TokenNotInDb
try:

View file

@ -12,7 +12,7 @@ from contextlib import asynccontextmanager
from httpx import HTTPError
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
from jwt import InvalidTokenError, PyJWTError
@ -31,17 +31,13 @@ from .models import User
from .auth_utils import (
get_oidc_provider,
get_oidc_provider_or_none,
hasrole,
get_current_user_or_none,
get_current_user,
get_user_from_token,
authlib_oauth,
get_token,
get_providers_info,
)
from .auth_misc import pretty_details
from .database import TokenNotInDb, db
from .resource_server import get_resource
from .resource_server import resource_server
logger = logging.getLogger("oidc-test")
@ -50,6 +46,7 @@ templates = Jinja2Templates(Path(__file__).parent / "templates")
@asynccontextmanager
async def lifespan(app: FastAPI):
assert app is not None
await get_providers_info()
yield
@ -64,24 +61,21 @@ app.add_middleware(
allow_headers=["*"],
)
app.mount(
"/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static"
)
# SessionMiddleware is required by authlib
app.add_middleware(
SessionMiddleware,
secret_key=settings.secret_key,
)
app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static")
app.mount("/resource", resource_server, name="resource_server")
@app.get("/")
async def home(
request: Request,
user: Annotated[User, Depends(get_current_user_or_none)],
oidc_provider: Annotated[
StarletteOAuth2App | None, Depends(get_oidc_provider_or_none)
],
oidc_provider: Annotated[StarletteOAuth2App | None, Depends(get_oidc_provider_or_none)],
) -> HTMLResponse:
now = datetime.now()
if oidc_provider and (
@ -119,9 +113,7 @@ async def home(
"oidc_provider_settings": oidc_provider_settings,
"resources": resources,
"user_info_details": (
pretty_details(user, now)
if user and settings.oidc.show_session_details
else None
pretty_details(user, now) if user and settings.oidc.show_session_details else None
),
},
)
@ -215,24 +207,19 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
else:
# Not sure if it's correct to redirect to plain login
# if no userinfo is provided
return RedirectResponse(
url=request.url_for("login", oidc_provider_id=oidc_provider_id)
)
return RedirectResponse(url=request.url_for("login", oidc_provider_id=oidc_provider_id))
@app.get("/account")
async def account(
request: Request,
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
) -> RedirectResponse:
if (
oidc_provider_settings := oidc_providers_settings.get(
request.session.get("oidc_provider_id", "")
)
) is None:
raise HTTPException(
status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider settings"
)
raise HTTPException(status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider settings")
return RedirectResponse(f"{oidc_provider_settings.account_url_template}")
@ -244,12 +231,8 @@ async def logout(
# Clear session
request.session.pop("user_sub", None)
# Get provider's endpoint
if (
provider_logout_uri := oidc_provider.server_metadata.get("end_session_endpoint")
) is None:
logger.warn(
f"Cannot find end_session_endpoint for provider {oidc_provider.name}"
)
if (provider_logout_uri := oidc_provider.server_metadata.get("end_session_endpoint")) is None:
logger.warn(f"Cannot find end_session_endpoint for provider {oidc_provider.name}")
return RedirectResponse(request.url_for("non_compliant_logout"))
post_logout_uri = request.url_for("home")
oidc_provider_settings = oidc_providers_settings.get(
@ -257,9 +240,7 @@ async def logout(
)
assert oidc_provider_settings is not None
try:
token = await db.get_token(
oidc_provider_settings, request.session.pop("sid", None)
)
token = await db.get_token(oidc_provider_settings, request.session.pop("sid", None))
except TokenNotInDb:
logger.warn("No session in db for the token or no token")
return RedirectResponse(request.url_for("home"))
@ -292,90 +273,6 @@ async def non_compliant_logout(
)
# Route for OAuth resource server
@app.get("/resource/{id}")
async def get_resource_(
id: str,
# user: Annotated[User, Depends(get_current_user)],
# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
# token: Annotated[OAuth2Token, Depends(get_token)],
user: Annotated[User, Depends(get_user_from_token)],
) -> JSONResponse:
"""Generic path for testing a resource provided by a provider"""
return JSONResponse(await get_resource(id, user))
# Routes for RBAC based tests
@app.get("/public")
async def public() -> HTMLResponse:
return HTMLResponse("<h1>Not protected</h1>")
@app.get("/protected")
async def get_protected(
user: Annotated[User, Depends(get_current_user)]
) -> HTMLResponse:
assert user is not None # Just to keep QA checks happy
return HTMLResponse("<h1>Only authenticated users can see this</h1>")
@app.get("/protected-by-foorole")
@hasrole("foorole")
async def get_protected_by_foorole(request: Request) -> HTMLResponse:
assert request is not None # Just to keep QA checks happy
return HTMLResponse("<h1>Only users with foorole can see this</h1>")
@app.get("/protected-by-barrole")
@hasrole("barrole")
async def get_protected_by_barrole(request: Request) -> HTMLResponse:
assert request is not None # Just to keep QA checks happy
return HTMLResponse("<h1>Protected by barrole</h1>")
@app.get("/protected-by-foorole-and-barrole")
@hasrole("barrole")
@hasrole("foorole")
async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse:
assert request is not None # Just to keep QA checks happy
return HTMLResponse("<h1>Only users with foorole and barrole can see this</h1>")
@app.get("/protected-by-foorole-or-barrole")
@hasrole(["foorole", "barrole"])
async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse:
assert request is not None # Just to keep QA checks happy
return HTMLResponse("<h1>Only users with foorole or barrole can see this</h1>")
@app.get("/introspect")
async def get_introspect(
request: Request,
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
token: Annotated[OAuth2Token, Depends(get_token)],
) -> JSONResponse:
assert request is not None # Just to keep QA checks happy
if (url := oidc_provider.server_metadata.get("introspection_endpoint")) is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="No intrispection endpoint found for the OIDC provider",
)
if (
response := await oidc_provider.post(
url,
token=token,
data={"token": token["access_token"]},
)
).is_success:
return response.json()
else:
raise HTTPException(status_code=response.status_code, detail=response.text)
# Snippet for running standalone
# Mostly useful for the --version option,
# as running with uvicorn is easy and provides better flexibility, eg.
@ -397,9 +294,7 @@ def main():
parser.add_argument(
"-p", "--port", type=int, default=80, help="Port to listen to (default: 80)"
)
parser.add_argument(
"-v", "--version", action="store_true", help="Print version and exit"
)
parser.add_argument("-v", "--version", action="store_true", help="Print version and exit")
args = parser.parse_args()
if args.version:

View file

@ -1,6 +1,6 @@
import logging
from functools import cached_property
from typing import Self, Any
from typing import Any
from pydantic import (
computed_field,
@ -60,6 +60,4 @@ class User(UserBase):
assert self.oidc_provider.name is not None
from .settings import oidc_providers_settings
return oidc_providers_settings[self.oidc_provider.name].decode(
self.access_token
)
return oidc_providers_settings[self.oidc_provider.name].decode(self.access_token)

View file

@ -1,15 +1,127 @@
from datetime import datetime
from typing import Annotated
import logging
from httpx import AsyncClient
from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
from fastapi import HTTPException, status
from starlette.status import HTTP_401_UNAUTHORIZED
from fastapi import FastAPI, HTTPException, Depends, Request, status
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
# from starlette.middleware.sessions import SessionMiddleware
# from authlib.integrations.starlette_client.apps import StarletteOAuth2App
# from authlib.oauth2.rfc6749 import OAuth2Token
from .models import User
from .auth_utils import (
get_user_from_token,
UserWithRole,
get_oidc_provider,
get_token,
)
from .settings import settings
logger = logging.getLogger("oidc-test")
resource_server = FastAPI()
resource_server.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# SessionMiddleware is required by authlib
# resource_server.add_middleware(
# SessionMiddleware,
# secret_key=settings.secret_key,
# )
# Route for OAuth resource server
# Routes for RBAC based tests
@resource_server.get("/public")
async def public() -> HTMLResponse:
return HTMLResponse("<h1>Not protected</h1>")
@resource_server.get("/protected")
async def get_protected(user: Annotated[User, Depends(get_user_from_token)]) -> HTMLResponse:
assert user is not None # Just to keep QA checks happy
return HTMLResponse("<h1>Only authenticated users can see this</h1>")
@resource_server.get("/protected-by-foorole")
async def get_protected_by_foorole(
user: Annotated[User, Depends(UserWithRole("foorole"))]
) -> HTMLResponse:
return HTMLResponse("<h1>Only users with foorole can see this</h1>")
@resource_server.get("/protected-by-barrole")
async def get_protected_by_barrole(
user: Annotated[User, Depends(UserWithRole("barrole"))]
) -> HTMLResponse:
return HTMLResponse("<h1>Protected by barrole</h1>")
@resource_server.get("/protected-by-foorole-and-barrole")
async def get_protected_by_foorole_and_barrole(
user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))],
) -> HTMLResponse:
assert user is not None # Just to keep QA checks happy
return HTMLResponse("<h1>Only users with foorole and barrole can see this</h1>")
@resource_server.get("/protected-by-foorole-or-barrole")
async def get_protected_by_foorole_or_barrole(
user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))]
) -> HTMLResponse:
assert user is not None # Just to keep QA checks happy
return HTMLResponse("<h1>Only users with foorole or barrole can see this</h1>")
# @resource_server.get("/introspect")
# async def get_introspect(
# request: Request,
# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
# token: Annotated[OAuth2Token, Depends(get_token)],
# ) -> JSONResponse:
# assert request is not None # Just to keep QA checks happy
# if (url := oidc_provider.server_metadata.get("introspection_endpoint")) is None:
# raise HTTPException(
# status_code=status.HTTP_401_UNAUTHORIZED,
# detail="No introspection endpoint found for the OIDC provider",
# )
# if (
# response := await oidc_provider.post(
# url,
# token=token,
# data={"token": token["access_token"]},
# )
# ).is_success:
# return response.json()
# else:
# raise HTTPException(status_code=response.status_code, detail=response.text)
@resource_server.get("/{id}")
async def get_resource_(
id: str,
# user: Annotated[User, Depends(get_current_user)],
# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
# token: Annotated[OAuth2Token, Depends(get_token)],
user: Annotated[User, Depends(get_user_from_token)],
) -> JSONResponse:
"""Generic path for testing a resource provided by a provider"""
return JSONResponse(await get_resource(id, user))
async def get_resource(resource_id: str, user: User) -> dict:
"""
@ -34,12 +146,10 @@ async def get_resource(resource_id: str, user: User) -> dict:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED,
f"No scope {required_scope} in the access token "
+ "but it is required for accessing this resource.",
+ "but it is required for accessing this resource",
)
except ExpiredSignatureError:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED, "The token's signature has expired"
)
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token's signature has expired")
except InvalidTokenError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token is invalid")
return resp

View file

@ -1,20 +1,28 @@
function checkHref(elem) {
var xmlHttp = new XMLHttpRequest()
xmlHttp.onreadystatechange = function () {
if (xmlHttp.readyState == 4) {
elem.classList.add("hasResponseStatus")
elem.classList.add("status-" + xmlHttp.status)
elem.title = "Response code: " + xmlHttp.status + " - " + xmlHttp.statusText
}
async function checkHref(elem, token, authProvider) {
const msg = document.getElementById("msg")
const resp = await fetch(elem.href, {
headers: new Headers({
"Content-type": "application/json",
"Authorization": `Bearer ${token}`,
"auth_provider": authProvider,
}),
}).catch(err => {
msg.innerHTML = "Cannot fetch resource: " + err.message
resourceElem.innerHTML = ""
})
if (resp === undefined) {
return
} else {
elem.classList.add("hasResponseStatus")
elem.classList.add("status-" + resp.status)
elem.title = "Response code: " + resp.status + " - " + resp.statusText
}
xmlHttp.open("GET", elem.href, true) // true for asynchronous
xmlHttp.send(null)
}
function checkPerms(className) {
function checkPerms(className, token, authProvider) {
var rootElems = document.getElementsByClassName(className)
Array.from(rootElems).forEach(elem =>
Array.from(elem.children).forEach(elem => checkHref(elem))
Array.from(elem.children).forEach(elem => checkHref(elem, token, authProvider))
)
}

View file

@ -4,7 +4,7 @@
<link href="{{ url_for('static', path='/styles.css') }}" rel="stylesheet">
<script src="{{ url_for('static', path='/utils.js') }}"></script>
</head>
<body onload="checkPerms('links-to-check')">
<body onload="checkPerms('links-to-check', '{{ user.access_token }}', '{{ oidc_provider_settings.id }}')">
<h1>OIDC-test - FastAPI client</h1>
{% block content %}
{% endblock %}

View file

@ -80,14 +80,14 @@
These links should get different response codes depending on the authorization:
</p>
<div class="links-to-check">
<a href="public">Public</a>
<a href="protected">Auth protected content</a>
<a href="protected-by-foorole">Auth + foorole protected content</a>
<a href="protected-by-foorole-or-barrole">Auth + foorole or barrole protected content</a>
<a href="protected-by-barrole">Auth + barrole protected content</a>
<a href="protected-by-foorole-and-barrole">Auth + foorole and barrole protected content</a>
<a href="fast_api_depends" class="hidden">Using FastAPI Depends</a>
<a href="introspect">Introspect token (401 expected)</a>
<a href="resource/public">Public</a>
<a href="resource/protected">Auth protected content</a>
<a href="resource/protected-by-foorole">Auth + foorole protected content</a>
<a href="resource/protected-by-foorole-or-barrole">Auth + foorole or barrole protected content</a>
<a href="resource/protected-by-barrole">Auth + barrole protected content</a>
<a href="resource/protected-by-foorole-and-barrole">Auth + foorole and barrole protected content</a>
<a href="resource/fast_api_depends" class="hidden">Using FastAPI Depends</a>
<!--<a href="resource/introspect">Introspect token (401 expected)</a>-->
</div>
{% if resources %}
<p>