diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index d64b6cf..2fcfc76 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -1,4 +1,4 @@ -from typing import Union, Annotated, Tuple +from typing import Union, Annotated from functools import wraps import logging @@ -18,10 +18,13 @@ from .settings import settings, OIDCProvider logger = logging.getLogger(__name__) -oidc_providers_settings: dict[str, OIDCProvider] = dict([(provider.id, provider) for provider in settings.oidc.providers]) +oidc_providers_settings: dict[str, OIDCProvider] = dict( + [(provider.id, provider) for provider in settings.oidc.providers] +) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + def fetch_token(name, request): breakpoint() ... @@ -43,7 +46,7 @@ authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_t def init_providers(): -# Add oidc providers to authlib from the settings + # Add oidc providers to authlib from the settings for id, provider in oidc_providers_settings.items(): authlib_oauth.register( name=id, @@ -61,8 +64,10 @@ def init_providers(): # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) ) + init_providers() + async def get_providers_info(): # Get the public key: async with AsyncClient() as client: @@ -174,18 +179,35 @@ async def get_user_from_token( request: Request, ) -> User: if (auth_provider_id := request.headers.get("auth_provider")) is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Request headers must have a 'auth_provider' field") - if (auth_provider_settings := oidc_providers_settings.get(auth_provider_id)) is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'") - oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id) - await oidc_provider.load_server_metadata() + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, + "Request headers must have a 'auth_provider' field", + ) + if ( + auth_provider_settings := oidc_providers_settings.get(auth_provider_id) + ) is None: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'" + ) if (key := auth_provider_settings.get_public_key()) is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Key for provider '{auth_provider_id}' unknown") + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, + f"Key for provider '{auth_provider_id}' unknown", + ) try: - payload = decode(token, key=key, algorithms=["RS256"], audience="oidc-test") + payload = decode( + token, + key=key, + algorithms=["RS256"], + audience="oidc-test", + options={"verify_signature": not settings.insecure.skip_verify_signature}, + ) except ExpiredSignatureError as err: logger.info(f"Expired signature: {err}") - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Expired signature (refresh not implemented yet)") + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, + "Expired signature (refresh not implemented yet)", + ) except InvalidKeyError as err: logger.info(f"Invalid key: {err}") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key") @@ -193,16 +215,20 @@ async def get_user_from_token( logger.info("Cannot decode token, see below") logger.exception(err) raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot decode token") - if (user_id := payload.get('sub')) is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found") + if (user_id := payload.get("sub")) is None: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found" + ) try: user = await db.get_user(user_id) except UserNotInDB: - logger.info(f"User {user_id} not found in DB, creating it (real apps can behave differently") + logger.info( + f"User {user_id} not found in DB, creating it (real apps can behave differently" + ) user = await db.add_user( - sub=payload['sub'], + sub=payload["sub"], user_info=payload, oidc_provider=getattr(authlib_oauth, auth_provider_id), - user_info_from_endpoint={} + user_info_from_endpoint={}, ) return user diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 6942759..f6ce405 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -53,16 +53,14 @@ origins = [ "https://philo.ydns.eu/", ] + @asynccontextmanager async def lifespan(app: FastAPI): await get_providers_info() yield -app = FastAPI( - title="OIDC auth test", - lifespan=lifespan -) +app = FastAPI(title="OIDC auth test", lifespan=lifespan) app.add_middleware( @@ -284,7 +282,6 @@ async def non_compliant_logout( @app.get("/resource/{id}") async def get_resource_( id: str, - request: Request, # user: Annotated[User, Depends(get_current_user)], # oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], # token: Annotated[OAuth2Token, Depends(get_token)], @@ -294,7 +291,7 @@ async def get_resource_( return JSONResponse(await get_resource(id, user)) -# Routes for test +# Routes for RBAC based tests @app.get("/public") diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index a9dfe3a..0186064 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -1,21 +1,76 @@ from datetime import datetime +import logging + from httpx import AsyncClient +from fastapi import HTTPException, status +from jwt import ExpiredSignatureError, InvalidKeyError, decode from .models import User +from .auth_utils import oidc_providers_settings +from .settings import settings + +logger = logging.getLogger(__name__) + async def get_resource(id: str, user: User) -> dict: pname = getattr(user.oidc_provider, "name", "?") resp = { - "hello" : f"Hi {user.name} from an OAuth resource provider.", - "comment": f"I received a request for '{id}' with an access token signed by {pname}." + "hello": f"Hi {user.name} from an OAuth resource provider.", + "comment": f"I received a request for '{id}' with an access token signed by {pname}.", } - if id == "time": - resp["time"] = datetime.now().strftime("%c") - elif id == "bs": - async with AsyncClient() as client: - bs = await client.get("https://corporatebs-generator.sameerkumar.website/") - resp['bs'] = bs.json().get("phrase", "Sorry, i am out of BS today.") + scope = f"get:{id}" + user_scopes = user.userinfo["scope"].split(" ") + if scope in user_scopes: + if id == "time": + resp["time"] = datetime.now().strftime("%c") + elif id == "bs": + async with AsyncClient() as client: + bs = await client.get( + "https://corporatebs-generator.sameerkumar.website/" + ) + resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.") + else: + resp["sorry"] = f"I don't known how to give '{id}' but i know corporate bs." else: - resp['sorry'] = f"I don't known how to give '{id}' but i know corporate bs." - + resp["sorry"] = ( + f"I don't serve the ressource {id} to you because" + "there is no scope {scope} in the access token," + ) return resp + + # assert user.oidc_provider is not None + ### Get some info (TODO: refactor) + # if (auth_provider_id := user.oidc_provider.name) is None: + # raise HTTPException( + # status.HTTP_401_UNAUTHORIZED, + # "Request headers must have a 'auth_provider' field", + # ) + # if ( + # auth_provider_settings := oidc_providers_settings.get(auth_provider_id) + # ) is None: + # raise HTTPException( + # status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'" + # ) + # if (key := auth_provider_settings.get_public_key()) is None: + # raise HTTPException( + # status.HTTP_401_UNAUTHORIZED, + # f"Key for provider '{auth_provider_id}' unknown", + # ) + # logger.warn(f"refresh with scope {scope}") + # breakpoint() + # refreshed_auth_info = await user.oidc_provider.fetch_access_token(scope=scope) + ### Decode the new token + # try: + # payload = decode( + # refreshed_auth_info["access_token"], + # key=key, + # algorithms=["RS256"], + # audience="account", + # options={"verify_signature": not settings.insecure.skip_verify_signature}, + # ) + # except ExpiredSignatureError as err: + # logger.info(f"Expired signature: {err}") + # raise HTTPException( + # status.HTTP_401_UNAUTHORIZED, + # "Expired signature (refresh not implemented yet)", + # ) diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index 00c3f23..399fbac 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -36,8 +36,12 @@ class OIDCProvider(BaseModel): hint: str = "No hint" resources: list[Resource] = [] account_url_template: str | None = None - info_url: str | None = None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key) - info: dict[str, str | int ] | None = None # Info fetched from info_url, eg. public key + info_url: str | None = ( + None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key) + ) + info: dict[str, str | int] | None = ( + None # Info fetched from info_url, eg. public key + ) public_key: str | None = None @computed_field @@ -68,7 +72,9 @@ class OIDCProvider(BaseModel): def get_public_key(self) -> str | None: """Return the public key formatted for decoding token""" - public_key = self.public_key or (self.info is not None and self.info["public_key"]) + public_key = self.public_key or ( + self.info is not None and self.info["public_key"] + ) if public_key is None: return None return f""" @@ -77,6 +83,7 @@ class OIDCProvider(BaseModel): -----END PUBLIC KEY----- """ + class ResourceProvider(BaseModel): id: str name: str @@ -90,15 +97,22 @@ class OIDCSettings(BaseModel): swagger_provider: str = "" +class Insecure(BaseModel): + """Warning: changing these defaults are only suitable for debugging""" + + skip_verify_signature: bool = False + + class Settings(BaseSettings): """Settings wil be read from an .env file""" + model_config = SettingsConfigDict(env_nested_delimiter="__") + oidc: OIDCSettings = OIDCSettings() resource_providers: list[ResourceProvider] = [] secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16)) log: bool = False - - model_config = SettingsConfigDict(env_nested_delimiter="__") + insecure: Insecure = Insecure() @classmethod def settings_customise_sources(