From 2fe7536c53098826523413bf668f9e1aaf44365b Mon Sep 17 00:00:00 2001 From: phil Date: Sat, 18 Jan 2025 06:20:44 +0100 Subject: [PATCH 01/79] Remove OAuthToken from db (use authlib dict); basic OAuth2 service provider with Forgejo --- TODO | 3 ++ pyproject.toml | 1 + src/oidc_test/auth_utils.py | 13 ++++-- src/oidc_test/database.py | 14 ++++--- src/oidc_test/main.py | 68 +++++++++++++++++++++++++++---- src/oidc_test/models.py | 40 ++++-------------- src/oidc_test/settings.py | 2 +- src/oidc_test/templates/base.html | 2 +- src/oidc_test/templates/home.html | 2 + uv.lock | 11 +++++ 10 files changed, 106 insertions(+), 50 deletions(-) create mode 100644 TODO diff --git a/TODO b/TODO new file mode 100644 index 0000000..5d7e575 --- /dev/null +++ b/TODO @@ -0,0 +1,3 @@ +https://docs.authlib.org/en/latest/oauth/2/intro.html#intro-oauth2 + +https://www.keycloak.org/docs/latest/authorization_services/index.html diff --git a/pyproject.toml b/pyproject.toml index 36fed63..4509e5b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "fastapi[standard]>=0.115.6", "itsdangerous>=2.2.0", "passlib[bcrypt]>=1.7.4", + "pkce>=1.0.3", "pydantic-settings>=2.7.1", "python-jose[cryptography]>=3.3.0", "requests>=2.32.3", diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index 14f82cd..1f026b3 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -4,12 +4,13 @@ from datetime import datetime import logging from fastapi import HTTPException, Request, status +from authlib.oauth2.rfc6749 import OAuth2Token from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App # from authlib.oauth1.auth import OAuthToken # from authlib.oauth2.auth import OAuth2Token -from .models import OAuth2Token, User +from .models import User from .database import db from .settings import settings @@ -23,7 +24,7 @@ def get_provider(request: Request) -> StarletteOAuth2App: It can be used in Depends()""" if (oidc_provider_id := request.session.get("oidc_provider_id")) is None: raise HTTPException( - status.HTTP_500_INTERNAL_SERVER_ERROR, + status.HTTP_503_SERVICE_UNAVAILABLE, "Not logged in (no provider in session)", ) try: @@ -43,7 +44,7 @@ async def get_current_user(request: Request) -> User: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token unknown") user = await db.get_user(user_sub) ## Check if the token is expired - if token.expires_at < datetime.timestamp(datetime.now()): + if token.is_expired(): oidc_provider = get_provider(request=request) ## Ask a new refresh token from the provider logger.info(f"Token expired for user {user.name}") @@ -117,4 +118,10 @@ def update_token(*args, **kwargs): ... +async def get_token(request: Request) -> OAuth2Token: + if (token := await db.get_token(request.session.get("token"))) is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token") + return token + + authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index f30d14d..4b3f529 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -1,9 +1,11 @@ -# Implement a fake in-memory database interface for demo purpose +"""Fake in-memory database interface for demo purpose""" + import logging from authlib.integrations.starlette_client.apps import StarletteOAuth2App -from .models import User, OAuth2Token, Role +from .models import User, Role +from authlib.oauth2.rfc6749 import OAuth2Token logger = logging.getLogger(__name__) @@ -37,11 +39,11 @@ class Database: async def get_user(self, sub: str) -> User: return self.users[sub] - async def add_token(self, token_dict: dict, user: User) -> None: - self.tokens[token_dict['id_token']] = OAuth2Token.from_dict(token_dict=token_dict, user=user) + async def add_token(self, token: OAuth2Token, user: User) -> None: + self.tokens[token["id_token"]] = token - async def get_token(self, name) -> OAuth2Token | None: - return self.tokens.get(name) + async def get_token(self, id_token: str) -> OAuth2Token | None: + return self.tokens.get(id_token) db = Database() diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index e4a00f2..b96763b 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -10,12 +10,15 @@ from urllib.parse import urlencode from httpx import HTTPError from fastapi import Depends, FastAPI, HTTPException, Request, status -from fastapi.responses import HTMLResponse, RedirectResponse +from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse from fastapi.templating import Jinja2Templates from fastapi.security import OpenIdConnect from starlette.middleware.sessions import SessionMiddleware from authlib.integrations.starlette_client.apps import StarletteOAuth2App -from authlib.integrations.starlette_client import OAuthError +from authlib.integrations.base_client import OAuthError +from authlib.integrations.httpx_client import AsyncOAuth2Client +from authlib.oauth2.rfc6749 import OAuth2Token +from pkce import generate_code_verifier, generate_pkce_pair from .settings import settings from .models import User @@ -25,6 +28,7 @@ from .auth_utils import ( get_current_user_or_none, get_current_user, authlib_oauth, + get_token, ) from .auth_misc import pretty_details from .database import db @@ -77,8 +81,7 @@ for provider in settings.oidc.providers: @app.get("/login/{oidc_provider_id}") async def login(request: Request, oidc_provider_id: str) -> RedirectResponse: - """Login with the provider id, - by giving the browser a redirect to its authorize page. + """Login with the provider id, giving the browser a redirect to its authorize page. The provider is expected to send the browser back to our own /auth/{oidc_provider_id} url with the token. """ @@ -87,9 +90,20 @@ async def login(request: Request, oidc_provider_id: str) -> RedirectResponse: provider_: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id) except AttributeError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") + if ( + code_challenge_method := _providers[oidc_provider_id].code_challenge_method + ) is not None: + client = AsyncOAuth2Client(..., code_challenge_method=code_challenge_method) + code_verifier = generate_code_verifier() + logger.debug("TODO: PKCE") + else: + code_verifier = None try: response = await provider_.authorize_redirect( - request, redirect_uri, access_type="offline" + request, + redirect_uri, + access_type="offline", + code_verifier=code_verifier, ) return response except HTTPError: @@ -106,7 +120,7 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: except AttributeError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") try: - token = await oidc_provider.authorize_access_token(request) + token: OAuth2Token = await oidc_provider.authorize_access_token(request) except OAuthError as error: raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error) # Remember the oidc_provider in the session @@ -166,7 +180,7 @@ async def logout( logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}") return RedirectResponse(request.url_for("non_compliant_logout")) post_logout_uri = request.url_for("home") - if (id_token := await db.get_token(request.session.pop("token", None))) is None: + if (token := await db.get_token(request.session.pop("token", None))) is None: logger.warn("No session in db for the token") return RedirectResponse(request.url_for("home")) logout_url = ( @@ -175,7 +189,7 @@ async def logout( + urlencode( { "post_logout_redirect_uri": post_logout_uri, - "id_token_hint": id_token.raw_id_token, + "id_token_hint": token["id_token"], "cliend_id": "oidc_local_test", } ) @@ -260,6 +274,44 @@ async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse: return HTMLResponse("

Only users with foorole or barrole can see this

") +@app.get("/introspect") +async def get_introspect( + request: Request, + provider: Annotated[StarletteOAuth2App, Depends(get_provider)], + token: Annotated[OAuth2Token, Depends(get_token)], +) -> JSONResponse: + if ( + response := await provider.get( + provider.server_metadata["introspection_endpoint"], + token=token, + ) + ).is_success: + return response.json() + else: + raise HTTPException(status_code=response.status_code, detail=response.text) + + +@app.get("/oauth2-forgejo-test") +async def get_forgejo_user_info( + request: Request, + user: Annotated[User, Depends(get_current_user)], + provider: Annotated[StarletteOAuth2App, Depends(get_provider)], + token: Annotated[OAuth2Token, Depends(get_token)], +) -> HTMLResponse: + if ( + response := await provider.get( + "/api/v1/user/repos", + # headers={"Authorization": f"token {token['access_token']}"}, + token=token, + ) + ).is_success: + repos = response.json() + names = [repo["name"] for repo in repos] + return HTMLResponse(f"{user.name} has {len(repos)} repos: {', '.join(names)}") + else: + raise HTTPException(status_code=response.status_code, detail=response.text) + + # @app.get("/fast_api_depends") # def fast_api_depends( # token: Annotated[str, Depends(fastapi_providers["Keycloak"])] diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py index 8c25a69..3d484aa 100644 --- a/src/oidc_test/models.py +++ b/src/oidc_test/models.py @@ -1,8 +1,16 @@ from functools import cached_property from typing import Self -from pydantic import computed_field, AnyHttpUrl, EmailStr, ConfigDict +from pydantic import ( + computed_field, + AnyHttpUrl, + EmailStr, + ConfigDict, + GetCoreSchemaHandler, +) +from pydantic_core import CoreSchema, core_schema from authlib.integrations.starlette_client.apps import StarletteOAuth2App +from authlib.oauth2.rfc6749 import OAuth2Token as OAuth2Token_authlib from sqlmodel import SQLModel, Field @@ -45,33 +53,3 @@ class User(UserBase): @cached_property def roles_as_set(self) -> set[str]: return set([role.name for role in self.roles]) - - -class OAuth2Token(SQLModel): - name: str = Field(primary_key=True) - token_type: str # = Field(max_length=40) - access_token: str # = Field(max_length=2000) - raw_id_token: str - refresh_token: str # = Field(max_length=200) - expires_at: int # = PositiveIntegerField() - user: User # = ForeignKey(User) - - def to_token(self): - return dict( - access_token=self.access_token, - token_type=self.token_type, - refresh_token=self.refresh_token, - expires_at=self.expires_at, - ) - - @classmethod - def from_dict(cls, token_dict: dict, user: User) -> Self: - return cls( - name=token_dict["access_token"], - access_token=token_dict["access_token"], - raw_id_token=token_dict["id_token"], - token_type=token_dict["token_type"], - refresh_token=token_dict["refresh_token"], - expires_at=token_dict["expires_at"], - user=user, - ) diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index d2d9c55..0f41a2e 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -20,7 +20,7 @@ class OIDCProvider(BaseModel): client_id: str client_secret: str = "" # For PKCE (not implemented yet) - # code_challenge_method: str | None = None + code_challenge_method: str | None = None hint: str = "No hint" @computed_field diff --git a/src/oidc_test/templates/base.html b/src/oidc_test/templates/base.html index 7558cc7..30f6194 100644 --- a/src/oidc_test/templates/base.html +++ b/src/oidc_test/templates/base.html @@ -120,7 +120,7 @@ if (xmlHttp.readyState == 4) { elem.classList.add("hasResponseStatus") elem.classList.add("status-" + xmlHttp.status) - elem.title = "Response code: " + xmlHttp.status + elem.title = "Response code: " + xmlHttp.status + " - " + xmlHttp.statusText } } xmlHttp.open("GET", elem.href, true) // true for asynchronous diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index fcebd0f..eaf2de7 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -109,6 +109,8 @@ Auth + foorole and barrole protected content Other + OAuth2 test (forgejo user info) + Introspect token {% if user_info_details %}
diff --git a/uv.lock b/uv.lock index 60f5572..6ceb4ca 100644 --- a/uv.lock +++ b/uv.lock @@ -490,6 +490,7 @@ dependencies = [ { name = "fastapi", extra = ["standard"] }, { name = "itsdangerous" }, { name = "passlib", extra = ["bcrypt"] }, + { name = "pkce" }, { name = "pydantic-settings" }, { name = "python-jose", extra = ["cryptography"] }, { name = "requests" }, @@ -509,6 +510,7 @@ requires-dist = [ { name = "fastapi", extras = ["standard"], specifier = ">=0.115.6" }, { name = "itsdangerous", specifier = ">=2.2.0" }, { name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" }, + { name = "pkce", specifier = ">=1.0.3" }, { name = "pydantic-settings", specifier = ">=2.7.1" }, { name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" }, { name = "requests", specifier = ">=2.32.3" }, @@ -565,6 +567,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523", size = 63772 }, ] +[[package]] +name = "pkce" +version = "1.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/29/ea/ddd845c2ec21bf1e8555c782b32dc39b82f0b12764feb9f73ccbb2470f13/pkce-1.0.3.tar.gz", hash = "sha256:9775fd76d8a743d39b87df38af1cd04a58c9b5a5242d5a6350ef343d06814ab6", size = 2757 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/15/51/52c22ec0812d25f5bf297a01153604bfa7bfa59ed66f6cd8345beb3c2b2a/pkce-1.0.3-py3-none-any.whl", hash = "sha256:55927e24c7d403b2491ebe182b95d9dcb1807643243d47e3879fbda5aad4471d", size = 3200 }, +] + [[package]] name = "pluggy" version = "1.5.0" From b96bfa870a2b3e81dc31c2f4c5c0bbc5ffd77f00 Mon Sep 17 00:00:00 2001 From: phil Date: Sat, 18 Jan 2025 14:23:01 +0100 Subject: [PATCH 02/79] Fix token introspection link (should be 401) --- src/oidc_test/main.py | 3 ++- src/oidc_test/templates/home.html | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index b96763b..90ab910 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -281,9 +281,10 @@ async def get_introspect( token: Annotated[OAuth2Token, Depends(get_token)], ) -> JSONResponse: if ( - response := await provider.get( + response := await provider.post( provider.server_metadata["introspection_endpoint"], token=token, + data={"token": token["access_token"]}, ) ).is_success: return response.json() diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index eaf2de7..2d858a3 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -110,7 +110,7 @@ Other OAuth2 test (forgejo user info) - Introspect token + Introspect token (401 expected)
{% if user_info_details %}
From 17fabd21c976c936e33f00a547bf57997c9e9f7e Mon Sep 17 00:00:00 2001 From: phil Date: Sat, 18 Jan 2025 14:24:28 +0100 Subject: [PATCH 03/79] Cosmetic --- src/oidc_test/templates/home.html | 1 + 1 file changed, 1 insertion(+) diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 2d858a3..80e7133 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -113,6 +113,7 @@ Introspect token (401 expected)
{% if user_info_details %} +

User info

-
Now is: {{ now }}
+
Now is: {{ now.strftime("%T, %D") }}
{% endif %} {% endblock %} From 5b70d4bbea52e702510f85d8c74c3fa983278419 Mon Sep 17 00:00:00 2001 From: phil Date: Sun, 19 Jan 2025 14:54:08 +0100 Subject: [PATCH 06/79] Fix non complient logout --- src/oidc_test/main.py | 2 +- src/oidc_test/templates/non_compliant_logout.html | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index c835b7c..ee50025 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -245,7 +245,7 @@ async def non_compliant_logout( return templates.TemplateResponse( name="non_compliant_logout.html", request=request, - context={"provider": provider, "home_url": request.url_for("home")}, + context={"oidc_provider": oidc_provider, "home_url": request.url_for("home")}, ) diff --git a/src/oidc_test/templates/non_compliant_logout.html b/src/oidc_test/templates/non_compliant_logout.html index 2f5b247..24a96ae 100644 --- a/src/oidc_test/templates/non_compliant_logout.html +++ b/src/oidc_test/templates/non_compliant_logout.html @@ -6,12 +6,12 @@ authorisation to log in again without asking for credentials.

- This is because {{ provider.name }} does not provide "end_session_endpoint" in its metadata - (see: {{ provider._server_metadata_url }}). + This is because {{ oidc_provider.name }} does not provide "end_session_endpoint" in its metadata + (see: {{ oidc_provider._server_metadata_url }}).

You can just also go back to the application home page, but - it recommended to go to the provider's site + it recommended to go to the OIDC provider's site and log out explicitely from there.

{% endblock %} From 90cfdb66dd28c0e2efc5fbde4ba31e41f09e8ae7 Mon Sep 17 00:00:00 2001 From: phil Date: Sun, 19 Jan 2025 16:27:12 +0100 Subject: [PATCH 07/79] Cleanup --- src/oidc_test/main.py | 49 +++++++++++++++++++++---------- src/oidc_test/templates/home.html | 2 -- 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index ee50025..92cddf7 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -13,13 +13,15 @@ from fastapi import Depends, FastAPI, HTTPException, Request, status from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse from fastapi.templating import Jinja2Templates -from fastapi.security import OpenIdConnect from starlette.middleware.sessions import SessionMiddleware from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.base_client import OAuthError -from authlib.integrations.httpx_client import AsyncOAuth2Client from authlib.oauth2.rfc6749 import OAuth2Token -from pkce import generate_code_verifier, generate_pkce_pair + +# TODO: PKCE +# from authlib.integrations.httpx_client import AsyncOAuth2Client +# from fastapi.security import OpenIdConnect +# from pkce import generate_code_verifier, generate_pkce_pair from .settings import settings, OIDCProvider from .models import User @@ -91,7 +93,12 @@ async def home( ) -> HTMLResponse: now = datetime.now() if oidc_provider and ( - (provider := providers_settings.get(oidc_provider.name)) is not None + ( + provider := providers_settings.get( + request.session.get("oidc_provider_id", "") + ) + ) + is not None ): resources = provider.resources else: @@ -127,22 +134,22 @@ async def login(request: Request, oidc_provider_id: str) -> RedirectResponse: provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id) except AttributeError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") - if ( - code_challenge_method := providers_settings[ - oidc_provider_id - ].code_challenge_method - ) is not None: - client = AsyncOAuth2Client(..., code_challenge_method=code_challenge_method) - code_verifier = generate_code_verifier() - logger.debug("TODO: PKCE") - else: - code_verifier = None + # if ( + # code_challenge_method := providers_settings[ + # oidc_provider_id + # ].code_challenge_method + # ) is not None: + # #client = AsyncOAuth2Client(..., code_challenge_method=code_challenge_method) + # code_verifier = generate_code_verifier() + # logger.debug("TODO: PKCE") + # else: + # code_verifier = None try: response = await provider.authorize_redirect( request, redirect_uri, access_type="offline", - code_verifier=code_verifier, + code_verifier=None, ) return response except HTTPError: @@ -261,11 +268,14 @@ async def get_resource( token: Annotated[OAuth2Token, Depends(get_token)], ) -> JSONResponse: """Generic path for testing a resource provided by a provider""" + assert user is not None if oidc_provider is None: raise HTTPException( status.HTTP_406_NOT_ACCEPTABLE, detail="No such oidc provider" ) - if (provider := providers_settings.get(oidc_provider.name)) is None: + if ( + provider := providers_settings.get(request.session.get("oidc_provider_id", "")) + ) is None: raise HTTPException( status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider setting" ) @@ -299,18 +309,21 @@ async def public() -> HTMLResponse: async def get_protected( user: Annotated[User, Depends(get_current_user)] ) -> HTMLResponse: + assert user is not None return HTMLResponse("

Only authenticated users can see this

") @app.get("/protected-by-foorole") @hasrole("foorole") async def get_protected_by_foorole(request: Request) -> HTMLResponse: + assert request is not None return HTMLResponse("

Only users with foorole can see this

") @app.get("/protected-by-barrole") @hasrole("barrole") async def get_protected_by_barrole(request: Request) -> HTMLResponse: + assert request is not None return HTMLResponse("

Protected by barrole

") @@ -318,12 +331,14 @@ async def get_protected_by_barrole(request: Request) -> HTMLResponse: @hasrole("barrole") @hasrole("foorole") async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse: + assert request is not None return HTMLResponse("

Only users with foorole and barrole can see this

") @app.get("/protected-by-foorole-or-barrole") @hasrole(["foorole", "barrole"]) async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse: + assert request is not None return HTMLResponse("

Only users with foorole or barrole can see this

") @@ -333,6 +348,7 @@ async def get_introspect( oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], token: Annotated[OAuth2Token, Depends(get_token)], ) -> JSONResponse: + assert request is not None if ( response := await oidc_provider.post( oidc_provider.server_metadata["introspection_endpoint"], @@ -352,6 +368,7 @@ async def get_forgejo_user_info( oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], token: Annotated[OAuth2Token, Depends(get_token)], ) -> HTMLResponse: + assert request is not None if ( response := await oidc_provider.get( "/api/v1/user/repos", diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index ab4dc77..ac732e4 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -58,8 +58,6 @@ Auth + barrole protected content Auth + foorole and barrole protected content - Other - OAuth2 test (forgejo user info) Introspect token (401 expected) {% if resources %} From 5f2901d55896d9a191e58c7ca161c2304a684daa Mon Sep 17 00:00:00 2001 From: phil Date: Sun, 19 Jan 2025 16:45:21 +0100 Subject: [PATCH 08/79] Cleanup --- src/oidc_test/main.py | 41 +++++++++++------------------------------ 1 file changed, 11 insertions(+), 30 deletions(-) diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 92cddf7..ed4c6a1 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -249,6 +249,8 @@ async def non_compliant_logout( oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], ): """A page for non-compliant OAuth2 servers that we cannot log out.""" + # Clear the remain of the session + request.session.pop("oidc_provider_id", None) return templates.TemplateResponse( name="non_compliant_logout.html", request=request, @@ -268,7 +270,7 @@ async def get_resource( token: Annotated[OAuth2Token, Depends(get_token)], ) -> JSONResponse: """Generic path for testing a resource provided by a provider""" - assert user is not None + assert user is not None # Just to keep QA checks happy if oidc_provider is None: raise HTTPException( status.HTTP_406_NOT_ACCEPTABLE, detail="No such oidc provider" @@ -309,21 +311,21 @@ async def public() -> HTMLResponse: async def get_protected( user: Annotated[User, Depends(get_current_user)] ) -> HTMLResponse: - assert user is not None + assert user is not None # Just to keep QA checks happy return HTMLResponse("

Only authenticated users can see this

") @app.get("/protected-by-foorole") @hasrole("foorole") async def get_protected_by_foorole(request: Request) -> HTMLResponse: - assert request is not None + assert request is not None # Just to keep QA checks happy return HTMLResponse("

Only users with foorole can see this

") @app.get("/protected-by-barrole") @hasrole("barrole") async def get_protected_by_barrole(request: Request) -> HTMLResponse: - assert request is not None + assert request is not None # Just to keep QA checks happy return HTMLResponse("

Protected by barrole

") @@ -331,14 +333,14 @@ async def get_protected_by_barrole(request: Request) -> HTMLResponse: @hasrole("barrole") @hasrole("foorole") async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse: - assert request is not None + assert request is not None # Just to keep QA checks happy return HTMLResponse("

Only users with foorole and barrole can see this

") @app.get("/protected-by-foorole-or-barrole") @hasrole(["foorole", "barrole"]) async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse: - assert request is not None + assert request is not None # Just to keep QA checks happy return HTMLResponse("

Only users with foorole or barrole can see this

") @@ -348,7 +350,7 @@ async def get_introspect( oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], token: Annotated[OAuth2Token, Depends(get_token)], ) -> JSONResponse: - assert request is not None + assert request is not None # Just to keep QA checks happy if ( response := await oidc_provider.post( oidc_provider.server_metadata["introspection_endpoint"], @@ -361,31 +363,10 @@ async def get_introspect( raise HTTPException(status_code=response.status_code, detail=response.text) -@app.get("/oauth2-forgejo-test") -async def get_forgejo_user_info( - request: Request, - user: Annotated[User, Depends(get_current_user)], - oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], - token: Annotated[OAuth2Token, Depends(get_token)], -) -> HTMLResponse: - assert request is not None - if ( - response := await oidc_provider.get( - "/api/v1/user/repos", - # headers={"Authorization": f"token {token['access_token']}"}, - token=token, - ) - ).is_success: - repos = response.json() - names = [repo["name"] for repo in repos] - return HTMLResponse(f"{user.name} has {len(repos)} repos: {', '.join(names)}") - else: - raise HTTPException(status_code=response.status_code, detail=response.text) - - # Snippet for running standalone # Mostly useful for the --version option, -# as running with uvicorn is easy and provides flaxibility +# as running with uvicorn is easy and provides better flexibility, eg. +# uvicorn --host foo oidc_test.main:app --reload def main(): From 572d2a7b0d0973ec7fa56efdbc68cb9838dad74e Mon Sep 17 00:00:00 2001 From: phil Date: Mon, 20 Jan 2025 01:16:17 +0100 Subject: [PATCH 09/79] Cleanup --- TODO | 2 ++ src/oidc_test/auth_utils.py | 3 +-- src/oidc_test/main.py | 11 ++++++++--- src/oidc_test/models.py | 6 +----- src/oidc_test/settings.py | 6 ++++++ 5 files changed, 18 insertions(+), 10 deletions(-) diff --git a/TODO b/TODO index 5d7e575..93e80a9 100644 --- a/TODO +++ b/TODO @@ -1,3 +1,5 @@ https://docs.authlib.org/en/latest/oauth/2/intro.html#intro-oauth2 https://www.keycloak.org/docs/latest/authorization_services/index.html + +https://thinhdanggroup.github.io/oauth2-python/ diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index 880c111..0e5156b 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -1,6 +1,5 @@ from typing import Union from functools import wraps -from datetime import datetime import logging from fastapi import HTTPException, Request, status @@ -54,7 +53,7 @@ async def get_current_user(request: Request) -> User: logger.info(f"Token expired for user {user.name}") try: userinfo = await oidc_provider.fetch_access_token( - refresh_token=token.refresh_token + refresh_token=token.get("refresh_token") ) except OAuthError as err: logger.exception(err) diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index ed4c6a1..0f49dfa 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -66,7 +66,7 @@ for provider in settings.oidc.providers: name=provider.id, server_metadata_url=provider.openid_configuration, client_kwargs={ - "scope": "openid email", # offline_access profile", + "scope": "openid email offline_access profile", }, client_id=provider.client_id, client_secret=provider.client_secret, @@ -170,7 +170,7 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: except OAuthError as error: raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error) # Remember the oidc_provider in the session - # logger.debug(f"Scope: {token['scope']}") + # logger.info(f"Scope: {token['scope']}") request.session["oidc_provider_id"] = oidc_provider_id # # One could process the full decoded token which contains extra information @@ -351,9 +351,14 @@ async def get_introspect( token: Annotated[OAuth2Token, Depends(get_token)], ) -> JSONResponse: assert request is not None # Just to keep QA checks happy + if (url := oidc_provider.server_metadata.get("introspection_endpoint")) is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="No intrispection endpoint found for the OIDC provider", + ) if ( response := await oidc_provider.post( - oidc_provider.server_metadata["introspection_endpoint"], + url, token=token, data={"token": token["access_token"]}, ) diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py index 3d484aa..b6a267e 100644 --- a/src/oidc_test/models.py +++ b/src/oidc_test/models.py @@ -6,11 +6,8 @@ from pydantic import ( AnyHttpUrl, EmailStr, ConfigDict, - GetCoreSchemaHandler, ) -from pydantic_core import CoreSchema, core_schema from authlib.integrations.starlette_client.apps import StarletteOAuth2App -from authlib.oauth2.rfc6749 import OAuth2Token as OAuth2Token_authlib from sqlmodel import SQLModel, Field @@ -19,10 +16,9 @@ class Role(SQLModel, extra="ignore"): class UserBase(SQLModel, extra="ignore"): - id: str | None = None sid: str | None = None - name: str + name: str | None = None email: EmailStr | None = None picture: AnyHttpUrl | None = None roles: list[Role] = [] diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index 3a9447c..ef481c8 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -45,6 +45,12 @@ class OIDCProvider(BaseModel): return "auth/" + self.id +class ResourceProvider(BaseModel): + id: str + name: str + resources: list[Resource] = [] + + class OIDCSettings(BaseModel): show_session_details: bool = False providers: list[OIDCProvider] = [] From 7ab715da5a62721cbd5deae4c1cbd263706bcfe6 Mon Sep 17 00:00:00 2001 From: phil Date: Mon, 20 Jan 2025 04:35:33 +0100 Subject: [PATCH 10/79] Add resource provider settings --- src/oidc_test/settings.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index ef481c8..048aa57 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -4,7 +4,7 @@ import random from typing import Type, Tuple from pathlib import Path -from pydantic import BaseModel, computed_field +from pydantic import BaseModel, computed_field, AnyUrl from pydantic_settings import ( BaseSettings, SettingsConfigDict, @@ -18,7 +18,6 @@ class Resource(BaseModel): id: str name: str - url: str class OIDCProvider(BaseModel): @@ -48,6 +47,7 @@ class OIDCProvider(BaseModel): class ResourceProvider(BaseModel): id: str name: str + base_url: AnyUrl resources: list[Resource] = [] @@ -61,6 +61,7 @@ class Settings(BaseSettings): """Settings wil be read from an .env file""" oidc: OIDCSettings = OIDCSettings() + resource_providers: list[ResourceProvider] = [] secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16)) log: bool = False From dc93c7c05b0682f6c744c5e3142726638a9a6f2a Mon Sep 17 00:00:00 2001 From: phil Date: Sun, 26 Jan 2025 19:08:49 +0100 Subject: [PATCH 11/79] Add user self-care link & setting for supporting providers --- src/oidc_test/main.py | 20 ++++++++++++++++++-- src/oidc_test/settings.py | 9 +++++++++ src/oidc_test/templates/base.html | 2 +- src/oidc_test/templates/home.html | 7 +++++-- 4 files changed, 33 insertions(+), 5 deletions(-) diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 0f49dfa..83ee101 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -94,15 +94,16 @@ async def home( now = datetime.now() if oidc_provider and ( ( - provider := providers_settings.get( + oidc_provider_settings := providers_settings.get( request.session.get("oidc_provider_id", "") ) ) is not None ): - resources = provider.resources + resources = oidc_provider_settings.resources else: resources = [] + oidc_provider_settings = None return templates.TemplateResponse( name="home.html", request=request, @@ -110,6 +111,7 @@ async def home( "settings": settings.model_dump(), "user": user, "now": now, + "oidc_provider_settings": oidc_provider_settings, "resources": resources, "user_info_details": ( pretty_details(user, now) @@ -212,6 +214,20 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: ) +@app.get("/account") +async def account( + request: Request, + oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], +) -> RedirectResponse: + if ( + provider := providers_settings.get(request.session.get("oidc_provider_id", "")) + ) is None: + raise HTTPException( + status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider setting" + ) + return RedirectResponse(f"{provider.url}/account") + + @app.get("/logout") async def logout( request: Request, diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index 048aa57..38443fa 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -32,6 +32,7 @@ class OIDCProvider(BaseModel): code_challenge_method: str | None = None hint: str = "No hint" resources: list[Resource] = [] + account_url_suffix: str | None = None @computed_field @property @@ -43,6 +44,14 @@ class OIDCProvider(BaseModel): def token_url(self) -> str: return "auth/" + self.id + @computed_field + @property + def account_url(self) -> str | None: + if self.account_url_suffix: + return self.url + self.account_url_suffix + else: + return None + class ResourceProvider(BaseModel): id: str diff --git a/src/oidc_test/templates/base.html b/src/oidc_test/templates/base.html index d2aa44b..3ff5f65 100644 --- a/src/oidc_test/templates/base.html +++ b/src/oidc_test/templates/base.html @@ -1,6 +1,6 @@ - FastAPI OIDC test + OIDC (FastAPI) test diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index ac732e4..f103be9 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -42,7 +42,10 @@ Provider: {{ user.oidc_provider.name }} - Logout + {% if oidc_provider_settings.account_url %} + + {% endif %} + {% endif %}
@@ -66,7 +69,7 @@

{% endif %} From 5b6c6f1aacc024d8eaa1ece1c0e7bfbcd97b5d08 Mon Sep 17 00:00:00 2001 From: phil Date: Sun, 26 Jan 2025 23:37:56 +0100 Subject: [PATCH 12/79] Fix account url, use template for settings --- src/oidc_test/main.py | 19 ++++++++++++------- src/oidc_test/settings.py | 23 +++++++++++++++++------ src/oidc_test/templates/base.html | 2 +- 3 files changed, 30 insertions(+), 14 deletions(-) diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 83ee101..36d9d76 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -59,7 +59,7 @@ app.add_middleware( # Add oidc providers to authlib from the settings # fastapi_providers: dict[str, OpenIdConnect] = {} -providers_settings: dict[str, OIDCProvider] = {} +oidc_providers_settings: dict[str, OIDCProvider] = {} for provider in settings.oidc.providers: authlib_oauth.register( @@ -80,7 +80,7 @@ for provider in settings.oidc.providers: # fastapi_providers[provider.id] = OpenIdConnect( # openIdConnectUrl=provider.openid_configuration # ) - providers_settings[provider.id] = provider + oidc_providers_settings[provider.id] = provider @app.get("/") @@ -94,7 +94,7 @@ async def home( now = datetime.now() if oidc_provider and ( ( - oidc_provider_settings := providers_settings.get( + oidc_provider_settings := oidc_providers_settings.get( request.session.get("oidc_provider_id", "") ) ) @@ -111,6 +111,7 @@ async def home( "settings": settings.model_dump(), "user": user, "now": now, + "oidc_provider": oidc_provider, "oidc_provider_settings": oidc_provider_settings, "resources": resources, "user_info_details": ( @@ -137,7 +138,7 @@ async def login(request: Request, oidc_provider_id: str) -> RedirectResponse: except AttributeError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") # if ( - # code_challenge_method := providers_settings[ + # code_challenge_method := oidc_providers_settings[ # oidc_provider_id # ].code_challenge_method # ) is not None: @@ -220,12 +221,14 @@ async def account( oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], ) -> RedirectResponse: if ( - provider := providers_settings.get(request.session.get("oidc_provider_id", "")) + provider := oidc_providers_settings.get( + request.session.get("oidc_provider_id", "") + ) ) is None: raise HTTPException( status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider setting" ) - return RedirectResponse(f"{provider.url}/account") + return RedirectResponse(f"{provider.account_url}") @app.get("/logout") @@ -292,7 +295,9 @@ async def get_resource( status.HTTP_406_NOT_ACCEPTABLE, detail="No such oidc provider" ) if ( - provider := providers_settings.get(request.session.get("oidc_provider_id", "")) + provider := oidc_providers_settings.get( + request.session.get("oidc_provider_id", "") + ) ) is None: raise HTTPException( status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider setting" diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index 38443fa..81d5099 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -11,6 +11,9 @@ from pydantic_settings import ( PydanticBaseSettingsSource, YamlConfigSettingsSource, ) +from starlette.requests import Request + +from .models import User class Resource(BaseModel): @@ -32,7 +35,7 @@ class OIDCProvider(BaseModel): code_challenge_method: str | None = None hint: str = "No hint" resources: list[Resource] = [] - account_url_suffix: str | None = None + account_url_template: str | None = None @computed_field @property @@ -44,11 +47,19 @@ class OIDCProvider(BaseModel): def token_url(self) -> str: return "auth/" + self.id - @computed_field - @property - def account_url(self) -> str | None: - if self.account_url_suffix: - return self.url + self.account_url_suffix + def get_account_url(self, request: Request, user: User) -> str | None: + if self.account_url_template: + if not ( + self.url.endswith("/") or self.account_url_template.startswith("/") + ): + sep = "/" + else: + sep = "" + return ( + self.url + + sep + + self.account_url_template.format(request=request, user=user) + ) else: return None diff --git a/src/oidc_test/templates/base.html b/src/oidc_test/templates/base.html index 3ff5f65..3bdb3f3 100644 --- a/src/oidc_test/templates/base.html +++ b/src/oidc_test/templates/base.html @@ -5,7 +5,7 @@ -

OIDC-test

+

OIDC-test - FastAPI client

{% block content %} {% endblock %} From 61be70054b9e9441e8b70fea1e4ec13fb5d6ce5b Mon Sep 17 00:00:00 2001 From: phil Date: Sun, 26 Jan 2025 23:42:55 +0100 Subject: [PATCH 13/79] Fix account url, use template for settings --- src/oidc_test/templates/home.html | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index f103be9..c062101 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -40,10 +40,10 @@ {% endif %}
Provider: - {{ user.oidc_provider.name }} + {{ oidc_provider_settings.name }}
- {% if oidc_provider_settings.account_url %} - + {% if oidc_provider_settings.account_url_template %} + {% endif %} From 5b31ef888c3380b7d35e75a9949a6357c0744d02 Mon Sep 17 00:00:00 2001 From: phil Date: Tue, 28 Jan 2025 19:48:35 +0100 Subject: [PATCH 14/79] Add resource provider --- pyproject.toml | 2 + src/oidc_test/auth_utils.py | 100 ++++++++++++++++++++++++------- src/oidc_test/database.py | 5 ++ src/oidc_test/main.py | 84 ++++++++------------------ src/oidc_test/resource_server.py | 21 +++++++ src/oidc_test/settings.py | 10 ++++ uv.lock | 13 ++++ 7 files changed, 152 insertions(+), 83 deletions(-) create mode 100644 src/oidc_test/resource_server.py diff --git a/pyproject.toml b/pyproject.toml index 4509e5b..980bcfc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,10 +9,12 @@ dependencies = [ "authlib>=1.4.0", "cachetools>=5.5.0", "fastapi[standard]>=0.115.6", + "httpx>=0.28.1", "itsdangerous>=2.2.0", "passlib[bcrypt]>=1.7.4", "pkce>=1.0.3", "pydantic-settings>=2.7.1", + "pyjwt>=2.10.1", "python-jose[cryptography]>=3.3.0", "requests>=2.32.3", "sqlmodel>=0.0.22", diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index 0e5156b..2b3d0fd 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -1,22 +1,63 @@ -from typing import Union +from typing import Union, Annotated, Tuple from functools import wraps import logging -from fastapi import HTTPException, Request, status +from fastapi import HTTPException, Request, Depends, status +from fastapi.security import OAuth2PasswordBearer from authlib.oauth2.rfc6749 import OAuth2Token from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App +from jwt import decode # from authlib.oauth1.auth import OAuthToken # from authlib.oauth2.auth import OAuth2Token from .models import User -from .database import db -from .settings import settings +from .database import db, UserNotInDB +from .settings import settings, OIDCProvider logger = logging.getLogger(__name__) -OIDC_PROVIDERS = set([provider.id for provider in settings.oidc.providers]) +oidc_providers_settings: dict[str, OIDCProvider] = dict([(provider.id, provider) for provider in settings.oidc.providers]) +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + +def fetch_token(name, request): + breakpoint() + ... + # if name in oidc_providers: + # model = OAuth2Token + # else: + # model = OAuthToken + + # token = model.find(name=name, user=request.user) + # return token.to_token() + + +def update_token(*args, **kwargs): + breakpoint() + ... + + +authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) + + +# Add oidc providers to authlib from the settings +for id, provider in oidc_providers_settings.items(): + authlib_oauth.register( + name=id, + server_metadata_url=provider.openid_configuration, + client_kwargs={ + "scope": "openid email offline_access profile", + }, + client_id=provider.client_id, + client_secret=provider.client_secret, + api_base_url=provider.url, + # For PKCE (not implemented yet): + # code_challenge_method="S256", + # fetch_token=fetch_token, + # update_token=update_token, + # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) + ) def get_oidc_provider_or_none(request: Request) -> StarletteOAuth2App | None: """Return the oidc_provider from a request object, from the session. @@ -115,21 +156,34 @@ def get_token_info(token: dict) -> dict: return token_info -def fetch_token(name, request): - breakpoint() - ... - # if name in OIDC_PROVIDERS: - # model = OAuth2Token - # else: - # model = OAuthToken - - # token = model.find(name=name, user=request.user) - # return token.to_token() - - -def update_token(*args, **kwargs): - breakpoint() - ... - - -authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) +async def get_resource_user( + token: Annotated[str, Depends(oauth2_scheme)], + request: Request, +) -> User: + # TODO: decode token (ah!) + # See https://fastapi.tiangolo.com/tutorial/security/oauth2-jwt/#hash-and-verify-the-passwords + if (auth_provider_id := request.headers.get("auth_provider")) is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Request headers must have a 'auth_provider' field") + if (auth_provider := oidc_providers_settings.get(auth_provider_id)) is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'") + if (key := auth_provider.get_key()) is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Key for provider '{auth_provider_id}' unknown") + try: + payload = decode(token, key=key, algorithms=["RS256"], audience="oidc-test") + except Exception as err: + logger.info("Cannot decode token, see below") + logger.exception(err) + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot decode token") + if (user_id := payload.get('sub')) is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found") + try: + user = await db.get_user(user_id) + except UserNotInDB: + logger.info(f"User {user_id} not found in DB, creating it (real apps can behave differently") + user = await db.add_user( + sub=payload['sub'], + user_info=payload, + oidc_provider=getattr(authlib_oauth, auth_provider_id), + user_info_from_endpoint={} + ) + return user diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index 4b3f529..9d72081 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -10,6 +10,9 @@ from authlib.oauth2.rfc6749 import OAuth2Token logger = logging.getLogger(__name__) +class UserNotInDB(Exception): + pass + class Database: users: dict[str, User] = {} tokens: dict[str, OAuth2Token] = {} @@ -37,6 +40,8 @@ class Database: return user async def get_user(self, sub: str) -> User: + if sub not in self.users: + raise UserNotInDB return self.users[sub] async def add_token(self, token: OAuth2Token, user: User) -> None: diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 36d9d76..351cc2f 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -13,6 +13,7 @@ from fastapi import Depends, FastAPI, HTTPException, Request, status from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse from fastapi.templating import Jinja2Templates +from fastapi.middleware.cors import CORSMiddleware from starlette.middleware.sessions import SessionMiddleware from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.base_client import OAuthError @@ -23,7 +24,7 @@ from authlib.oauth2.rfc6749 import OAuth2Token # from fastapi.security import OpenIdConnect # from pkce import generate_code_verifier, generate_pkce_pair -from .settings import settings, OIDCProvider +from .settings import settings from .models import User from .auth_utils import ( get_oidc_provider, @@ -31,21 +32,37 @@ from .auth_utils import ( hasrole, get_current_user_or_none, get_current_user, + get_resource_user, authlib_oauth, get_token, + oidc_providers_settings, ) from .auth_misc import pretty_details from .database import db +from .resource_server import get_resource logger = logging.getLogger("uvicorn.error") templates = Jinja2Templates(Path(__file__).parent / "templates") +origins = [ + "https://tiptop:3002", + "https://philo.ydns.eu/", +] + app = FastAPI( title="OIDC auth test", ) +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + app.mount( "/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static" ) @@ -56,32 +73,6 @@ app.add_middleware( secret_key=settings.secret_key, ) -# Add oidc providers to authlib from the settings - -# fastapi_providers: dict[str, OpenIdConnect] = {} -oidc_providers_settings: dict[str, OIDCProvider] = {} - -for provider in settings.oidc.providers: - authlib_oauth.register( - name=provider.id, - server_metadata_url=provider.openid_configuration, - client_kwargs={ - "scope": "openid email offline_access profile", - }, - client_id=provider.client_id, - client_secret=provider.client_secret, - api_base_url=provider.url, - # For PKCE (not implemented yet): - # code_challenge_method="S256", - # fetch_token=fetch_token, - # update_token=update_token, - # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) - ) - # fastapi_providers[provider.id] = OpenIdConnect( - # openIdConnectUrl=provider.openid_configuration - # ) - oidc_providers_settings[provider.id] = provider - @app.get("/") async def home( @@ -281,43 +272,16 @@ async def non_compliant_logout( @app.get("/resource/{id}") -async def get_resource( +async def get_resource_( id: str, request: Request, - user: Annotated[User, Depends(get_current_user)], - oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], - token: Annotated[OAuth2Token, Depends(get_token)], + # user: Annotated[User, Depends(get_current_user)], + # oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], + # token: Annotated[OAuth2Token, Depends(get_token)], + user: Annotated[User, Depends(get_resource_user)], ) -> JSONResponse: """Generic path for testing a resource provided by a provider""" - assert user is not None # Just to keep QA checks happy - if oidc_provider is None: - raise HTTPException( - status.HTTP_406_NOT_ACCEPTABLE, detail="No such oidc provider" - ) - if ( - provider := oidc_providers_settings.get( - request.session.get("oidc_provider_id", "") - ) - ) is None: - raise HTTPException( - status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider setting" - ) - try: - resource = next(x for x in provider.resources if x.id == id) - except StopIteration: - raise HTTPException( - status.HTTP_406_NOT_ACCEPTABLE, detail="No such resource for this provider" - ) - if ( - response := await oidc_provider.get( - resource.url, - # headers={"Authorization": f"token {token['access_token']}"}, - token=token, - ) - ).is_success: - return JSONResponse(response.json()) - else: - raise HTTPException(status_code=response.status_code, detail=response.text) + return JSONResponse(await get_resource(id, user)) # Routes for test diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py new file mode 100644 index 0000000..a9dfe3a --- /dev/null +++ b/src/oidc_test/resource_server.py @@ -0,0 +1,21 @@ +from datetime import datetime +from httpx import AsyncClient + +from .models import User + +async def get_resource(id: str, user: User) -> dict: + pname = getattr(user.oidc_provider, "name", "?") + resp = { + "hello" : f"Hi {user.name} from an OAuth resource provider.", + "comment": f"I received a request for '{id}' with an access token signed by {pname}." + } + if id == "time": + resp["time"] = datetime.now().strftime("%c") + elif id == "bs": + async with AsyncClient() as client: + bs = await client.get("https://corporatebs-generator.sameerkumar.website/") + resp['bs'] = bs.json().get("phrase", "Sorry, i am out of BS today.") + else: + resp['sorry'] = f"I don't known how to give '{id}' but i know corporate bs." + + return resp diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index 81d5099..c511f86 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -36,6 +36,7 @@ class OIDCProvider(BaseModel): hint: str = "No hint" resources: list[Resource] = [] account_url_template: str | None = None + key: str | None = None @computed_field @property @@ -63,6 +64,15 @@ class OIDCProvider(BaseModel): else: return None + def get_key(self) -> str | None: + """Return the public key formatted for """ + if self.key is None: + return None + return f""" + -----BEGIN PUBLIC KEY----- + {self.key} + -----END PUBLIC KEY----- + """ class ResourceProvider(BaseModel): id: str diff --git a/uv.lock b/uv.lock index 6ceb4ca..01b64de 100644 --- a/uv.lock +++ b/uv.lock @@ -488,10 +488,12 @@ dependencies = [ { name = "authlib" }, { name = "cachetools" }, { name = "fastapi", extra = ["standard"] }, + { name = "httpx" }, { name = "itsdangerous" }, { name = "passlib", extra = ["bcrypt"] }, { name = "pkce" }, { name = "pydantic-settings" }, + { name = "pyjwt" }, { name = "python-jose", extra = ["cryptography"] }, { name = "requests" }, { name = "sqlmodel" }, @@ -508,10 +510,12 @@ requires-dist = [ { name = "authlib", specifier = ">=1.4.0" }, { name = "cachetools", specifier = ">=5.5.0" }, { name = "fastapi", extras = ["standard"], specifier = ">=0.115.6" }, + { name = "httpx", specifier = ">=0.28.1" }, { name = "itsdangerous", specifier = ">=2.2.0" }, { name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.4" }, { name = "pkce", specifier = ">=1.0.3" }, { name = "pydantic-settings", specifier = ">=2.7.1" }, + { name = "pyjwt", specifier = ">=2.10.1" }, { name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" }, { name = "requests", specifier = ">=2.32.3" }, { name = "sqlmodel", specifier = ">=0.0.22" }, @@ -694,6 +698,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f7/3f/01c8b82017c199075f8f788d0d906b9ffbbc5a47dc9918a945e13d5a2bda/pygments-2.18.0-py3-none-any.whl", hash = "sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a", size = 1205513 }, ] +[[package]] +name = "pyjwt" +version = "2.10.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997 }, +] + [[package]] name = "pytest" version = "8.3.4" From f9108347369a24dd2998323d140c5fa4dcab4a88 Mon Sep 17 00:00:00 2001 From: phil Date: Wed, 29 Jan 2025 14:03:33 +0100 Subject: [PATCH 15/79] Fetch provider info at boot time: get public key from there instead of in settings --- src/oidc_test/auth_utils.py | 63 ++++++++++++++++++++++++------------- src/oidc_test/main.py | 14 +++++++-- src/oidc_test/settings.py | 13 +++++--- 3 files changed, 61 insertions(+), 29 deletions(-) diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index 2b3d0fd..d64b6cf 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -6,7 +6,8 @@ from fastapi import HTTPException, Request, Depends, status from fastapi.security import OAuth2PasswordBearer from authlib.oauth2.rfc6749 import OAuth2Token from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App -from jwt import decode +from jwt import ExpiredSignatureError, InvalidKeyError, decode +from httpx import AsyncClient # from authlib.oauth1.auth import OAuthToken # from authlib.oauth2.auth import OAuth2Token @@ -41,23 +42,35 @@ def update_token(*args, **kwargs): authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) +def init_providers(): # Add oidc providers to authlib from the settings -for id, provider in oidc_providers_settings.items(): - authlib_oauth.register( - name=id, - server_metadata_url=provider.openid_configuration, - client_kwargs={ - "scope": "openid email offline_access profile", - }, - client_id=provider.client_id, - client_secret=provider.client_secret, - api_base_url=provider.url, - # For PKCE (not implemented yet): - # code_challenge_method="S256", - # fetch_token=fetch_token, - # update_token=update_token, - # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) - ) + for id, provider in oidc_providers_settings.items(): + authlib_oauth.register( + name=id, + server_metadata_url=provider.openid_configuration, + client_kwargs={ + "scope": "openid email offline_access profile", + }, + client_id=provider.client_id, + client_secret=provider.client_secret, + api_base_url=provider.url, + # For PKCE (not implemented yet): + # code_challenge_method="S256", + # fetch_token=fetch_token, + # update_token=update_token, + # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) + ) + +init_providers() + +async def get_providers_info(): + # Get the public key: + async with AsyncClient() as client: + for provider_settings in oidc_providers_settings.values(): + if provider_settings.info_url: + provider_info = await client.get(provider_settings.url) + provider_settings.info = provider_info.json() + def get_oidc_provider_or_none(request: Request) -> StarletteOAuth2App | None: """Return the oidc_provider from a request object, from the session. @@ -156,20 +169,26 @@ def get_token_info(token: dict) -> dict: return token_info -async def get_resource_user( +async def get_user_from_token( token: Annotated[str, Depends(oauth2_scheme)], request: Request, ) -> User: - # TODO: decode token (ah!) - # See https://fastapi.tiangolo.com/tutorial/security/oauth2-jwt/#hash-and-verify-the-passwords if (auth_provider_id := request.headers.get("auth_provider")) is None: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Request headers must have a 'auth_provider' field") - if (auth_provider := oidc_providers_settings.get(auth_provider_id)) is None: + if (auth_provider_settings := oidc_providers_settings.get(auth_provider_id)) is None: raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'") - if (key := auth_provider.get_key()) is None: + oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id) + await oidc_provider.load_server_metadata() + if (key := auth_provider_settings.get_public_key()) is None: raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Key for provider '{auth_provider_id}' unknown") try: payload = decode(token, key=key, algorithms=["RS256"], audience="oidc-test") + except ExpiredSignatureError as err: + logger.info(f"Expired signature: {err}") + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Expired signature (refresh not implemented yet)") + except InvalidKeyError as err: + logger.info(f"Invalid key: {err}") + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key") except Exception as err: logger.info("Cannot decode token, see below") logger.exception(err) diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 351cc2f..6942759 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -7,6 +7,7 @@ from pathlib import Path from datetime import datetime import logging from urllib.parse import urlencode +from contextlib import asynccontextmanager from httpx import HTTPError from fastapi import Depends, FastAPI, HTTPException, Request, status @@ -32,10 +33,11 @@ from .auth_utils import ( hasrole, get_current_user_or_none, get_current_user, - get_resource_user, + get_user_from_token, authlib_oauth, get_token, oidc_providers_settings, + get_providers_info, ) from .auth_misc import pretty_details from .database import db @@ -51,10 +53,18 @@ origins = [ "https://philo.ydns.eu/", ] +@asynccontextmanager +async def lifespan(app: FastAPI): + await get_providers_info() + yield + + app = FastAPI( title="OIDC auth test", + lifespan=lifespan ) + app.add_middleware( CORSMiddleware, allow_origins=origins, @@ -278,7 +288,7 @@ async def get_resource_( # user: Annotated[User, Depends(get_current_user)], # oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], # token: Annotated[OAuth2Token, Depends(get_token)], - user: Annotated[User, Depends(get_resource_user)], + user: Annotated[User, Depends(get_user_from_token)], ) -> JSONResponse: """Generic path for testing a resource provided by a provider""" return JSONResponse(await get_resource(id, user)) diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index c511f86..00c3f23 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -36,7 +36,9 @@ class OIDCProvider(BaseModel): hint: str = "No hint" resources: list[Resource] = [] account_url_template: str | None = None - key: str | None = None + info_url: str | None = None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key) + info: dict[str, str | int ] | None = None # Info fetched from info_url, eg. public key + public_key: str | None = None @computed_field @property @@ -64,13 +66,14 @@ class OIDCProvider(BaseModel): else: return None - def get_key(self) -> str | None: - """Return the public key formatted for """ - if self.key is None: + def get_public_key(self) -> str | None: + """Return the public key formatted for decoding token""" + public_key = self.public_key or (self.info is not None and self.info["public_key"]) + if public_key is None: return None return f""" -----BEGIN PUBLIC KEY----- - {self.key} + {public_key} -----END PUBLIC KEY----- """ From b3e19b3e4083cf9531fa9b52dc32501f441feab6 Mon Sep 17 00:00:00 2001 From: phil Date: Thu, 30 Jan 2025 20:40:04 +0100 Subject: [PATCH 16/79] Resource server: read the required scope in access token --- src/oidc_test/auth_utils.py | 58 +++++++++++++++++------- src/oidc_test/main.py | 9 ++-- src/oidc_test/resource_server.py | 75 +++++++++++++++++++++++++++----- src/oidc_test/settings.py | 24 +++++++--- 4 files changed, 129 insertions(+), 37 deletions(-) diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index d64b6cf..2fcfc76 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -1,4 +1,4 @@ -from typing import Union, Annotated, Tuple +from typing import Union, Annotated from functools import wraps import logging @@ -18,10 +18,13 @@ from .settings import settings, OIDCProvider logger = logging.getLogger(__name__) -oidc_providers_settings: dict[str, OIDCProvider] = dict([(provider.id, provider) for provider in settings.oidc.providers]) +oidc_providers_settings: dict[str, OIDCProvider] = dict( + [(provider.id, provider) for provider in settings.oidc.providers] +) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + def fetch_token(name, request): breakpoint() ... @@ -43,7 +46,7 @@ authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_t def init_providers(): -# Add oidc providers to authlib from the settings + # Add oidc providers to authlib from the settings for id, provider in oidc_providers_settings.items(): authlib_oauth.register( name=id, @@ -61,8 +64,10 @@ def init_providers(): # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) ) + init_providers() + async def get_providers_info(): # Get the public key: async with AsyncClient() as client: @@ -174,18 +179,35 @@ async def get_user_from_token( request: Request, ) -> User: if (auth_provider_id := request.headers.get("auth_provider")) is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Request headers must have a 'auth_provider' field") - if (auth_provider_settings := oidc_providers_settings.get(auth_provider_id)) is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'") - oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id) - await oidc_provider.load_server_metadata() + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, + "Request headers must have a 'auth_provider' field", + ) + if ( + auth_provider_settings := oidc_providers_settings.get(auth_provider_id) + ) is None: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'" + ) if (key := auth_provider_settings.get_public_key()) is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Key for provider '{auth_provider_id}' unknown") + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, + f"Key for provider '{auth_provider_id}' unknown", + ) try: - payload = decode(token, key=key, algorithms=["RS256"], audience="oidc-test") + payload = decode( + token, + key=key, + algorithms=["RS256"], + audience="oidc-test", + options={"verify_signature": not settings.insecure.skip_verify_signature}, + ) except ExpiredSignatureError as err: logger.info(f"Expired signature: {err}") - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Expired signature (refresh not implemented yet)") + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, + "Expired signature (refresh not implemented yet)", + ) except InvalidKeyError as err: logger.info(f"Invalid key: {err}") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key") @@ -193,16 +215,20 @@ async def get_user_from_token( logger.info("Cannot decode token, see below") logger.exception(err) raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot decode token") - if (user_id := payload.get('sub')) is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found") + if (user_id := payload.get("sub")) is None: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found" + ) try: user = await db.get_user(user_id) except UserNotInDB: - logger.info(f"User {user_id} not found in DB, creating it (real apps can behave differently") + logger.info( + f"User {user_id} not found in DB, creating it (real apps can behave differently" + ) user = await db.add_user( - sub=payload['sub'], + sub=payload["sub"], user_info=payload, oidc_provider=getattr(authlib_oauth, auth_provider_id), - user_info_from_endpoint={} + user_info_from_endpoint={}, ) return user diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 6942759..f6ce405 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -53,16 +53,14 @@ origins = [ "https://philo.ydns.eu/", ] + @asynccontextmanager async def lifespan(app: FastAPI): await get_providers_info() yield -app = FastAPI( - title="OIDC auth test", - lifespan=lifespan -) +app = FastAPI(title="OIDC auth test", lifespan=lifespan) app.add_middleware( @@ -284,7 +282,6 @@ async def non_compliant_logout( @app.get("/resource/{id}") async def get_resource_( id: str, - request: Request, # user: Annotated[User, Depends(get_current_user)], # oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], # token: Annotated[OAuth2Token, Depends(get_token)], @@ -294,7 +291,7 @@ async def get_resource_( return JSONResponse(await get_resource(id, user)) -# Routes for test +# Routes for RBAC based tests @app.get("/public") diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index a9dfe3a..0186064 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -1,21 +1,76 @@ from datetime import datetime +import logging + from httpx import AsyncClient +from fastapi import HTTPException, status +from jwt import ExpiredSignatureError, InvalidKeyError, decode from .models import User +from .auth_utils import oidc_providers_settings +from .settings import settings + +logger = logging.getLogger(__name__) + async def get_resource(id: str, user: User) -> dict: pname = getattr(user.oidc_provider, "name", "?") resp = { - "hello" : f"Hi {user.name} from an OAuth resource provider.", - "comment": f"I received a request for '{id}' with an access token signed by {pname}." + "hello": f"Hi {user.name} from an OAuth resource provider.", + "comment": f"I received a request for '{id}' with an access token signed by {pname}.", } - if id == "time": - resp["time"] = datetime.now().strftime("%c") - elif id == "bs": - async with AsyncClient() as client: - bs = await client.get("https://corporatebs-generator.sameerkumar.website/") - resp['bs'] = bs.json().get("phrase", "Sorry, i am out of BS today.") + scope = f"get:{id}" + user_scopes = user.userinfo["scope"].split(" ") + if scope in user_scopes: + if id == "time": + resp["time"] = datetime.now().strftime("%c") + elif id == "bs": + async with AsyncClient() as client: + bs = await client.get( + "https://corporatebs-generator.sameerkumar.website/" + ) + resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.") + else: + resp["sorry"] = f"I don't known how to give '{id}' but i know corporate bs." else: - resp['sorry'] = f"I don't known how to give '{id}' but i know corporate bs." - + resp["sorry"] = ( + f"I don't serve the ressource {id} to you because" + "there is no scope {scope} in the access token," + ) return resp + + # assert user.oidc_provider is not None + ### Get some info (TODO: refactor) + # if (auth_provider_id := user.oidc_provider.name) is None: + # raise HTTPException( + # status.HTTP_401_UNAUTHORIZED, + # "Request headers must have a 'auth_provider' field", + # ) + # if ( + # auth_provider_settings := oidc_providers_settings.get(auth_provider_id) + # ) is None: + # raise HTTPException( + # status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'" + # ) + # if (key := auth_provider_settings.get_public_key()) is None: + # raise HTTPException( + # status.HTTP_401_UNAUTHORIZED, + # f"Key for provider '{auth_provider_id}' unknown", + # ) + # logger.warn(f"refresh with scope {scope}") + # breakpoint() + # refreshed_auth_info = await user.oidc_provider.fetch_access_token(scope=scope) + ### Decode the new token + # try: + # payload = decode( + # refreshed_auth_info["access_token"], + # key=key, + # algorithms=["RS256"], + # audience="account", + # options={"verify_signature": not settings.insecure.skip_verify_signature}, + # ) + # except ExpiredSignatureError as err: + # logger.info(f"Expired signature: {err}") + # raise HTTPException( + # status.HTTP_401_UNAUTHORIZED, + # "Expired signature (refresh not implemented yet)", + # ) diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index 00c3f23..399fbac 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -36,8 +36,12 @@ class OIDCProvider(BaseModel): hint: str = "No hint" resources: list[Resource] = [] account_url_template: str | None = None - info_url: str | None = None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key) - info: dict[str, str | int ] | None = None # Info fetched from info_url, eg. public key + info_url: str | None = ( + None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key) + ) + info: dict[str, str | int] | None = ( + None # Info fetched from info_url, eg. public key + ) public_key: str | None = None @computed_field @@ -68,7 +72,9 @@ class OIDCProvider(BaseModel): def get_public_key(self) -> str | None: """Return the public key formatted for decoding token""" - public_key = self.public_key or (self.info is not None and self.info["public_key"]) + public_key = self.public_key or ( + self.info is not None and self.info["public_key"] + ) if public_key is None: return None return f""" @@ -77,6 +83,7 @@ class OIDCProvider(BaseModel): -----END PUBLIC KEY----- """ + class ResourceProvider(BaseModel): id: str name: str @@ -90,15 +97,22 @@ class OIDCSettings(BaseModel): swagger_provider: str = "" +class Insecure(BaseModel): + """Warning: changing these defaults are only suitable for debugging""" + + skip_verify_signature: bool = False + + class Settings(BaseSettings): """Settings wil be read from an .env file""" + model_config = SettingsConfigDict(env_nested_delimiter="__") + oidc: OIDCSettings = OIDCSettings() resource_providers: list[ResourceProvider] = [] secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16)) log: bool = False - - model_config = SettingsConfigDict(env_nested_delimiter="__") + insecure: Insecure = Insecure() @classmethod def settings_customise_sources( From 815a4503df613e5f8f7819522a8064159665f813 Mon Sep 17 00:00:00 2001 From: phil Date: Fri, 31 Jan 2025 00:12:50 +0100 Subject: [PATCH 17/79] Add cors origins setting --- src/oidc_test/main.py | 8 +------- src/oidc_test/settings.py | 1 + 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index f6ce405..e02a627 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -48,12 +48,6 @@ logger = logging.getLogger("uvicorn.error") templates = Jinja2Templates(Path(__file__).parent / "templates") -origins = [ - "https://tiptop:3002", - "https://philo.ydns.eu/", -] - - @asynccontextmanager async def lifespan(app: FastAPI): await get_providers_info() @@ -65,7 +59,7 @@ app = FastAPI(title="OIDC auth test", lifespan=lifespan) app.add_middleware( CORSMiddleware, - allow_origins=origins, + allow_origins=settings.cors_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index 399fbac..329b9c0 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -113,6 +113,7 @@ class Settings(BaseSettings): secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16)) log: bool = False insecure: Insecure = Insecure() + cors_origins: list[str] = [] @classmethod def settings_customise_sources( From f7ea132b7cf29ec3c7eccdeb390237d3ea536384 Mon Sep 17 00:00:00 2001 From: phil Date: Fri, 31 Jan 2025 11:43:11 +0100 Subject: [PATCH 18/79] Fix resource server error message --- src/oidc_test/resource_server.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index 0186064..34a77cc 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -33,8 +33,7 @@ async def get_resource(id: str, user: User) -> dict: resp["sorry"] = f"I don't known how to give '{id}' but i know corporate bs." else: resp["sorry"] = ( - f"I don't serve the ressource {id} to you because" - "there is no scope {scope} in the access token," + f"I don't serve the ressource {id} to you because there is no scope {scope} in the access token," ) return resp From 17bf34a8a10ac0c76eeb598cfb9713750c2b8eb8 Mon Sep 17 00:00:00 2001 From: phil Date: Sat, 1 Feb 2025 02:01:53 +0100 Subject: [PATCH 19/79] Fix error handling in resource server --- src/oidc_test/database.py | 1 + src/oidc_test/main.py | 1 - src/oidc_test/resource_server.py | 135 ++++++++++++++++++------------- 3 files changed, 78 insertions(+), 59 deletions(-) diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index 9d72081..0252d28 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -13,6 +13,7 @@ logger = logging.getLogger(__name__) class UserNotInDB(Exception): pass + class Database: users: dict[str, User] = {} tokens: dict[str, OAuth2Token] = {} diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index e02a627..739dd1b 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -56,7 +56,6 @@ async def lifespan(app: FastAPI): app = FastAPI(title="OIDC auth test", lifespan=lifespan) - app.add_middleware( CORSMiddleware, allow_origins=settings.cors_origins, diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index 34a77cc..67f87dd 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -2,74 +2,93 @@ from datetime import datetime import logging from httpx import AsyncClient -from fastapi import HTTPException, status -from jwt import ExpiredSignatureError, InvalidKeyError, decode from .models import User -from .auth_utils import oidc_providers_settings -from .settings import settings logger = logging.getLogger(__name__) -async def get_resource(id: str, user: User) -> dict: +async def get_resource(resource_id: str, user: User) -> dict: + """ + Resource processing: build an informative rely as a simple showcase + """ pname = getattr(user.oidc_provider, "name", "?") resp = { - "hello": f"Hi {user.name} from an OAuth resource provider.", - "comment": f"I received a request for '{id}' with an access token signed by {pname}.", + "hello": f"Hi {user.name} from an OAuth resource provider", + "comment": f"I received a request for '{resource_id}' " + + f"with an access token signed by {pname}", } - scope = f"get:{id}" - user_scopes = user.userinfo["scope"].split(" ") - if scope in user_scopes: - if id == "time": - resp["time"] = datetime.now().strftime("%c") - elif id == "bs": - async with AsyncClient() as client: - bs = await client.get( - "https://corporatebs-generator.sameerkumar.website/" - ) - resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.") + # For the demo, resource resource_id matches a scope get:resource_id, + # but this has to be refined for production + required_scope = f"get:{resource_id}" + # Check if the required scope is in the scopes allowed in userinfo + if "required_scope" in user.userinfo: + user_scopes = user.userinfo["required_scope"].split(" ") + if required_scope in user_scopes: + await process(user, required_scope, resp) else: - resp["sorry"] = f"I don't known how to give '{id}' but i know corporate bs." + ## For the showcase, giving a explanation. + ## Alternatively, raise HTTP_401_UNAUTHORIZED + resp["sorry"] = ( + f"No scope {required_scope} in the access token " + + "but it is required for accessing this resource." + ) else: - resp["sorry"] = ( - f"I don't serve the ressource {id} to you because there is no scope {scope} in the access token," - ) + resp["sorry"] = "There is no scope in id token" return resp - # assert user.oidc_provider is not None - ### Get some info (TODO: refactor) - # if (auth_provider_id := user.oidc_provider.name) is None: - # raise HTTPException( - # status.HTTP_401_UNAUTHORIZED, - # "Request headers must have a 'auth_provider' field", - # ) - # if ( - # auth_provider_settings := oidc_providers_settings.get(auth_provider_id) - # ) is None: - # raise HTTPException( - # status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'" - # ) - # if (key := auth_provider_settings.get_public_key()) is None: - # raise HTTPException( - # status.HTTP_401_UNAUTHORIZED, - # f"Key for provider '{auth_provider_id}' unknown", - # ) - # logger.warn(f"refresh with scope {scope}") - # breakpoint() - # refreshed_auth_info = await user.oidc_provider.fetch_access_token(scope=scope) - ### Decode the new token - # try: - # payload = decode( - # refreshed_auth_info["access_token"], - # key=key, - # algorithms=["RS256"], - # audience="account", - # options={"verify_signature": not settings.insecure.skip_verify_signature}, - # ) - # except ExpiredSignatureError as err: - # logger.info(f"Expired signature: {err}") - # raise HTTPException( - # status.HTTP_401_UNAUTHORIZED, - # "Expired signature (refresh not implemented yet)", - # ) + +async def process(user, resource_id, resp): + """ + Too simple to be serious. + It's a good fit for a plugin architecture for production + """ + assert user is not None + if resource_id == "time": + resp["time"] = datetime.now().strftime("%c") + elif resource_id == "bs": + async with AsyncClient() as client: + bs = await client.get("https://corporatebs-generator.sameerkumar.website/") + resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.") + else: + resp["sorry"] = ( + f"I don't known how to give '{resource_id}' but i know corporate bs." + ) + + +# assert user.oidc_provider is not None +### Get some info (TODO: refactor) +# if (auth_provider_id := user.oidc_provider.name) is None: +# raise HTTPException( +# status.HTTP_401_UNAUTHORIZED, +# "Request headers must have a 'auth_provider' field", +# ) +# if ( +# auth_provider_settings := oidc_providers_settings.get(auth_provider_id) +# ) is None: +# raise HTTPException( +# status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'" +# ) +# if (key := auth_provider_settings.get_public_key()) is None: +# raise HTTPException( +# status.HTTP_401_UNAUTHORIZED, +# f"Key for provider '{auth_provider_id}' unknown", +# ) +# logger.warn(f"refresh with scope {scope}") +# breakpoint() +# refreshed_auth_info = await user.oidc_provider.fetch_access_token(scope=scope) +### Decode the new token +# try: +# payload = decode( +# refreshed_auth_info["access_token"], +# key=key, +# algorithms=["RS256"], +# audience="account", +# options={"verify_signature": not settings.insecure.skip_verify_signature}, +# ) +# except ExpiredSignatureError as err: +# logger.info(f"Expired signature: {err}") +# raise HTTPException( +# status.HTTP_401_UNAUTHORIZED, +# "Expired signature (refresh not implemented yet)", +# ) From e90a1cc920d4296589f88fa8e5a0556660339fb3 Mon Sep 17 00:00:00 2001 From: phil Date: Sat, 1 Feb 2025 02:16:40 +0100 Subject: [PATCH 20/79] Update README --- README.md | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 0dc11be..9e00474 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,12 @@ as a template for integration in other FastAPI/SQLModel applications. Feedback welcome. +## Resource server + +It also functions as a resource server in a OAuth architecture. +See a sibling test project, a web based OIDC/OAuth: +[oidc-vue-test](https://code.philo.ydns.eu/philorg/oidc-vue-test). + ## RBAC The application is also a playground for RBAC (Role Based Access control) @@ -45,7 +51,7 @@ given by the OIDC providers. For example: -```text +```yaml oidc: secret_key: "ASecretNoOneKnows" show_session_details: yes @@ -60,6 +66,7 @@ oidc: - id: keycloak name: Keycloak at somewhere url: "https://" + account_url_template: "/account" client_id: "" client_secret: "client_secret_generated_by_keycloak" hint: "User: foo, password: foofoo" @@ -67,14 +74,27 @@ oidc: - id: codeberg name: Codeberg url: "https://codeberg.org" + account_url_template: "/user/settings" client_id: "" client_secret: "client_secret_generated_by_codeberg" + resources: + - name: List of repos + id: repos + url: /api/v1/user/repos + - name: List of OAuth2 applications + id: oauth2_applications + url: /api/v1/user/applications/oauth2 + +cors_origins: + - https://some.client + - https://localhost:8000 ``` The application reads the `OIDC_TEST_SETTINGS_FILE` environment variable to determine the location of this file at startup. -For example, to run on port 8000 in a container, with the setting file in the current working directory: +For example, to run on port 8000 in a container, +with the setting file in the current working directory: ```sh podman run -p 8000:80 --env OIDC_TEST_CONFIG_FILE=/app/settings.yaml --mount type=bind,source=settings.yaml,destination=/app/settings.yaml code.philo.ydns.eu/philorg/oidc-fastapi-test:latest From e9bc6c671ae23e0eb4f74f8f207eda6d02625f28 Mon Sep 17 00:00:00 2001 From: phil Date: Sat, 1 Feb 2025 11:30:45 +0100 Subject: [PATCH 21/79] Cosmetic --- src/oidc_test/static/styles.css | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/oidc_test/static/styles.css b/src/oidc_test/static/styles.css index 6065a91..a4a0178 100644 --- a/src/oidc_test/static/styles.css +++ b/src/oidc_test/static/styles.css @@ -2,11 +2,15 @@ body { font-family: Arial, Helvetica, sans-serif; background-color: floralwhite; margin: 0; + font-family: system-ui; } h1 { text-align: center; background-color: #f7c7867d; margin: 0 0 0.2em 0; + box-shadow: 0px 0.2em 0.2em #f7c7867d; + text-shadow: 0 0 2px #00000080; + font-weight: 200; } p { margin: 0.2em; @@ -146,7 +150,6 @@ hr { max-width: 13em; } .providers .error { - color: darkred; padding: 3px 6px; text-align: center; font-weight: bold; From 8b8bbcd7a0e0b49830b4e7da3ece323435060f48 Mon Sep 17 00:00:00 2001 From: phil Date: Sat, 1 Feb 2025 18:51:17 +0100 Subject: [PATCH 22/79] Fix resource server error with scope --- src/oidc_test/auth_utils.py | 2 ++ src/oidc_test/database.py | 2 ++ src/oidc_test/models.py | 7 +++++++ src/oidc_test/resource_server.py | 24 +++++++++--------------- 4 files changed, 20 insertions(+), 15 deletions(-) diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index 2fcfc76..ed3350c 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -221,6 +221,7 @@ async def get_user_from_token( ) try: user = await db.get_user(user_id) + user.access_token = payload except UserNotInDB: logger.info( f"User {user_id} not found in DB, creating it (real apps can behave differently" @@ -230,5 +231,6 @@ async def get_user_from_token( user_info=payload, oidc_provider=getattr(authlib_oauth, auth_provider_id), user_info_from_endpoint={}, + access_token=payload, ) return user diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index 0252d28..5dec7fc 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -26,8 +26,10 @@ class Database: user_info: dict, oidc_provider: StarletteOAuth2App, user_info_from_endpoint: dict, + access_token: dict, ) -> User: user = User.from_auth(userinfo=user_info, oidc_provider=oidc_provider) + user.access_token = access_token try: raw_roles = user_info_from_endpoint["resource_access"][ oidc_provider.client_id diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py index b6a267e..542a9c4 100644 --- a/src/oidc_test/models.py +++ b/src/oidc_test/models.py @@ -32,6 +32,7 @@ class User(UserBase): also the key for the database 'table'""", ) userinfo: dict = {} + access_token: dict = {} oidc_provider: StarletteOAuth2App | None = None @classmethod @@ -49,3 +50,9 @@ class User(UserBase): @cached_property def roles_as_set(self) -> set[str]: return set([role.name for role in self.roles]) + + def has_scope(self, scope: str) -> bool: + """Check if the scope is present in user info or access token""" + info_scopes = self.userinfo.get("scope", "").split(" ") + access_token_scopes = self.access_token.get("scope", "").split(" ") + return scope in set(info_scopes + access_token_scopes) diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index 67f87dd..ecaa597 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -22,19 +22,15 @@ async def get_resource(resource_id: str, user: User) -> dict: # but this has to be refined for production required_scope = f"get:{resource_id}" # Check if the required scope is in the scopes allowed in userinfo - if "required_scope" in user.userinfo: - user_scopes = user.userinfo["required_scope"].split(" ") - if required_scope in user_scopes: - await process(user, required_scope, resp) - else: - ## For the showcase, giving a explanation. - ## Alternatively, raise HTTP_401_UNAUTHORIZED - resp["sorry"] = ( - f"No scope {required_scope} in the access token " - + "but it is required for accessing this resource." - ) + if user.has_scope(required_scope): + await process(user, resource_id, resp) else: - resp["sorry"] = "There is no scope in id token" + ## For the showcase, giving a explanation. + ## Alternatively, raise HTTP_401_UNAUTHORIZED + resp["sorry"] = ( + f"No scope {required_scope} in the access token " + + "but it is required for accessing this resource." + ) return resp @@ -51,9 +47,7 @@ async def process(user, resource_id, resp): bs = await client.get("https://corporatebs-generator.sameerkumar.website/") resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.") else: - resp["sorry"] = ( - f"I don't known how to give '{resource_id}' but i know corporate bs." - ) + resp["sorry"] = f"I don't known how to give '{resource_id}'." # assert user.oidc_provider is not None From e1dac777388ca2f6850f88d9afccb4d2f394811c Mon Sep 17 00:00:00 2001 From: phil Date: Sun, 2 Feb 2025 15:54:44 +0100 Subject: [PATCH 23/79] Decode access token, refactor --- src/oidc_test/auth_utils.py | 22 ++++++---------------- src/oidc_test/database.py | 11 +++++++++-- src/oidc_test/main.py | 15 ++++++++++++--- src/oidc_test/settings.py | 16 ++++++++++++++-- src/oidc_test/static/styles.css | 1 + 5 files changed, 42 insertions(+), 23 deletions(-) diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index ed3350c..33ca582 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -6,14 +6,14 @@ from fastapi import HTTPException, Request, Depends, status from fastapi.security import OAuth2PasswordBearer from authlib.oauth2.rfc6749 import OAuth2Token from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App -from jwt import ExpiredSignatureError, InvalidKeyError, decode +from jwt import ExpiredSignatureError, InvalidKeyError from httpx import AsyncClient # from authlib.oauth1.auth import OAuthToken # from authlib.oauth2.auth import OAuth2Token from .models import User -from .database import db, UserNotInDB +from .database import TokenNotInDb, db, UserNotInDB from .settings import settings, OIDCProvider logger = logging.getLogger(__name__) @@ -126,9 +126,10 @@ async def get_current_user(request: Request) -> User: async def get_token(request: Request) -> OAuth2Token: """Return the token from a request object, from the session. It can be used in Depends()""" - if (token := await db.get_token(request.session.get("token"))) is None: + try: + return await db.get_token(request.session["token"]) + except (KeyError, TokenNotInDb): raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token") - return token async def get_current_user_or_none(request: Request) -> User | None: @@ -189,19 +190,8 @@ async def get_user_from_token( raise HTTPException( status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'" ) - if (key := auth_provider_settings.get_public_key()) is None: - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, - f"Key for provider '{auth_provider_id}' unknown", - ) try: - payload = decode( - token, - key=key, - algorithms=["RS256"], - audience="oidc-test", - options={"verify_signature": not settings.insecure.skip_verify_signature}, - ) + payload = auth_provider_settings.decode(token) except ExpiredSignatureError as err: logger.info(f"Expired signature: {err}") raise HTTPException( diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index 5dec7fc..b2cf1b9 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -14,6 +14,10 @@ class UserNotInDB(Exception): pass +class TokenNotInDb(Exception): + pass + + class Database: users: dict[str, User] = {} tokens: dict[str, OAuth2Token] = {} @@ -50,8 +54,11 @@ class Database: async def add_token(self, token: OAuth2Token, user: User) -> None: self.tokens[token["id_token"]] = token - async def get_token(self, id_token: str) -> OAuth2Token | None: - return self.tokens.get(id_token) + async def get_token(self, id_token: str) -> OAuth2Token: + try: + return self.tokens[id_token] + except KeyError: + raise TokenNotInDb db = Database() diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 739dd1b..ef19245 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -187,11 +187,20 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: # Build and remember the user in the session request.session["user_sub"] = sub # Store the user in the database + try: + oidc_provider_settings = oidc_providers_settings[oidc_provider_id] + access_token = oidc_provider_settings.decode(token["access_token"]) + except Exception: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, + detail="Cannot decode token or verify its signature", + ) user = await db.add_user( sub, user_info=userinfo, oidc_provider=oidc_provider, user_info_from_endpoint=user_info_from_endpoint, + access_token=access_token, ) # Add the id_token to the session request.session["token"] = token["id_token"] @@ -213,14 +222,14 @@ async def account( oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], ) -> RedirectResponse: if ( - provider := oidc_providers_settings.get( + oidc_provider_settings := oidc_providers_settings.get( request.session.get("oidc_provider_id", "") ) ) is None: raise HTTPException( - status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider setting" + status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider settings" ) - return RedirectResponse(f"{provider.account_url}") + return RedirectResponse(f"{oidc_provider_settings.account_url}") @app.get("/logout") diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index 329b9c0..46d857d 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -4,6 +4,7 @@ import random from typing import Type, Tuple from pathlib import Path +from jwt import decode from pydantic import BaseModel, computed_field, AnyUrl from pydantic_settings import ( BaseSettings, @@ -43,6 +44,7 @@ class OIDCProvider(BaseModel): None # Info fetched from info_url, eg. public key ) public_key: str | None = None + signature_alg: str = "RS256" @computed_field @property @@ -70,19 +72,29 @@ class OIDCProvider(BaseModel): else: return None - def get_public_key(self) -> str | None: + def get_public_key(self) -> str: """Return the public key formatted for decoding token""" public_key = self.public_key or ( self.info is not None and self.info["public_key"] ) if public_key is None: - return None + raise AttributeError(f"Cannot get public key for {self.name}") return f""" -----BEGIN PUBLIC KEY----- {public_key} -----END PUBLIC KEY----- """ + def decode(self, token: str) -> dict: + """Decode the token with signature check""" + return decode( + token, + self.get_public_key(), + algorithms=[self.signature_alg], + audience=["oidc-test", "oidc-test-web"], + options={"verify_signature": not settings.insecure.skip_verify_signature}, + ) + class ResourceProvider(BaseModel): id: str diff --git a/src/oidc_test/static/styles.css b/src/oidc_test/static/styles.css index a4a0178..cc84736 100644 --- a/src/oidc_test/static/styles.css +++ b/src/oidc_test/static/styles.css @@ -104,6 +104,7 @@ hr { .role { padding: 3px 6px; background-color: #44228840; + border-radius: 6px; } /* For home */ From dc181bd3a841fcb3b598877035b217c056e8d192 Mon Sep 17 00:00:00 2001 From: phil Date: Mon, 3 Feb 2025 13:20:33 +0100 Subject: [PATCH 24/79] Store raw access token within user; get resource --- src/oidc_test/auth_utils.py | 13 +++++----- src/oidc_test/database.py | 2 +- src/oidc_test/main.py | 8 ++++--- src/oidc_test/models.py | 16 ++++++++++--- src/oidc_test/static/styles.css | 24 +++++++++---------- src/oidc_test/static/utils.js | 21 ++++++++++++++++ src/oidc_test/templates/home.html | 40 +++++++++++++++++++++++++++---- 7 files changed, 94 insertions(+), 30 deletions(-) diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index 33ca582..4c3b98c 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -25,8 +25,8 @@ oidc_providers_settings: dict[str, OIDCProvider] = dict( oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") -def fetch_token(name, request): - breakpoint() +async def fetch_token(name, request): + logger.warn("TODO: fetch_token") ... # if name in oidc_providers: # model = OAuth2Token @@ -37,8 +37,8 @@ def fetch_token(name, request): # return token.to_token() -def update_token(*args, **kwargs): - breakpoint() +async def update_token(*args, **kwargs): + logger.warn("TODO: update_token") ... @@ -211,7 +211,8 @@ async def get_user_from_token( ) try: user = await db.get_user(user_id) - user.access_token = payload + if user.access_token != token: + user.access_token = token except UserNotInDB: logger.info( f"User {user_id} not found in DB, creating it (real apps can behave differently" @@ -221,6 +222,6 @@ async def get_user_from_token( user_info=payload, oidc_provider=getattr(authlib_oauth, auth_provider_id), user_info_from_endpoint={}, - access_token=payload, + access_token=token, ) return user diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index b2cf1b9..1b682ef 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -30,7 +30,7 @@ class Database: user_info: dict, oidc_provider: StarletteOAuth2App, user_info_from_endpoint: dict, - access_token: dict, + access_token: str, ) -> User: user = User.from_auth(userinfo=user_info, oidc_provider=oidc_provider) user.access_token = access_token diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index ef19245..3d95009 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -200,7 +200,7 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: user_info=userinfo, oidc_provider=oidc_provider, user_info_from_endpoint=user_info_from_endpoint, - access_token=access_token, + access_token=token["access_token"], ) # Add the id_token to the session request.session["token"] = token["id_token"] @@ -229,7 +229,7 @@ async def account( raise HTTPException( status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider settings" ) - return RedirectResponse(f"{oidc_provider_settings.account_url}") + return RedirectResponse(f"{oidc_provider_settings.account_url_template}") @app.get("/logout") @@ -243,7 +243,9 @@ async def logout( if ( provider_logout_uri := oidc_provider.server_metadata.get("end_session_endpoint") ) is None: - logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}") + logger.warn( + f"Cannot find end_session_endpoint for provider {oidc_provider.name}" + ) return RedirectResponse(request.url_for("non_compliant_logout")) post_logout_uri = request.url_for("home") if (token := await db.get_token(request.session.pop("token", None))) is None: diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py index 542a9c4..db5d6ad 100644 --- a/src/oidc_test/models.py +++ b/src/oidc_test/models.py @@ -25,14 +25,14 @@ class UserBase(SQLModel, extra="ignore"): class User(UserBase): - model_config = ConfigDict(arbitrary_types_allowed=True) + model_config = ConfigDict(arbitrary_types_allowed=True) # type:ignore sub: str = Field( description="""subject id of the user given by the oidc provider, also the key for the database 'table'""", ) userinfo: dict = {} - access_token: dict = {} + access_token: str | None = None oidc_provider: StarletteOAuth2App | None = None @classmethod @@ -54,5 +54,15 @@ class User(UserBase): def has_scope(self, scope: str) -> bool: """Check if the scope is present in user info or access token""" info_scopes = self.userinfo.get("scope", "").split(" ") - access_token_scopes = self.access_token.get("scope", "").split(" ") + access_token_scopes = self.access_token_parsed().get("scope", "").split(" ") return scope in set(info_scopes + access_token_scopes) + + def access_token_parsed(self): + assert self.access_token is not None + assert self.oidc_provider is not None + assert self.oidc_provider.name is not None + from .auth_utils import oidc_providers_settings + + return oidc_providers_settings[self.oidc_provider.name].decode( + self.access_token + ) diff --git a/src/oidc_test/static/styles.css b/src/oidc_test/static/styles.css index cc84736..4552ca0 100644 --- a/src/oidc_test/static/styles.css +++ b/src/oidc_test/static/styles.css @@ -3,9 +3,9 @@ body { background-color: floralwhite; margin: 0; font-family: system-ui; + text-align: center; } h1 { - text-align: center; background-color: #f7c7867d; margin: 0 0 0.2em 0; box-shadow: 0px 0.2em 0.2em #f7c7867d; @@ -21,9 +21,6 @@ hr { .hidden { display: none; } -.center { - text-align: center; -} .content { width: 100%; display: flex; @@ -55,7 +52,6 @@ hr { border: 2px solid darkkhaki; padding: 3px 6px; text-decoration: none; - text-align: center; color: black; } .user-info a.logout:hover { @@ -70,7 +66,6 @@ hr { margin: 0; } .debug-auth p { - text-align: center; border-bottom: 1px solid black; } .debug-auth ul { @@ -101,16 +96,24 @@ hr { .hasResponseStatus.status-503 { background-color: #ffA88050; } -.role { + +.role, .scope { padding: 3px 6px; - background-color: #44228840; + margin: 3px; border-radius: 6px; } +.role { + background-color: #44228840; +} + +.scope { + background-color: #8888FF80; +} + /* For home */ .login-box { - text-align: center; background-color: antiquewhite; margin: 0.5em auto; width: fit-content; @@ -137,7 +140,6 @@ hr { max-height: 2em; } .providers .provider .link div { - text-align: center; background-color: #f7c7867d; border-radius: 8px; padding: 6px; @@ -152,13 +154,11 @@ hr { } .providers .error { padding: 3px 6px; - text-align: center; font-weight: bold; flex: 1 1 auto; } .content .links-to-check { display: flex; - text-align: center; justify-content: center; gap: 0.5em; flex-flow: wrap; diff --git a/src/oidc_test/static/utils.js b/src/oidc_test/static/utils.js index 6b40d3d..142fa6e 100644 --- a/src/oidc_test/static/utils.js +++ b/src/oidc_test/static/utils.js @@ -17,3 +17,24 @@ function checkPerms(className) { Array.from(elem.children).forEach(elem => checkHref(elem)) ) } + +async function get_resource(id, token, authProvider) { + //if (!keycloak.keycloak) { return } + const resp = await fetch("resource/" + id, { + method: "GET", + headers: new Headers({ + "Content-type": "application/json", + "Authorization": `Bearer ${token}`, + "auth_provider": authProvider, + }), + }) + /* + resource.value = resp['data'] + msg.value = "" + } + ).catch ( + err => msg.value = err + ) +*/ + console.log(await resp.json()) +} diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index c062101..bba2f2a 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -30,6 +30,10 @@ {% endif %}
{{ user.email }}
+
+ Provider: + {{ oidc_provider_settings.name }} +
{% if user.roles %}
Roles: @@ -38,17 +42,43 @@ {% endfor %}
{% endif %} -
- Provider: - {{ oidc_provider_settings.name }} -
+ {% if user.access_token.scope %} +
+ Scopes: + {% for scope in user.access_token.scope.split(' ') %} + {{ scope }} + {% endfor %} +
+ {% endif %} {% if oidc_provider_settings.account_url_template %} - + {% endif %} {% endif %}
+

+ Fetch resources from the resource server with your authentication token: +

+
+ + +
+
+
+
+
{{ key }}
+
{{ value }}
+
{{ value }}
+
+
+
+
{{ msg }}
+

These links should get different response codes depending on the authorization: From af49242192c29a38b58819ebd0f2e8ae55bd87b6 Mon Sep 17 00:00:00 2001 From: phil Date: Tue, 4 Feb 2025 02:27:32 +0100 Subject: [PATCH 25/79] Add self resouce provider --- src/oidc_test/auth_utils.py | 25 ++++++++++++++++++------- src/oidc_test/main.py | 1 + src/oidc_test/resource_server.py | 24 +++++++++++++++--------- src/oidc_test/settings.py | 1 + src/oidc_test/static/styles.css | 20 ++++++++++++++++++++ src/oidc_test/static/utils.js | 18 +++++++++++++++++- src/oidc_test/templates/home.html | 17 +++++------------ 7 files changed, 77 insertions(+), 29 deletions(-) diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index 4c3b98c..d96aba9 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -4,13 +4,12 @@ import logging from fastapi import HTTPException, Request, Depends, status from fastapi.security import OAuth2PasswordBearer -from authlib.oauth2.rfc6749 import OAuth2Token from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App from jwt import ExpiredSignatureError, InvalidKeyError from httpx import AsyncClient # from authlib.oauth1.auth import OAuthToken -# from authlib.oauth2.auth import OAuth2Token +from authlib.oauth2.auth import OAuth2Token from .models import User from .database import TokenNotInDb, db, UserNotInDB @@ -21,7 +20,6 @@ logger = logging.getLogger(__name__) oidc_providers_settings: dict[str, OIDCProvider] = dict( [(provider.id, provider) for provider in settings.oidc.providers] ) - oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") @@ -37,9 +35,19 @@ async def fetch_token(name, request): # return token.to_token() -async def update_token(*args, **kwargs): - logger.warn("TODO: update_token") - ... +async def update_token(name, token, refresh_token=None, access_token=None): + breakpoint() + if refresh_token: + item = OAuth2Token.find(name=name, refresh_token=refresh_token) + elif access_token: + item = OAuth2Token.find(name=name, access_token=access_token) + else: + return + # update old token + item.access_token = token["access_token"] + item.refresh_token = token.get("refresh_token") + item.expires_at = token["expires_at"] + item.save() authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) @@ -52,7 +60,10 @@ def init_providers(): name=id, server_metadata_url=provider.openid_configuration, client_kwargs={ - "scope": "openid email offline_access profile", + "scope": " ".join( + ["openid", "email", "offline_access", "profile"] + + provider.resource_provider_scopes + ), }, client_id=provider.client_id, client_secret=provider.client_secret, diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 3d95009..e14b4a8 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -102,6 +102,7 @@ async def home( context={ "settings": settings.model_dump(), "user": user, + "access_token_scope": user.access_token_parsed()["scope"] if user else None, "now": now, "oidc_provider": oidc_provider, "oidc_provider_settings": oidc_provider_settings, diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index ecaa597..fbee866 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -2,6 +2,7 @@ from datetime import datetime import logging from httpx import AsyncClient +from jwt.exceptions import ExpiredSignatureError, InvalidTokenError from .models import User @@ -22,15 +23,20 @@ async def get_resource(resource_id: str, user: User) -> dict: # but this has to be refined for production required_scope = f"get:{resource_id}" # Check if the required scope is in the scopes allowed in userinfo - if user.has_scope(required_scope): - await process(user, resource_id, resp) - else: - ## For the showcase, giving a explanation. - ## Alternatively, raise HTTP_401_UNAUTHORIZED - resp["sorry"] = ( - f"No scope {required_scope} in the access token " - + "but it is required for accessing this resource." - ) + try: + if user.has_scope(required_scope): + await process(user, resource_id, resp) + else: + ## For the showcase, giving a explanation. + ## Alternatively, raise HTTP_401_UNAUTHORIZED + resp["sorry"] = ( + f"No scope {required_scope} in the access token " + + "but it is required for accessing this resource." + ) + except ExpiredSignatureError: + resp["sorry"] = "The token's signature has expired" + except InvalidTokenError: + resp["sorry"] = "The token is invalid" return resp diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index 46d857d..4d08ada 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -45,6 +45,7 @@ class OIDCProvider(BaseModel): ) public_key: str | None = None signature_alg: str = "RS256" + resource_provider_scopes: list[str] = [] @computed_field @property diff --git a/src/oidc_test/static/styles.css b/src/oidc_test/static/styles.css index 4552ca0..b0753b7 100644 --- a/src/oidc_test/static/styles.css +++ b/src/oidc_test/static/styles.css @@ -170,3 +170,23 @@ hr { border-radius: 8px; } +.resource { + padding: 0.5em; + display: flex; + gap: 0.5em; + flex-direction: column; + width: fit-content; + align-items: center; + margin: 5px auto; + box-shadow: 0px 0px 10px #90c3eeA0; + background-color: #90c3eeA0; + border-radius: 8px; +} + +.resources { + display: flex; +} + +.key { + font-weight: bold; +} diff --git a/src/oidc_test/static/utils.js b/src/oidc_test/static/utils.js index 142fa6e..9cc8040 100644 --- a/src/oidc_test/static/utils.js +++ b/src/oidc_test/static/utils.js @@ -36,5 +36,21 @@ async function get_resource(id, token, authProvider) { err => msg.value = err ) */ - console.log(await resp.json()) + const resource = await resp.json() + const rootElem = document.getElementById('resource') + rootElem.innerHTML = "" + Object.entries(resource).forEach( + ([k, v]) => { + let r = document.createElement('div') + let kElem = document.createElement('div') + kElem.innerText = k + kElem.className = "key" + let vElem = document.createElement('div') + vElem.innerText = v + vElem.className = "value" + r.appendChild(kElem) + r.appendChild(vElem) + rootElem.appendChild(r) + } + ) } diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index bba2f2a..55bd844 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -42,10 +42,10 @@ {% endfor %}

{% endif %} - {% if user.access_token.scope %} + {% if access_token_scope %}
Scopes: - {% for scope in user.access_token.scope.split(' ') %} + {% for scope in access_token_scope.split(' ') %} {{ scope }} {% endfor %}
@@ -61,6 +61,7 @@ {% endif %}
+ {% if user %}

Fetch resources from the resource server with your authentication token:

@@ -68,17 +69,9 @@ -
-
-
-
{{ key }}
-
{{ value }}
-
{{ value }}
-
-
-
-
{{ msg }}
+

+ {% endif %}

These links should get different response codes depending on the authorization: From fefe44acfef0a16c492070ac8bc9f43d4637aa2c Mon Sep 17 00:00:00 2001 From: phil Date: Tue, 4 Feb 2025 03:03:28 +0100 Subject: [PATCH 26/79] CI: build only if git clean tag --- .forgejo/workflows/build.yaml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.forgejo/workflows/build.yaml b/.forgejo/workflows/build.yaml index df52e0b..e02bf47 100644 --- a/.forgejo/workflows/build.yaml +++ b/.forgejo/workflows/build.yaml @@ -48,11 +48,13 @@ jobs: run: sed "s/0.0.0/$VERSION/" -i pyproject.toml - name: Workaround for bug of podman-login + if: fromJSON(steps.builder.outputs.run) run: | mkdir -p $HOME/.docker echo "{ \"auths\": {} }" > $HOME/.docker/config.json - name: Log in to the container registry (with another workaround) + if: fromJSON(steps.builder.outputs.run) uses: actions/podman-login@v1 with: registry: ${{ vars.REGISTRY }} @@ -61,6 +63,7 @@ jobs: auth_file_path: /tmp/auth.json - name: Build the container image + if: fromJSON(steps.builder.outputs.run) uses: actions/buildah-build@v1 with: image: oidc-fastapi-test @@ -71,6 +74,7 @@ jobs: ./Containerfile - name: Push the image to the registry + if: fromJSON(steps.builder.outputs.run) uses: actions/push-to-registry@v2 with: registry: "docker://${{ vars.REGISTRY }}/${{ vars.ORGANISATION }}" @@ -78,9 +82,11 @@ jobs: tags: latest ${{ steps.version.outputs.version }} - name: Build wheel + if: fromJSON(steps.builder.outputs.run) run: uv build --wheel - name: Publish Python package (home) + if: fromJSON(steps.builder.outputs.run) env: LOCAL_PYPI_TOKEN: ${{ secrets.LOCAL_PYPI_TOKEN }} run: uv publish --publish-url https://code.philo.ydns.eu/api/packages/philorg/pypi --token $LOCAL_PYPI_TOKEN From aa86f8135843232739df3790aa137b3da4d78f85 Mon Sep 17 00:00:00 2001 From: phil Date: Tue, 4 Feb 2025 03:38:33 +0100 Subject: [PATCH 27/79] Fix home when token cannot be decoded --- src/oidc_test/main.py | 27 +++++++++++++++++++++++---- src/oidc_test/models.py | 11 +++++++++-- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index e14b4a8..aac258b 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -15,6 +15,7 @@ from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse from fastapi.templating import Jinja2Templates from fastapi.middleware.cors import CORSMiddleware +from jwt import InvalidKeyError, InvalidTokenError from starlette.middleware.sessions import SessionMiddleware from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.base_client import OAuthError @@ -96,13 +97,24 @@ async def home( else: resources = [] oidc_provider_settings = None + + if user is None: + access_token_scope = None + else: + try: + access_token_scope = user.decode_access_token()["scope"] + except InvalidTokenError as err: + access_token_scope = None + logger.info("Invalid token") + logger.exception(err) + return templates.TemplateResponse( name="home.html", request=request, context={ "settings": settings.model_dump(), "user": user, - "access_token_scope": user.access_token_parsed()["scope"] if user else None, + "access_token_scope": access_token_scope, "now": now, "oidc_provider": oidc_provider, "oidc_provider_settings": oidc_provider_settings, @@ -187,15 +199,22 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: user_info_from_endpoint = {} # Build and remember the user in the session request.session["user_sub"] = sub - # Store the user in the database + # Verify the token's signature and validity try: oidc_provider_settings = oidc_providers_settings[oidc_provider_id] - access_token = oidc_provider_settings.decode(token["access_token"]) - except Exception: + oidc_provider_settings.decode(token["access_token"]) + except InvalidKeyError: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, + detail="Token invalid key / signature", + ) + except Exception as err: + logger.exception(err) raise HTTPException( status.HTTP_401_UNAUTHORIZED, detail="Cannot decode token or verify its signature", ) + # Store the user in the database user = await db.add_user( sub, user_info=userinfo, diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py index db5d6ad..4b1c064 100644 --- a/src/oidc_test/models.py +++ b/src/oidc_test/models.py @@ -1,3 +1,4 @@ +import logging from functools import cached_property from typing import Self @@ -10,6 +11,8 @@ from pydantic import ( from authlib.integrations.starlette_client.apps import StarletteOAuth2App from sqlmodel import SQLModel, Field +logger = logging.getLogger(__name__) + class Role(SQLModel, extra="ignore"): name: str @@ -54,10 +57,14 @@ class User(UserBase): def has_scope(self, scope: str) -> bool: """Check if the scope is present in user info or access token""" info_scopes = self.userinfo.get("scope", "").split(" ") - access_token_scopes = self.access_token_parsed().get("scope", "").split(" ") + try: + access_token_scopes = self.decode_access_token().get("scope", "").split(" ") + except Exception as err: + logger.info(f"Access token cannot be decoded: {err}") + access_token_scopes = [] return scope in set(info_scopes + access_token_scopes) - def access_token_parsed(self): + def decode_access_token(self): assert self.access_token is not None assert self.oidc_provider is not None assert self.oidc_provider.name is not None From 31a783cbf19be0fe2e78f1fdb053c485d379d3a1 Mon Sep 17 00:00:00 2001 From: phil Date: Tue, 4 Feb 2025 18:03:17 +0100 Subject: [PATCH 28/79] Fix token error handling --- src/oidc_test/auth_utils.py | 7 +++---- src/oidc_test/database.py | 4 +++- src/oidc_test/main.py | 8 +++++--- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index d96aba9..fd82ecd 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -113,8 +113,7 @@ async def get_current_user(request: Request) -> User: """ if (user_sub := request.session.get("user_sub")) is None: raise HTTPException(status.HTTP_401_UNAUTHORIZED) - if (token := await db.get_token(request.session["token"])) is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Token unknown") + token = await get_token(request) user = await db.get_user(user_sub) ## Check if the token is expired if token.is_expired(): @@ -138,8 +137,8 @@ async def get_token(request: Request) -> OAuth2Token: """Return the token from a request object, from the session. It can be used in Depends()""" try: - return await db.get_token(request.session["token"]) - except (KeyError, TokenNotInDb): + return await db.get_token(request.session.get("token")) + except TokenNotInDb: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token") diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index 1b682ef..0a30e9c 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -54,7 +54,9 @@ class Database: async def add_token(self, token: OAuth2Token, user: User) -> None: self.tokens[token["id_token"]] = token - async def get_token(self, id_token: str) -> OAuth2Token: + async def get_token(self, id_token: str | None) -> OAuth2Token: + if id_token is None: + raise TokenNotInDb try: return self.tokens[id_token] except KeyError: diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index aac258b..e9ba5b1 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -41,7 +41,7 @@ from .auth_utils import ( get_providers_info, ) from .auth_misc import pretty_details -from .database import db +from .database import TokenNotInDb, db from .resource_server import get_resource logger = logging.getLogger("uvicorn.error") @@ -268,8 +268,10 @@ async def logout( ) return RedirectResponse(request.url_for("non_compliant_logout")) post_logout_uri = request.url_for("home") - if (token := await db.get_token(request.session.pop("token", None))) is None: - logger.warn("No session in db for the token") + try: + token = await db.get_token(request.session.pop("token", None)) + except TokenNotInDb: + logger.warn("No session in db for the token or no token") return RedirectResponse(request.url_for("home")) logout_url = ( provider_logout_uri From 3dc14ae57ba9aaafeff97a6ed3e20e8f1b35f4f8 Mon Sep 17 00:00:00 2001 From: phil Date: Tue, 4 Feb 2025 18:19:58 +0100 Subject: [PATCH 29/79] Cosmetic --- src/oidc_test/static/styles.css | 38 +++++++++++++++++++++++++++++++++ src/oidc_test/static/utils.js | 3 +++ 2 files changed, 41 insertions(+) diff --git a/src/oidc_test/static/styles.css b/src/oidc_test/static/styles.css index b0753b7..426a464 100644 --- a/src/oidc_test/static/styles.css +++ b/src/oidc_test/static/styles.css @@ -21,6 +21,12 @@ hr { .hidden { display: none; } +.center { + text-align: center; +} +.error { + color: darkred; +} .content { width: 100%; display: flex; @@ -111,6 +117,7 @@ hr { background-color: #8888FF80; } + /* For home */ .login-box { @@ -170,6 +177,17 @@ hr { border-radius: 8px; } +.token { + overflow-wrap: anywhere; + font-family: monospace; +} + +.actions { + display: flex; + justify-content: center; + gap: 0.5em; +} + .resource { padding: 0.5em; display: flex; @@ -187,6 +205,26 @@ hr { display: flex; } +.resource { + text-align: center; +} + +.token-info { + margin: 0 1em; +} + .key { font-weight: bold; } + +.token .key, .token .value { + display: inline; +} +.token .value { + padding-left: 1em; +} + +.msg { + text-align: center; + font-weight: bold; +} diff --git a/src/oidc_test/static/utils.js b/src/oidc_test/static/utils.js index 9cc8040..2fdb32d 100644 --- a/src/oidc_test/static/utils.js +++ b/src/oidc_test/static/utils.js @@ -48,6 +48,9 @@ async function get_resource(id, token, authProvider) { let vElem = document.createElement('div') vElem.innerText = v vElem.className = "value" + if (k == "sorry") { + vElem.classList.add("error") + } r.appendChild(kElem) r.appendChild(vElem) rootElem.appendChild(r) From b86ae4eb112ec142ba816cb391ae4f8baad54b60 Mon Sep 17 00:00:00 2001 From: phil Date: Wed, 5 Feb 2025 02:13:09 +0100 Subject: [PATCH 30/79] Raise HTTPException on resource server error --- src/oidc_test/auth_utils.py | 1 + src/oidc_test/resource_server.py | 17 ++++++++++++----- src/oidc_test/static/styles.css | 4 ++-- src/oidc_test/static/utils.js | 25 +++++++++++++++---------- src/oidc_test/templates/home.html | 8 ++++++-- 5 files changed, 36 insertions(+), 19 deletions(-) diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index fd82ecd..0c8dcc7 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -37,6 +37,7 @@ async def fetch_token(name, request): async def update_token(name, token, refresh_token=None, access_token=None): breakpoint() + item = await db.get_token(token["id_token"]) if refresh_token: item = OAuth2Token.find(name=name, refresh_token=refresh_token) elif access_token: diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index fbee866..635a91b 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -3,6 +3,8 @@ import logging from httpx import AsyncClient from jwt.exceptions import ExpiredSignatureError, InvalidTokenError +from fastapi import HTTPException, status +from starlette.status import HTTP_401_UNAUTHORIZED from .models import User @@ -29,14 +31,17 @@ async def get_resource(resource_id: str, user: User) -> dict: else: ## For the showcase, giving a explanation. ## Alternatively, raise HTTP_401_UNAUTHORIZED - resp["sorry"] = ( + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, f"No scope {required_scope} in the access token " - + "but it is required for accessing this resource." + + "but it is required for accessing this resource.", ) except ExpiredSignatureError: - resp["sorry"] = "The token's signature has expired" + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, "The token's signature has expired" + ) except InvalidTokenError: - resp["sorry"] = "The token is invalid" + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token is invalid") return resp @@ -53,7 +58,9 @@ async def process(user, resource_id, resp): bs = await client.get("https://corporatebs-generator.sameerkumar.website/") resp["bs"] = bs.json().get("phrase", "Sorry, i am out of BS today.") else: - resp["sorry"] = f"I don't known how to give '{resource_id}'." + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, f"I don't known how to give '{resource_id}'." + ) # assert user.oidc_provider is not None diff --git a/src/oidc_test/static/styles.css b/src/oidc_test/static/styles.css index 426a464..7e1260b 100644 --- a/src/oidc_test/static/styles.css +++ b/src/oidc_test/static/styles.css @@ -73,6 +73,7 @@ hr { } .debug-auth p { border-bottom: 1px solid black; + text-align: left; } .debug-auth ul { padding: 0; @@ -188,9 +189,8 @@ hr { gap: 0.5em; } -.resource { +.resourceResult { padding: 0.5em; - display: flex; gap: 0.5em; flex-direction: column; width: fit-content; diff --git a/src/oidc_test/static/utils.js b/src/oidc_test/static/utils.js index 2fdb32d..a982267 100644 --- a/src/oidc_test/static/utils.js +++ b/src/oidc_test/static/utils.js @@ -20,6 +20,8 @@ function checkPerms(className) { async function get_resource(id, token, authProvider) { //if (!keycloak.keycloak) { return } + const msg = document.getElementById("msg") + const resourceElem = document.getElementById('resource') const resp = await fetch("resource/" + id, { method: "GET", headers: new Headers({ @@ -27,18 +29,21 @@ async function get_resource(id, token, authProvider) { "Authorization": `Bearer ${token}`, "auth_provider": authProvider, }), + }).catch(err => { + msg.innerHTML = "Cannot fetch resource: " + err.message + resourceElem.innerHTML = "" }) - /* - resource.value = resp['data'] - msg.value = "" + if (resp === undefined) { + return } - ).catch ( - err => msg.value = err - ) -*/ const resource = await resp.json() - const rootElem = document.getElementById('resource') - rootElem.innerHTML = "" + if (!resp.ok) { + msg.innerHTML = resource["detail"] + resourceElem.innerHTML = "" + return + } + msg.innerHTML = "" + resourceElem.innerHTML = "" Object.entries(resource).forEach( ([k, v]) => { let r = document.createElement('div') @@ -53,7 +58,7 @@ async function get_resource(id, token, authProvider) { } r.appendChild(kElem) r.appendChild(vElem) - rootElem.appendChild(r) + resourceElem.appendChild(r) } ) } diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 55bd844..ce344cc 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -69,7 +69,10 @@

-
+
+
+
+

{% endif %}
@@ -96,6 +99,7 @@ {% endfor %}
{% endif %} + {% if user_info_details %}
@@ -103,7 +107,7 @@
    {% for key, value in user_info_details.items() %}
  • - {{ key }}: {{ value }} + {{ key }}: {{ value }}
  • {% endfor %}
From 76da695b66d306eda3e57cf0a562d33b4c004ac9 Mon Sep 17 00:00:00 2001 From: phil Date: Thu, 6 Feb 2025 13:27:14 +0100 Subject: [PATCH 31/79] Set black config - line length --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 980bcfc..b1e6504 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,3 +35,6 @@ packages = ["src/oidc_test"] [tool.uv] package = true + +[tool.black] +line-length = 98 From 5c9ed9724e1a53bee8230adeca3174e0ccc8c086 Mon Sep 17 00:00:00 2001 From: phil Date: Thu, 6 Feb 2025 13:27:45 +0100 Subject: [PATCH 32/79] Add logging conf for debugging --- log_conf.yaml | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 log_conf.yaml diff --git a/log_conf.yaml b/log_conf.yaml new file mode 100644 index 0000000..a6bb0b4 --- /dev/null +++ b/log_conf.yaml @@ -0,0 +1,34 @@ +version: 1 +disable_existing_loggers: False +formatters: + default: + "()": uvicorn.logging.DefaultFormatter + format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + access: + "()": uvicorn.logging.AccessFormatter + format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s" +handlers: + default: + formatter: default + class: logging.StreamHandler + stream: ext://sys.stderr + access: + formatter: access + class: logging.StreamHandler + stream: ext://sys.stdout +loggers: + uvicorn.error: + level: INFO + handlers: + - default + propagate: no + uvicorn.access: + level: INFO + handlers: + - access + propagate: no + "oidc-test": + level: DEBUG + handlers: + - default + propagate: yes From ee8ba3d2df9bcf0baa0f799aac00ac9e010cfd5d Mon Sep 17 00:00:00 2001 From: phil Date: Thu, 6 Feb 2025 13:30:35 +0100 Subject: [PATCH 33/79] Get roles from access token, remove user info inspection, refreactorings --- src/oidc_test/auth_utils.py | 56 ++++++++++++++++------------- src/oidc_test/database.py | 58 ++++++++++++++++++++---------- src/oidc_test/main.py | 60 +++++++++++++------------------- src/oidc_test/models.py | 18 +++------- src/oidc_test/resource_server.py | 2 +- src/oidc_test/settings.py | 29 ++++++++++++--- 6 files changed, 126 insertions(+), 97 deletions(-) diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index 0c8dcc7..281511d 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -13,13 +13,10 @@ from authlib.oauth2.auth import OAuth2Token from .models import User from .database import TokenNotInDb, db, UserNotInDB -from .settings import settings, OIDCProvider +from .settings import settings, OIDCProvider, oidc_providers_settings -logger = logging.getLogger(__name__) +logger = logging.getLogger("oidc-test") -oidc_providers_settings: dict[str, OIDCProvider] = dict( - [(provider.id, provider) for provider in settings.oidc.providers] -) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") @@ -36,19 +33,16 @@ async def fetch_token(name, request): async def update_token(name, token, refresh_token=None, access_token=None): - breakpoint() - item = await db.get_token(token["id_token"]) - if refresh_token: - item = OAuth2Token.find(name=name, refresh_token=refresh_token) - elif access_token: - item = OAuth2Token.find(name=name, access_token=access_token) - else: - return + oidc_provider_settings = oidc_providers_settings[name] + sid: str = oidc_provider_settings.decode(token["id_token"])["sid"] + item = await db.get_token(oidc_provider_settings, sid) # update old token - item.access_token = token["access_token"] - item.refresh_token = token.get("refresh_token") - item.expires_at = token["expires_at"] - item.save() + item["access_token"] = token.get("access_token") + item["refresh_token"] = token.get("refresh_token") + item["expires_at"] = token["expires_at"] + logger.info(f"Token {sid} refreshed") + # It's a fake db and only in memory, so there's nothing to save + # await item.save() authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token) @@ -138,8 +132,17 @@ async def get_token(request: Request) -> OAuth2Token: """Return the token from a request object, from the session. It can be used in Depends()""" try: - return await db.get_token(request.session.get("token")) - except TokenNotInDb: + oidc_provider_settings = oidc_providers_settings[ + request.session.get("oidc_provider_id", "") + ] + except KeyError: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid provider") + try: + return await db.get_token( + oidc_provider_settings, + request.session.get("sid"), + ) + except (TokenNotInDb, InvalidKeyError): raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token") @@ -190,14 +193,16 @@ async def get_user_from_token( token: Annotated[str, Depends(oauth2_scheme)], request: Request, ) -> User: - if (auth_provider_id := request.headers.get("auth_provider")) is None: + try: + auth_provider_id = request.headers["auth_provider"] + except KeyError: raise HTTPException( status.HTTP_401_UNAUTHORIZED, "Request headers must have a 'auth_provider' field", ) - if ( - auth_provider_settings := oidc_providers_settings.get(auth_provider_id) - ) is None: + try: + auth_provider_settings = oidc_providers_settings[auth_provider_id] + except KeyError: raise HTTPException( status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'" ) @@ -216,7 +221,9 @@ async def get_user_from_token( logger.info("Cannot decode token, see below") logger.exception(err) raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot decode token") - if (user_id := payload.get("sub")) is None: + try: + user_id = payload["sub"] + except KeyError: raise HTTPException( status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found" ) @@ -232,7 +239,6 @@ async def get_user_from_token( sub=payload["sub"], user_info=payload, oidc_provider=getattr(authlib_oauth, auth_provider_id), - user_info_from_endpoint={}, access_token=token, ) return user diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index 0a30e9c..360ef11 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -3,11 +3,12 @@ import logging from authlib.integrations.starlette_client.apps import StarletteOAuth2App - -from .models import User, Role from authlib.oauth2.rfc6749 import OAuth2Token -logger = logging.getLogger(__name__) +from .settings import OIDCProvider, oidc_providers_settings +from .models import User, Role + +logger = logging.getLogger("oidc-test") class UserNotInDB(Exception): @@ -29,20 +30,34 @@ class Database: sub: str, user_info: dict, oidc_provider: StarletteOAuth2App, - user_info_from_endpoint: dict, access_token: str, + access_token_decoded: dict | None = None, ) -> User: - user = User.from_auth(userinfo=user_info, oidc_provider=oidc_provider) + if access_token_decoded is None: + assert oidc_provider.name is not None + oidc_provider_settings = oidc_providers_settings[oidc_provider.name] + access_token_decoded = oidc_provider_settings.decode(access_token) + user = User(**user_info) + user.userinfo = user_info + user.oidc_provider = oidc_provider user.access_token = access_token + user.access_token_decoded = access_token_decoded + # Add roles provided in the access token + roles = set() try: - raw_roles = user_info_from_endpoint["resource_access"][ - oidc_provider.client_id - ]["roles"] - except Exception as err: - logger.debug(f"Cannot read additional roles: {err}") - raw_roles = [] - for raw_role in raw_roles: - user.roles.append(Role(name=raw_role)) + r = access_token_decoded["resource_access"][oidc_provider.client_id]["roles"] + roles.update(r) + except KeyError: + pass + try: + r = access_token_decoded["realm_access"]["roles"] + if isinstance(r, str): + roles.add(r) + else: + roles.update(r) + except KeyError: + pass + user.roles = [Role(name=role_name) for role_name in roles] self.users[sub] = user return user @@ -51,14 +66,21 @@ class Database: raise UserNotInDB return self.users[sub] - async def add_token(self, token: OAuth2Token, user: User) -> None: - self.tokens[token["id_token"]] = token + async def add_token(self, oidc_provider_settings: OIDCProvider, token: OAuth2Token) -> None: + """Store a token using as key the sid (auth provider's session id) + in the id_token""" + sid = token["userinfo"]["sid"] + self.tokens[sid] = token - async def get_token(self, id_token: str | None) -> OAuth2Token: - if id_token is None: + async def get_token( + self, + oidc_provider_settings: OIDCProvider, + sid: str | None, + ) -> OAuth2Token: + if sid is None: raise TokenNotInDb try: - return self.tokens[id_token] + return self.tokens[sid] except KeyError: raise TokenNotInDb diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index e9ba5b1..60482bd 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -15,7 +15,7 @@ from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse from fastapi.templating import Jinja2Templates from fastapi.middleware.cors import CORSMiddleware -from jwt import InvalidKeyError, InvalidTokenError +from jwt import InvalidTokenError, PyJWTError from starlette.middleware.sessions import SessionMiddleware from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.base_client import OAuthError @@ -26,7 +26,7 @@ from authlib.oauth2.rfc6749 import OAuth2Token # from fastapi.security import OpenIdConnect # from pkce import generate_code_verifier, generate_pkce_pair -from .settings import settings +from .settings import settings, oidc_providers_settings from .models import User from .auth_utils import ( get_oidc_provider, @@ -37,14 +37,13 @@ from .auth_utils import ( get_user_from_token, authlib_oauth, get_token, - oidc_providers_settings, get_providers_info, ) from .auth_misc import pretty_details from .database import TokenNotInDb, db from .resource_server import get_resource -logger = logging.getLogger("uvicorn.error") +logger = logging.getLogger("oidc-test") templates = Jinja2Templates(Path(__file__).parent / "templates") @@ -189,43 +188,28 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: request.session["oidc_provider_id"] = oidc_provider_id # User id (sub) given by oidc provider sub = userinfo["sub"] - # Get additional data from userinfo endpoint - try: - user_info_from_endpoint = await oidc_provider.userinfo( - token=token, follow_redirects=True - ) - except Exception as err: - logger.warn(f"Cannot get userinfo from endpoint: {err}") - user_info_from_endpoint = {} # Build and remember the user in the session request.session["user_sub"] = sub - # Verify the token's signature and validity + # Store the user in the database, which also verifies the token validity and signature try: - oidc_provider_settings = oidc_providers_settings[oidc_provider_id] - oidc_provider_settings.decode(token["access_token"]) - except InvalidKeyError: + user = await db.add_user( + sub, + user_info=userinfo, + oidc_provider=oidc_provider, + access_token=token["access_token"], + ) + except PyJWTError as err: raise HTTPException( status.HTTP_401_UNAUTHORIZED, - detail="Token invalid key / signature", + detail=f"Token invalid: {err.__class__.__name__}", ) - except Exception as err: - logger.exception(err) - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, - detail="Cannot decode token or verify its signature", - ) - # Store the user in the database - user = await db.add_user( - sub, - user_info=userinfo, - oidc_provider=oidc_provider, - user_info_from_endpoint=user_info_from_endpoint, - access_token=token["access_token"], - ) - # Add the id_token to the session - request.session["token"] = token["id_token"] + assert isinstance(user, User) + # Add the provider session id to the session + request.session["sid"] = userinfo["sid"] # Add the token to the db because it is used for logout - await db.add_token(token, user) + assert oidc_provider.name is not None + oidc_provider_settings = oidc_providers_settings[oidc_provider.name] + await db.add_token(oidc_provider_settings, token) # Send the user to the home: (s)he is authenticated return RedirectResponse(url=request.url_for("home")) else: @@ -268,8 +252,14 @@ async def logout( ) return RedirectResponse(request.url_for("non_compliant_logout")) post_logout_uri = request.url_for("home") + oidc_provider_settings = oidc_providers_settings.get( + request.session.get("oidc_provider_id", "") + ) + assert oidc_provider_settings is not None try: - token = await db.get_token(request.session.pop("token", None)) + token = await db.get_token( + oidc_provider_settings, request.session.pop("sid", None) + ) except TokenNotInDb: logger.warn("No session in db for the token or no token") return RedirectResponse(request.url_for("home")) diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py index 4b1c064..fc0dba7 100644 --- a/src/oidc_test/models.py +++ b/src/oidc_test/models.py @@ -1,6 +1,6 @@ import logging from functools import cached_property -from typing import Self +from typing import Self, Any from pydantic import ( computed_field, @@ -11,7 +11,7 @@ from pydantic import ( from authlib.integrations.starlette_client.apps import StarletteOAuth2App from sqlmodel import SQLModel, Field -logger = logging.getLogger(__name__) +logger = logging.getLogger("oidc-test") class Role(SQLModel, extra="ignore"): @@ -36,19 +36,9 @@ class User(UserBase): ) userinfo: dict = {} access_token: str | None = None + access_token_decoded: dict[str, Any] | None = None oidc_provider: StarletteOAuth2App | None = None - @classmethod - def from_auth(cls, userinfo: dict, oidc_provider: StarletteOAuth2App) -> Self: - user = cls(**userinfo) - user.userinfo = userinfo - user.oidc_provider = oidc_provider - # Add roles if they are provided in the token - if raw_ra := userinfo.get("realm_access"): - if raw_roles := raw_ra.get("roles"): - user.roles = [Role(name=raw_role) for raw_role in raw_roles] - return user - @computed_field @cached_property def roles_as_set(self) -> set[str]: @@ -68,7 +58,7 @@ class User(UserBase): assert self.access_token is not None assert self.oidc_provider is not None assert self.oidc_provider.name is not None - from .auth_utils import oidc_providers_settings + from .settings import oidc_providers_settings return oidc_providers_settings[self.oidc_provider.name].decode( self.access_token diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index 635a91b..0d90533 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -8,7 +8,7 @@ from starlette.status import HTTP_401_UNAUTHORIZED from .models import User -logger = logging.getLogger(__name__) +logger = logging.getLogger("oidc-test") async def get_resource(resource_id: str, user: User) -> dict: diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index 4d08ada..2544bd7 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -1,8 +1,9 @@ from os import environ import string import random -from typing import Type, Tuple +from typing import Type, Tuple, Any from pathlib import Path +import logging from jwt import decode from pydantic import BaseModel, computed_field, AnyUrl @@ -16,6 +17,8 @@ from starlette.requests import Request from .models import User +logger = logging.getLogger("oidc-test") + class Resource(BaseModel): """A resource with an URL that can be accessed with an OAuth2 access token""" @@ -86,14 +89,27 @@ class OIDCProvider(BaseModel): -----END PUBLIC KEY----- """ - def decode(self, token: str) -> dict: + def decode(self, token: str, verify_signature: bool = True) -> dict[str, Any]: """Decode the token with signature check""" + decoded = decode( + token, + self.get_public_key(), + algorithms=[self.signature_alg], + audience=["account", "oidc-test", "oidc-test-web"], + options={ + "verify_signature": False, + "verify_aud": False, + }, # not settings.insecure.skip_verify_signature}, + ) + logger.debug(str(decoded)) return decode( token, self.get_public_key(), algorithms=[self.signature_alg], - audience=["oidc-test", "oidc-test-web"], - options={"verify_signature": not settings.insecure.skip_verify_signature}, + audience=["account", "oidc-test", "oidc-test-web"], + options={ + "verify_signature": verify_signature, + }, # not settings.insecure.skip_verify_signature}, ) @@ -156,3 +172,8 @@ class Settings(BaseSettings): settings = Settings() + + +oidc_providers_settings: dict[str, OIDCProvider] = dict( + [(provider.id, provider) for provider in settings.oidc.providers] +) From d39adf41eff5dac79bde268c01602f5f6072385d Mon Sep 17 00:00:00 2001 From: phil Date: Fri, 7 Feb 2025 13:57:17 +0100 Subject: [PATCH 34/79] Create a sub-app for resource server move all resources to resource server; use token bearer instead of session cookie for resources and use fetch instead of XMLHttpRequest for checking resource status; add UserWithRole class for fastapi depends (instead of has_role decorator); add asserts for typing QC; code formatting; comment out introspect endpoint processing --- src/oidc_test/auth_utils.py | 28 ++++++- src/oidc_test/database.py | 2 + src/oidc_test/main.py | 133 ++++-------------------------- src/oidc_test/models.py | 6 +- src/oidc_test/resource_server.py | 122 +++++++++++++++++++++++++-- src/oidc_test/static/utils.js | 32 ++++--- src/oidc_test/templates/base.html | 2 +- src/oidc_test/templates/home.html | 16 ++-- 8 files changed, 188 insertions(+), 153 deletions(-) diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index 281511d..1004527 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -13,7 +13,7 @@ from authlib.oauth2.auth import OAuth2Token from .models import User from .database import TokenNotInDb, db, UserNotInDB -from .settings import settings, OIDCProvider, oidc_providers_settings +from .settings import oidc_providers_settings logger = logging.getLogger("oidc-test") @@ -21,6 +21,8 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") async def fetch_token(name, request): + assert name is not None + assert request is not None logger.warn("TODO: fetch_token") ... # if name in oidc_providers: @@ -37,8 +39,10 @@ async def update_token(name, token, refresh_token=None, access_token=None): sid: str = oidc_provider_settings.decode(token["id_token"])["sid"] item = await db.get_token(oidc_provider_settings, sid) # update old token - item["access_token"] = token.get("access_token") - item["refresh_token"] = token.get("refresh_token") + if access_token is not None: + item["access_token"] = token.get("access_token") + if refresh_token is not None: + item["refresh_token"] = refresh_token item["expires_at"] = token["expires_at"] logger.info(f"Token {sid} refreshed") # It's a fake db and only in memory, so there's nothing to save @@ -119,6 +123,7 @@ async def get_current_user(request: Request) -> User: userinfo = await oidc_provider.fetch_access_token( refresh_token=token.get("refresh_token") ) + assert userinfo is not None except OAuthError as err: logger.exception(err) # raise HTTPException( @@ -242,3 +247,20 @@ async def get_user_from_token( access_token=token, ) return user + + +class UserWithRole: + roles: set[str] + + def __init__(self, roles: str | list[str] | tuple[str] | set[str]): + if isinstance(roles, str): + self.roles = set([roles]) + elif isinstance(roles, (list, tuple, set)): + self.roles = set(roles) + + def __call__(self, user: User = Depends(get_user_from_token)) -> User: + if not any(self.roles.intersection(user.roles_as_set)): + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, f"Not of any required role {', '.join(self.roles)}" + ) + return user diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index 360ef11..d3bdd4e 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -69,6 +69,7 @@ class Database: async def add_token(self, oidc_provider_settings: OIDCProvider, token: OAuth2Token) -> None: """Store a token using as key the sid (auth provider's session id) in the id_token""" + assert isinstance(oidc_provider_settings, OIDCProvider) sid = token["userinfo"]["sid"] self.tokens[sid] = token @@ -77,6 +78,7 @@ class Database: oidc_provider_settings: OIDCProvider, sid: str | None, ) -> OAuth2Token: + assert isinstance(oidc_provider_settings, OIDCProvider) if sid is None: raise TokenNotInDb try: diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 60482bd..47d0c39 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -12,7 +12,7 @@ from contextlib import asynccontextmanager from httpx import HTTPError from fastapi import Depends, FastAPI, HTTPException, Request, status from fastapi.staticfiles import StaticFiles -from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse +from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.templating import Jinja2Templates from fastapi.middleware.cors import CORSMiddleware from jwt import InvalidTokenError, PyJWTError @@ -31,17 +31,13 @@ from .models import User from .auth_utils import ( get_oidc_provider, get_oidc_provider_or_none, - hasrole, get_current_user_or_none, - get_current_user, - get_user_from_token, authlib_oauth, - get_token, get_providers_info, ) from .auth_misc import pretty_details from .database import TokenNotInDb, db -from .resource_server import get_resource +from .resource_server import resource_server logger = logging.getLogger("oidc-test") @@ -50,6 +46,7 @@ templates = Jinja2Templates(Path(__file__).parent / "templates") @asynccontextmanager async def lifespan(app: FastAPI): + assert app is not None await get_providers_info() yield @@ -64,24 +61,21 @@ app.add_middleware( allow_headers=["*"], ) -app.mount( - "/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static" -) - # SessionMiddleware is required by authlib app.add_middleware( SessionMiddleware, secret_key=settings.secret_key, ) +app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static") +app.mount("/resource", resource_server, name="resource_server") + @app.get("/") async def home( request: Request, user: Annotated[User, Depends(get_current_user_or_none)], - oidc_provider: Annotated[ - StarletteOAuth2App | None, Depends(get_oidc_provider_or_none) - ], + oidc_provider: Annotated[StarletteOAuth2App | None, Depends(get_oidc_provider_or_none)], ) -> HTMLResponse: now = datetime.now() if oidc_provider and ( @@ -119,9 +113,7 @@ async def home( "oidc_provider_settings": oidc_provider_settings, "resources": resources, "user_info_details": ( - pretty_details(user, now) - if user and settings.oidc.show_session_details - else None + pretty_details(user, now) if user and settings.oidc.show_session_details else None ), }, ) @@ -215,24 +207,19 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: else: # Not sure if it's correct to redirect to plain login # if no userinfo is provided - return RedirectResponse( - url=request.url_for("login", oidc_provider_id=oidc_provider_id) - ) + return RedirectResponse(url=request.url_for("login", oidc_provider_id=oidc_provider_id)) @app.get("/account") async def account( request: Request, - oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], ) -> RedirectResponse: if ( oidc_provider_settings := oidc_providers_settings.get( request.session.get("oidc_provider_id", "") ) ) is None: - raise HTTPException( - status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider settings" - ) + raise HTTPException(status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider settings") return RedirectResponse(f"{oidc_provider_settings.account_url_template}") @@ -244,12 +231,8 @@ async def logout( # Clear session request.session.pop("user_sub", None) # Get provider's endpoint - if ( - provider_logout_uri := oidc_provider.server_metadata.get("end_session_endpoint") - ) is None: - logger.warn( - f"Cannot find end_session_endpoint for provider {oidc_provider.name}" - ) + if (provider_logout_uri := oidc_provider.server_metadata.get("end_session_endpoint")) is None: + logger.warn(f"Cannot find end_session_endpoint for provider {oidc_provider.name}") return RedirectResponse(request.url_for("non_compliant_logout")) post_logout_uri = request.url_for("home") oidc_provider_settings = oidc_providers_settings.get( @@ -257,9 +240,7 @@ async def logout( ) assert oidc_provider_settings is not None try: - token = await db.get_token( - oidc_provider_settings, request.session.pop("sid", None) - ) + token = await db.get_token(oidc_provider_settings, request.session.pop("sid", None)) except TokenNotInDb: logger.warn("No session in db for the token or no token") return RedirectResponse(request.url_for("home")) @@ -292,90 +273,6 @@ async def non_compliant_logout( ) -# Route for OAuth resource server - - -@app.get("/resource/{id}") -async def get_resource_( - id: str, - # user: Annotated[User, Depends(get_current_user)], - # oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], - # token: Annotated[OAuth2Token, Depends(get_token)], - user: Annotated[User, Depends(get_user_from_token)], -) -> JSONResponse: - """Generic path for testing a resource provided by a provider""" - return JSONResponse(await get_resource(id, user)) - - -# Routes for RBAC based tests - - -@app.get("/public") -async def public() -> HTMLResponse: - return HTMLResponse("

Not protected

") - - -@app.get("/protected") -async def get_protected( - user: Annotated[User, Depends(get_current_user)] -) -> HTMLResponse: - assert user is not None # Just to keep QA checks happy - return HTMLResponse("

Only authenticated users can see this

") - - -@app.get("/protected-by-foorole") -@hasrole("foorole") -async def get_protected_by_foorole(request: Request) -> HTMLResponse: - assert request is not None # Just to keep QA checks happy - return HTMLResponse("

Only users with foorole can see this

") - - -@app.get("/protected-by-barrole") -@hasrole("barrole") -async def get_protected_by_barrole(request: Request) -> HTMLResponse: - assert request is not None # Just to keep QA checks happy - return HTMLResponse("

Protected by barrole

") - - -@app.get("/protected-by-foorole-and-barrole") -@hasrole("barrole") -@hasrole("foorole") -async def get_protected_by_foorole_and_barrole(request: Request) -> HTMLResponse: - assert request is not None # Just to keep QA checks happy - return HTMLResponse("

Only users with foorole and barrole can see this

") - - -@app.get("/protected-by-foorole-or-barrole") -@hasrole(["foorole", "barrole"]) -async def get_protected_by_foorole_or_barrole(request: Request) -> HTMLResponse: - assert request is not None # Just to keep QA checks happy - return HTMLResponse("

Only users with foorole or barrole can see this

") - - -@app.get("/introspect") -async def get_introspect( - request: Request, - oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], - token: Annotated[OAuth2Token, Depends(get_token)], -) -> JSONResponse: - assert request is not None # Just to keep QA checks happy - if (url := oidc_provider.server_metadata.get("introspection_endpoint")) is None: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="No intrispection endpoint found for the OIDC provider", - ) - if ( - response := await oidc_provider.post( - url, - token=token, - data={"token": token["access_token"]}, - ) - ).is_success: - return response.json() - else: - raise HTTPException(status_code=response.status_code, detail=response.text) - - # Snippet for running standalone # Mostly useful for the --version option, # as running with uvicorn is easy and provides better flexibility, eg. @@ -397,9 +294,7 @@ def main(): parser.add_argument( "-p", "--port", type=int, default=80, help="Port to listen to (default: 80)" ) - parser.add_argument( - "-v", "--version", action="store_true", help="Print version and exit" - ) + parser.add_argument("-v", "--version", action="store_true", help="Print version and exit") args = parser.parse_args() if args.version: diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py index fc0dba7..9554bd5 100644 --- a/src/oidc_test/models.py +++ b/src/oidc_test/models.py @@ -1,6 +1,6 @@ import logging from functools import cached_property -from typing import Self, Any +from typing import Any from pydantic import ( computed_field, @@ -60,6 +60,4 @@ class User(UserBase): assert self.oidc_provider.name is not None from .settings import oidc_providers_settings - return oidc_providers_settings[self.oidc_provider.name].decode( - self.access_token - ) + return oidc_providers_settings[self.oidc_provider.name].decode(self.access_token) diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index 0d90533..d5e2aaa 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -1,15 +1,127 @@ from datetime import datetime +from typing import Annotated import logging from httpx import AsyncClient from jwt.exceptions import ExpiredSignatureError, InvalidTokenError -from fastapi import HTTPException, status -from starlette.status import HTTP_401_UNAUTHORIZED +from fastapi import FastAPI, HTTPException, Depends, Request, status +from fastapi.responses import HTMLResponse, JSONResponse +from fastapi.middleware.cors import CORSMiddleware + +# from starlette.middleware.sessions import SessionMiddleware +# from authlib.integrations.starlette_client.apps import StarletteOAuth2App +# from authlib.oauth2.rfc6749 import OAuth2Token from .models import User +from .auth_utils import ( + get_user_from_token, + UserWithRole, + get_oidc_provider, + get_token, +) +from .settings import settings logger = logging.getLogger("oidc-test") +resource_server = FastAPI() + + +resource_server.add_middleware( + CORSMiddleware, + allow_origins=settings.cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# SessionMiddleware is required by authlib +# resource_server.add_middleware( +# SessionMiddleware, +# secret_key=settings.secret_key, +# ) + +# Route for OAuth resource server + + +# Routes for RBAC based tests + + +@resource_server.get("/public") +async def public() -> HTMLResponse: + return HTMLResponse("

Not protected

") + + +@resource_server.get("/protected") +async def get_protected(user: Annotated[User, Depends(get_user_from_token)]) -> HTMLResponse: + assert user is not None # Just to keep QA checks happy + return HTMLResponse("

Only authenticated users can see this

") + + +@resource_server.get("/protected-by-foorole") +async def get_protected_by_foorole( + user: Annotated[User, Depends(UserWithRole("foorole"))] +) -> HTMLResponse: + return HTMLResponse("

Only users with foorole can see this

") + + +@resource_server.get("/protected-by-barrole") +async def get_protected_by_barrole( + user: Annotated[User, Depends(UserWithRole("barrole"))] +) -> HTMLResponse: + return HTMLResponse("

Protected by barrole

") + + +@resource_server.get("/protected-by-foorole-and-barrole") +async def get_protected_by_foorole_and_barrole( + user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))], +) -> HTMLResponse: + assert user is not None # Just to keep QA checks happy + return HTMLResponse("

Only users with foorole and barrole can see this

") + + +@resource_server.get("/protected-by-foorole-or-barrole") +async def get_protected_by_foorole_or_barrole( + user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))] +) -> HTMLResponse: + assert user is not None # Just to keep QA checks happy + return HTMLResponse("

Only users with foorole or barrole can see this

") + + +# @resource_server.get("/introspect") +# async def get_introspect( +# request: Request, +# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], +# token: Annotated[OAuth2Token, Depends(get_token)], +# ) -> JSONResponse: +# assert request is not None # Just to keep QA checks happy +# if (url := oidc_provider.server_metadata.get("introspection_endpoint")) is None: +# raise HTTPException( +# status_code=status.HTTP_401_UNAUTHORIZED, +# detail="No introspection endpoint found for the OIDC provider", +# ) +# if ( +# response := await oidc_provider.post( +# url, +# token=token, +# data={"token": token["access_token"]}, +# ) +# ).is_success: +# return response.json() +# else: +# raise HTTPException(status_code=response.status_code, detail=response.text) + + +@resource_server.get("/{id}") +async def get_resource_( + id: str, + # user: Annotated[User, Depends(get_current_user)], + # oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], + # token: Annotated[OAuth2Token, Depends(get_token)], + user: Annotated[User, Depends(get_user_from_token)], +) -> JSONResponse: + """Generic path for testing a resource provided by a provider""" + return JSONResponse(await get_resource(id, user)) + async def get_resource(resource_id: str, user: User) -> dict: """ @@ -34,12 +146,10 @@ async def get_resource(resource_id: str, user: User) -> dict: raise HTTPException( status.HTTP_401_UNAUTHORIZED, f"No scope {required_scope} in the access token " - + "but it is required for accessing this resource.", + + "but it is required for accessing this resource", ) except ExpiredSignatureError: - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, "The token's signature has expired" - ) + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token's signature has expired") except InvalidTokenError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "The token is invalid") return resp diff --git a/src/oidc_test/static/utils.js b/src/oidc_test/static/utils.js index a982267..e6c4bfc 100644 --- a/src/oidc_test/static/utils.js +++ b/src/oidc_test/static/utils.js @@ -1,20 +1,28 @@ -function checkHref(elem) { - var xmlHttp = new XMLHttpRequest() - xmlHttp.onreadystatechange = function () { - if (xmlHttp.readyState == 4) { - elem.classList.add("hasResponseStatus") - elem.classList.add("status-" + xmlHttp.status) - elem.title = "Response code: " + xmlHttp.status + " - " + xmlHttp.statusText - } +async function checkHref(elem, token, authProvider) { + const msg = document.getElementById("msg") + const resp = await fetch(elem.href, { + headers: new Headers({ + "Content-type": "application/json", + "Authorization": `Bearer ${token}`, + "auth_provider": authProvider, + }), + }).catch(err => { + msg.innerHTML = "Cannot fetch resource: " + err.message + resourceElem.innerHTML = "" + }) + if (resp === undefined) { + return + } else { + elem.classList.add("hasResponseStatus") + elem.classList.add("status-" + resp.status) + elem.title = "Response code: " + resp.status + " - " + resp.statusText } - xmlHttp.open("GET", elem.href, true) // true for asynchronous - xmlHttp.send(null) } -function checkPerms(className) { +function checkPerms(className, token, authProvider) { var rootElems = document.getElementsByClassName(className) Array.from(rootElems).forEach(elem => - Array.from(elem.children).forEach(elem => checkHref(elem)) + Array.from(elem.children).forEach(elem => checkHref(elem, token, authProvider)) ) } diff --git a/src/oidc_test/templates/base.html b/src/oidc_test/templates/base.html index 3bdb3f3..0fe1a6b 100644 --- a/src/oidc_test/templates/base.html +++ b/src/oidc_test/templates/base.html @@ -4,7 +4,7 @@ - +

OIDC-test - FastAPI client

{% block content %} {% endblock %} diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index ce344cc..09c313f 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -80,14 +80,14 @@ These links should get different response codes depending on the authorization:

{% if resources %}

From 3eb6dc3dcf4be7350aa4cda6b3157183fe77d5d8 Mon Sep 17 00:00:00 2001 From: phil Date: Fri, 7 Feb 2025 16:09:49 +0100 Subject: [PATCH 35/79] Migrate all resources to json contents; improve token decoding & logging error messages --- src/oidc_test/auth_utils.py | 17 ++++++++---- src/oidc_test/resource_server.py | 43 +++++++++++++++-------------- src/oidc_test/settings.py | 46 ++++++++++++------------------- src/oidc_test/static/styles.css | 10 ++----- src/oidc_test/static/utils.js | 3 +- src/oidc_test/templates/home.html | 45 ++++++++++++++---------------- 6 files changed, 77 insertions(+), 87 deletions(-) diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index 1004527..3303e58 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -5,7 +5,7 @@ import logging from fastapi import HTTPException, Request, Depends, status from fastapi.security import OAuth2PasswordBearer from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App -from jwt import ExpiredSignatureError, InvalidKeyError +from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError from httpx import AsyncClient # from authlib.oauth1.auth import OAuthToken @@ -147,8 +147,8 @@ async def get_token(request: Request) -> OAuth2Token: oidc_provider_settings, request.session.get("sid"), ) - except (TokenNotInDb, InvalidKeyError): - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token") + except (TokenNotInDb, InvalidKeyError, DecodeError) as err: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, err.__class__.__name__) async def get_current_user_or_none(request: Request) -> User | None: @@ -208,9 +208,14 @@ async def get_user_from_token( try: auth_provider_settings = oidc_providers_settings[auth_provider_id] except KeyError: - raise HTTPException( - status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'" - ) + if auth_provider_id == "": + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No auth provider") + else: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'" + ) + if token == "": + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token") try: payload = auth_provider_settings.decode(token) except ExpiredSignatureError as err: diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index d5e2aaa..cb944ed 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -4,8 +4,7 @@ import logging from httpx import AsyncClient from jwt.exceptions import ExpiredSignatureError, InvalidTokenError -from fastapi import FastAPI, HTTPException, Depends, Request, status -from fastapi.responses import HTMLResponse, JSONResponse +from fastapi import FastAPI, HTTPException, Depends, status from fastapi.middleware.cors import CORSMiddleware # from starlette.middleware.sessions import SessionMiddleware @@ -16,8 +15,8 @@ from .models import User from .auth_utils import ( get_user_from_token, UserWithRole, - get_oidc_provider, - get_token, + # get_oidc_provider, + # get_token, ) from .settings import settings @@ -47,44 +46,46 @@ resource_server.add_middleware( @resource_server.get("/public") -async def public() -> HTMLResponse: - return HTMLResponse("

