Use pydantic on ResourceServer
This commit is contained in:
parent
0464047f8a
commit
381ce1ebc1
3 changed files with 16 additions and 10 deletions
|
@ -58,7 +58,7 @@ async def lifespan(app: FastAPI):
|
||||||
try:
|
try:
|
||||||
await provider.get_info()
|
await provider.get_info()
|
||||||
except NoPublicKey:
|
except NoPublicKey:
|
||||||
logger.warn(f"Disable {provider.id}: public key not found")
|
logger.warning(f"Disable {provider.id}: public key not found")
|
||||||
del providers[provider.id]
|
del providers[provider.id]
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
@ -300,7 +300,6 @@ async def refresh(
|
||||||
await update_token(provider.id, new_token)
|
await update_token(provider.id, new_token)
|
||||||
return RedirectResponse(url=request.url_for("home"))
|
return RedirectResponse(url=request.url_for("home"))
|
||||||
|
|
||||||
|
|
||||||
# Snippet for running standalone
|
# Snippet for running standalone
|
||||||
# Mostly useful for the --version option,
|
# Mostly useful for the --version option,
|
||||||
# as running with uvicorn is easy and provides better flexibility, eg.
|
# as running with uvicorn is easy and provides better flexibility, eg.
|
||||||
|
|
|
@ -17,20 +17,20 @@ class ProcessError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ResourceProvider:
|
class ResourceProvider(BaseModel):
|
||||||
name: str
|
|
||||||
scope_required: str | None = None
|
scope_required: str | None = None
|
||||||
default_resource_id: str | None = None
|
default_resource_id: str | None = None
|
||||||
|
|
||||||
def __init__(self, name: str):
|
def __init__(self, name: str):
|
||||||
self.name = name
|
super().__init__()
|
||||||
|
self.__name__ = name
|
||||||
|
|
||||||
async def process(self, user: User, resource_id: str | None = None) -> ProcessResult:
|
async def process(self, user: User, resource_id: str | None = None) -> ProcessResult:
|
||||||
logger.warning(f"{self.name} should define a process method")
|
logger.warning(f"{self.__name__} should define a process method")
|
||||||
return ProcessResult()
|
return ProcessResult()
|
||||||
|
|
||||||
|
|
||||||
class ResourceRegistry:
|
class ResourceRegistry(BaseModel):
|
||||||
resource_providers: dict[str, ResourceProvider] = {}
|
resource_providers: dict[str, ResourceProvider] = {}
|
||||||
|
|
||||||
def make_registry(self):
|
def make_registry(self):
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from datetime import datetime
|
from typing import Annotated, Any
|
||||||
from typing import Annotated
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from authlib.oauth2.auth import OAuth2Token
|
from authlib.oauth2.auth import OAuth2Token
|
||||||
|
@ -21,7 +20,7 @@ from oidc_test.auth.utils import (
|
||||||
from oidc_test.auth_providers import providers
|
from oidc_test.auth_providers import providers
|
||||||
from oidc_test.settings import settings
|
from oidc_test.settings import settings
|
||||||
from oidc_test.models import User
|
from oidc_test.models import User
|
||||||
from oidc_test.registry import ProcessError, ProcessResult, registry
|
from oidc_test.registry import ProcessError, ProcessResult, ResourceProvider, registry
|
||||||
|
|
||||||
logger = logging.getLogger("oidc-test")
|
logger = logging.getLogger("oidc-test")
|
||||||
|
|
||||||
|
@ -48,6 +47,14 @@ resource_server.add_middleware(
|
||||||
# Routes for RBAC based tests
|
# Routes for RBAC based tests
|
||||||
|
|
||||||
|
|
||||||
|
@resource_server.get("/")
|
||||||
|
async def resources() -> dict[str, dict[str, Any]]:
|
||||||
|
return {
|
||||||
|
"internal": {},
|
||||||
|
"plugins": registry.resource_providers
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@resource_server.get("/public")
|
@resource_server.get("/public")
|
||||||
async def public() -> dict:
|
async def public() -> dict:
|
||||||
return {"msg": "Not protected"}
|
return {"msg": "Not protected"}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue