Decode access token, refactor
This commit is contained in:
parent
8b8bbcd7a0
commit
e1dac77738
5 changed files with 42 additions and 23 deletions
|
@ -6,14 +6,14 @@ from fastapi import HTTPException, Request, Depends, status
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from authlib.oauth2.rfc6749 import OAuth2Token
|
from authlib.oauth2.rfc6749 import OAuth2Token
|
||||||
from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App
|
from authlib.integrations.starlette_client import OAuth, OAuthError, StarletteOAuth2App
|
||||||
from jwt import ExpiredSignatureError, InvalidKeyError, decode
|
from jwt import ExpiredSignatureError, InvalidKeyError
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
|
|
||||||
# from authlib.oauth1.auth import OAuthToken
|
# from authlib.oauth1.auth import OAuthToken
|
||||||
# from authlib.oauth2.auth import OAuth2Token
|
# from authlib.oauth2.auth import OAuth2Token
|
||||||
|
|
||||||
from .models import User
|
from .models import User
|
||||||
from .database import db, UserNotInDB
|
from .database import TokenNotInDb, db, UserNotInDB
|
||||||
from .settings import settings, OIDCProvider
|
from .settings import settings, OIDCProvider
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -126,9 +126,10 @@ async def get_current_user(request: Request) -> User:
|
||||||
async def get_token(request: Request) -> OAuth2Token:
|
async def get_token(request: Request) -> OAuth2Token:
|
||||||
"""Return the token from a request object, from the session.
|
"""Return the token from a request object, from the session.
|
||||||
It can be used in Depends()"""
|
It can be used in Depends()"""
|
||||||
if (token := await db.get_token(request.session.get("token"))) is None:
|
try:
|
||||||
|
return await db.get_token(request.session["token"])
|
||||||
|
except (KeyError, TokenNotInDb):
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token")
|
||||||
return token
|
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user_or_none(request: Request) -> User | None:
|
async def get_current_user_or_none(request: Request) -> User | None:
|
||||||
|
@ -189,19 +190,8 @@ async def get_user_from_token(
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
|
status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
|
||||||
)
|
)
|
||||||
if (key := auth_provider_settings.get_public_key()) is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status.HTTP_401_UNAUTHORIZED,
|
|
||||||
f"Key for provider '{auth_provider_id}' unknown",
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
payload = decode(
|
payload = auth_provider_settings.decode(token)
|
||||||
token,
|
|
||||||
key=key,
|
|
||||||
algorithms=["RS256"],
|
|
||||||
audience="oidc-test",
|
|
||||||
options={"verify_signature": not settings.insecure.skip_verify_signature},
|
|
||||||
)
|
|
||||||
except ExpiredSignatureError as err:
|
except ExpiredSignatureError as err:
|
||||||
logger.info(f"Expired signature: {err}")
|
logger.info(f"Expired signature: {err}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
|
|
@ -14,6 +14,10 @@ class UserNotInDB(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TokenNotInDb(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Database:
|
class Database:
|
||||||
users: dict[str, User] = {}
|
users: dict[str, User] = {}
|
||||||
tokens: dict[str, OAuth2Token] = {}
|
tokens: dict[str, OAuth2Token] = {}
|
||||||
|
@ -50,8 +54,11 @@ class Database:
|
||||||
async def add_token(self, token: OAuth2Token, user: User) -> None:
|
async def add_token(self, token: OAuth2Token, user: User) -> None:
|
||||||
self.tokens[token["id_token"]] = token
|
self.tokens[token["id_token"]] = token
|
||||||
|
|
||||||
async def get_token(self, id_token: str) -> OAuth2Token | None:
|
async def get_token(self, id_token: str) -> OAuth2Token:
|
||||||
return self.tokens.get(id_token)
|
try:
|
||||||
|
return self.tokens[id_token]
|
||||||
|
except KeyError:
|
||||||
|
raise TokenNotInDb
|
||||||
|
|
||||||
|
|
||||||
db = Database()
|
db = Database()
|
||||||
|
|
|
@ -187,11 +187,20 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
||||||
# Build and remember the user in the session
|
# Build and remember the user in the session
|
||||||
request.session["user_sub"] = sub
|
request.session["user_sub"] = sub
|
||||||
# Store the user in the database
|
# Store the user in the database
|
||||||
|
try:
|
||||||
|
oidc_provider_settings = oidc_providers_settings[oidc_provider_id]
|
||||||
|
access_token = oidc_provider_settings.decode(token["access_token"])
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Cannot decode token or verify its signature",
|
||||||
|
)
|
||||||
user = await db.add_user(
|
user = await db.add_user(
|
||||||
sub,
|
sub,
|
||||||
user_info=userinfo,
|
user_info=userinfo,
|
||||||
oidc_provider=oidc_provider,
|
oidc_provider=oidc_provider,
|
||||||
user_info_from_endpoint=user_info_from_endpoint,
|
user_info_from_endpoint=user_info_from_endpoint,
|
||||||
|
access_token=access_token,
|
||||||
)
|
)
|
||||||
# Add the id_token to the session
|
# Add the id_token to the session
|
||||||
request.session["token"] = token["id_token"]
|
request.session["token"] = token["id_token"]
|
||||||
|
@ -213,14 +222,14 @@ async def account(
|
||||||
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
oidc_provider: Annotated[StarletteOAuth2App, Depends(get_oidc_provider)],
|
||||||
) -> RedirectResponse:
|
) -> RedirectResponse:
|
||||||
if (
|
if (
|
||||||
provider := oidc_providers_settings.get(
|
oidc_provider_settings := oidc_providers_settings.get(
|
||||||
request.session.get("oidc_provider_id", "")
|
request.session.get("oidc_provider_id", "")
|
||||||
)
|
)
|
||||||
) is None:
|
) is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider setting"
|
status.HTTP_406_NOT_ACCEPTABLE, detail="No oidc provider settings"
|
||||||
)
|
)
|
||||||
return RedirectResponse(f"{provider.account_url}")
|
return RedirectResponse(f"{oidc_provider_settings.account_url}")
|
||||||
|
|
||||||
|
|
||||||
@app.get("/logout")
|
@app.get("/logout")
|
||||||
|
|
|
@ -4,6 +4,7 @@ import random
|
||||||
from typing import Type, Tuple
|
from typing import Type, Tuple
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from jwt import decode
|
||||||
from pydantic import BaseModel, computed_field, AnyUrl
|
from pydantic import BaseModel, computed_field, AnyUrl
|
||||||
from pydantic_settings import (
|
from pydantic_settings import (
|
||||||
BaseSettings,
|
BaseSettings,
|
||||||
|
@ -43,6 +44,7 @@ class OIDCProvider(BaseModel):
|
||||||
None # Info fetched from info_url, eg. public key
|
None # Info fetched from info_url, eg. public key
|
||||||
)
|
)
|
||||||
public_key: str | None = None
|
public_key: str | None = None
|
||||||
|
signature_alg: str = "RS256"
|
||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
|
@ -70,19 +72,29 @@ class OIDCProvider(BaseModel):
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_public_key(self) -> str | None:
|
def get_public_key(self) -> str:
|
||||||
"""Return the public key formatted for decoding token"""
|
"""Return the public key formatted for decoding token"""
|
||||||
public_key = self.public_key or (
|
public_key = self.public_key or (
|
||||||
self.info is not None and self.info["public_key"]
|
self.info is not None and self.info["public_key"]
|
||||||
)
|
)
|
||||||
if public_key is None:
|
if public_key is None:
|
||||||
return None
|
raise AttributeError(f"Cannot get public key for {self.name}")
|
||||||
return f"""
|
return f"""
|
||||||
-----BEGIN PUBLIC KEY-----
|
-----BEGIN PUBLIC KEY-----
|
||||||
{public_key}
|
{public_key}
|
||||||
-----END PUBLIC KEY-----
|
-----END PUBLIC KEY-----
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def decode(self, token: str) -> dict:
|
||||||
|
"""Decode the token with signature check"""
|
||||||
|
return decode(
|
||||||
|
token,
|
||||||
|
self.get_public_key(),
|
||||||
|
algorithms=[self.signature_alg],
|
||||||
|
audience=["oidc-test", "oidc-test-web"],
|
||||||
|
options={"verify_signature": not settings.insecure.skip_verify_signature},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ResourceProvider(BaseModel):
|
class ResourceProvider(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
|
|
|
@ -104,6 +104,7 @@ hr {
|
||||||
.role {
|
.role {
|
||||||
padding: 3px 6px;
|
padding: 3px 6px;
|
||||||
background-color: #44228840;
|
background-color: #44228840;
|
||||||
|
border-radius: 6px;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* For home */
|
/* For home */
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue