Migrate joins to sqlalchemy's query options
Use native pandas read_sql_query and geopandas from_postgis Fix definiiton of status in models Fix table names Fix category fields
This commit is contained in:
parent
956147aea8
commit
75bedb3e91
8 changed files with 236 additions and 190 deletions
|
@ -1,11 +1,14 @@
|
|||
from contextlib import asynccontextmanager
|
||||
from typing import Annotated
|
||||
from typing import Annotated, Literal, Any
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.orm import joinedload, QueryableAttribute, InstrumentedAttribute
|
||||
from sqlmodel import SQLModel, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from fastapi import Depends
|
||||
import pandas as pd
|
||||
import geopandas as gpd
|
||||
|
||||
from gisaf.config import conf
|
||||
|
||||
|
@ -28,4 +31,52 @@ async def db_session() -> AsyncGenerator[AsyncSession]:
|
|||
def pandas_query(session, query):
|
||||
return pd.read_sql_query(query, session.connection())
|
||||
|
||||
def geopandas_query(session, query, *, crs=None, cast=True):
|
||||
return gpd.GeoDataFrame.from_postgis(query, session.connection(), crs=crs)
|
||||
|
||||
class BaseModel(SQLModel):
|
||||
@classmethod
|
||||
def selectinload(cls) -> list[Literal['*'] | QueryableAttribute[Any]]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
async def get_df(cls, where=None, with_related=True, **kwargs) -> pd.DataFrame:
|
||||
return await cls._get_df(pandas_query, where=None, with_related=True, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def get_gdf(cls, *, where=None, with_related=True, **kwargs) -> gpd.GeoDataFrame:
|
||||
return await cls._get_df(geopandas_query, where=None, with_related=True, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def _get_df(cls, method, *, where=None, with_related=True, **kwargs) -> pd.DataFrame | gpd.GeoDataFrame:
|
||||
async with db_session() as session:
|
||||
query = select(cls)
|
||||
if where is not None:
|
||||
query.append_whereclause(where)
|
||||
## Get the joined tables
|
||||
joined_tables = cls.selectinload()
|
||||
if with_related and len(joined_tables) > 0:
|
||||
query = query.options(*(joinedload(jt) for jt in joined_tables))
|
||||
df = await session.run_sync(method, query, **kwargs)
|
||||
## Chamge column names to reflect the joined tables
|
||||
## Leave the first columns unchanged, as their names come straight
|
||||
## from the model's fields
|
||||
joined_columns = list(df.columns[len(cls.model_fields):])
|
||||
renames: dict[str, str] = {}
|
||||
## Match colum names with the joined tables
|
||||
## Important: this assumes that orders of the joined tables
|
||||
## and their columns is preserved by pandas' read_sql
|
||||
for joined_table in joined_tables:
|
||||
target = joined_table.property.target # type: ignore
|
||||
target_name = target.name
|
||||
for col in target.columns:
|
||||
## Pop the column from the colujmn list and make a new name
|
||||
renames[joined_columns.pop(0)] = f'{target.schema}_{target_name}_{col.name}'
|
||||
df.rename(columns=renames, inplace=True)
|
||||
## Finally, set the index of the df as the index of cls
|
||||
df.set_index([c.name for c in cls.__table__.primary_key.columns], # type: ignore
|
||||
inplace=True)
|
||||
return df
|
||||
|
||||
|
||||
fastapi_db_session = Annotated[AsyncSession, Depends(get_db_session)]
|
Loading…
Add table
Add a link
Reference in a new issue