Add resource provider
This commit is contained in:
parent
61be70054b
commit
5b31ef888c
7 changed files with 152 additions and 83 deletions
|
@ -1,22 +1,63 @@
|
|||
from typing import Union
|
||||
from typing import Union, Annotated, Tuple
|
||||
from functools import wraps
|
||||
import logging
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
from fastapi import HTTPException, Request, Depends, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from authlib.oauth2.rfc6749 import OAuth2Token
|
||||
from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App
|
||||
from jwt import decode
|
||||
|
||||
# from authlib.oauth1.auth import OAuthToken
|
||||
# from authlib.oauth2.auth import OAuth2Token
|
||||
|
||||
from .models import User
|
||||
from .database import db
|
||||
from .settings import settings
|
||||
from .database import db, UserNotInDB
|
||||
from .settings import settings, OIDCProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OIDC_PROVIDERS = set([provider.id for provider in settings.oidc.providers])
|
||||
oidc_providers_settings: dict[str, OIDCProvider] = dict([(provider.id, provider) for provider in settings.oidc.providers])
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
def fetch_token(name, request):
|
||||
breakpoint()
|
||||
...
|
||||
# if name in oidc_providers:
|
||||
# model = OAuth2Token
|
||||
# else:
|
||||
# model = OAuthToken
|
||||
|
||||
# token = model.find(name=name, user=request.user)
|
||||
# return token.to_token()
|
||||
|
||||
|
||||
def update_token(*args, **kwargs):
|
||||
breakpoint()
|
||||
...
|
||||
|
||||
|
||||
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
|
||||
|
||||
|
||||
# Add oidc providers to authlib from the settings
|
||||
for id, provider in oidc_providers_settings.items():
|
||||
authlib_oauth.register(
|
||||
name=id,
|
||||
server_metadata_url=provider.openid_configuration,
|
||||
client_kwargs={
|
||||
"scope": "openid email offline_access profile",
|
||||
},
|
||||
client_id=provider.client_id,
|
||||
client_secret=provider.client_secret,
|
||||
api_base_url=provider.url,
|
||||
# For PKCE (not implemented yet):
|
||||
# code_challenge_method="S256",
|
||||
# fetch_token=fetch_token,
|
||||
# update_token=update_token,
|
||||
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
|
||||
)
|
||||
|
||||
def get_oidc_provider_or_none(request: Request) -> StarletteOAuth2App | None:
|
||||
"""Return the oidc_provider from a request object, from the session.
|
||||
|
@ -115,21 +156,34 @@ def get_token_info(token: dict) -> dict:
|
|||
return token_info
|
||||
|
||||
|
||||
def fetch_token(name, request):
|
||||
breakpoint()
|
||||
...
|
||||
# if name in OIDC_PROVIDERS:
|
||||
# model = OAuth2Token
|
||||
# else:
|
||||
# model = OAuthToken
|
||||
|
||||
# token = model.find(name=name, user=request.user)
|
||||
# return token.to_token()
|
||||
|
||||
|
||||
def update_token(*args, **kwargs):
|
||||
breakpoint()
|
||||
...
|
||||
|
||||
|
||||
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
|
||||
async def get_resource_user(
|
||||
token: Annotated[str, Depends(oauth2_scheme)],
|
||||
request: Request,
|
||||
) -> User:
|
||||
# TODO: decode token (ah!)
|
||||
# See https://fastapi.tiangolo.com/tutorial/security/oauth2-jwt/#hash-and-verify-the-passwords
|
||||
if (auth_provider_id := request.headers.get("auth_provider")) is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Request headers must have a 'auth_provider' field")
|
||||
if (auth_provider := oidc_providers_settings.get(auth_provider_id)) is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'")
|
||||
if (key := auth_provider.get_key()) is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Key for provider '{auth_provider_id}' unknown")
|
||||
try:
|
||||
payload = decode(token, key=key, algorithms=["RS256"], audience="oidc-test")
|
||||
except Exception as err:
|
||||
logger.info("Cannot decode token, see below")
|
||||
logger.exception(err)
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot decode token")
|
||||
if (user_id := payload.get('sub')) is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found")
|
||||
try:
|
||||
user = await db.get_user(user_id)
|
||||
except UserNotInDB:
|
||||
logger.info(f"User {user_id} not found in DB, creating it (real apps can behave differently")
|
||||
user = await db.add_user(
|
||||
sub=payload['sub'],
|
||||
user_info=payload,
|
||||
oidc_provider=getattr(authlib_oauth, auth_provider_id),
|
||||
user_info_from_endpoint={}
|
||||
)
|
||||
return user
|
||||
|
|
|
@ -10,6 +10,9 @@ from authlib.oauth2.rfc6749 import OAuth2Token
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserNotInDB(Exception):
|
||||
pass
|
||||
|
||||
class Database:
|
||||
users: dict[str, User] = {}
|
||||
tokens: dict[str, OAuth2Token] = {}
|
||||
|
@ -37,6 +40,8 @@ class Database:
|
|||
return user
|
||||
|
||||
async def get_user(self, sub: str) -> User:
|
||||
if sub not in self.users:
|
||||
raise UserNotInDB
|
||||
return self.users[sub]
|
||||
|
||||
async def add_token(self, token: OAuth2Token, user: User) -> None:
|
||||
|
|
|
@ -13,6 +13,7 @@ from fastapi import Depends, FastAPI, HTTPException, Request, status
|
|||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
from authlib.integrations.base_client import OAuthError
|
||||
|
@ -23,7 +24,7 @@ from authlib.oauth2.rfc6749 import OAuth2Token
|
|||
# from fastapi.security import OpenIdConnect
|
||||
# from pkce import generate_code_verifier, generate_pkce_pair
|
||||
|
||||
from .settings import settings, OIDCProvider
|
||||
from .settings import settings
|
||||
from .models import User
|
||||
from .auth_utils import (
|
||||
get_oidc_provider,
|
||||
|
@ -31,21 +32,37 @@ from .auth_utils import (
|
|||
hasrole,
|
||||
get_current_user_or_none,
|
||||
get_current_user,
|
||||
get_resource_user,
|
||||
authlib_oauth,
|
||||
get_token,
|
||||
oidc_providers_settings,
|
||||
)
|
||||
from .auth_misc import pretty_details
|
||||
from .database import db
|
||||
from .resource_server import get_resource
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
|
||||
templates = Jinja2Templates(Path(__file__).parent / "templates")
|
||||
|
||||
|
||||
origins = [
|
||||
"https://tiptop:3002",
|
||||
"https://philo.ydns.eu/",
|
||||
]
|
||||
|
||||
app = FastAPI(
|
||||
title="OIDC auth test",
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.mount(
|
||||
"/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static"
|
||||
)
|
||||
|
@ -56,32 +73,6 @@ app.add_middleware(
|
|||
secret_key=settings.secret_key,
|
||||
)
|
||||
|
||||
# Add oidc providers to authlib from the settings
|
||||
|
||||
# fastapi_providers: dict[str, OpenIdConnect] = {}
|
||||
oidc_providers_settings: dict[str, OIDCProvider] = {}
|
||||
|
||||
for provider in settings.oidc.providers:
|
||||
authlib_oauth.register(
|
||||
name=provider.id,
|
||||
server_metadata_url=provider.openid_configuration,
|
||||
client_kwargs={
|
||||
"scope": "openid email offline_access profile",
|
||||
},
|
||||
client_id=provider.client_id,
|
||||
client_secret=provider.client_secret,
|
||||
api_base_url=provider.url,
|
||||
# For PKCE (not implemented yet):
|
||||
# code_challenge_method="S256",
|
||||
# fetch_token=fetch_token,
|
||||
# update_token=update_token,
|
||||
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
|
||||
)
|
||||
# fastapi_providers[provider.id] = OpenIdConnect(
|
||||
# openIdConnectUrl=provider.openid_configuration
|
||||
# )
|
||||
oidc_providers_settings[provider.id] = provider
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def home(
|
||||
|
@ -281,43 +272,16 @@ async def non_compliant_logout(
|
|||
|
||||
|
||||
@app.get("/resource/{id}")
|
||||
async def get_resource(
|
||||
async def get_resource_(
|
||||
id: str,
|
||||
request: Request,
|
||||
user: Annotated[User, Depends(get_current_user)],
|
||||
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
||||
token: Annotated[OAuth2Token, Depends(get_token)],
|
||||
# user: Annotated[User, Depends(get_current_user)],
|
||||
# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
||||
# token: Annotated[OAuth2Token, Depends(get_token)],
|
||||
user: Annotated[User, Depends(get_resource_user)],
|
||||
) -> JSONResponse:
|
||||
"""Generic path for testing a resource provided by a provider"""
|
||||
assert user is not None # Just to keep QA checks happy
|
||||
if oidc_provider is None:
|
||||
raise HTTPException(
|
||||
status.HTTP_406_NOT_ACCEPTABLE, detail="No such oidc provider"
|
||||
)
|
||||
if (
|
||||
provider := oidc_providers_settings.get(
|
||||
request.session.get("oidc_provider_id", "")
|
||||
)
|
||||
) is None:
|
||||
raise HTTPException(
|
||||
status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider setting"
|
||||
)
|
||||
try:
|
||||
resource = next(x for x in provider.resources if x.id == id)
|
||||
except StopIteration:
|
||||
raise HTTPException(
|
||||
status.HTTP_406_NOT_ACCEPTABLE, detail="No such resource for this provider"
|
||||
)
|
||||
if (
|
||||
response := await oidc_provider.get(
|
||||
resource.url,
|
||||
# headers={"Authorization": f"token {token['access_token']}"},
|
||||
token=token,
|
||||
)
|
||||
).is_success:
|
||||
return JSONResponse(response.json())
|
||||
else:
|
||||
raise HTTPException(status_code=response.status_code, detail=response.text)
|
||||
return JSONResponse(await get_resource(id, user))
|
||||
|
||||
|
||||
# Routes for test
|
||||
|
|
21
src/oidc_test/resource_server.py
Normal file
21
src/oidc_test/resource_server.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
from datetime import datetime
|
||||
from httpx import AsyncClient
|
||||
|
||||
from .models import User
|
||||
|
||||
async def get_resource(id: str, user: User) -> dict:
|
||||
pname = getattr(user.oidc_provider, "name", "?")
|
||||
resp = {
|
||||
"hello" : f"Hi {user.name} from an OAuth resource provider.",
|
||||
"comment": f"I received a request for '{id}' with an access token signed by {pname}."
|
||||
}
|
||||
if id == "time":
|
||||
resp["time"] = datetime.now().strftime("%c")
|
||||
elif id == "bs":
|
||||
async with AsyncClient() as client:
|
||||
bs = await client.get("https://corporatebs-generator.sameerkumar.website/")
|
||||
resp['bs'] = bs.json().get("phrase", "Sorry, i am out of BS today.")
|
||||
else:
|
||||
resp['sorry'] = f"I don't known how to give '{id}' but i know corporate bs."
|
||||
|
||||
return resp
|
|
@ -36,6 +36,7 @@ class OIDCProvider(BaseModel):
|
|||
hint: str = "No hint"
|
||||
resources: list[Resource] = []
|
||||
account_url_template: str | None = None
|
||||
key: str | None = None
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
|
@ -63,6 +64,15 @@ class OIDCProvider(BaseModel):
|
|||
else:
|
||||
return None
|
||||
|
||||
def get_key(self) -> str | None:
|
||||
"""Return the public key formatted for """
|
||||
if self.key is None:
|
||||
return None
|
||||
return f"""
|
||||
-----BEGIN PUBLIC KEY-----
|
||||
{self.key}
|
||||
-----END PUBLIC KEY-----
|
||||
"""
|
||||
|
||||
class ResourceProvider(BaseModel):
|
||||
id: str
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue