Misc:
Basic registry, with survey stores Move to standard src/ dir versions: sqlmodel official, pydantic v2 etc...
This commit is contained in:
parent
5494f6085f
commit
049b8c9927
31 changed files with 670 additions and 526 deletions
118
src/gisaf/models/models_base.py
Normal file
118
src/gisaf/models/models_base.py
Normal file
|
@ -0,0 +1,118 @@
|
|||
from typing import Any
|
||||
import logging
|
||||
|
||||
from sqlmodel import Field, SQLModel, MetaData, JSON, TEXT, Relationship, Column
|
||||
from pydantic import computed_field
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import geopandas as gpd
|
||||
import shapely
|
||||
from sqlalchemy.sql import sqltypes
|
||||
from geoalchemy2.types import Geometry
|
||||
|
||||
pandas_cast_map = {
|
||||
sqltypes.Integer: 'Int64',
|
||||
sqltypes.Float: 'float64',
|
||||
}
|
||||
|
||||
logger = logging.getLogger('model_base_base')
|
||||
|
||||
class Model(SQLModel):
|
||||
"""
|
||||
Base mixin class for models that can be converted to a Pandas dataframe with get_df
|
||||
"""
|
||||
|
||||
class Meta:
|
||||
filtered_columns_on_map: list[str] = []
|
||||
|
||||
@classmethod
|
||||
def get_store_name(cls):
|
||||
return "{}.{}".format(cls.metadata.schema, cls.__tablename__)
|
||||
|
||||
@classmethod
|
||||
def get_table_name_prefix(cls):
|
||||
return "{}_{}".format(cls.metadata.schema, cls.__tablename__)
|
||||
|
||||
@classmethod
|
||||
async def get_df(cls, where=None,
|
||||
with_related=None, recursive=True,
|
||||
cast=True,
|
||||
with_only_columns=None,
|
||||
geom_as_ewkt=False,
|
||||
**kwargs):
|
||||
"""
|
||||
Return a Pandas dataframe of all records
|
||||
Optional arguments:
|
||||
* an SQLAlchemy where clause
|
||||
* with_related: automatically get data from related columns, following the foreign keys in the model definitions
|
||||
* cast: automatically transform various data in their best python types (eg. with date, time...)
|
||||
* with_only_columns: fetch only these columns (list of column names)
|
||||
* geom_as_ewkt: convert geometry columns to EWKB (handy for eg. using upsert_df)
|
||||
:return:
|
||||
"""
|
||||
query = cls.query
|
||||
|
||||
if with_related is not False:
|
||||
if with_related or getattr(cls, 'get_gdf_with_related', False):
|
||||
joins = get_join_with(cls, recursive)
|
||||
model_loader = cls.load(**joins)
|
||||
query = _get_query_with_table_names(model_loader)
|
||||
|
||||
if where is not None:
|
||||
query.append_whereclause(where)
|
||||
|
||||
if with_only_columns:
|
||||
query = query.with_only_columns([getattr(cls, colname) for colname in with_only_columns])
|
||||
|
||||
## Got idea from https://github.com/MagicStack/asyncpg/issues/173.
|
||||
async with query.bind.raw_pool.acquire() as conn:
|
||||
## Convert hstore fields to dict
|
||||
await conn.set_builtin_type_codec('hstore', codec_name='pg_contrib.hstore')
|
||||
|
||||
compiled = query.compile()
|
||||
stmt = await conn.prepare(compiled.string)
|
||||
columns = [a.name for a in stmt.get_attributes()]
|
||||
data = await stmt.fetch(*[compiled.params.get(param) for param in compiled.positiontup])
|
||||
df = pd.DataFrame(data, columns=columns)
|
||||
|
||||
## Convert primary key columns to Int64:
|
||||
## allows NaN, fixing type convertion to float with merge
|
||||
for pk in [c.name for c in cls.__table__.primary_key.columns]:
|
||||
if pk in df.columns and df[pk].dtype=='int64':
|
||||
df[pk] = df[pk].astype('Int64')
|
||||
|
||||
if cast:
|
||||
## Cast the type for known types (datetime, ...)
|
||||
for column_name in df.columns:
|
||||
col = getattr(query.columns, column_name, None)
|
||||
if col is None:
|
||||
logger.debug(f'Cannot get column {column_name} in query for model {cls.__name__}')
|
||||
continue
|
||||
column_type = getattr(query.columns, column_name).type
|
||||
## XXX: Needs refinment, eg. nullable -> Int64 ...
|
||||
if column_type.__class__ in pandas_cast_map:
|
||||
df[column_name] = df[column_name].astype(pandas_cast_map[column_type.__class__])
|
||||
elif isinstance(column_type, (sqltypes.Date, sqltypes.DateTime)):
|
||||
## Dates, times
|
||||
df[column_name] = pd.to_datetime(df[column_name])
|
||||
#elif isinstance(column_type, (sqltypes.Integer, sqltypes.Float)):
|
||||
# ## Numeric
|
||||
# df[column_name] = pd.to_numeric(df[column_name], errors='coerce')
|
||||
## XXX: keeping this note about that is about "char" SQL type, but the fix of #9694 makes it unnessary
|
||||
#elif isinstance(column_type, sqltypes.CHAR) or (isinstance(column_type, sqltypes.String) and column_type.length == 1):
|
||||
# ## Workaround for bytes being used for string of length 1 (not sure - why???)
|
||||
# df[column_name] = df[column_name].str.decode('utf-8')
|
||||
|
||||
## Rename the columns, removing the schema_table prefix for the columns in that model
|
||||
prefix = cls.get_table_name_prefix()
|
||||
prefix_length = len(prefix) + 1
|
||||
rename_map = {colname: colname[prefix_length:] for colname in df.columns if colname.startswith(prefix)}
|
||||
df.rename(columns=rename_map, inplace=True)
|
||||
|
||||
## Eventually convert geometry columns to EWKB
|
||||
if geom_as_ewkt:
|
||||
geometry_columns = [col.name for col in cls.__table__.columns if isinstance(col.type, Geometry)]
|
||||
for column in geometry_columns:
|
||||
df[column] = shapely.to_wkb(shapely.from_wkb(df.geom), hex=True, include_srid=True)
|
||||
|
||||
return df
|
Loading…
Add table
Add a link
Reference in a new issue