gisaf-backend/src/gisaf/database.py
phil ec71b6ed15 Remove custom sqlalchemy metadata, manage with __table_args__
Allow sqlmodel queries, with relations
Remode join_with mechanisms coming from gino
Handlew ith_only_columns in get_df and get_gdf
Implement feature-info
2024-01-04 18:50:23 +05:30

94 lines
No EOL
3.7 KiB
Python

from contextlib import asynccontextmanager
from typing import Annotated, Literal, Any
from collections.abc import AsyncGenerator
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import joinedload, QueryableAttribute, InstrumentedAttribute
from sqlmodel import SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession
from fastapi import Depends
import pandas as pd
import geopandas as gpd
from gisaf.config import conf
engine = create_async_engine(
conf.db.get_sqla_url(),
echo=conf.db.echo,
pool_size=conf.db.pool_size,
max_overflow=conf.db.max_overflow,
)
async def get_db_session() -> AsyncGenerator[AsyncSession]:
async with AsyncSession(engine) as session:
yield session
@asynccontextmanager
async def db_session() -> AsyncGenerator[AsyncSession]:
async with AsyncSession(engine) as session:
yield session
def pandas_query(session, query):
return pd.read_sql_query(query, session.connection())
def geopandas_query(session, query, *, crs=None, cast=True):
return gpd.GeoDataFrame.from_postgis(query, session.connection(), crs=crs)
class BaseModel(SQLModel):
@classmethod
def selectinload(cls) -> list[Literal['*'] | QueryableAttribute[Any]]:
return []
@classmethod
async def get_df(cls, *,
where=None, with_related=True, **kwargs
) -> pd.DataFrame:
return await cls._get_df(pandas_query, where=None, with_related=True, **kwargs)
@classmethod
async def get_gdf(cls, *,
where=None, with_related=True, **kwargs
) -> gpd.GeoDataFrame:
return await cls._get_df(geopandas_query,
where=None, with_related=True, **kwargs)
@classmethod
async def _get_df(cls, method, *,
where=None, with_related=True, with_only_columns=[], **kwargs
) -> pd.DataFrame | gpd.GeoDataFrame:
async with db_session() as session:
if len(with_only_columns) == 0:
query = select(cls)
else:
columns = set(with_only_columns)
columns.add(*(col.name for col in cls.__table__.primary_key.columns))
query = select(*(getattr(cls, col) for col in columns))
if where is not None:
query.append_whereclause(where)
## Get the joined tables
joined_tables = cls.selectinload()
if with_related and len(joined_tables) > 0:
query = query.options(*(joinedload(jt) for jt in joined_tables))
df = await session.run_sync(method, query, **kwargs)
## Chamge column names to reflect the joined tables
## Leave the first columns unchanged, as their names come straight
## from the model's fields
joined_columns = list(df.columns[len(cls.model_fields):])
renames: dict[str, str] = {}
## Match colum names with the joined tables
## Important: this assumes that orders of the joined tables
## and their columns is preserved by pandas' read_sql
for joined_table in joined_tables:
target = joined_table.property.target # type: ignore
target_name = target.name
for col in target.columns:
## Pop the column from the colujmn list and make a new name
renames[joined_columns.pop(0)] = f'{target.schema}_{target_name}_{col.name}'
df.rename(columns=renames, inplace=True)
## Finally, set the index of the df as the index of cls
df.set_index([c.name for c in cls.__table__.primary_key.columns], # type: ignore
inplace=True)
return df
fastapi_db_session = Annotated[AsyncSession, Depends(get_db_session)]