diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index ed3350c..33ca582 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -6,14 +6,14 @@ from fastapi import HTTPException, Request, Depends, status from fastapi.security import OAuth2PasswordBearer from authlib.oauth2.rfc6749 import OAuth2Token from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App -from jwt import ExpiredSignatureError, InvalidKeyError, decode +from jwt import ExpiredSignatureError, InvalidKeyError from httpx import AsyncClient # from authlib.oauth1.auth import OAuthToken # from authlib.oauth2.auth import OAuth2Token from .models import User -from .database import db, UserNotInDB +from .database import TokenNotInDb, db, UserNotInDB from .settings import settings, OIDCProvider logger = logging.getLogger(__name__) @@ -126,9 +126,10 @@ async def get_current_user(request: Request) -> User: async def get_token(request: Request) -> OAuth2Token: """Return the token from a request object, from the session. It can be used in Depends()""" - if (token := await db.get_token(request.session.get("token"))) is None: + try: + return await db.get_token(request.session["token"]) + except (KeyError, TokenNotInDb): raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token") - return token async def get_current_user_or_none(request: Request) -> User | None: @@ -189,19 +190,8 @@ async def get_user_from_token( raise HTTPException( status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'" ) - if (key := auth_provider_settings.get_public_key()) is None: - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, - f"Key for provider '{auth_provider_id}' unknown", - ) try: - payload = decode( - token, - key=key, - algorithms=["RS256"], - audience="oidc-test", - options={"verify_signature": not settings.insecure.skip_verify_signature}, - ) + payload = auth_provider_settings.decode(token) except ExpiredSignatureError as err: logger.info(f"Expired signature: {err}") raise HTTPException( diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index 5dec7fc..b2cf1b9 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -14,6 +14,10 @@ class UserNotInDB(Exception): pass +class TokenNotInDb(Exception): + pass + + class Database: users: dict[str, User] = {} tokens: dict[str, OAuth2Token] = {} @@ -50,8 +54,11 @@ class Database: async def add_token(self, token: OAuth2Token, user: User) -> None: self.tokens[token["id_token"]] = token - async def get_token(self, id_token: str) -> OAuth2Token | None: - return self.tokens.get(id_token) + async def get_token(self, id_token: str) -> OAuth2Token: + try: + return self.tokens[id_token] + except KeyError: + raise TokenNotInDb db = Database() diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 739dd1b..ef19245 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -187,11 +187,20 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: # Build and remember the user in the session request.session["user_sub"] = sub # Store the user in the database + try: + oidc_provider_settings = oidc_providers_settings[oidc_provider_id] + access_token = oidc_provider_settings.decode(token["access_token"]) + except Exception: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, + detail="Cannot decode token or verify its signature", + ) user = await db.add_user( sub, user_info=userinfo, oidc_provider=oidc_provider, user_info_from_endpoint=user_info_from_endpoint, + access_token=access_token, ) # Add the id_token to the session request.session["token"] = token["id_token"] @@ -213,14 +222,14 @@ async def account( oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], ) -> RedirectResponse: if ( - provider := oidc_providers_settings.get( + oidc_provider_settings := oidc_providers_settings.get( request.session.get("oidc_provider_id", "") ) ) is None: raise HTTPException( - status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider setting" + status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider settings" ) - return RedirectResponse(f"{provider.account_url}") + return RedirectResponse(f"{oidc_provider_settings.account_url}") @app.get("/logout") diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index 329b9c0..46d857d 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -4,6 +4,7 @@ import random from typing import Type, Tuple from pathlib import Path +from jwt import decode from pydantic import BaseModel, computed_field, AnyUrl from pydantic_settings import ( BaseSettings, @@ -43,6 +44,7 @@ class OIDCProvider(BaseModel): None # Info fetched from info_url, eg. public key ) public_key: str | None = None + signature_alg: str = "RS256" @computed_field @property @@ -70,19 +72,29 @@ class OIDCProvider(BaseModel): else: return None - def get_public_key(self) -> str | None: + def get_public_key(self) -> str: """Return the public key formatted for decoding token""" public_key = self.public_key or ( self.info is not None and self.info["public_key"] ) if public_key is None: - return None + raise AttributeError(f"Cannot get public key for {self.name}") return f""" -----BEGIN PUBLIC KEY----- {public_key} -----END PUBLIC KEY----- """ + def decode(self, token: str) -> dict: + """Decode the token with signature check""" + return decode( + token, + self.get_public_key(), + algorithms=[self.signature_alg], + audience=["oidc-test", "oidc-test-web"], + options={"verify_signature": not settings.insecure.skip_verify_signature}, + ) + class ResourceProvider(BaseModel): id: str diff --git a/src/oidc_test/static/styles.css b/src/oidc_test/static/styles.css index a4a0178..cc84736 100644 --- a/src/oidc_test/static/styles.css +++ b/src/oidc_test/static/styles.css @@ -104,6 +104,7 @@ hr { .role { padding: 3px 6px; background-color: #44228840; + border-radius: 6px; } /* For home */