Migrate joins to sqlalchemy's query options

Use native pandas read_sql_query and geopandas from_postgis
Fix definiiton of status in models
Fix table names
Fix category fields
This commit is contained in:
phil 2024-01-02 00:09:08 +05:30
parent 956147aea8
commit 75bedb3e91
8 changed files with 236 additions and 190 deletions

View file

@ -1,11 +1,14 @@
from contextlib import asynccontextmanager
from typing import Annotated
from typing import Annotated, Literal, Any
from collections.abc import AsyncGenerator
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import joinedload, QueryableAttribute, InstrumentedAttribute
from sqlmodel import SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession
from fastapi import Depends
import pandas as pd
import geopandas as gpd
from gisaf.config import conf
@ -28,4 +31,52 @@ async def db_session() -> AsyncGenerator[AsyncSession]:
def pandas_query(session, query):
return pd.read_sql_query(query, session.connection())
def geopandas_query(session, query, *, crs=None, cast=True):
return gpd.GeoDataFrame.from_postgis(query, session.connection(), crs=crs)
class BaseModel(SQLModel):
@classmethod
def selectinload(cls) -> list[Literal['*'] | QueryableAttribute[Any]]:
return []
@classmethod
async def get_df(cls, where=None, with_related=True, **kwargs) -> pd.DataFrame:
return await cls._get_df(pandas_query, where=None, with_related=True, **kwargs)
@classmethod
async def get_gdf(cls, *, where=None, with_related=True, **kwargs) -> gpd.GeoDataFrame:
return await cls._get_df(geopandas_query, where=None, with_related=True, **kwargs)
@classmethod
async def _get_df(cls, method, *, where=None, with_related=True, **kwargs) -> pd.DataFrame | gpd.GeoDataFrame:
async with db_session() as session:
query = select(cls)
if where is not None:
query.append_whereclause(where)
## Get the joined tables
joined_tables = cls.selectinload()
if with_related and len(joined_tables) > 0:
query = query.options(*(joinedload(jt) for jt in joined_tables))
df = await session.run_sync(method, query, **kwargs)
## Chamge column names to reflect the joined tables
## Leave the first columns unchanged, as their names come straight
## from the model's fields
joined_columns = list(df.columns[len(cls.model_fields):])
renames: dict[str, str] = {}
## Match colum names with the joined tables
## Important: this assumes that orders of the joined tables
## and their columns is preserved by pandas' read_sql
for joined_table in joined_tables:
target = joined_table.property.target # type: ignore
target_name = target.name
for col in target.columns:
## Pop the column from the colujmn list and make a new name
renames[joined_columns.pop(0)] = f'{target.schema}_{target_name}_{col.name}'
df.rename(columns=renames, inplace=True)
## Finally, set the index of the df as the index of cls
df.set_index([c.name for c in cls.__table__.primary_key.columns], # type: ignore
inplace=True)
return df
fastapi_db_session = Annotated[AsyncSession, Depends(get_db_session)]