from functools import cached_property from typing import Self from pydantic import computed_field, AnyHttpUrl, EmailStr, ConfigDict from authlib.integrations.starlette_client.apps import StarletteOAuth2App from sqlmodel import SQLModel, Field class Role(SQLModel, extra="ignore"): name: str class UserBase(SQLModel, extra="ignore"): id: str | None = None sid: str | None = None name: str email: EmailStr | None = None picture: AnyHttpUrl | None = None roles: list[Role] = [] class User(UserBase): model_config = ConfigDict(arbitrary_types_allowed=True) sub: str = Field( description="""subject id of the user given by the oidc provider, also the key for the database 'table'""", ) userinfo: dict = {} oidc_provider: StarletteOAuth2App | None = None @classmethod def from_auth(cls, userinfo: dict, oidc_provider: StarletteOAuth2App) -> Self: user = cls(**userinfo) user.userinfo = userinfo user.oidc_provider = oidc_provider # Add roles if they are provided in the token if raw_ra := userinfo.get("realm_access"): if raw_roles := raw_ra.get("roles"): user.roles = [Role(name=raw_role) for raw_role in raw_roles] return user @computed_field @cached_property def roles_as_set(self) -> set[str]: return set([role.name for role in self.roles]) class OAuth2Token(SQLModel): name: str = Field(primary_key=True) token_type: str # = Field(max_length=40) access_token: str # = Field(max_length=2000) raw_id_token: str refresh_token: str # = Field(max_length=200) expires_at: int # = PositiveIntegerField() user: User # = ForeignKey(User) def to_token(self): return dict( access_token=self.access_token, token_type=self.token_type, refresh_token=self.refresh_token, expires_at=self.expires_at, ) @classmethod def from_dict(cls, token_dict: dict, user: User) -> Self: return cls( name=token_dict["access_token"], access_token=token_dict["access_token"], raw_id_token=token_dict["id_token"], token_type=token_dict["token_type"], refresh_token=token_dict["refresh_token"], expires_at=token_dict["expires_at"], user=user, )