from uuid import uuid1 from io import BytesIO from asyncio import create_task from json import loads, dumps from pickle import dump, HIGHEST_PROTOCOL, loads as loads_pickle from time import time import logging import pandas as pd import geopandas as gpd from asyncpg import connect from asyncpg.connection import Connection from asyncpg.exceptions import UndefinedTableError, InterfaceError from sqlalchemy import text from redis import asyncio as aioredis from gisaf.config import conf # from gisaf.models.live import LiveModel from gisaf.utils import (SHAPELY_TYPE_TO_MAPBOX_TYPE, DEFAULT_MAPBOX_LAYOUT, DEFAULT_MAPBOX_PAINT, gisTypeSymbolMap) from gisaf.registry import registry #from .models.geom import GeomGroup, GeomModel from gisaf.models.geo_models_base import LiveGeoModel from gisaf.database import db_session logger = logging.getLogger(__name__) ttag_function = """ CREATE OR REPLACE FUNCTION gisaf.ttag() RETURNS trigger LANGUAGE plpgsql AS $$ BEGIN PERFORM(select pg_notify('gisaf_ttag', TG_TABLE_SCHEMA || '.' || TG_TABLE_NAME)); RETURN NULL; END; $$ ; """ ttag_drop_trigger = 'DROP TRIGGER IF EXISTS gisaf_ttag ON "{schema}"."{table}";' ttag_create_trigger = """ CREATE TRIGGER gisaf_ttag AFTER INSERT OR UPDATE OR DELETE ON "{schema}"."{table}" FOR EACH STATEMENT EXECUTE FUNCTION gisaf.ttag(); """ ## From https://dba.stackexchange.com/questions/121717/get-triggers-table-names-in-postgresql get_all_triggers = """ SELECT trg.tgname as tigger_name, CASE trg.tgtype::INTEGER & 66 WHEN 2 THEN 'BEFORE' WHEN 64 THEN 'INSTEAD OF' ELSE 'AFTER' END AS trigger_type, CASE trg.tgtype::INTEGER & cast(28 AS INT2) WHEN 16 THEN 'UPDATE' WHEN 8 THEN 'DELETE' WHEN 4 THEN 'INSERT' WHEN 20 THEN 'INSERT, UPDATE' WHEN 28 THEN 'INSERT, UPDATE, DELETE' WHEN 24 THEN 'UPDATE, DELETE' WHEN 12 THEN 'INSERT, DELETE' END AS trigger_event, ns.nspname||'.'||tbl.relname AS trigger_table, obj_description(trg.oid) AS remarks, CASE WHEN trg.tgenabled='O' THEN 'ENABLED' ELSE 'DISABLED' END AS status, CASE trg.tgtype::INTEGER & 1 WHEN 1 THEN 'ROW'::TEXT ELSE 'STATEMENT'::TEXT END AS trigger_level, n.nspname || '.' || proc.proname AS function_name FROM pg_trigger trg JOIN pg_proc proc ON proc.oid = trg.tgfoid JOIN pg_catalog.pg_namespace n ON n.oid = proc.pronamespace JOIN pg_class tbl ON trg.tgrelid = tbl.oid JOIN pg_namespace ns ON ns.oid = tbl.relnamespace WHERE trg.tgname not like 'RI_ConstraintTrigger%' AND trg.tgname not like 'pg_sync_pg%'; """ class RedisError(Exception): pass class Store: """ Store for redis: - redis: RedisConnection - pub (/sub) connections """ asyncpg_conn: Connection async def setup(self): """ Setup the live service for the main Gisaf application: - Create connection for the publishers - Create connection for redis listeners (websocket service) """ await self.create_connections() await self.get_live_layer_defs() async def create_connections(self): """ Create the connection for the publisher XXX: this should be renamed to be explicit """ self.redis = aioredis.from_url(conf.gisaf_live.redis) self.pub = self.redis.pubsub() await self.redis.client_setname(str(uuid1())) self.uuid = await self.redis.client_getname() self.uuid = str(uuid1()) def get_json_channel(self, store_name): """ Name of the Redis channel for the json representation """ return f'{store_name}:json' def get_gdf_channel(self, store_name): """ Name of the Redis channel for the source gdf, in pickle format """ return f'{store_name}:gdf' def get_layer_def_channel(self, store_name): """ Name of the Redis channel for the layer definition """ return f'{store_name}:layer_def' def get_mapbox_layout_channel(self, store_name): """ Name of the Redis channel for the mapbox layout style definition """ return f'{store_name}:mapbox_layout' def get_mapbox_paint_channel(self, store_name): """ Name of the Redis channel for the mapbox paint style definition """ return f'{store_name}:mapbox_paint' async def store_json(self, model, geojson, **kwargs): """ Store the json representation of the gdf for caching. """ ## Save in Redis channel channel = self.get_json_channel(model.get_store_name()) await self.redis.set(channel, geojson) ## XXX: publish to websocket? #await self.pub.publish(self.get_json_channel(store_name), data) async def publish(self, *args, **kwargs): """ Wrapper for publishing to the redis pubsub channel """ return await self.redis.publish(*args, **kwargs) async def publish_gdf(self, live_name, gdf, **kwargs): """ Create or update the live layer, store in redis. Additionally, publish to the channel for websocket live updates to ws_clients """ if gdf is None: gdf = gpd.GeoDataFrame(data={'geom': []}, geometry='geom') if isinstance(gdf.index, pd.core.indexes.multi.MultiIndex): raise ValueError('Gisaf live does not accept dataframes with multi index') return await self._store_live_to_redis(live_name, gdf, **kwargs) async def _store_live_to_redis(self, live_name, gdf, properties=None, mapbox_paint=None, mapbox_layout=None, info=None, viewable_role=None, z_index=499, description='', status='L', symbol=None, color=None, attribution=None, **kwargs): """ Store and publish the live layer data and metadata to redis channels """ store_name = f'live:{live_name}' ## Filter empty geometries gdf = gdf[gdf[gdf.geometry.name].notnull()] ## Reproject eventually if 'status' not in gdf.columns: gdf['status'] = status if 'popup' not in gdf.columns: gdf['popup'] = 'Live: ' + live_name + ' #' + gdf.index.astype('U') if len(gdf) > 0: gdf = gdf.to_crs(conf.crs.geojson) gis_type = gdf.gis_type.iloc[0] else: gis_type = 'Point' ## FIXME: cannot be inferred from the gdf? mapbox_type = SHAPELY_TYPE_TO_MAPBOX_TYPE.get(gis_type, None) if not mapbox_paint: mapbox_paint = DEFAULT_MAPBOX_PAINT.get(mapbox_type, {}) if color: if mapbox_type == 'symbol': mapbox_paint['text-color'] = color if not mapbox_layout: mapbox_layout = DEFAULT_MAPBOX_LAYOUT.get(mapbox_type, {}) if symbol: mapbox_layout['text-field'] = symbol if not symbol: symbol = gisTypeSymbolMap.get(gis_type, '\ue02e') if properties == None: properties = [] ## Add a column for json representation columns = {'status', 'popup', gdf.geometry.name, 'store', 'id'} geojson = gdf[list(columns.intersection(gdf.columns).union(properties))].to_json() ## Publish to websocket await self.redis.publish(self.get_json_channel(store_name), geojson) layer_def_data = dumps({ 'store': store_name, 'z_index': z_index, 'count': len(gdf), 'mapbox_type': mapbox_type, 'gis_type': gis_type, 'symbol': symbol, 'name': live_name, 'description': description, 'viewable_role': viewable_role, 'attribution': attribution, 'is_live': True, }) ## Pickle the dataframe with BytesIO() as buf: dump(gdf, buf, protocol=HIGHEST_PROTOCOL) buf.seek(0) #df_blob = buf.read() await self.redis.set(self.get_gdf_channel(store_name), buf.read()) ## Save in Redis channels await self.redis.set(self.get_json_channel(store_name), geojson) await self.redis.set( self.get_mapbox_layout_channel(store_name), dumps(mapbox_layout) ) await self.redis.set(self.get_mapbox_paint_channel(store_name), dumps(mapbox_paint)) await self.redis.set(self.get_layer_def_channel(store_name), layer_def_data) ## Update the layers/stores registry await self.get_live_layer_defs() return geojson async def get_listener(self, channel): return await self.pub.psubscribe(channel) async def remove_layer(self, store_name): """ Remove the layer from Gisaf Live (channel) """ await self.redis.delete(self.get_json_channel(store_name)) await self.redis.delete(self.get_layer_def_channel(store_name)) await self.redis.delete(self.get_gdf_channel(store_name)) await self.redis.delete(self.get_mapbox_layout_channel(store_name)) await self.redis.delete(self.get_mapbox_paint_channel(store_name)) ## Update the layers/stores registry await self.get_live_layer_defs() async def has_channel(self, store_name): return len(await self.redis.keys(self.get_json_channel(store_name))) > 0 async def get_live_layer_def_channels(self): try: return [k.decode() for k in await self.redis.keys('live:*:layer_def')] except aioredis.exceptions.ConnectionError as err: raise RedisError('Cannot use Redis server, please restart Gisaf') async def get_layer_def(self, store_name): return loads(await self.redis.get(self.get_layer_def_channel(store_name))) async def get_live_layer_defs(self): # -> list[LiveGeoModel]: registry.geom_live_defs = {} for channel in sorted(await self.get_live_layer_def_channels()): model_info = loads(await self.redis.get(channel)) registry.geom_live_defs[model_info['store']] = model_info registry.update_live_layers() async def get_mapbox_style(self, store_name): """ Get the http headers (mapbox style) from the store name (layer_def) """ paint = await self.redis.get(self.get_mapbox_paint_channel(store_name)) layout = await self.redis.get(self.get_mapbox_layout_channel(store_name)) style = {} if paint is not None: style['paint'] = paint.decode() if layout is not None: style['layout'] = layout.decode() return style async def get_layer_as_json(self, store_name): """ Get the json from the store name (layer_def) """ return await self.redis.get(self.get_json_channel(store_name)) async def get_gdf(self, store_name, reproject=False): raw_data = await self.redis.get(self.get_gdf_channel(store_name)) if raw_data == None: raise RedisError(f'Cannot get {store_name}: no data') try: gdf = loads_pickle(raw_data) except Exception as err: logger.exception(err) raise RedisError(f'Cannot get {store_name}: pickle error from redis store: {err.__class__.__name__}, {err.args[0]}') if len(gdf) == 0: raise RedisError(f'Cannot get {store_name}: empty') if reproject: gdf.to_crs(conf.crs['for_proj'], inplace=True) return gdf async def get_feature_info(self, store_name, id): gdf = await self.get_gdf(store_name) ## FIXME: requires the gdf to have an integer index, used as feature['id'] on the map return gdf.loc[int(id)] async def set_ttag(self, store_name, now): """ Set the ttag for the store as 'now' """ #logger.debug(f'ttag {store_name} at {now}') await self.redis.set(f'ttag:{store_name}', now) def create_task_store_ttag(self, connection, pid, channel, store_name): """ Postgres/asyncpg listener for the trigger on data change. A task is created because this function is not asynchronous. """ if store_name in registry.stores: create_task(self.set_ttag(store_name, time())) else: logger.warn(f'Notify received for an unexisting store: {store_name}') async def get_ttag(self, store_name): """ Get the ttag for the given store. ttag is the time stamp of the last modification of the store. If no ttag is know, create one as now. """ ttag = await self.redis.get(f'ttag:{store_name}') if ttag: return ttag.decode() else: ## No ttag: Gisaf doesn't know when was the last update, ## ie it was restarted and the ttags are cleared on startup. ## Set a ttag now, using the current epoch time in seconds in hex, ## double quoted and add a W/ prefix as it's basically a weak ETag weak_now_hex = f'W/"{hex(int(time()))[2:]}"' await self.set_ttag(store_name, weak_now_hex) return weak_now_hex async def delete_all_ttags(self) -> None: """ Delete all ttags in redis """ ## Equivalient command line: redis-cli del (redis-cli --scan --pattern 'ttag:*') keys = await self.redis.keys('ttag:*') if keys: await self.redis.delete(*keys) async def _setup_db_cache_system(self) -> None: """ Setup the caching system: - clear all Redis store at startup - make sure the triggers and the "change" (insert, update, delete) event emitter function are setup on the database server - listen to the DB event emitter: setup a callback function """ ## Delete all the ttags, for safety ## eg. the database was changed and Gisaf wasn't running, so the redis store wasn't updated await store.delete_all_ttags() async with db_session() as session: ## Create the function in the database await session.exec(text(ttag_function)) ## Create DB triggers on the tables of the models all_triggers_resp = await session.exec(text(get_all_triggers)) all_triggers = all_triggers_resp.mappings().all() stores_with_trigger = {t['trigger_table'] for t in all_triggers if t['tigger_name'] == 'gisaf_ttag'} missing_triger_tables = set(registry.geom).difference(stores_with_trigger) # model: SQLModel = registry.stores.loc[store_name, 'model'] if len(missing_triger_tables) > 0: logger.info('Create Postgres modification triggers for ' f'{len(missing_triger_tables)} tables') for store_name in missing_triger_tables: ## XXX: TODO: See https://stackoverflow.com/questions/7888846/trigger-in-sqlachemy model = registry.geom[store_name] try: await session.exec(text( ttag_create_trigger.format( schema=model.__table__.schema, table=model.__table__.name) )) except UndefinedTableError: logger.warning(f'table {store_name} does not exist in ' 'the database: skip modification trigger') ## Setup triggers on Category and Qml, for Mapbox layer styling for schema, table in (('gisaf_map', 'qml'), ('gisaf_survey', 'category')): triggers = [t for t in all_triggers if t['tigger_name'] == 'gisaf_ttag' and t['trigger_table'] == f'{schema}.{table}'] if len(triggers) == 0: await session.exec(text( ttag_create_trigger.format(schema=schema, table=table) )) ## Listen: define the callback function self.asyncpg_conn = await connect(conf.db.get_pg_url()) await self.asyncpg_conn.add_listener('gisaf_ttag', store.create_task_store_ttag) async def _close_permanant_db_connection(self): """ Called at aiohttp server shutdown: remove the listener and close the connections """ try: await self.asyncpg_conn.remove_listener( 'gisaf_ttag', store.create_task_store_ttag) except InterfaceError as err: logger.warning(f'Cannot remove asyncpg listener in _close_permanant_db_connection: {err}') await self.asyncpg_conn.close() async def setup_redis(): global store await store.setup() async def setup_redis_cache(): global store await store._setup_db_cache_system() async def shutdown_redis(): global store await store._close_permanant_db_connection() store = Store()