from uuid import uuid1 from io import BytesIO from asyncio import create_task from json import loads, dumps from pickle import dump, HIGHEST_PROTOCOL, loads as loads_pickle from time import time import logging import pandas as pd import geopandas as gpd # type: ignore[import-untyped] from asyncpg import connect from asyncpg.connection import Connection from asyncpg.exceptions import UndefinedTableError, InterfaceError from sqlalchemy import text from redis import asyncio as aioredis from gisaf.config import conf # from gisaf.models.live import LiveModel from gisaf.models.map_bases import MaplibreStyle from gisaf.utils import (SHAPELY_TYPE_TO_MAPBOX_TYPE, DEFAULT_MAPBOX_LAYOUT, DEFAULT_MAPBOX_PAINT, gisTypeSymbolMap) from gisaf.registry import registry #from .models.geom import GeomGroup, GeomModel from gisaf.models.geo_models_base import LiveGeoModel from gisaf.database import db_session logger = logging.getLogger(__name__) ttag_function = """ CREATE OR REPLACE FUNCTION gisaf.ttag() RETURNS trigger LANGUAGE plpgsql AS $$ BEGIN PERFORM(select pg_notify('gisaf_ttag', TG_TABLE_SCHEMA || '.' || TG_TABLE_NAME)); RETURN NULL; END; $$ ; """ ttag_drop_trigger = 'DROP TRIGGER IF EXISTS gisaf_ttag ON "{schema}"."{table}";' ttag_create_trigger = """ CREATE TRIGGER gisaf_ttag AFTER INSERT OR UPDATE OR DELETE ON "{schema}"."{table}" FOR EACH STATEMENT EXECUTE FUNCTION gisaf.ttag(); """ ## From https://dba.stackexchange.com/questions/121717/get-triggers-table-names-in-postgresql get_all_triggers = """ SELECT trg.tgname as tigger_name, CASE trg.tgtype::INTEGER & 66 WHEN 2 THEN 'BEFORE' WHEN 64 THEN 'INSTEAD OF' ELSE 'AFTER' END AS trigger_type, CASE trg.tgtype::INTEGER & cast(28 AS INT2) WHEN 16 THEN 'UPDATE' WHEN 8 THEN 'DELETE' WHEN 4 THEN 'INSERT' WHEN 20 THEN 'INSERT, UPDATE' WHEN 28 THEN 'INSERT, UPDATE, DELETE' WHEN 24 THEN 'UPDATE, DELETE' WHEN 12 THEN 'INSERT, DELETE' END AS trigger_event, ns.nspname||'.'||tbl.relname AS trigger_table, obj_description(trg.oid) AS remarks, CASE WHEN trg.tgenabled='O' THEN 'ENABLED' ELSE 'DISABLED' END AS status, CASE trg.tgtype::INTEGER & 1 WHEN 1 THEN 'ROW'::TEXT ELSE 'STATEMENT'::TEXT END AS trigger_level, n.nspname || '.' || proc.proname AS function_name FROM pg_trigger trg JOIN pg_proc proc ON proc.oid = trg.tgfoid JOIN pg_catalog.pg_namespace n ON n.oid = proc.pronamespace JOIN pg_class tbl ON trg.tgrelid = tbl.oid JOIN pg_namespace ns ON ns.oid = tbl.relnamespace WHERE trg.tgname not like 'RI_ConstraintTrigger%' AND trg.tgname not like 'pg_sync_pg%'; """ class RedisError(Exception): pass class Store: """ Store for redis: - redis: RedisConnection - pub (/sub) connections """ asyncpg_conn: Connection async def setup(self, with_registry=False): """ Setup the live service for the main Gisaf application: - Create connection for the publishers - Create connection for redis listeners (websocket service) """ await self.create_connections() if with_registry: await self.get_live_layer_defs() async def create_connections(self): """ Create the connection for the publisher XXX: this should be renamed to be explicit """ self.redis = aioredis.from_url(conf.gisaf_live.redis) self.pub = self.redis.pubsub() await self.redis.client_setname(str(uuid1())) self.uuid = await self.redis.client_getname() self.uuid = str(uuid1()) def get_json_channel(self, store_name): """ Name of the Redis channel for the json representation """ return f'{store_name}:json' def get_gdf_channel(self, store_name): """ Name of the Redis channel for the source gdf, in pickle format """ return f'{store_name}:gdf' def get_layer_def_channel(self, store_name): """ Name of the Redis channel for the layer definition """ return f'{store_name}:layer_def' def get_mapbox_layout_channel(self, store_name): """ Name of the Redis channel for the mapbox layout style definition """ return f'{store_name}:mapbox_layout' def get_mapbox_paint_channel(self, store_name): """ Name of the Redis channel for the mapbox paint style definition """ return f'{store_name}:mapbox_paint' async def store_json(self, model, geojson, **kwargs): """ Store the json representation of the gdf for caching. """ ## Save in Redis channel channel = self.get_json_channel(model.get_store_name()) await self.redis.set(channel, geojson) ## XXX: publish to websocket? #await self.pub.publish(self.get_json_channel(store_name), data) async def publish(self, *args, **kwargs): """ Wrapper for publishing to the redis pubsub channel """ return await self.redis.publish(*args, **kwargs) async def publish_gdf(self, live_name, gdf, **kwargs): """ Create or update the live layer, store in redis. Additionally, publish to the channel for websocket live updates to ws_clients """ if gdf is None: gdf = gpd.GeoDataFrame(data={'geom': []}, geometry='geom') # type: ignore if isinstance(gdf.index, pd.MultiIndex): raise ValueError('Gisaf live does not accept dataframes with multi index') return await self._store_live_to_redis(live_name, gdf, **kwargs) async def _store_live_to_redis(self, live_name, gdf, properties=None, mapbox_paint=None, mapbox_layout=None, info=None, viewable_role=None, z_index=499, description='', status='L', symbol=None, color=None, attribution=None, **kwargs): """ Store and publish the live layer data and metadata to redis channels """ store_name = f'live:{live_name}' ## Filter empty geometries gdf = gdf[gdf[gdf.geometry.name].notnull()] ## Reproject eventually if 'status' not in gdf.columns: gdf['status'] = status if 'popup' not in gdf.columns: gdf['popup'] = 'Live: ' + live_name + ' #' + gdf.index.astype('U') if len(gdf) > 0: gdf = gdf.to_crs(conf.crs.geojson) gis_type = gdf.geom_type.iloc[0] else: gis_type = 'Point' ## FIXME: cannot be inferred from the gdf? mapbox_type = SHAPELY_TYPE_TO_MAPBOX_TYPE.get(gis_type, None) if not mapbox_paint: mapbox_paint = DEFAULT_MAPBOX_PAINT.get(mapbox_type, {}) if color: if mapbox_type == 'symbol': mapbox_paint['text-color'] = color if not mapbox_layout: mapbox_layout = DEFAULT_MAPBOX_LAYOUT.get(mapbox_type, {}) if symbol: mapbox_layout['text-field'] = symbol if not symbol: symbol = gisTypeSymbolMap.get(gis_type, '\ue02e') if properties is None: properties = [] ## Add a column for json representation columns = {'status', 'popup', gdf.geometry.name, 'store', 'id'} geojson = gdf[list(columns.intersection(gdf.columns).union(properties))].to_json() ## Publish to websocket await self.redis.publish(self.get_json_channel(store_name), geojson) layer_def_data = dumps({ 'store': store_name, 'z_index': z_index, 'count': len(gdf), 'type': mapbox_type, 'gis_type': gis_type, 'symbol': symbol, 'name': live_name, 'description': description, 'viewable_role': viewable_role, 'attribution': attribution, 'is_live': True, }) ## Pickle the dataframe with BytesIO() as buf: dump(gdf, buf, protocol=HIGHEST_PROTOCOL) buf.seek(0) #df_blob = buf.read() await self.redis.set(self.get_gdf_channel(store_name), buf.read()) ## Save in Redis channels await self.redis.set(self.get_json_channel(store_name), geojson) await self.redis.set( self.get_mapbox_layout_channel(store_name), dumps(mapbox_layout) ) await self.redis.set(self.get_mapbox_paint_channel(store_name), dumps(mapbox_paint)) await self.redis.set(self.get_layer_def_channel(store_name), layer_def_data) ## Update the layers/stores registry ## XXX: Commentinhg out the update of live layers: ## This should be triggerred from a redis listener #await self.get_live_layer_defs() return geojson async def get_listener(self, channel): return await self.pub.psubscribe(channel) async def remove_layer(self, store_name): """ Remove the layer from Gisaf Live (channel) """ await self.redis.delete(self.get_json_channel(store_name)) await self.redis.delete(self.get_layer_def_channel(store_name)) await self.redis.delete(self.get_gdf_channel(store_name)) await self.redis.delete(self.get_mapbox_layout_channel(store_name)) await self.redis.delete(self.get_mapbox_paint_channel(store_name)) ## Update the layers/stores registry await self.get_live_layer_defs() async def has_channel(self, store_name): return len(await self.redis.keys(self.get_json_channel(store_name))) > 0 async def get_live_layer_def_channels(self): try: return [k.decode() for k in await self.redis.keys('live:*:layer_def')] except aioredis.exceptions.ConnectionError as err: raise RedisError('Cannot use Redis server, please restart Gisaf') async def get_layer_def(self, store_name): return loads(await self.redis.get(self.get_layer_def_channel(store_name))) async def get_live_layer_defs(self): # -> list[LiveGeoModel]: registry.geom_live_defs = {} for channel in sorted(await self.get_live_layer_def_channels()): model_info = loads(await self.redis.get(channel)) registry.geom_live_defs[model_info['store']] = model_info registry.update_live_layers() async def get_maplibre_style(self, store_name) -> MaplibreStyle: """ Get the http headers (mapbox style) from the store name (layer_def) """ paint_raw = await self.redis.get(self.get_mapbox_paint_channel(store_name)) layout_raw = await self.redis.get(self.get_mapbox_layout_channel(store_name)) return MaplibreStyle( paint=loads(paint_raw.decode()), layout=loads(layout_raw.decode()), ) async def get_layer_as_json(self, store_name): """ Get the json from the store name (layer_def) """ return await self.redis.get(self.get_json_channel(store_name)) async def get_gdf(self, store_name, reproject=False): raw_data = await self.redis.get(self.get_gdf_channel(store_name)) if raw_data == None: raise RedisError(f'Cannot get {store_name}: no data') try: gdf = loads_pickle(raw_data) except Exception as err: logger.exception(err) raise RedisError(f'Cannot get {store_name}: pickle error from redis store: {err.__class__.__name__}, {err.args[0]}') if len(gdf) == 0: raise RedisError(f'Cannot get {store_name}: empty') if reproject: gdf.to_crs(conf.crs['for_proj'], inplace=True) return gdf async def get_feature_info(self, store_name, id): gdf = await self.get_gdf(store_name) ## FIXME: requires the gdf to have an integer index, used as feature['id'] on the map return gdf.loc[int(id)] async def set_ttag(self, store_name, now): """ Set the ttag for the store as 'now' """ #logger.debug(f'ttag {store_name} at {now}') await self.redis.set(f'ttag:{store_name}', now) def create_task_store_ttag(self, connection, pid, channel, store_name): """ Postgres/asyncpg listener for the trigger on data change. A task is created because this function is not asynchronous. """ if store_name in registry.stores: create_task(self.set_ttag(store_name, time())) else: logger.warn(f'Notify received for an unexisting store: {store_name}') async def get_ttag(self, store_name): """ Get the ttag for the given store. ttag is the time stamp of the last modification of the store. If no ttag is know, create one as now. """ ttag = await self.redis.get(f'ttag:{store_name}') if ttag: return ttag.decode() else: ## No ttag: Gisaf doesn't know when was the last update, ## ie it was restarted and the ttags are cleared on startup. ## Set a ttag now, using the current epoch time in seconds in hex, ## double quoted and add a W/ prefix as it's basically a weak ETag weak_now_hex = f'W/"{hex(int(time()))[2:]}"' await self.set_ttag(store_name, weak_now_hex) return weak_now_hex async def delete_all_ttags(self) -> None: """ Delete all ttags in redis """ ## Equivalient command line: redis-cli del (redis-cli --scan --pattern 'ttag:*') keys = await self.redis.keys('ttag:*') if keys: await self.redis.delete(*keys) async def _setup_db_cache_system(self) -> None: """ Setup the caching system: - clear all Redis store at startup - make sure the triggers and the "change" (insert, update, delete) event emitter function are setup on the database server - listen to the DB event emitter: setup a callback function """ ## Delete all the ttags, for safety ## eg. the database was changed and Gisaf wasn't running, so the redis store wasn't updated await store.delete_all_ttags() async with db_session() as session: ## Create the function in the database await session.exec(text(ttag_function)) ## Create DB triggers on the tables of the models all_triggers_resp = await session.exec(text(get_all_triggers)) all_triggers = all_triggers_resp.mappings().all() stores_with_trigger = {t['trigger_table'] for t in all_triggers if t['tigger_name'] == 'gisaf_ttag'} missing_triger_tables = set(registry.geom).difference(stores_with_trigger) # model: SQLModel = registry.stores.loc[store_name, 'model'] if len(missing_triger_tables) > 0: logger.info('Create Postgres modification triggers for ' f'{len(missing_triger_tables)} tables') for store_name in missing_triger_tables: ## XXX: TODO: See https://stackoverflow.com/questions/7888846/trigger-in-sqlachemy model = registry.geom[store_name] try: await session.exec(text( ttag_create_trigger.format( schema=model.__table__.schema, table=model.__table__.name) )) except UndefinedTableError: logger.warning(f'table {store_name} does not exist in ' 'the database: skip modification trigger') ## Setup triggers on Category and Qml, for Mapbox layer styling for schema, table in (('gisaf_map', 'qml'), ('gisaf_survey', 'category')): triggers = [t for t in all_triggers if t['tigger_name'] == 'gisaf_ttag' and t['trigger_table'] == f'{schema}.{table}'] if len(triggers) == 0: await session.exec(text( ttag_create_trigger.format(schema=schema, table=table) )) ## Listen: define the callback function self.asyncpg_conn = await connect(conf.db.get_pg_url()) await self.asyncpg_conn.add_listener('gisaf_ttag', store.create_task_store_ttag) async def _close_permanant_db_connection(self): """ Called at aiohttp server shutdown: remove the listener and close the connections """ try: await self.asyncpg_conn.remove_listener( 'gisaf_ttag', store.create_task_store_ttag) except InterfaceError as err: logger.warning(f'Cannot remove asyncpg listener in _close_permanant_db_connection: {err}') await self.asyncpg_conn.close() async def setup_redis(): global store await store.setup(with_registry=True) async def setup_redis_cache(): global store await store._setup_db_cache_system() async def shutdown_redis(): if not hasattr(self, 'asyncpg_conn'): return global store await store._close_permanant_db_connection() store = Store()