Add resource provided registry and plugin system

This commit is contained in:
phil 2025-02-11 17:27:49 +01:00
parent e56be3c378
commit 7439aa082b
10 changed files with 229 additions and 142 deletions

View file

@ -7,8 +7,8 @@ from pydantic import ConfigDict
from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from httpx import AsyncClient from httpx import AsyncClient
from ..settings import AuthProviderSettings, settings from oidc_test.settings import AuthProviderSettings, settings
from ..models import User from oidc_test.models import User
logger = logging.getLogger("oidc-test") logger = logging.getLogger("oidc-test")

View file

@ -10,12 +10,11 @@ from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError
# from authlib.oauth1.auth import OAuthToken # from authlib.oauth1.auth import OAuthToken
from authlib.oauth2.auth import OAuth2Token from authlib.oauth2.auth import OAuth2Token
from .provider import Provider from oidc_test.auth.provider import Provider
from oidc_test.models import User
from ..models import User from oidc_test.database import db, TokenNotInDb, UserNotInDB
from ..database import db, TokenNotInDb, UserNotInDB from oidc_test.settings import settings
from ..settings import settings from oidc_test.auth_providers import providers
from ..auth_providers import providers
logger = logging.getLogger("oidc-test") logger = logging.getLogger("oidc-test")

View file

@ -1,5 +1,5 @@
from collections import OrderedDict from collections import OrderedDict
from .auth.provider import Provider from oidc_test.auth.provider import Provider
providers: OrderedDict[str, Provider] = OrderedDict() providers: OrderedDict[str, Provider] = OrderedDict()

View file

@ -5,10 +5,10 @@ import logging
from authlib.oauth2.rfc6749 import OAuth2Token from authlib.oauth2.rfc6749 import OAuth2Token
from jwt import PyJWTError from jwt import PyJWTError
from .auth.provider import Provider from oidc_test.auth.provider import Provider
from .models import User, Role from oidc_test.models import User, Role
from .auth_providers import providers from oidc_test.auth_providers import providers
logger = logging.getLogger("oidc-test") logger = logging.getLogger("oidc-test")
@ -23,6 +23,7 @@ class TokenNotInDb(Exception):
class Database: class Database:
users: dict[str, User] = {} users: dict[str, User] = {}
# TODO: key of the token table should be provider: sid
tokens: dict[str, OAuth2Token] = {} tokens: dict[str, OAuth2Token] = {}
# Last sessions for the user (key: users's subject id (sub)) # Last sessions for the user (key: users's subject id (sub))
@ -82,6 +83,7 @@ class Database:
provider: Provider, provider: Provider,
sid: str | None, sid: str | None,
) -> OAuth2Token: ) -> OAuth2Token:
# TODO: key of the token table should be provider: sid
assert isinstance(provider, Provider) assert isinstance(provider, Provider)
if sid is None: if sid is None:
raise TokenNotInDb raise TokenNotInDb

View file

@ -26,8 +26,9 @@ from authlib.oauth2.rfc6749 import OAuth2Token
# from fastapi.security import OpenIdConnect # from fastapi.security import OpenIdConnect
# from pkce import generate_code_verifier, generate_pkce_pair # from pkce import generate_code_verifier, generate_pkce_pair
from .auth.provider import NoPublicKey, Provider from oidc_test.registry import registry
from .auth.utils import ( from oidc_test.auth.provider import NoPublicKey, Provider
from oidc_test.auth.utils import (
get_auth_provider, get_auth_provider,
get_auth_provider_or_none, get_auth_provider_or_none,
get_current_user_or_none, get_current_user_or_none,
@ -36,13 +37,12 @@ from .auth.utils import (
get_token, get_token,
update_token, update_token,
) )
from oidc_test.auth.utils import init_providers
from .auth.utils import init_providers from oidc_test.settings import settings
from .settings import settings from oidc_test.auth_providers import providers
from .auth_providers import providers from oidc_test.models import User
from .models import User from oidc_test.database import TokenNotInDb, db
from .database import TokenNotInDb, db from oidc_test.resource_server import resource_server
from .resource_server import resource_server
logger = logging.getLogger("oidc-test") logger = logging.getLogger("oidc-test")
@ -53,6 +53,7 @@ templates = Jinja2Templates(Path(__file__).parent / "templates")
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
assert app is not None assert app is not None
init_providers() init_providers()
registry.make_registry()
for provider in list(providers.values()): for provider in list(providers.values()):
try: try:
await provider.get_info() await provider.get_info()
@ -104,6 +105,7 @@ async def home(
context["resources"] = None context["resources"] = None
else: else:
context["access_token"] = token["access_token"] context["access_token"] = token["access_token"]
# XXX: resources defined externally? I am confused...
context["resources"] = provider.resources context["resources"] = provider.resources
try: try:
access_token_parsed = provider.decode(token["access_token"], verify_signature=False) access_token_parsed = provider.decode(token["access_token"], verify_signature=False)
@ -113,9 +115,9 @@ async def home(
context["access_token_scope"] = access_token_parsed["scope"] context["access_token_scope"] = access_token_parsed["scope"]
except KeyError: except KeyError:
context["access_token_scope"] = None context["access_token_scope"] = None
# context["id_token_parsed"] = pretty_details(user, now)
context["id_token_parsed"] = provider.decode(token["id_token"], verify_signature=False) context["id_token_parsed"] = provider.decode(token["id_token"], verify_signature=False)
context["access_token_parsed"] = access_token_parsed context["access_token_parsed"] = access_token_parsed
context["resource_providers"] = registry.resource_providers
try: try:
context["refresh_token_parsed"] = provider.decode( context["refresh_token_parsed"] = provider.decode(
token["refresh_token"], verify_signature=False token["refresh_token"], verify_signature=False

43
src/oidc_test/registry.py Normal file
View file

@ -0,0 +1,43 @@
from importlib.metadata import entry_points
import logging
from typing import Any
from pydantic import BaseModel
from oidc_test.models import User
logger = logging.getLogger("registry")
class ProcessResult(BaseModel):
result: dict[str, Any] = {}
class ProcessError(Exception):
pass
class ResourceProvider:
name: str
scope_required: str | None = None
default_resource_id: str | None = None
def __init__(self, name: str):
self.name = name
async def process(self, user: User, resource_id: str | None = None) -> ProcessResult:
logger.warning(f"{self.name} should define a process method")
return ProcessResult()
class ResourceRegistry:
resource_providers: dict[str, ResourceProvider] = {}
def make_registry(self):
for ep in entry_points().select(group="oidc_test.resource_provider"):
ResourceProviderClass = ep.load()
if issubclass(ResourceProviderClass, ResourceProvider):
self.resource_providers[ep.name] = ResourceProviderClass(ep.name)
registry = ResourceRegistry()

View file

@ -5,23 +5,23 @@ import logging
from authlib.oauth2.auth import OAuth2Token from authlib.oauth2.auth import OAuth2Token
from httpx import AsyncClient from httpx import AsyncClient
from jwt.exceptions import ExpiredSignatureError, InvalidTokenError from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
from fastapi import FastAPI, HTTPException, Depends, status from fastapi import FastAPI, HTTPException, Depends, Request, status
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
# from starlette.middleware.sessions import SessionMiddleware # from starlette.middleware.sessions import SessionMiddleware
# from authlib.integrations.starlette_client.apps import StarletteOAuth2App # from authlib.integrations.starlette_client.apps import StarletteOAuth2App
# from authlib.oauth2.rfc6749 import OAuth2Token # from authlib.oauth2.rfc6749 import OAuth2Token
from .auth.provider import Provider from oidc_test.auth.provider import Provider
from .auth.utils import ( from oidc_test.auth.utils import (
get_token_or_none, get_token_or_none,
get_user_from_token, get_user_from_token,
UserWithRole, UserWithRole,
) )
from oidc_test.auth_providers import providers
from .auth_providers import providers from oidc_test.settings import settings
from .settings import settings from oidc_test.models import User
from .models import User from oidc_test.registry import ProcessError, ProcessResult, registry
logger = logging.getLogger("oidc-test") logger = logging.getLogger("oidc-test")
@ -91,6 +91,128 @@ async def get_protected_by_foorole_or_barrole(
return {"msg": "Only users with foorole or barrole can see this"} return {"msg": "Only users with foorole or barrole can see this"}
@resource_server.get("/{resource_name}")
@resource_server.get("/{resource_name}/{resource_id}")
async def get_resource(
resource_name: str,
user: Annotated[User, Depends(get_user_from_token)],
token: Annotated[OAuth2Token | None, Depends(get_token_or_none)],
request: Request,
resource_id: str | None = None,
) -> ProcessResult:
"""Generic path for testing a resource provided by a provider"""
provider = providers[user.auth_provider_id]
# Third party resource (provided through the auth provider)
# The token is just passed on
if resource_name in [r.resource_name for r in provider.resources]:
return await get_auth_provider_resource(
provider=provider,
resource_name=resource_name,
access_token=token["access_token"] if token else None,
user=user,
)
# Internal resource (provided here)
if resource_name in registry.resource_providers:
resource_provider = registry.resource_providers[resource_name]
if resource_provider.scope_required is not None and user.has_scope(
resource_provider.scope_required
):
try:
return await resource_provider.process(user=user, resource_id=resource_id)
except ProcessError as err:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED, f"Cannot process resource: {err}"
)
else:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED,
f"No scope {resource_provider.scope_required} in the access token "
+ "but it is required for accessing this resource",
)
else:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown resource {resource_name}")
# return await get_resource_(resource_name, user, **request.query_params)
async def get_auth_provider_resource(
provider: Provider, resource_name: str, access_token: str | None, user: User
) -> ProcessResult:
resource = [r for r in provider.resources if r.resource_name == resource_name][0]
async with AsyncClient() as client:
resp = await client.get(
url=provider.url + resource.url,
headers={
"Content-type": "application/json",
"Authorization": f"Bearer {access_token}",
},
)
if resp.is_error:
raise HTTPException(resp.status_code, f"Cannot fetch resource: {resp.reason_phrase}")
# Only a demo, real application would really process the response
resp_length = len(resp.text)
if resp_length > 1024:
return ProcessResult(
result={"msg": f"The resource is too long ({resp_length} bytes) to show here"}
)
else:
return ProcessResult(result=resp.json())
# async def get_resource_(resource_id: str, user: User, **kwargs) -> dict:
# """
# Resource processing: build an informative rely as a simple showcase
# """
# if resource_id == "petition":
# return await sign(user, kwargs["petition_id"])
# provider = providers[user.auth_provider_id]
# try:
# pname = provider.name
# except KeyError:
# pname = "?"
# resp = {
# "hello": f"Hi {user.name} from an OAuth resource provider",
# "comment": f"I received a request for '{resource_id}' "
# + f"with an access token signed by {pname}",
# }
# # For the demo, resource resource_id matches a scope get:resource_id,
# # but this has to be refined for production
# required_scope = f"get:{resource_id}"
# # Check if the required scope is in the scopes allowed in userinfo
# try:
# if user.has_scope(required_scope):
# await process(user, resource_id, resp)
# else:
# ## For the showcase, giving a explanation.
# ## Alternatively, raise HTTP_401_UNAUTHORIZED
# raise HTTPException(
# status.HTTP_401_UNAUTHORIZED,
# f"No scope {required_scope} in the access token "
# + "but it is required for accessing this resource",
# )
# except ExpiredSignatureError:
# raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token's signature has expired")
# except InvalidTokenError:
# raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token is invalid")
# return resp
# async def process(user, resource_id, resp):
# """
# Too simple to be serious.
# It's a good fit for a plugin architecture for production
# """
# if resource_id == "time":
# resp["time"] = datetime.now().strftime("%c")
# elif resource_id == "bs":
# async with AsyncClient() as client:
# bs = await client.get("https://corporatebs-generator.sameerkumar.website/")
# resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.")
# else:
# raise HTTPException(
# status.HTTP_401_UNAUTHORIZED, f"I don't known how to give '{resource_id}'."
# )
# @resource_server.get("/introspect") # @resource_server.get("/introspect")
# async def get_introspect( # async def get_introspect(
# request: Request, # request: Request,
@ -114,99 +236,6 @@ async def get_protected_by_foorole_or_barrole(
# else: # else:
# raise HTTPException(status_code=response.status_code, detail=response.text) # raise HTTPException(status_code=response.status_code, detail=response.text)
@resource_server.get("/{id}")
async def get_resource(
id: str,
user: Annotated[User, Depends(get_user_from_token)],
token: Annotated[OAuth2Token | None, Depends(get_token_or_none)],
) -> dict | list:
"""Generic path for testing a resource provided by a provider"""
provider = providers[user.auth_provider_id]
if id in [r.id for r in provider.resources]:
return await get_external_resource(
provider=provider,
id=id,
access_token=token["access_token"] if token else None,
user=user,
)
return await get_resource_(id, user)
async def get_external_resource(
provider: Provider, id: str, access_token: str | None, user: User
) -> dict | list:
resource = [r for r in provider.resources if r.id == id][0]
async with AsyncClient() as client:
resp = await client.get(
url=provider.url + resource.url,
headers={
"Content-type": "application/json",
"Authorization": f"Bearer {access_token}",
},
)
if resp.is_error:
raise HTTPException(resp.status_code, f"Cannot fetch resource: {resp.reason_phrase}")
resp_length = len(resp.text)
if resp_length > 1024:
return {"msg": f"The resource is too long ({resp_length} bytes) to show here"}
else:
return resp.json()
async def get_resource_(resource_id: str, user: User) -> dict:
"""
Resource processing: build an informative rely as a simple showcase
"""
provider = providers[user.auth_provider_id]
try:
pname = provider.name
except KeyError:
pname = "?"
resp = {
"hello": f"Hi {user.name} from an OAuth resource provider",
"comment": f"I received a request for '{resource_id}' "
+ f"with an access token signed by {pname}",
}
# For the demo, resource resource_id matches a scope get:resource_id,
# but this has to be refined for production
required_scope = f"get:{resource_id}"
# Check if the required scope is in the scopes allowed in userinfo
try:
if user.has_scope(required_scope):
await process(user, resource_id, resp)
else:
## For the showcase, giving a explanation.
## Alternatively, raise HTTP_401_UNAUTHORIZED
raise HTTPException(
status.HTTP_401_UNAUTHORIZED,
f"No scope {required_scope} in the access token "
+ "but it is required for accessing this resource",
)
except ExpiredSignatureError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token's signature has expired")
except InvalidTokenError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token is invalid")
return resp
async def process(user, resource_id, resp):
"""
Too simple to be serious.
It's a good fit for a plugin architecture for production
"""
if resource_id == "time":
resp["time"] = datetime.now().strftime("%c")
elif resource_id == "bs":
async with AsyncClient() as client:
bs = await client.get("https://corporatebs-generator.sameerkumar.website/")
resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.")
else:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED, f"I don't known how to give '{resource_id}'."
)
# assert user.oidc_provider is not None # assert user.oidc_provider is not None
### Get some info (TODO: refactor) ### Get some info (TODO: refactor)
# if (auth_provider_id := user.oidc_provider.name) is None: # if (auth_provider_id := user.oidc_provider.name) is None:

View file

@ -17,7 +17,7 @@ from starlette.requests import Request
class Resource(BaseModel): class Resource(BaseModel):
"""A resource with an URL that can be accessed with an OAuth2 access token""" """A resource with an URL that can be accessed with an OAuth2 access token"""
id: str resource_name: str
name: str name: str
url: str url: str

View file

@ -1,6 +1,8 @@
async function checkHref(elem, token, authProvider) { async function checkHref(elem, token, authProvider) {
const msg = document.getElementById("msg") const msg = document.getElementById("msg")
const url = `resource/${elem.getAttribute("resource-id")}` const resourceName = elem.getAttribute("resource-name")
const resourceId = elem.getAttribute("resource-id")
const url = resourceId ? `resource/${resourceName}/${resourceId}` : `resource/${resourceName}`
const resp = await fetch(url, { const resp = await fetch(url, {
method: "GET", method: "GET",
headers: new Headers({ headers: new Headers({
@ -28,11 +30,12 @@ function checkPerms(className, token, authProvider) {
) )
} }
async function get_resource(id, token, authProvider) { async function get_resource(resource_name, token, authProvider, resource_id) {
//if (!keycloak.keycloak) { return } //if (!keycloak.keycloak) { return }
const msg = document.getElementById("msg") const msg = document.getElementById("msg")
const resourceElem = document.getElementById('resource') const resourceElem = document.getElementById('resource')
const resp = await fetch("resource/" + id, { const url = resource_id ? `resource/${resource_name}/${resource_id}` : `resource/${resource_name}`
const resp = await fetch(url, {
method: "GET", method: "GET",
headers: new Headers({ headers: new Headers({
"Content-type": "application/json", "Content-type": "application/json",

View file

@ -62,33 +62,42 @@
{% endif %} {% endif %}
<hr> <hr>
<div class="content"> <div class="content">
<p class="center">
Resources validated by scope:
</p>
<div class="links-to-check">
<button resource-id="time" onclick="get_resource('time', '{{ access_token }}', '{{ auth_provider.id }}')">Time</button>
<button resource-id="bs" onclick="get_resource('bs', '{{ access_token }}', '{{ auth_provider.id }}')">BS</button>
</div>
<p> <p>
Resources validated by role: Resources validated by role:
</p> </p>
<div class="links-to-check"> <div class="links-to-check">
<button resource-id="public" onclick="get_resource('public', '{{ access_token }}', '{{ auth_provider.id }}')">Public</button> <button resource-name="public" onclick="get_resource('public', '{{ access_token }}', '{{ auth_provider.id }}')">Public</button>
<button resource-id="protected" onclick="get_resource('protected', '{{ access_token }}', '{{ auth_provider.id }}')">Auth protected content</button> <button resource-name="protected" onclick="get_resource('protected', '{{ access_token }}', '{{ auth_provider.id }}')">Auth protected content</button>
<button resource-id="protected-by-foorole" onclick="get_resource('protected-by-foorole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole protected content</button> <button resource-name="protected-by-foorole" onclick="get_resource('protected-by-foorole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole protected content</button>
<button resource-id="protected-by-foorole-or-barrole" onclick="get_resource('protected-by-foorole-or-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole or barrole protected content</button> <button resource-name="protected-by-foorole-or-barrole" onclick="get_resource('protected-by-foorole-or-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole or barrole protected content</button>
<button resource-id="protected-by-barrole" onclick="get_resource('protected-by-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + barrole protected content</button> <button resource-name="protected-by-barrole" onclick="get_resource('protected-by-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + barrole protected content</button>
<button resource-id="protected-by-foorole-and-barrole" onclick="get_resource('protected-by-foorole-and-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole and barrole protected content</button> <button resource-name="protected-by-foorole-and-barrole" onclick="get_resource('protected-by-foorole-and-barrole', '{{ access_token }}', '{{ auth_provider.id }}')">Auth + foorole and barrole protected content</button>
<button resource-id="fast_api_depends" class="hidden" onclick="get_resource('fast_api_depends', '{{ access_token }}', '{{ auth_provider.id }}')">Using FastAPI Depends</button> <button resource-name="fast_api_depends" class="hidden" onclick="get_resource('fast_api_depends', '{{ access_token }}', '{{ auth_provider.id }}')">Using FastAPI Depends</button>
<!--<button resource-id="introspect" onclick="get_resource('introspect', '{{ access_token }}', '{{ auth_provider.id }}')">Introspect token (401 expected)</button>--> <!--<button resource-name="introspect" onclick="get_resource('introspect', '{{ access_token }}', '{{ auth_provider.id }}')">Introspect token (401 expected)</button>-->
</div> </div>
<!-- XXX confused...
{% if resources %} {% if resources %}
<p> <p>
Resources for this provider: Resources for this provider:
</p> </p>
<div class="links-to-check"> <div class="links-to-check">
{% for resource in resources %} {% for resource in resources %}
<button resource-id="{{ resource.id }}" onclick="get_resource('{{ resource.id }}', '{{ access_token }}', '{{ auth_provider.id }}')">{{ resource.name }}</buttona> <button resource-name="{{ resource.id }}" onclick="get_resource('{{ resource.name }}', '{{ access_token }}', '{{ auth_provider.id }}')">{{ resource.name }}</buttona>
{% endfor %}
</div>
{% endif %}
-->
{% if resource_providers %}
<p>
Resource providers (validated by scope):
</p>
<div class="links-to-check">
{% for name, resource_provider in resource_providers.items() %}
{% if resource_provider.default_resource_id %}
<button resource-name="{{ name }}" resource-id="{{ resource_provider.default_resource_id }}" onclick="get_resource('{{ name }}', '{{ access_token }}', '{{ auth_provider.id }}', '{{ resource_provider.default_resource_id }}')">{{ name }}</buttona>
{% else %}
<button resource-name="{{ name }}" onclick="get_resource('{{ name }}', '{{ access_token }}', '{{ auth_provider.id }}')">{{ name }}</buttona>
{% endif %}
{% endfor %} {% endfor %}
</div> </div>
{% endif %} {% endif %}