treetrail-backend/src/treetrail/models.py
2024-10-23 16:19:51 +02:00

271 lines
No EOL
10 KiB
Python

from typing import Annotated, Any, Literal
from datetime import datetime
import uuid
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import joinedload, QueryableAttribute
from geoalchemy2 import Geometry, WKBElement # type: ignore
from sqlmodel import (SQLModel, Field, String, Relationship, JSON,
select)
import pandas as pd
import geopandas as gpd # type: ignore
from treetrail.utils import pandas_query, geopandas_query
from treetrail.config import Map, conf, App
from treetrail.database import db_session
class BaseModel(SQLModel):
@classmethod
def selectinload(cls) -> list[Literal['*'] | QueryableAttribute[Any]]:
return []
@classmethod
async def get_df(cls, **kwargs) -> pd.DataFrame:
return await cls._get_df(pandas_query, **kwargs)
@classmethod
async def get_gdf(cls, **kwargs) -> gpd.GeoDataFrame:
return await cls._get_df(geopandas_query, model=cls, **kwargs) # type: ignore
@classmethod
async def _get_df(cls, method, *,
where=None, with_related=True, with_only_columns=[],
simplify_tolerance: float | None=None,
preserve_topology: bool | None=None,
**kwargs) -> pd.DataFrame | gpd.GeoDataFrame:
async with db_session() as session:
if len(with_only_columns) == 0:
query = select(cls)
else:
columns = set(with_only_columns)
# TODO: user SQLModel model_fields instead of __table__
columns.add(*(col.name for col in cls.__table__.primary_key.columns)) # type: ignore
query = select(*(getattr(cls, col) for col in columns))
if where is not None:
query = query.where(where)
## Get the joined tables
joined_tables = cls.selectinload()
if with_related and len(joined_tables) > 0:
query = query.options(*(joinedload(jt) for jt in joined_tables))
df = await session.run_sync(method, query, **kwargs)
if method is geopandas_query and simplify_tolerance is not None:
df['geom'] = df['geom'].simplify(
simplify_tolerance / conf.geo.simplify_geom_factor,
preserve_topology=(conf.geo.simplify_preserve_topology
if preserve_topology is None
else preserve_topology)
)
## Chamge column names to reflect the joined tables
## Leave the first columns unchanged, as their names come straight
## from the model's fields
joined_columns = list(df.columns[len(cls.model_fields):])
renames: dict[str, str] = {}
## Match colum names with the joined tables
## Important: this assumes that orders of the joined tables
## and their columns is preserved by pandas' read_sql
for joined_table in joined_tables:
target = joined_table.property.target # type: ignore
target_name = target.name
for col in target.columns:
## Pop the column from the colujmn list and make a new name
renames[joined_columns.pop(0)] = f'{target.schema}_{target_name}_{col.name}'
df.rename(columns=renames, inplace=True)
## Finally, set the index of the df as the index of cls
df.set_index([c.name for c in cls.__table__.primary_key.columns], # type: ignore
inplace=True)
return df
class TreeTrail(BaseModel, table=True):
__tablename__: str = 'tree_trail' # type: ignore
tree_id: uuid.UUID | None = Field(
default=None,
foreign_key='tree.id',
primary_key=True
)
trail_id: int | None = Field(
default=None,
foreign_key='trail.id',
primary_key=True
)
class Trail(BaseModel, table=True):
__tablename__: str = "trail" # type: ignore
id: int = Field(primary_key=True)
name: str
description: str
create_date: datetime = Field(default_factory=datetime.now)
geom: Annotated[str, WKBElement] = Field(
sa_type=Geometry('LINESTRING', srid=4326, dimension=2),
)
photo: str = Field(sa_type=String(250)) # type: ignore
trees: list['Tree'] = Relationship(
link_model=TreeTrail,
back_populates="trails")
viewable_role_id: str | None = Field(foreign_key='role.name', index=True)
viewable_role: 'Role' = Relationship(back_populates='viewable_trails')
# __mapper_args__ = {"eager_defaults": True}
life_stages = ('Y', 'MA', 'M', 'OM', 'A')
class Tree(BaseModel, table=True):
__tablename__: str = "tree" # type: ignore
id: uuid.UUID | None = Field(
default_factory=uuid.uuid1,
primary_key=True,
index=True,
nullable=False,
)
create_date: datetime = Field(default_factory=datetime.now)
# ALTER TABLE tree ADD CONSTRAINT tree_plant_id_fkey FOREIGN KEY (plantekey_id) REFERENCES plant(id); # noqa: E501
plantekey_id: str = Field(foreign_key='plant.id')
geom: Annotated[str, WKBElement] = Field(
sa_type=Geometry('POINT', srid=4326, dimension=2))
photo: str | None = Field(sa_type=String(250)) # type: ignore
height: float | None
comments: str | None
# ALTER TABLE public.tree ADD contributor_id varchar(50) NULL;
# ALTER TABLE public.tree ADD CONSTRAINT contributor_fk FOREIGN KEY (contributor_id) REFERENCES public."user"(username);
contributor_id: str = Field(foreign_key='user.username', index=True)
contributor: 'User' = Relationship()
viewable_role_id: str | None = Field(foreign_key='role.name', index=True)
viewable_role: 'Role' = Relationship(back_populates='viewable_trees')
# CREATE EXTENSION hstore;
# ALTER TABLE tree ADD COLUMN data JSONB;
data: dict = Field(sa_type=MutableDict.as_mutable(JSONB), # type: ignore
default_factory=dict) # type: ignore
trails: list[Trail] = Relationship(
link_model=TreeTrail,
back_populates="trees")
__mapper_args__ = {"eager_defaults": True}
@classmethod
def get_tree_insert_params(
cls,
plantekey_id: str,
lng, lat,
username,
details: dict,
) -> dict:
params = {
'plantekey_id': plantekey_id,
'geom': f'POINT({lng} {lat})',
'contributor_id': username
}
## Consume some details in their respective field...
if p:=details.pop('comments', None):
params['comments'] = p
if p:=details.pop('height', None):
params['height'] = p
# ... and store the rest in data
params['data'] = {k: v for k, v in details.items() if v}
return params
class UserRoleLink(SQLModel, table=True):
__tablename__: str = 'roles_users' # type: ignore
user_id: str | None = Field(
default=None,
foreign_key='user.username',
primary_key=True
)
role_id: str | None = Field(
default=None,
foreign_key='role.name',
primary_key=True
)
class UserBase(BaseModel):
username: str = Field(sa_type=String(50), primary_key=True) # type: ignore
full_name: str | None = None
email: str | None = None
class User(UserBase, table=True):
__tablename__: str = "user" # type: ignore
roles: list["Role"] = Relationship(back_populates="users",
link_model=UserRoleLink)
password: str
disabled: bool = False
class UserWithRoles(UserBase):
roles: list['Role']
class Role(BaseModel, table=True):
__tablename__: str = "role" # type: ignore
name: str = Field(sa_type=String(50), primary_key=True) # type: ignore
users: list[User] = Relationship(back_populates="roles",
link_model=UserRoleLink)
viewable_trees: list[Tree] = Relationship(back_populates='viewable_role')
viewable_zones: list['Zone'] = Relationship(back_populates='viewable_role')
viewable_trails: list[Trail] = Relationship(back_populates='viewable_role')
class POI(BaseModel, table=True):
__tablename__: str = "poi" # type: ignore
id: int = Field(primary_key=True)
name: str = Field(sa_column=String(200)) # type: ignore
description: str | None = None
create_date: datetime = Field(default_factory=datetime.now)
geom: Annotated[str, WKBElement] = Field(
sa_type=Geometry('POINTZ', srid=4326, dimension=3))
photo: str = Field(sa_column=String(250)) # type: ignore
type: str = Field(sa_column=String(25)) # type: ignore
data: dict = Field(sa_type=MutableDict.as_mutable(JSONB), # type: ignore
default_factory=dict) # type: ignore
class Zone(BaseModel, table=True):
__tablename__: str = "zone" # type: ignore
id: int = Field(primary_key=True)
name: str = Field(sa_type=String(200)) # type:ignore
description: str
create_date: datetime = Field(default_factory=datetime.now)
geom: Annotated[str, WKBElement] = Field(
sa_type=Geometry('MULTIPOLYGON', srid=4326))
photo: str | None = Field(sa_type=String(250)) # type:ignore
type: str = Field(sa_type=String(30)) # type:ignore
data: dict | None = Field(sa_type=MutableDict.as_mutable(JSONB), # type:ignore
default_factory=dict) # type:ignore
viewable_role_id: str | None = Field(foreign_key='role.name', index=True)
viewable_role: 'Role' = Relationship(back_populates='viewable_zones')
class MapStyle(BaseModel, table=True):
__tablename__: str = "map_style" # type: ignore
id: int = Field(primary_key=True)
layer: str = Field(sa_type=String(100), nullable=False) # type:ignore
paint: dict[str, Any] | None = Field(sa_type=JSON(none_as_null=True)) # type:ignore
layout: dict[str, Any] | None = Field(sa_type=JSON(none_as_null=True)) # type:ignore
class VersionedComponent(BaseModel):
version: str
class BaseMapStyles(BaseModel):
embedded: list[str]
external: dict[str, str]
class Bootstrap(BaseModel):
client: VersionedComponent
server: VersionedComponent
app: App
user: UserWithRoles | None # type:ignore
map: Map
baseMapStyles: BaseMapStyles