gisaf-backend/src/gisaf/models/models_base.py

137 lines
6 KiB
Python
Raw Normal View History

from typing import Any, Dict, Tuple, Type, ClassVar
import logging
from sqlmodel import Field, SQLModel, MetaData, JSON, TEXT, Relationship, Column, select
from pydantic import computed_field
import numpy as np
import pandas as pd
import geopandas as gpd
import shapely
from sqlalchemy.sql import sqltypes
from geoalchemy2.types import Geometry
from gisaf.database import BaseModel
pandas_cast_map = {
sqltypes.Integer: 'Int64',
sqltypes.Float: 'float64',
}
logger = logging.getLogger('model_base_base')
class Model(BaseModel):
"""
Base mixin class for models that can be converted to a Pandas dataframe with get_df
"""
# status: ClassVar[str] = 'E'
def __new__(cls, *args, **kwargs):
if not hasattr(cls, 'query'):
cls.query = select(cls)
return super().__new__(cls, *args, **kwargs)
class Meta:
filtered_columns_on_map: list[str] = []
@classmethod
def get_store_name(cls):
if hasattr(cls, '__table__'):
return cls.__table__.fullname
elif hasattr(cls, '__table_args__') and 'schema' in cls.__table_args__:
return f"{cls.__table_args__.schema}.{cls.__tablename__}"
else:
return f'{cls.metadata.schema}.{cls.__tablename__}'
@classmethod
def get_table_name_prefix(cls):
return "{}_{}".format(cls.metadata.schema, cls.__tablename__)
# @classmethod
# async def get_df(cls, where=None,
# with_related=None, recursive=True,
# cast=True,
# with_only_columns=None,
# geom_as_ewkt=False,
# **kwargs):
# """
# Return a Pandas dataframe of all records
# Optional arguments:
# * an SQLAlchemy where clause
# * with_related: automatically get data from related columns, following the foreign keys in the model definitions
# * cast: automatically transform various data in their best python types (eg. with date, time...)
# * with_only_columns: fetch only these columns (list of column names)
# * geom_as_ewkt: convert geometry columns to EWKB (handy for eg. using upsert_df)
# :return:
# """
# breakpoint()
# if hasattr(cls, 'query'):
# query = cls.query
# else:
# query = select(cls)
# # if with_related is not False:
# # if with_related or getattr(cls, 'get_gdf_with_related', False):
# # joins = get_join_with(cls, recursive)
# # model_loader = cls.load(**joins)
# # query = _get_query_with_table_names(model_loader)
# if where is not None:
# query.append_whereclause(where)
# if with_only_columns:
# query = query.with_only_columns([getattr(cls, colname) for colname in with_only_columns])
# ## Got idea from https://github.com/MagicStack/asyncpg/issues/173.
# breakpoint()
# async with query.bind.raw_pool.acquire() as conn:
# ## Convert hstore fields to dict
# await conn.set_builtin_type_codec('hstore', codec_name='pg_contrib.hstore')
# compiled = query.compile()
# stmt = await conn.prepare(compiled.string)
# columns = [a.name for a in stmt.get_attributes()]
# data = await stmt.fetch(*[compiled.params.get(param) for param in compiled.positiontup])
# df = pd.DataFrame(data, columns=columns)
# ## Convert primary key columns to Int64:
# ## allows NaN, fixing type convertion to float with merge
# for pk in [c.name for c in cls.__table__.primary_key.columns]:
# if pk in df.columns and df[pk].dtype=='int64':
# df[pk] = df[pk].astype('Int64')
# if cast:
# ## Cast the type for known types (datetime, ...)
# for column_name in df.columns:
# col = getattr(query.columns, column_name, None)
# if col is None:
# logger.debug(f'Cannot get column {column_name} in query for model {cls.__name__}')
# continue
# column_type = getattr(query.columns, column_name).type
# ## XXX: Needs refinment, eg. nullable -> Int64 ...
# if column_type.__class__ in pandas_cast_map:
# df[column_name] = df[column_name].astype(pandas_cast_map[column_type.__class__])
# elif isinstance(column_type, (sqltypes.Date, sqltypes.DateTime)):
# ## Dates, times
# df[column_name] = pd.to_datetime(df[column_name])
# #elif isinstance(column_type, (sqltypes.Integer, sqltypes.Float)):
# # ## Numeric
# # df[column_name] = pd.to_numeric(df[column_name], errors='coerce')
# ## XXX: keeping this note about that is about "char" SQL type, but the fix of #9694 makes it unnessary
# #elif isinstance(column_type, sqltypes.CHAR) or (isinstance(column_type, sqltypes.String) and column_type.length == 1):
# # ## Workaround for bytes being used for string of length 1 (not sure - why???)
# # df[column_name] = df[column_name].str.decode('utf-8')
# ## Rename the columns, removing the schema_table prefix for the columns in that model
# prefix = cls.get_table_name_prefix()
# prefix_length = len(prefix) + 1
# rename_map = {colname: colname[prefix_length:] for colname in df.columns if colname.startswith(prefix)}
# df.rename(columns=rename_map, inplace=True)
# ## Eventually convert geometry columns to EWKB
# if geom_as_ewkt:
# geometry_columns = [col.name for col in cls.__table__.columns if isinstance(col.type, Geometry)]
# for column in geometry_columns:
# df[column] = shapely.to_wkb(shapely.from_wkb(df.geom), hex=True, include_srid=True)
# return df