Get roles from access token, remove user info inspection, refreactorings
This commit is contained in:
parent
5c9ed9724e
commit
ee8ba3d2df
6 changed files with 126 additions and 97 deletions
|
@ -13,13 +13,10 @@ from authlib.oauth2.auth import OAuth2Token
|
|||
|
||||
from .models import User
|
||||
from .database import TokenNotInDb, db, UserNotInDB
|
||||
from .settings import settings, OIDCProvider
|
||||
from .settings import settings, OIDCProvider, oidc_providers_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger("oidc-test")
|
||||
|
||||
oidc_providers_settings: dict[str, OIDCProvider] = dict(
|
||||
[(provider.id, provider) for provider in settings.oidc.providers]
|
||||
)
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
|
||||
|
@ -36,19 +33,16 @@ async def fetch_token(name, request):
|
|||
|
||||
|
||||
async def update_token(name, token, refresh_token=None, access_token=None):
|
||||
breakpoint()
|
||||
item = await db.get_token(token["id_token"])
|
||||
if refresh_token:
|
||||
item = OAuth2Token.find(name=name, refresh_token=refresh_token)
|
||||
elif access_token:
|
||||
item = OAuth2Token.find(name=name, access_token=access_token)
|
||||
else:
|
||||
return
|
||||
oidc_provider_settings = oidc_providers_settings[name]
|
||||
sid: str = oidc_provider_settings.decode(token["id_token"])["sid"]
|
||||
item = await db.get_token(oidc_provider_settings, sid)
|
||||
# update old token
|
||||
item.access_token = token["access_token"]
|
||||
item.refresh_token = token.get("refresh_token")
|
||||
item.expires_at = token["expires_at"]
|
||||
item.save()
|
||||
item["access_token"] = token.get("access_token")
|
||||
item["refresh_token"] = token.get("refresh_token")
|
||||
item["expires_at"] = token["expires_at"]
|
||||
logger.info(f"Token {sid} refreshed")
|
||||
# It's a fake db and only in memory, so there's nothing to save
|
||||
# await item.save()
|
||||
|
||||
|
||||
authlib_oauth = OAuth(cache=None, fetch_token=fetch_token, update_token=update_token)
|
||||
|
@ -138,8 +132,17 @@ async def get_token(request: Request) -> OAuth2Token:
|
|||
"""Return the token from a request object, from the session.
|
||||
It can be used in Depends()"""
|
||||
try:
|
||||
return await db.get_token(request.session.get("token"))
|
||||
except TokenNotInDb:
|
||||
oidc_provider_settings = oidc_providers_settings[
|
||||
request.session.get("oidc_provider_id", "")
|
||||
]
|
||||
except KeyError:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid provider")
|
||||
try:
|
||||
return await db.get_token(
|
||||
oidc_provider_settings,
|
||||
request.session.get("sid"),
|
||||
)
|
||||
except (TokenNotInDb, InvalidKeyError):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "No token")
|
||||
|
||||
|
||||
|
@ -190,14 +193,16 @@ async def get_user_from_token(
|
|||
token: Annotated[str, Depends(oauth2_scheme)],
|
||||
request: Request,
|
||||
) -> User:
|
||||
if (auth_provider_id := request.headers.get("auth_provider")) is None:
|
||||
try:
|
||||
auth_provider_id = request.headers["auth_provider"]
|
||||
except KeyError:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED,
|
||||
"Request headers must have a 'auth_provider' field",
|
||||
)
|
||||
if (
|
||||
auth_provider_settings := oidc_providers_settings.get(auth_provider_id)
|
||||
) is None:
|
||||
try:
|
||||
auth_provider_settings = oidc_providers_settings[auth_provider_id]
|
||||
except KeyError:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED, f"Unknown auth provider '{auth_provider_id}'"
|
||||
)
|
||||
|
@ -216,7 +221,9 @@ async def get_user_from_token(
|
|||
logger.info("Cannot decode token, see below")
|
||||
logger.exception(err)
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Cannot decode token")
|
||||
if (user_id := payload.get("sub")) is None:
|
||||
try:
|
||||
user_id = payload["sub"]
|
||||
except KeyError:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED, "Wrong token: 'sub' (user id) not found"
|
||||
)
|
||||
|
@ -232,7 +239,6 @@ async def get_user_from_token(
|
|||
sub=payload["sub"],
|
||||
user_info=payload,
|
||||
oidc_provider=getattr(authlib_oauth, auth_provider_id),
|
||||
user_info_from_endpoint={},
|
||||
access_token=token,
|
||||
)
|
||||
return user
|
||||
|
|
|
@ -3,11 +3,12 @@
|
|||
import logging
|
||||
|
||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
|
||||
from .models import User, Role
|
||||
from authlib.oauth2.rfc6749 import OAuth2Token
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from .settings import OIDCProvider, oidc_providers_settings
|
||||
from .models import User, Role
|
||||
|
||||
logger = logging.getLogger("oidc-test")
|
||||
|
||||
|
||||
class UserNotInDB(Exception):
|
||||
|
@ -29,20 +30,34 @@ class Database:
|
|||
sub: str,
|
||||
user_info: dict,
|
||||
oidc_provider: StarletteOAuth2App,
|
||||
user_info_from_endpoint: dict,
|
||||
access_token: str,
|
||||
access_token_decoded: dict | None = None,
|
||||
) -> User:
|
||||
user = User.from_auth(userinfo=user_info, oidc_provider=oidc_provider)
|
||||
if access_token_decoded is None:
|
||||
assert oidc_provider.name is not None
|
||||
oidc_provider_settings = oidc_providers_settings[oidc_provider.name]
|
||||
access_token_decoded = oidc_provider_settings.decode(access_token)
|
||||
user = User(**user_info)
|
||||
user.userinfo = user_info
|
||||
user.oidc_provider = oidc_provider
|
||||
user.access_token = access_token
|
||||
user.access_token_decoded = access_token_decoded
|
||||
# Add roles provided in the access token
|
||||
roles = set()
|
||||
try:
|
||||
raw_roles = user_info_from_endpoint["resource_access"][
|
||||
oidc_provider.client_id
|
||||
]["roles"]
|
||||
except Exception as err:
|
||||
logger.debug(f"Cannot read additional roles: {err}")
|
||||
raw_roles = []
|
||||
for raw_role in raw_roles:
|
||||
user.roles.append(Role(name=raw_role))
|
||||
r = access_token_decoded["resource_access"][oidc_provider.client_id]["roles"]
|
||||
roles.update(r)
|
||||
except KeyError:
|
||||
pass
|
||||
try:
|
||||
r = access_token_decoded["realm_access"]["roles"]
|
||||
if isinstance(r, str):
|
||||
roles.add(r)
|
||||
else:
|
||||
roles.update(r)
|
||||
except KeyError:
|
||||
pass
|
||||
user.roles = [Role(name=role_name) for role_name in roles]
|
||||
self.users[sub] = user
|
||||
return user
|
||||
|
||||
|
@ -51,14 +66,21 @@ class Database:
|
|||
raise UserNotInDB
|
||||
return self.users[sub]
|
||||
|
||||
async def add_token(self, token: OAuth2Token, user: User) -> None:
|
||||
self.tokens[token["id_token"]] = token
|
||||
async def add_token(self, oidc_provider_settings: OIDCProvider, token: OAuth2Token) -> None:
|
||||
"""Store a token using as key the sid (auth provider's session id)
|
||||
in the id_token"""
|
||||
sid = token["userinfo"]["sid"]
|
||||
self.tokens[sid] = token
|
||||
|
||||
async def get_token(self, id_token: str | None) -> OAuth2Token:
|
||||
if id_token is None:
|
||||
async def get_token(
|
||||
self,
|
||||
oidc_provider_settings: OIDCProvider,
|
||||
sid: str | None,
|
||||
) -> OAuth2Token:
|
||||
if sid is None:
|
||||
raise TokenNotInDb
|
||||
try:
|
||||
return self.tokens[id_token]
|
||||
return self.tokens[sid]
|
||||
except KeyError:
|
||||
raise TokenNotInDb
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ from fastapi.staticfiles import StaticFiles
|
|||
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from jwt import InvalidKeyError, InvalidTokenError
|
||||
from jwt import InvalidTokenError, PyJWTError
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
from authlib.integrations.base_client import OAuthError
|
||||
|
@ -26,7 +26,7 @@ from authlib.oauth2.rfc6749 import OAuth2Token
|
|||
# from fastapi.security import OpenIdConnect
|
||||
# from pkce import generate_code_verifier, generate_pkce_pair
|
||||
|
||||
from .settings import settings
|
||||
from .settings import settings, oidc_providers_settings
|
||||
from .models import User
|
||||
from .auth_utils import (
|
||||
get_oidc_provider,
|
||||
|
@ -37,14 +37,13 @@ from .auth_utils import (
|
|||
get_user_from_token,
|
||||
authlib_oauth,
|
||||
get_token,
|
||||
oidc_providers_settings,
|
||||
get_providers_info,
|
||||
)
|
||||
from .auth_misc import pretty_details
|
||||
from .database import TokenNotInDb, db
|
||||
from .resource_server import get_resource
|
||||
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
logger = logging.getLogger("oidc-test")
|
||||
|
||||
templates = Jinja2Templates(Path(__file__).parent / "templates")
|
||||
|
||||
|
@ -189,43 +188,28 @@ async def auth(request: Request, oidc_provider_id: str) -> RedirectResponse:
|
|||
request.session["oidc_provider_id"] = oidc_provider_id
|
||||
# User id (sub) given by oidc provider
|
||||
sub = userinfo["sub"]
|
||||
# Get additional data from userinfo endpoint
|
||||
try:
|
||||
user_info_from_endpoint = await oidc_provider.userinfo(
|
||||
token=token, follow_redirects=True
|
||||
)
|
||||
except Exception as err:
|
||||
logger.warn(f"Cannot get userinfo from endpoint: {err}")
|
||||
user_info_from_endpoint = {}
|
||||
# Build and remember the user in the session
|
||||
request.session["user_sub"] = sub
|
||||
# Verify the token's signature and validity
|
||||
# Store the user in the database, which also verifies the token validity and signature
|
||||
try:
|
||||
oidc_provider_settings = oidc_providers_settings[oidc_provider_id]
|
||||
oidc_provider_settings.decode(token["access_token"])
|
||||
except InvalidKeyError:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token invalid key / signature",
|
||||
)
|
||||
except Exception as err:
|
||||
logger.exception(err)
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Cannot decode token or verify its signature",
|
||||
)
|
||||
# Store the user in the database
|
||||
user = await db.add_user(
|
||||
sub,
|
||||
user_info=userinfo,
|
||||
oidc_provider=oidc_provider,
|
||||
user_info_from_endpoint=user_info_from_endpoint,
|
||||
access_token=token["access_token"],
|
||||
)
|
||||
# Add the id_token to the session
|
||||
request.session["token"] = token["id_token"]
|
||||
except PyJWTError as err:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"Token invalid: {err.__class__.__name__}",
|
||||
)
|
||||
assert isinstance(user, User)
|
||||
# Add the provider session id to the session
|
||||
request.session["sid"] = userinfo["sid"]
|
||||
# Add the token to the db because it is used for logout
|
||||
await db.add_token(token, user)
|
||||
assert oidc_provider.name is not None
|
||||
oidc_provider_settings = oidc_providers_settings[oidc_provider.name]
|
||||
await db.add_token(oidc_provider_settings, token)
|
||||
# Send the user to the home: (s)he is authenticated
|
||||
return RedirectResponse(url=request.url_for("home"))
|
||||
else:
|
||||
|
@ -268,8 +252,14 @@ async def logout(
|
|||
)
|
||||
return RedirectResponse(request.url_for("non_compliant_logout"))
|
||||
post_logout_uri = request.url_for("home")
|
||||
oidc_provider_settings = oidc_providers_settings.get(
|
||||
request.session.get("oidc_provider_id", "")
|
||||
)
|
||||
assert oidc_provider_settings is not None
|
||||
try:
|
||||
token = await db.get_token(request.session.pop("token", None))
|
||||
token = await db.get_token(
|
||||
oidc_provider_settings, request.session.pop("sid", None)
|
||||
)
|
||||
except TokenNotInDb:
|
||||
logger.warn("No session in db for the token or no token")
|
||||
return RedirectResponse(request.url_for("home"))
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
from functools import cached_property
|
||||
from typing import Self
|
||||
from typing import Self, Any
|
||||
|
||||
from pydantic import (
|
||||
computed_field,
|
||||
|
@ -11,7 +11,7 @@ from pydantic import (
|
|||
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger("oidc-test")
|
||||
|
||||
|
||||
class Role(SQLModel, extra="ignore"):
|
||||
|
@ -36,19 +36,9 @@ class User(UserBase):
|
|||
)
|
||||
userinfo: dict = {}
|
||||
access_token: str | None = None
|
||||
access_token_decoded: dict[str, Any] | None = None
|
||||
oidc_provider: StarletteOAuth2App | None = None
|
||||
|
||||
@classmethod
|
||||
def from_auth(cls, userinfo: dict, oidc_provider: StarletteOAuth2App) -> Self:
|
||||
user = cls(**userinfo)
|
||||
user.userinfo = userinfo
|
||||
user.oidc_provider = oidc_provider
|
||||
# Add roles if they are provided in the token
|
||||
if raw_ra := userinfo.get("realm_access"):
|
||||
if raw_roles := raw_ra.get("roles"):
|
||||
user.roles = [Role(name=raw_role) for raw_role in raw_roles]
|
||||
return user
|
||||
|
||||
@computed_field
|
||||
@cached_property
|
||||
def roles_as_set(self) -> set[str]:
|
||||
|
@ -68,7 +58,7 @@ class User(UserBase):
|
|||
assert self.access_token is not None
|
||||
assert self.oidc_provider is not None
|
||||
assert self.oidc_provider.name is not None
|
||||
from .auth_utils import oidc_providers_settings
|
||||
from .settings import oidc_providers_settings
|
||||
|
||||
return oidc_providers_settings[self.oidc_provider.name].decode(
|
||||
self.access_token
|
||||
|
|
|
@ -8,7 +8,7 @@ from starlette.status import HTTP_401_UNAUTHORIZED
|
|||
|
||||
from .models import User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = logging.getLogger("oidc-test")
|
||||
|
||||
|
||||
async def get_resource(resource_id: str, user: User) -> dict:
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
from os import environ
|
||||
import string
|
||||
import random
|
||||
from typing import Type, Tuple
|
||||
from typing import Type, Tuple, Any
|
||||
from pathlib import Path
|
||||
import logging
|
||||
|
||||
from jwt import decode
|
||||
from pydantic import BaseModel, computed_field, AnyUrl
|
||||
|
@ -16,6 +17,8 @@ from starlette.requests import Request
|
|||
|
||||
from .models import User
|
||||
|
||||
logger = logging.getLogger("oidc-test")
|
||||
|
||||
|
||||
class Resource(BaseModel):
|
||||
"""A resource with an URL that can be accessed with an OAuth2 access token"""
|
||||
|
@ -86,14 +89,27 @@ class OIDCProvider(BaseModel):
|
|||
-----END PUBLIC KEY-----
|
||||
"""
|
||||
|
||||
def decode(self, token: str) -> dict:
|
||||
def decode(self, token: str, verify_signature: bool = True) -> dict[str, Any]:
|
||||
"""Decode the token with signature check"""
|
||||
decoded = decode(
|
||||
token,
|
||||
self.get_public_key(),
|
||||
algorithms=[self.signature_alg],
|
||||
audience=["account", "oidc-test", "oidc-test-web"],
|
||||
options={
|
||||
"verify_signature": False,
|
||||
"verify_aud": False,
|
||||
}, # not settings.insecure.skip_verify_signature},
|
||||
)
|
||||
logger.debug(str(decoded))
|
||||
return decode(
|
||||
token,
|
||||
self.get_public_key(),
|
||||
algorithms=[self.signature_alg],
|
||||
audience=["oidc-test", "oidc-test-web"],
|
||||
options={"verify_signature": not settings.insecure.skip_verify_signature},
|
||||
audience=["account", "oidc-test", "oidc-test-web"],
|
||||
options={
|
||||
"verify_signature": verify_signature,
|
||||
}, # not settings.insecure.skip_verify_signature},
|
||||
)
|
||||
|
||||
|
||||
|
@ -156,3 +172,8 @@ class Settings(BaseSettings):
|
|||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
oidc_providers_settings: dict[str, OIDCProvider] = dict(
|
||||
[(provider.id, provider) for provider in settings.oidc.providers]
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue