oidc-fastapi-test/src/oidc_test/resource_server.py
phil 64f6a90f22
All checks were successful
/ build (push) Successful in 6s
/ test (push) Successful in 5s
Add resource provided registry and plugin system
2025-02-11 17:27:49 +01:00

274 lines
9.7 KiB
Python

from datetime import datetime
from typing import Annotated
import logging
from authlib.oauth2.auth import OAuth2Token
from httpx import AsyncClient
from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
from fastapi import FastAPI, HTTPException, Depends, Request, status
from fastapi.middleware.cors import CORSMiddleware
# from starlette.middleware.sessions import SessionMiddleware
# from authlib.integrations.starlette_client.apps import StarletteOAuth2App
# from authlib.oauth2.rfc6749 import OAuth2Token
from oidc_test.auth.provider import Provider
from oidc_test.auth.utils import (
get_token_or_none,
get_user_from_token,
UserWithRole,
)
from oidc_test.auth_providers import providers
from oidc_test.settings import settings
from oidc_test.models import User
from oidc_test.registry import ProcessError, ProcessResult, registry
logger = logging.getLogger("oidc-test")
resource_server = FastAPI()
resource_server.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# SessionMiddleware is required by authlib
# resource_server.add_middleware(
# SessionMiddleware,
# secret_key=settings.secret_key,
# )
# Route for OAuth resource server
# Routes for RBAC based tests
@resource_server.get("/public")
async def public() -> dict:
return {"msg": "Not protected"}
@resource_server.get("/protected")
async def get_protected(user: Annotated[User, Depends(get_user_from_token)]):
assert user is not None # Just to keep QA checks happy
return {"msg": "Only authenticated users can see this"}
@resource_server.get("/protected-by-foorole")
async def get_protected_by_foorole(
user: Annotated[User, Depends(UserWithRole("foorole"))],
):
assert user is not None
return {"msg": "Only users with foorole can see this"}
@resource_server.get("/protected-by-barrole")
async def get_protected_by_barrole(
user: Annotated[User, Depends(UserWithRole("barrole"))],
):
assert user is not None
return {"msg": "Protected by barrole"}
@resource_server.get("/protected-by-foorole-and-barrole")
async def get_protected_by_foorole_and_barrole(
user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))],
):
assert user is not None # Just to keep QA checks happy
return {"msg": "Only users with foorole and barrole can see this"}
@resource_server.get("/protected-by-foorole-or-barrole")
async def get_protected_by_foorole_or_barrole(
user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))],
):
assert user is not None # Just to keep QA checks happy
return {"msg": "Only users with foorole or barrole can see this"}
@resource_server.get("/{resource_name}")
@resource_server.get("/{resource_name}/{resource_id}")
async def get_resource(
resource_name: str,
user: Annotated[User, Depends(get_user_from_token)],
token: Annotated[OAuth2Token | None, Depends(get_token_or_none)],
request: Request,
resource_id: str | None = None,
) -> ProcessResult:
"""Generic path for testing a resource provided by a provider"""
provider = providers[user.auth_provider_id]
# Third party resource (provided through the auth provider)
# The token is just passed on
if resource_name in [r.resource_name for r in provider.resources]:
return await get_auth_provider_resource(
provider=provider,
resource_name=resource_name,
access_token=token["access_token"] if token else None,
user=user,
)
# Internal resource (provided here)
if resource_name in registry.resource_providers:
resource_provider = registry.resource_providers[resource_name]
if resource_provider.scope_required is not None and user.has_scope(
resource_provider.scope_required
):
try:
return await resource_provider.process(user=user, resource_id=resource_id)
except ProcessError as err:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED, f"Cannot process resource: {err}"
)
else:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED,
f"No scope {resource_provider.scope_required} in the access token "
+ "but it is required for accessing this resource",
)
else:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown resource {resource_name}")
# return await get_resource_(resource_name, user, **request.query_params)
async def get_auth_provider_resource(
provider: Provider, resource_name: str, access_token: str | None, user: User
) -> ProcessResult:
resource = [r for r in provider.resources if r.resource_name == resource_name][0]
async with AsyncClient() as client:
resp = await client.get(
url=provider.url + resource.url,
headers={
"Content-type": "application/json",
"Authorization": f"Bearer {access_token}",
},
)
if resp.is_error:
raise HTTPException(resp.status_code, f"Cannot fetch resource: {resp.reason_phrase}")
# Only a demo, real application would really process the response
resp_length = len(resp.text)
if resp_length > 1024:
return ProcessResult(
result={"msg": f"The resource is too long ({resp_length} bytes) to show here"}
)
else:
return ProcessResult(result=resp.json())
# async def get_resource_(resource_id: str, user: User, **kwargs) -> dict:
# """
# Resource processing: build an informative rely as a simple showcase
# """
# if resource_id == "petition":
# return await sign(user, kwargs["petition_id"])
# provider = providers[user.auth_provider_id]
# try:
# pname = provider.name
# except KeyError:
# pname = "?"
# resp = {
# "hello": f"Hi {user.name} from an OAuth resource provider",
# "comment": f"I received a request for '{resource_id}' "
# + f"with an access token signed by {pname}",
# }
# # For the demo, resource resource_id matches a scope get:resource_id,
# # but this has to be refined for production
# required_scope = f"get:{resource_id}"
# # Check if the required scope is in the scopes allowed in userinfo
# try:
# if user.has_scope(required_scope):
# await process(user, resource_id, resp)
# else:
# ## For the showcase, giving a explanation.
# ## Alternatively, raise HTTP_401_UNAUTHORIZED
# raise HTTPException(
# status.HTTP_401_UNAUTHORIZED,
# f"No scope {required_scope} in the access token "
# + "but it is required for accessing this resource",
# )
# except ExpiredSignatureError:
# raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token's signature has expired")
# except InvalidTokenError:
# raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token is invalid")
# return resp
# async def process(user, resource_id, resp):
# """
# Too simple to be serious.
# It's a good fit for a plugin architecture for production
# """
# if resource_id == "time":
# resp["time"] = datetime.now().strftime("%c")
# elif resource_id == "bs":
# async with AsyncClient() as client:
# bs = await client.get("https://corporatebs-generator.sameerkumar.website/")
# resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.")
# else:
# raise HTTPException(
# status.HTTP_401_UNAUTHORIZED, f"I don't known how to give '{resource_id}'."
# )
# @resource_server.get("/introspect")
# async def get_introspect(
# request: Request,
# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
# token: Annotated[OAuth2Token, Depends(get_token)],
# ) -> JSONResponse:
# assert request is not None # Just to keep QA checks happy
# if (url := oidc_provider.server_metadata.get("introspection_endpoint")) is None:
# raise HTTPException(
# status_code=status.HTTP_401_UNAUTHORIZED,
# detail="No introspection endpoint found for the OIDC provider",
# )
# if (
# response := await oidc_provider.post(
# url,
# token=token,
# data={"token": token["access_token"]},
# )
# ).is_success:
# return response.json()
# else:
# raise HTTPException(status_code=response.status_code, detail=response.text)
# assert user.oidc_provider is not None
### Get some info (TODO: refactor)
# if (auth_provider_id := user.oidc_provider.name) is None:
# raise HTTPException(
# status.HTTP_401_UNAUTHORIZED,
# "Request headers must have a 'auth_provider' field",
# )
# if (
# auth_provider_settings := oidc_providers_settings.get(auth_provider_id)
# ) is None:
# raise HTTPException(
# status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
# )
# if (key := auth_provider_settings.get_public_key()) is None:
# raise HTTPException(
# status.HTTP_401_UNAUTHORIZED,
# f"Key for provider '{auth_provider_id}' unknown",
# )
# logger.warn(f"refresh with scope {scope}")
# breakpoint()
# refreshed_auth_info = await user.oidc_provider.fetch_access_token(scope=scope)
### Decode the new token
# try:
# payload = decode(
# refreshed_auth_info["access_token"],
# key=key,
# algorithms=["RS256"],
# audience="account",
# options={"verify_signature": not settings.insecure.skip_verify_signature},
# )
# except ExpiredSignatureError as err:
# logger.info(f"Expired signature: {err}")
# raise HTTPException(
# status.HTTP_401_UNAUTHORIZED,
# "Expired signature (refresh not implemented yet)",
# )