Add resource provider
Some checks failed
/ build (push) Failing after 14s
/ test (push) Successful in 4s

This commit is contained in:
phil 2025-01-28 19:48:35 +01:00
parent 61be70054b
commit 5b31ef888c
7 changed files with 152 additions and 83 deletions

View file

@ -9,10 +9,12 @@ dependencies = [
"authlib>=1.4.0", "authlib>=1.4.0",
"cachetools>=5.5.0", "cachetools>=5.5.0",
"fastapi[standard]>=0.115.6", "fastapi[standard]>=0.115.6",
"httpx>=0.28.1",
"itsdangerous>=2.2.0", "itsdangerous>=2.2.0",
"passlib[bcrypt]>=1.7.4", "passlib[bcrypt]>=1.7.4",
"pkce>=1.0.3", "pkce>=1.0.3",
"pydantic-settings>=2.7.1", "pydantic-settings>=2.7.1",
"pyjwt>=2.10.1",
"python-jose[cryptography]>=3.3.0", "python-jose[cryptography]>=3.3.0",
"requests>=2.32.3", "requests>=2.32.3",
"sqlmodel>=0.0.22", "sqlmodel>=0.0.22",

View file

@ -1,22 +1,63 @@
from typing import Union from typing import Union, Annotated, Tuple
from functools import wraps from functools import wraps
import logging import logging
from fastapi import HTTPException, Request, status from fastapi import HTTPException, Request, Depends, status
from fastapi.security import OAuth2PasswordBearer
from authlib.oauth2.rfc6749 import OAuth2Token from authlib.oauth2.rfc6749 import OAuth2Token
from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App
from jwt import decode
# from authlib.oauth1.auth import OAuthToken # from authlib.oauth1.auth import OAuthToken
# from authlib.oauth2.auth import OAuth2Token # from authlib.oauth2.auth import OAuth2Token
from .models import User from .models import User
from .database import db from .database import db, UserNotInDB
from .settings import settings from .settings import settings, OIDCProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
OIDC_PROVIDERS = set([provider.id for provider in settings.oidc.providers]) oidc_providers_settings: dict[str, OIDCProvider] = dict([(provider.id, provider) for provider in settings.oidc.providers])
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def fetch_token(name, request):
breakpoint()
...
# if name in oidc_providers:
# model = OAuth2Token
# else:
# model = OAuthToken
# token = model.find(name=name, user=request.user)
# return token.to_token()
def update_token(*args, **kwargs):
breakpoint()
...
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
# Add oidc providers to authlib from the settings
for id, provider in oidc_providers_settings.items():
authlib_oauth.register(
name=id,
server_metadata_url=provider.openid_configuration,
client_kwargs={
"scope": "openid email offline_access profile",
},
client_id=provider.client_id,
client_secret=provider.client_secret,
api_base_url=provider.url,
# For PKCE (not implemented yet):
# code_challenge_method="S256",
# fetch_token=fetch_token,
# update_token=update_token,
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
)
def get_oidc_provider_or_none(request: Request) -> StarletteOAuth2App | None: def get_oidc_provider_or_none(request: Request) -> StarletteOAuth2App | None:
"""Return the oidc_provider from a request object, from the session. """Return the oidc_provider from a request object, from the session.
@ -115,21 +156,34 @@ def get_token_info(token: dict) -> dict:
return token_info return token_info
def fetch_token(name, request): async def get_resource_user(
breakpoint() token: Annotated[str, Depends(oauth2_scheme)],
... request: Request,
# if name in OIDC_PROVIDERS: ) -> User:
# model = OAuth2Token # TODO: decode token (ah!)
# else: # See https://fastapi.tiangolo.com/tutorial/security/oauth2-jwt/#hash-and-verify-the-passwords
# model = OAuthToken if (auth_provider_id := request.headers.get("auth_provider")) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Request headers must have a 'auth_provider' field")
# token = model.find(name=name, user=request.user) if (auth_provider := oidc_providers_settings.get(auth_provider_id)) is None:
# return token.to_token() raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'")
if (key := auth_provider.get_key()) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Key for provider '{auth_provider_id}' unknown")
def update_token(*args, **kwargs): try:
breakpoint() payload = decode(token, key=key, algorithms=["RS256"], audience="oidc-test")
... except Exception as err:
logger.info("Cannot decode token, see below")
logger.exception(err)
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot decode token")
if (user_id := payload.get('sub')) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found")
try:
user = await db.get_user(user_id)
except UserNotInDB:
logger.info(f"User {user_id} not found in DB, creating it (real apps can behave differently")
user = await db.add_user(
sub=payload['sub'],
user_info=payload,
oidc_provider=getattr(authlib_oauth, auth_provider_id),
user_info_from_endpoint={}
)
return user

View file

@ -10,6 +10,9 @@ from authlib.oauth2.rfc6749 import OAuth2Token
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UserNotInDB(Exception):
pass
class Database: class Database:
users: dict[str, User] = {} users: dict[str, User] = {}
tokens: dict[str, OAuth2Token] = {} tokens: dict[str, OAuth2Token] = {}
@ -37,6 +40,8 @@ class Database:
return user return user
async def get_user(self, sub: str) -> User: async def get_user(self, sub: str) -> User:
if sub not in self.users:
raise UserNotInDB
return self.users[sub] return self.users[sub]
async def add_token(self, token: OAuth2Token, user: User) -> None: async def add_token(self, token: OAuth2Token, user: User) -> None:

View file

@ -13,6 +13,7 @@ from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from authlib.integrations.base_client import OAuthError from authlib.integrations.base_client import OAuthError
@ -23,7 +24,7 @@ from authlib.oauth2.rfc6749 import OAuth2Token
# from fastapi.security import OpenIdConnect # from fastapi.security import OpenIdConnect
# from pkce import generate_code_verifier, generate_pkce_pair # from pkce import generate_code_verifier, generate_pkce_pair
from .settings import settings, OIDCProvider from .settings import settings
from .models import User from .models import User
from .auth_utils import ( from .auth_utils import (
get_oidc_provider, get_oidc_provider,
@ -31,21 +32,37 @@ from .auth_utils import (
hasrole, hasrole,
get_current_user_or_none, get_current_user_or_none,
get_current_user, get_current_user,
get_resource_user,
authlib_oauth, authlib_oauth,
get_token, get_token,
oidc_providers_settings,
) )
from .auth_misc import pretty_details from .auth_misc import pretty_details
from .database import db from .database import db
from .resource_server import get_resource
logger = logging.getLogger("uvicorn.error") logger = logging.getLogger("uvicorn.error")
templates = Jinja2Templates(Path(__file__).parent / "templates") templates = Jinja2Templates(Path(__file__).parent / "templates")
origins = [
"https://tiptop:3002",
"https://philo.ydns.eu/",
]
app = FastAPI( app = FastAPI(
title="OIDC auth test", title="OIDC auth test",
) )
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.mount( app.mount(
"/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static" "/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static"
) )
@ -56,32 +73,6 @@ app.add_middleware(
secret_key=settings.secret_key, secret_key=settings.secret_key,
) )
# Add oidc providers to authlib from the settings
# fastapi_providers: dict[str, OpenIdConnect] = {}
oidc_providers_settings: dict[str, OIDCProvider] = {}
for provider in settings.oidc.providers:
authlib_oauth.register(
name=provider.id,
server_metadata_url=provider.openid_configuration,
client_kwargs={
"scope": "openid email offline_access profile",
},
client_id=provider.client_id,
client_secret=provider.client_secret,
api_base_url=provider.url,
# For PKCE (not implemented yet):
# code_challenge_method="S256",
# fetch_token=fetch_token,
# update_token=update_token,
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
)
# fastapi_providers[provider.id] = OpenIdConnect(
# openIdConnectUrl=provider.openid_configuration
# )
oidc_providers_settings[provider.id] = provider
@app.get("/") @app.get("/")
async def home( async def home(
@ -281,43 +272,16 @@ async def non_compliant_logout(
@app.get("/resource/{id}") @app.get("/resource/{id}")
async def get_resource( async def get_resource_(
id: str, id: str,
request: Request, request: Request,
user: Annotated[User, Depends(get_current_user)], # user: Annotated[User, Depends(get_current_user)],
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], # oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
token: Annotated[OAuth2Token, Depends(get_token)], # token: Annotated[OAuth2Token, Depends(get_token)],
user: Annotated[User, Depends(get_resource_user)],
) -> JSONResponse: ) -> JSONResponse:
"""Generic path for testing a resource provided by a provider""" """Generic path for testing a resource provided by a provider"""
assert user is not None # Just to keep QA checks happy return JSONResponse(await get_resource(id, user))
if oidc_provider is None:
raise HTTPException(
status.HTTP_406_NOT_ACCEPTABLE, detail="No such oidc provider"
)
if (
provider := oidc_providers_settings.get(
request.session.get("oidc_provider_id", "")
)
) is None:
raise HTTPException(
status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider setting"
)
try:
resource = next(x for x in provider.resources if x.id == id)
except StopIteration:
raise HTTPException(
status.HTTP_406_NOT_ACCEPTABLE, detail="No such resource for this provider"
)
if (
response := await oidc_provider.get(
resource.url,
# headers={"Authorization": f"token {token['access_token']}"},
token=token,
)
).is_success:
return JSONResponse(response.json())
else:
raise HTTPException(status_code=response.status_code, detail=response.text)
# Routes for test # Routes for test

View file

@ -0,0 +1,21 @@
from datetime import datetime
from httpx import AsyncClient
from .models import User
async def get_resource(id: str, user: User) -> dict:
pname = getattr(user.oidc_provider, "name", "?")
resp = {
"hello" : f"Hi {user.name} from an OAuth resource provider.",
"comment": f"I received a request for '{id}' with an access token signed by {pname}."
}
if id == "time":
resp["time"] = datetime.now().strftime("%c")
elif id == "bs":
async with AsyncClient() as client:
bs = await client.get("https://corporatebs-generator.sameerkumar.website/")
resp['bs'] = bs.json().get("phrase", "Sorry, i am out of BS today.")
else:
resp['sorry'] = f"I don't known how to give '{id}' but i know corporate bs."
return resp

View file

@ -36,6 +36,7 @@ class OIDCProvider(BaseModel):
hint: str = "No hint" hint: str = "No hint"
resources: list[Resource] = [] resources: list[Resource] = []
account_url_template: str | None = None account_url_template: str | None = None
key: str | None = None
@computed_field @computed_field
@property @property
@ -63,6 +64,15 @@ class OIDCProvider(BaseModel):
else: else:
return None return None
def get_key(self) -> str | None:
"""Return the public key formatted for """
if self.key is None:
return None
return f"""
-----BEGIN PUBLIC KEY-----
{self.key}
-----END PUBLIC KEY-----
"""
class ResourceProvider(BaseModel): class ResourceProvider(BaseModel):
id: str id: str

13
uv.lock generated
View file

@ -488,10 +488,12 @@ dependencies = [
{ name = "authlib" }, { name = "authlib" },
{ name = "cachetools" }, { name = "cachetools" },
{ name = "fastapi", extra = ["standard"] }, { name = "fastapi", extra = ["standard"] },
{ name = "httpx" },
{ name = "itsdangerous" }, { name = "itsdangerous" },
{ name = "passlib", extra = ["bcrypt"] }, { name = "passlib", extra = ["bcrypt"] },
{ name = "pkce" }, { name = "pkce" },
{ name = "pydantic-settings" }, { name = "pydantic-settings" },
{ name = "pyjwt" },
{ name = "python-jose", extra = ["cryptography"] }, { name = "python-jose", extra = ["cryptography"] },
{ name = "requests" }, { name = "requests" },
{ name = "sqlmodel" }, { name = "sqlmodel" },
@ -508,10 +510,12 @@ requires-dist = [
{ name = "authlib", specifier = ">=1.4.0" }, { name = "authlib", specifier = ">=1.4.0" },
{ name = "cachetools", specifier = ">=5.5.0" }, { name = "cachetools", specifier = ">=5.5.0" },
{ name = "fastapi", extras = ["standard"], specifier = ">=0.115.6" }, { name = "fastapi", extras = ["standard"], specifier = ">=0.115.6" },
{ name = "httpx", specifier = ">=0.28.1" },
{ name = "itsdangerous", specifier = ">=2.2.0" }, { name = "itsdangerous", specifier = ">=2.2.0" },
{ name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" }, { name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" },
{ name = "pkce", specifier = ">=1.0.3" }, { name = "pkce", specifier = ">=1.0.3" },
{ name = "pydantic-settings", specifier = ">=2.7.1" }, { name = "pydantic-settings", specifier = ">=2.7.1" },
{ name = "pyjwt", specifier = ">=2.10.1" },
{ name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" }, { name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" },
{ name = "requests", specifier = ">=2.32.3" }, { name = "requests", specifier = ">=2.32.3" },
{ name = "sqlmodel", specifier = ">=0.0.22" }, { name = "sqlmodel", specifier = ">=0.0.22" },
@ -694,6 +698,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/f7/3f/01c8b82017c199075f8f788d0d906b9ffbbc5a47dc9918a945e13d5a2bda/pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a", size = 1205513 }, { url = "https://files.pythonhosted.org/packages/f7/3f/01c8b82017c199075f8f788d0d906b9ffbbc5a47dc9918a945e13d5a2bda/pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a", size = 1205513 },
] ]
[[package]]
name = "pyjwt"
version = "2.10.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997 },
]
[[package]] [[package]]
name = "pytest" name = "pytest"
version = "8.3.4" version = "8.3.4"