Fix reactor
Cleanup some confusion in redis_tools with mqtt
This commit is contained in:
parent
46b524636b
commit
1fd347d8df
5 changed files with 54 additions and 32 deletions
|
@ -1 +1 @@
|
|||
__version__: str = '0.1.dev85+g41e92fa.d20240509'
|
||||
__version__: str = '2023.4.dev95+g46b5246.d20240520'
|
|
@ -9,6 +9,7 @@ from sqlalchemy.sql.selectable import Select
|
|||
from sqlmodel import SQLModel, select
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from fastapi import Depends
|
||||
|
||||
# from geoalchemy2.functions import ST_SimplifyPreserveTopology
|
||||
import pandas as pd
|
||||
import geopandas as gpd # type: ignore
|
||||
|
@ -29,22 +30,31 @@ sync_engine = create_engine(
|
|||
max_overflow=conf.db.max_overflow,
|
||||
)
|
||||
|
||||
|
||||
async def get_db_session() -> AsyncGenerator[AsyncSession]:
|
||||
async with AsyncSession(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def db_session() -> AsyncGenerator[AsyncSession]:
|
||||
async with AsyncSession(engine) as session:
|
||||
yield session
|
||||
|
||||
def pandas_query(session, query):
|
||||
|
||||
def pandas_query(session, query, cast=False):
|
||||
return pd.read_sql_query(query, session.connection())
|
||||
|
||||
def geopandas_query(session, query: Select, model, *,
|
||||
|
||||
def geopandas_query(
|
||||
session,
|
||||
query: Select,
|
||||
model,
|
||||
*,
|
||||
# simplify_tolerance: float|None=None,
|
||||
crs=None, cast=True,
|
||||
):
|
||||
crs=None,
|
||||
cast=True,
|
||||
):
|
||||
## XXX: I could not get the add_columns work without creating a subquery,
|
||||
## so moving the simplification to geopandas - see in _get_df
|
||||
# if simplify_tolerance is not None:
|
||||
|
@ -55,9 +65,10 @@ def geopandas_query(session, query: Select, model, *,
|
|||
# query = query.add_columns(new_column)
|
||||
return gpd.GeoDataFrame.from_postgis(query, session.connection(), crs=crs)
|
||||
|
||||
|
||||
class BaseModel(SQLModel):
|
||||
@classmethod
|
||||
def selectinload(cls) -> list[Literal['*'] | QueryableAttribute[Any]]:
|
||||
def selectinload(cls) -> list[Literal["*"] | QueryableAttribute[Any]]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
|
@ -69,11 +80,17 @@ class BaseModel(SQLModel):
|
|||
return await cls._get_df(geopandas_query, model=cls, **kwargs) # type: ignore
|
||||
|
||||
@classmethod
|
||||
async def _get_df(cls, method, *,
|
||||
where=None, with_related=True, with_only_columns=[],
|
||||
simplify_tolerance: float | None=None,
|
||||
preserve_topology: bool | None=None,
|
||||
**kwargs) -> pd.DataFrame | gpd.GeoDataFrame:
|
||||
async def _get_df(
|
||||
cls,
|
||||
method,
|
||||
*,
|
||||
where=None,
|
||||
with_related=True,
|
||||
with_only_columns=[],
|
||||
simplify_tolerance: float | None = None,
|
||||
preserve_topology: bool | None = None,
|
||||
**kwargs,
|
||||
) -> pd.DataFrame | gpd.GeoDataFrame:
|
||||
async with db_session() as session:
|
||||
if len(with_only_columns) == 0:
|
||||
query = select(cls)
|
||||
|
@ -97,17 +114,19 @@ class BaseModel(SQLModel):
|
|||
# pass
|
||||
df = await session.run_sync(method, query, **kwargs)
|
||||
if method is geopandas_query and simplify_tolerance is not None:
|
||||
df['geom'] = df['geom'].simplify(
|
||||
df["geom"] = df["geom"].simplify(
|
||||
simplify_tolerance / conf.geo.simplify_geom_factor,
|
||||
preserve_topology=(conf.geo.simplify_preserve_topology
|
||||
preserve_topology=(
|
||||
conf.geo.simplify_preserve_topology
|
||||
if preserve_topology is None
|
||||
else preserve_topology)
|
||||
else preserve_topology
|
||||
),
|
||||
)
|
||||
## Chamge column names to reflect the joined tables
|
||||
## Leave the first columns unchanged, as their names come straight
|
||||
## from the model's fields
|
||||
if with_related:
|
||||
joined_columns = list(df.columns[len(cls.model_fields):])
|
||||
joined_columns = list(df.columns[len(cls.model_fields) :])
|
||||
renames: dict[str, str] = {}
|
||||
# Match colum names with the joined tables
|
||||
# This uses the fact that orders of the joined tables
|
||||
|
@ -119,11 +138,15 @@ class BaseModel(SQLModel):
|
|||
target = joined_table.property.target # type: ignore
|
||||
for col in target.columns:
|
||||
## Pop the column from the colujmn list and make a new name
|
||||
renames[joined_columns.pop(0)] = f'{target.schema}_{target.name}_{col.name}'
|
||||
renames[joined_columns.pop(0)] = (
|
||||
f"{target.schema}_{target.name}_{col.name}"
|
||||
)
|
||||
df.rename(columns=renames, inplace=True)
|
||||
## Finally, set the index of the df as the index of cls
|
||||
df.set_index([c.name for c in cls.__table__.primary_key.columns], # type: ignore
|
||||
inplace=True)
|
||||
df.set_index(
|
||||
[c.name for c in cls.__table__.primary_key.columns], # type: ignore
|
||||
inplace=True,
|
||||
)
|
||||
return df
|
||||
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ import asyncio
|
|||
import logging
|
||||
from collections import defaultdict
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi import WebSocket
|
||||
|
||||
# from .config import conf
|
||||
from gisaf.redis_tools import store
|
||||
|
@ -14,12 +14,11 @@ class LiveServer:
|
|||
def __init__(self):
|
||||
self.ws_clients = defaultdict(set)
|
||||
|
||||
async def setup(self, listen_to_redis=False, with_mqtt=False):
|
||||
async def setup(self, listen_to_redis=False):
|
||||
"""
|
||||
Setup for the live server
|
||||
"""
|
||||
if with_mqtt:
|
||||
logger.warning("Gisaf LiveServer does not support with_mqtt: ignoring")
|
||||
await store.setup(with_registry=False)
|
||||
if listen_to_redis:
|
||||
self.pub = store.redis.pubsub()
|
||||
await self.pub.psubscribe("live:*:json")
|
||||
|
|
|
@ -14,6 +14,7 @@ from collections import OrderedDict
|
|||
from aiomqtt import Client, Message
|
||||
|
||||
from gisaf.config import conf
|
||||
from gisaf.live import live_server
|
||||
|
||||
logger = logging.getLogger("gisaf.reactor")
|
||||
|
||||
|
@ -23,7 +24,7 @@ class Reactor:
|
|||
self.processors = {}
|
||||
|
||||
async def setup(self, exclude_processor_names=None):
|
||||
if exclude_processor_names == None:
|
||||
if exclude_processor_names is None:
|
||||
exclude_processor_names = []
|
||||
for entry_point in entry_points().select(group="gisaf_message_processors"):
|
||||
logger.debug(f"Processing message processor module {entry_point.name}")
|
||||
|
@ -53,6 +54,7 @@ class Reactor:
|
|||
await message_processor.setup()
|
||||
self.add_processor(message_processor)
|
||||
logger.info(f'Added message processor "{entry_point.name}"')
|
||||
await live_server.setup()
|
||||
|
||||
def get_available_processors(self):
|
||||
return [
|
||||
|
@ -139,6 +141,7 @@ async def cancel_tasks(tasks):
|
|||
async def main(list=None, exclude_processor_names=None) -> None:
|
||||
if list:
|
||||
reactor = Reactor()
|
||||
await reactor.setup()
|
||||
jobs = reactor.get_available_processors()
|
||||
print(" ".join(jobs))
|
||||
sys.exit(0)
|
||||
|
|
|
@ -23,8 +23,6 @@ from apscheduler.triggers.cron import CronTrigger
|
|||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from apscheduler.triggers.date import DateTrigger
|
||||
|
||||
# from gisaf.ipynb_tools import Gisaf
|
||||
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s:%(levelname)s:%(name)s:%(message)s", "%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
|
@ -76,7 +74,6 @@ class JobBaseClass:
|
|||
|
||||
|
||||
class JobScheduler:
|
||||
# gs: Gisaf
|
||||
jobs: dict[str, Any]
|
||||
# tasks: dict[str, Any]
|
||||
wss: dict[str, Any]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue