77 lines
2.4 KiB
Python
77 lines
2.4 KiB
Python
from functools import cached_property
|
|
from typing import Self
|
|
|
|
from pydantic import computed_field, AnyHttpUrl, EmailStr, ConfigDict
|
|
from authlib.integrations.starlette_client.apps import StarletteOAuth2App
|
|
from sqlmodel import SQLModel, Field
|
|
|
|
|
|
class Role(SQLModel, extra="ignore"):
|
|
name: str
|
|
|
|
|
|
class UserBase(SQLModel, extra="ignore"):
|
|
|
|
id: str | None = None
|
|
sid: str | None = None
|
|
name: str
|
|
email: EmailStr | None = None
|
|
picture: AnyHttpUrl | None = None
|
|
roles: list[Role] = []
|
|
|
|
|
|
class User(UserBase):
|
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
|
|
sub: str = Field(
|
|
description="""subject id of the user given by the oidc provider,
|
|
also the key for the database 'table'""",
|
|
)
|
|
userinfo: dict = {}
|
|
oidc_provider: StarletteOAuth2App | None = None
|
|
|
|
@classmethod
|
|
def from_auth(cls, userinfo: dict, oidc_provider: StarletteOAuth2App) -> Self:
|
|
user = cls(**userinfo)
|
|
user.userinfo = userinfo
|
|
user.oidc_provider = oidc_provider
|
|
# Add roles if they are provided in the token
|
|
if raw_ra := userinfo.get("realm_access"):
|
|
if raw_roles := raw_ra.get("roles"):
|
|
user.roles = [Role(name=raw_role) for raw_role in raw_roles]
|
|
return user
|
|
|
|
@computed_field
|
|
@cached_property
|
|
def roles_as_set(self) -> set[str]:
|
|
return set([role.name for role in self.roles])
|
|
|
|
|
|
class OAuth2Token(SQLModel):
|
|
name: str = Field(primary_key=True)
|
|
token_type: str # = Field(max_length=40)
|
|
access_token: str # = Field(max_length=2000)
|
|
raw_id_token: str
|
|
refresh_token: str # = Field(max_length=200)
|
|
expires_at: int # = PositiveIntegerField()
|
|
user: User # = ForeignKey(User)
|
|
|
|
def to_token(self):
|
|
return dict(
|
|
access_token=self.access_token,
|
|
token_type=self.token_type,
|
|
refresh_token=self.refresh_token,
|
|
expires_at=self.expires_at,
|
|
)
|
|
|
|
@classmethod
|
|
def from_dict(cls, token_dict: dict, user: User) -> Self:
|
|
return cls(
|
|
name=token_dict["access_token"],
|
|
access_token=token_dict["access_token"],
|
|
raw_id_token=token_dict["id_token"],
|
|
token_type=token_dict["token_type"],
|
|
refresh_token=token_dict["refresh_token"],
|
|
expires_at=token_dict["expires_at"],
|
|
user=user,
|
|
)
|