Resource server: read the required scope in access token
All checks were successful
/ build (push) Successful in 15s
/ test (push) Successful in 5s

This commit is contained in:
phil 2025-01-30 20:40:04 +01:00
parent f910834736
commit b3e19b3e40
4 changed files with 129 additions and 37 deletions

View file

@ -1,4 +1,4 @@
from typing import Union, Annotated, Tuple from typing import Union, Annotated
from functools import wraps from functools import wraps
import logging import logging
@ -18,10 +18,13 @@ from .settings import settings, OIDCProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
oidc_providers_settings: dict[str, OIDCProvider] = dict([(provider.id, provider) for provider in settings.oidc.providers]) oidc_providers_settings: dict[str, OIDCProvider] = dict(
[(provider.id, provider) for provider in settings.oidc.providers]
)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def fetch_token(name, request): def fetch_token(name, request):
breakpoint() breakpoint()
... ...
@ -61,8 +64,10 @@ def init_providers():
# client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience)
) )
init_providers() init_providers()
async def get_providers_info(): async def get_providers_info():
# Get the public key: # Get the public key:
async with AsyncClient() as client: async with AsyncClient() as client:
@ -174,18 +179,35 @@ async def get_user_from_token(
request: Request, request: Request,
) -> User: ) -> User:
if (auth_provider_id := request.headers.get("auth_provider")) is None: if (auth_provider_id := request.headers.get("auth_provider")) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Request headers must have a 'auth_provider' field") raise HTTPException(
if (auth_provider_settings := oidc_providers_settings.get(auth_provider_id)) is None: status.HTTP_401_UNAUTHORIZED,
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'") "Request headers must have a 'auth_provider' field",
oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id) )
await oidc_provider.load_server_metadata() if (
auth_provider_settings := oidc_providers_settings.get(auth_provider_id)
) is None:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
)
if (key := auth_provider_settings.get_public_key()) is None: if (key := auth_provider_settings.get_public_key()) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Key for provider '{auth_provider_id}' unknown") raise HTTPException(
status.HTTP_401_UNAUTHORIZED,
f"Key for provider '{auth_provider_id}' unknown",
)
try: try:
payload = decode(token, key=key, algorithms=["RS256"], audience="oidc-test") payload = decode(
token,
key=key,
algorithms=["RS256"],
audience="oidc-test",
options={"verify_signature": not settings.insecure.skip_verify_signature},
)
except ExpiredSignatureError as err: except ExpiredSignatureError as err:
logger.info(f"Expired signature: {err}") logger.info(f"Expired signature: {err}")
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Expired signature (refresh not implemented yet)") raise HTTPException(
status.HTTP_401_UNAUTHORIZED,
"Expired signature (refresh not implemented yet)",
)
except InvalidKeyError as err: except InvalidKeyError as err:
logger.info(f"Invalid key: {err}") logger.info(f"Invalid key: {err}")
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key")
@ -193,16 +215,20 @@ async def get_user_from_token(
logger.info("Cannot decode token, see below") logger.info("Cannot decode token, see below")
logger.exception(err) logger.exception(err)
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot decode token") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot decode token")
if (user_id := payload.get('sub')) is None: if (user_id := payload.get("sub")) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found") raise HTTPException(
status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found"
)
try: try:
user = await db.get_user(user_id) user = await db.get_user(user_id)
except UserNotInDB: except UserNotInDB:
logger.info(f"User {user_id} not found in DB, creating it (real apps can behave differently") logger.info(
f"User {user_id} not found in DB, creating it (real apps can behave differently"
)
user = await db.add_user( user = await db.add_user(
sub=payload['sub'], sub=payload["sub"],
user_info=payload, user_info=payload,
oidc_provider=getattr(authlib_oauth, auth_provider_id), oidc_provider=getattr(authlib_oauth, auth_provider_id),
user_info_from_endpoint={} user_info_from_endpoint={},
) )
return user return user

View file

@ -53,16 +53,14 @@ origins = [
"https://philo.ydns.eu/", "https://philo.ydns.eu/",
] ]
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
await get_providers_info() await get_providers_info()
yield yield
app = FastAPI( app = FastAPI(title="OIDC auth test", lifespan=lifespan)
title="OIDC auth test",
lifespan=lifespan
)
app.add_middleware( app.add_middleware(
@ -284,7 +282,6 @@ async def non_compliant_logout(
@app.get("/resource/{id}") @app.get("/resource/{id}")
async def get_resource_( async def get_resource_(
id: str, id: str,
request: Request,
# user: Annotated[User, Depends(get_current_user)], # user: Annotated[User, Depends(get_current_user)],
# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], # oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
# token: Annotated[OAuth2Token, Depends(get_token)], # token: Annotated[OAuth2Token, Depends(get_token)],
@ -294,7 +291,7 @@ async def get_resource_(
return JSONResponse(await get_resource(id, user)) return JSONResponse(await get_resource(id, user))
# Routes for test # Routes for RBAC based tests
@app.get("/public") @app.get("/public")

View file

@ -1,21 +1,76 @@
from datetime import datetime from datetime import datetime
import logging
from httpx import AsyncClient from httpx import AsyncClient
from fastapi import HTTPException, status
from jwt import ExpiredSignatureError, InvalidKeyError, decode
from .models import User from .models import User
from .auth_utils import oidc_providers_settings
from .settings import settings
logger = logging.getLogger(__name__)
async def get_resource(id: str, user: User) -> dict: async def get_resource(id: str, user: User) -> dict:
pname = getattr(user.oidc_provider, "name", "?") pname = getattr(user.oidc_provider, "name", "?")
resp = { resp = {
"hello": f"Hi {user.name} from an OAuth resource provider.", "hello": f"Hi {user.name} from an OAuth resource provider.",
"comment": f"I received a request for '{id}' with an access token signed by {pname}." "comment": f"I received a request for '{id}' with an access token signed by {pname}.",
} }
scope = f"get:{id}"
user_scopes = user.userinfo["scope"].split(" ")
if scope in user_scopes:
if id == "time": if id == "time":
resp["time"] = datetime.now().strftime("%c") resp["time"] = datetime.now().strftime("%c")
elif id == "bs": elif id == "bs":
async with AsyncClient() as client: async with AsyncClient() as client:
bs = await client.get("https://corporatebs-generator.sameerkumar.website/") bs = await client.get(
resp['bs'] = bs.json().get("phrase", "Sorry, i am out of BS today.") "https://corporatebs-generator.sameerkumar.website/"
)
resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.")
else: else:
resp['sorry'] = f"I don't known how to give '{id}' but i know corporate bs." resp["sorry"] = f"I don't known how to give '{id}' but i know corporate bs."
else:
resp["sorry"] = (
f"I don't serve the ressource {id} to you because"
"there is no scope {scope} in the access token,"
)
return resp return resp
# assert user.oidc_provider is not None
### Get some info (TODO: refactor)
# if (auth_provider_id := user.oidc_provider.name) is None:
# raise HTTPException(
# status.HTTP_401_UNAUTHORIZED,
# "Request headers must have a 'auth_provider' field",
# )
# if (
# auth_provider_settings := oidc_providers_settings.get(auth_provider_id)
# ) is None:
# raise HTTPException(
# status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
# )
# if (key := auth_provider_settings.get_public_key()) is None:
# raise HTTPException(
# status.HTTP_401_UNAUTHORIZED,
# f"Key for provider '{auth_provider_id}' unknown",
# )
# logger.warn(f"refresh with scope {scope}")
# breakpoint()
# refreshed_auth_info = await user.oidc_provider.fetch_access_token(scope=scope)
### Decode the new token
# try:
# payload = decode(
# refreshed_auth_info["access_token"],
# key=key,
# algorithms=["RS256"],
# audience="account",
# options={"verify_signature": not settings.insecure.skip_verify_signature},
# )
# except ExpiredSignatureError as err:
# logger.info(f"Expired signature: {err}")
# raise HTTPException(
# status.HTTP_401_UNAUTHORIZED,
# "Expired signature (refresh not implemented yet)",
# )

View file

@ -36,8 +36,12 @@ class OIDCProvider(BaseModel):
hint: str = "No hint" hint: str = "No hint"
resources: list[Resource] = [] resources: list[Resource] = []
account_url_template: str | None = None account_url_template: str | None = None
info_url: str | None = None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key) info_url: str | None = (
info: dict[str, str | int ] | None = None # Info fetched from info_url, eg. public key None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key)
)
info: dict[str, str | int] | None = (
None # Info fetched from info_url, eg. public key
)
public_key: str | None = None public_key: str | None = None
@computed_field @computed_field
@ -68,7 +72,9 @@ class OIDCProvider(BaseModel):
def get_public_key(self) -> str | None: def get_public_key(self) -> str | None:
"""Return the public key formatted for decoding token""" """Return the public key formatted for decoding token"""
public_key = self.public_key or (self.info is not None and self.info["public_key"]) public_key = self.public_key or (
self.info is not None and self.info["public_key"]
)
if public_key is None: if public_key is None:
return None return None
return f""" return f"""
@ -77,6 +83,7 @@ class OIDCProvider(BaseModel):
-----END PUBLIC KEY----- -----END PUBLIC KEY-----
""" """
class ResourceProvider(BaseModel): class ResourceProvider(BaseModel):
id: str id: str
name: str name: str
@ -90,15 +97,22 @@ class OIDCSettings(BaseModel):
swagger_provider: str = "" swagger_provider: str = ""
class Insecure(BaseModel):
"""Warning: changing these defaults are only suitable for debugging"""
skip_verify_signature: bool = False
class Settings(BaseSettings): class Settings(BaseSettings):
"""Settings wil be read from an .env file""" """Settings wil be read from an .env file"""
model_config = SettingsConfigDict(env_nested_delimiter="__")
oidc: OIDCSettings = OIDCSettings() oidc: OIDCSettings = OIDCSettings()
resource_providers: list[ResourceProvider] = [] resource_providers: list[ResourceProvider] = []
secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16)) secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16))
log: bool = False log: bool = False
insecure: Insecure = Insecure()
model_config = SettingsConfigDict(env_nested_delimiter="__")
@classmethod @classmethod
def settings_customise_sources( def settings_customise_sources(