gisaf-backend/src/gisaf/database.py

82 lines
3.4 KiB
Python
Raw Normal View History

from contextlib import asynccontextmanager
from typing import Annotated, Literal, Any
2023-12-23 15:08:42 +05:30
from collections.abc import AsyncGenerator
2023-11-06 17:04:17 +05:30
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import joinedload, QueryableAttribute, InstrumentedAttribute
from sqlmodel import SQLModel, select
2023-11-06 17:04:17 +05:30
from sqlmodel.ext.asyncio.session import AsyncSession
from fastapi import Depends
2023-11-06 17:04:17 +05:30
import pandas as pd
import geopandas as gpd
2023-11-06 17:04:17 +05:30
from gisaf.config import conf
engine = create_async_engine(
2023-12-23 15:08:42 +05:30
conf.db.get_sqla_url(),
echo=conf.db.echo,
pool_size=conf.db.pool_size,
max_overflow=conf.db.max_overflow,
)
2023-12-23 15:08:42 +05:30
async def get_db_session() -> AsyncGenerator[AsyncSession]:
async with AsyncSession(engine) as session:
yield session
2023-11-06 17:04:17 +05:30
@asynccontextmanager
2023-12-23 15:08:42 +05:30
async def db_session() -> AsyncGenerator[AsyncSession]:
2023-11-06 17:04:17 +05:30
async with AsyncSession(engine) as session:
yield session
def pandas_query(session, query):
return pd.read_sql_query(query, session.connection())
def geopandas_query(session, query, *, crs=None, cast=True):
return gpd.GeoDataFrame.from_postgis(query, session.connection(), crs=crs)
class BaseModel(SQLModel):
@classmethod
def selectinload(cls) -> list[Literal['*'] | QueryableAttribute[Any]]:
return []
@classmethod
async def get_df(cls, where=None, with_related=True, **kwargs) -> pd.DataFrame:
return await cls._get_df(pandas_query, where=None, with_related=True, **kwargs)
@classmethod
async def get_gdf(cls, *, where=None, with_related=True, **kwargs) -> gpd.GeoDataFrame:
return await cls._get_df(geopandas_query, where=None, with_related=True, **kwargs)
@classmethod
async def _get_df(cls, method, *, where=None, with_related=True, **kwargs) -> pd.DataFrame | gpd.GeoDataFrame:
async with db_session() as session:
query = select(cls)
if where is not None:
query.append_whereclause(where)
## Get the joined tables
joined_tables = cls.selectinload()
if with_related and len(joined_tables) > 0:
query = query.options(*(joinedload(jt) for jt in joined_tables))
df = await session.run_sync(method, query, **kwargs)
## Chamge column names to reflect the joined tables
## Leave the first columns unchanged, as their names come straight
## from the model's fields
joined_columns = list(df.columns[len(cls.model_fields):])
renames: dict[str, str] = {}
## Match colum names with the joined tables
## Important: this assumes that orders of the joined tables
## and their columns is preserved by pandas' read_sql
for joined_table in joined_tables:
target = joined_table.property.target # type: ignore
target_name = target.name
for col in target.columns:
## Pop the column from the colujmn list and make a new name
renames[joined_columns.pop(0)] = f'{target.schema}_{target_name}_{col.name}'
df.rename(columns=renames, inplace=True)
## Finally, set the index of the df as the index of cls
df.set_index([c.name for c in cls.__table__.primary_key.columns], # type: ignore
inplace=True)
return df
fastapi_db_session = Annotated[AsyncSession, Depends(get_db_session)]