Resource server: read the required scope in access token

This commit is contained in:
phil 2025-01-30 20:40:04 +01:00
parent f910834736
commit b3e19b3e40
4 changed files with 129 additions and 37 deletions

View file

@ -1,4 +1,4 @@
from typing import Union, Annotated, Tuple
from typing import Union, Annotated
from functools import wraps
import logging
@ -18,10 +18,13 @@ from .settings import settings, OIDCProvider
logger = logging.getLogger(__name__)
oidc_providers_settings: dict[str, OIDCProvider] = dict([(provider.id, provider) for provider in settings.oidc.providers])
oidc_providers_settings: dict[str, OIDCProvider] = dict(
[(provider.id, provider) for provider in settings.oidc.providers]
)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def fetch_token(name, request):
breakpoint()
...
@ -61,8 +64,10 @@ def init_providers():
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
)
init_providers()
async def get_providers_info():
# Get the public key:
async with AsyncClient() as client:
@ -174,18 +179,35 @@ async def get_user_from_token(
request: Request,
) -> User:
if (auth_provider_id := request.headers.get("auth_provider")) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Request headers must have a 'auth_provider' field")
if (auth_provider_settings := oidc_providers_settings.get(auth_provider_id)) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'")
oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id)
await oidc_provider.load_server_metadata()
raise HTTPException(
status.HTTP_401_UNAUTHORIZED,
"Request headers must have a 'auth_provider' field",
)
if (
auth_provider_settings := oidc_providers_settings.get(auth_provider_id)
) is None:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
)
if (key := auth_provider_settings.get_public_key()) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Key for provider '{auth_provider_id}' unknown")
raise HTTPException(
status.HTTP_401_UNAUTHORIZED,
f"Key for provider '{auth_provider_id}' unknown",
)
try:
payload = decode(token, key=key, algorithms=["RS256"], audience="oidc-test")
payload = decode(
token,
key=key,
algorithms=["RS256"],
audience="oidc-test",
options={"verify_signature": not settings.insecure.skip_verify_signature},
)
except ExpiredSignatureError as err:
logger.info(f"Expired signature: {err}")
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Expired signature (refresh not implemented yet)")
raise HTTPException(
status.HTTP_401_UNAUTHORIZED,
"Expired signature (refresh not implemented yet)",
)
except InvalidKeyError as err:
logger.info(f"Invalid key: {err}")
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key")
@ -193,16 +215,20 @@ async def get_user_from_token(
logger.info("Cannot decode token, see below")
logger.exception(err)
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot decode token")
if (user_id := payload.get('sub')) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found")
if (user_id := payload.get("sub")) is None:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found"
)
try:
user = await db.get_user(user_id)
except UserNotInDB:
logger.info(f"User {user_id} not found in DB, creating it (real apps can behave differently")
logger.info(
f"User {user_id} not found in DB, creating it (real apps can behave differently"
)
user = await db.add_user(
sub=payload['sub'],
sub=payload["sub"],
user_info=payload,
oidc_provider=getattr(authlib_oauth, auth_provider_id),
user_info_from_endpoint={}
user_info_from_endpoint={},
)
return user

View file

@ -53,16 +53,14 @@ origins = [
"https://philo.ydns.eu/",
]
@asynccontextmanager
async def lifespan(app: FastAPI):
await get_providers_info()
yield
app = FastAPI(
title="OIDC auth test",
lifespan=lifespan
)
app = FastAPI(title="OIDC auth test", lifespan=lifespan)
app.add_middleware(
@ -284,7 +282,6 @@ async def non_compliant_logout(
@app.get("/resource/{id}")
async def get_resource_(
id: str,
request: Request,
# user: Annotated[User, Depends(get_current_user)],
# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
# token: Annotated[OAuth2Token, Depends(get_token)],
@ -294,7 +291,7 @@ async def get_resource_(
return JSONResponse(await get_resource(id, user))
# Routes for test
# Routes for RBAC based tests
@app.get("/public")

View file

@ -1,21 +1,76 @@
from datetime import datetime
import logging
from httpx import AsyncClient
from fastapi import HTTPException, status
from jwt import ExpiredSignatureError, InvalidKeyError, decode
from .models import User
from .auth_utils import oidc_providers_settings
from .settings import settings
logger = logging.getLogger(__name__)
async def get_resource(id: str, user: User) -> dict:
pname = getattr(user.oidc_provider, "name", "?")
resp = {
"hello": f"Hi {user.name} from an OAuth resource provider.",
"comment": f"I received a request for '{id}' with an access token signed by {pname}."
"comment": f"I received a request for '{id}' with an access token signed by {pname}.",
}
scope = f"get:{id}"
user_scopes = user.userinfo["scope"].split(" ")
if scope in user_scopes:
if id == "time":
resp["time"] = datetime.now().strftime("%c")
elif id == "bs":
async with AsyncClient() as client:
bs = await client.get("https://corporatebs-generator.sameerkumar.website/")
resp['bs'] = bs.json().get("phrase", "Sorry, i am out of BS today.")
bs = await client.get(
"https://corporatebs-generator.sameerkumar.website/"
)
resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.")
else:
resp['sorry'] = f"I don't known how to give '{id}' but i know corporate bs."
resp["sorry"] = f"I don't known how to give '{id}' but i know corporate bs."
else:
resp["sorry"] = (
f"I don't serve the ressource {id} to you because"
"there is no scope {scope} in the access token,"
)
return resp
# assert user.oidc_provider is not None
### Get some info (TODO: refactor)
# if (auth_provider_id := user.oidc_provider.name) is None:
# raise HTTPException(
# status.HTTP_401_UNAUTHORIZED,
# "Request headers must have a 'auth_provider' field",
# )
# if (
# auth_provider_settings := oidc_providers_settings.get(auth_provider_id)
# ) is None:
# raise HTTPException(
# status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
# )
# if (key := auth_provider_settings.get_public_key()) is None:
# raise HTTPException(
# status.HTTP_401_UNAUTHORIZED,
# f"Key for provider '{auth_provider_id}' unknown",
# )
# logger.warn(f"refresh with scope {scope}")
# breakpoint()
# refreshed_auth_info = await user.oidc_provider.fetch_access_token(scope=scope)
### Decode the new token
# try:
# payload = decode(
# refreshed_auth_info["access_token"],
# key=key,
# algorithms=["RS256"],
# audience="account",
# options={"verify_signature": not settings.insecure.skip_verify_signature},
# )
# except ExpiredSignatureError as err:
# logger.info(f"Expired signature: {err}")
# raise HTTPException(
# status.HTTP_401_UNAUTHORIZED,
# "Expired signature (refresh not implemented yet)",
# )

View file

@ -36,8 +36,12 @@ class OIDCProvider(BaseModel):
hint: str = "No hint"
resources: list[Resource] = []
account_url_template: str | None = None
info_url: str | None = None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key)
info: dict[str, str | int ] | None = None # Info fetched from info_url, eg. public key
info_url: str | None = (
None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key)
)
info: dict[str, str | int] | None = (
None # Info fetched from info_url, eg. public key
)
public_key: str | None = None
@computed_field
@ -68,7 +72,9 @@ class OIDCProvider(BaseModel):
def get_public_key(self) -> str | None:
"""Return the public key formatted for decoding token"""
public_key = self.public_key or (self.info is not None and self.info["public_key"])
public_key = self.public_key or (
self.info is not None and self.info["public_key"]
)
if public_key is None:
return None
return f"""
@ -77,6 +83,7 @@ class OIDCProvider(BaseModel):
-----END PUBLIC KEY-----
"""
class ResourceProvider(BaseModel):
id: str
name: str
@ -90,15 +97,22 @@ class OIDCSettings(BaseModel):
swagger_provider: str = ""
class Insecure(BaseModel):
"""Warning: changing these defaults are only suitable for debugging"""
skip_verify_signature: bool = False
class Settings(BaseSettings):
"""Settings wil be read from an .env file"""
model_config = SettingsConfigDict(env_nested_delimiter="__")
oidc: OIDCSettings = OIDCSettings()
resource_providers: list[ResourceProvider] = []
secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16))
log: bool = False
model_config = SettingsConfigDict(env_nested_delimiter="__")
insecure: Insecure = Insecure()
@classmethod
def settings_customise_sources(