274 lines
9.7 KiB
Python
274 lines
9.7 KiB
Python
from datetime import datetime
|
|
from typing import Annotated
|
|
import logging
|
|
|
|
from authlib.oauth2.auth import OAuth2Token
|
|
from httpx import AsyncClient
|
|
from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
|
|
from fastapi import FastAPI, HTTPException, Depends, Request, status
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
# from starlette.middleware.sessions import SessionMiddleware
|
|
# from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
|
# from authlib.oauth2.rfc6749 import OAuth2Token
|
|
|
|
from oidc_test.auth.provider import Provider
|
|
from oidc_test.auth.utils import (
|
|
get_token_or_none,
|
|
get_user_from_token,
|
|
UserWithRole,
|
|
)
|
|
from oidc_test.auth_providers import providers
|
|
from oidc_test.settings import settings
|
|
from oidc_test.models import User
|
|
from oidc_test.registry import ProcessError, ProcessResult, registry
|
|
|
|
logger = logging.getLogger("oidc-test")
|
|
|
|
resource_server = FastAPI()
|
|
|
|
|
|
resource_server.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=settings.cors_origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# SessionMiddleware is required by authlib
|
|
# resource_server.add_middleware(
|
|
# SessionMiddleware,
|
|
# secret_key=settings.secret_key,
|
|
# )
|
|
|
|
# Route for OAuth resource server
|
|
|
|
|
|
# Routes for RBAC based tests
|
|
|
|
|
|
@resource_server.get("/public")
|
|
async def public() -> dict:
|
|
return {"msg": "Not protected"}
|
|
|
|
|
|
@resource_server.get("/protected")
|
|
async def get_protected(user: Annotated[User, Depends(get_user_from_token)]):
|
|
assert user is not None # Just to keep QA checks happy
|
|
return {"msg": "Only authenticated users can see this"}
|
|
|
|
|
|
@resource_server.get("/protected-by-foorole")
|
|
async def get_protected_by_foorole(
|
|
user: Annotated[User, Depends(UserWithRole("foorole"))],
|
|
):
|
|
assert user is not None
|
|
return {"msg": "Only users with foorole can see this"}
|
|
|
|
|
|
@resource_server.get("/protected-by-barrole")
|
|
async def get_protected_by_barrole(
|
|
user: Annotated[User, Depends(UserWithRole("barrole"))],
|
|
):
|
|
assert user is not None
|
|
return {"msg": "Protected by barrole"}
|
|
|
|
|
|
@resource_server.get("/protected-by-foorole-and-barrole")
|
|
async def get_protected_by_foorole_and_barrole(
|
|
user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))],
|
|
):
|
|
assert user is not None # Just to keep QA checks happy
|
|
return {"msg": "Only users with foorole and barrole can see this"}
|
|
|
|
|
|
@resource_server.get("/protected-by-foorole-or-barrole")
|
|
async def get_protected_by_foorole_or_barrole(
|
|
user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))],
|
|
):
|
|
assert user is not None # Just to keep QA checks happy
|
|
return {"msg": "Only users with foorole or barrole can see this"}
|
|
|
|
|
|
@resource_server.get("/{resource_name}")
|
|
@resource_server.get("/{resource_name}/{resource_id}")
|
|
async def get_resource(
|
|
resource_name: str,
|
|
user: Annotated[User, Depends(get_user_from_token)],
|
|
token: Annotated[OAuth2Token | None, Depends(get_token_or_none)],
|
|
request: Request,
|
|
resource_id: str | None = None,
|
|
) -> ProcessResult:
|
|
"""Generic path for testing a resource provided by a provider"""
|
|
provider = providers[user.auth_provider_id]
|
|
# Third party resource (provided through the auth provider)
|
|
# The token is just passed on
|
|
if resource_name in [r.resource_name for r in provider.resources]:
|
|
return await get_auth_provider_resource(
|
|
provider=provider,
|
|
resource_name=resource_name,
|
|
access_token=token["access_token"] if token else None,
|
|
user=user,
|
|
)
|
|
# Internal resource (provided here)
|
|
if resource_name in registry.resource_providers:
|
|
resource_provider = registry.resource_providers[resource_name]
|
|
if resource_provider.scope_required is not None and user.has_scope(
|
|
resource_provider.scope_required
|
|
):
|
|
try:
|
|
return await resource_provider.process(user=user, resource_id=resource_id)
|
|
except ProcessError as err:
|
|
raise HTTPException(
|
|
status.HTTP_401_UNAUTHORIZED, f"Cannot process resource: {err}"
|
|
)
|
|
else:
|
|
raise HTTPException(
|
|
status.HTTP_401_UNAUTHORIZED,
|
|
f"No scope {resource_provider.scope_required} in the access token "
|
|
+ "but it is required for accessing this resource",
|
|
)
|
|
else:
|
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown resource {resource_name}")
|
|
# return await get_resource_(resource_name, user, **request.query_params)
|
|
|
|
|
|
async def get_auth_provider_resource(
|
|
provider: Provider, resource_name: str, access_token: str | None, user: User
|
|
) -> ProcessResult:
|
|
resource = [r for r in provider.resources if r.resource_name == resource_name][0]
|
|
async with AsyncClient() as client:
|
|
resp = await client.get(
|
|
url=provider.url + resource.url,
|
|
headers={
|
|
"Content-type": "application/json",
|
|
"Authorization": f"Bearer {access_token}",
|
|
},
|
|
)
|
|
if resp.is_error:
|
|
raise HTTPException(resp.status_code, f"Cannot fetch resource: {resp.reason_phrase}")
|
|
# Only a demo, real application would really process the response
|
|
resp_length = len(resp.text)
|
|
if resp_length > 1024:
|
|
return ProcessResult(
|
|
result={"msg": f"The resource is too long ({resp_length} bytes) to show here"}
|
|
)
|
|
else:
|
|
return ProcessResult(result=resp.json())
|
|
|
|
|
|
# async def get_resource_(resource_id: str, user: User, **kwargs) -> dict:
|
|
# """
|
|
# Resource processing: build an informative rely as a simple showcase
|
|
# """
|
|
# if resource_id == "petition":
|
|
# return await sign(user, kwargs["petition_id"])
|
|
# provider = providers[user.auth_provider_id]
|
|
# try:
|
|
# pname = provider.name
|
|
# except KeyError:
|
|
# pname = "?"
|
|
# resp = {
|
|
# "hello": f"Hi {user.name} from an OAuth resource provider",
|
|
# "comment": f"I received a request for '{resource_id}' "
|
|
# + f"with an access token signed by {pname}",
|
|
# }
|
|
# # For the demo, resource resource_id matches a scope get:resource_id,
|
|
# # but this has to be refined for production
|
|
# required_scope = f"get:{resource_id}"
|
|
# # Check if the required scope is in the scopes allowed in userinfo
|
|
# try:
|
|
# if user.has_scope(required_scope):
|
|
# await process(user, resource_id, resp)
|
|
# else:
|
|
# ## For the showcase, giving a explanation.
|
|
# ## Alternatively, raise HTTP_401_UNAUTHORIZED
|
|
# raise HTTPException(
|
|
# status.HTTP_401_UNAUTHORIZED,
|
|
# f"No scope {required_scope} in the access token "
|
|
# + "but it is required for accessing this resource",
|
|
# )
|
|
# except ExpiredSignatureError:
|
|
# raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token's signature has expired")
|
|
# except InvalidTokenError:
|
|
# raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token is invalid")
|
|
# return resp
|
|
|
|
|
|
# async def process(user, resource_id, resp):
|
|
# """
|
|
# Too simple to be serious.
|
|
# It's a good fit for a plugin architecture for production
|
|
# """
|
|
# if resource_id == "time":
|
|
# resp["time"] = datetime.now().strftime("%c")
|
|
# elif resource_id == "bs":
|
|
# async with AsyncClient() as client:
|
|
# bs = await client.get("https://corporatebs-generator.sameerkumar.website/")
|
|
# resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.")
|
|
# else:
|
|
# raise HTTPException(
|
|
# status.HTTP_401_UNAUTHORIZED, f"I don't known how to give '{resource_id}'."
|
|
# )
|
|
|
|
|
|
# @resource_server.get("/introspect")
|
|
# async def get_introspect(
|
|
# request: Request,
|
|
# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
|
# token: Annotated[OAuth2Token, Depends(get_token)],
|
|
# ) -> JSONResponse:
|
|
# assert request is not None # Just to keep QA checks happy
|
|
# if (url := oidc_provider.server_metadata.get("introspection_endpoint")) is None:
|
|
# raise HTTPException(
|
|
# status_code=status.HTTP_401_UNAUTHORIZED,
|
|
# detail="No introspection endpoint found for the OIDC provider",
|
|
# )
|
|
# if (
|
|
# response := await oidc_provider.post(
|
|
# url,
|
|
# token=token,
|
|
# data={"token": token["access_token"]},
|
|
# )
|
|
# ).is_success:
|
|
# return response.json()
|
|
# else:
|
|
# raise HTTPException(status_code=response.status_code, detail=response.text)
|
|
|
|
# assert user.oidc_provider is not None
|
|
### Get some info (TODO: refactor)
|
|
# if (auth_provider_id := user.oidc_provider.name) is None:
|
|
# raise HTTPException(
|
|
# status.HTTP_401_UNAUTHORIZED,
|
|
# "Request headers must have a 'auth_provider' field",
|
|
# )
|
|
# if (
|
|
# auth_provider_settings := oidc_providers_settings.get(auth_provider_id)
|
|
# ) is None:
|
|
# raise HTTPException(
|
|
# status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
|
|
# )
|
|
# if (key := auth_provider_settings.get_public_key()) is None:
|
|
# raise HTTPException(
|
|
# status.HTTP_401_UNAUTHORIZED,
|
|
# f"Key for provider '{auth_provider_id}' unknown",
|
|
# )
|
|
# logger.warn(f"refresh with scope {scope}")
|
|
# breakpoint()
|
|
# refreshed_auth_info = await user.oidc_provider.fetch_access_token(scope=scope)
|
|
### Decode the new token
|
|
# try:
|
|
# payload = decode(
|
|
# refreshed_auth_info["access_token"],
|
|
# key=key,
|
|
# algorithms=["RS256"],
|
|
# audience="account",
|
|
# options={"verify_signature": not settings.insecure.skip_verify_signature},
|
|
# )
|
|
# except ExpiredSignatureError as err:
|
|
# logger.info(f"Expired signature: {err}")
|
|
# raise HTTPException(
|
|
# status.HTTP_401_UNAUTHORIZED,
|
|
# "Expired signature (refresh not implemented yet)",
|
|
# )
|