442 lines
15 KiB
Python
442 lines
15 KiB
Python
import logging
|
|
import asyncio
|
|
from functools import wraps
|
|
from json import JSONEncoder
|
|
from math import isnan
|
|
from time import time
|
|
import datetime
|
|
from typing import Any
|
|
|
|
from numpy import ndarray
|
|
import pandas as pd
|
|
|
|
from sqlalchemy.dialects.postgresql import insert
|
|
from sqlmodel import SQLModel, delete
|
|
|
|
from gisaf.config import conf
|
|
from gisaf.database import db_session
|
|
|
|
SHAPELY_TYPE_TO_MAPBOX_TYPE = {
|
|
'Point': 'symbol',
|
|
'LineString': 'line',
|
|
'Polygon': 'fill',
|
|
'MultiPolygon': 'fill',
|
|
}
|
|
|
|
DEFAULT_MAPBOX_LAYOUT: dict[str, dict[str, Any]] = {
|
|
'symbol': {
|
|
'text-line-height': 1,
|
|
'text-padding': 0,
|
|
'text-allow-overlap': True,
|
|
'text-field': '\ue32b',
|
|
'icon-optional': True,
|
|
'text-font': ['GisafSymbols'],
|
|
'text-size': 24,
|
|
}
|
|
}
|
|
|
|
DEFAULT_MAPBOX_PAINT = {
|
|
'symbol': {
|
|
'text-translate-anchor': 'viewport',
|
|
'text-color': '#000000',
|
|
},
|
|
'line': {
|
|
'line-color': 'red',
|
|
'line-opacity': 0.70,
|
|
'line-width': 2,
|
|
'line-blur': 0.5,
|
|
},
|
|
'fill': {
|
|
'fill-color': 'blue',
|
|
'fill-opacity': 0.50,
|
|
}
|
|
}
|
|
|
|
MAPBOX_COLOR_ATTRIBUTE_NAME = {
|
|
'symbol': 'text-color',
|
|
'line': 'line-color',
|
|
'fill': 'fill-color',
|
|
}
|
|
|
|
MAPBOX_OPACITY_ATTRIBUTE_NAME = {
|
|
'symbol': 'text-opacity',
|
|
'line': 'line-opacity',
|
|
'fill': 'fill-opacity',
|
|
}
|
|
|
|
gisTypeSymbolMap = {
|
|
'Point': '\ue32b',
|
|
'Line': '\ue32c',
|
|
'Polygon': '\ue32d',
|
|
'MultiPolygon': '\ue32d',
|
|
}
|
|
|
|
|
|
# survey_to_db_project_func = pyproj.Transformer.from_crs(
|
|
# conf.geo.raw_survey.spatial_sys_ref,
|
|
# conf.geo.srid,
|
|
# always_xy=True
|
|
# ).transform
|
|
|
|
|
|
def dict_array_to_list(d: dict) -> dict:
|
|
'''Convert any ndarray a dict to plain python list.
|
|
Useful for transforming a Dataframe to a serializable object'''
|
|
for k, v in d.items():
|
|
if isinstance(v, dict):
|
|
dict_array_to_list(v)
|
|
else:
|
|
if isinstance(v, ndarray):
|
|
d[k] = v.tolist()
|
|
return d
|
|
|
|
class NumpyEncoder(JSONEncoder):
|
|
"""
|
|
Encoder that can serialize numpy arrays and datetime objects
|
|
"""
|
|
def default(self, obj):
|
|
if isinstance(obj, datetime.datetime):
|
|
return obj.isoformat()
|
|
if isinstance(obj, datetime.date):
|
|
return obj.isoformat()
|
|
if isinstance(obj, datetime.timedelta):
|
|
return (datetime.datetime.min + obj).time().isoformat()
|
|
if isinstance(obj, ndarray):
|
|
#return obj.tolist()
|
|
## TODO: convert nat to None
|
|
return [None if isinstance(rr, float) and isnan(rr) else rr for rr in obj]
|
|
if isinstance(obj, float) and isnan(obj):
|
|
return None
|
|
if isinstance(obj, bytes):
|
|
return obj.decode()
|
|
return JSONEncoder.default(self, obj)
|
|
|
|
|
|
# class GraphQlObjectTypeEncoder(JSONEncoder):
|
|
# """
|
|
# Encoder that can serialize basic Graphene ObjectTypes
|
|
# """
|
|
# def default(self, obj):
|
|
# if isinstance(obj, datetime.datetime):
|
|
# return obj.isoformat()
|
|
# if isinstance(obj, datetime.date):
|
|
# return obj.isoformat()
|
|
# if isinstance(obj, ObjectType):
|
|
# return obj.__dict__
|
|
|
|
|
|
# def json_response(data, body=None, status=200,
|
|
# reason=None, headers=None, content_type='application/json', check_circular=True,
|
|
# **kwargs):
|
|
# text = dumps(data, cls=NumpyEncoder, separators=(',', ':'), check_circular=check_circular)
|
|
# return web.Response(text=text, body=body, status=status, reason=reason,
|
|
# headers=headers, content_type=content_type, **kwargs)
|
|
|
|
|
|
# def get_join_with(cls, recursive=True):
|
|
# """
|
|
# Helper function for loading related tables with a Gino loader (left outer join)
|
|
# Should work recursively...
|
|
# Eg:
|
|
# cls.load(**get_join_with(cls)).query.gino.all()
|
|
# :param cls:
|
|
# :return:
|
|
# """
|
|
# if hasattr(cls, 'dyn_join_with'):
|
|
# joins = cls.dyn_join_with()
|
|
# else:
|
|
# joins = {}
|
|
# if hasattr(cls, '_join_with'):
|
|
# joins.update(cls._join_with)
|
|
# if not recursive:
|
|
# return joins
|
|
# recursive_joins = {}
|
|
# for name, join in joins.items():
|
|
# more_joins = get_join_with(join)
|
|
# if more_joins:
|
|
# aliased = {name: join.alias() for name, join in more_joins.items()}
|
|
# recursive_joins[name] = join.load(**aliased)
|
|
# else:
|
|
# recursive_joins[name] = join
|
|
# return recursive_joins
|
|
|
|
# def get_joined_query(cls):
|
|
# """
|
|
# Helper function to get a query from a model with all the related tables loaded
|
|
# :param cls:
|
|
# :return:
|
|
# """
|
|
# return cls.load(**get_join_with(cls)).query
|
|
|
|
|
|
def timeit(f):
|
|
"""
|
|
Decorator for timing *non async* methods (development tool for performance analysis)
|
|
"""
|
|
@wraps(f)
|
|
def wrap(*args, **kw):
|
|
ts = time()
|
|
result = f(*args, **kw)
|
|
te = time()
|
|
logging.debug('func:{} args:{}, {} took: {:2.4f} sec'.format(f.__name__, args, kw, te-ts))
|
|
return result
|
|
return wrap
|
|
|
|
|
|
def atimeit(func):
|
|
"""
|
|
Decorator for timing *async* methods (development tool for performance analysis)
|
|
"""
|
|
async def process(func, *args, **params):
|
|
if asyncio.iscoroutinefunction(func):
|
|
#logging.debug('this function is a coroutine: {}'.format(func.__name__))
|
|
return await func(*args, **params)
|
|
else:
|
|
#logging.debug('this is not a coroutine')
|
|
return func(*args, **params)
|
|
|
|
async def helper(*args, **params):
|
|
#logging.debug('{}.time'.format(func.__name__))
|
|
start = time()
|
|
result = await process(func, *args, **params)
|
|
|
|
# Test normal function route...
|
|
# result = await process(lambda *a, **p: print(*a, **p), *args, **params)
|
|
|
|
logging.debug("{} {}".format(func.__name__, time() - start))
|
|
return result
|
|
|
|
return helper
|
|
|
|
|
|
async def delete_df(df: pd.DataFrame, model: SQLModel):
|
|
"""
|
|
Delete all data in the model's table in the database
|
|
that matches data in the pandas dataframe.
|
|
"""
|
|
if len(df) == 0:
|
|
return
|
|
ids = df.reset_index()['id'].values
|
|
statement = delete(model).where(model.id.in_(ids))
|
|
async with db_session() as session:
|
|
await session.exec(statement)
|
|
await session.commit()
|
|
|
|
|
|
# async def upsert_df(df, model):
|
|
# """
|
|
# Insert or update all data in the model's table in the database
|
|
# that's present in the pandas dataframe.
|
|
# Use postgres insert ... on conflict update...
|
|
# with a series of inserts with with one row at a time.
|
|
# For GeoDataFrame: the "geometry" column (df._geometry_column_name) is not honnored
|
|
# (yet). It's the caller's responsibility to have a proper column name
|
|
# (typically "geom" in Gisaf models) with a EWKT or EWKB representation of the geometry.
|
|
# """
|
|
# ## See: https://stackoverflow.com/questions/33307250/postgresql-on-conflict-in-sqlalchemy
|
|
|
|
# if len(df) == 0:
|
|
# return df
|
|
|
|
# table = model.__table__
|
|
|
|
# ## Generate the 'upsert' statement, using fake values but defining columns
|
|
# columns = {c.name for c in table.columns}
|
|
# values = {col: None for col in df.columns if col in columns}
|
|
# insrt_stmnt = insert(table).inline().values(values).returning(table.primary_key.columns)
|
|
# df_columns = set(df.columns)
|
|
# do_update_stmt = insrt_stmnt.on_conflict_do_update(
|
|
# constraint=table.primary_key,
|
|
# set_={
|
|
# k.name: getattr(insrt_stmnt.excluded, k.name)
|
|
# for k in insrt_stmnt.excluded
|
|
# if k.name in df_columns and
|
|
# k.name not in [c.name for c in table.primary_key.columns]
|
|
# }
|
|
# )
|
|
# ## Filter and reorder the df columns
|
|
# ## in order to match the order of columns in the insert statement
|
|
# df = df[[col for col in do_update_stmt.compile().positiontup
|
|
# if col in df_columns]].copy()
|
|
|
|
# def convert_to_object(value):
|
|
# """
|
|
# Quick (but slow) and dirty: clean up values (nan, nat) for inserting to postgres via asyncpg
|
|
# """
|
|
# if isinstance(value, float) and isnan(value):
|
|
# return None
|
|
# elif pd.isna(value):
|
|
# return None
|
|
# else:
|
|
# return value
|
|
|
|
# # def encode_geometry(geometry):
|
|
# # if not hasattr(geometry, '__geo_interface__'):
|
|
# # raise TypeError('{g} does not conform to '
|
|
# # 'the geo interface'.format(g=geometry))
|
|
# # shape = shapely.geometry.asShape(geometry)
|
|
# # return shapely.wkb.dumps(shape)
|
|
|
|
# # def decode_geometry(wkb):
|
|
# # return shapely.wkb.loads(wkb)
|
|
|
|
# ## pks: list of dicts of primary keys
|
|
# pks = {pk.name: [] for pk in table.primary_key.columns}
|
|
# async with db.bind.raw_pool.acquire() as conn:
|
|
# ## Set standard encoder for HSTORE, geometry
|
|
# await conn.set_builtin_type_codec('hstore', codec_name='pg_contrib.hstore')
|
|
|
|
# #await conn.set_type_codec(
|
|
# # 'geometry', # also works for 'geography'
|
|
# # encoder=encode_geometry,
|
|
# # decoder=decode_geometry,
|
|
# # format='binary',
|
|
# #)
|
|
# #await conn.set_type_codec(
|
|
# # 'json',
|
|
# # encoder=json.dumps,
|
|
# # decoder=json.loads,
|
|
# # schema='pg_catalog'
|
|
# #)
|
|
# ## For a sequence of inserts:
|
|
# insrt_stmnt_single = await conn.prepare(str(do_update_stmt))
|
|
# async with conn.transaction():
|
|
# for row in df.itertuples(index=False):
|
|
# converted_row = [convert_to_object(v) for v in row]
|
|
# returned = await insrt_stmnt_single.fetch(*converted_row)
|
|
# for returned_single in returned:
|
|
# for pk, value in returned_single.items():
|
|
# pks[pk].append(value)
|
|
# ## Return a copy of the original df, with actual DB columns, data and the primary keys
|
|
# for pk, values in pks.items():
|
|
# df[pk] = values
|
|
# return df
|
|
|
|
def postgres_upsert(table, conn, keys, data_iter):
|
|
# See https://stackoverflow.com/questions/61366664/how-to-upsert-pandas-dataframe-to-postgresql-table
|
|
# Comment by @HopefullyThisHelps
|
|
data = [dict(zip(keys, row)) for row in data_iter]
|
|
insert_statement = insert(table.table).values(data)
|
|
upsert_statement = insert_statement.on_conflict_do_update(
|
|
constraint=f"{table.table.name}_pkey",
|
|
set_={c.key: c for c in insert_statement.excluded},
|
|
)
|
|
conn.execute(upsert_statement)
|
|
|
|
async def upsert_df(df: pd.DataFrame, model: SQLModel, chunksize: int = 1000):
|
|
if len(df) == 0:
|
|
return df
|
|
from functools import partial
|
|
import concurrent.futures
|
|
import asyncio
|
|
loop = asyncio.get_running_loop()
|
|
with concurrent.futures.ProcessPoolExecutor() as pool:
|
|
await loop.run_in_executor(
|
|
pool,
|
|
partial(
|
|
df.to_sql,
|
|
model.__tablename__,
|
|
conf.db.get_pg_url(), # Cannot use sync_engine in run_in_executor
|
|
# because it's not pickable
|
|
schema=model.__table__.schema, # type: ignore
|
|
if_exists="append",
|
|
index=False,
|
|
method=postgres_upsert,
|
|
chunksize=chunksize,
|
|
),
|
|
)
|
|
|
|
|
|
#async def upsert_df(df, model):
|
|
# """
|
|
# Experiment with port of pandas.io.sql port for asyncpg: sql_async
|
|
# """
|
|
# from .sql_async import SQLDatabase, SQLTable
|
|
#
|
|
# table = model.__table__
|
|
#
|
|
# async with db.bind.raw_pool.acquire() as conn:
|
|
# sql_db = SQLDatabase(conn)
|
|
# result = await sql_db.to_sql(df, table.name, if_exists='replace', index=False)
|
|
# return f'{len(df)} records imported (create or update)'
|
|
|
|
|
|
#async def upsert_df_bulk(df, model):
|
|
# """
|
|
# Insert or update all data in the pandas dataframe to the model's table in the database.
|
|
# Use postgres insert ... on conflict update...
|
|
# in a bulk insert with all data in one request.
|
|
# """
|
|
# raise NotImplementedError('Needs fix, use upsert_df instead')
|
|
# ## See: https://stackoverflow.com/questions/33307250/postgresql-on-conflict-in-sqlalchemy
|
|
# insrt_vals = df.to_dict(orient='records')
|
|
#
|
|
# insrt_stmnt = insert(model.__table__).values(insrt_vals)
|
|
# do_update_stmt = insrt_stmnt.on_conflict_do_update(
|
|
# constraint=model.__table__.primary_key,
|
|
# set_={
|
|
# k.name: getattr(insrt_stmnt.excluded, k.name)
|
|
# for k in insrt_stmnt.excluded
|
|
# if k.name not in [c.name for c in model.__table__.primary_key.columns]
|
|
# }
|
|
# )
|
|
# async with db.bind.raw_pool.acquire() as conn:
|
|
# ## For a sequence of inserts:
|
|
# insrt_stmnt_single = await conn.prepare(str(insert(model.__table__)))
|
|
# async with conn.transaction():
|
|
# ## TODO: flatten the insrt_vals so that they match the request's $n placeholders
|
|
# await conn.execute(do_update_stmt, insrt_vals)
|
|
|
|
#def encode_geometry(geometry):
|
|
# if not hasattr(geometry, '__geo_interface__'):
|
|
# raise TypeError('{g} does not conform to '
|
|
# 'the geo interface'.format(g=geometry))
|
|
# shape = shapely.geometry.asShape(geometry)
|
|
# geos.lgeos.GEOSSetSRID(shape._geom, conf.raw_survey['srid'])
|
|
# return shapely.wkb.dumps(shape, include_srid=True)
|
|
|
|
#def decode_geometry(wkb):
|
|
# return shapely.wkb.loads(wkb)
|
|
|
|
## XXX: dev notes
|
|
## What's the best way to save a dataframe to the DB?
|
|
## 1/ df.to_sql might have been an easy solution, doesn't support async operations
|
|
#
|
|
## 2/ Experiment with COPY (copy_records_to_table, see below): it doesn't update records.
|
|
#async with db.bind.raw_pool.acquire() as conn:
|
|
# await conn.set_type_codec(
|
|
# 'geometry', # also works for 'geography'
|
|
# encoder=encode_geometry,
|
|
# decoder=decode_geometry,
|
|
# format='binary',
|
|
# )
|
|
# async with conn.transaction():
|
|
# ## See https://github.com/MagicStack/asyncpg/issues/245
|
|
# s = await conn.copy_records_to_table(
|
|
# model.__table__.name,
|
|
# schema_name=model.__table__.schema,
|
|
# records=[tuple(x) for x in gdf_for_db.values],
|
|
# columns=list(gdf_for_db.columns),
|
|
# timeout=10
|
|
# )
|
|
#
|
|
## 3/ SqlAclhemy/Asyncpg multiple inserts, then updates
|
|
### Build SQL statements
|
|
#insert = db.insert(model.__table__).compile()
|
|
#update = db.update(model.__table__).compile()
|
|
### Reorder the columns of the dataframe
|
|
#gdf_for_db = gdf_for_db[insert.positiontup]
|
|
### Find the records whose id already present in the DB, and segregate the df
|
|
#existing_records = await model.get_df(with_only_columns=['id'])
|
|
#gdf_insert = gdf_for_db[~gdf_for_db.id.isin(existing_records.id)]
|
|
#gdf_update = gdf_for_db[gdf_for_db.id.isin(existing_records.id)]
|
|
#async with db.bind.raw_pool.acquire() as conn:
|
|
# await conn.executemany(insert.string, [tuple(x) for x in gdf_insert.values])
|
|
# await conn.executemany(update.string, [tuple(x) for x in gdf_update.values])
|
|
##
|
|
## 4/ Fall back to gino. Bad luck, there's no equivalent to "merge", so the strategy is:
|
|
## - get all records ids in DB
|
|
## - build the set of records that needs update, and other that needs insert
|
|
## - do these operations (possibly in bulk)
|
|
#
|
|
## 5/ Make a utility lib for other use cases...
|