Decode access token, refactor
This commit is contained in:
parent
8b8bbcd7a0
commit
e1dac77738
5 changed files with 42 additions and 23 deletions
|
@ -6,14 +6,14 @@ from fastapi import HTTPException, Request, Depends, status
|
|||
from fastapi.security import OAuth2PasswordBearer
|
||||
from authlib.oauth2.rfc6749 import OAuth2Token
|
||||
from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App
|
||||
from jwt import ExpiredSignatureError, InvalidKeyError, decode
|
||||
from jwt import ExpiredSignatureError, InvalidKeyError
|
||||
from httpx import AsyncClient
|
||||
|
||||
# from authlib.oauth1.auth import OAuthToken
|
||||
# from authlib.oauth2.auth import OAuth2Token
|
||||
|
||||
from .models import User
|
||||
from .database import db, UserNotInDB
|
||||
from .database import TokenNotInDb, db, UserNotInDB
|
||||
from .settings import settings, OIDCProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -126,9 +126,10 @@ async def get_current_user(request: Request) -> User:
|
|||
async def get_token(request: Request) -> OAuth2Token:
|
||||
"""Return the token from a request object, from the session.
|
||||
It can be used in Depends()"""
|
||||
if (token := await db.get_token(request.session.get("token"))) is None:
|
||||
try:
|
||||
return await db.get_token(request.session["token"])
|
||||
except (KeyError, TokenNotInDb):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token")
|
||||
return token
|
||||
|
||||
|
||||
async def get_current_user_or_none(request: Request) -> User | None:
|
||||
|
@ -189,19 +190,8 @@ async def get_user_from_token(
|
|||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
|
||||
)
|
||||
if (key := auth_provider_settings.get_public_key()) is None:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED,
|
||||
f"Key for provider '{auth_provider_id}' unknown",
|
||||
)
|
||||
try:
|
||||
payload = decode(
|
||||
token,
|
||||
key=key,
|
||||
algorithms=["RS256"],
|
||||
audience="oidc-test",
|
||||
options={"verify_signature": not settings.insecure.skip_verify_signature},
|
||||
)
|
||||
payload = auth_provider_settings.decode(token)
|
||||
except ExpiredSignatureError as err:
|
||||
logger.info(f"Expired signature: {err}")
|
||||
raise HTTPException(
|
||||
|
|
|
@ -14,6 +14,10 @@ class UserNotInDB(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class TokenNotInDb(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Database:
|
||||
users: dict[str, User] = {}
|
||||
tokens: dict[str, OAuth2Token] = {}
|
||||
|
@ -50,8 +54,11 @@ class Database:
|
|||
async def add_token(self, token: OAuth2Token, user: User) -> None:
|
||||
self.tokens[token["id_token"]] = token
|
||||
|
||||
async def get_token(self, id_token: str) -> OAuth2Token | None:
|
||||
return self.tokens.get(id_token)
|
||||
async def get_token(self, id_token: str) -> OAuth2Token:
|
||||
try:
|
||||
return self.tokens[id_token]
|
||||
except KeyError:
|
||||
raise TokenNotInDb
|
||||
|
||||
|
||||
db = Database()
|
||||
|
|
|
@ -187,11 +187,20 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
|||
# Build and remember the user in the session
|
||||
request.session["user_sub"] = sub
|
||||
# Store the user in the database
|
||||
try:
|
||||
oidc_provider_settings = oidc_providers_settings[oidc_provider_id]
|
||||
access_token = oidc_provider_settings.decode(token["access_token"])
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Cannot decode token or verify its signature",
|
||||
)
|
||||
user = await db.add_user(
|
||||
sub,
|
||||
user_info=userinfo,
|
||||
oidc_provider=oidc_provider,
|
||||
user_info_from_endpoint=user_info_from_endpoint,
|
||||
access_token=access_token,
|
||||
)
|
||||
# Add the id_token to the session
|
||||
request.session["token"] = token["id_token"]
|
||||
|
@ -213,14 +222,14 @@ async def account(
|
|||
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
||||
) -> RedirectResponse:
|
||||
if (
|
||||
provider := oidc_providers_settings.get(
|
||||
oidc_provider_settings := oidc_providers_settings.get(
|
||||
request.session.get("oidc_provider_id", "")
|
||||
)
|
||||
) is None:
|
||||
raise HTTPException(
|
||||
status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider setting"
|
||||
status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider settings"
|
||||
)
|
||||
return RedirectResponse(f"{provider.account_url}")
|
||||
return RedirectResponse(f"{oidc_provider_settings.account_url}")
|
||||
|
||||
|
||||
@app.get("/logout")
|
||||
|
|
|
@ -4,6 +4,7 @@ import random
|
|||
from typing import Type, Tuple
|
||||
from pathlib import Path
|
||||
|
||||
from jwt import decode
|
||||
from pydantic import BaseModel, computed_field, AnyUrl
|
||||
from pydantic_settings import (
|
||||
BaseSettings,
|
||||
|
@ -43,6 +44,7 @@ class OIDCProvider(BaseModel):
|
|||
None # Info fetched from info_url, eg. public key
|
||||
)
|
||||
public_key: str | None = None
|
||||
signature_alg: str = "RS256"
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
|
@ -70,19 +72,29 @@ class OIDCProvider(BaseModel):
|
|||
else:
|
||||
return None
|
||||
|
||||
def get_public_key(self) -> str | None:
|
||||
def get_public_key(self) -> str:
|
||||
"""Return the public key formatted for decoding token"""
|
||||
public_key = self.public_key or (
|
||||
self.info is not None and self.info["public_key"]
|
||||
)
|
||||
if public_key is None:
|
||||
return None
|
||||
raise AttributeError(f"Cannot get public key for {self.name}")
|
||||
return f"""
|
||||
-----BEGIN PUBLIC KEY-----
|
||||
{public_key}
|
||||
-----END PUBLIC KEY-----
|
||||
"""
|
||||
|
||||
def decode(self, token: str) -> dict:
|
||||
"""Decode the token with signature check"""
|
||||
return decode(
|
||||
token,
|
||||
self.get_public_key(),
|
||||
algorithms=[self.signature_alg],
|
||||
audience=["oidc-test", "oidc-test-web"],
|
||||
options={"verify_signature": not settings.insecure.skip_verify_signature},
|
||||
)
|
||||
|
||||
|
||||
class ResourceProvider(BaseModel):
|
||||
id: str
|
||||
|
|
|
@ -104,6 +104,7 @@ hr {
|
|||
.role {
|
||||
padding: 3px 6px;
|
||||
background-color: #44228840;
|
||||
border-radius: 6px;
|
||||
}
|
||||
|
||||
/* For home */
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue