From 496ce016e3d31fdb003146a3f84bc86726275e6d Mon Sep 17 00:00:00 2001
From: phil
Date: Mon, 10 Feb 2025 02:05:34 +0100
Subject: [PATCH] Continue refactor; fetch resources from the providers'
settings
---
src/oidc_test/auth_provider.py | 64 +++++++++++++++--
src/oidc_test/auth_utils.py | 45 ++++++------
src/oidc_test/database.py | 9 ++-
src/oidc_test/main.py | 68 +++++++++++++------
src/oidc_test/models.py | 6 +-
src/oidc_test/resource_server.py | 50 +++++++++++---
src/oidc_test/settings.py | 16 ++---
src/oidc_test/static/utils.js | 17 +++--
src/oidc_test/templates/home.html | 49 +++++++------
.../templates/non_compliant_logout.html | 6 +-
10 files changed, 217 insertions(+), 113 deletions(-)
diff --git a/src/oidc_test/auth_provider.py b/src/oidc_test/auth_provider.py
index bed4596..e50241c 100644
--- a/src/oidc_test/auth_provider.py
+++ b/src/oidc_test/auth_provider.py
@@ -1,26 +1,41 @@
+from json import JSONDecodeError
from typing import Any
from jwt import decode
import logging
+from collections import OrderedDict
+from pydantic import ConfigDict
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
+from httpx import AsyncClient
from .settings import AuthProviderSettings, settings
+from .models import User
logger = logging.getLogger("oidc-test")
+class NoPublicKey(Exception):
+ pass
+
+
class Provider(AuthProviderSettings):
- class Config:
- arbitrary_types_allowed = True
+ # To allow authlib_client as StarletteOAuth2App
+ model_config = ConfigDict(arbitrary_types_allowed=True) # type:ignore
authlib_client: StarletteOAuth2App = StarletteOAuth2App(None)
+ info: dict[str, Any] = {}
+ unknown_auth_user: User
- def decode(self, token: str, verify_signature: bool = True) -> dict[str, Any]:
+ def decode(self, token: str, verify_signature: bool | None = None) -> dict[str, Any]:
"""Decode the token with signature check"""
+ if self.public_key is None:
+ raise NoPublicKey
+ if verify_signature is None:
+ verify_signature = self.skip_verify_signature
if settings.debug_token:
decoded = decode(
token,
- self.get_public_key(),
+ self.public_key,
algorithms=[self.signature_alg],
audience=["account", "oidc-test", "oidc-test-web"],
options={
@@ -31,7 +46,7 @@ class Provider(AuthProviderSettings):
logger.debug(str(decoded))
return decode(
token,
- self.get_public_key(),
+ self.public_key,
algorithms=[self.signature_alg],
audience=["account", "oidc-test", "oidc-test-web"],
options={
@@ -39,5 +54,42 @@ class Provider(AuthProviderSettings):
}, # not settings.insecure.skip_verify_signature},
)
+ async def get_info(self):
+ # Get the public key:
+ async with AsyncClient() as client:
+ public_key: str | None = None
+ if self.info_url is not None:
+ try:
+ provider_info = await client.get(self.info_url)
+ except Exception:
+ raise NoPublicKey
+ try:
+ self.info = provider_info.json()
+ except JSONDecodeError:
+ raise NoPublicKey
+ if "public_key" in self.info:
+ # For Keycloak
+ try:
+ public_key = str(self.info["public_key"])
+ except KeyError:
+ raise NoPublicKey
+ elif "keys" in self.info:
+ # For Forgejo/Gitea
+ try:
+ public_key = str(self.info["keys"][0]["n"])
+ except KeyError:
+ raise NoPublicKey
+ if self.public_key_url is not None:
+ resp = await client.get(self.public_key_url)
+ public_key = resp.text
+ if public_key is None:
+ raise NoPublicKey
+ self.public_key = "\n".join(
+ ["-----BEGIN PUBLIC KEY-----", public_key, "-----END PUBLIC KEY-----"]
+ )
-providers: dict[str, Provider] = {}
+ def get_session_key(self, userinfo):
+ return userinfo[self.session_key]
+
+
+providers: OrderedDict[str, Provider] = OrderedDict()
diff --git a/src/oidc_test/auth_utils.py b/src/oidc_test/auth_utils.py
index e62fe39..8cd5028 100644
--- a/src/oidc_test/auth_utils.py
+++ b/src/oidc_test/auth_utils.py
@@ -5,8 +5,7 @@ import logging
from fastapi import HTTPException, Request, Depends, status
from fastapi.security import OAuth2PasswordBearer
from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App
-from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError
-from httpx import AsyncClient
+from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError
# from authlib.oauth1.auth import OAuthToken
from authlib.oauth2.auth import OAuth2Token
@@ -40,7 +39,7 @@ async def update_token(
):
"""Update the token in the database"""
provider = providers[provider_id]
- sid: str = provider.decode(token["id_token"])["sid"]
+ sid: str = provider.get_session_key(provider.decode(token["id_token"]))
item = await db.get_token(provider, sid)
# update old token
item["access_token"] = token["access_token"]
@@ -59,7 +58,12 @@ def init_providers():
"""Add oidc providers to authlib from the settings
and build the providers dict"""
for provider_settings in settings.auth.providers:
- provider = Provider(**provider_settings.model_dump())
+ provider_settings_dict = provider_settings.model_dump()
+ # Add an anonymous user, that cannot be identified but has provided a valid access token
+ provider_settings_dict["unknown_auth_user"] = User(
+ sub="", auth_provider_id=provider_settings.id
+ )
+ provider = Provider(**provider_settings_dict)
authlib_oauth.register(
name=provider.id,
server_metadata_url=provider.openid_configuration,
@@ -85,15 +89,6 @@ def init_providers():
init_providers()
-async def get_providers_info():
- # Get the public key:
- async with AsyncClient() as client:
- for provider in providers.values():
- if provider.info_url:
- provider_info = await client.get(provider.url)
- provider.info = provider_info.json()
-
-
def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None:
"""Return the oidc_provider from a request object, from the session.
It can be used in Depends()"""
@@ -166,6 +161,8 @@ async def get_token(request: Request) -> OAuth2Token:
try:
provider = providers[request.session.get("auth_provider_id", "")]
except KeyError:
+ request.session.pop("auth_provider_id", None)
+ request.session.pop("user_sub", None)
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid provider")
try:
return await db.get_token(
@@ -239,29 +236,27 @@ async def get_user_from_token(
raise HTTPException(
status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
)
- if token == "":
+ if token == "None":
+ request.session.pop("auth_provider_id", None)
+ request.session.pop("user_sub", None)
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token")
try:
payload = provider.decode(token)
- except ExpiredSignatureError as err:
- logger.info(f"Expired signature: {err}")
+ except ExpiredSignatureError:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED,
"Expired signature (refresh not implemented yet)",
)
- except InvalidKeyError as err:
- logger.info(f"Invalid key: {err}")
+ except InvalidKeyError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key")
- except Exception as err:
- logger.info("Cannot decode token, see below")
- logger.exception(err)
- raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot decode token")
+ except PyJWTError as err:
+ raise HTTPException(
+ status.HTTP_401_UNAUTHORIZED, f"Cannot decode token: {err.__class__.__name__}"
+ )
try:
user_id = payload["sub"]
except KeyError:
- raise HTTPException(
- status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found"
- )
+ return provider.unknown_auth_user
try:
user = await db.get_user(user_id)
if user.access_token != token:
diff --git a/src/oidc_test/database.py b/src/oidc_test/database.py
index 3493429..659fd13 100644
--- a/src/oidc_test/database.py
+++ b/src/oidc_test/database.py
@@ -3,6 +3,7 @@
import logging
from authlib.oauth2.rfc6749 import OAuth2Token
+from jwt import PyJWTError
from .models import User, Role
from .auth_provider import Provider, providers
@@ -35,7 +36,10 @@ class Database:
if access_token_decoded is None:
assert auth_provider.name is not None
provider = providers[auth_provider.id]
- access_token_decoded = provider.decode(access_token)
+ try:
+ access_token_decoded = provider.decode(access_token)
+ except PyJWTError:
+ access_token_decoded = {}
user_info["auth_provider_id"] = auth_provider.id
user = User(**user_info)
user.userinfo = user_info
@@ -68,8 +72,7 @@ class Database:
async def add_token(self, provider: Provider, token: OAuth2Token) -> None:
"""Store a token using as key the sid (auth provider's session id)
in the id_token"""
- assert isinstance(provider, Provider)
- sid = token["userinfo"]["sid"]
+ sid = provider.get_session_key(token["userinfo"])
self.tokens[sid] = token
async def get_token(
diff --git a/src/oidc_test/main.py b/src/oidc_test/main.py
index 304df92..4018997 100644
--- a/src/oidc_test/main.py
+++ b/src/oidc_test/main.py
@@ -27,14 +27,13 @@ from authlib.oauth2.rfc6749 import OAuth2Token
# from pkce import generate_code_verifier, generate_pkce_pair
from .settings import settings
-from .auth_provider import Provider, providers
+from .auth_provider import NoPublicKey, Provider, providers
from .models import User
from .auth_utils import (
get_auth_provider,
get_auth_provider_or_none,
get_current_user_or_none,
authlib_oauth,
- get_providers_info,
get_token_or_none,
get_token,
update_token,
@@ -50,7 +49,12 @@ templates = Jinja2Templates(Path(__file__).parent / "templates")
@asynccontextmanager
async def lifespan(app: FastAPI):
assert app is not None
- await get_providers_info()
+ for provider in list(providers.values()):
+ try:
+ await provider.get_info()
+ except NoPublicKey:
+ logger.warn(f"Disable {provider.id}: public key not found")
+ del providers[provider.id]
yield
@@ -82,12 +86,13 @@ async def home(
token: Annotated[OAuth2Token | None, Depends(get_token_or_none)],
) -> HTMLResponse:
context = {
- "settings": settings.model_dump(),
+ "show_token": settings.show_token,
"user": user,
"now": datetime.now(),
"auth_provider": provider,
}
if provider is None or token is None:
+ context["providers"] = providers
context["access_token"] = None
context["id_token_parsed"] = None
context["access_token_parsed"] = None
@@ -96,14 +101,23 @@ async def home(
else:
context["access_token"] = token["access_token"]
context["resources"] = provider.resources
- access_token_parsed = provider.decode(token["access_token"], verify_signature=False)
- context["access_token_scope"] = access_token_parsed["scope"]
+ try:
+ access_token_parsed = provider.decode(token["access_token"], verify_signature=False)
+ except PyJWTError as err:
+ access_token_parsed = {"Cannot parse": err.__class__.__name__}
+ try:
+ context["access_token_scope"] = access_token_parsed["scope"]
+ except KeyError:
+ context["access_token_scope"] = None
# context["id_token_parsed"] = pretty_details(user, now)
context["id_token_parsed"] = provider.decode(token["id_token"], verify_signature=False)
context["access_token_parsed"] = access_token_parsed
- context["refresh_token_parsed"] = provider.decode(
- token["refresh_token"], verify_signature=False
- )
+ try:
+ context["refresh_token_parsed"] = provider.decode(
+ token["refresh_token"], verify_signature=False
+ )
+ except PyJWTError as err:
+ context["refresh_token_parsed"] = {"Cannot parse": err.__class__.__name__}
return templates.TemplateResponse(name="home.html", request=request, context=context)
@@ -144,16 +158,19 @@ async def login(request: Request, auth_provider_id: str) -> RedirectResponse:
@app.get("/auth/{auth_provider_id}")
-async def auth(request: Request, auth_provider_id: str) -> RedirectResponse:
+async def auth(
+ request: Request,
+ auth_provider_id: str,
+) -> RedirectResponse:
"""Decrypt the auth token, store it to the session (cookie based)
and response to the browser with a redirect to a "welcome user" page.
"""
try:
- authlib_client: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id)
- except AttributeError:
+ provider = providers[auth_provider_id]
+ except KeyError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
try:
- token: OAuth2Token = await authlib_client.authorize_access_token(request)
+ token: OAuth2Token = await provider.authlib_client.authorize_access_token(request)
except OAuthError as error:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error)
# Remember the authlib_client in the session
@@ -168,6 +185,14 @@ async def auth(request: Request, auth_provider_id: str) -> RedirectResponse:
request.session["auth_provider_id"] = auth_provider_id
# User id (sub) given by auth provider
sub = userinfo["sub"]
+ ## Get additional data from userinfo endpoint
+ # try:
+ # user_info_from_endpoint = await authlib_client.userinfo(
+ # token=token, follow_redirects=True
+ # )
+ # except Exception as err:
+ # logger.warn(f"Cannot get userinfo from endpoint: {err}")
+ # user_info_from_endpoint = {}
# Build and remember the user in the session
request.session["user_sub"] = sub
# Store the user in the database, which also verifies the token validity and signature
@@ -185,9 +210,8 @@ async def auth(request: Request, auth_provider_id: str) -> RedirectResponse:
)
assert isinstance(user, User)
# Add the provider session id to the session
- request.session["sid"] = userinfo["sid"]
+ request.session["sid"] = provider.get_session_key(userinfo)
# Add the token to the db because it is used for logout
- provider = providers[auth_provider_id]
await db.add_token(provider, token)
# Send the user to the home: (s)he is authenticated
return RedirectResponse(url=request.url_for("home"))
@@ -211,15 +235,16 @@ async def logout(
request: Request,
provider: Annotated[Provider, Depends(get_auth_provider)],
) -> RedirectResponse:
- # Clear session
- request.session.pop("user_sub", None)
# Get provider's endpoint
if (
provider_logout_uri := provider.authlib_client.server_metadata.get("end_session_endpoint")
) is None:
- logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}")
+ logger.warn(f"Cannot find end_session_endpoint for provider {provider.id}")
return RedirectResponse(request.url_for("non_compliant_logout"))
post_logout_uri = request.url_for("home")
+ # Clear session
+ request.session.pop("user_sub", None)
+ request.session.pop("auth_provider_id", None)
try:
token = await db.get_token(provider, request.session.pop("sid", None))
except TokenNotInDb:
@@ -242,15 +267,16 @@ async def logout(
@app.get("/non-compliant-logout")
async def non_compliant_logout(
request: Request,
- provider: Annotated[StarletteOAuth2App, Depends(get_auth_provider)],
+ provider: Annotated[Provider, Depends(get_auth_provider)],
):
"""A page for non-compliant OAuth2 servers that we cannot log out."""
- # Clear the remain of the session
+ # Clear session
+ request.session.pop("user_sub", None)
request.session.pop("auth_provider_id", None)
return templates.TemplateResponse(
name="non_compliant_logout.html",
request=request,
- context={"oidc_provider": provider, "home_url": request.url_for("home")},
+ context={"auth_provider": provider, "home_url": request.url_for("home")},
)
diff --git a/src/oidc_test/models.py b/src/oidc_test/models.py
index eda63a6..7c5250b 100644
--- a/src/oidc_test/models.py
+++ b/src/oidc_test/models.py
@@ -49,13 +49,13 @@ class User(UserBase):
try:
access_token_scopes = self.decode_access_token().get("scope", "").split(" ")
except Exception as err:
- logger.info(f"Access token cannot be decoded: {err}")
+ logger.debug(f"Cannot find scope because the access token cannot be decoded: {err}")
access_token_scopes = []
return scope in set(info_scopes + access_token_scopes)
def decode_access_token(self, verify_signature: bool = True):
- assert self.access_token is not None
- assert self.auth_provider_id is not None
+ assert self.access_token is not None, "no access_token"
+ assert self.auth_provider_id is not None, "no auth_provider_id"
from .auth_provider import providers
return providers[self.auth_provider_id].decode(
diff --git a/src/oidc_test/resource_server.py b/src/oidc_test/resource_server.py
index e5670ed..f7f0433 100644
--- a/src/oidc_test/resource_server.py
+++ b/src/oidc_test/resource_server.py
@@ -2,7 +2,7 @@ from datetime import datetime
from typing import Annotated
import logging
-from authlib.jose import Key
+from authlib.oauth2.auth import OAuth2Token
from httpx import AsyncClient
from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
from fastapi import FastAPI, HTTPException, Depends, status
@@ -14,11 +14,12 @@ from fastapi.middleware.cors import CORSMiddleware
from .models import User
from .auth_utils import (
+ get_token_or_none,
get_user_from_token,
UserWithRole,
)
from .settings import settings
-from .auth_provider import providers
+from .auth_provider import providers, Provider
logger = logging.getLogger("oidc-test")
@@ -113,23 +114,51 @@ async def get_protected_by_foorole_or_barrole(
@resource_server.get("/{id}")
-async def get_resource_(
+async def get_resource(
id: str,
- # user: Annotated[User, Depends(get_current_user)],
- # oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
- # token: Annotated[OAuth2Token, Depends(get_token)],
user: Annotated[User, Depends(get_user_from_token)],
-):
+ token: Annotated[OAuth2Token | None, Depends(get_token_or_none)],
+) -> dict | list:
"""Generic path for testing a resource provided by a provider"""
- return await get_resource(id, user)
+ provider = providers[user.auth_provider_id]
+ if id in [r.id for r in provider.resources]:
+ return await get_external_resource(
+ provider=provider,
+ id=id,
+ access_token=token["access_token"] if token else None,
+ user=user,
+ )
+ return await get_resource_(id, user)
-async def get_resource(resource_id: str, user: User) -> dict:
+async def get_external_resource(
+ provider: Provider, id: str, access_token: str | None, user: User
+) -> dict | list:
+ resource = [r for r in provider.resources if r.id == id][0]
+ async with AsyncClient() as client:
+ resp = await client.get(
+ url=provider.url + resource.url,
+ headers={
+ "Content-type": "application/json",
+ "Authorization": f"Bearer {access_token}",
+ },
+ )
+ if resp.is_error:
+ raise HTTPException(resp.status_code, f"Cannot fetch resource: {resp.reason_phrase}")
+ resp_length = len(resp.text)
+ if resp_length > 1024:
+ return {"msg": f"The resource is too long ({resp_length} bytes) to show here"}
+ else:
+ return resp.json()
+
+
+async def get_resource_(resource_id: str, user: User) -> dict:
"""
Resource processing: build an informative rely as a simple showcase
"""
+ provider = providers[user.auth_provider_id]
try:
- pname = providers[user.auth_provider_id].name
+ pname = provider.name
except KeyError:
pname = "?"
resp = {
@@ -164,7 +193,6 @@ async def process(user, resource_id, resp):
Too simple to be serious.
It's a good fit for a plugin architecture for production
"""
- assert user is not None
if resource_id == "time":
resp["time"] = datetime.now().strftime("%c")
elif resource_id == "bs":
diff --git a/src/oidc_test/settings.py b/src/oidc_test/settings.py
index 9a789a0..86a2b6b 100644
--- a/src/oidc_test/settings.py
+++ b/src/oidc_test/settings.py
@@ -21,6 +21,7 @@ class Resource(BaseModel):
id: str
name: str
+ url: str
class AuthProviderSettings(BaseModel):
@@ -39,10 +40,12 @@ class AuthProviderSettings(BaseModel):
info_url: str | None = (
None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key)
)
- info: dict[str, str | int] | None = None # Info fetched from info_url, eg. public key
public_key: str | None = None
+ public_key_url: str | None = None
signature_alg: str = "RS256"
resource_provider_scopes: list[str] = []
+ session_key: str = "sid"
+ skip_verify_signature: bool = True
@computed_field
@property
@@ -64,17 +67,6 @@ class AuthProviderSettings(BaseModel):
else:
return None
- def get_public_key(self) -> str:
- """Return the public key formatted for decoding token"""
- public_key = self.public_key or (self.info is not None and self.info["public_key"])
- if public_key is None:
- raise AttributeError(f"Cannot get public key for {self.name}")
- return f"""
- -----BEGIN PUBLIC KEY-----
- {public_key}
- -----END PUBLIC KEY-----
- """
-
class ResourceProvider(BaseModel):
id: str
diff --git a/src/oidc_test/static/utils.js b/src/oidc_test/static/utils.js
index 8e8ad59..6c9fae4 100644
--- a/src/oidc_test/static/utils.js
+++ b/src/oidc_test/static/utils.js
@@ -55,15 +55,24 @@ async function get_resource(id, token, authProvider) {
msg.innerHTML = ""
resourceElem.innerHTML = ""
Object.entries(resource).forEach(
- ([k, v]) => {
+ ([key, value]) => {
let r = document.createElement('div')
let kElem = document.createElement('div')
- kElem.innerText = k
+ kElem.innerText = key
kElem.className = "key"
let vElem = document.createElement('div')
- vElem.innerText = v
+ if (typeof value == "object") {
+ Object.entries(value).forEach(v => {
+ const ne = document.createElement('div')
+ ne.innerHTML = `${v[0]} : ${v[1]} `
+ vElem.appendChild(ne)
+ })
+ }
+ else {
+ vElem.innerText = value
+ }
vElem.className = "value"
- if (k == "sorry") {
+ if (key == "sorry") {
vElem.classList.add("error")
}
r.appendChild(kElem)
diff --git a/src/oidc_test/templates/home.html b/src/oidc_test/templates/home.html
index 7275f2d..da513c9 100644
--- a/src/oidc_test/templates/home.html
+++ b/src/oidc_test/templates/home.html
@@ -5,25 +5,24 @@
with OpenID Connect and OAuth2 with different providers.
{% if not user %}
-
-
Log in with:
-
- {% for provider in settings.auth.providers %}
-
-
- {{ provider.name }}
-
- {{ provider.hint }}
-
-
- {% else %}
- There is no authentication provider defined.
- Hint: check the settings.yaml file.
- {% endfor %}
-
-
- {% endif %}
- {% if user %}
+
+
Log in with:
+
+ {% for provider in providers.values() %}
+
+
+ {{ provider.name }}
+
+ {{ provider.hint }}
+
+
+ {% else %}
+ There is no authentication provider defined.
+ Hint: check the settings.yaml file.
+ {% endfor %}
+
+
+ {% else %}
Hey, {{ user.name }}
{% if user.picture %}
@@ -83,22 +82,22 @@
Using FastAPI Depends
-
{% if resources %}
Resources for this provider:
{% endif %}
+
- {% if settings.show_token and id_token_parsed %}
+ {% if show_token and id_token_parsed %}
diff --git a/src/oidc_test/templates/non_compliant_logout.html b/src/oidc_test/templates/non_compliant_logout.html
index 24a96ae..56758de 100644
--- a/src/oidc_test/templates/non_compliant_logout.html
+++ b/src/oidc_test/templates/non_compliant_logout.html
@@ -6,12 +6,12 @@
authorisation to log in again without asking for credentials.
- This is because {{ oidc_provider.name }} does not provide "end_session_endpoint" in its metadata
- (see: {{ oidc_provider._server_metadata_url }} ).
+ This is because {{ auth_provider.name }} does not provide "end_session_endpoint" in its metadata
+ (see: {{ auth_provider.authlib_client._server_metadata_url }} ).
You can just also go back to the application home page , but
- it recommended to go to the OIDC provider's site
+ it recommended to go to the OIDC provider's site
and log out explicitely from there.
{% endblock %}