Allow sqlmodel queries, with relations Remode join_with mechanisms coming from gino Handlew ith_only_columns in get_df and get_gdf Implement feature-info
94 lines
No EOL
3.7 KiB
Python
94 lines
No EOL
3.7 KiB
Python
from contextlib import asynccontextmanager
|
|
from typing import Annotated, Literal, Any
|
|
from collections.abc import AsyncGenerator
|
|
|
|
from sqlalchemy.ext.asyncio import create_async_engine
|
|
from sqlalchemy.orm import joinedload, QueryableAttribute, InstrumentedAttribute
|
|
from sqlmodel import SQLModel, select
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
from fastapi import Depends
|
|
import pandas as pd
|
|
import geopandas as gpd
|
|
|
|
from gisaf.config import conf
|
|
|
|
engine = create_async_engine(
|
|
conf.db.get_sqla_url(),
|
|
echo=conf.db.echo,
|
|
pool_size=conf.db.pool_size,
|
|
max_overflow=conf.db.max_overflow,
|
|
)
|
|
|
|
async def get_db_session() -> AsyncGenerator[AsyncSession]:
|
|
async with AsyncSession(engine) as session:
|
|
yield session
|
|
|
|
@asynccontextmanager
|
|
async def db_session() -> AsyncGenerator[AsyncSession]:
|
|
async with AsyncSession(engine) as session:
|
|
yield session
|
|
|
|
def pandas_query(session, query):
|
|
return pd.read_sql_query(query, session.connection())
|
|
|
|
def geopandas_query(session, query, *, crs=None, cast=True):
|
|
return gpd.GeoDataFrame.from_postgis(query, session.connection(), crs=crs)
|
|
|
|
class BaseModel(SQLModel):
|
|
@classmethod
|
|
def selectinload(cls) -> list[Literal['*'] | QueryableAttribute[Any]]:
|
|
return []
|
|
|
|
@classmethod
|
|
async def get_df(cls, *,
|
|
where=None, with_related=True, **kwargs
|
|
) -> pd.DataFrame:
|
|
return await cls._get_df(pandas_query, where=None, with_related=True, **kwargs)
|
|
|
|
@classmethod
|
|
async def get_gdf(cls, *,
|
|
where=None, with_related=True, **kwargs
|
|
) -> gpd.GeoDataFrame:
|
|
return await cls._get_df(geopandas_query,
|
|
where=None, with_related=True, **kwargs)
|
|
|
|
@classmethod
|
|
async def _get_df(cls, method, *,
|
|
where=None, with_related=True, with_only_columns=[], **kwargs
|
|
) -> pd.DataFrame | gpd.GeoDataFrame:
|
|
async with db_session() as session:
|
|
if len(with_only_columns) == 0:
|
|
query = select(cls)
|
|
else:
|
|
columns = set(with_only_columns)
|
|
columns.add(*(col.name for col in cls.__table__.primary_key.columns))
|
|
query = select(*(getattr(cls, col) for col in columns))
|
|
if where is not None:
|
|
query.append_whereclause(where)
|
|
## Get the joined tables
|
|
joined_tables = cls.selectinload()
|
|
if with_related and len(joined_tables) > 0:
|
|
query = query.options(*(joinedload(jt) for jt in joined_tables))
|
|
df = await session.run_sync(method, query, **kwargs)
|
|
## Chamge column names to reflect the joined tables
|
|
## Leave the first columns unchanged, as their names come straight
|
|
## from the model's fields
|
|
joined_columns = list(df.columns[len(cls.model_fields):])
|
|
renames: dict[str, str] = {}
|
|
## Match colum names with the joined tables
|
|
## Important: this assumes that orders of the joined tables
|
|
## and their columns is preserved by pandas' read_sql
|
|
for joined_table in joined_tables:
|
|
target = joined_table.property.target # type: ignore
|
|
target_name = target.name
|
|
for col in target.columns:
|
|
## Pop the column from the colujmn list and make a new name
|
|
renames[joined_columns.pop(0)] = f'{target.schema}_{target_name}_{col.name}'
|
|
df.rename(columns=renames, inplace=True)
|
|
## Finally, set the index of the df as the index of cls
|
|
df.set_index([c.name for c in cls.__table__.primary_key.columns], # type: ignore
|
|
inplace=True)
|
|
return df
|
|
|
|
|
|
fastapi_db_session = Annotated[AsyncSession, Depends(get_db_session)] |