Not protected

") +async def public() -> dict: + return {"msg": "Not protected"} @resource_server.get("/protected") -async def get_protected(user: Annotated[User, Depends(get_user_from_token)]) -> HTMLResponse: +async def get_protected(user: Annotated[User, Depends(get_user_from_token)]): assert user is not None # Just to keep QA checks happy - return HTMLResponse("

Only authenticated users can see this

") + return {"msg": "Only authenticated users can see this"} @resource_server.get("/protected-by-foorole") async def get_protected_by_foorole( - user: Annotated[User, Depends(UserWithRole("foorole"))] -) -> HTMLResponse: - return HTMLResponse("

Only users with foorole can see this

") + user: Annotated[User, Depends(UserWithRole("foorole"))], +): + assert user is not None + return {"msg": "Only users with foorole can see this"} @resource_server.get("/protected-by-barrole") async def get_protected_by_barrole( - user: Annotated[User, Depends(UserWithRole("barrole"))] -) -> HTMLResponse: - return HTMLResponse("

Protected by barrole

") + user: Annotated[User, Depends(UserWithRole("barrole"))], +): + assert user is not None + return {"msg": "Protected by barrole"} @resource_server.get("/protected-by-foorole-and-barrole") async def get_protected_by_foorole_and_barrole( user: Annotated[User, Depends(UserWithRole("foorole")), Depends(UserWithRole("barrole"))], -) -> HTMLResponse: +): assert user is not None # Just to keep QA checks happy - return HTMLResponse("

