Decode access token, refactor
All checks were successful
/ build (push) Successful in 15s
/ test (push) Successful in 5s

This commit is contained in:
phil 2025-02-02 15:54:44 +01:00
parent 8b8bbcd7a0
commit e1dac77738
5 changed files with 42 additions and 23 deletions

View file

@ -6,14 +6,14 @@ from fastapi import HTTPException, Request, Depends, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from authlib.oauth2.rfc6749 import OAuth2Token from authlib.oauth2.rfc6749 import OAuth2Token
from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App
from jwt import ExpiredSignatureError, InvalidKeyError, decode from jwt import ExpiredSignatureError, InvalidKeyError
from httpx import AsyncClient from httpx import AsyncClient
# from authlib.oauth1.auth import OAuthToken # from authlib.oauth1.auth import OAuthToken
# from authlib.oauth2.auth import OAuth2Token # from authlib.oauth2.auth import OAuth2Token
from .models import User from .models import User
from .database import db, UserNotInDB from .database import TokenNotInDb, db, UserNotInDB
from .settings import settings, OIDCProvider from .settings import settings, OIDCProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -126,9 +126,10 @@ async def get_current_user(request: Request) -> User:
async def get_token(request: Request) -> OAuth2Token: async def get_token(request: Request) -> OAuth2Token:
"""Return the token from a request object, from the session. """Return the token from a request object, from the session.
It can be used in Depends()""" It can be used in Depends()"""
if (token := await db.get_token(request.session.get("token"))) is None: try:
return await db.get_token(request.session["token"])
except (KeyError, TokenNotInDb):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token")
return token
async def get_current_user_or_none(request: Request) -> User | None: async def get_current_user_or_none(request: Request) -> User | None:
@ -189,19 +190,8 @@ async def get_user_from_token(
raise HTTPException( raise HTTPException(
status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'" status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
) )
if (key := auth_provider_settings.get_public_key()) is None:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED,
f"Key for provider '{auth_provider_id}' unknown",
)
try: try:
payload = decode( payload = auth_provider_settings.decode(token)
token,
key=key,
algorithms=["RS256"],
audience="oidc-test",
options={"verify_signature": not settings.insecure.skip_verify_signature},
)
except ExpiredSignatureError as err: except ExpiredSignatureError as err:
logger.info(f"Expired signature: {err}") logger.info(f"Expired signature: {err}")
raise HTTPException( raise HTTPException(

View file

@ -14,6 +14,10 @@ class UserNotInDB(Exception):
pass pass
class TokenNotInDb(Exception):
pass
class Database: class Database:
users: dict[str, User] = {} users: dict[str, User] = {}
tokens: dict[str, OAuth2Token] = {} tokens: dict[str, OAuth2Token] = {}
@ -50,8 +54,11 @@ class Database:
async def add_token(self, token: OAuth2Token, user: User) -> None: async def add_token(self, token: OAuth2Token, user: User) -> None:
self.tokens[token["id_token"]] = token self.tokens[token["id_token"]] = token
async def get_token(self, id_token: str) -> OAuth2Token | None: async def get_token(self, id_token: str) -> OAuth2Token:
return self.tokens.get(id_token) try:
return self.tokens[id_token]
except KeyError:
raise TokenNotInDb
db = Database() db = Database()

View file

@ -187,11 +187,20 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
# Build and remember the user in the session # Build and remember the user in the session
request.session["user_sub"] = sub request.session["user_sub"] = sub
# Store the user in the database # Store the user in the database
try:
oidc_provider_settings = oidc_providers_settings[oidc_provider_id]
access_token = oidc_provider_settings.decode(token["access_token"])
except Exception:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED,
detail="Cannot decode token or verify its signature",
)
user = await db.add_user( user = await db.add_user(
sub, sub,
user_info=userinfo, user_info=userinfo,
oidc_provider=oidc_provider, oidc_provider=oidc_provider,
user_info_from_endpoint=user_info_from_endpoint, user_info_from_endpoint=user_info_from_endpoint,
access_token=access_token,
) )
# Add the id_token to the session # Add the id_token to the session
request.session["token"] = token["id_token"] request.session["token"] = token["id_token"]
@ -213,14 +222,14 @@ async def account(
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)], oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
) -> RedirectResponse: ) -> RedirectResponse:
if ( if (
provider := oidc_providers_settings.get( oidc_provider_settings := oidc_providers_settings.get(
request.session.get("oidc_provider_id", "") request.session.get("oidc_provider_id", "")
) )
) is None: ) is None:
raise HTTPException( raise HTTPException(
status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider setting" status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider settings"
) )
return RedirectResponse(f"{provider.account_url}") return RedirectResponse(f"{oidc_provider_settings.account_url}")
@app.get("/logout") @app.get("/logout")

View file

@ -4,6 +4,7 @@ import random
from typing import Type, Tuple from typing import Type, Tuple
from pathlib import Path from pathlib import Path
from jwt import decode
from pydantic import BaseModel, computed_field, AnyUrl from pydantic import BaseModel, computed_field, AnyUrl
from pydantic_settings import ( from pydantic_settings import (
BaseSettings, BaseSettings,
@ -43,6 +44,7 @@ class OIDCProvider(BaseModel):
None # Info fetched from info_url, eg. public key None # Info fetched from info_url, eg. public key
) )
public_key: str | None = None public_key: str | None = None
signature_alg: str = "RS256"
@computed_field @computed_field
@property @property
@ -70,19 +72,29 @@ class OIDCProvider(BaseModel):
else: else:
return None return None
def get_public_key(self) -> str | None: def get_public_key(self) -> str:
"""Return the public key formatted for decoding token""" """Return the public key formatted for decoding token"""
public_key = self.public_key or ( public_key = self.public_key or (
self.info is not None and self.info["public_key"] self.info is not None and self.info["public_key"]
) )
if public_key is None: if public_key is None:
return None raise AttributeError(f"Cannot get public key for {self.name}")
return f""" return f"""
-----BEGIN PUBLIC KEY----- -----BEGIN PUBLIC KEY-----
{public_key} {public_key}
-----END PUBLIC KEY----- -----END PUBLIC KEY-----
""" """
def decode(self, token: str) -> dict:
"""Decode the token with signature check"""
return decode(
token,
self.get_public_key(),
algorithms=[self.signature_alg],
audience=["oidc-test", "oidc-test-web"],
options={"verify_signature": not settings.insecure.skip_verify_signature},
)
class ResourceProvider(BaseModel): class ResourceProvider(BaseModel):
id: str id: str

View file

@ -104,6 +104,7 @@ hr {
.role { .role {
padding: 3px 6px; padding: 3px 6px;
background-color: #44228840; background-color: #44228840;
border-radius: 6px;
} }
/* For home */ /* For home */