diff --git a/src/oidc_test/auth_provider.py b/src/oidc_test/auth_provider.py index bed4596..e50241c 100644 --- a/src/oidc_test/auth_provider.py +++ b/src/oidc_test/auth_provider.py @@ -1,26 +1,41 @@ +from json import JSONDecodeError from typing import Any from jwt import decode import logging +from collections import OrderedDict +from pydantic import ConfigDict from authlib.integrations.starlette_client.apps import StarletteOAuth2App +from httpx import AsyncClient from .settings import AuthProviderSettings, settings +from .models import User logger = logging.getLogger("oidc-test") +class NoPublicKey(Exception): + pass + + class Provider(AuthProviderSettings): - class Config: - arbitrary_types_allowed = True + # To allow authlib_client as StarletteOAuth2App + model_config = ConfigDict(arbitrary_types_allowed=True) # type:ignore authlib_client: StarletteOAuth2App = StarletteOAuth2App(None) + info: dict[str, Any] = {} + unknown_auth_user: User - def decode(self, token: str, verify_signature: bool = True) -> dict[str, Any]: + def decode(self, token: str, verify_signature: bool | None = None) -> dict[str, Any]: """Decode the token with signature check""" + if self.public_key is None: + raise NoPublicKey + if verify_signature is None: + verify_signature = self.skip_verify_signature if settings.debug_token: decoded = decode( token, - self.get_public_key(), + self.public_key, algorithms=[self.signature_alg], audience=["account", "oidc-test", "oidc-test-web"], options={ @@ -31,7 +46,7 @@ class Provider(AuthProviderSettings): logger.debug(str(decoded)) return decode( token, - self.get_public_key(), + self.public_key, algorithms=[self.signature_alg], audience=["account", "oidc-test", "oidc-test-web"], options={ @@ -39,5 +54,42 @@ class Provider(AuthProviderSettings): }, # not settings.insecure.skip_verify_signature}, ) + async def get_info(self): + # Get the public key: + async with AsyncClient() as client: + public_key: str | None = None + if self.info_url is not None: + try: + provider_info = await client.get(self.info_url) + except Exception: + raise NoPublicKey + try: + self.info = provider_info.json() + except JSONDecodeError: + raise NoPublicKey + if "public_key" in self.info: + # For Keycloak + try: + public_key = str(self.info["public_key"]) + except KeyError: + raise NoPublicKey + elif "keys" in self.info: + # For Forgejo/Gitea + try: + public_key = str(self.info["keys"][0]["n"]) + except KeyError: + raise NoPublicKey + if self.public_key_url is not None: + resp = await client.get(self.public_key_url) + public_key = resp.text + if public_key is None: + raise NoPublicKey + self.public_key = "\n".join( + ["-----BEGIN PUBLIC KEY-----", public_key, "-----END PUBLIC KEY-----"] + ) -providers: dict[str, Provider] = {} + def get_session_key(self, userinfo): + return userinfo[self.session_key] + + +providers: OrderedDict[str, Provider] = OrderedDict() diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index e62fe39..8cd5028 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -5,8 +5,7 @@ import logging from fastapi import HTTPException, Request, Depends, status from fastapi.security import OAuth2PasswordBearer from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App -from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError -from httpx import AsyncClient +from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError # from authlib.oauth1.auth import OAuthToken from authlib.oauth2.auth import OAuth2Token @@ -40,7 +39,7 @@ async def update_token( ): """Update the token in the database""" provider = providers[provider_id] - sid: str = provider.decode(token["id_token"])["sid"] + sid: str = provider.get_session_key(provider.decode(token["id_token"])) item = await db.get_token(provider, sid) # update old token item["access_token"] = token["access_token"] @@ -59,7 +58,12 @@ def init_providers(): """Add oidc providers to authlib from the settings and build the providers dict""" for provider_settings in settings.auth.providers: - provider = Provider(**provider_settings.model_dump()) + provider_settings_dict = provider_settings.model_dump() + # Add an anonymous user, that cannot be identified but has provided a valid access token + provider_settings_dict["unknown_auth_user"] = User( + sub="", auth_provider_id=provider_settings.id + ) + provider = Provider(**provider_settings_dict) authlib_oauth.register( name=provider.id, server_metadata_url=provider.openid_configuration, @@ -85,15 +89,6 @@ def init_providers(): init_providers() -async def get_providers_info(): - # Get the public key: - async with AsyncClient() as client: - for provider in providers.values(): - if provider.info_url: - provider_info = await client.get(provider.url) - provider.info = provider_info.json() - - def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None: """Return the oidc_provider from a request object, from the session. It can be used in Depends()""" @@ -166,6 +161,8 @@ async def get_token(request: Request) -> OAuth2Token: try: provider = providers[request.session.get("auth_provider_id", "")] except KeyError: + request.session.pop("auth_provider_id", None) + request.session.pop("user_sub", None) raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid provider") try: return await db.get_token( @@ -239,29 +236,27 @@ async def get_user_from_token( raise HTTPException( status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'" ) - if token == "": + if token == "None": + request.session.pop("auth_provider_id", None) + request.session.pop("user_sub", None) raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token") try: payload = provider.decode(token) - except ExpiredSignatureError as err: - logger.info(f"Expired signature: {err}") + except ExpiredSignatureError: raise HTTPException( status.HTTP_401_UNAUTHORIZED, "Expired signature (refresh not implemented yet)", ) - except InvalidKeyError as err: - logger.info(f"Invalid key: {err}") + except InvalidKeyError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key") - except Exception as err: - logger.info("Cannot decode token, see below") - logger.exception(err) - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot decode token") + except PyJWTError as err: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, f"Cannot decode token: {err.__class__.__name__}" + ) try: user_id = payload["sub"] except KeyError: - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found" - ) + return provider.unknown_auth_user try: user = await db.get_user(user_id) if user.access_token != token: diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index 3493429..659fd13 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -3,6 +3,7 @@ import logging from authlib.oauth2.rfc6749 import OAuth2Token +from jwt import PyJWTError from .models import User, Role from .auth_provider import Provider, providers @@ -35,7 +36,10 @@ class Database: if access_token_decoded is None: assert auth_provider.name is not None provider = providers[auth_provider.id] - access_token_decoded = provider.decode(access_token) + try: + access_token_decoded = provider.decode(access_token) + except PyJWTError: + access_token_decoded = {} user_info["auth_provider_id"] = auth_provider.id user = User(**user_info) user.userinfo = user_info @@ -68,8 +72,7 @@ class Database: async def add_token(self, provider: Provider, token: OAuth2Token) -> None: """Store a token using as key the sid (auth provider's session id) in the id_token""" - assert isinstance(provider, Provider) - sid = token["userinfo"]["sid"] + sid = provider.get_session_key(token["userinfo"]) self.tokens[sid] = token async def get_token( diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 304df92..4018997 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -27,14 +27,13 @@ from authlib.oauth2.rfc6749 import OAuth2Token # from pkce import generate_code_verifier, generate_pkce_pair from .settings import settings -from .auth_provider import Provider, providers +from .auth_provider import NoPublicKey, Provider, providers from .models import User from .auth_utils import ( get_auth_provider, get_auth_provider_or_none, get_current_user_or_none, authlib_oauth, - get_providers_info, get_token_or_none, get_token, update_token, @@ -50,7 +49,12 @@ templates = Jinja2Templates(Path(__file__).parent / "templates") @asynccontextmanager async def lifespan(app: FastAPI): assert app is not None - await get_providers_info() + for provider in list(providers.values()): + try: + await provider.get_info() + except NoPublicKey: + logger.warn(f"Disable {provider.id}: public key not found") + del providers[provider.id] yield @@ -82,12 +86,13 @@ async def home( token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], ) -> HTMLResponse: context = { - "settings": settings.model_dump(), + "show_token": settings.show_token, "user": user, "now": datetime.now(), "auth_provider": provider, } if provider is None or token is None: + context["providers"] = providers context["access_token"] = None context["id_token_parsed"] = None context["access_token_parsed"] = None @@ -96,14 +101,23 @@ async def home( else: context["access_token"] = token["access_token"] context["resources"] = provider.resources - access_token_parsed = provider.decode(token["access_token"], verify_signature=False) - context["access_token_scope"] = access_token_parsed["scope"] + try: + access_token_parsed = provider.decode(token["access_token"], verify_signature=False) + except PyJWTError as err: + access_token_parsed = {"Cannot parse": err.__class__.__name__} + try: + context["access_token_scope"] = access_token_parsed["scope"] + except KeyError: + context["access_token_scope"] = None # context["id_token_parsed"] = pretty_details(user, now) context["id_token_parsed"] = provider.decode(token["id_token"], verify_signature=False) context["access_token_parsed"] = access_token_parsed - context["refresh_token_parsed"] = provider.decode( - token["refresh_token"], verify_signature=False - ) + try: + context["refresh_token_parsed"] = provider.decode( + token["refresh_token"], verify_signature=False + ) + except PyJWTError as err: + context["refresh_token_parsed"] = {"Cannot parse": err.__class__.__name__} return templates.TemplateResponse(name="home.html", request=request, context=context) @@ -144,16 +158,19 @@ async def login(request: Request, auth_provider_id: str) -> RedirectResponse: @app.get("/auth/{auth_provider_id}") -async def auth(request: Request, auth_provider_id: str) -> RedirectResponse: +async def auth( + request: Request, + auth_provider_id: str, +) -> RedirectResponse: """Decrypt the auth token, store it to the session (cookie based) and response to the browser with a redirect to a "welcome user" page. """ try: - authlib_client: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id) - except AttributeError: + provider = providers[auth_provider_id] + except KeyError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") try: - token: OAuth2Token = await authlib_client.authorize_access_token(request) + token: OAuth2Token = await provider.authlib_client.authorize_access_token(request) except OAuthError as error: raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error) # Remember the authlib_client in the session @@ -168,6 +185,14 @@ async def auth(request: Request, auth_provider_id: str) -> RedirectResponse: request.session["auth_provider_id"] = auth_provider_id # User id (sub) given by auth provider sub = userinfo["sub"] + ## Get additional data from userinfo endpoint + # try: + # user_info_from_endpoint = await authlib_client.userinfo( + # token=token, follow_redirects=True + # ) + # except Exception as err: + # logger.warn(f"Cannot get userinfo from endpoint: {err}") + # user_info_from_endpoint = {} # Build and remember the user in the session request.session["user_sub"] = sub # Store the user in the database, which also verifies the token validity and signature @@ -185,9 +210,8 @@ async def auth(request: Request, auth_provider_id: str) -> RedirectResponse: ) assert isinstance(user, User) # Add the provider session id to the session - request.session["sid"] = userinfo["sid"] + request.session["sid"] = provider.get_session_key(userinfo) # Add the token to the db because it is used for logout - provider = providers[auth_provider_id] await db.add_token(provider, token) # Send the user to the home: (s)he is authenticated return RedirectResponse(url=request.url_for("home")) @@ -211,15 +235,16 @@ async def logout( request: Request, provider: Annotated[Provider, Depends(get_auth_provider)], ) -> RedirectResponse: - # Clear session - request.session.pop("user_sub", None) # Get provider's endpoint if ( provider_logout_uri := provider.authlib_client.server_metadata.get("end_session_endpoint") ) is None: - logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}") + logger.warn(f"Cannot find end_session_endpoint for provider {provider.id}") return RedirectResponse(request.url_for("non_compliant_logout")) post_logout_uri = request.url_for("home") + # Clear session + request.session.pop("user_sub", None) + request.session.pop("auth_provider_id", None) try: token = await db.get_token(provider, request.session.pop("sid", None)) except TokenNotInDb: @@ -242,15 +267,16 @@ async def logout( @app.get("/non-compliant-logout") async def non_compliant_logout( request: Request, - provider: Annotated[StarletteOAuth2App, Depends(get_auth_provider)], + provider: Annotated[Provider, Depends(get_auth_provider)], ): """A page for non-compliant OAuth2 servers that we cannot log out.""" - # Clear the remain of the session + # Clear session + request.session.pop("user_sub", None) request.session.pop("auth_provider_id", None) return templates.TemplateResponse( name="non_compliant_logout.html", request=request, - context={"oidc_provider": provider, "home_url": request.url_for("home")}, + context={"auth_provider": provider, "home_url": request.url_for("home")}, ) diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py index eda63a6..7c5250b 100644 --- a/src/oidc_test/models.py +++ b/src/oidc_test/models.py @@ -49,13 +49,13 @@ class User(UserBase): try: access_token_scopes = self.decode_access_token().get("scope", "").split(" ") except Exception as err: - logger.info(f"Access token cannot be decoded: {err}") + logger.debug(f"Cannot find scope because the access token cannot be decoded: {err}") access_token_scopes = [] return scope in set(info_scopes + access_token_scopes) def decode_access_token(self, verify_signature: bool = True): - assert self.access_token is not None - assert self.auth_provider_id is not None + assert self.access_token is not None, "no access_token" + assert self.auth_provider_id is not None, "no auth_provider_id" from .auth_provider import providers return providers[self.auth_provider_id].decode( diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index e5670ed..f7f0433 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import Annotated import logging -from authlib.jose import Key +from authlib.oauth2.auth import OAuth2Token from httpx import AsyncClient from jwt.exceptions import ExpiredSignatureError, InvalidTokenError from fastapi import FastAPI, HTTPException, Depends, status @@ -14,11 +14,12 @@ from fastapi.middleware.cors import CORSMiddleware from .models import User from .auth_utils import ( + get_token_or_none, get_user_from_token, UserWithRole, ) from .settings import settings -from .auth_provider import providers +from .auth_provider import providers, Provider logger = logging.getLogger("oidc-test") @@ -113,23 +114,51 @@ async def get_protected_by_foorole_or_barrole( @resource_server.get("/{id}") -async def get_resource_( +async def get_resource( id: str, - # user: Annotated[User, Depends(get_current_user)], - # oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], - # token: Annotated[OAuth2Token, Depends(get_token)], user: Annotated[User, Depends(get_user_from_token)], -): + token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], +) -> dict | list: """Generic path for testing a resource provided by a provider""" - return await get_resource(id, user) + provider = providers[user.auth_provider_id] + if id in [r.id for r in provider.resources]: + return await get_external_resource( + provider=provider, + id=id, + access_token=token["access_token"] if token else None, + user=user, + ) + return await get_resource_(id, user) -async def get_resource(resource_id: str, user: User) -> dict: +async def get_external_resource( + provider: Provider, id: str, access_token: str | None, user: User +) -> dict | list: + resource = [r for r in provider.resources if r.id == id][0] + async with AsyncClient() as client: + resp = await client.get( + url=provider.url + resource.url, + headers={ + "Content-type": "application/json", + "Authorization": f"Bearer {access_token}", + }, + ) + if resp.is_error: + raise HTTPException(resp.status_code, f"Cannot fetch resource: {resp.reason_phrase}") + resp_length = len(resp.text) + if resp_length > 1024: + return {"msg": f"The resource is too long ({resp_length} bytes) to show here"} + else: + return resp.json() + + +async def get_resource_(resource_id: str, user: User) -> dict: """ Resource processing: build an informative rely as a simple showcase """ + provider = providers[user.auth_provider_id] try: - pname = providers[user.auth_provider_id].name + pname = provider.name except KeyError: pname = "?" resp = { @@ -164,7 +193,6 @@ async def process(user, resource_id, resp): Too simple to be serious. It's a good fit for a plugin architecture for production """ - assert user is not None if resource_id == "time": resp["time"] = datetime.now().strftime("%c") elif resource_id == "bs": diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index 9a789a0..86a2b6b 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -21,6 +21,7 @@ class Resource(BaseModel): id: str name: str + url: str class AuthProviderSettings(BaseModel): @@ -39,10 +40,12 @@ class AuthProviderSettings(BaseModel): info_url: str | None = ( None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key) ) - info: dict[str, str | int] | None = None # Info fetched from info_url, eg. public key public_key: str | None = None + public_key_url: str | None = None signature_alg: str = "RS256" resource_provider_scopes: list[str] = [] + session_key: str = "sid" + skip_verify_signature: bool = True @computed_field @property @@ -64,17 +67,6 @@ class AuthProviderSettings(BaseModel): else: return None - def get_public_key(self) -> str: - """Return the public key formatted for decoding token""" - public_key = self.public_key or (self.info is not None and self.info["public_key"]) - if public_key is None: - raise AttributeError(f"Cannot get public key for {self.name}") - return f""" - -----BEGIN PUBLIC KEY----- - {public_key} - -----END PUBLIC KEY----- - """ - class ResourceProvider(BaseModel): id: str diff --git a/src/oidc_test/static/utils.js b/src/oidc_test/static/utils.js index 8e8ad59..6c9fae4 100644 --- a/src/oidc_test/static/utils.js +++ b/src/oidc_test/static/utils.js @@ -55,15 +55,24 @@ async function get_resource(id, token, authProvider) { msg.innerHTML = "" resourceElem.innerHTML = "" Object.entries(resource).forEach( - ([k, v]) => { + ([key, value]) => { let r = document.createElement('div') let kElem = document.createElement('div') - kElem.innerText = k + kElem.innerText = key kElem.className = "key" let vElem = document.createElement('div') - vElem.innerText = v + if (typeof value == "object") { + Object.entries(value).forEach(v => { + const ne = document.createElement('div') + ne.innerHTML = `${v[0]}: ${v[1]}` + vElem.appendChild(ne) + }) + } + else { + vElem.innerText = value + } vElem.className = "value" - if (k == "sorry") { + if (key == "sorry") { vElem.classList.add("error") } r.appendChild(kElem) diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 7275f2d..da513c9 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -5,25 +5,24 @@ with OpenID Connect and OAuth2 with different providers.

