from typing import Annotated, Any, Literal from datetime import datetime import uuid from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import joinedload, QueryableAttribute from geoalchemy2 import Geometry, WKBElement # type: ignore from sqlmodel import (SQLModel, Field, String, Relationship, JSON, select) import pandas as pd import geopandas as gpd # type: ignore from treetrail.utils import pandas_query, geopandas_query from treetrail.config import Map, conf, App from treetrail.database import db_session class BaseModel(SQLModel): @classmethod def selectinload(cls) -> list[Literal['*'] | QueryableAttribute[Any]]: return [] @classmethod async def get_df(cls, **kwargs) -> pd.DataFrame: return await cls._get_df(pandas_query, **kwargs) @classmethod async def get_gdf(cls, **kwargs) -> gpd.GeoDataFrame: return await cls._get_df(geopandas_query, model=cls, **kwargs) # type: ignore @classmethod async def _get_df(cls, method, *, where=None, with_related=True, with_only_columns=[], simplify_tolerance: float | None=None, preserve_topology: bool | None=None, **kwargs) -> pd.DataFrame | gpd.GeoDataFrame: async with db_session() as session: if len(with_only_columns) == 0: query = select(cls) else: columns = set(with_only_columns) # TODO: user SQLModel model_fields instead of __table__ columns.add(*(col.name for col in cls.__table__.primary_key.columns)) # type: ignore query = select(*(getattr(cls, col) for col in columns)) if where is not None: query = query.where(where) ## Get the joined tables joined_tables = cls.selectinload() if with_related and len(joined_tables) > 0: query = query.options(*(joinedload(jt) for jt in joined_tables)) df = await session.run_sync(method, query, **kwargs) if method is geopandas_query and simplify_tolerance is not None: df['geom'] = df['geom'].simplify( simplify_tolerance / conf.geo.simplify_geom_factor, preserve_topology=(conf.geo.simplify_preserve_topology if preserve_topology is None else preserve_topology) ) ## Chamge column names to reflect the joined tables ## Leave the first columns unchanged, as their names come straight ## from the model's fields joined_columns = list(df.columns[len(cls.model_fields):]) renames: dict[str, str] = {} ## Match colum names with the joined tables ## Important: this assumes that orders of the joined tables ## and their columns is preserved by pandas' read_sql for joined_table in joined_tables: target = joined_table.property.target # type: ignore target_name = target.name for col in target.columns: ## Pop the column from the colujmn list and make a new name renames[joined_columns.pop(0)] = f'{target.schema}_{target_name}_{col.name}' df.rename(columns=renames, inplace=True) ## Finally, set the index of the df as the index of cls df.set_index([c.name for c in cls.__table__.primary_key.columns], # type: ignore inplace=True) return df class TreeTrail(BaseModel, table=True): __tablename__: str = 'tree_trail' # type: ignore tree_id: uuid.UUID | None = Field( default=None, foreign_key='tree.id', primary_key=True ) trail_id: int | None = Field( default=None, foreign_key='trail.id', primary_key=True ) class Trail(BaseModel, table=True): __tablename__: str = "trail" # type: ignore id: int = Field(primary_key=True) name: str description: str create_date: datetime = Field(default_factory=datetime.now) geom: Annotated[str, WKBElement] = Field( sa_type=Geometry('LINESTRING', srid=4326, dimension=2), ) photo: str = Field(sa_type=String(250)) # type: ignore trees: list['Tree'] = Relationship( link_model=TreeTrail, back_populates="trails") viewable_role_id: str | None = Field(foreign_key='role.name', index=True) viewable_role: 'Role' = Relationship(back_populates='viewable_trails') # __mapper_args__ = {"eager_defaults": True} life_stages = ('Y', 'MA', 'M', 'OM', 'A') class Tree(BaseModel, table=True): __tablename__: str = "tree" # type: ignore id: uuid.UUID | None = Field( default_factory=uuid.uuid1, primary_key=True, index=True, nullable=False, ) create_date: datetime = Field(default_factory=datetime.now) # ALTER TABLE tree ADD CONSTRAINT tree_plant_id_fkey FOREIGN KEY (plantekey_id) REFERENCES plant(id); # noqa: E501 plantekey_id: str = Field(foreign_key='plant.id') geom: Annotated[str, WKBElement] = Field( sa_type=Geometry('POINT', srid=4326, dimension=2)) photo: str | None = Field(sa_type=String(250)) # type: ignore height: float | None comments: str | None # ALTER TABLE public.tree ADD contributor_id varchar(50) NULL; # ALTER TABLE public.tree ADD CONSTRAINT contributor_fk FOREIGN KEY (contributor_id) REFERENCES public."user"(username); contributor_id: str = Field(foreign_key='user.username', index=True) contributor: 'User' = Relationship() viewable_role_id: str | None = Field(foreign_key='role.name', index=True) viewable_role: 'Role' = Relationship(back_populates='viewable_trees') # CREATE EXTENSION hstore; # ALTER TABLE tree ADD COLUMN data JSONB; data: dict = Field(sa_type=MutableDict.as_mutable(JSONB), # type: ignore default_factory=dict) # type: ignore trails: list[Trail] = Relationship( link_model=TreeTrail, back_populates="trees") __mapper_args__ = {"eager_defaults": True} @classmethod def get_tree_insert_params( cls, plantekey_id: str, lng, lat, username, details: dict, ) -> dict: params = { 'plantekey_id': plantekey_id, 'geom': f'POINT({lng} {lat})', 'contributor_id': username } ## Consume some details in their respective field... if p:=details.pop('comments', None): params['comments'] = p if p:=details.pop('height', None): params['height'] = p # ... and store the rest in data params['data'] = {k: v for k, v in details.items() if v} return params class UserRoleLink(SQLModel, table=True): __tablename__: str = 'roles_users' # type: ignore user_id: str | None = Field( default=None, foreign_key='user.username', primary_key=True ) role_id: str | None = Field( default=None, foreign_key='role.name', primary_key=True ) class UserBase(BaseModel): username: str = Field(sa_type=String(50), primary_key=True) # type: ignore full_name: str | None = None email: str | None = None class User(UserBase, table=True): __tablename__: str = "user" # type: ignore roles: list["Role"] = Relationship(back_populates="users", link_model=UserRoleLink) password: str disabled: bool = False class UserWithRoles(UserBase): roles: list['Role'] class Role(BaseModel, table=True): __tablename__: str = "role" # type: ignore name: str = Field(sa_type=String(50), primary_key=True) # type: ignore users: list[User] = Relationship(back_populates="roles", link_model=UserRoleLink) viewable_trees: list[Tree] = Relationship(back_populates='viewable_role') viewable_zones: list['Zone'] = Relationship(back_populates='viewable_role') viewable_trails: list[Trail] = Relationship(back_populates='viewable_role') class POI(BaseModel, table=True): __tablename__: str = "poi" # type: ignore id: int = Field(primary_key=True) name: str = Field(sa_column=String(200)) # type: ignore description: str | None = None create_date: datetime = Field(default_factory=datetime.now) geom: Annotated[str, WKBElement] = Field( sa_type=Geometry('POINTZ', srid=4326, dimension=3)) photo: str = Field(sa_column=String(250)) # type: ignore type: str = Field(sa_column=String(25)) # type: ignore data: dict = Field(sa_type=MutableDict.as_mutable(JSONB), # type: ignore default_factory=dict) # type: ignore class Zone(BaseModel, table=True): __tablename__: str = "zone" # type: ignore id: int = Field(primary_key=True) name: str = Field(sa_type=String(200)) # type:ignore description: str create_date: datetime = Field(default_factory=datetime.now) geom: Annotated[str, WKBElement] = Field( sa_type=Geometry('MULTIPOLYGON', srid=4326)) photo: str | None = Field(sa_type=String(250)) # type:ignore type: str = Field(sa_type=String(30)) # type:ignore data: dict | None = Field(sa_type=MutableDict.as_mutable(JSONB), # type:ignore default_factory=dict) # type:ignore viewable_role_id: str | None = Field(foreign_key='role.name', index=True) viewable_role: 'Role' = Relationship(back_populates='viewable_zones') class MapStyle(BaseModel, table=True): __tablename__: str = "map_style" # type: ignore id: int = Field(primary_key=True) layer: str = Field(sa_type=String(100), nullable=False) # type:ignore paint: dict[str, Any] | None = Field(sa_type=JSON(none_as_null=True)) # type:ignore layout: dict[str, Any] | None = Field(sa_type=JSON(none_as_null=True)) # type:ignore class VersionedComponent(BaseModel): version: str class BaseMapStyles(BaseModel): embedded: list[str] external: dict[str, str] class Bootstrap(BaseModel): client: VersionedComponent server: VersionedComponent app: App user: UserWithRoles | None # type:ignore map: Map baseMapStyles: BaseMapStyles