from typing import Any, Dict, Tuple, Type, ClassVar import logging from sqlmodel import Field, SQLModel, MetaData, JSON, TEXT, Relationship, Column, select from pydantic import computed_field import numpy as np import pandas as pd import geopandas as gpd import shapely from sqlalchemy.sql import sqltypes from geoalchemy2.types import Geometry from gisaf.database import BaseModel pandas_cast_map = { sqltypes.Integer: 'Int64', sqltypes.Float: 'float64', } logger = logging.getLogger('model_base_base') class Model(BaseModel): """ Base mixin class for models that can be converted to a Pandas dataframe with get_df """ # status: ClassVar[str] = 'E' def __new__(cls, *args, **kwargs): if not hasattr(cls, 'query'): cls.query = select(cls) return super().__new__(cls, *args, **kwargs) class Meta: filtered_columns_on_map: list[str] = [] @classmethod def get_store_name(cls): if hasattr(cls, '__table__'): return cls.__table__.fullname elif hasattr(cls, '__table_args__') and 'schema' in cls.__table_args__: return f"{cls.__table_args__.schema}.{cls.__tablename__}" else: return f'{cls.metadata.schema}.{cls.__tablename__}' @classmethod def get_table_name_prefix(cls): return "{}_{}".format(cls.metadata.schema, cls.__tablename__) # @classmethod # async def get_df(cls, where=None, # with_related=None, recursive=True, # cast=True, # with_only_columns=None, # geom_as_ewkt=False, # **kwargs): # """ # Return a Pandas dataframe of all records # Optional arguments: # * an SQLAlchemy where clause # * with_related: automatically get data from related columns, following the foreign keys in the model definitions # * cast: automatically transform various data in their best python types (eg. with date, time...) # * with_only_columns: fetch only these columns (list of column names) # * geom_as_ewkt: convert geometry columns to EWKB (handy for eg. using upsert_df) # :return: # """ # breakpoint() # if hasattr(cls, 'query'): # query = cls.query # else: # query = select(cls) # # if with_related is not False: # # if with_related or getattr(cls, 'get_gdf_with_related', False): # # joins = get_join_with(cls, recursive) # # model_loader = cls.load(**joins) # # query = _get_query_with_table_names(model_loader) # if where is not None: # query.append_whereclause(where) # if with_only_columns: # query = query.with_only_columns([getattr(cls, colname) for colname in with_only_columns]) # ## Got idea from https://github.com/MagicStack/asyncpg/issues/173. # breakpoint() # async with query.bind.raw_pool.acquire() as conn: # ## Convert hstore fields to dict # await conn.set_builtin_type_codec('hstore', codec_name='pg_contrib.hstore') # compiled = query.compile() # stmt = await conn.prepare(compiled.string) # columns = [a.name for a in stmt.get_attributes()] # data = await stmt.fetch(*[compiled.params.get(param) for param in compiled.positiontup]) # df = pd.DataFrame(data, columns=columns) # ## Convert primary key columns to Int64: # ## allows NaN, fixing type convertion to float with merge # for pk in [c.name for c in cls.__table__.primary_key.columns]: # if pk in df.columns and df[pk].dtype=='int64': # df[pk] = df[pk].astype('Int64') # if cast: # ## Cast the type for known types (datetime, ...) # for column_name in df.columns: # col = getattr(query.columns, column_name, None) # if col is None: # logger.debug(f'Cannot get column {column_name} in query for model {cls.__name__}') # continue # column_type = getattr(query.columns, column_name).type # ## XXX: Needs refinment, eg. nullable -> Int64 ... # if column_type.__class__ in pandas_cast_map: # df[column_name] = df[column_name].astype(pandas_cast_map[column_type.__class__]) # elif isinstance(column_type, (sqltypes.Date, sqltypes.DateTime)): # ## Dates, times # df[column_name] = pd.to_datetime(df[column_name]) # #elif isinstance(column_type, (sqltypes.Integer, sqltypes.Float)): # # ## Numeric # # df[column_name] = pd.to_numeric(df[column_name], errors='coerce') # ## XXX: keeping this note about that is about "char" SQL type, but the fix of #9694 makes it unnessary # #elif isinstance(column_type, sqltypes.CHAR) or (isinstance(column_type, sqltypes.String) and column_type.length == 1): # # ## Workaround for bytes being used for string of length 1 (not sure - why???) # # df[column_name] = df[column_name].str.decode('utf-8') # ## Rename the columns, removing the schema_table prefix for the columns in that model # prefix = cls.get_table_name_prefix() # prefix_length = len(prefix) + 1 # rename_map = {colname: colname[prefix_length:] for colname in df.columns if colname.startswith(prefix)} # df.rename(columns=rename_map, inplace=True) # ## Eventually convert geometry columns to EWKB # if geom_as_ewkt: # geometry_columns = [col.name for col in cls.__table__.columns if isinstance(col.type, Geometry)] # for column in geometry_columns: # df[column] = shapely.to_wkb(shapely.from_wkb(df.geom), hex=True, include_srid=True) # return df