{% if not user %} -
-

Log in with:

- - {% for provider in settings.auth.providers %} - - - - - {% else %} -
There is no authentication provider defined. - Hint: check the settings.yaml file.
- {% endfor %} -
-
{{ provider.name }}
-
{{ provider.hint }} -
-
- {% endif %} - {% if user %} +
+

Log in with:

+ + {% for provider in providers.values() %} + + + + + {% else %} +
There is no authentication provider defined. + Hint: check the settings.yaml file.
+ {% endfor %} +
+
{{ provider.name }}
+
{{ provider.hint }} +
+
+ {% else %}

Hey, {{ user.name }}

{% if user.picture %} @@ -83,22 +82,22 @@
-
-
-
-
{% if resources %}

Resources for this provider:

{% endif %} +
+
+
+
- {% if settings.show_token and id_token_parsed %} + {% if show_token and id_token_parsed %}

diff --git a/src/oidc_test/templates/non_compliant_logout.html b/src/oidc_test/templates/non_compliant_logout.html index 24a96ae..56758de 100644 --- a/src/oidc_test/templates/non_compliant_logout.html +++ b/src/oidc_test/templates/non_compliant_logout.html @@ -6,12 +6,12 @@ authorisation to log in again without asking for credentials.

- This is because {{ oidc_provider.name }} does not provide "end_session_endpoint" in its metadata - (see: {{ oidc_provider._server_metadata_url }}). + This is because {{ auth_provider.name }} does not provide "end_session_endpoint" in its metadata + (see: {{ auth_provider.authlib_client._server_metadata_url }}).

You can just also go back to the application home page, but - it recommended to go to the OIDC provider's site + it recommended to go to the OIDC provider's site and log out explicitely from there.

{% endblock %}