Decode access token, refactor
All checks were successful
/ build (push) Successful in 15s
/ test (push) Successful in 5s

This commit is contained in:
phil 2025-02-02 15:54:44 +01:00
parent 8b8bbcd7a0
commit e1dac77738
5 changed files with 42 additions and 23 deletions

View file

@ -6,14 +6,14 @@ from fastapi import HTTPException, Request, Depends, status
from fastapi.security import OAuth2PasswordBearer
from authlib.oauth2.rfc6749 import OAuth2Token
from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App
from jwt import ExpiredSignatureError, InvalidKeyError, decode
from jwt import ExpiredSignatureError, InvalidKeyError
from httpx import AsyncClient
# from authlib.oauth1.auth import OAuthToken
# from authlib.oauth2.auth import OAuth2Token
from .models import User
from .database import db, UserNotInDB
from .database import TokenNotInDb, db, UserNotInDB
from .settings import settings, OIDCProvider
logger = logging.getLogger(__name__)
@ -126,9 +126,10 @@ async def get_current_user(request: Request) -> User:
async def get_token(request: Request) -> OAuth2Token:
"""Return the token from a request object, from the session.
It can be used in Depends()"""
if (token := await db.get_token(request.session.get("token"))) is None:
try:
return await db.get_token(request.session["token"])
except (KeyError, TokenNotInDb):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token")
return token
async def get_current_user_or_none(request: Request) -> User | None:
@ -189,19 +190,8 @@ async def get_user_from_token(
raise HTTPException(
status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
)
if (key := auth_provider_settings.get_public_key()) is None:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED,
f"Key for provider '{auth_provider_id}' unknown",
)
try:
payload = decode(
token,
key=key,
algorithms=["RS256"],
audience="oidc-test",
options={"verify_signature": not settings.insecure.skip_verify_signature},
)
payload = auth_provider_settings.decode(token)
except ExpiredSignatureError as err:
logger.info(f"Expired signature: {err}")
raise HTTPException(

View file

@ -14,6 +14,10 @@ class UserNotInDB(Exception):
pass
class TokenNotInDb(Exception):
pass
class Database:
users: dict[str, User] = {}
tokens: dict[str, OAuth2Token] = {}
@ -50,8 +54,11 @@ class Database:
async def add_token(self, token: OAuth2Token, user: User) -> None:
self.tokens[token["id_token"]] = token
async def get_token(self, id_token: str) -> OAuth2Token | None:
return self.tokens.get(id_token)
async def get_token(self, id_token: str) -> OAuth2Token:
try:
return self.tokens[id_token]
except KeyError:
raise TokenNotInDb
db = Database()

View file

@ -187,11 +187,20 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
# Build and remember the user in the session
request.session["user_sub"] = sub
# Store the user in the database
try:
oidc_provider_settings = oidc_providers_settings[oidc_provider_id]
access_token = oidc_provider_settings.decode(token["access_token"])
except Exception:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED,
detail="Cannot decode token or verify its signature",
)
user = await db.add_user(
sub,
user_info=userinfo,
oidc_provider=oidc_provider,
user_info_from_endpoint=user_info_from_endpoint,
access_token=access_token,
)
# Add the id_token to the session
request.session["token"] = token["id_token"]
@ -213,14 +222,14 @@ async def account(
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
) -> RedirectResponse:
if (
provider := oidc_providers_settings.get(
oidc_provider_settings := oidc_providers_settings.get(
request.session.get("oidc_provider_id", "")
)
) is None:
raise HTTPException(
status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider setting"
status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider settings"
)
return RedirectResponse(f"{provider.account_url}")
return RedirectResponse(f"{oidc_provider_settings.account_url}")
@app.get("/logout")

View file

@ -4,6 +4,7 @@ import random
from typing import Type, Tuple
from pathlib import Path
from jwt import decode
from pydantic import BaseModel, computed_field, AnyUrl
from pydantic_settings import (
BaseSettings,
@ -43,6 +44,7 @@ class OIDCProvider(BaseModel):
None # Info fetched from info_url, eg. public key
)
public_key: str | None = None
signature_alg: str = "RS256"
@computed_field
@property
@ -70,19 +72,29 @@ class OIDCProvider(BaseModel):
else:
return None
def get_public_key(self) -> str | None:
def get_public_key(self) -> str:
"""Return the public key formatted for decoding token"""
public_key = self.public_key or (
self.info is not None and self.info["public_key"]
)
if public_key is None:
return None
raise AttributeError(f"Cannot get public key for {self.name}")
return f"""
-----BEGIN PUBLIC KEY-----
{public_key}
-----END PUBLIC KEY-----
"""
def decode(self, token: str) -> dict:
"""Decode the token with signature check"""
return decode(
token,
self.get_public_key(),
algorithms=[self.signature_alg],
audience=["oidc-test", "oidc-test-web"],
options={"verify_signature": not settings.insecure.skip_verify_signature},
)
class ResourceProvider(BaseModel):
id: str

View file

@ -104,6 +104,7 @@ hr {
.role {
padding: 3px 6px;
background-color: #44228840;
border-radius: 6px;
}
/* For home */