diff --git a/src/gisaf/redis_tools.py b/src/gisaf/redis_tools.py index 72e3bf1..663795a 100644 --- a/src/gisaf/redis_tools.py +++ b/src/gisaf/redis_tools.py @@ -7,7 +7,7 @@ from time import time import logging import pandas as pd -import geopandas as gpd # type: ignore[import-untyped] +import geopandas as gpd # type: ignore[import-untyped] from asyncpg import connect from asyncpg.connection import Connection from asyncpg.exceptions import UndefinedTableError, InterfaceError @@ -15,12 +15,18 @@ from sqlalchemy import text from redis import asyncio as aioredis from gisaf.config import conf + # from gisaf.models.live import LiveModel from gisaf.models.map_bases import MaplibreStyle -from gisaf.utils import (SHAPELY_TYPE_TO_MAPBOX_TYPE, DEFAULT_MAPBOX_LAYOUT, - DEFAULT_MAPBOX_PAINT, gisTypeSymbolMap) +from gisaf.utils import ( + SHAPELY_TYPE_TO_MAPBOX_TYPE, + DEFAULT_MAPBOX_LAYOUT, + DEFAULT_MAPBOX_PAINT, + gisTypeSymbolMap, +) from gisaf.registry import registry -#from .models.geom import GeomGroup, GeomModel + +# from .models.geom import GeomGroup, GeomModel from gisaf.models.geo_models_base import LiveGeoModel from gisaf.database import db_session @@ -84,15 +90,18 @@ WHERE AND trg.tgname not like 'pg_sync_pg%'; """ + class RedisError(Exception): pass + class Store: """ Store for redis: - redis: RedisConnection - pub (/sub) connections """ + asyncpg_conn: Connection async def setup(self, with_registry=False): @@ -116,36 +125,35 @@ class Store: self.uuid = await self.redis.client_getname() self.uuid = str(uuid1()) - def get_json_channel(self, store_name): """ Name of the Redis channel for the json representation """ - return f'{store_name}:json' + return f"{store_name}:json" def get_gdf_channel(self, store_name): """ Name of the Redis channel for the source gdf, in pickle format """ - return f'{store_name}:gdf' + return f"{store_name}:gdf" def get_layer_def_channel(self, store_name): """ Name of the Redis channel for the layer definition """ - return f'{store_name}:layer_def' + return f"{store_name}:layer_def" def get_mapbox_layout_channel(self, store_name): """ Name of the Redis channel for the mapbox layout style definition """ - return f'{store_name}:mapbox_layout' + return f"{store_name}:mapbox_layout" def get_mapbox_paint_channel(self, store_name): """ Name of the Redis channel for the mapbox paint style definition """ - return f'{store_name}:mapbox_paint' + return f"{store_name}:mapbox_paint" async def store_json(self, model, geojson, **kwargs): """ @@ -155,7 +163,7 @@ class Store: channel = self.get_json_channel(model.get_store_name()) await self.redis.set(channel, geojson) ## XXX: publish to websocket? - #await self.pub.publish(self.get_json_channel(store_name), data) + # await self.pub.publish(self.get_json_channel(store_name), data) async def publish(self, *args, **kwargs): """ @@ -169,84 +177,102 @@ class Store: Additionally, publish to the channel for websocket live updates to ws_clients """ if gdf is None: - gdf = gpd.GeoDataFrame(data={'geom': []}, geometry='geom') # type: ignore + gdf = gpd.GeoDataFrame(data={"geom": []}, geometry="geom") # type: ignore if isinstance(gdf.index, pd.MultiIndex): - raise ValueError('Gisaf live does not accept dataframes with multi index') + raise ValueError("Gisaf live does not accept dataframes with multi index") return await self._store_live_to_redis(live_name, gdf, **kwargs) - async def _store_live_to_redis(self, live_name, gdf, properties=None, - mapbox_paint=None, mapbox_layout=None, info=None, - viewable_role=None, z_index=499, description='', - status='L', symbol=None, color=None, attribution=None, - **kwargs): + async def _store_live_to_redis( + self, + live_name, + gdf, + properties=None, + mapbox_paint=None, + mapbox_layout=None, + info=None, + viewable_role=None, + z_index=499, + description="", + status="L", + symbol=None, + color=None, + attribution=None, + **kwargs, + ): """ Store and publish the live layer data and metadata to redis channels """ - store_name = f'live:{live_name}' + store_name = f"live:{live_name}" ## Filter empty geometries gdf = gdf[gdf[gdf.geometry.name].notnull()] ## Reproject eventually - if 'status' not in gdf.columns: - gdf['status'] = status - if 'popup' not in gdf.columns: - gdf['popup'] = 'Live: ' + live_name + ' #' + gdf.index.astype('U') + if "status" not in gdf.columns: + gdf["status"] = status + if "popup" not in gdf.columns: + gdf["popup"] = "Live: " + live_name + " #" + gdf.index.astype("U") if len(gdf) > 0: gdf = gdf.to_crs(conf.crs.geojson) gis_type = gdf.geom_type.iloc[0] else: - gis_type = 'Point' ## FIXME: cannot be inferred from the gdf? + gis_type = "Point" ## FIXME: cannot be inferred from the gdf? mapbox_type = SHAPELY_TYPE_TO_MAPBOX_TYPE.get(gis_type, None) if not mapbox_paint: mapbox_paint = DEFAULT_MAPBOX_PAINT.get(mapbox_type, {}) if color: - if mapbox_type == 'symbol': - mapbox_paint['text-color'] = color + if mapbox_type == "symbol": + mapbox_paint["text-color"] = color if not mapbox_layout: mapbox_layout = DEFAULT_MAPBOX_LAYOUT.get(mapbox_type, {}) if symbol: - mapbox_layout['text-field'] = symbol + mapbox_layout["text-field"] = symbol if not symbol: - symbol = gisTypeSymbolMap.get(gis_type, '\ue02e') + symbol = gisTypeSymbolMap.get(gis_type, "\ue02e") if properties is None: properties = [] ## Add a column for json representation - columns = {'status', 'popup', gdf.geometry.name, 'store', 'id'} - geojson = gdf[list(columns.intersection(gdf.columns).union(properties))].to_json() + columns = {"status", "popup", gdf.geometry.name, "store", "id"} + geojson = gdf[ + list(columns.intersection(gdf.columns).union(properties)) + ].to_json() ## Publish to websocket await self.redis.publish(self.get_json_channel(store_name), geojson) - layer_def_data = dumps({ - 'store': store_name, - 'z_index': z_index, - 'count': len(gdf), - 'type': mapbox_type, - 'gis_type': gis_type, - 'symbol': symbol, - 'name': live_name, - 'description': description, - 'viewable_role': viewable_role, - 'attribution': attribution, - 'is_live': True, - }) + layer_def_data = dumps( + { + "store": store_name, + "z_index": z_index, + "count": len(gdf), + "type": mapbox_type, + "gis_type": gis_type, + "symbol": symbol, + "name": live_name, + "description": description, + "viewable_role": viewable_role, + "attribution": attribution, + "is_live": True, + } + ) ## Pickle the dataframe with BytesIO() as buf: dump(gdf, buf, protocol=HIGHEST_PROTOCOL) buf.seek(0) - #df_blob = buf.read() + # df_blob = buf.read() await self.redis.set(self.get_gdf_channel(store_name), buf.read()) ## Save in Redis channels await self.redis.set(self.get_json_channel(store_name), geojson) + breakpoint() await self.redis.set( - self.get_mapbox_layout_channel(store_name), - dumps(mapbox_layout) + self.get_mapbox_layout_channel(store_name), dumps(mapbox_layout) + ) + await self.redis.set( + self.get_mapbox_paint_channel(store_name), dumps(mapbox_paint) ) - await self.redis.set(self.get_mapbox_paint_channel(store_name), dumps(mapbox_paint)) await self.redis.set(self.get_layer_def_channel(store_name), layer_def_data) ## Update the layers/stores registry ## XXX: Commentinhg out the update of live layers: ## This should be triggerred from a redis listener - #await self.get_live_layer_defs() + # await self.get_live_layer_defs() return geojson @@ -271,18 +297,18 @@ class Store: async def get_live_layer_def_channels(self): try: - return [k.decode() for k in await self.redis.keys('live:*:layer_def')] + return [k.decode() for k in await self.redis.keys("live:*:layer_def")] except aioredis.exceptions.ConnectionError as err: - raise RedisError('Cannot use Redis server, please restart Gisaf') + raise RedisError("Cannot use Redis server, please restart Gisaf") async def get_layer_def(self, store_name): return loads(await self.redis.get(self.get_layer_def_channel(store_name))) - async def get_live_layer_defs(self): # -> list[LiveGeoModel]: + async def get_live_layer_defs(self): # -> list[LiveGeoModel]: registry.geom_live_defs = {} for channel in sorted(await self.get_live_layer_def_channels()): model_info = loads(await self.redis.get(channel)) - registry.geom_live_defs[model_info['store']] = model_info + registry.geom_live_defs[model_info["store"]] = model_info registry.update_live_layers() async def get_maplibre_style(self, store_name) -> MaplibreStyle: @@ -305,16 +331,18 @@ class Store: async def get_gdf(self, store_name, reproject=False): raw_data = await self.redis.get(self.get_gdf_channel(store_name)) if raw_data == None: - raise RedisError(f'Cannot get {store_name}: no data') + raise RedisError(f"Cannot get {store_name}: no data") try: gdf = loads_pickle(raw_data) except Exception as err: logger.exception(err) - raise RedisError(f'Cannot get {store_name}: pickle error from redis store: {err.__class__.__name__}, {err.args[0]}') + raise RedisError( + f"Cannot get {store_name}: pickle error from redis store: {err.__class__.__name__}, {err.args[0]}" + ) if len(gdf) == 0: - raise RedisError(f'Cannot get {store_name}: empty') + raise RedisError(f"Cannot get {store_name}: empty") if reproject: - gdf.to_crs(conf.crs['for_proj'], inplace=True) + gdf.to_crs(conf.crs["for_proj"], inplace=True) return gdf async def get_feature_info(self, store_name, id): @@ -326,8 +354,8 @@ class Store: """ Set the ttag for the store as 'now' """ - #logger.debug(f'ttag {store_name} at {now}') - await self.redis.set(f'ttag:{store_name}', now) + # logger.debug(f'ttag {store_name} at {now}') + await self.redis.set(f"ttag:{store_name}", now) def create_task_store_ttag(self, connection, pid, channel, store_name): """ @@ -337,7 +365,7 @@ class Store: if store_name in registry.stores: create_task(self.set_ttag(store_name, time())) else: - logger.warn(f'Notify received for an unexisting store: {store_name}') + logger.warn(f"Notify received for an unexisting store: {store_name}") async def get_ttag(self, store_name): """ @@ -345,7 +373,7 @@ class Store: ttag is the time stamp of the last modification of the store. If no ttag is know, create one as now. """ - ttag = await self.redis.get(f'ttag:{store_name}') + ttag = await self.redis.get(f"ttag:{store_name}") if ttag: return ttag.decode() else: @@ -362,7 +390,7 @@ class Store: Delete all ttags in redis """ ## Equivalient command line: redis-cli del (redis-cli --scan --pattern 'ttag:*') - keys = await self.redis.keys('ttag:*') + keys = await self.redis.keys("ttag:*") if keys: await self.redis.delete(*keys) @@ -385,40 +413,54 @@ class Store: ## Create DB triggers on the tables of the models all_triggers_resp = await session.exec(text(get_all_triggers)) all_triggers = all_triggers_resp.mappings().all() - stores_with_trigger = {t['trigger_table'] - for t in all_triggers - if t['tigger_name'] == 'gisaf_ttag'} + stores_with_trigger = { + t["trigger_table"] + for t in all_triggers + if t["tigger_name"] == "gisaf_ttag" + } missing_triger_tables = set(registry.geom).difference(stores_with_trigger) # model: SQLModel = registry.stores.loc[store_name, 'model'] if len(missing_triger_tables) > 0: - logger.info('Create Postgres modification triggers for ' - f'{len(missing_triger_tables)} tables') + logger.info( + "Create Postgres modification triggers for " + f"{len(missing_triger_tables)} tables" + ) for store_name in missing_triger_tables: ## XXX: TODO: See https://stackoverflow.com/questions/7888846/trigger-in-sqlachemy model = registry.geom[store_name] try: - await session.exec(text( - ttag_create_trigger.format( - schema=model.__table__.schema, - table=model.__table__.name) - )) + await session.exec( + text( + ttag_create_trigger.format( + schema=model.__table__.schema, + table=model.__table__.name, + ) + ) + ) except UndefinedTableError: - logger.warning(f'table {store_name} does not exist in ' - 'the database: skip modification trigger') + logger.warning( + f"table {store_name} does not exist in " + "the database: skip modification trigger" + ) ## Setup triggers on Category and Qml, for Mapbox layer styling - for schema, table in (('gisaf_map', 'qml'), ('gisaf_survey', 'category')): - triggers = [t for t in all_triggers - if t['tigger_name'] == 'gisaf_ttag' - and t['trigger_table'] == f'{schema}.{table}'] + for schema, table in (("gisaf_map", "qml"), ("gisaf_survey", "category")): + triggers = [ + t + for t in all_triggers + if t["tigger_name"] == "gisaf_ttag" + and t["trigger_table"] == f"{schema}.{table}" + ] if len(triggers) == 0: - await session.exec(text( - ttag_create_trigger.format(schema=schema, table=table) - )) + await session.exec( + text(ttag_create_trigger.format(schema=schema, table=table)) + ) ## Listen: define the callback function self.asyncpg_conn = await connect(conf.db.get_pg_url()) - await self.asyncpg_conn.add_listener('gisaf_ttag', store.create_task_store_ttag) + await self.asyncpg_conn.add_listener( + "gisaf_ttag", store.create_task_store_ttag + ) async def _close_permanant_db_connection(self): """ @@ -426,9 +468,12 @@ class Store: """ try: await self.asyncpg_conn.remove_listener( - 'gisaf_ttag', store.create_task_store_ttag) + "gisaf_ttag", store.create_task_store_ttag + ) except InterfaceError as err: - logger.warning(f'Cannot remove asyncpg listener in _close_permanant_db_connection: {err}') + logger.warning( + f"Cannot remove asyncpg listener in _close_permanant_db_connection: {err}" + ) await self.asyncpg_conn.close() @@ -443,7 +488,7 @@ async def setup_redis_cache(): async def shutdown_redis(): - if not hasattr(self, 'asyncpg_conn'): + if not hasattr(self, "asyncpg_conn"): return global store await store._close_permanant_db_connection()