diff --git a/pdm.lock b/pdm.lock index eb6d3b9..7a6ee68 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["cross_platform"] lock_version = "4.4" -content_hash = "sha256:0d6cc736afc51fceae2eaff49ffbd91678e0ecb5c6f29e683f12c974c6f9bdac" +content_hash = "sha256:4593cf6b7e4e89f1e407c7b7feeb12c56c84bf16d84b94d1bbe89d3d3ed4ea6d" [[package]] name = "annotated-types" @@ -297,7 +297,7 @@ files = [ [[package]] name = "fastapi" -version = "0.104.1" +version = "0.105.0" requires_python = ">=3.8" summary = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" dependencies = [ @@ -307,8 +307,8 @@ dependencies = [ "typing-extensions>=4.8.0", ] files = [ - {file = "fastapi-0.104.1-py3-none-any.whl", hash = "sha256:752dc31160cdbd0436bb93bad51560b57e525cbb1d4bbf6f4904ceee75548241"}, - {file = "fastapi-0.104.1.tar.gz", hash = "sha256:e5e4540a7c5e1dcfbbcf5b903c234feddcdcd881f191977a1c5dfd917487e7ae"}, + {file = "fastapi-0.105.0-py3-none-any.whl", hash = "sha256:f19ebf6fdc82a3281d10f2cb4774bdfa90238e3b40af3525a0c09fd08ad1c480"}, + {file = "fastapi-0.105.0.tar.gz", hash = "sha256:4d12838819aa52af244580675825e750ad67c9df4614f557a769606af902cf22"}, ] [[package]] @@ -969,6 +969,19 @@ files = [ {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, ] +[[package]] +name = "redis" +version = "5.0.1" +requires_python = ">=3.7" +summary = "Python client for Redis database and key-value store" +dependencies = [ + "async-timeout>=4.0.2; python_full_version <= \"3.11.2\"", +] +files = [ + {file = "redis-5.0.1-py3-none-any.whl", hash = "sha256:ed4802971884ae19d640775ba3b03aa2e7bd5e8fb8dfaed2decce4d0fc48391f"}, + {file = "redis-5.0.1.tar.gz", hash = "sha256:0dab495cd5753069d3bc650a0dde8a8f9edde16fc5691b689a566eda58100d0f"}, +] + [[package]] name = "rsa" version = "4.9" diff --git a/pyproject.toml b/pyproject.toml index ffaa5df..c0592b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "pyshp>=2.3.1", "orjson>=3.9.10", "sqlmodel>=0.0.14", + "redis>=5.0.1", ] requires-python = ">=3.11" readme = "README.md" diff --git a/src/gisaf/_version.py b/src/gisaf/_version.py index 9bab421..0a90112 100644 --- a/src/gisaf/_version.py +++ b/src/gisaf/_version.py @@ -1 +1 @@ -__version__ = '2023.4.dev3+g5494f60.d20231212' \ No newline at end of file +__version__ = '2023.4.dev4+g049b8c9.d20231213' \ No newline at end of file diff --git a/src/gisaf/api.py b/src/gisaf/api.py index ddecf07..bbc78bf 100644 --- a/src/gisaf/api.py +++ b/src/gisaf/api.py @@ -16,7 +16,7 @@ from .models.category import Category, CategoryRead from .config import conf from .models.bootstrap import BootstrapData from .models.store import Store -from .database import get_db_session, pandas_query +from .database import get_db_session, pandas_query, fastapi_db_session as db_session from .security import ( Token, authenticate_user, get_current_user, create_access_token, @@ -31,7 +31,7 @@ api = FastAPI( ) #api.add_middleware(SessionMiddleware, secret_key=conf.crypto.secret) -db_session = Annotated[AsyncSession, Depends(get_db_session)] +#db_session = Annotated[AsyncSession, Depends(get_db_session)] @api.get('/bootstrap') @@ -55,7 +55,7 @@ async def login_for_access_token( access_token = create_access_token( data={"sub": user.username}, expires_delta=timedelta(seconds=conf.crypto.expire)) - return {"access_token": access_token, "token_type": "bearer"} + return Token(access_token=access_token, token_type='bearer') @api.get("/list") async def list_data_providers(): diff --git a/src/gisaf/application.py b/src/gisaf/application.py index 6ca04cf..e60ec62 100644 --- a/src/gisaf/application.py +++ b/src/gisaf/application.py @@ -8,8 +8,10 @@ from typing import Any from fastapi import FastAPI, responses from .api import api +from .geoapi import api as geoapi from .config import conf from .registry import registry, ModelRegistry +from .redis_tools import setup_redis, shutdown_redis, setup_redis_cache logging.basicConfig(level=conf.gisaf.debugLevel) @@ -28,6 +30,7 @@ class GisafFastAPI(FastAPI): @asynccontextmanager async def lifespan(app: FastAPI): await registry.make_registry(app) + await setup_redis(app) yield app = FastAPI( @@ -37,4 +40,6 @@ app = FastAPI( lifespan=lifespan, default_response_class=responses.ORJSONResponse, ) -app.mount('/v2', api) \ No newline at end of file + +app.mount('/v2', api) +app.mount('/gj', geoapi) \ No newline at end of file diff --git a/src/gisaf/database.py b/src/gisaf/database.py index 9433ea2..8e62e40 100644 --- a/src/gisaf/database.py +++ b/src/gisaf/database.py @@ -1,8 +1,9 @@ from contextlib import asynccontextmanager +from typing import Annotated from sqlalchemy.ext.asyncio import create_async_engine from sqlmodel.ext.asyncio.session import AsyncSession - +from fastapi import Depends import pandas as pd from .config import conf @@ -27,4 +28,6 @@ async def db_session() -> AsyncSession: yield session def pandas_query(session, query): - return pd.read_sql_query(query, session.connection()) \ No newline at end of file + return pd.read_sql_query(query, session.connection()) + +fastapi_db_session = Annotated[AsyncSession, Depends(get_db_session)] \ No newline at end of file diff --git a/src/gisaf/geoapi.py b/src/gisaf/geoapi.py new file mode 100644 index 0000000..fc161e1 --- /dev/null +++ b/src/gisaf/geoapi.py @@ -0,0 +1,132 @@ +""" +Geographical json stores, served under /gj +Used for displaying features on maps +""" +import logging +from asyncio import CancelledError + +from fastapi import FastAPI, HTTPException, status, responses +from .redis_tools import store as redis_store +# from gisaf.live import live_server +from .registry import registry + + +logger = logging.getLogger(__name__) + +api = FastAPI( + default_response_class=responses.ORJSONResponse, +) + +@api.get('/live/{store}') +async def live_layer(store: str): + """ + Websocket for live layer updates + """ + ws = web.WebSocketResponse() + await ws.prepare(request) + async for msg in ws: + if msg.type == WSMsgType.TEXT: + if msg.data == 'close': + await ws.close() + else: + msg_data = msg.json() + if 'message' in msg_data: + if msg_data['message'] == 'subscribeLiveLayer': + live_server.add_subscription(ws, store) + elif msg_data['message'] == 'unsubscribeLiveLayer': + live_server.remove_subscription(ws, store) + elif msg.type == WSMsgType.ERROR: + logger.exception(ws.exception()) + logger.debug('websocket connection closed') + return ws + +@api.get('/{store_name}') +async def get_geojson(store_name): + """ + Some REST stores coded manually (route prefixed with "gj": geojson). + :param store_name: name of the model + :return: json + """ + use_cache = False + try: + model = registry.stores.loc[store_name].model + except KeyError: + raise HTTPException(status.HTTP_404_NOT_FOUND) + + if hasattr(model, 'viewable_role') and model.viewable_role: + await check_permission(request, model.viewable_role) + + if await redis_store.has_channel(store_name): + ## Live layers + data = await redis_store.get_layer_as_json(store_name) + return web.Response(text=data.decode(), content_type='application/json') + + # elif not model: + # raise HTTPException(status.HTTP_404_NOT_FOUND) + + if model.cache_enabled: + ttag = await redis_store.get_ttag(store_name) + if ttag and request.headers.get('If-None-Match') == ttag: + return web.HTTPNotModified() + + if hasattr(model, 'get_geojson'): + geojson = await model.get_geojson(simplify_tolerance=float(request.headers.get('simplify', 50.0))) + ## Store to redis for caching + if use_cache: + await redis_store.store_json(model, geojson) + resp = web.Response(text=geojson, content_type='application/json') + + elif model.can_get_features_as_df: + ## Get the GeoDataframe (gdf) with GeoPandas + ## get_popup and get_propertites get the gdf as argument and can use vectorised operations + try: + gdf = await model.get_geo_df(cast=True, with_related=True, filter_columns=True) + except CancelledError as err: + logger.debug(f'Request for {store_name} cancelled while getting gdf') + raise err + except Exception as err: + logger.exception(err) + raise web.HTTPInternalServerError() + ## The query of category defined models gets the status (not sure how and this could be skipped) + ## Other models do not have: just add it manually from the model itself + if 'status' not in gdf.columns: + gdf['status'] = model.status + if 'popup' not in gdf.columns: + gdf['popup'] = await model.get_popup(gdf) + properties = await model.get_properties(gdf) + columns = ['geometry', 'status', 'popup'] + for property, values in properties.items(): + columns.append(property) + gdf[property] = values + geojson = gdf[columns].to_json(separators=(',', ':'), check_circular=False) + ## Store to redis for caching + if use_cache: + await redis_store.store_json(model, geojson) + resp = geojson + + else: + logger.warn(f"{model} doesn't allow using dataframe for generating json!") + attrs, features_kwargs = await model.get_features_attrs( + float(request.headers.get('simplify', 50.0))) + ## Using gino: allows OO model (get_info, etc) + try: + attrs['features'] = await model.get_features_in_bulk_gino(**features_kwargs) + except Exception as err: + logger.exception(err) + raise web.HTTPInternalServerError() + resp = attrs + + if model.cache_enabled and ttag: + resp.headers.add('ETag', ttag) + return resp + + +@api.get('/gj/{store_name}/popup/{id}') +async def gj_popup(store_name: str, id: int): + model = registry.geom.get(store_name) + if not hasattr(model, 'get_popup_dynamic'): + return '' + obj = await model.get(id) + ## Escape characters for json + popup_more = obj.get_popup_dynamic().replace('"', '\\"').replace('\n', '\\n') + return {"text": popup_more} diff --git a/src/gisaf/models/geo_models_base.py b/src/gisaf/models/geo_models_base.py index 7b2e281..ff3d226 100644 --- a/src/gisaf/models/geo_models_base.py +++ b/src/gisaf/models/geo_models_base.py @@ -305,7 +305,7 @@ class GeoModel(Model): It can be overridden with the with_related parameter when calling get_df. """ - z_index: ClassVar[int] = 450 + z_index: ClassVar[int] = 450 # Field(450, alias='zIndex') """ z-index for the leaflet layer. Should be between 400 and 500. @@ -742,6 +742,13 @@ class GeoModel(Model): def get_attachment_base_dir(cls): return Path(conf.attachments['base_dir'])/cls.get_attachment_dir() +class LiveGeoModel(GeoModel): + store: ClassVar[str] + group: ClassVar[str] ='Live' + custom: ClassVar[bool] = True + is_live: ClassVar[bool] = True + is_db: ClassVar[bool] = False + class Geom(str): pass diff --git a/src/gisaf/redis_tools.py b/src/gisaf/redis_tools.py new file mode 100644 index 0000000..8b1d0c5 --- /dev/null +++ b/src/gisaf/redis_tools.py @@ -0,0 +1,437 @@ +from typing import ClassVar +from uuid import uuid1 +from io import BytesIO +from asyncio import create_task +from json import loads, dumps +from pickle import dump, HIGHEST_PROTOCOL, loads as loads_pickle +from time import time +import logging + +import pandas as pd +import geopandas as gpd +from asyncpg.exceptions import UndefinedTableError, InterfaceError +from redis import asyncio as aioredis +from pydantic import create_model + +from .config import conf +# from gisaf.models.live import LiveModel +from .utils import (SHAPELY_TYPE_TO_MAPBOX_TYPE, DEFAULT_MAPBOX_LAYOUT, + DEFAULT_MAPBOX_PAINT, gisTypeSymbolMap) +from .registry import registry +#from .models.geom import GeomGroup, GeomModel +from .models.geo_models_base import LiveGeoModel + +logger = logging.getLogger(__name__) + +ttag_function = """ +CREATE OR REPLACE FUNCTION gisaf.ttag() RETURNS trigger LANGUAGE plpgsql AS +$$ +BEGIN +PERFORM(select pg_notify('gisaf_ttag', TG_TABLE_SCHEMA || '.' || TG_TABLE_NAME)); +RETURN NULL; +END; +$$ +; +""" + +ttag_drop_trigger = 'DROP TRIGGER IF EXISTS gisaf_ttag ON "{schema}"."{table}";' + +ttag_create_trigger = """ +CREATE TRIGGER gisaf_ttag AFTER INSERT OR UPDATE OR DELETE +ON "{schema}"."{table}" +FOR EACH STATEMENT +EXECUTE FUNCTION gisaf.ttag(); +""" + +## From https://dba.stackexchange.com/questions/121717/get-triggers-table-names-in-postgresql +get_all_triggers = """ +SELECT trg.tgname as tigger_name, + CASE trg.tgtype::INTEGER & 66 + WHEN 2 THEN 'BEFORE' + WHEN 64 THEN 'INSTEAD OF' + ELSE 'AFTER' + END AS trigger_type, + CASE trg.tgtype::INTEGER & cast(28 AS INT2) + WHEN 16 THEN 'UPDATE' + WHEN 8 THEN 'DELETE' + WHEN 4 THEN 'INSERT' + WHEN 20 THEN 'INSERT, UPDATE' + WHEN 28 THEN 'INSERT, UPDATE, DELETE' + WHEN 24 THEN 'UPDATE, DELETE' + WHEN 12 THEN 'INSERT, DELETE' + END AS trigger_event, + ns.nspname||'.'||tbl.relname AS trigger_table, + obj_description(trg.oid) AS remarks, + CASE + WHEN trg.tgenabled='O' THEN 'ENABLED' + ELSE 'DISABLED' + END AS status, + CASE trg.tgtype::INTEGER & 1 + WHEN 1 THEN 'ROW'::TEXT + ELSE 'STATEMENT'::TEXT + END AS trigger_level, + n.nspname || '.' || proc.proname AS function_name +FROM pg_trigger trg + JOIN pg_proc proc ON proc.oid = trg.tgfoid + JOIN pg_catalog.pg_namespace n ON n.oid = proc.pronamespace + JOIN pg_class tbl ON trg.tgrelid = tbl.oid + JOIN pg_namespace ns ON ns.oid = tbl.relnamespace +WHERE + trg.tgname not like 'RI_ConstraintTrigger%' + AND trg.tgname not like 'pg_sync_pg%'; +""" + +class RedisError(Exception): + pass + +class Store: + """ + Store for redis: + - redis: RedisConnection + - pub (/sub) connections + """ + async def setup(self, app): + """ + Setup the live service for the main Gisaf application: + - Create connection for the publishers + - Create connection for redis listeners (websocket service) + """ + self.app = app + app.extra['store'] = self + await self.create_connections() + await self.get_live_layer_defs() + + async def create_connections(self): + """ + Create the connection for the publisher + XXX: this should be renamed to be explicit + """ + self.redis = aioredis.from_url(conf.gisaf_live.redis) + self.pub = self.redis.pubsub() + await self.redis.client_setname(str(uuid1())) + self.uuid = await self.redis.client_getname() + self.uuid = str(uuid1()) + + + def get_json_channel(self, store_name): + """ + Name of the Redis channel for the json representation + """ + return f'{store_name}:json' + + def get_gdf_channel(self, store_name): + """ + Name of the Redis channel for the source gdf, in pickle format + """ + return f'{store_name}:gdf' + + def get_layer_def_channel(self, store_name): + """ + Name of the Redis channel for the layer definition + """ + return f'{store_name}:layer_def' + + def get_mapbox_layout_channel(self, store_name): + """ + Name of the Redis channel for the mapbox layout style definition + """ + return f'{store_name}:mapbox_layout' + + def get_mapbox_paint_channel(self, store_name): + """ + Name of the Redis channel for the mapbox paint style definition + """ + return f'{store_name}:mapbox_paint' + + async def store_json(self, model, geojson, **kwargs): + """ + Store the json representation of the gdf for caching. + """ + ## Save in Redis channel + channel = self.get_json_channel(model.get_store_name()) + await self.redis.set(channel, geojson) + ## XXX: publish to websocket? + #await self.pub.publish(self.get_json_channel(store_name), data) + + async def publish(self, *args, **kwargs): + """ + Wrapper for publishing to the redis pubsub channel + """ + return await self.redis.publish(*args, **kwargs) + + async def publish_gdf(self, live_name, gdf, **kwargs): + """ + Create or update the live layer, store in redis. + Additionally, publish to the channel for websocket live updates to ws_clients + """ + if gdf is None: + gdf = gpd.GeoDataFrame(data={'geom': []}, geometry='geom') + if isinstance(gdf.index, pd.core.indexes.multi.MultiIndex): + raise ValueError('Gisaf live does not accept dataframes with multi index') + return await self._store_live_to_redis(live_name, gdf, **kwargs) + + async def _store_live_to_redis(self, live_name, gdf, properties=None, + mapbox_paint=None, mapbox_layout=None, info=None, + viewable_role=None, z_index=499, description='', + status='L', symbol=None, color=None, attribution=None, + **kwargs): + """ + Store and publish the live layer data and metadata to redis channels + """ + store_name = f'live:{live_name}' + ## Filter empty geometries + gdf = gdf[gdf[gdf.geometry.name].notnull()] + ## Reproject eventually + if 'status' not in gdf.columns: + gdf['status'] = status + if 'popup' not in gdf.columns: + gdf['popup'] = 'Live: ' + live_name + ' #' + gdf.index.astype('U') + if len(gdf) > 0: + gdf = gdf.to_crs(conf.crs['geojson']) + gis_type = gdf.geom_type.iloc[0] + else: + gis_type = 'Point' ## FIXME: cannot be inferred from the gdf? + mapbox_type = SHAPELY_TYPE_TO_MAPBOX_TYPE.get(gis_type, None) + if not mapbox_paint: + mapbox_paint = DEFAULT_MAPBOX_PAINT.get(mapbox_type, {}) + if color: + if mapbox_type == 'symbol': + mapbox_paint['text-color'] = color + if not mapbox_layout: + mapbox_layout = DEFAULT_MAPBOX_LAYOUT.get(mapbox_type, {}) + if symbol: + mapbox_layout['text-field'] = symbol + if not symbol: + symbol = gisTypeSymbolMap.get(gis_type, '\ue02e') + if properties == None: + properties = [] + ## Add a column for json representation + columns = {'status', 'popup', gdf.geometry.name, 'store', 'id'} + geojson = gdf[list(columns.intersection(gdf.columns).union(properties))].to_json() + ## Publish to websocket + await self.redis.publish(self.get_json_channel(store_name), geojson) + layer_def_data = dumps({ + 'store': store_name, + 'z_index': z_index, + 'count': len(gdf), + 'mapbox_type': mapbox_type, + 'gis_type': gis_type, + 'symbol': symbol, + 'name': live_name, + 'description': description, + 'viewable_role': viewable_role, + 'attribution': attribution, + 'is_live': True, + }) + ## Pickle the dataframe + with BytesIO() as buf: + dump(gdf, buf, protocol=HIGHEST_PROTOCOL) + buf.seek(0) + #df_blob = buf.read() + await self.redis.set(self.get_gdf_channel(store_name), buf.read()) + + ## Save in Redis channels + await self.redis.set(self.get_json_channel(store_name), geojson) + await self.redis.set( + self.get_mapbox_layout_channel(store_name), + dumps(mapbox_layout) + ) + await self.redis.set(self.get_mapbox_paint_channel(store_name), dumps(mapbox_paint)) + await self.redis.set(self.get_layer_def_channel(store_name), layer_def_data) + + ## Update the layers/stores registry + if hasattr(self, 'app'): + await self.get_live_layer_defs() + + return geojson + + async def get_listener(self, channel): + return await self.pub.psubscribe(channel) + + async def remove_layer(self, store_name): + """ + Remove the layer from Gisaf Live (channel) + """ + await self.redis.delete(self.get_json_channel(store_name)) + await self.redis.delete(self.get_layer_def_channel(store_name)) + await self.redis.delete(self.get_gdf_channel(store_name)) + await self.redis.delete(self.get_mapbox_layout_channel(store_name)) + await self.redis.delete(self.get_mapbox_paint_channel(store_name)) + + ## Update the layers/stores registry + if hasattr(self, 'app'): + await self.get_live_layer_defs() + + async def has_channel(self, store_name): + return len(await self.redis.keys(self.get_json_channel(store_name))) > 0 + + async def get_live_layer_def_channels(self): + try: + return [k.decode() for k in await self.redis.keys('live:*:layer_def')] + except aioredis.exceptions.ConnectionError as err: + raise RedisError('Cannot use Redis server, please restart Gisaf') + + async def get_layer_def(self, store_name): + return loads(await self.redis.get(self.get_layer_def_channel(store_name))) + + async def get_live_layer_defs(self) -> list[LiveGeoModel]: + registry.geom_live_defs = {} + for channel in sorted(await self.get_live_layer_def_channels()): + model_info = loads(await self.redis.get(channel)) + registry.geom_live_defs[model_info['store']] = model_info + registry.update_live_layers() + + async def get_mapbox_style(self, store_name): + """ + Get the http headers (mapbox style) from the store name (layer_def) + """ + paint = await self.redis.get(self.get_mapbox_paint_channel(store_name)) + layout = await self.redis.get(self.get_mapbox_layout_channel(store_name)) + style = {} + if paint is not None: + style['paint'] = paint.decode() + if layout is not None: + style['layout'] = layout.decode() + return style + + async def get_layer_as_json(self, store_name): + """ + Get the json from the store name (layer_def) + """ + return await self.redis.get(self.get_json_channel(store_name)) + + async def get_gdf(self, store_name, reproject=False): + raw_data = await self.redis.get(self.get_gdf_channel(store_name)) + if raw_data == None: + raise RedisError(f'Cannot get {store_name}: no data') + try: + gdf = loads_pickle(raw_data) + except Exception as err: + logger.exception(err) + raise RedisError(f'Cannot get {store_name}: pickle error from redis store: {err.__class__.__name__}, {err.args[0]}') + if len(gdf) == 0: + raise RedisError(f'Cannot get {store_name}: empty') + if reproject: + gdf.to_crs(conf.crs['for_proj'], inplace=True) + return gdf + + async def get_feature_info(self, store_name, id): + gdf = await self.get_gdf(store_name) + ## FIXME: requires the gdf to have an integer index, used as feature['id'] on the map + return gdf.loc[int(id)] + + async def set_ttag(self, store_name, now): + """ + Set the ttag for the store as 'now' + """ + #logger.debug(f'ttag {store_name} at {now}') + await self.redis.set(f'ttag:{store_name}', now) + + def create_task_store_ttag(self, connection, pid, channel, store_name): + """ + Postgres/asyncpg listener for the trigger on data change. + A task is created because this function is not asynchronous. + """ + create_task(self.set_ttag(store_name, time())) + + async def get_ttag(self, store_name): + """ + Get the ttag for the given store. + ttag is the time stamp of the last modification of the store. + If no ttag is know, create one as now. + """ + ttag = await self.redis.get(f'ttag:{store_name}') + if ttag: + return ttag.decode() + else: + ## No ttag: Gisaf doesn't know when was the last update, + ## ie it was restarted and the ttags are cleared on startup. + ## Set a ttag now, using the current epoch time in seconds in hex, + ## double quoted and add a W/ prefix as it's basically a weak ETag + weak_now_hex = f'W/"{hex(int(time()))[2:]}"' + await self.set_ttag(store_name, weak_now_hex) + return weak_now_hex + + async def delete_all_ttags(self): + """ + Delete all ttags in redis + """ + ## Equivalient command line: redis-cli del (redis-cli --scan --pattern 'ttag:*') + keys = await self.redis.keys('ttag:*') + if keys: + await self.redis.delete(*keys) + + async def _setup_db_cache_system(self): + """ + Setup the caching system: + - clear all Redis store at startup + - make sure the triggers and the "change" (insert, update, delete) event emitter + function are setup on the database server + - listen to the DB event emitter: setup a callback function + """ + ## Setup the function and triggers on tables + db = self.app['db'] + + ## Keep the connection alive: don't use a "with" block + ## It needs to be closed correctly: see _close_permanant_db_connection + self._permanent_conn = await db.acquire() + self._permanent_raw_conn = await self._permanent_conn.get_raw_connection() + + ## Create the function in the database + await self._permanent_raw_conn.execute(ttag_function) + + ## Delete all the ttags, for safety + ## eg. the database was changed and Gisaf wasn't running, so the redis store wasn't updated + await store.delete_all_ttags() + + ## Create DB triggers on the tables of the models + all_triggers = await self._permanent_raw_conn.fetch(get_all_triggers) + stores_with_trigger = {t['trigger_table'] for t in all_triggers if t['tigger_name'] == 'gisaf_ttag'} + missing_triger_tables = set(registry.geom).difference(stores_with_trigger) + if len(missing_triger_tables) > 0: + logger.info(f'Create Postgres modification triggers for {len(missing_triger_tables)} tables') + for store_name in missing_triger_tables: + model = registry.geom[store_name] + try: + await self._permanent_raw_conn.execute(ttag_create_trigger.format( + schema=model.__table__.schema, table=model.__table__.name)) + except UndefinedTableError: + logger.warning(f'table {store_name} does not exist in the database: skip modification trigger') + ## Setup triggers on Category and Qml, for Mapbox layer styling + for schema, table in (('gisaf_map', 'qml'), ('gisaf_survey', 'category')): + triggers = [t for t in all_triggers + if t['tigger_name'] == 'gisaf_ttag' and t['trigger_table'] == f'{schema}.{table}'] + if len(triggers) == 0: + await self._permanent_raw_conn.execute(ttag_create_trigger.format(schema=schema, table=table)) + + ## Listen: define the callback function + await self._permanent_raw_conn.add_listener('gisaf_ttag', store.create_task_store_ttag) + + async def _close_permanant_db_connection(self): + """ + Called at aiohttp server shutdown: remove the listener and close the connections + """ + try: + await self._permanent_raw_conn.remove_listener('gisaf_ttag', store.create_task_store_ttag) + except InterfaceError as err: + logger.warning(f'Cannot remove asyncpg listener in _close_permanant_db_connection: {err}') + await self._permanent_raw_conn.close() + await self._permanent_conn.release() + + +async def setup_redis(app): + global store + await store.setup(app) + + +async def setup_redis_cache(app): + global store + await store._setup_db_cache_system() + + +async def shutdown_redis(app): + global store + await store._close_permanant_db_connection() + + +store = Store() diff --git a/src/gisaf/registry.py b/src/gisaf/registry.py index 8d323a1..67a0e0d 100644 --- a/src/gisaf/registry.py +++ b/src/gisaf/registry.py @@ -6,6 +6,7 @@ import importlib import pkgutil from collections import defaultdict from importlib.metadata import entry_points +from typing import Any, ClassVar from pydantic import create_model from sqlalchemy import inspect, text @@ -19,6 +20,7 @@ from .config import conf from .models import (misc, category as category_module, project, reconcile, map_bases, tags) from .models.geo_models_base import ( + LiveGeoModel, PlottableModel, GeoModel, RawSurveyBaseModel, @@ -68,6 +70,8 @@ class ModelRegistry: Maintains registries for all kind of model types, eg. geom, data, values... Provides tools to get the models from their names, table names, etc. """ + stores: pd.DataFrame + def __init__(self): """ Get geo models @@ -75,6 +79,8 @@ class ModelRegistry: """ self.geom_custom = {} self.geom_custom_store = {} + self.geom_live: dict[str, LiveGeoModel] = {} + self.geom_live_defs: dict[str, dict[str, Any]] = {} self.values = {} self.other = {} self.misc = {} @@ -125,12 +131,11 @@ class ModelRegistry: raw_store_name = f'{raw_survey.schema}.RAW_{category.table_name}' raw_survey_field_definitions = { ## FIXME: RawSurveyBaseModel.category should be a Category, not category.name - 'category_name': (str, category.name), + 'category_name': (ClassVar[str], category.name), ## FIXME: Same for RawSurveyBaseModel.group - 'group_name': (str, category.category_group.name), - 'viewable_role': (str, category.viewable_role), - 'store_name': (str, raw_store_name), - # 'icon': (str, ''), + 'group_name': (ClassVar[str], category.category_group.name), + 'viewable_role': (ClassVar[str], category.viewable_role), + 'store_name': (ClassVar[str], raw_store_name), # 'icon': (str, ''), } ## Raw survey points @@ -142,14 +147,14 @@ class ModelRegistry: 'table': True, 'metadata': raw_survey, '__tablename__': category.raw_survey_table_name, - ## FIXME: RawSurveyBaseModel.category should be a Category, not category.name - 'category_name': category.name, - ## FIXME: Same for RawSurveyBaseModel.group - 'group_name': category.category_group.name, - 'viewable_role': category.viewable_role, - 'store_name': raw_store_name, + # ## FIXME: RawSurveyBaseModel.category should be a Category, not category.name + # 'category_name': category.name, + # ## FIXME: Same for RawSurveyBaseModel.group + # 'group_name': category.category_group.name, + # 'viewable_role': category.viewable_role, + # 'store_name': raw_store_name, }, - # **raw_survey_field_definitions + **raw_survey_field_definitions ) except Exception as err: logger.exception(err) @@ -162,11 +167,11 @@ class ModelRegistry: try: if model_class: survey_field_definitions = { - 'category_name': (str, category.name), - 'group_name': (str, category.category_group.name), - 'raw_store_name': (str, raw_store_name), - 'viewable_role': (str, category.viewable_role), - 'symbol': (str, category.symbol), + 'category_name': (ClassVar[str], category.name), + 'group_name': (ClassVar[str], category.category_group.name), + 'raw_store_name': (ClassVar[str], raw_store_name), + 'viewable_role': (ClassVar[str], category.viewable_role), + 'symbol': (ClassVar[str], category.symbol), #'raw_model': (str, self.raw_survey_models.get(raw_store_name)), # 'icon': (str, f'{survey.schema}-{category.table_name}'), } @@ -177,13 +182,13 @@ class ModelRegistry: 'table': True, 'metadata': survey, '__tablename__': category.table_name, - 'category_name': category.name, - 'group_name': category.category_group.name, - 'raw_store_name': raw_store_name, - 'viewable_role': category.viewable_role, - 'symbol': category.symbol, + # 'category_name': category.name, + # 'group_name': category.category_group.name, + # 'raw_store_name': raw_store_name, + # 'viewable_role': category.viewable_role, + # 'symbol': category.symbol, }, - # **survey_field_definitions, + **survey_field_definitions, ) except Exception as err: logger.warning(err) @@ -519,13 +524,15 @@ class ModelRegistry: row.model.mapbox_type, # or None, row.model.base_gis_type, row.model.z_index, + row.model.attribution, ) # self.stores['icon'],\ # self.stores['symbol'],\ - self.stores['mapbox_type_default'],\ - self.stores['base_gis_type'],\ - self.stores['z_index']\ + self.stores['mapbox_type_default'], \ + self.stores['base_gis_type'], \ + self.stores['z_index'], \ + self.stores['attribution'] \ = zip(*self.stores.apply(fill_columns_from_model, axis=1)) #self.stores['mapbox_type_custom'] = self.stores['mapbox_type_custom'].replace('', np.nan).fillna(np.nan) @@ -614,27 +621,61 @@ class ModelRegistry: # return store_df.gql_object_type.to_list() - #def update_live_layers(self, live_models: List[GeomModel]): - #raise ToMigrate('make_model_gql_object_type') - def update_live_layers(self, live_models): + def update_live_layers(self): """ - Update the live layers in the registry, using the provided list of GeomModel + Update the live layers, using the list of model definitions found in + self.geom_live_defs, which is normally updated by the redis store """ ## Remove existing live layers - self.stores.drop(self.stores[self.stores.is_live==True].index, inplace=True) - - ## Add provided live layers - ## Ideally, should be vectorized - for model in live_models: - self.stores.loc[model.store] = { - 'description': model.description, - 'group': model.group, - 'name': model.name, - 'gql_object_type': model, - 'is_live': True, - 'is_db': False, - 'custom': True, + self.geom_live = {} + self.stores.drop(self.stores[self.stores.is_live == True].index, # noqa: E712 + inplace=True) + df_live = pd.DataFrame.from_dict(self.geom_live_defs.values(), + orient='columns' + ).set_index('store') + ## Adjust column names + ## and add columns, to make sure pandas dtypes are not changed when the + ## dataframes are concat + ## TODO: standardize names across the whole workflow, + ## then remove the rename below: + df_live.rename( + columns={ + 'live': 'is_live', + 'zIndex': 'z_index', + 'gisType': 'model_type', + 'type': 'mapbox_type', + 'viewableRole': 'viewable_role', + }, inplace=True + ) + ## Add columns + df_live['auto_import'] = False + df_live['base_gis_type'] = df_live['model_type'] + df_live['custom'] = False + df_live['group'] = '' + df_live['in_menu'] = True + df_live['is_db'] = False + df_live['is_line_work'] = False + df_live['long_name'] = df_live['name'] + df_live['mapbox_type_custom'] = df_live['mapbox_type'] + df_live['minor_group_1'] = '' + df_live['minor_group_2'] = '' + df_live['status'] = 'E' + df_live['style'] = None + df_live['title'] = df_live['name'] + registry.stores = pd.concat([registry.stores, df_live]) + for store, model_info in self.geom_live_defs.items(): + ## Add provided live layers in the stores df + # Create the pydantic model + # NOTE: Unused at this point, but might be usedful + field_definitions = { + k: (ClassVar[v.__class__], v) + for k, v in model_info.items() } + self.geom_live[store] = create_model( + __model_name=store, + __base__= LiveGeoModel, + **field_definitions + ) # Accessible as global registry: ModelRegistry = ModelRegistry() diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000