Code cosmetic

This commit is contained in:
phil 2024-05-11 09:52:34 +02:00
parent 4be1ae4a0d
commit ae19ba9f27
3 changed files with 89 additions and 62 deletions

View file

@ -2,13 +2,22 @@
Geographical json stores, served under /gj Geographical json stores, served under /gj
Used for displaying features on maps Used for displaying features on maps
""" """
from json import JSONDecodeError from json import JSONDecodeError
import logging import logging
from typing import Annotated from typing import Annotated
from asyncio import CancelledError from asyncio import CancelledError
from fastapi import (Depends, APIRouter, HTTPException, Response, Header, from fastapi import (
WebSocket, WebSocketDisconnect, status) Depends,
APIRouter,
HTTPException,
Response,
Header,
WebSocket,
WebSocketDisconnect,
status,
)
from gisaf.models.authentication import User from gisaf.models.authentication import User
from gisaf.redis_tools import store as redis_store from gisaf.redis_tools import store as redis_store
@ -25,8 +34,10 @@ api = APIRouter(
responses={404: {"description": "Not found"}}, responses={404: {"description": "Not found"}},
) )
class ConnectionManager: class ConnectionManager:
active_connections: list[WebSocket] active_connections: list[WebSocket]
def __init__(self): def __init__(self):
self.active_connections = [] self.active_connections = []
@ -44,9 +55,11 @@ class ConnectionManager:
for connection in self.active_connections: for connection in self.active_connections:
await connection.send_text(message) await connection.send_text(message)
manager = ConnectionManager() manager = ConnectionManager()
@api.websocket('/live/{store}')
@api.websocket("/live/{store}")
async def live_layer(store: str, websocket: WebSocket): async def live_layer(store: str, websocket: WebSocket):
""" """
Websocket for live layer updates Websocket for live layer updates
@ -58,29 +71,33 @@ async def live_layer(store: str, websocket: WebSocket):
msg_data = await websocket.receive_json() msg_data = await websocket.receive_json()
except JSONDecodeError: except JSONDecodeError:
msg_text = await websocket.receive_text() msg_text = await websocket.receive_text()
if msg_text == 'close': if msg_text == "close":
await websocket.close() await websocket.close()
continue continue
# else: # else:
if 'message' in msg_data: if "message" in msg_data:
if msg_data['message'] == 'subscribeLiveLayer': if msg_data["message"] == "subscribeLiveLayer":
live_server.add_subscription(websocket, store) live_server.add_subscription(websocket, store)
elif msg_data['message'] == 'unsubscribeLiveLayer': elif msg_data["message"] == "unsubscribeLiveLayer":
live_server.remove_subscription(websocket, store) live_server.remove_subscription(websocket, store)
else: else:
logger.warning(f'Got websocket message with no message field: {msg_data}') logger.warning(
f"Got websocket message with no message field: {msg_data}"
)
except WebSocketDisconnect: except WebSocketDisconnect:
logger.debug('Websocket disconnected') logger.debug("Websocket disconnected")
# logger.debug('websocket connection closed') # logger.debug('websocket connection closed')
@api.get('/{store_name}')
async def get_geojson(store_name, @api.get("/{store_name}")
async def get_geojson(
store_name,
user: User = Depends(get_current_active_user), user: User = Depends(get_current_active_user),
If_None_Match: Annotated[str | None, Header()] = None, If_None_Match: Annotated[str | None, Header()] = None,
simplify: Annotated[float | None, Header()] = None, simplify: Annotated[float | None, Header()] = None,
preserveTopology: Annotated[bool|None, Header()] = None, preserveTopology: Annotated[bool | None, Header()] = None,
): ):
""" """
Some REST stores coded manually (route prefixed with "gj": geojson). Some REST stores coded manually (route prefixed with "gj": geojson).
:param store_name: name of the model :param store_name: name of the model
@ -91,23 +108,23 @@ async def get_geojson(store_name,
model = registry.stores.loc[store_name].model model = registry.stores.loc[store_name].model
except KeyError: except KeyError:
raise HTTPException(status.HTTP_404_NOT_FOUND) raise HTTPException(status.HTTP_404_NOT_FOUND)
if getattr(model, 'viewable_role', None): if getattr(model, "viewable_role", None):
if not(user and user.can_view(model)): if not (user and user.can_view(model)):
username = user.username if user else "Anonymous" username = user.username if user else "Anonymous"
logger.info(f'{username} tried to access {model}') logger.info(f"{username} tried to access {model}")
raise HTTPException(status.HTTP_401_UNAUTHORIZED) raise HTTPException(status.HTTP_401_UNAUTHORIZED)
if await redis_store.has_channel(store_name): if await redis_store.has_channel(store_name):
## Live layers ## Live layers
data = await redis_store.get_layer_as_json(store_name) data = await redis_store.get_layer_as_json(store_name)
return Response(content=data.decode(), return Response(content=data.decode(), media_type="application/json")
media_type="application/json")
if model.cache_enabled: if model.cache_enabled:
ttag = await redis_store.get_ttag(store_name) ttag = await redis_store.get_ttag(store_name)
if ttag and If_None_Match == ttag: if ttag and If_None_Match == ttag:
raise HTTPException(status.HTTP_304_NOT_MODIFIED) raise HTTPException(status.HTTP_304_NOT_MODIFIED)
if hasattr(model, 'get_geojson'): if hasattr(model, "get_geojson"):
geojson = await model.get_geojson(simplify_tolerance=simplify, geojson = await model.get_geojson(
preserve_topology=preserveTopology) simplify_tolerance=simplify, preserve_topology=preserveTopology
)
## Store to redis for caching ## Store to redis for caching
if use_cache: if use_cache:
await redis_store.store_json(model, geojson) await redis_store.store_json(model, geojson)
@ -117,12 +134,15 @@ async def get_geojson(store_name,
## get_popup and get_propertites get the gdf as argument ## get_popup and get_propertites get the gdf as argument
## and can use vectorised operations ## and can use vectorised operations
try: try:
gdf = await model.get_gdf(cast=True, with_related=True, gdf = await model.get_gdf(
cast=True,
with_related=True,
# filter_columns=True, # filter_columns=True,
preserve_topology=preserveTopology, preserve_topology=preserveTopology,
simplify_tolerance=simplify) simplify_tolerance=simplify,
)
except CancelledError as err: except CancelledError as err:
logger.debug(f'Getting {store_name} cancelled while getting gdf') logger.debug(f"Getting {store_name} cancelled while getting gdf")
raise err raise err
except Exception as err: except Exception as err:
logger.exception(err) logger.exception(err)
@ -130,25 +150,26 @@ async def get_geojson(store_name,
## The query of category defined models gets the status ## The query of category defined models gets the status
## (not sure how and this could be skipped) ## (not sure how and this could be skipped)
## Other models do not have: just add it manually from the model itself ## Other models do not have: just add it manually from the model itself
if 'status' not in gdf.columns: if "status" not in gdf.columns:
gdf['status'] = model.status gdf["status"] = model.status
if 'popup' not in gdf.columns: if "popup" not in gdf.columns:
gdf['popup'] = await model.get_popup(gdf) gdf["popup"] = await model.get_popup(gdf)
# Add properties # Add properties
properties = await model.get_properties(gdf) properties = await model.get_properties(gdf)
columns = ['geom', 'status', 'popup'] columns = ["geom", "status", "popup"]
for property, values in properties.items(): for property, values in properties.items():
columns.append(property) columns.append(property)
gdf[property] = values gdf[property] = values
geojson = gdf[columns].to_json(separators=(',', ':'), geojson = gdf[columns].to_json(separators=(",", ":"), check_circular=False)
check_circular=False)
## Store to redis for caching ## Store to redis for caching
if use_cache: if use_cache:
await redis_store.store_json(model, geojson) await redis_store.store_json(model, geojson)
resp = geojson resp = geojson
else: else:
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, raise HTTPException(
detail='Gino is for: Gino Is No Option') status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Gino is for: Gino Is No Option",
)
# logger.warn(f"{model} doesn't allow using dataframe for generating json!") # logger.warn(f"{model} doesn't allow using dataframe for generating json!")
# attrs, features_kwargs = await model.get_features_attrs(simplify) # attrs, features_kwargs = await model.get_features_attrs(simplify)
# ## Using gino: allows OO model (get_info, etc) # ## Using gino: allows OO model (get_info, etc)
@ -161,9 +182,8 @@ async def get_geojson(store_name,
headers = {} headers = {}
if model.cache_enabled and ttag: if model.cache_enabled and ttag:
headers['ETag'] = ttag headers["ETag"] = ttag
return Response(content=resp, return Response(content=resp, media_type="application/json", headers=headers)
media_type="application/json", headers=headers)
# @api.get('/gj/{store_name}/popup/{id}') # @api.get('/gj/{store_name}/popup/{id}')

View file

@ -9,6 +9,7 @@ from gisaf.redis_tools import store
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LiveServer: class LiveServer:
def __init__(self): def __init__(self):
self.ws_clients = defaultdict(set) self.ws_clients = defaultdict(set)
@ -18,10 +19,10 @@ class LiveServer:
Setup for the live server Setup for the live server
""" """
if with_mqtt: if with_mqtt:
logger.warning('Gisaf LiveServer does not support with_mqtt: ignoring') logger.warning("Gisaf LiveServer does not support with_mqtt: ignoring")
if listen_to_redis: if listen_to_redis:
self.pub = store.redis.pubsub() self.pub = store.redis.pubsub()
await self.pub.psubscribe('live:*:json') await self.pub.psubscribe("live:*:json")
asyncio.create_task(self._listen_to_redis()) asyncio.create_task(self._listen_to_redis())
async def _listen_to_redis(self): async def _listen_to_redis(self):
@ -30,9 +31,10 @@ class LiveServer:
and send the messages to websockets and send the messages to websockets
""" """
async for msg in self.pub.listen(): async for msg in self.pub.listen():
if msg['type'] == 'pmessage': if msg["type"] == "pmessage":
await self._send_to_ws_clients(msg['channel'].decode(), await self._send_to_ws_clients(
msg['data'].decode()) msg["channel"].decode(), msg["data"].decode()
)
async def _send_to_ws_clients(self, store_name, json_data): async def _send_to_ws_clients(self, store_name, json_data):
""" """
@ -40,32 +42,36 @@ class LiveServer:
to that channel (store_name) to that channel (store_name)
""" """
if len(self.ws_clients[store_name]) > 0: if len(self.ws_clients[store_name]) > 0:
logger.debug(f'WS channel {store_name} got {len(json_data)} bytes to send to:' logger.debug(
f' {", ".join([str(id(ws)) for ws in self.ws_clients[store_name]])}') f"WS channel {store_name} got {len(json_data)} bytes to send to:"
f" {', '.join([str(id(ws)) for ws in self.ws_clients[store_name]])}"
)
for ws in self.ws_clients[store_name]: for ws in self.ws_clients[store_name]:
if ws.client_state.name != 'CONNECTED': if ws.client_state.name != "CONNECTED":
logger.debug(f'Cannot send {store_name} for WS {id(ws)}, state: {ws.client_state.name}') logger.debug(
f"Cannot send {store_name} for WS {id(ws)}, state: {ws.client_state.name}"
)
continue continue
try: try:
await ws.send_text(json_data) await ws.send_text(json_data)
logger.debug(f'Sent live update for WS {id(ws)}: {len(json_data)}') logger.debug(f"Sent live update for WS {id(ws)}: {len(json_data)}")
except RuntimeError as err: except RuntimeError as err:
## The ws is probably closed, remove it from the clients ## The ws is probably closed, remove it from the clients
logger.debug(f'Cannot send live update for {store_name}: {err}') logger.debug(f"Cannot send live update for {store_name}: {err}")
del self.ws_clients[store_name] del self.ws_clients[store_name]
else: else:
pass pass
#logger.debug(f'WS channel {store_name} has no clients') # logger.debug(f'WS channel {store_name} has no clients')
def add_subscription(self, ws, store_name): def add_subscription(self, ws: WebSocket, store_name: str):
""" """
Add the websocket subscription to the layer Add the websocket subscription to the layer
""" """
channel = store.get_json_channel(store_name) channel = store.get_json_channel(store_name)
logger.debug(f'WS {id(ws)} subscribed to {channel}') logger.debug(f"WS {id(ws)} subscribed to {channel}")
self.ws_clients[channel].add(ws) self.ws_clients[channel].add(ws)
def remove_subscription(self, ws, store_name): def remove_subscription(self, ws: WebSocket, store_name: str):
""" """
Remove the websocket subscription to the layer Remove the websocket subscription to the layer
""" """
@ -78,4 +84,5 @@ async def setup_live():
global live_server global live_server
await live_server.setup(listen_to_redis=True) await live_server.setup(listen_to_redis=True)
live_server = LiveServer() live_server = LiveServer()

View file

@ -125,37 +125,37 @@ class Store:
self.uuid = await self.redis.client_getname() self.uuid = await self.redis.client_getname()
self.uuid = str(uuid1()) self.uuid = str(uuid1())
def get_json_channel(self, store_name): def get_json_channel(self, store_name) -> str:
""" """
Name of the Redis channel for the json representation Name of the Redis channel for the json representation
""" """
return f"{store_name}:json" return f"{store_name}:json"
def get_gdf_channel(self, store_name): def get_gdf_channel(self, store_name) -> str:
""" """
Name of the Redis channel for the source gdf, in pickle format Name of the Redis channel for the source gdf, in pickle format
""" """
return f"{store_name}:gdf" return f"{store_name}:gdf"
def get_layer_def_channel(self, store_name): def get_layer_def_channel(self, store_name) -> str:
""" """
Name of the Redis channel for the layer definition Name of the Redis channel for the layer definition
""" """
return f"{store_name}:layer_def" return f"{store_name}:layer_def"
def get_mapbox_layout_channel(self, store_name): def get_mapbox_layout_channel(self, store_name) -> str:
""" """
Name of the Redis channel for the mapbox layout style definition Name of the Redis channel for the mapbox layout style definition
""" """
return f"{store_name}:mapbox_layout" return f"{store_name}:mapbox_layout"
def get_mapbox_paint_channel(self, store_name): def get_mapbox_paint_channel(self, store_name) -> str:
""" """
Name of the Redis channel for the mapbox paint style definition Name of the Redis channel for the mapbox paint style definition
""" """
return f"{store_name}:mapbox_paint" return f"{store_name}:mapbox_paint"
async def store_json(self, model, geojson, **kwargs): async def store_json(self, model, geojson, **kwargs) -> None:
""" """
Store the json representation of the gdf for caching. Store the json representation of the gdf for caching.
""" """