from typing import ClassVar from uuid import uuid1 from io import BytesIO from asyncio import create_task from json import loads, dumps from pickle import dump, HIGHEST_PROTOCOL, loads as loads_pickle from time import time import logging import pandas as pd import geopandas as gpd from asyncpg.exceptions import UndefinedTableError, InterfaceError from redis import asyncio as aioredis from pydantic import create_model from .config import conf # from gisaf.models.live import LiveModel from .utils import (SHAPELY_TYPE_TO_MAPBOX_TYPE, DEFAULT_MAPBOX_LAYOUT, DEFAULT_MAPBOX_PAINT, gisTypeSymbolMap) from .registry import registry #from .models.geom import GeomGroup, GeomModel from .models.geo_models_base import LiveGeoModel logger = logging.getLogger(__name__) ttag_function = """ CREATE OR REPLACE FUNCTION gisaf.ttag() RETURNS trigger LANGUAGE plpgsql AS $$ BEGIN PERFORM(select pg_notify('gisaf_ttag', TG_TABLE_SCHEMA || '.' || TG_TABLE_NAME)); RETURN NULL; END; $$ ; """ ttag_drop_trigger = 'DROP TRIGGER IF EXISTS gisaf_ttag ON "{schema}"."{table}";' ttag_create_trigger = """ CREATE TRIGGER gisaf_ttag AFTER INSERT OR UPDATE OR DELETE ON "{schema}"."{table}" FOR EACH STATEMENT EXECUTE FUNCTION gisaf.ttag(); """ ## From https://dba.stackexchange.com/questions/121717/get-triggers-table-names-in-postgresql get_all_triggers = """ SELECT trg.tgname as tigger_name, CASE trg.tgtype::INTEGER & 66 WHEN 2 THEN 'BEFORE' WHEN 64 THEN 'INSTEAD OF' ELSE 'AFTER' END AS trigger_type, CASE trg.tgtype::INTEGER & cast(28 AS INT2) WHEN 16 THEN 'UPDATE' WHEN 8 THEN 'DELETE' WHEN 4 THEN 'INSERT' WHEN 20 THEN 'INSERT, UPDATE' WHEN 28 THEN 'INSERT, UPDATE, DELETE' WHEN 24 THEN 'UPDATE, DELETE' WHEN 12 THEN 'INSERT, DELETE' END AS trigger_event, ns.nspname||'.'||tbl.relname AS trigger_table, obj_description(trg.oid) AS remarks, CASE WHEN trg.tgenabled='O' THEN 'ENABLED' ELSE 'DISABLED' END AS status, CASE trg.tgtype::INTEGER & 1 WHEN 1 THEN 'ROW'::TEXT ELSE 'STATEMENT'::TEXT END AS trigger_level, n.nspname || '.' || proc.proname AS function_name FROM pg_trigger trg JOIN pg_proc proc ON proc.oid = trg.tgfoid JOIN pg_catalog.pg_namespace n ON n.oid = proc.pronamespace JOIN pg_class tbl ON trg.tgrelid = tbl.oid JOIN pg_namespace ns ON ns.oid = tbl.relnamespace WHERE trg.tgname not like 'RI_ConstraintTrigger%' AND trg.tgname not like 'pg_sync_pg%'; """ class RedisError(Exception): pass class Store: """ Store for redis: - redis: RedisConnection - pub (/sub) connections """ async def setup(self): """ Setup the live service for the main Gisaf application: - Create connection for the publishers - Create connection for redis listeners (websocket service) """ await self.create_connections() await self.get_live_layer_defs() async def create_connections(self): """ Create the connection for the publisher XXX: this should be renamed to be explicit """ self.redis = aioredis.from_url(conf.gisaf_live.redis) self.pub = self.redis.pubsub() await self.redis.client_setname(str(uuid1())) self.uuid = await self.redis.client_getname() self.uuid = str(uuid1()) def get_json_channel(self, store_name): """ Name of the Redis channel for the json representation """ return f'{store_name}:json' def get_gdf_channel(self, store_name): """ Name of the Redis channel for the source gdf, in pickle format """ return f'{store_name}:gdf' def get_layer_def_channel(self, store_name): """ Name of the Redis channel for the layer definition """ return f'{store_name}:layer_def' def get_mapbox_layout_channel(self, store_name): """ Name of the Redis channel for the mapbox layout style definition """ return f'{store_name}:mapbox_layout' def get_mapbox_paint_channel(self, store_name): """ Name of the Redis channel for the mapbox paint style definition """ return f'{store_name}:mapbox_paint' async def store_json(self, model, geojson, **kwargs): """ Store the json representation of the gdf for caching. """ ## Save in Redis channel channel = self.get_json_channel(model.get_store_name()) await self.redis.set(channel, geojson) ## XXX: publish to websocket? #await self.pub.publish(self.get_json_channel(store_name), data) async def publish(self, *args, **kwargs): """ Wrapper for publishing to the redis pubsub channel """ return await self.redis.publish(*args, **kwargs) async def publish_gdf(self, live_name, gdf, **kwargs): """ Create or update the live layer, store in redis. Additionally, publish to the channel for websocket live updates to ws_clients """ if gdf is None: gdf = gpd.GeoDataFrame(data={'geom': []}, geometry='geom') if isinstance(gdf.index, pd.core.indexes.multi.MultiIndex): raise ValueError('Gisaf live does not accept dataframes with multi index') return await self._store_live_to_redis(live_name, gdf, **kwargs) async def _store_live_to_redis(self, live_name, gdf, properties=None, mapbox_paint=None, mapbox_layout=None, info=None, viewable_role=None, z_index=499, description='', status='L', symbol=None, color=None, attribution=None, **kwargs): """ Store and publish the live layer data and metadata to redis channels """ store_name = f'live:{live_name}' ## Filter empty geometries gdf = gdf[gdf[gdf.geometry.name].notnull()] ## Reproject eventually if 'status' not in gdf.columns: gdf['status'] = status if 'popup' not in gdf.columns: gdf['popup'] = 'Live: ' + live_name + ' #' + gdf.index.astype('U') if len(gdf) > 0: gdf = gdf.to_crs(conf.crs.geojson) gis_type = gdf.geom_type.iloc[0] else: gis_type = 'Point' ## FIXME: cannot be inferred from the gdf? mapbox_type = SHAPELY_TYPE_TO_MAPBOX_TYPE.get(gis_type, None) if not mapbox_paint: mapbox_paint = DEFAULT_MAPBOX_PAINT.get(mapbox_type, {}) if color: if mapbox_type == 'symbol': mapbox_paint['text-color'] = color if not mapbox_layout: mapbox_layout = DEFAULT_MAPBOX_LAYOUT.get(mapbox_type, {}) if symbol: mapbox_layout['text-field'] = symbol if not symbol: symbol = gisTypeSymbolMap.get(gis_type, '\ue02e') if properties == None: properties = [] ## Add a column for json representation columns = {'status', 'popup', gdf.geometry.name, 'store', 'id'} geojson = gdf[list(columns.intersection(gdf.columns).union(properties))].to_json() ## Publish to websocket await self.redis.publish(self.get_json_channel(store_name), geojson) layer_def_data = dumps({ 'store': store_name, 'z_index': z_index, 'count': len(gdf), 'mapbox_type': mapbox_type, 'gis_type': gis_type, 'symbol': symbol, 'name': live_name, 'description': description, 'viewable_role': viewable_role, 'attribution': attribution, 'is_live': True, }) ## Pickle the dataframe with BytesIO() as buf: dump(gdf, buf, protocol=HIGHEST_PROTOCOL) buf.seek(0) #df_blob = buf.read() await self.redis.set(self.get_gdf_channel(store_name), buf.read()) ## Save in Redis channels await self.redis.set(self.get_json_channel(store_name), geojson) await self.redis.set( self.get_mapbox_layout_channel(store_name), dumps(mapbox_layout) ) await self.redis.set(self.get_mapbox_paint_channel(store_name), dumps(mapbox_paint)) await self.redis.set(self.get_layer_def_channel(store_name), layer_def_data) ## Update the layers/stores registry await self.get_live_layer_defs() return geojson async def get_listener(self, channel): return await self.pub.psubscribe(channel) async def remove_layer(self, store_name): """ Remove the layer from Gisaf Live (channel) """ await self.redis.delete(self.get_json_channel(store_name)) await self.redis.delete(self.get_layer_def_channel(store_name)) await self.redis.delete(self.get_gdf_channel(store_name)) await self.redis.delete(self.get_mapbox_layout_channel(store_name)) await self.redis.delete(self.get_mapbox_paint_channel(store_name)) ## Update the layers/stores registry await self.get_live_layer_defs() async def has_channel(self, store_name): return len(await self.redis.keys(self.get_json_channel(store_name))) > 0 async def get_live_layer_def_channels(self): try: return [k.decode() for k in await self.redis.keys('live:*:layer_def')] except aioredis.exceptions.ConnectionError as err: raise RedisError('Cannot use Redis server, please restart Gisaf') async def get_layer_def(self, store_name): return loads(await self.redis.get(self.get_layer_def_channel(store_name))) async def get_live_layer_defs(self): # -> list[LiveGeoModel]: registry.geom_live_defs = {} for channel in sorted(await self.get_live_layer_def_channels()): model_info = loads(await self.redis.get(channel)) registry.geom_live_defs[model_info['store']] = model_info registry.update_live_layers() async def get_mapbox_style(self, store_name): """ Get the http headers (mapbox style) from the store name (layer_def) """ paint = await self.redis.get(self.get_mapbox_paint_channel(store_name)) layout = await self.redis.get(self.get_mapbox_layout_channel(store_name)) style = {} if paint is not None: style['paint'] = paint.decode() if layout is not None: style['layout'] = layout.decode() return style async def get_layer_as_json(self, store_name): """ Get the json from the store name (layer_def) """ return await self.redis.get(self.get_json_channel(store_name)) async def get_gdf(self, store_name, reproject=False): raw_data = await self.redis.get(self.get_gdf_channel(store_name)) if raw_data == None: raise RedisError(f'Cannot get {store_name}: no data') try: gdf = loads_pickle(raw_data) except Exception as err: logger.exception(err) raise RedisError(f'Cannot get {store_name}: pickle error from redis store: {err.__class__.__name__}, {err.args[0]}') if len(gdf) == 0: raise RedisError(f'Cannot get {store_name}: empty') if reproject: gdf.to_crs(conf.crs['for_proj'], inplace=True) return gdf async def get_feature_info(self, store_name, id): gdf = await self.get_gdf(store_name) ## FIXME: requires the gdf to have an integer index, used as feature['id'] on the map return gdf.loc[int(id)] async def set_ttag(self, store_name, now): """ Set the ttag for the store as 'now' """ #logger.debug(f'ttag {store_name} at {now}') await self.redis.set(f'ttag:{store_name}', now) def create_task_store_ttag(self, connection, pid, channel, store_name): """ Postgres/asyncpg listener for the trigger on data change. A task is created because this function is not asynchronous. """ create_task(self.set_ttag(store_name, time())) async def get_ttag(self, store_name): """ Get the ttag for the given store. ttag is the time stamp of the last modification of the store. If no ttag is know, create one as now. """ ttag = await self.redis.get(f'ttag:{store_name}') if ttag: return ttag.decode() else: ## No ttag: Gisaf doesn't know when was the last update, ## ie it was restarted and the ttags are cleared on startup. ## Set a ttag now, using the current epoch time in seconds in hex, ## double quoted and add a W/ prefix as it's basically a weak ETag weak_now_hex = f'W/"{hex(int(time()))[2:]}"' await self.set_ttag(store_name, weak_now_hex) return weak_now_hex async def delete_all_ttags(self): """ Delete all ttags in redis """ ## Equivalient command line: redis-cli del (redis-cli --scan --pattern 'ttag:*') keys = await self.redis.keys('ttag:*') if keys: await self.redis.delete(*keys) async def _setup_db_cache_system(self): """ Setup the caching system: - clear all Redis store at startup - make sure the triggers and the "change" (insert, update, delete) event emitter function are setup on the database server - listen to the DB event emitter: setup a callback function """ ## Setup the function and triggers on tables ## Keep the connection alive: don't use a "with" block ## It needs to be closed correctly: see _close_permanant_db_connection self._permanent_conn = await db.acquire() self._permanent_raw_conn = await self._permanent_conn.get_raw_connection() ## Create the function in the database await self._permanent_raw_conn.execute(ttag_function) ## Delete all the ttags, for safety ## eg. the database was changed and Gisaf wasn't running, so the redis store wasn't updated await store.delete_all_ttags() ## Create DB triggers on the tables of the models all_triggers = await self._permanent_raw_conn.fetch(get_all_triggers) stores_with_trigger = {t['trigger_table'] for t in all_triggers if t['tigger_name'] == 'gisaf_ttag'} missing_triger_tables = set(registry.geom).difference(stores_with_trigger) if len(missing_triger_tables) > 0: logger.info(f'Create Postgres modification triggers for {len(missing_triger_tables)} tables') for store_name in missing_triger_tables: model = registry.geom[store_name] try: await self._permanent_raw_conn.execute(ttag_create_trigger.format( schema=model.__table__.schema, table=model.__table__.name)) except UndefinedTableError: logger.warning(f'table {store_name} does not exist in the database: skip modification trigger') ## Setup triggers on Category and Qml, for Mapbox layer styling for schema, table in (('gisaf_map', 'qml'), ('gisaf_survey', 'category')): triggers = [t for t in all_triggers if t['tigger_name'] == 'gisaf_ttag' and t['trigger_table'] == f'{schema}.{table}'] if len(triggers) == 0: await self._permanent_raw_conn.execute(ttag_create_trigger.format(schema=schema, table=table)) ## Listen: define the callback function await self._permanent_raw_conn.add_listener('gisaf_ttag', store.create_task_store_ttag) async def _close_permanant_db_connection(self): """ Called at aiohttp server shutdown: remove the listener and close the connections """ try: await self._permanent_raw_conn.remove_listener('gisaf_ttag', store.create_task_store_ttag) except InterfaceError as err: logger.warning(f'Cannot remove asyncpg listener in _close_permanant_db_connection: {err}') await self._permanent_raw_conn.close() await self._permanent_conn.release() async def setup_redis(): global store await store.setup() async def setup_redis_cache(): global store await store._setup_db_cache_system() async def shutdown_redis(): global store await store._close_permanant_db_connection() store = Store()