from contextlib import asynccontextmanager from typing import Annotated, Literal, Any from collections.abc import AsyncGenerator from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.orm import joinedload, QueryableAttribute, InstrumentedAttribute from sqlmodel import SQLModel, select from sqlmodel.ext.asyncio.session import AsyncSession from fastapi import Depends import pandas as pd import geopandas as gpd from gisaf.config import conf engine = create_async_engine( conf.db.get_sqla_url(), echo=conf.db.echo, pool_size=conf.db.pool_size, max_overflow=conf.db.max_overflow, ) async def get_db_session() -> AsyncGenerator[AsyncSession]: async with AsyncSession(engine) as session: yield session @asynccontextmanager async def db_session() -> AsyncGenerator[AsyncSession]: async with AsyncSession(engine) as session: yield session def pandas_query(session, query): return pd.read_sql_query(query, session.connection()) def geopandas_query(session, query, *, crs=None, cast=True): return gpd.GeoDataFrame.from_postgis(query, session.connection(), crs=crs) class BaseModel(SQLModel): @classmethod def selectinload(cls) -> list[Literal['*'] | QueryableAttribute[Any]]: return [] @classmethod async def get_df(cls, where=None, with_related=True, **kwargs) -> pd.DataFrame: return await cls._get_df(pandas_query, where=None, with_related=True, **kwargs) @classmethod async def get_gdf(cls, *, where=None, with_related=True, **kwargs) -> gpd.GeoDataFrame: return await cls._get_df(geopandas_query, where=None, with_related=True, **kwargs) @classmethod async def _get_df(cls, method, *, where=None, with_related=True, **kwargs) -> pd.DataFrame | gpd.GeoDataFrame: async with db_session() as session: query = select(cls) if where is not None: query.append_whereclause(where) ## Get the joined tables joined_tables = cls.selectinload() if with_related and len(joined_tables) > 0: query = query.options(*(joinedload(jt) for jt in joined_tables)) df = await session.run_sync(method, query, **kwargs) ## Chamge column names to reflect the joined tables ## Leave the first columns unchanged, as their names come straight ## from the model's fields joined_columns = list(df.columns[len(cls.model_fields):]) renames: dict[str, str] = {} ## Match colum names with the joined tables ## Important: this assumes that orders of the joined tables ## and their columns is preserved by pandas' read_sql for joined_table in joined_tables: target = joined_table.property.target # type: ignore target_name = target.name for col in target.columns: ## Pop the column from the colujmn list and make a new name renames[joined_columns.pop(0)] = f'{target.schema}_{target_name}_{col.name}' df.rename(columns=renames, inplace=True) ## Finally, set the index of the df as the index of cls df.set_index([c.name for c in cls.__table__.primary_key.columns], # type: ignore inplace=True) return df fastapi_db_session = Annotated[AsyncSession, Depends(get_db_session)]