Only users with foorole and barrole can see this

") + return {"msg": "Only users with foorole and barrole can see this"} @resource_server.get("/protected-by-foorole-or-barrole") async def get_protected_by_foorole_or_barrole( - user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))] -) -> HTMLResponse: + user: Annotated[User, Depends(UserWithRole(["foorole", "barrole"]))], +): assert user is not None # Just to keep QA checks happy - return HTMLResponse("

Only users with foorole or barrole can see this

") + return {"msg": "Only users with foorole or barrole can see this"} # @resource_server.get("/introspect") @@ -118,9 +119,9 @@ async def get_resource_( # oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], # token: Annotated[OAuth2Token, Depends(get_token)], user: Annotated[User, Depends(get_user_from_token)], -) -> JSONResponse: +): """Generic path for testing a resource provided by a provider""" - return JSONResponse(await get_resource(id, user)) + return await get_resource(id, user) async def get_resource(resource_id: str, user: User) -> dict: diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index 2544bd7..b601739 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -43,9 +43,7 @@ class OIDCProvider(BaseModel): info_url: str | None = ( None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key) ) - info: dict[str, str | int] | None = ( - None # Info fetched from info_url, eg. public key - ) + info: dict[str, str | int] | None = None # Info fetched from info_url, eg. public key public_key: str | None = None signature_alg: str = "RS256" resource_provider_scopes: list[str] = [] @@ -62,25 +60,17 @@ class OIDCProvider(BaseModel): def get_account_url(self, request: Request, user: User) -> str | None: if self.account_url_template: - if not ( - self.url.endswith("/") or self.account_url_template.startswith("/") - ): + if not (self.url.endswith("/") or self.account_url_template.startswith("/")): sep = "/" else: sep = "" - return ( - self.url - + sep - + self.account_url_template.format(request=request, user=user) - ) + return self.url + sep + self.account_url_template.format(request=request, user=user) else: return None def get_public_key(self) -> str: """Return the public key formatted for decoding token""" - public_key = self.public_key or ( - self.info is not None and self.info["public_key"] - ) + public_key = self.public_key or (self.info is not None and self.info["public_key"]) if public_key is None: raise AttributeError(f"Cannot get public key for {self.name}") return f""" @@ -91,17 +81,18 @@ class OIDCProvider(BaseModel): def decode(self, token: str, verify_signature: bool = True) -> dict[str, Any]: """Decode the token with signature check""" - decoded = decode( - token, - self.get_public_key(), - algorithms=[self.signature_alg], - audience=["account", "oidc-test", "oidc-test-web"], - options={ - "verify_signature": False, - "verify_aud": False, - }, # not settings.insecure.skip_verify_signature}, - ) - logger.debug(str(decoded)) + if settings.debug_token: + decoded = decode( + token, + self.get_public_key(), + algorithms=[self.signature_alg], + audience=["account", "oidc-test", "oidc-test-web"], + options={ + "verify_signature": False, + "verify_aud": False, + }, # not settings.insecure.skip_verify_signature}, + ) + logger.debug(str(decoded)) return decode( token, self.get_public_key(), @@ -143,6 +134,7 @@ class Settings(BaseSettings): log: bool = False insecure: Insecure = Insecure() cors_origins: list[str] = [] + debug_token: bool = False @classmethod def settings_customise_sources( @@ -161,9 +153,7 @@ class Settings(BaseSettings): settings_cls, Path( Path( - environ.get( - "OIDC_TEST_SETTINGS_FILE", Path.cwd() / "settings.yaml" - ), + environ.get("OIDC_TEST_SETTINGS_FILE", Path.cwd() / "settings.yaml"), ) ), ), diff --git a/src/oidc_test/static/styles.css b/src/oidc_test/static/styles.css index 7e1260b..6262d79 100644 --- a/src/oidc_test/static/styles.css +++ b/src/oidc_test/static/styles.css @@ -171,11 +171,13 @@ hr { gap: 0.5em; flex-flow: wrap; } -.content .links-to-check a { +.content .links-to-check button { color: black; padding: 5px 10px; text-decoration: none; border-radius: 8px; + border: none; + cursor: pointer; } .token { @@ -183,12 +185,6 @@ hr { font-family: monospace; } -.actions { - display: flex; - justify-content: center; - gap: 0.5em; -} - .resourceResult { padding: 0.5em; gap: 0.5em; diff --git a/src/oidc_test/static/utils.js b/src/oidc_test/static/utils.js index e6c4bfc..6ea8da2 100644 --- a/src/oidc_test/static/utils.js +++ b/src/oidc_test/static/utils.js @@ -1,6 +1,7 @@ async function checkHref(elem, token, authProvider) { const msg = document.getElementById("msg") - const resp = await fetch(elem.href, { + const url = `resource/${elem.getAttribute("resource-id")}` + const resp = await fetch(url, { headers: new Headers({ "Content-type": "application/json", "Authorization": `Bearer ${token}`, diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 09c313f..92b7068 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -61,33 +61,30 @@
{% endif %}
- {% if user %} -

- Fetch resources from the resource server with your authentication token: -

-
- - -
-
-
-
-
-
- {% endif %}
-

- These links should get different response codes depending on the authorization: +

+ Resources validated by scope:

+

+ Resources validated by role: +

+ +
+
+
{% if resources %}

From ff72f0cae585858e400a9bc8f7d3fe1727035c44 Mon Sep 17 00:00:00 2001 From: phil Date: Sat, 8 Feb 2025 01:55:36 +0100 Subject: [PATCH 36/79] Display full token info --- src/oidc_test/auth_utils.py | 14 +++++++++-- src/oidc_test/main.py | 41 +++++++++++++++++++------------ src/oidc_test/settings.py | 1 + src/oidc_test/static/styles.css | 4 +-- src/oidc_test/templates/home.html | 41 ++++++++++++++++++++++--------- 5 files changed, 70 insertions(+), 31 deletions(-) diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index 3303e58..26f3779 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -1,3 +1,4 @@ +import re from typing import Union, Annotated from functools import wraps import logging @@ -133,9 +134,18 @@ async def get_current_user(request: Request) -> User: return user +async def get_token_or_none(request: Request) -> OAuth2Token | None: + """Return the auth token from the session or None. + Can be used in Depends()""" + try: + return await get_token(request) + except HTTPException: + return None + + async def get_token(request: Request) -> OAuth2Token: - """Return the token from a request object, from the session. - It can be used in Depends()""" + """Return the token from the session. + Can be used in Depends()""" try: oidc_provider_settings = oidc_providers_settings[ request.session.get("oidc_provider_id", "") diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 47d0c39..4a037eb 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -34,6 +34,7 @@ from .auth_utils import ( get_current_user_or_none, authlib_oauth, get_providers_info, + get_token_or_none, ) from .auth_misc import pretty_details from .database import TokenNotInDb, db @@ -76,6 +77,7 @@ async def home( request: Request, user: Annotated[User, Depends(get_current_user_or_none)], oidc_provider: Annotated[StarletteOAuth2App | None, Depends(get_oidc_provider_or_none)], + token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], ) -> HTMLResponse: now = datetime.now() if oidc_provider and ( @@ -101,22 +103,29 @@ async def home( logger.info("Invalid token") logger.exception(err) - return templates.TemplateResponse( - name="home.html", - request=request, - context={ - "settings": settings.model_dump(), - "user": user, - "access_token_scope": access_token_scope, - "now": now, - "oidc_provider": oidc_provider, - "oidc_provider_settings": oidc_provider_settings, - "resources": resources, - "user_info_details": ( - pretty_details(user, now) if user and settings.oidc.show_session_details else None - ), - }, - ) + context = { + "settings": settings.model_dump(), + "user": user, + "access_token_scope": access_token_scope, + "now": now, + "oidc_provider": oidc_provider, + "oidc_provider_settings": oidc_provider_settings, + "resources": resources, + } + if token is None: + context["id_token_parsed"] = None + context["access_token_parsed"] = None + context["refresh_token_parsed"] = None + else: + assert oidc_provider is not None + assert oidc_provider.name is not None + oidc_provider_settings = oidc_providers_settings[oidc_provider.name] + context["id_token_parsed"] = pretty_details(user, now) + context["access_token_parsed"] = oidc_provider_settings.decode(token["access_token"]) + context["refresh_token_parsed"] = oidc_provider_settings.decode( + token["refresh_token"], verify_signature=False + ) + return templates.TemplateResponse(name="home.html", request=request, context=context) # Endpoints for the login / authorization process diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index b601739..e448c1e 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -135,6 +135,7 @@ class Settings(BaseSettings): insecure: Insecure = Insecure() cors_origins: list[str] = [] debug_token: bool = False + show_token: bool = False @classmethod def settings_customise_sources( diff --git a/src/oidc_test/static/styles.css b/src/oidc_test/static/styles.css index 6262d79..367ea99 100644 --- a/src/oidc_test/static/styles.css +++ b/src/oidc_test/static/styles.css @@ -73,7 +73,6 @@ hr { } .debug-auth p { border-bottom: 1px solid black; - text-align: left; } .debug-auth ul { padding: 0; @@ -185,8 +184,9 @@ hr { font-family: monospace; } -.resourceResult { +.resource { padding: 0.5em; + display: flex; gap: 0.5em; flex-direction: column; width: fit-content; diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 92b7068..9da5392 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -97,19 +97,38 @@

{% endif %} - {% if user_info_details %} -
-
-

User info

-
    - {% for key, value in user_info_details.items() %} -
  • - {{ key }}: {{ value }} -
  • + {% if settings.show_token and id_token_parsed %} +
    +
    +
    +

    id token

    +
    + {% for key, value in id_token_parsed.items() %} +
    +
    {{ key }}
    +
    {{ value }}
    +
    {% endfor %} -
+
+

access token

+
+ {% for key, value in access_token_parsed.items() %} +
+
{{ key }}
+
{{ value }}
+
+ {% endfor %} +
+

refresh token

+
+ {% for key, value in refresh_token_parsed.items() %} +
+
{{ key }}
+
{{ value }}
+
+ {% endfor %} +
-
Now is: {{ now.strftime("%T, %D") }}
{% endif %} {% endblock %} From 923a63f5d527e5a4128d193bb9ac7adef956cf84 Mon Sep 17 00:00:00 2001 From: phil Date: Sat, 8 Feb 2025 18:32:02 +0100 Subject: [PATCH 37/79] Add refresh token button --- src/oidc_test/auth_utils.py | 19 +++++++++++-------- src/oidc_test/main.py | 30 +++++++++++++++++++++++++++--- src/oidc_test/models.py | 9 +++++++-- src/oidc_test/static/utils.js | 1 + src/oidc_test/templates/base.html | 2 +- src/oidc_test/templates/home.html | 21 +++++++++++---------- 6 files changed, 58 insertions(+), 24 deletions(-) diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index 26f3779..cab14b2 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -13,7 +13,7 @@ from httpx import AsyncClient from authlib.oauth2.auth import OAuth2Token from .models import User -from .database import TokenNotInDb, db, UserNotInDB +from .database import db, TokenNotInDb, UserNotInDB from .settings import oidc_providers_settings logger = logging.getLogger("oidc-test") @@ -36,14 +36,14 @@ async def fetch_token(name, request): async def update_token(name, token, refresh_token=None, access_token=None): + """Update the token in the database""" oidc_provider_settings = oidc_providers_settings[name] sid: str = oidc_provider_settings.decode(token["id_token"])["sid"] item = await db.get_token(oidc_provider_settings, sid) # update old token - if access_token is not None: - item["access_token"] = token.get("access_token") - if refresh_token is not None: - item["refresh_token"] = refresh_token + item["access_token"] = token["access_token"] + item["refresh_token"] = token["refresh_token"] + item["id_token"] = token["id_token"] item["expires_at"] = token["expires_at"] logger.info(f"Token {sid} refreshed") # It's a fake db and only in memory, so there's nothing to save @@ -70,8 +70,8 @@ def init_providers(): api_base_url=provider.url, # For PKCE (not implemented yet): # code_challenge_method="S256", - # fetch_token=fetch_token, - # update_token=update_token, + fetch_token=fetch_token, + update_token=update_token, # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) ) @@ -101,7 +101,10 @@ def get_oidc_provider_or_none(request: Request) -> StarletteOAuth2App | None: def get_oidc_provider(request: Request) -> StarletteOAuth2App: if (oidc_provider := get_oidc_provider_or_none(request)) is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") + if oidc_provider is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No provider") + else: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") else: return oidc_provider diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 4a037eb..81b354f 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -35,6 +35,8 @@ from .auth_utils import ( authlib_oauth, get_providers_info, get_token_or_none, + get_token, + update_token, ) from .auth_misc import pretty_details from .database import TokenNotInDb, db @@ -97,7 +99,7 @@ async def home( access_token_scope = None else: try: - access_token_scope = user.decode_access_token()["scope"] + access_token_scope = user.get_scope(verify_signature=False) except InvalidTokenError as err: access_token_scope = None logger.info("Invalid token") @@ -113,15 +115,22 @@ async def home( "resources": resources, } if token is None: + context["access_token"] = None context["id_token_parsed"] = None context["access_token_parsed"] = None context["refresh_token_parsed"] = None else: + context["access_token"] = token["access_token"] assert oidc_provider is not None assert oidc_provider.name is not None oidc_provider_settings = oidc_providers_settings[oidc_provider.name] - context["id_token_parsed"] = pretty_details(user, now) - context["access_token_parsed"] = oidc_provider_settings.decode(token["access_token"]) + # context["id_token_parsed"] = pretty_details(user, now) + context["id_token_parsed"] = oidc_provider_settings.decode( + token["id_token"], verify_signature=False + ) + context["access_token_parsed"] = oidc_provider_settings.decode( + token["access_token"], verify_signature=False + ) context["refresh_token_parsed"] = oidc_provider_settings.decode( token["refresh_token"], verify_signature=False ) @@ -282,6 +291,21 @@ async def non_compliant_logout( ) +@app.get("/refresh") +async def refresh( + request: Request, + oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], + token: Annotated[OAuth2Token, Depends(get_token)], +) -> RedirectResponse: + """Manually refresh token""" + new_token = await oidc_provider.fetch_access_token( + refresh_token=token["refresh_token"], + grant_type="refresh_token", + ) + await update_token(oidc_provider.name, new_token) + return RedirectResponse(url=request.url_for("home")) + + # Snippet for running standalone # Mostly useful for the --version option, # as running with uvicorn is easy and provides better flexibility, eg. diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py index 9554bd5..8aee2e6 100644 --- a/src/oidc_test/models.py +++ b/src/oidc_test/models.py @@ -54,10 +54,15 @@ class User(UserBase): access_token_scopes = [] return scope in set(info_scopes + access_token_scopes) - def decode_access_token(self): + def decode_access_token(self, verify_signature: bool = True): assert self.access_token is not None assert self.oidc_provider is not None assert self.oidc_provider.name is not None from .settings import oidc_providers_settings - return oidc_providers_settings[self.oidc_provider.name].decode(self.access_token) + return oidc_providers_settings[self.oidc_provider.name].decode( + self.access_token, verify_signature=verify_signature + ) + + def get_scope(self, verify_signature: bool = True): + return self.decode_access_token(verify_signature=verify_signature)["scope"] diff --git a/src/oidc_test/static/utils.js b/src/oidc_test/static/utils.js index 6ea8da2..8e8ad59 100644 --- a/src/oidc_test/static/utils.js +++ b/src/oidc_test/static/utils.js @@ -2,6 +2,7 @@ async function checkHref(elem, token, authProvider) { const msg = document.getElementById("msg") const url = `resource/${elem.getAttribute("resource-id")}` const resp = await fetch(url, { + method: "GET", headers: new Headers({ "Content-type": "application/json", "Authorization": `Bearer ${token}`, diff --git a/src/oidc_test/templates/base.html b/src/oidc_test/templates/base.html index 0fe1a6b..2ce758c 100644 --- a/src/oidc_test/templates/base.html +++ b/src/oidc_test/templates/base.html @@ -4,7 +4,7 @@ - +

OIDC-test - FastAPI client

{% block content %} {% endblock %} diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 9da5392..08bcf43 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -57,6 +57,7 @@ Account management {% endif %} + {% endif %} @@ -66,21 +67,21 @@ Resources validated by scope:

Resources validated by role:

From 38b983c2a51ff1866e3306896a9e5e960bbe984b Mon Sep 17 00:00:00 2001 From: phil Date: Sat, 8 Feb 2025 19:05:13 +0100 Subject: [PATCH 38/79] Fix scope --- src/oidc_test/main.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 81b354f..03d13d7 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -95,20 +95,9 @@ async def home( resources = [] oidc_provider_settings = None - if user is None: - access_token_scope = None - else: - try: - access_token_scope = user.get_scope(verify_signature=False) - except InvalidTokenError as err: - access_token_scope = None - logger.info("Invalid token") - logger.exception(err) - context = { "settings": settings.model_dump(), "user": user, - "access_token_scope": access_token_scope, "now": now, "oidc_provider": oidc_provider, "oidc_provider_settings": oidc_provider_settings, @@ -124,13 +113,15 @@ async def home( assert oidc_provider is not None assert oidc_provider.name is not None oidc_provider_settings = oidc_providers_settings[oidc_provider.name] + access_token_parsed = oidc_provider_settings.decode( + token["access_token"], verify_signature=False + ) + context["access_token_scope"] = access_token_parsed["scope"] # context["id_token_parsed"] = pretty_details(user, now) context["id_token_parsed"] = oidc_provider_settings.decode( token["id_token"], verify_signature=False ) - context["access_token_parsed"] = oidc_provider_settings.decode( - token["access_token"], verify_signature=False - ) + context["access_token_parsed"] = access_token_parsed context["refresh_token_parsed"] = oidc_provider_settings.decode( token["refresh_token"], verify_signature=False ) From c5bb4f4319445ba145ab08067593a6957ce37f42 Mon Sep 17 00:00:00 2001 From: phil Date: Sun, 9 Feb 2025 06:20:48 +0100 Subject: [PATCH 39/79] Refactor most code, isolate authlib somehow --- src/oidc_test/auth_misc.py | 29 ------- src/oidc_test/auth_provider.py | 43 ++++++++++ src/oidc_test/auth_utils.py | 80 +++++++++++-------- src/oidc_test/database.py | 27 +++---- src/oidc_test/main.py | 128 ++++++++++++------------------ src/oidc_test/models.py | 10 +-- src/oidc_test/resource_server.py | 9 ++- src/oidc_test/settings.py | 45 ++--------- src/oidc_test/templates/base.html | 2 +- src/oidc_test/templates/home.html | 28 +++---- 10 files changed, 183 insertions(+), 218 deletions(-) delete mode 100644 src/oidc_test/auth_misc.py create mode 100644 src/oidc_test/auth_provider.py diff --git a/src/oidc_test/auth_misc.py b/src/oidc_test/auth_misc.py deleted file mode 100644 index a4e9ea3..0000000 --- a/src/oidc_test/auth_misc.py +++ /dev/null @@ -1,29 +0,0 @@ -from datetime import datetime, timedelta -from collections import OrderedDict - -from .models import User - -time_keys = set(("iat", "exp", "auth_time", "updated_at")) - - -def pretty_details(user: User, now: datetime) -> OrderedDict: - details = OrderedDict() - # breakpoint() - for key in sorted(time_keys): - try: - dt = datetime.fromtimestamp(user.userinfo[key]) - except (KeyError, TypeError): - pass - else: - td = now - dt - td = timedelta(days=td.days, seconds=td.seconds) - if td.days < 0: - ptd = f"in {-td} h:m:s" - else: - ptd = f"{td} h:m:s ago" - details[key] = f"{user.userinfo[key]} - {dt} ({ptd})" - for key in sorted(user.userinfo): - if key in time_keys: - continue - details[key] = user.userinfo[key] - return details diff --git a/src/oidc_test/auth_provider.py b/src/oidc_test/auth_provider.py new file mode 100644 index 0000000..bed4596 --- /dev/null +++ b/src/oidc_test/auth_provider.py @@ -0,0 +1,43 @@ +from typing import Any +from jwt import decode +import logging + +from authlib.integrations.starlette_client.apps import StarletteOAuth2App + +from .settings import AuthProviderSettings, settings + +logger = logging.getLogger("oidc-test") + + +class Provider(AuthProviderSettings): + class Config: + arbitrary_types_allowed = True + + authlib_client: StarletteOAuth2App = StarletteOAuth2App(None) + + def decode(self, token: str, verify_signature: bool = True) -> dict[str, Any]: + """Decode the token with signature check""" + if settings.debug_token: + decoded = decode( + token, + self.get_public_key(), + algorithms=[self.signature_alg], + audience=["account", "oidc-test", "oidc-test-web"], + options={ + "verify_signature": False, + "verify_aud": False, + }, # not settings.insecure.skip_verify_signature}, + ) + logger.debug(str(decoded)) + return decode( + token, + self.get_public_key(), + algorithms=[self.signature_alg], + audience=["account", "oidc-test", "oidc-test-web"], + options={ + "verify_signature": verify_signature, + }, # not settings.insecure.skip_verify_signature}, + ) + + +providers: dict[str, Provider] = {} diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py index cab14b2..e62fe39 100644 --- a/src/oidc_test/auth_utils.py +++ b/src/oidc_test/auth_utils.py @@ -1,4 +1,3 @@ -import re from typing import Union, Annotated from functools import wraps import logging @@ -14,7 +13,8 @@ from authlib.oauth2.auth import OAuth2Token from .models import User from .database import db, TokenNotInDb, UserNotInDB -from .settings import oidc_providers_settings +from .settings import settings +from .auth_provider import providers, Provider logger = logging.getLogger("oidc-test") @@ -35,11 +35,13 @@ async def fetch_token(name, request): # return token.to_token() -async def update_token(name, token, refresh_token=None, access_token=None): +async def update_token( + provider_id, token, refresh_token: str | None = None, access_token: str | None = None +): """Update the token in the database""" - oidc_provider_settings = oidc_providers_settings[name] - sid: str = oidc_provider_settings.decode(token["id_token"])["sid"] - item = await db.get_token(oidc_provider_settings, sid) + provider = providers[provider_id] + sid: str = provider.decode(token["id_token"])["sid"] + item = await db.get_token(provider, sid) # update old token item["access_token"] = token["access_token"] item["refresh_token"] = token["refresh_token"] @@ -54,10 +56,12 @@ authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_t def init_providers(): - # Add oidc providers to authlib from the settings - for id, provider in oidc_providers_settings.items(): + """Add oidc providers to authlib from the settings + and build the providers dict""" + for provider_settings in settings.auth.providers: + provider = Provider(**provider_settings.model_dump()) authlib_oauth.register( - name=id, + name=provider.id, server_metadata_url=provider.openid_configuration, client_kwargs={ "scope": " ".join( @@ -74,6 +78,8 @@ def init_providers(): update_token=update_token, # client_id="some-client-id", # if enabled, authlib will also check that the access token belongs to this client id (audience) ) + provider.authlib_client = getattr(authlib_oauth, provider.id) + providers[provider.id] = provider init_providers() @@ -82,33 +88,41 @@ init_providers() async def get_providers_info(): # Get the public key: async with AsyncClient() as client: - for provider_settings in oidc_providers_settings.values(): - if provider_settings.info_url: - provider_info = await client.get(provider_settings.url) - provider_settings.info = provider_info.json() + for provider in providers.values(): + if provider.info_url: + provider_info = await client.get(provider.url) + provider.info = provider_info.json() -def get_oidc_provider_or_none(request: Request) -> StarletteOAuth2App | None: +def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None: """Return the oidc_provider from a request object, from the session. It can be used in Depends()""" - if (oidc_provider_id := request.session.get("oidc_provider_id")) is None: - return - try: - return getattr(authlib_oauth, str(oidc_provider_id)) - except AttributeError: + if (auth_provider_id := request.session.get("auth_provider_id")) is None: return + return getattr(authlib_oauth, str(auth_provider_id), None) -def get_oidc_provider(request: Request) -> StarletteOAuth2App: - if (oidc_provider := get_oidc_provider_or_none(request)) is None: - if oidc_provider is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No provider") - else: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") +def get_auth_provider_client(request: Request) -> StarletteOAuth2App: + if (oidc_provider := get_auth_provider_client_or_none(request)) is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") else: return oidc_provider +def get_auth_provider_or_none(request: Request) -> Provider | None: + """Return the oidc_provider settings from a request object, from the session. + It can be used in Depends()""" + if (auth_provider_id := request.session.get("auth_provider_id")) is None: + return + return providers.get(auth_provider_id) + + +def get_auth_provider(request: Request) -> Provider: + if (provider := get_auth_provider_or_none(request)) is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") + return provider + + async def get_current_user(request: Request) -> User: """Get the current user from a request object. Also validates the token expiration time. @@ -120,11 +134,11 @@ async def get_current_user(request: Request) -> User: user = await db.get_user(user_sub) ## Check if the token is expired if token.is_expired(): - oidc_provider = get_oidc_provider(request=request) + provider = get_auth_provider(request=request) ## Ask a new refresh token from the provider logger.info(f"Token expired for user {user.name}") try: - userinfo = await oidc_provider.fetch_access_token( + userinfo = await provider.authlib_client.fetch_access_token( refresh_token=token.get("refresh_token") ) assert userinfo is not None @@ -150,14 +164,12 @@ async def get_token(request: Request) -> OAuth2Token: """Return the token from the session. Can be used in Depends()""" try: - oidc_provider_settings = oidc_providers_settings[ - request.session.get("oidc_provider_id", "") - ] + provider = providers[request.session.get("auth_provider_id", "")] except KeyError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid provider") try: return await db.get_token( - oidc_provider_settings, + provider, request.session.get("sid"), ) except (TokenNotInDb, InvalidKeyError, DecodeError) as err: @@ -219,7 +231,7 @@ async def get_user_from_token( "Request headers must have a 'auth_provider' field", ) try: - auth_provider_settings = oidc_providers_settings[auth_provider_id] + provider = providers[auth_provider_id] except KeyError: if auth_provider_id == "": raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No auth provider") @@ -230,7 +242,7 @@ async def get_user_from_token( if token == "": raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token") try: - payload = auth_provider_settings.decode(token) + payload = provider.decode(token) except ExpiredSignatureError as err: logger.info(f"Expired signature: {err}") raise HTTPException( @@ -261,7 +273,7 @@ async def get_user_from_token( user = await db.add_user( sub=payload["sub"], user_info=payload, - oidc_provider=getattr(authlib_oauth, auth_provider_id), + auth_provider=getattr(authlib_oauth, auth_provider_id), access_token=token, ) return user diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py index d3bdd4e..3493429 100644 --- a/src/oidc_test/database.py +++ b/src/oidc_test/database.py @@ -2,11 +2,10 @@ import logging -from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.oauth2.rfc6749 import OAuth2Token -from .settings import OIDCProvider, oidc_providers_settings from .models import User, Role +from .auth_provider import Provider, providers logger = logging.getLogger("oidc-test") @@ -29,23 +28,23 @@ class Database: self, sub: str, user_info: dict, - oidc_provider: StarletteOAuth2App, + auth_provider: Provider, access_token: str, access_token_decoded: dict | None = None, ) -> User: if access_token_decoded is None: - assert oidc_provider.name is not None - oidc_provider_settings = oidc_providers_settings[oidc_provider.name] - access_token_decoded = oidc_provider_settings.decode(access_token) + assert auth_provider.name is not None + provider = providers[auth_provider.id] + access_token_decoded = provider.decode(access_token) + user_info["auth_provider_id"] = auth_provider.id user = User(**user_info) user.userinfo = user_info - user.oidc_provider = oidc_provider - user.access_token = access_token - user.access_token_decoded = access_token_decoded + # user.access_token = access_token + # user.access_token_decoded = access_token_decoded # Add roles provided in the access token roles = set() try: - r = access_token_decoded["resource_access"][oidc_provider.client_id]["roles"] + r = access_token_decoded["resource_access"][auth_provider.client_id]["roles"] roles.update(r) except KeyError: pass @@ -66,19 +65,19 @@ class Database: raise UserNotInDB return self.users[sub] - async def add_token(self, oidc_provider_settings: OIDCProvider, token: OAuth2Token) -> None: + async def add_token(self, provider: Provider, token: OAuth2Token) -> None: """Store a token using as key the sid (auth provider's session id) in the id_token""" - assert isinstance(oidc_provider_settings, OIDCProvider) + assert isinstance(provider, Provider) sid = token["userinfo"]["sid"] self.tokens[sid] = token async def get_token( self, - oidc_provider_settings: OIDCProvider, + provider: Provider, sid: str | None, ) -> OAuth2Token: - assert isinstance(oidc_provider_settings, OIDCProvider) + assert isinstance(provider, Provider) if sid is None: raise TokenNotInDb try: diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py index 03d13d7..304df92 100644 --- a/src/oidc_test/main.py +++ b/src/oidc_test/main.py @@ -15,7 +15,7 @@ from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.templating import Jinja2Templates from fastapi.middleware.cors import CORSMiddleware -from jwt import InvalidTokenError, PyJWTError +from jwt import PyJWTError from starlette.middleware.sessions import SessionMiddleware from authlib.integrations.starlette_client.apps import StarletteOAuth2App from authlib.integrations.base_client import OAuthError @@ -26,11 +26,12 @@ from authlib.oauth2.rfc6749 import OAuth2Token # from fastapi.security import OpenIdConnect # from pkce import generate_code_verifier, generate_pkce_pair -from .settings import settings, oidc_providers_settings +from .settings import settings +from .auth_provider import Provider, providers from .models import User from .auth_utils import ( - get_oidc_provider, - get_oidc_provider_or_none, + get_auth_provider, + get_auth_provider_or_none, get_current_user_or_none, authlib_oauth, get_providers_info, @@ -38,7 +39,6 @@ from .auth_utils import ( get_token, update_token, ) -from .auth_misc import pretty_details from .database import TokenNotInDb, db from .resource_server import resource_server @@ -78,51 +78,30 @@ app.mount("/resource", resource_server, name="resource_server") async def home( request: Request, user: Annotated[User, Depends(get_current_user_or_none)], - oidc_provider: Annotated[StarletteOAuth2App | None, Depends(get_oidc_provider_or_none)], + provider: Annotated[Provider | None, Depends(get_auth_provider_or_none)], token: Annotated[OAuth2Token | None, Depends(get_token_or_none)], ) -> HTMLResponse: - now = datetime.now() - if oidc_provider and ( - ( - oidc_provider_settings := oidc_providers_settings.get( - request.session.get("oidc_provider_id", "") - ) - ) - is not None - ): - resources = oidc_provider_settings.resources - else: - resources = [] - oidc_provider_settings = None - context = { "settings": settings.model_dump(), "user": user, - "now": now, - "oidc_provider": oidc_provider, - "oidc_provider_settings": oidc_provider_settings, - "resources": resources, + "now": datetime.now(), + "auth_provider": provider, } - if token is None: + if provider is None or token is None: context["access_token"] = None context["id_token_parsed"] = None context["access_token_parsed"] = None context["refresh_token_parsed"] = None + context["resources"] = None else: context["access_token"] = token["access_token"] - assert oidc_provider is not None - assert oidc_provider.name is not None - oidc_provider_settings = oidc_providers_settings[oidc_provider.name] - access_token_parsed = oidc_provider_settings.decode( - token["access_token"], verify_signature=False - ) + context["resources"] = provider.resources + access_token_parsed = provider.decode(token["access_token"], verify_signature=False) context["access_token_scope"] = access_token_parsed["scope"] # context["id_token_parsed"] = pretty_details(user, now) - context["id_token_parsed"] = oidc_provider_settings.decode( - token["id_token"], verify_signature=False - ) + context["id_token_parsed"] = provider.decode(token["id_token"], verify_signature=False) context["access_token_parsed"] = access_token_parsed - context["refresh_token_parsed"] = oidc_provider_settings.decode( + context["refresh_token_parsed"] = provider.decode( token["refresh_token"], verify_signature=False ) return templates.TemplateResponse(name="home.html", request=request, context=context) @@ -131,20 +110,20 @@ async def home( # Endpoints for the login / authorization process -@app.get("/login/{oidc_provider_id}") -async def login(request: Request, oidc_provider_id: str) -> RedirectResponse: +@app.get("/login/{auth_provider_id}") +async def login(request: Request, auth_provider_id: str) -> RedirectResponse: """Login with the provider id, giving the browser a redirect to its authorize page. - The provider is expected to send the browser back to our own /auth/{oidc_provider_id} url + The provider is expected to send the browser back to our own /auth/{auth_provider_id} url with the token. """ - redirect_uri = request.url_for("auth", oidc_provider_id=oidc_provider_id) + redirect_uri = request.url_for("auth", auth_provider_id=auth_provider_id) try: - provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id) + provider: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id) except AttributeError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") # if ( - # code_challenge_method := oidc_providers_settings[ - # oidc_provider_id + # code_challenge_method := providers[ + # auth_provider_id # ].code_challenge_method # ) is not None: # #client = AsyncOAuth2Client(..., code_challenge_method=code_challenge_method) @@ -164,30 +143,30 @@ async def login(request: Request, oidc_provider_id: str) -> RedirectResponse: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot reach provider") -@app.get("/auth/{oidc_provider_id}") -async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: +@app.get("/auth/{auth_provider_id}") +async def auth(request: Request, auth_provider_id: str) -> RedirectResponse: """Decrypt the auth token, store it to the session (cookie based) and response to the browser with a redirect to a "welcome user" page. """ try: - oidc_provider: StarletteOAuth2App = getattr(authlib_oauth, oidc_provider_id) + authlib_client: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id) except AttributeError: raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider") try: - token: OAuth2Token = await oidc_provider.authorize_access_token(request) + token: OAuth2Token = await authlib_client.authorize_access_token(request) except OAuthError as error: raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error) - # Remember the oidc_provider in the session + # Remember the authlib_client in the session # logger.info(f"Scope: {token['scope']}") - request.session["oidc_provider_id"] = oidc_provider_id + request.session["auth_provider_id"] = auth_provider_id # # One could process the full decoded token which contains extra information # eg for updates. Here we are only interested in roles # if userinfo := token.get("userinfo"): - # Remember the oidc_provider in the session - request.session["oidc_provider_id"] = oidc_provider_id - # User id (sub) given by oidc provider + # Remember the authlib_client in the session + request.session["auth_provider_id"] = auth_provider_id + # User id (sub) given by auth provider sub = userinfo["sub"] # Build and remember the user in the session request.session["user_sub"] = sub @@ -196,7 +175,7 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: user = await db.add_user( sub, user_info=userinfo, - oidc_provider=oidc_provider, + auth_provider=providers[auth_provider_id], access_token=token["access_token"], ) except PyJWTError as err: @@ -208,48 +187,41 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse: # Add the provider session id to the session request.session["sid"] = userinfo["sid"] # Add the token to the db because it is used for logout - assert oidc_provider.name is not None - oidc_provider_settings = oidc_providers_settings[oidc_provider.name] - await db.add_token(oidc_provider_settings, token) + provider = providers[auth_provider_id] + await db.add_token(provider, token) # Send the user to the home: (s)he is authenticated return RedirectResponse(url=request.url_for("home")) else: # Not sure if it's correct to redirect to plain login # if no userinfo is provided - return RedirectResponse(url=request.url_for("login", oidc_provider_id=oidc_provider_id)) + return RedirectResponse(url=request.url_for("login", auth_provider_id=auth_provider_id)) @app.get("/account") async def account( - request: Request, + provider: Annotated[Provider, Depends(get_auth_provider)], ) -> RedirectResponse: - if ( - oidc_provider_settings := oidc_providers_settings.get( - request.session.get("oidc_provider_id", "") - ) - ) is None: - raise HTTPException(status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider settings") - return RedirectResponse(f"{oidc_provider_settings.account_url_template}") + """Redirect to the auth provider account management, + if account_url_template is in the provider's settings""" + return RedirectResponse(f"{provider.account_url_template}") @app.get("/logout") async def logout( request: Request, - oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], + provider: Annotated[Provider, Depends(get_auth_provider)], ) -> RedirectResponse: # Clear session request.session.pop("user_sub", None) # Get provider's endpoint - if (provider_logout_uri := oidc_provider.server_metadata.get("end_session_endpoint")) is None: - logger.warn(f"Cannot find end_session_endpoint for provider {oidc_provider.name}") + if ( + provider_logout_uri := provider.authlib_client.server_metadata.get("end_session_endpoint") + ) is None: + logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}") return RedirectResponse(request.url_for("non_compliant_logout")) post_logout_uri = request.url_for("home") - oidc_provider_settings = oidc_providers_settings.get( - request.session.get("oidc_provider_id", "") - ) - assert oidc_provider_settings is not None try: - token = await db.get_token(oidc_provider_settings, request.session.pop("sid", None)) + token = await db.get_token(provider, request.session.pop("sid", None)) except TokenNotInDb: logger.warn("No session in db for the token or no token") return RedirectResponse(request.url_for("home")) @@ -270,30 +242,30 @@ async def logout( @app.get("/non-compliant-logout") async def non_compliant_logout( request: Request, - oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], + provider: Annotated[StarletteOAuth2App, Depends(get_auth_provider)], ): """A page for non-compliant OAuth2 servers that we cannot log out.""" # Clear the remain of the session - request.session.pop("oidc_provider_id", None) + request.session.pop("auth_provider_id", None) return templates.TemplateResponse( name="non_compliant_logout.html", request=request, - context={"oidc_provider": oidc_provider, "home_url": request.url_for("home")}, + context={"oidc_provider": provider, "home_url": request.url_for("home")}, ) @app.get("/refresh") async def refresh( request: Request, - oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], + provider: Annotated[Provider, Depends(get_auth_provider)], token: Annotated[OAuth2Token, Depends(get_token)], ) -> RedirectResponse: """Manually refresh token""" - new_token = await oidc_provider.fetch_access_token( + new_token = await provider.authlib_client.fetch_access_token( refresh_token=token["refresh_token"], grant_type="refresh_token", ) - await update_token(oidc_provider.name, new_token) + await update_token(provider.id, new_token) return RedirectResponse(url=request.url_for("home")) diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py index 8aee2e6..eda63a6 100644 --- a/src/oidc_test/models.py +++ b/src/oidc_test/models.py @@ -8,7 +8,6 @@ from pydantic import ( EmailStr, ConfigDict, ) -from authlib.integrations.starlette_client.apps import StarletteOAuth2App from sqlmodel import SQLModel, Field logger = logging.getLogger("oidc-test") @@ -37,7 +36,7 @@ class User(UserBase): userinfo: dict = {} access_token: str | None = None access_token_decoded: dict[str, Any] | None = None - oidc_provider: StarletteOAuth2App | None = None + auth_provider_id: str @computed_field @cached_property @@ -56,11 +55,10 @@ class User(UserBase): def decode_access_token(self, verify_signature: bool = True): assert self.access_token is not None - assert self.oidc_provider is not None - assert self.oidc_provider.name is not None - from .settings import oidc_providers_settings + assert self.auth_provider_id is not None + from .auth_provider import providers - return oidc_providers_settings[self.oidc_provider.name].decode( + return providers[self.auth_provider_id].decode( self.access_token, verify_signature=verify_signature ) diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py index cb944ed..e5670ed 100644 --- a/src/oidc_test/resource_server.py +++ b/src/oidc_test/resource_server.py @@ -2,6 +2,7 @@ from datetime import datetime from typing import Annotated import logging +from authlib.jose import Key from httpx import AsyncClient from jwt.exceptions import ExpiredSignatureError, InvalidTokenError from fastapi import FastAPI, HTTPException, Depends, status @@ -15,10 +16,9 @@ from .models import User from .auth_utils import ( get_user_from_token, UserWithRole, - # get_oidc_provider, - # get_token, ) from .settings import settings +from .auth_provider import providers logger = logging.getLogger("oidc-test") @@ -128,7 +128,10 @@ async def get_resource(resource_id: str, user: User) -> dict: """ Resource processing: build an informative rely as a simple showcase """ - pname = getattr(user.oidc_provider, "name", "?") + try: + pname = providers[user.auth_provider_id].name + except KeyError: + pname = "?" resp = { "hello": f"Hi {user.name} from an OAuth resource provider", "comment": f"I received a request for '{resource_id}' " diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py index e448c1e..9a789a0 100644 --- a/src/oidc_test/settings.py +++ b/src/oidc_test/settings.py @@ -1,11 +1,9 @@ from os import environ import string import random -from typing import Type, Tuple, Any +from typing import Type, Tuple from pathlib import Path -import logging -from jwt import decode from pydantic import BaseModel, computed_field, AnyUrl from pydantic_settings import ( BaseSettings, @@ -17,8 +15,6 @@ from starlette.requests import Request from .models import User -logger = logging.getLogger("oidc-test") - class Resource(BaseModel): """A resource with an URL that can be accessed with an OAuth2 access token""" @@ -27,8 +23,8 @@ class Resource(BaseModel): name: str -class OIDCProvider(BaseModel): - """OIDC provider, can also be a resource server""" +class AuthProviderSettings(BaseModel): + """Auth provider, can also be a resource server""" id: str name: str @@ -79,30 +75,6 @@ class OIDCProvider(BaseModel): -----END PUBLIC KEY----- """ - def decode(self, token: str, verify_signature: bool = True) -> dict[str, Any]: - """Decode the token with signature check""" - if settings.debug_token: - decoded = decode( - token, - self.get_public_key(), - algorithms=[self.signature_alg], - audience=["account", "oidc-test", "oidc-test-web"], - options={ - "verify_signature": False, - "verify_aud": False, - }, # not settings.insecure.skip_verify_signature}, - ) - logger.debug(str(decoded)) - return decode( - token, - self.get_public_key(), - algorithms=[self.signature_alg], - audience=["account", "oidc-test", "oidc-test-web"], - options={ - "verify_signature": verify_signature, - }, # not settings.insecure.skip_verify_signature}, - ) - class ResourceProvider(BaseModel): id: str @@ -111,9 +83,9 @@ class ResourceProvider(BaseModel): resources: list[Resource] = [] -class OIDCSettings(BaseModel): +class AuthSettings(BaseModel): show_session_details: bool = False - providers: list[OIDCProvider] = [] + providers: list[AuthProviderSettings] = [] swagger_provider: str = "" @@ -128,7 +100,7 @@ class Settings(BaseSettings): model_config = SettingsConfigDict(env_nested_delimiter="__") - oidc: OIDCSettings = OIDCSettings() + auth: AuthSettings = AuthSettings() resource_providers: list[ResourceProvider] = [] secret_key: str = "".join(random.choice(string.ascii_letters) for _ in range(16)) log: bool = False @@ -163,8 +135,3 @@ class Settings(BaseSettings): settings = Settings() - - -oidc_providers_settings: dict[str, OIDCProvider] = dict( - [(provider.id, provider) for provider in settings.oidc.providers] -) diff --git a/src/oidc_test/templates/base.html b/src/oidc_test/templates/base.html index 2ce758c..4cb56f5 100644 --- a/src/oidc_test/templates/base.html +++ b/src/oidc_test/templates/base.html @@ -4,7 +4,7 @@ - +

OIDC-test - FastAPI client

{% block content %} {% endblock %} diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html index 08bcf43..7275f2d 100644 --- a/src/oidc_test/templates/home.html +++ b/src/oidc_test/templates/home.html @@ -8,7 +8,7 @@