Create a sub-app for resource server move all resources to resource server;
use token bearer instead of session cookie for resources and use fetch instead of XMLHttpRequest for checking resource status; add UserWithRole class for fastapi depends (instead of has_role decorator); add asserts for typing QC; code formatting; comment out introspect endpoint processing
This commit is contained in:
parent
ee8ba3d2df
commit
d39adf41ef
8 changed files with 188 additions and 153 deletions
|
@ -12,7 +12,7 @@ from contextlib import asynccontextmanager
|
|||
from httpx import HTTPError
|
||||
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from jwt import InvalidTokenError, PyJWTError
|
||||
|
@ -31,17 +31,13 @@ from .models import User
|
|||
from .auth_utils import (
|
||||
get_oidc_provider,
|
||||
get_oidc_provider_or_none,
|
||||
hasrole,
|
||||
get_current_user_or_none,
|
||||
get_current_user,
|
||||
get_user_from_token,
|
||||
authlib_oauth,
|
||||
get_token,
|
||||
get_providers_info,
|
||||
)
|
||||
from .auth_misc import pretty_details
|
||||
from .database import TokenNotInDb, db
|
||||
from .resource_server import get_resource
|
||||
from .resource_server import resource_server
|
||||
|
||||
logger = logging.getLogger("oidc-test")
|
||||
|
||||
|
@ -50,6 +46,7 @@ templates = Jinja2Templates(Path(__file__).parent / "templates")
|
|||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
assert app is not None
|
||||
await get_providers_info()
|
||||
yield
|
||||
|
||||
|
@ -64,24 +61,21 @@ app.add_middleware(
|
|||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.mount(
|
||||
"/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static"
|
||||
)
|
||||
|
||||
# SessionMiddleware is required by authlib
|
||||
app.add_middleware(
|
||||
SessionMiddleware,
|
||||
secret_key=settings.secret_key,
|
||||
)
|
||||
|
||||
app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static")
|
||||
app.mount("/resource", resource_server, name="resource_server")
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def home(
|
||||
request: Request,
|
||||
user: Annotated[User, Depends(get_current_user_or_none)],
|
||||
oidc_provider: Annotated[
|
||||
StarletteOAuth2App | None, Depends(get_oidc_provider_or_none)
|
||||
],
|
||||
oidc_provider: Annotated[StarletteOAuth2App | None, Depends(get_oidc_provider_or_none)],
|
||||
) -> HTMLResponse:
|
||||
now = datetime.now()
|
||||
if oidc_provider and (
|
||||
|
@ -119,9 +113,7 @@ async def home(
|
|||
"oidc_provider_settings": oidc_provider_settings,
|
||||
"resources": resources,
|
||||
"user_info_details": (
|
||||
pretty_details(user, now)
|
||||
if user and settings.oidc.show_session_details
|
||||
else None
|
||||
pretty_details(user, now) if user and settings.oidc.show_session_details else None
|
||||
),
|
||||
},
|
||||
)
|
||||
|
@ -215,24 +207,19 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
|||
else:
|
||||
# Not sure if it's correct to redirect to plain login
|
||||
# if no userinfo is provided
|
||||
return RedirectResponse(
|
||||
url=request.url_for("login", oidc_provider_id=oidc_provider_id)
|
||||
)
|
||||
return RedirectResponse(url=request.url_for("login", oidc_provider_id=oidc_provider_id))
|
||||
|
||||
|
||||
@app.get("/account")
|
||||
async def account(
|
||||
request: Request,
|
||||
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
||||
) -> RedirectResponse:
|
||||
if (
|
||||
oidc_provider_settings := oidc_providers_settings.get(
|
||||
request.session.get("oidc_provider_id", "")
|
||||
)
|
||||
) is None:
|
||||
raise HTTPException(
|
||||
status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider settings"
|
||||
)
|
||||
raise HTTPException(status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider settings")
|
||||
return RedirectResponse(f"{oidc_provider_settings.account_url_template}")
|
||||
|
||||
|
||||
|
@ -244,12 +231,8 @@ async def logout(
|
|||
# Clear session
|
||||
request.session.pop("user_sub", None)
|
||||
# Get provider's endpoint
|
||||
if (
|
||||
provider_logout_uri := oidc_provider.server_metadata.get("end_session_endpoint")
|
||||
) is None:
|
||||
logger.warn(
|
||||
f"Cannot find end_session_endpoint for provider {oidc_provider.name}"
|
||||
)
|
||||
if (provider_logout_uri := oidc_provider.server_metadata.get("end_session_endpoint")) is None:
|
||||
logger.warn(f"Cannot find end_session_endpoint for provider {oidc_provider.name}")
|
||||
return RedirectResponse(request.url_for("non_compliant_logout"))
|
||||
post_logout_uri = request.url_for("home")
|
||||
oidc_provider_settings = oidc_providers_settings.get(
|
||||
|
@ -257,9 +240,7 @@ async def logout(
|
|||
)
|
||||
assert oidc_provider_settings is not None
|
||||
try:
|
||||
token = await db.get_token(
|
||||
oidc_provider_settings, request.session.pop("sid", None)
|
||||
)
|
||||
token = await db.get_token(oidc_provider_settings, request.session.pop("sid", None))
|
||||
except TokenNotInDb:
|
||||
logger.warn("No session in db for the token or no token")
|
||||
return RedirectResponse(request.url_for("home"))
|
||||
|
@ -292,90 +273,6 @@ async def non_compliant_logout(
|
|||
)
|
||||
|
||||
|
||||
# Route for OAuth resource server
|
||||
|
||||
|
||||
@app.get("/resource/{id}")
|
||||
async def get_resource_(
|
||||
id: str,
|
||||
# user: Annotated[User, Depends(get_current_user)],
|
||||
# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
||||
# token: Annotated[OAuth2Token, Depends(get_token)],
|
||||
user: Annotated[User, Depends(get_user_from_token)],
|
||||
) -> JSONResponse:
|
||||
"""Generic path for testing a resource provided by a provider"""
|
||||
return JSONResponse(await get_resource(id, user))
|
||||
|
||||
|
||||
# Routes for RBAC based tests
|
||||
|
||||
|
||||
@app.get("/public")
|
||||
async def public() -> HTMLResponse:
|
||||
return HTMLResponse("<h1>Not protected</h1>")
|
||||
|
||||
|
||||
@app.get("/protected")
|
||||
async def get_protected(
|
||||
user: Annotated[User, Depends(get_current_user)]
|
||||
) -> HTMLResponse:
|
||||
assert user is not None # Just to keep QA checks happy
|
||||
return HTMLResponse("<h1>Only authenticated users can see this</h1>")
|
||||
|
||||
|
||||
@app.get("/protected-by-foorole")
|
||||
@hasrole("foorole")
|
||||
async def get_protected_by_foorole(request: Request) -> HTMLResponse:
|
||||
assert request is not None # Just to keep QA checks happy
|
||||
return HTMLResponse("<h1>Only users with foorole can see this</h1>")
|
||||
|
||||
|
||||
@app.get("/protected-by-barrole")
|
||||
@hasrole("barrole")
|
||||
async def get_protected_by_barrole(request: Request) -> HTMLResponse:
|
||||
assert request is not None # Just to keep QA checks happy
|
||||
return HTMLResponse("<h1>Protected by barrole</h1>")
|
||||
|
||||
|
||||
@app.get("/protected-by-foorole-and-barrole")
|
||||
@hasrole("barrole")
|
||||
@hasrole("foorole")
|
||||
async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse:
|
||||
assert request is not None # Just to keep QA checks happy
|
||||
return HTMLResponse("<h1>Only users with foorole and barrole can see this</h1>")
|
||||
|
||||
|
||||
@app.get("/protected-by-foorole-or-barrole")
|
||||
@hasrole(["foorole", "barrole"])
|
||||
async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse:
|
||||
assert request is not None # Just to keep QA checks happy
|
||||
return HTMLResponse("<h1>Only users with foorole or barrole can see this</h1>")
|
||||
|
||||
|
||||
@app.get("/introspect")
|
||||
async def get_introspect(
|
||||
request: Request,
|
||||
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
||||
token: Annotated[OAuth2Token, Depends(get_token)],
|
||||
) -> JSONResponse:
|
||||
assert request is not None # Just to keep QA checks happy
|
||||
if (url := oidc_provider.server_metadata.get("introspection_endpoint")) is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="No intrispection endpoint found for the OIDC provider",
|
||||
)
|
||||
if (
|
||||
response := await oidc_provider.post(
|
||||
url,
|
||||
token=token,
|
||||
data={"token": token["access_token"]},
|
||||
)
|
||||
).is_success:
|
||||
return response.json()
|
||||
else:
|
||||
raise HTTPException(status_code=response.status_code, detail=response.text)
|
||||
|
||||
|
||||
# Snippet for running standalone
|
||||
# Mostly useful for the --version option,
|
||||
# as running with uvicorn is easy and provides better flexibility, eg.
|
||||
|
@ -397,9 +294,7 @@ def main():
|
|||
parser.add_argument(
|
||||
"-p", "--port", type=int, default=80, help="Port to listen to (default: 80)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v", "--version", action="store_true", help="Print version and exit"
|
||||
)
|
||||
parser.add_argument("-v", "--version", action="store_true", help="Print version and exit")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.version:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue