Add live (redis and websockets)
Add modernised ipynb_tools Add scheduler Fix crs in settings Lots of small fixes
This commit is contained in:
parent
461c31fb6f
commit
47df53f4d1
15 changed files with 1614 additions and 61 deletions
|
@ -1 +1 @@
|
|||
__version__ = '2023.4.dev4+g049b8c9.d20231213'
|
||||
__version__ = '2023.4.dev7+g461c31f.d20231218'
|
|
@ -12,9 +12,9 @@ from .geoapi import api as geoapi
|
|||
from .config import conf
|
||||
from .registry import registry, ModelRegistry
|
||||
from .redis_tools import setup_redis, shutdown_redis, setup_redis_cache
|
||||
from .live import setup_live
|
||||
|
||||
logging.basicConfig(level=conf.gisaf.debugLevel)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
## Subclass FastAPI to add attributes to be used globally, ie. registry
|
||||
|
@ -29,9 +29,11 @@ class GisafFastAPI(FastAPI):
|
|||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await registry.make_registry(app)
|
||||
await setup_redis(app)
|
||||
await registry.make_registry()
|
||||
await setup_redis()
|
||||
await setup_live()
|
||||
yield
|
||||
await shutdown_redis()
|
||||
|
||||
app = FastAPI(
|
||||
debug=False,
|
||||
|
|
|
@ -221,6 +221,16 @@ class Job(BaseSettings):
|
|||
minutes: int | None = 0
|
||||
seconds: int | None = 0
|
||||
|
||||
class Crs(BaseSettings):
|
||||
'''
|
||||
Handy definitions for crs-es
|
||||
'''
|
||||
db: str
|
||||
geojson: str
|
||||
for_proj: str
|
||||
survey: str
|
||||
web_mercator: str
|
||||
|
||||
class Config(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
#env_prefix='gisaf_',
|
||||
|
@ -238,9 +248,20 @@ class Config(BaseSettings):
|
|||
) -> Tuple[PydanticBaseSettingsSource, ...]:
|
||||
return env_settings, init_settings, file_secret_settings, config_file_settings
|
||||
|
||||
# def __init__(self, **kwargs):
|
||||
# super().__init__(**kwargs)
|
||||
# self.crs = {
|
||||
# 'db': f'epsg:{conf.srid}',
|
||||
# 'geojson': f'epsg:{conf.geojson_srid}',
|
||||
# 'for_proj': f'epsg:{conf.srid_for_proj}',
|
||||
# 'survey': f'epsg:{conf.raw_survey_srid}',
|
||||
# 'web_mercator': 'epsg:3857',
|
||||
# }
|
||||
|
||||
admin: Admin
|
||||
attachments: Attachments
|
||||
basket: BasketOldDef
|
||||
# crs: Crs
|
||||
crypto: Crypto
|
||||
dashboard: Dashboard
|
||||
db: DB
|
||||
|
@ -261,6 +282,15 @@ class Config(BaseSettings):
|
|||
#engine: AsyncEngine
|
||||
#session_maker: sessionmaker
|
||||
|
||||
@property
|
||||
def crs(self) -> Crs:
|
||||
return Crs(
|
||||
db=f'epsg:{self.geo.srid}',
|
||||
geojson=f'epsg:{self.geo.srid}',
|
||||
for_proj=f'epsg:{self.geo.srid_for_proj}',
|
||||
survey=f'epsg:{self.geo.raw_survey.srid}',
|
||||
web_mercator='epsg:3857',
|
||||
)
|
||||
|
||||
def config_file_settings() -> dict[str, Any]:
|
||||
config: dict[str, Any] = {}
|
||||
|
|
|
@ -2,14 +2,16 @@
|
|||
Geographical json stores, served under /gj
|
||||
Used for displaying features on maps
|
||||
"""
|
||||
from json import JSONDecodeError
|
||||
import logging
|
||||
from typing import Annotated
|
||||
from asyncio import CancelledError
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Response, status, responses, Header
|
||||
from fastapi import (FastAPI, HTTPException, Response, Header, WebSocket, WebSocketDisconnect,
|
||||
status, responses)
|
||||
|
||||
from .redis_tools import store as redis_store
|
||||
# from gisaf.live import live_server
|
||||
from .live import live_server
|
||||
from .registry import registry
|
||||
|
||||
|
||||
|
@ -19,28 +21,54 @@ api = FastAPI(
|
|||
default_response_class=responses.ORJSONResponse,
|
||||
)
|
||||
|
||||
@api.get('/live/{store}')
|
||||
async def live_layer(store: str):
|
||||
class ConnectionManager:
|
||||
active_connections: list[WebSocket]
|
||||
def __init__(self):
|
||||
self.active_connections = []
|
||||
|
||||
async def connect(self, websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
self.active_connections.append(websocket)
|
||||
|
||||
def disconnect(self, websocket: WebSocket):
|
||||
self.active_connections.remove(websocket)
|
||||
|
||||
async def send_personal_message(self, message: str, websocket: WebSocket):
|
||||
await websocket.send_text(message)
|
||||
|
||||
async def broadcast(self, message: str):
|
||||
for connection in self.active_connections:
|
||||
await connection.send_text(message)
|
||||
|
||||
manager = ConnectionManager()
|
||||
|
||||
@api.websocket('/live/{store}')
|
||||
async def live_layer(store: str, websocket: WebSocket):
|
||||
"""
|
||||
Websocket for live layer updates
|
||||
"""
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
async for msg in ws:
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
if msg.data == 'close':
|
||||
await ws.close()
|
||||
await websocket.accept()
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
msg_data = await websocket.receive_json()
|
||||
except JSONDecodeError:
|
||||
msg_text = await websocket.receive_text()
|
||||
if msg_text == 'close':
|
||||
await websocket.close()
|
||||
continue
|
||||
# else:
|
||||
if 'message' in msg_data:
|
||||
if msg_data['message'] == 'subscribeLiveLayer':
|
||||
live_server.add_subscription(websocket, store)
|
||||
elif msg_data['message'] == 'unsubscribeLiveLayer':
|
||||
live_server.remove_subscription(websocket, store)
|
||||
else:
|
||||
msg_data = msg.json()
|
||||
if 'message' in msg_data:
|
||||
if msg_data['message'] == 'subscribeLiveLayer':
|
||||
live_server.add_subscription(ws, store)
|
||||
elif msg_data['message'] == 'unsubscribeLiveLayer':
|
||||
live_server.remove_subscription(ws, store)
|
||||
elif msg.type == WSMsgType.ERROR:
|
||||
logger.exception(ws.exception())
|
||||
logger.debug('websocket connection closed')
|
||||
return ws
|
||||
logger.warning(f'Got websocket message with no message field: {msg_data}')
|
||||
except WebSocketDisconnect:
|
||||
logger.debug('Websocket disconnected')
|
||||
|
||||
# logger.debug('websocket connection closed')
|
||||
|
||||
@api.get('/{store_name}')
|
||||
async def get_geojson(store_name,
|
||||
|
|
357
src/gisaf/ipynb_tools.py
Normal file
357
src/gisaf/ipynb_tools.py
Normal file
|
@ -0,0 +1,357 @@
|
|||
"""
|
||||
Utility functions for Jupyter/iPython notebooks
|
||||
Usage from a notebook:
|
||||
from gisaf.ipynb_tools import registry
|
||||
"""
|
||||
|
||||
import logging
|
||||
from urllib.error import URLError
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from pickle import dump, HIGHEST_PROTOCOL
|
||||
# from aiohttp import ClientSession, MultipartWriter
|
||||
|
||||
import pandas as pd
|
||||
import geopandas as gpd
|
||||
|
||||
from geoalchemy2 import WKTElement
|
||||
# from geoalchemy2.shape import from_shape
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
# from shapely import wkb
|
||||
|
||||
from .config import conf
|
||||
from .redis_tools import store as redis_store
|
||||
from .live import live_server
|
||||
from .registry import registry
|
||||
|
||||
## For base maps: contextily
|
||||
try:
|
||||
import contextily as ctx
|
||||
except ImportError:
|
||||
ctx = None
|
||||
|
||||
logger = logging.getLogger('Gisaf tools')
|
||||
|
||||
|
||||
class Notebook:
|
||||
"""
|
||||
Proof of concept? Gisaf could control notebook execution.
|
||||
"""
|
||||
def __init__(self, path: str):
|
||||
self.path = path
|
||||
|
||||
|
||||
class Gisaf:
|
||||
"""
|
||||
Gisaf tool for ipython/Jupyter notebooks
|
||||
"""
|
||||
def __init__(self):
|
||||
# self.db = db
|
||||
self.conf = conf
|
||||
self.store = redis_store
|
||||
self.live_server = live_server
|
||||
if ctx:
|
||||
## Contextily newer version deprecated ctx.sources
|
||||
self.basemaps = ctx.providers
|
||||
else:
|
||||
self.basemaps = None
|
||||
|
||||
async def setup(self, with_mqtt=False):
|
||||
await self.store.create_connections()
|
||||
if with_mqtt:
|
||||
logger.warning('Gisaf live_server does not support with_mqtt anymore: ignoring')
|
||||
try:
|
||||
await self.live_server.setup()
|
||||
except Exception as err:
|
||||
logger.warn(f'Cannot setup live_server: {err}')
|
||||
logger.exception(err)
|
||||
|
||||
async def make_models(self, **kwargs):
|
||||
"""
|
||||
Populate the model registry.
|
||||
By default, all models will be added, including the those defined in categories (full registry).
|
||||
Set with_categories=False to skip them and speed up the registry initialization.
|
||||
:return:
|
||||
"""
|
||||
await registry.make_registry()
|
||||
if 'with_categories' in kwargs:
|
||||
logger.warning(f'{self.__class__}.make_models() does not support argument with_categories anymore')
|
||||
self.registry = registry
|
||||
## TODO: Compatibility: mark "models" deprecated, replaced by "registry"
|
||||
# self.models = registry
|
||||
|
||||
def get_layer_list(self):
|
||||
"""
|
||||
Get a list of the names of all layers (ie. models with a geometry).
|
||||
See get_all_geo for fetching data for a layer.
|
||||
:return: list of strings
|
||||
"""
|
||||
return self.registry.geom.keys()
|
||||
|
||||
async def get_query(self, query):
|
||||
"""
|
||||
Return a dataframe for the query
|
||||
"""
|
||||
async with query.bind.raw_pool.acquire() as conn:
|
||||
compiled = query.compile()
|
||||
columns = [a.name for a in compiled.statement.columns]
|
||||
stmt = await conn.prepare(compiled.string)
|
||||
data = await stmt.fetch(*[compiled.params.get(param) for param in compiled.positiontup])
|
||||
return pd.DataFrame(data, columns=columns)
|
||||
|
||||
async def get_all(self, model, **kwargs):
|
||||
"""
|
||||
Return a dataframe with all records for the model
|
||||
"""
|
||||
return await self.get_query(model.query)
|
||||
|
||||
async def set_dashboard(self, name, group,
|
||||
notebook=None,
|
||||
description=None,
|
||||
html=None,
|
||||
plot=None,
|
||||
df=None,
|
||||
attached=None,
|
||||
expanded_panes=None,
|
||||
sections=None):
|
||||
"""
|
||||
Add or update a dashboard page in Gisaf
|
||||
:param name: name of the dashboard page
|
||||
:param group: name of the group (level directory)
|
||||
:param notebook: name of the notebook, to be registered for future use
|
||||
:param description:
|
||||
:param attached: a matplotlib/pyplot plot, etc
|
||||
:param sections: a list of DashboardPageSection
|
||||
:return:
|
||||
"""
|
||||
from gisaf.models.dashboard import DashboardPage, DashboardPageSection
|
||||
|
||||
expanded_panes = expanded_panes or []
|
||||
sections = sections or []
|
||||
now = datetime.now()
|
||||
if not description:
|
||||
description = 'Dashboard {}/{}'.format(group, name)
|
||||
if df is not None:
|
||||
with BytesIO() as buf:
|
||||
## Don't use df.to_pickle as it closes the buffer (as per pandas==0.25.1)
|
||||
dump(df, buf, protocol=HIGHEST_PROTOCOL)
|
||||
buf.seek(0)
|
||||
df_blob = buf.read()
|
||||
else:
|
||||
df_blob = None
|
||||
|
||||
if plot is not None:
|
||||
with BytesIO() as buf:
|
||||
dump(plot, buf)
|
||||
buf.seek(0)
|
||||
plot_blob = buf.read()
|
||||
else:
|
||||
plot_blob = None
|
||||
|
||||
page = await DashboardPage.query.where((DashboardPage.name==name) & (DashboardPage.group==group)).gino.first()
|
||||
if not page:
|
||||
page = DashboardPage(
|
||||
name=name,
|
||||
group=group,
|
||||
description=description,
|
||||
notebook=notebook,
|
||||
time=now,
|
||||
df=df_blob,
|
||||
plot=plot_blob,
|
||||
html=html,
|
||||
expanded_panes=','.join(expanded_panes)
|
||||
)
|
||||
if attached:
|
||||
page.attachment = page.save_attachment(attached, name=name)
|
||||
await page.create()
|
||||
else:
|
||||
if attached:
|
||||
page.attachment = page.save_attachment(attached)
|
||||
await page.update(
|
||||
description=description,
|
||||
notebook=notebook,
|
||||
html=html,
|
||||
attachment=page.attachment,
|
||||
time=now,
|
||||
df=df_blob,
|
||||
plot=plot_blob,
|
||||
expanded_panes=','.join(expanded_panes)
|
||||
).apply()
|
||||
|
||||
for section in sections:
|
||||
#print(section)
|
||||
section.page = page
|
||||
## Replace section.plot (matplotlib plot or figure)
|
||||
## by the name of the rendered pic inthe filesystem
|
||||
section.plot = section.save_plot(section.plot)
|
||||
section_record = await DashboardPageSection.query.where(
|
||||
(DashboardPageSection.dashboard_page_id==page.id) & (DashboardPageSection.name==section.name)
|
||||
).gino.first()
|
||||
if not section_record:
|
||||
section.dashboard_page_id = page.id
|
||||
await section.create()
|
||||
else:
|
||||
logger.warn('TODO: set_dashboard section update')
|
||||
logger.warn('TODO: set_dashboard section remove')
|
||||
|
||||
|
||||
async def set_widget(self, name, title, subtitle, content, notebook=None):
|
||||
"""
|
||||
Create a web widget, that is served by /embed/<name>.
|
||||
"""
|
||||
from gisaf.models.dashboard import Widget
|
||||
now = datetime.now()
|
||||
widget = await Widget.query.where(Widget.name==name).gino.first()
|
||||
kwargs = dict(
|
||||
title=title,
|
||||
subtitle=subtitle,
|
||||
content=content,
|
||||
notebook=notebook,
|
||||
time=now,
|
||||
)
|
||||
if widget:
|
||||
await widget.update(**kwargs).apply()
|
||||
else:
|
||||
await Widget(name=name, **kwargs).create()
|
||||
|
||||
async def to_live_layer(self, gdf, channel, mapbox_paint=None, mapbox_layout=None, properties=None):
|
||||
"""
|
||||
Send a geodataframe to a gisaf server with an HTTP POST request for live map display
|
||||
"""
|
||||
with BytesIO() as buf:
|
||||
dump(gdf, buf, protocol=HIGHEST_PROTOCOL)
|
||||
buf.seek(0)
|
||||
|
||||
async with ClientSession() as session:
|
||||
with MultipartWriter('mixed') as mpwriter:
|
||||
mpwriter.append(buf)
|
||||
if mapbox_paint != None:
|
||||
mpwriter.append_json(mapbox_paint, {'name': 'mapbox_paint'})
|
||||
if mapbox_layout != None:
|
||||
mpwriter.append_json(mapbox_layout, {'name': 'mapbox_layout'})
|
||||
if properties != None:
|
||||
mpwriter.append_json(properties, {'name': 'properties'})
|
||||
async with session.post('{}://{}:{}/api/live/{}'.format(
|
||||
self.conf.gisaf_live['scheme'],
|
||||
self.conf.gisaf_live['hostname'],
|
||||
self.conf.gisaf_live['port'],
|
||||
channel,
|
||||
), data=mpwriter) as resp:
|
||||
return await resp.text()
|
||||
|
||||
async def remove_live_layer(self, channel):
|
||||
"""
|
||||
Remove the channel from Gisaf Live
|
||||
"""
|
||||
async with ClientSession() as session:
|
||||
async with session.get('{}://{}:{}/api/remove-live/{}'.format(
|
||||
self.conf.gisaf_live['scheme'],
|
||||
self.conf.gisaf_live['hostname'],
|
||||
self.conf.gisaf_live['port'],
|
||||
channel
|
||||
)) as resp:
|
||||
return await resp.text()
|
||||
|
||||
def to_layer(self, gdf: gpd.GeoDataFrame, model, project_id=None,
|
||||
skip_columns=None, replace_all=True,
|
||||
chunksize=100):
|
||||
"""
|
||||
Save the geodataframe gdf to the Gisaf model, using pandas' to_sql dataframes' method.
|
||||
Note that it's NOT an async call. Explanations:
|
||||
* to_sql doesn't seems to work with gino/asyncpg
|
||||
* using Gisaf models is few magnitude orders slower
|
||||
(the async code using this technique is left commented out, for reference)
|
||||
"""
|
||||
if skip_columns == None:
|
||||
skip_columns = []
|
||||
|
||||
## Filter empty geometries, and reproject
|
||||
_gdf: gpd.GeoDataFrame = gdf[~gdf.geometry.is_empty].to_crs(self.conf.crs['geojson'])
|
||||
|
||||
## Remove the empty geometries
|
||||
_gdf.dropna(inplace=True, subset=['geometry'])
|
||||
#_gdf['geom'] = _gdf.geom1.apply(lambda geom: from_shape(geom, srid=self.conf.srid))
|
||||
|
||||
for col in skip_columns:
|
||||
if col in _gdf.columns:
|
||||
_gdf.drop(columns=[col], inplace=True)
|
||||
|
||||
_gdf['geom'] = _gdf['geometry'].apply(lambda geom: WKTElement(geom.wkt, srid=self.conf.srid))
|
||||
_gdf.drop(columns=['geometry'], inplace=True)
|
||||
|
||||
engine = create_engine(self.conf.db['uri'], echo=False)
|
||||
|
||||
## Drop existing
|
||||
if replace_all:
|
||||
engine.execute('DELETE FROM "{}"."{}"'.format(model.__table_args__['schema'], model.__tablename__))
|
||||
else:
|
||||
raise NotImplementedError('ipynb_tools.Gisaf.to_layer does not support updates yet')
|
||||
|
||||
## See https://stackoverflow.com/questions/38361336/write-geodataframe-into-sql-database
|
||||
# Use 'dtype' to specify column's type
|
||||
_gdf.to_sql(
|
||||
name=model.__tablename__,
|
||||
con=engine,
|
||||
schema=model.__table_args__['schema'],
|
||||
if_exists='append',
|
||||
index=False,
|
||||
dtype={
|
||||
'geom': model.geom.type,
|
||||
},
|
||||
method='multi',
|
||||
chunksize=chunksize,
|
||||
)
|
||||
|
||||
#async with self.db.transaction() as tx:
|
||||
# if replace_all:
|
||||
# await model.delete.gino.status()
|
||||
# else:
|
||||
# raise NotImplementedError('ipynb_tools.Gisaf.to_layer does not support updates yet')
|
||||
# if not skip_columns:
|
||||
# skip_columns = ['x', 'y', 'z', 'coords']
|
||||
|
||||
# ## Reproject
|
||||
# ggdf = gdf.to_crs(self.conf.crs['geojson'])
|
||||
|
||||
# ## Remove the empty geometries
|
||||
# ggdf.dropna(inplace=True)
|
||||
# #ggdf['geom'] = ggdf.geom1.apply(lambda geom: from_shape(geom, srid=self.conf.srid))
|
||||
|
||||
# for col in skip_columns:
|
||||
# if col in ggdf.columns:
|
||||
# ggdf.drop(columns=[col], inplace=True)
|
||||
|
||||
# #ggdf.set_geometry('geom', inplace=True)
|
||||
|
||||
# if project_id:
|
||||
# ggdf['project_id'] = project_id
|
||||
# ## XXX: index?
|
||||
# gdf_dict = ggdf.to_dict(orient='records')
|
||||
|
||||
# gdf_dict_2 = []
|
||||
# for row in gdf_dict:
|
||||
# geometry = row.pop('geometry')
|
||||
# if not geometry.is_empty:
|
||||
# row['geom'] = str(from_shape(geometry, srid=self.conf.srid))
|
||||
# gdf_dict_2.append(row)
|
||||
|
||||
# result = await model.insert().gino.all(*gdf_dict_2)
|
||||
|
||||
# return
|
||||
|
||||
# for row in gdf_dict:
|
||||
# if 'id' in row:
|
||||
# ## TODO: Existing id: can use merge
|
||||
# ex_item = await model.get(item['id'])
|
||||
# await ex_item.update(**row)
|
||||
# else:
|
||||
# geometry = row.pop('geometry')
|
||||
# if not geometry.is_empty:
|
||||
# feature = model(**row)
|
||||
# feature.geom = from_shape(geometry, srid=self.conf.srid)
|
||||
# await feature.create()
|
||||
# #db.session.commit()
|
||||
|
||||
gisaf = Gisaf()
|
81
src/gisaf/live.py
Normal file
81
src/gisaf/live.py
Normal file
|
@ -0,0 +1,81 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
|
||||
# from .config import conf
|
||||
from .redis_tools import store
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LiveServer:
|
||||
def __init__(self):
|
||||
self.ws_clients = defaultdict(set)
|
||||
|
||||
async def setup(self, listen_to_redis=False, with_mqtt=False):
|
||||
"""
|
||||
Setup for the live server
|
||||
"""
|
||||
if with_mqtt:
|
||||
logger.warning('Gisaf LiveServer does not support with_mqtt: ignoring')
|
||||
if listen_to_redis:
|
||||
self.pub = store.redis.pubsub()
|
||||
await self.pub.psubscribe('live:*:json')
|
||||
asyncio.create_task(self._listen_to_redis())
|
||||
|
||||
async def _listen_to_redis(self):
|
||||
"""
|
||||
Subscribe the redis sub channel to all data ("live:*:json"),
|
||||
and send the messages to websockets
|
||||
"""
|
||||
async for msg in self.pub.listen():
|
||||
if msg['type'] == 'pmessage':
|
||||
await self._send_to_ws_clients(msg['channel'].decode(),
|
||||
msg['data'].decode())
|
||||
|
||||
async def _send_to_ws_clients(self, store_name, json_data):
|
||||
"""
|
||||
Send the json_data to the websoclets which have subscribed
|
||||
to that channel (store_name)
|
||||
"""
|
||||
if len(self.ws_clients[store_name]) > 0:
|
||||
logger.debug(f'WS channel {store_name} got {len(json_data)} bytes to send to:'
|
||||
f' {", ".join([str(id(ws)) for ws in self.ws_clients[store_name]])}')
|
||||
for ws in self.ws_clients[store_name]:
|
||||
if ws.client_state.name != 'CONNECTED':
|
||||
logger.debug(f'Cannot send {store_name} for WS {id(ws)}, state: {ws.client_state.name}')
|
||||
continue
|
||||
try:
|
||||
await ws.send_text(json_data)
|
||||
logger.debug(f'Sent live update for WS {id(ws)}: {len(json_data)}')
|
||||
except RuntimeError as err:
|
||||
## The ws is probably closed, remove it from the clients
|
||||
logger.debug(f'Cannot send live update for {store_name}: {err}')
|
||||
del self.ws_clients[store_name]
|
||||
else:
|
||||
pass
|
||||
#logger.debug(f'WS channel {store_name} has no clients')
|
||||
|
||||
def add_subscription(self, ws, store_name):
|
||||
"""
|
||||
Add the websocket subscription to the layer
|
||||
"""
|
||||
channel = store.get_json_channel(store_name)
|
||||
logger.debug(f'WS {id(ws)} subscribed to {channel}')
|
||||
self.ws_clients[channel].add(ws)
|
||||
|
||||
def remove_subscription(self, ws, store_name):
|
||||
"""
|
||||
Remove the websocket subscription to the layer
|
||||
"""
|
||||
channel = store.get_json_channel(store_name)
|
||||
if ws in self.ws_clients[channel]:
|
||||
self.ws_clients[channel].remove(ws)
|
||||
|
||||
|
||||
async def setup_live():
|
||||
global live_server
|
||||
await live_server.setup(listen_to_redis=True)
|
||||
|
||||
live_server = LiveServer()
|
31
src/gisaf/models/live.py
Normal file
31
src/gisaf/models/live.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
# from pydantic import BaseModel, Field
|
||||
# from .geo_models_base import GeoModel
|
||||
|
||||
# class LiveModel(GeoModel):
|
||||
# attribution: str | None = None
|
||||
# # auto_import:
|
||||
# category: str | None = None
|
||||
# count: int
|
||||
# custom: bool = False
|
||||
# description: str
|
||||
# gisType: str
|
||||
# group: str
|
||||
# icon: str | None = None
|
||||
# is_db: bool = True
|
||||
# is_live: bool
|
||||
# name: str
|
||||
# rawSurveyStore: str | None = None
|
||||
# store: str
|
||||
# style: str | None = None
|
||||
# symbol: str
|
||||
# tagPlugins: list[str] = []
|
||||
# type: str
|
||||
# viewableRole: str | None = None
|
||||
# z_index: int = Field(..., alias='zIndex')
|
||||
|
||||
|
||||
# class GeomGroup(BaseModel):
|
||||
# name: str
|
||||
# title: str
|
||||
# description: str
|
||||
# models: list[GeoModel]
|
|
@ -90,14 +90,12 @@ class Store:
|
|||
- redis: RedisConnection
|
||||
- pub (/sub) connections
|
||||
"""
|
||||
async def setup(self, app):
|
||||
async def setup(self):
|
||||
"""
|
||||
Setup the live service for the main Gisaf application:
|
||||
- Create connection for the publishers
|
||||
- Create connection for redis listeners (websocket service)
|
||||
"""
|
||||
self.app = app
|
||||
app.extra['store'] = self
|
||||
await self.create_connections()
|
||||
await self.get_live_layer_defs()
|
||||
|
||||
|
@ -187,7 +185,7 @@ class Store:
|
|||
if 'popup' not in gdf.columns:
|
||||
gdf['popup'] = 'Live: ' + live_name + ' #' + gdf.index.astype('U')
|
||||
if len(gdf) > 0:
|
||||
gdf = gdf.to_crs(conf.crs['geojson'])
|
||||
gdf = gdf.to_crs(conf.crs.geojson)
|
||||
gis_type = gdf.geom_type.iloc[0]
|
||||
else:
|
||||
gis_type = 'Point' ## FIXME: cannot be inferred from the gdf?
|
||||
|
@ -240,8 +238,7 @@ class Store:
|
|||
await self.redis.set(self.get_layer_def_channel(store_name), layer_def_data)
|
||||
|
||||
## Update the layers/stores registry
|
||||
if hasattr(self, 'app'):
|
||||
await self.get_live_layer_defs()
|
||||
await self.get_live_layer_defs()
|
||||
|
||||
return geojson
|
||||
|
||||
|
@ -259,8 +256,7 @@ class Store:
|
|||
await self.redis.delete(self.get_mapbox_paint_channel(store_name))
|
||||
|
||||
## Update the layers/stores registry
|
||||
if hasattr(self, 'app'):
|
||||
await self.get_live_layer_defs()
|
||||
await self.get_live_layer_defs()
|
||||
|
||||
async def has_channel(self, store_name):
|
||||
return len(await self.redis.keys(self.get_json_channel(store_name))) > 0
|
||||
|
@ -274,7 +270,7 @@ class Store:
|
|||
async def get_layer_def(self, store_name):
|
||||
return loads(await self.redis.get(self.get_layer_def_channel(store_name)))
|
||||
|
||||
async def get_live_layer_defs(self) -> list[LiveGeoModel]:
|
||||
async def get_live_layer_defs(self): # -> list[LiveGeoModel]:
|
||||
registry.geom_live_defs = {}
|
||||
for channel in sorted(await self.get_live_layer_def_channels()):
|
||||
model_info = loads(await self.redis.get(channel))
|
||||
|
@ -370,8 +366,6 @@ class Store:
|
|||
- listen to the DB event emitter: setup a callback function
|
||||
"""
|
||||
## Setup the function and triggers on tables
|
||||
db = self.app['db']
|
||||
|
||||
## Keep the connection alive: don't use a "with" block
|
||||
## It needs to be closed correctly: see _close_permanant_db_connection
|
||||
self._permanent_conn = await db.acquire()
|
||||
|
@ -419,17 +413,17 @@ class Store:
|
|||
await self._permanent_conn.release()
|
||||
|
||||
|
||||
async def setup_redis(app):
|
||||
async def setup_redis():
|
||||
global store
|
||||
await store.setup(app)
|
||||
await store.setup()
|
||||
|
||||
|
||||
async def setup_redis_cache(app):
|
||||
async def setup_redis_cache():
|
||||
global store
|
||||
await store._setup_db_cache_system()
|
||||
|
||||
|
||||
async def shutdown_redis(app):
|
||||
async def shutdown_redis():
|
||||
global store
|
||||
await store._close_permanant_db_connection()
|
||||
|
||||
|
|
|
@ -11,9 +11,7 @@ from typing import Any, ClassVar
|
|||
from pydantic import create_model
|
||||
from sqlalchemy import inspect, text
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlmodel import select
|
||||
|
||||
import numpy as np
|
||||
from sqlmodel import SQLModel, select
|
||||
import pandas as pd
|
||||
|
||||
from .config import conf
|
||||
|
@ -23,6 +21,7 @@ from .models.geo_models_base import (
|
|||
LiveGeoModel,
|
||||
PlottableModel,
|
||||
GeoModel,
|
||||
SurveyModel,
|
||||
RawSurveyBaseModel,
|
||||
LineWorkSurveyModel,
|
||||
GeoPointSurveyModel,
|
||||
|
@ -32,6 +31,7 @@ from .models.geo_models_base import (
|
|||
from .utils import ToMigrate
|
||||
from .models.category import Category, CategoryGroup
|
||||
from .database import db_session
|
||||
from . import models
|
||||
from .models.metadata import survey, raw_survey
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -71,23 +71,32 @@ class ModelRegistry:
|
|||
Provides tools to get the models from their names, table names, etc.
|
||||
"""
|
||||
stores: pd.DataFrame
|
||||
values: dict[str, PlottableModel]
|
||||
geom_live: dict[str, LiveGeoModel]
|
||||
geom_live_defs: dict[str, dict[str, Any]]
|
||||
geom_custom: dict[str, GeoModel]
|
||||
geom_custom_store: dict[str, Any]
|
||||
other: dict[str, SQLModel]
|
||||
misc: dict[str, SQLModel]
|
||||
raw_survey_models: dict[str, RawSurveyBaseModel]
|
||||
survey_models: dict[str, SurveyModel]
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Get geo models
|
||||
:return: None
|
||||
"""
|
||||
self.geom_custom = {}
|
||||
self.geom_custom_store = {}
|
||||
self.geom_live: dict[str, LiveGeoModel] = {}
|
||||
self.geom_live_defs: dict[str, dict[str, Any]] = {}
|
||||
self.geom_live = {}
|
||||
self.geom_live_defs = {}
|
||||
self.values = {}
|
||||
self.other = {}
|
||||
self.misc = {}
|
||||
self.raw_survey_models = {}
|
||||
self.survey_models = {}
|
||||
|
||||
async def make_registry(self, app=None):
|
||||
async def make_registry(self):
|
||||
"""
|
||||
Make (or refresh) the registry of models.
|
||||
:return:
|
||||
|
@ -98,10 +107,7 @@ class ModelRegistry:
|
|||
await self.build()
|
||||
## If ogcapi is in app (i.e. not with scheduler):
|
||||
## Now that the models are refreshed, tells the ogcapi to (re)build
|
||||
if app:
|
||||
#app.extra['registry'] = self
|
||||
if 'ogcapi' in app.extra:
|
||||
await app.extra['ogcapi'].build()
|
||||
#await app.extra['ogcapi'].build()
|
||||
|
||||
async def make_category_models(self):
|
||||
"""
|
||||
|
@ -190,14 +196,12 @@ class ModelRegistry:
|
|||
which are defined by categories), and store them for reference.
|
||||
"""
|
||||
logger.debug('scan')
|
||||
from . import models # nocheck
|
||||
|
||||
## Scan the models defined in modules
|
||||
for module_name, module in import_submodules(models).items():
|
||||
if module_name in (
|
||||
'src.gisaf.models.geo_models_base',
|
||||
'src.gisaf.models.models_base',
|
||||
|
||||
if module_name.rsplit('.', 1)[-1] in (
|
||||
'geo_models_base',
|
||||
'models_base',
|
||||
):
|
||||
continue
|
||||
for name in dir(module):
|
||||
|
@ -630,13 +634,13 @@ class ModelRegistry:
|
|||
'live': 'is_live',
|
||||
'zIndex': 'z_index',
|
||||
'gisType': 'model_type',
|
||||
'type': 'mapbox_type',
|
||||
# 'type': 'mapbox_type',
|
||||
'viewableRole': 'viewable_role',
|
||||
}, inplace=True
|
||||
)
|
||||
## Add columns
|
||||
df_live['auto_import'] = False
|
||||
df_live['base_gis_type'] = df_live['model_type']
|
||||
df_live['base_gis_type'] = df_live['gis_type']
|
||||
df_live['custom'] = False
|
||||
df_live['group'] = ''
|
||||
df_live['in_menu'] = True
|
||||
|
|
310
src/gisaf/scheduler.py
Executable file
310
src/gisaf/scheduler.py
Executable file
|
@ -0,0 +1,310 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
Gisaf task scheduler, orchestrating the background tasks
|
||||
like remote device data collection, etc.
|
||||
"""
|
||||
import os
|
||||
import logging
|
||||
import sys
|
||||
import asyncio
|
||||
from json import dumps
|
||||
from datetime import datetime
|
||||
from importlib.metadata import entry_points
|
||||
from typing import Any, Mapping, List
|
||||
|
||||
from fastapi import FastAPI
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
# from apscheduler import SchedulerStarted
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from apscheduler.triggers.date import DateTrigger
|
||||
|
||||
from .ipynb_tools import Gisaf
|
||||
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s:%(levelname)s:%(name)s:%(message)s",
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
for handler in logging.root.handlers:
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
logger = logging.getLogger('gisaf.scheduler')
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_prefix='gisaf_scheduler_')
|
||||
app_name: str = 'Gisaf scheduler'
|
||||
job_names: List[str] = []
|
||||
exclude_job_names: List[str] = []
|
||||
list: bool = False
|
||||
|
||||
|
||||
class JobBaseClass:
|
||||
"""
|
||||
Base class for all the jobs.
|
||||
"""
|
||||
task_id = None
|
||||
interval = None
|
||||
cron = None
|
||||
enabled = True
|
||||
type = '' ## interval, cron or longrun
|
||||
sched_params = ''
|
||||
name = '<unnammed task>'
|
||||
features = None
|
||||
def __init__(self):
|
||||
self.last_run = None
|
||||
self.current_run = None
|
||||
|
||||
async def get_feature_ids(self):
|
||||
"""
|
||||
Subclasses might define a get_features function to inform the
|
||||
front-ends about the map features it works on.
|
||||
The scheduler runs this on startup.
|
||||
"""
|
||||
return []
|
||||
|
||||
async def run(self):
|
||||
"""
|
||||
Subclasses should define a run async function to run
|
||||
"""
|
||||
logger.info(f'Noop defined for {self.name}')
|
||||
|
||||
|
||||
class JobScheduler:
|
||||
gs: Gisaf
|
||||
jobs: dict[str, Any]
|
||||
tasks: dict[str, Any]
|
||||
wss: dict[str, Any]
|
||||
subscribers: set[Any]
|
||||
scheduler: AsyncIOScheduler
|
||||
def __init__(self):
|
||||
#self.redis_store = gs.app['store']
|
||||
self.jobs = {}
|
||||
self.tasks = {}
|
||||
self.wss = {}
|
||||
self.subscribers = set()
|
||||
self.scheduler = AsyncIOScheduler()
|
||||
|
||||
def start(self):
|
||||
self.scheduler.start()
|
||||
|
||||
def scheduler_event_listener(self, event):
|
||||
asyncio.create_task(self.scheduler_event_alistener(event))
|
||||
|
||||
async def scheduler_event_alistener(self, event):
|
||||
if isinstance(event, SchedulerStarted):
|
||||
pid = os.getpid()
|
||||
logger.debug(f'Scheduler started, pid={pid}')
|
||||
#await self.gs.app['store'].pub.set('_scheduler/pid', pid)
|
||||
|
||||
async def job_event_added(self, event):
|
||||
task = await self.scheduler.data_store.get_task(event.task_id)
|
||||
schedules = [ss for ss in await self.scheduler.get_schedules()
|
||||
if ss.task_id == event.task_id]
|
||||
if len(schedules) > 1:
|
||||
logger.warning(f'More than 1 schedule matching task {event.task_id}')
|
||||
return
|
||||
else:
|
||||
schedule = schedules[0]
|
||||
|
||||
async def job_acquired(self, event):
|
||||
pass
|
||||
|
||||
async def job_cancelled(self, event):
|
||||
pass
|
||||
|
||||
async def job_released(self, event):
|
||||
pass
|
||||
|
||||
# task = self.tasks.get(event.job_id)
|
||||
# if not task:
|
||||
# breakpoint()
|
||||
# logger.warning(f'Got an event {event} for unregistered task {event.task_id}')
|
||||
# return
|
||||
# if isinstance(event, apscheduler.JobCancelled): #events.EVENT_JOB_ERROR:
|
||||
# msg = f'"{task.name}" cancelled ({task.task_id})'
|
||||
# task.last_run = event
|
||||
# task.current_run = None
|
||||
# logger.warning(msg)
|
||||
# ## TODO: try to restart the task
|
||||
# elif isinstance(event, apscheduler.JobAcquired): #events.EVENT_JOB_SUBMITTED:
|
||||
# ## XXX: should be task.last_run = None
|
||||
# task.last_run = event
|
||||
# task.current_run = event
|
||||
# msg = f'"{task.name}" started ({task.task_id})'
|
||||
# elif isinstance(event, apscheduler.JobReleased): #events.EVENT_JOB_EXECUTED:
|
||||
# task.last_run = event
|
||||
# task.current_run = None
|
||||
# msg = f'"{task.name}" worked ({task.task_id})'
|
||||
# else:
|
||||
# logger.info(f'*********** Unhandled event: {event}')
|
||||
# pass
|
||||
# #await self.send_to_redis_store(task, event, msg)
|
||||
|
||||
# ## Send to notification graphql websockets subscribers
|
||||
# for queue in self.subscribers:
|
||||
# queue.put_nowait((task, event))
|
||||
|
||||
# ## Send raw messages through websockets
|
||||
# await self.send_to_websockets(task, event, msg)
|
||||
|
||||
async def send_to_redis_store(self, job, event, msg):
|
||||
"""
|
||||
Send to Redis store
|
||||
"""
|
||||
try:
|
||||
self.gs.app['store'].pub.publish(
|
||||
'admin:scheduler:json',
|
||||
dumps({'msg': msg})
|
||||
)
|
||||
except Exception as err:
|
||||
logger.warning(f'Cannot publish updates for "{job.name}" to Redis: {err}')
|
||||
logger.exception(err)
|
||||
|
||||
async def send_to_websockets(self, job, event, msg):
|
||||
"""
|
||||
Send to all connected websockets
|
||||
"""
|
||||
for ws in self.wss.values():
|
||||
asyncio.create_task(
|
||||
ws.send_json({
|
||||
'msg': msg
|
||||
})
|
||||
)
|
||||
|
||||
def add_subscription(self, ws):
|
||||
self.wss[id(ws)] = ws
|
||||
|
||||
def delete_subscription(self, ws):
|
||||
del self.wss[id(ws)]
|
||||
|
||||
def get_available_jobs(self):
|
||||
return [
|
||||
entry_point.name
|
||||
for entry_point in entry_points().select(group='gisaf_jobs')
|
||||
]
|
||||
|
||||
async def setup(self, job_names=None, exclude_job_names=None):
|
||||
if job_names is None:
|
||||
job_names = []
|
||||
if exclude_job_names is None:
|
||||
exclude_job_names = []
|
||||
|
||||
## Go through entry points and define the tasks
|
||||
for entry_point in entry_points().select(group='gisaf_jobs'):
|
||||
## Eventually skip task according to arguments of the command line
|
||||
if (entry_point.name in exclude_job_names) \
|
||||
or ((len(job_names) > 0) and entry_point.name not in job_names):
|
||||
logger.info(f'Skip task {entry_point.name}')
|
||||
continue
|
||||
|
||||
try:
|
||||
task_class = entry_point.load()
|
||||
except Exception as err:
|
||||
logger.error(f'Task {entry_point.name} skipped cannot be loaded: {err}')
|
||||
continue
|
||||
|
||||
## Create the task instance
|
||||
try:
|
||||
task = task_class(self.gs)
|
||||
except Exception as err:
|
||||
logger.error(f'Task {entry_point.name} cannot be instanciated: {err}')
|
||||
continue
|
||||
task.name = entry_point.name
|
||||
|
||||
if not task.enabled:
|
||||
logger.debug(f'Job "{entry_point.name}" disabled')
|
||||
continue
|
||||
|
||||
logger.debug(f'Add task "{entry_point.name}"')
|
||||
if not hasattr(task, 'run'):
|
||||
logger.error(f'Task {entry_point.name} skipped: no run method')
|
||||
continue
|
||||
task.features = await task.get_feature_ids()
|
||||
kwargs: dict[str: Any] = {
|
||||
# 'tags': [entry_point.name],
|
||||
}
|
||||
|
||||
if isinstance(task.interval, dict):
|
||||
kwargs['trigger'] = IntervalTrigger(**task.interval)
|
||||
task.type = 'interval'
|
||||
## TODO: format user friendly text for interval
|
||||
task.sched_params = get_pretty_format_interval(task.interval)
|
||||
elif isinstance(task.cron, dict):
|
||||
## FIXME: CronTrigger
|
||||
kwargs['trigger'] = CronTrigger(**task.cron)
|
||||
kwargs.update(task.cron)
|
||||
task.type = 'cron'
|
||||
## TODO: format user friendly text for cron
|
||||
task.sched_params = get_pretty_format_cron(task.cron)
|
||||
else:
|
||||
task.type = 'longrun'
|
||||
task.sched_params = 'always running'
|
||||
kwargs['trigger'] = DateTrigger(datetime.now())
|
||||
# task.task_id = await self.scheduler.add_job(task.run, **kwargs)
|
||||
# self.tasks[task.task_id] = task
|
||||
# continue
|
||||
|
||||
## Create the APScheduler task
|
||||
try:
|
||||
task.task_id = await self.scheduler.add_schedule(task.run, **kwargs)
|
||||
except Exception as err:
|
||||
logger.warning(f'Cannot add task {entry_point.name}: {err}')
|
||||
logger.exception(err)
|
||||
else:
|
||||
logger.info(f'Job "{entry_point.name}" added ({task.task_id})')
|
||||
self.tasks[task.task_id] = task
|
||||
|
||||
## Subscribe to all events
|
||||
# self.scheduler.subscribe(self.job_acquired, JobAcquired)
|
||||
# self.scheduler.subscribe(self.job_cancelled, JobCancelled)
|
||||
# self.scheduler.subscribe(self.job_released, JobReleased)
|
||||
# self.scheduler.subscribe(self.job_event_added, JobAdded)
|
||||
# self.scheduler.subscribe(self.scheduler_event_listener, SchedulerEvent)
|
||||
|
||||
|
||||
class GSFastAPI(FastAPI):
|
||||
js: JobScheduler
|
||||
|
||||
|
||||
allowed_interval_params = set(('seconds', 'minutes', 'hours', 'days', 'weeks'))
|
||||
def get_pretty_format_interval(params):
|
||||
"""
|
||||
Return a format for describing interval
|
||||
"""
|
||||
return str({
|
||||
k: v for k, v in params.items()
|
||||
if k in allowed_interval_params
|
||||
})
|
||||
|
||||
def get_pretty_format_cron(params):
|
||||
"""
|
||||
Return a format for describing cron
|
||||
"""
|
||||
return str(params)
|
||||
|
||||
|
||||
async def startup(settings):
|
||||
if settings.list:
|
||||
## Just print avalable jobs and exit
|
||||
jobs = js.get_available_jobs()
|
||||
print(' '.join(jobs))
|
||||
sys.exit(0)
|
||||
# try:
|
||||
# await js.gs.setup()
|
||||
# await js.gs.make_models()
|
||||
# except Exception as err:
|
||||
# logger.error('Cannot setup Gisaf')
|
||||
# logger.exception(err)
|
||||
# sys.exit(1)
|
||||
try:
|
||||
await js.setup(job_names=settings.job_names,
|
||||
exclude_job_names=settings.exclude_job_names)
|
||||
except Exception as err:
|
||||
logger.error('Cannot setup scheduler')
|
||||
logger.exception(err)
|
||||
sys.exit(1)
|
||||
|
||||
js = JobScheduler()
|
57
src/gisaf/scheduler_application.py
Executable file
57
src/gisaf/scheduler_application.py
Executable file
|
@ -0,0 +1,57 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
Gisaf job scheduler, orchestrating the background tasks
|
||||
like remote device data collection, etc.
|
||||
"""
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from starlette.routing import Mount
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
|
||||
from .config import conf
|
||||
from .ipynb_tools import gisaf
|
||||
from .scheduler import GSFastAPI, js, startup, Settings
|
||||
from .scheduler_web import app as sched_app
|
||||
|
||||
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s:%(levelname)s:%(name)s:%(message)s",
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
for handler in logging.root.handlers:
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
logging.basicConfig(level=conf.gisaf.debugLevel)
|
||||
logger = logging.getLogger('gisaf.scheduler_application')
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: GSFastAPI):
|
||||
'''
|
||||
Handle startup and shutdown: setup scheduler, etc
|
||||
'''
|
||||
## Startup
|
||||
await gisaf.setup()
|
||||
await startup(settings)
|
||||
js.start()
|
||||
yield
|
||||
## Shutdown
|
||||
pass
|
||||
|
||||
|
||||
settings = Settings()
|
||||
app = GSFastAPI(
|
||||
title=settings.app_name,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=['*'],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
app.mount('/_sched', sched_app)
|
||||
app.mount('/sched', sched_app)
|
169
src/gisaf/scheduler_web.py
Normal file
169
src/gisaf/scheduler_web.py
Normal file
|
@ -0,0 +1,169 @@
|
|||
"""
|
||||
The web API for Gisaf scheduler
|
||||
"""
|
||||
import logging
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
from typing import List
|
||||
|
||||
from fastapi import Request, WebSocket
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from redis import asyncio as aioredis
|
||||
from pandas import DataFrame
|
||||
|
||||
from gisaf.live import live_server
|
||||
from gisaf.scheduler import GSFastAPI
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = GSFastAPI()
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=['*'],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
class Subscriber:
|
||||
## See https://gist.github.com/appeltel/fd3ddeeed6c330c7208502462639d2c9
|
||||
def __init__(self, hub):
|
||||
self.hub = hub
|
||||
self.queue = asyncio.Queue()
|
||||
|
||||
def __enter__(self):
|
||||
self.hub.subscribers.add(self.queue)
|
||||
return self.queue
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
self.hub.subscribers.remove(self.queue)
|
||||
|
||||
|
||||
class JobEvent(BaseModel):
|
||||
jobId: str | UUID
|
||||
time: datetime | None
|
||||
status: str
|
||||
msg: str
|
||||
nextRunTime: datetime | None
|
||||
|
||||
|
||||
class Feature(BaseModel):
|
||||
store: str
|
||||
## XXX: Using "id" gives very strange issue with apollo client
|
||||
id_: str
|
||||
|
||||
|
||||
class Job_(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
schedParams: str
|
||||
nextRunTime: datetime | None
|
||||
lastRun: JobEvent | None
|
||||
features: List[Feature]
|
||||
|
||||
|
||||
class Task_(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
type: str
|
||||
schedParams: str
|
||||
nextRunTime: datetime | None
|
||||
lastRunTime: datetime | None
|
||||
features: list[Feature]
|
||||
|
||||
|
||||
def df_as_ObjectTypes(df):
|
||||
"""
|
||||
Utility function that returns List(Feature) graphql from a dataframe.
|
||||
The dataframe must contain a 'store' column and the feature ids as index.
|
||||
"""
|
||||
if not isinstance(df, DataFrame):
|
||||
return []
|
||||
if 'store' not in df.columns:
|
||||
# logger.warning(f'no store in get_feature_ids() for job "{job.name}"')
|
||||
return []
|
||||
return [
|
||||
Feature(id_=str(f[0]), store=f.store)
|
||||
for f in df.itertuples(index=True)
|
||||
]
|
||||
|
||||
|
||||
@app.websocket('/events')
|
||||
async def scheduler_ws(
|
||||
ws: WebSocket,
|
||||
):
|
||||
"""
|
||||
Websocket for scheduler updates
|
||||
"""
|
||||
#session = await get_session(request)
|
||||
#js = request.app.js
|
||||
await ws.accept()
|
||||
while True:
|
||||
# msg_text = await ws.receive_text()
|
||||
msg_data = await ws.receive_json()
|
||||
#await websocket.send_text(f"Message text was: {data}")
|
||||
if 'message' in msg_data:
|
||||
if msg_data['message'] == 'subscribe':
|
||||
live_server.add_subscription(ws, 'admin:scheduler')
|
||||
ws.app.js.add_subscription(ws)
|
||||
elif msg_data['message'] == 'unsubscribe':
|
||||
live_server.remove_subscription(ws, 'admin:scheduler')
|
||||
ws.app.js.delete_subscription(ws)
|
||||
|
||||
|
||||
@app.websocket('/subscriptions')
|
||||
async def subscriptions(ws: WebSocket):
|
||||
await ws.accept()
|
||||
while True:
|
||||
msg_data = await ws.receive_json()
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
if msg.data == 'close':
|
||||
await ws.close()
|
||||
else:
|
||||
await ws.send_str(msg.data + '/answer')
|
||||
elif msg.type == WSMsgType.ERROR:
|
||||
print('ws connection closed with exception %s' % ws.exception())
|
||||
|
||||
|
||||
@app.get('/time')
|
||||
async def get_time():
|
||||
return datetime.now()
|
||||
|
||||
|
||||
@app.get('/jobs')
|
||||
async def get_jobs(request: Request) -> list[Task_]:
|
||||
app: GSFastAPI = request.app
|
||||
tasks = {task.id: task for task in await app.js.scheduler.data_store.get_tasks()}
|
||||
tasks_ = []
|
||||
for schedule in await app.js.scheduler.data_store.get_schedules():
|
||||
task = tasks[schedule.task_id]
|
||||
task_ = app.js.tasks[schedule.id]
|
||||
tasks_.append(
|
||||
Task_(
|
||||
id=task.id,
|
||||
name=task.id,
|
||||
type=schedule.trigger.__class__.__name__,
|
||||
schedParams='',
|
||||
lastRunTime=schedule.last_fire_time,
|
||||
nextRunTime=schedule.next_fire_time,
|
||||
features=df_as_ObjectTypes(task_.features),
|
||||
)
|
||||
)
|
||||
return tasks_
|
||||
|
||||
|
||||
# async def setup_app_session(app):
|
||||
# """
|
||||
# Setup a redis pool for session management
|
||||
# Not related to the redis connection used by Gisaf
|
||||
# """
|
||||
# redis = aioredis.from_url('redis://localhost')
|
||||
# redis_storage = RedisStorage(redis)
|
||||
# session_identity_policy = SessionIdentityPolicy()
|
||||
# setup_session(app, redis_storage)
|
Loading…
Add table
Add a link
Reference in a new issue