Continue refactor; fetch resources from the providers' settings
All checks were successful
/ build (push) Successful in 5s
/ test (push) Successful in 5s

This commit is contained in:
phil 2025-02-10 02:05:34 +01:00
parent c5bb4f4319
commit 496ce016e3
10 changed files with 217 additions and 113 deletions

View file

@ -1,26 +1,41 @@
from json import JSONDecodeError
from typing import Any
from jwt import decode
import logging
from collections import OrderedDict
from pydantic import ConfigDict
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
from httpx import AsyncClient
from .settings import AuthProviderSettings, settings
from .models import User
logger = logging.getLogger("oidc-test")
class NoPublicKey(Exception):
pass
class Provider(AuthProviderSettings):
class Config:
arbitrary_types_allowed = True
# To allow authlib_client as StarletteOAuth2App
model_config = ConfigDict(arbitrary_types_allowed=True) # type:ignore
authlib_client: StarletteOAuth2App = StarletteOAuth2App(None)
info: dict[str, Any] = {}
unknown_auth_user: User
def decode(self, token: str, verify_signature: bool = True) -> dict[str, Any]:
def decode(self, token: str, verify_signature: bool | None = None) -> dict[str, Any]:
"""Decode the token with signature check"""
if self.public_key is None:
raise NoPublicKey
if verify_signature is None:
verify_signature = self.skip_verify_signature
if settings.debug_token:
decoded = decode(
token,
self.get_public_key(),
self.public_key,
algorithms=[self.signature_alg],
audience=["account", "oidc-test", "oidc-test-web"],
options={
@ -31,7 +46,7 @@ class Provider(AuthProviderSettings):
logger.debug(str(decoded))
return decode(
token,
self.get_public_key(),
self.public_key,
algorithms=[self.signature_alg],
audience=["account", "oidc-test", "oidc-test-web"],
options={
@ -39,5 +54,42 @@ class Provider(AuthProviderSettings):
}, # not settings.insecure.skip_verify_signature},
)
async def get_info(self):
# Get the public key:
async with AsyncClient() as client:
public_key: str | None = None
if self.info_url is not None:
try:
provider_info = await client.get(self.info_url)
except Exception:
raise NoPublicKey
try:
self.info = provider_info.json()
except JSONDecodeError:
raise NoPublicKey
if "public_key" in self.info:
# For Keycloak
try:
public_key = str(self.info["public_key"])
except KeyError:
raise NoPublicKey
elif "keys" in self.info:
# For Forgejo/Gitea
try:
public_key = str(self.info["keys"][0]["n"])
except KeyError:
raise NoPublicKey
if self.public_key_url is not None:
resp = await client.get(self.public_key_url)
public_key = resp.text
if public_key is None:
raise NoPublicKey
self.public_key = "\n".join(
["-----BEGIN PUBLIC KEY-----", public_key, "-----END PUBLIC KEY-----"]
)
providers: dict[str, Provider] = {}
def get_session_key(self, userinfo):
return userinfo[self.session_key]
providers: OrderedDict[str, Provider] = OrderedDict()

View file

@ -5,8 +5,7 @@ import logging
from fastapi import HTTPException, Request, Depends, status
from fastapi.security import OAuth2PasswordBearer
from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App
from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError
from httpx import AsyncClient
from jwt import ExpiredSignatureError, InvalidKeyError, DecodeError, PyJWTError
# from authlib.oauth1.auth import OAuthToken
from authlib.oauth2.auth import OAuth2Token
@ -40,7 +39,7 @@ async def update_token(
):
"""Update the token in the database"""
provider = providers[provider_id]
sid: str = provider.decode(token["id_token"])["sid"]
sid: str = provider.get_session_key(provider.decode(token["id_token"]))
item = await db.get_token(provider, sid)
# update old token
item["access_token"] = token["access_token"]
@ -59,7 +58,12 @@ def init_providers():
"""Add oidc providers to authlib from the settings
and build the providers dict"""
for provider_settings in settings.auth.providers:
provider = Provider(**provider_settings.model_dump())
provider_settings_dict = provider_settings.model_dump()
# Add an anonymous user, that cannot be identified but has provided a valid access token
provider_settings_dict["unknown_auth_user"] = User(
sub="", auth_provider_id=provider_settings.id
)
provider = Provider(**provider_settings_dict)
authlib_oauth.register(
name=provider.id,
server_metadata_url=provider.openid_configuration,
@ -85,15 +89,6 @@ def init_providers():
init_providers()
async def get_providers_info():
# Get the public key:
async with AsyncClient() as client:
for provider in providers.values():
if provider.info_url:
provider_info = await client.get(provider.url)
provider.info = provider_info.json()
def get_auth_provider_client_or_none(request: Request) -> StarletteOAuth2App | None:
"""Return the oidc_provider from a request object, from the session.
It can be used in Depends()"""
@ -166,6 +161,8 @@ async def get_token(request: Request) -> OAuth2Token:
try:
provider = providers[request.session.get("auth_provider_id", "")]
except KeyError:
request.session.pop("auth_provider_id", None)
request.session.pop("user_sub", None)
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid provider")
try:
return await db.get_token(
@ -239,29 +236,27 @@ async def get_user_from_token(
raise HTTPException(
status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
)
if token == "":
if token == "None":
request.session.pop("auth_provider_id", None)
request.session.pop("user_sub", None)
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token")
try:
payload = provider.decode(token)
except ExpiredSignatureError as err:
logger.info(f"Expired signature: {err}")
except ExpiredSignatureError:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED,
"Expired signature (refresh not implemented yet)",
)
except InvalidKeyError as err:
logger.info(f"Invalid key: {err}")
except InvalidKeyError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid auth provider key")
except Exception as err:
logger.info("Cannot decode token, see below")
logger.exception(err)
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot decode token")
except PyJWTError as err:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED, f"Cannot decode token: {err.__class__.__name__}"
)
try:
user_id = payload["sub"]
except KeyError:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found"
)
return provider.unknown_auth_user
try:
user = await db.get_user(user_id)
if user.access_token != token:

View file

@ -3,6 +3,7 @@
import logging
from authlib.oauth2.rfc6749 import OAuth2Token
from jwt import PyJWTError
from .models import User, Role
from .auth_provider import Provider, providers
@ -35,7 +36,10 @@ class Database:
if access_token_decoded is None:
assert auth_provider.name is not None
provider = providers[auth_provider.id]
access_token_decoded = provider.decode(access_token)
try:
access_token_decoded = provider.decode(access_token)
except PyJWTError:
access_token_decoded = {}
user_info["auth_provider_id"] = auth_provider.id
user = User(**user_info)
user.userinfo = user_info
@ -68,8 +72,7 @@ class Database:
async def add_token(self, provider: Provider, token: OAuth2Token) -> None:
"""Store a token using as key the sid (auth provider's session id)
in the id_token"""
assert isinstance(provider, Provider)
sid = token["userinfo"]["sid"]
sid = provider.get_session_key(token["userinfo"])
self.tokens[sid] = token
async def get_token(

View file

@ -27,14 +27,13 @@ from authlib.oauth2.rfc6749 import OAuth2Token
# from pkce import generate_code_verifier, generate_pkce_pair
from .settings import settings
from .auth_provider import Provider, providers
from .auth_provider import NoPublicKey, Provider, providers
from .models import User
from .auth_utils import (
get_auth_provider,
get_auth_provider_or_none,
get_current_user_or_none,
authlib_oauth,
get_providers_info,
get_token_or_none,
get_token,
update_token,
@ -50,7 +49,12 @@ templates = Jinja2Templates(Path(__file__).parent / "templates")
@asynccontextmanager
async def lifespan(app: FastAPI):
assert app is not None
await get_providers_info()
for provider in list(providers.values()):
try:
await provider.get_info()
except NoPublicKey:
logger.warn(f"Disable {provider.id}: public key not found")
del providers[provider.id]
yield
@ -82,12 +86,13 @@ async def home(
token: Annotated[OAuth2Token | None, Depends(get_token_or_none)],
) -> HTMLResponse:
context = {
"settings": settings.model_dump(),
"show_token": settings.show_token,
"user": user,
"now": datetime.now(),
"auth_provider": provider,
}
if provider is None or token is None:
context["providers"] = providers
context["access_token"] = None
context["id_token_parsed"] = None
context["access_token_parsed"] = None
@ -96,14 +101,23 @@ async def home(
else:
context["access_token"] = token["access_token"]
context["resources"] = provider.resources
access_token_parsed = provider.decode(token["access_token"], verify_signature=False)
context["access_token_scope"] = access_token_parsed["scope"]
try:
access_token_parsed = provider.decode(token["access_token"], verify_signature=False)
except PyJWTError as err:
access_token_parsed = {"Cannot parse": err.__class__.__name__}
try:
context["access_token_scope"] = access_token_parsed["scope"]
except KeyError:
context["access_token_scope"] = None
# context["id_token_parsed"] = pretty_details(user, now)
context["id_token_parsed"] = provider.decode(token["id_token"], verify_signature=False)
context["access_token_parsed"] = access_token_parsed
context["refresh_token_parsed"] = provider.decode(
token["refresh_token"], verify_signature=False
)
try:
context["refresh_token_parsed"] = provider.decode(
token["refresh_token"], verify_signature=False
)
except PyJWTError as err:
context["refresh_token_parsed"] = {"Cannot parse": err.__class__.__name__}
return templates.TemplateResponse(name="home.html", request=request, context=context)
@ -144,16 +158,19 @@ async def login(request: Request, auth_provider_id: str) -> RedirectResponse:
@app.get("/auth/{auth_provider_id}")
async def auth(request: Request, auth_provider_id: str) -> RedirectResponse:
async def auth(
request: Request,
auth_provider_id: str,
) -> RedirectResponse:
"""Decrypt the auth token, store it to the session (cookie based)
and response to the browser with a redirect to a "welcome user" page.
"""
try:
authlib_client: StarletteOAuth2App = getattr(authlib_oauth, auth_provider_id)
except AttributeError:
provider = providers[auth_provider_id]
except KeyError:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No such provider")
try:
token: OAuth2Token = await authlib_client.authorize_access_token(request)
token: OAuth2Token = await provider.authlib_client.authorize_access_token(request)
except OAuthError as error:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail=error.error)
# Remember the authlib_client in the session
@ -168,6 +185,14 @@ async def auth(request: Request, auth_provider_id: str) -> RedirectResponse:
request.session["auth_provider_id"] = auth_provider_id
# User id (sub) given by auth provider
sub = userinfo["sub"]
## Get additional data from userinfo endpoint
# try:
# user_info_from_endpoint = await authlib_client.userinfo(
# token=token, follow_redirects=True
# )
# except Exception as err:
# logger.warn(f"Cannot get userinfo from endpoint: {err}")
# user_info_from_endpoint = {}
# Build and remember the user in the session
request.session["user_sub"] = sub
# Store the user in the database, which also verifies the token validity and signature
@ -185,9 +210,8 @@ async def auth(request: Request, auth_provider_id: str) -> RedirectResponse:
)
assert isinstance(user, User)
# Add the provider session id to the session
request.session["sid"] = userinfo["sid"]
request.session["sid"] = provider.get_session_key(userinfo)
# Add the token to the db because it is used for logout
provider = providers[auth_provider_id]
await db.add_token(provider, token)
# Send the user to the home: (s)he is authenticated
return RedirectResponse(url=request.url_for("home"))
@ -211,15 +235,16 @@ async def logout(
request: Request,
provider: Annotated[Provider, Depends(get_auth_provider)],
) -> RedirectResponse:
# Clear session
request.session.pop("user_sub", None)
# Get provider's endpoint
if (
provider_logout_uri := provider.authlib_client.server_metadata.get("end_session_endpoint")
) is None:
logger.warn(f"Cannot find end_session_endpoint for provider {provider.name}")
logger.warn(f"Cannot find end_session_endpoint for provider {provider.id}")
return RedirectResponse(request.url_for("non_compliant_logout"))
post_logout_uri = request.url_for("home")
# Clear session
request.session.pop("user_sub", None)
request.session.pop("auth_provider_id", None)
try:
token = await db.get_token(provider, request.session.pop("sid", None))
except TokenNotInDb:
@ -242,15 +267,16 @@ async def logout(
@app.get("/non-compliant-logout")
async def non_compliant_logout(
request: Request,
provider: Annotated[StarletteOAuth2App, Depends(get_auth_provider)],
provider: Annotated[Provider, Depends(get_auth_provider)],
):
"""A page for non-compliant OAuth2 servers that we cannot log out."""
# Clear the remain of the session
# Clear session
request.session.pop("user_sub", None)
request.session.pop("auth_provider_id", None)
return templates.TemplateResponse(
name="non_compliant_logout.html",
request=request,
context={"oidc_provider": provider, "home_url": request.url_for("home")},
context={"auth_provider": provider, "home_url": request.url_for("home")},
)

View file

@ -49,13 +49,13 @@ class User(UserBase):
try:
access_token_scopes = self.decode_access_token().get("scope", "").split(" ")
except Exception as err:
logger.info(f"Access token cannot be decoded: {err}")
logger.debug(f"Cannot find scope because the access token cannot be decoded: {err}")
access_token_scopes = []
return scope in set(info_scopes + access_token_scopes)
def decode_access_token(self, verify_signature: bool = True):
assert self.access_token is not None
assert self.auth_provider_id is not None
assert self.access_token is not None, "no access_token"
assert self.auth_provider_id is not None, "no auth_provider_id"
from .auth_provider import providers
return providers[self.auth_provider_id].decode(

View file

@ -2,7 +2,7 @@ from datetime import datetime
from typing import Annotated
import logging
from authlib.jose import Key
from authlib.oauth2.auth import OAuth2Token
from httpx import AsyncClient
from jwt.exceptions import ExpiredSignatureError, InvalidTokenError
from fastapi import FastAPI, HTTPException, Depends, status
@ -14,11 +14,12 @@ from fastapi.middleware.cors import CORSMiddleware
from .models import User
from .auth_utils import (
get_token_or_none,
get_user_from_token,
UserWithRole,
)
from .settings import settings
from .auth_provider import providers
from .auth_provider import providers, Provider
logger = logging.getLogger("oidc-test")
@ -113,23 +114,51 @@ async def get_protected_by_foorole_or_barrole(
@resource_server.get("/{id}")
async def get_resource_(
async def get_resource(
id: str,
# user: Annotated[User, Depends(get_current_user)],
# oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
# token: Annotated[OAuth2Token, Depends(get_token)],
user: Annotated[User, Depends(get_user_from_token)],
):
token: Annotated[OAuth2Token | None, Depends(get_token_or_none)],
) -> dict | list:
"""Generic path for testing a resource provided by a provider"""
return await get_resource(id, user)
provider = providers[user.auth_provider_id]
if id in [r.id for r in provider.resources]:
return await get_external_resource(
provider=provider,
id=id,
access_token=token["access_token"] if token else None,
user=user,
)
return await get_resource_(id, user)
async def get_resource(resource_id: str, user: User) -> dict:
async def get_external_resource(
provider: Provider, id: str, access_token: str | None, user: User
) -> dict | list:
resource = [r for r in provider.resources if r.id == id][0]
async with AsyncClient() as client:
resp = await client.get(
url=provider.url + resource.url,
headers={
"Content-type": "application/json",
"Authorization": f"Bearer {access_token}",
},
)
if resp.is_error:
raise HTTPException(resp.status_code, f"Cannot fetch resource: {resp.reason_phrase}")
resp_length = len(resp.text)
if resp_length > 1024:
return {"msg": f"The resource is too long ({resp_length} bytes) to show here"}
else:
return resp.json()
async def get_resource_(resource_id: str, user: User) -> dict:
"""
Resource processing: build an informative rely as a simple showcase
"""
provider = providers[user.auth_provider_id]
try:
pname = providers[user.auth_provider_id].name
pname = provider.name
except KeyError:
pname = "?"
resp = {
@ -164,7 +193,6 @@ async def process(user, resource_id, resp):
Too simple to be serious.
It's a good fit for a plugin architecture for production
"""
assert user is not None
if resource_id == "time":
resp["time"] = datetime.now().strftime("%c")
elif resource_id == "bs":

View file

@ -21,6 +21,7 @@ class Resource(BaseModel):
id: str
name: str
url: str
class AuthProviderSettings(BaseModel):
@ -39,10 +40,12 @@ class AuthProviderSettings(BaseModel):
info_url: str | None = (
None # Used eg. for Keycloak's public key (see https://stackoverflow.com/questions/54318633/getting-keycloaks-public-key)
)
info: dict[str, str | int] | None = None # Info fetched from info_url, eg. public key
public_key: str | None = None
public_key_url: str | None = None
signature_alg: str = "RS256"
resource_provider_scopes: list[str] = []
session_key: str = "sid"
skip_verify_signature: bool = True
@computed_field
@property
@ -64,17 +67,6 @@ class AuthProviderSettings(BaseModel):
else:
return None
def get_public_key(self) -> str:
"""Return the public key formatted for decoding token"""
public_key = self.public_key or (self.info is not None and self.info["public_key"])
if public_key is None:
raise AttributeError(f"Cannot get public key for {self.name}")
return f"""
-----BEGIN PUBLIC KEY-----
{public_key}
-----END PUBLIC KEY-----
"""
class ResourceProvider(BaseModel):
id: str

View file

@ -55,15 +55,24 @@ async function get_resource(id, token, authProvider) {
msg.innerHTML = ""
resourceElem.innerHTML = ""
Object.entries(resource).forEach(
([k, v]) => {
([key, value]) => {
let r = document.createElement('div')
let kElem = document.createElement('div')
kElem.innerText = k
kElem.innerText = key
kElem.className = "key"
let vElem = document.createElement('div')
vElem.innerText = v
if (typeof value == "object") {
Object.entries(value).forEach(v => {
const ne = document.createElement('div')
ne.innerHTML = `<span class="key">${v[0]}</span>: <span class="value">${v[1]}</span>`
vElem.appendChild(ne)
})
}
else {
vElem.innerText = value
}
vElem.className = "value"
if (k == "sorry") {
if (key == "sorry") {
vElem.classList.add("error")
}
r.appendChild(kElem)

View file

@ -5,25 +5,24 @@
with OpenID Connect and OAuth2 with different providers.
</p>
{% if not user %}
<div class="login-box">
<p class="description">Log in with:</p>
<table class="providers">
{% for provider in settings.auth.providers %}
<tr class="provider">
<td>
<a class="link" href="login/{{ provider.id }}"><div>{{ provider.name }}</div></a>
</td>
<td class="hint">{{ provider.hint }}</div>
</td>
</tr>
{% else %}
<div class="error">There is no authentication provider defined.
Hint: check the settings.yaml file.</div>
{% endfor %}
</table>
</div>
{% endif %}
{% if user %}
<div class="login-box">
<p class="description">Log in with:</p>
<table class="providers">
{% for provider in providers.values() %}
<tr class="provider">
<td>
<a class="link" href="login/{{ provider.id }}"><div>{{ provider.name }}</div></a>
</td>
<td class="hint">{{ provider.hint }}</div>
</td>
</tr>
{% else %}
<div class="error">There is no authentication provider defined.
Hint: check the settings.yaml file.</div>
{% endfor %}
</table>
</div>
{% else %}
<div class="user-info">
<p>Hey, {{ user.name }}</p>
{% if user.picture %}
@ -83,22 +82,22 @@
<button resource-id="fast_api_depends" class="hidden" onclick="get_resource('fast_api_depends', '{{ access_token }}', '{{ auth_provider.id }}')">Using FastAPI Depends</button>
<!--<button resource-id="introspect" onclick="get_resource('introspect', '{{ access_token }}', '{{ auth_provider.id }}')">Introspect token (401 expected)</button>-->
</div>
<div class="resourceResult">
<div id="resource" class="resource"></div>
<div id="msg" class="msg error"></div>
</div>
{% if resources %}
<p>
Resources for this provider:
</p>
<div class="links-to-check">
{% for resource in resources %}
<a href="{{ request.url_for('get_resource', id=resource.id) }}">{{ resource.name }}</a>
<button resource-id="{{ resource.id }}" onclick="get_resource('{{ resource.id }}', '{{ access_token }}', '{{ auth_provider.id }}')">{{ resource.name }}</buttona>
{% endfor %}
</div>
{% endif %}
<div class="resourceResult">
<div id="resource" class="resource"></div>
<div id="msg" class="msg error"></div>
</div>
</div>
{% if settings.show_token and id_token_parsed %}
{% if show_token and id_token_parsed %}
<div class="token-info">
<hr>
<div>

View file

@ -6,12 +6,12 @@
authorisation to log in again without asking for credentials.
</p>
<p>
This is because {{ oidc_provider.name }} does not provide "end_session_endpoint" in its metadata
(see: <a href="{{ oidc_provider._server_metadata_url }}">{{ oidc_provider._server_metadata_url }}</a>).
This is because {{ auth_provider.name }} does not provide "end_session_endpoint" in its metadata
(see: <a href="{{ auth_provider.authlib_client._server_metadata_url }}">{{ auth_provider.authlib_client._server_metadata_url }}</a>).
</p>
<p>
You can just also go back to the <a href="{{ home_url }}">application home page</a>, but
it recommended to go to the <a href="{{ oidc_provider.server_metadata['issuer'] }}">OIDC provider's site</a>
it recommended to go to the <a href="{{ auth_provider.authlib_client.server_metadata['issuer'] }}">OIDC provider's site</a>
and log out explicitely from there.
</p>
{% endblock %}