Fix custom plugin downloaders

Change get_as_dataframe call signature
This commit is contained in:
phil 2024-04-06 13:11:38 +05:30
parent 08c53cf894
commit 52e1d2135b
9 changed files with 156 additions and 105 deletions

View file

@ -1 +1 @@
__version__: str = '2023.4.dev56+g775030d.d20240325'
__version__: str = '2023.4.dev62+g08c53cf.d20240405'

View file

@ -9,11 +9,13 @@ from sqlalchemy.orm import selectinload, joinedload
from gisaf.database import pandas_query, fastapi_db_session as db_session
from gisaf.models.geo_models_base import GeoModel, PlottableModel
from gisaf.models.info import Downloader
from gisaf.security import (
Token, authenticate_user, get_current_active_user, create_access_token,
)
from gisaf.models.authentication import (User, UserRead, Role, RoleRead)
from gisaf.registry import registry, NotInRegistry
from gisaf.plugins import DownloadResponse, manager as plugin_manager
logger = logging.getLogger(__name__)
@ -47,7 +49,8 @@ async def download_csv(
if custom_getter:
df = await custom_getter(model_id)
else:
df = await values_model.get_as_dataframe(model_id=model_id, with_only_columns=[value])
item = await db_session.get(model, model_id)
df = await values_model.get_as_dataframe(item=item, with_only_columns=[value])
if len(df) == 0:
raise HTTPException(status.HTTP_204_NO_CONTENT)
if resample and resample != '0':
@ -75,3 +78,49 @@ async def download_csv(
'Content-Disposition': f"attachment; filename={filename}"
})
return response
@api.get('/plugin/{name}/{store}/{id}')
async def execute_action(
name: str,
store: str,
id: int,
db_session: db_session,
user: Annotated[UserRead, Depends(get_current_active_user)]
):
"""
Download the result of an action
"""
## TODO: implement permissions for actions
#await check_permission(info.context['request'], 'action')
try:
store_record = registry.stores.loc[store]
model: type[GeoModel] = store_record.model
values_model = registry.values_for_model[model][0]
except KeyError:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
item = await db_session.get(model, id)
dls: list[Downloader] = [dl for dl in plugin_manager.downloaders_stores[store]
if dl.name == name]
if len(dls) == 0:
raise HTTPException(status.HTTP_501_NOT_IMPLEMENTED,
detail=f'No downloader {name} for {store}')
elif len(dls) > 1:
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f'Too many downloaders ({len(dls)}) {name} for {store}')
downloader = dls[0]
result: DownloadResponse
try:
result = await downloader._plugin.execute(model, item)
except Exception as err:
logging.exception(err)
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f'Error in action: {err.args[0]}'
)
return StreamingResponse(
iter([result.content]),
headers = {
'Content-Disposition': f'attachment; filename="{result.file_name}"'
},
media_type=result.content_type,
)

View file

@ -178,11 +178,14 @@ async def get_model_list(
return resp
@api.get('/{store_name}/values/{value}')
async def get_model_values(store_name: str, value: str,
response: Response,
where: str,
resample: str | None = None,
):
async def get_model_values(
db_session: db_session,
store_name: str,
value: str,
response: Response,
where: str,
resample: str | None = None,
):
"""
Get values
"""
@ -202,7 +205,8 @@ async def get_model_values(store_name: str, value: str,
if getter:
df = await getter(model_id)
else:
df = await values_model.get_as_dataframe(model_id=model_id,
item = await db_session.get(model, model_id)
df = await values_model.get_as_dataframe(item=item,
with_only_columns=[value])
if len(df) == 0:

View file

@ -15,7 +15,7 @@ from gisaf.api.admin import api as admin_api
from gisaf.api.dashboard import api as dashboard_api
from gisaf.api.map import api as map_api
from gisaf.api.download import api as download_api
from gisaf.plugins import manager as plugin_manger
from gisaf.plugins import manager as plugin_manager
logging.basicConfig(level=conf.gisaf.debugLevel)
logger = logging.getLogger(__name__)
@ -27,7 +27,7 @@ async def lifespan(app: FastAPI):
await setup_redis()
await setup_redis_cache()
await setup_live()
await plugin_manger.scan_plugins()
await plugin_manager.scan_plugins()
await admin_manager.setup_admin()
await map_tile_registry.setup()
yield

View file

@ -1130,44 +1130,31 @@ class PlottableModel(Model):
values: ClassVar[list[dict[str, str]]] = []
@classmethod
async def get_as_dataframe(cls, model_id=None, where=None, **kwargs):
async def get_as_dataframe(cls, item, where=None, **kwargs):
"""
Get a dataframe for the data.
It's quite generic, so subclasses might want to subclass this.
"""
if where is None:
if model_id is None:
where_ = None
else:
where_ = cls.ref_id == model_id
else:
if model_id is None:
where_ = where
else:
where_ = and_(cls.ref_id == model_id, where)
if where_ is not None:
df = await cls.get_df(where=where_, **kwargs)
else:
df = await cls.get_df(**kwargs)
return df
where_ = cls.ref_id == item.id
if where is not None:
where_ = and_(where_, where)
return await cls.get_df(where=where_, **kwargs)
class TimePlottableModel(PlottableModel):
time: datetime
@classmethod
async def get_as_dataframe(cls, model_id=None, with_only_columns=None, **kwargs):
async def get_as_dataframe(cls, item, with_only_columns=None, **kwargs):
"""
Get the data as a time-indexed dataframe
"""
if with_only_columns == None:
if with_only_columns is None:
with_only_columns = [val['name'] for val in cls.values]
if 'time' not in with_only_columns:
with_only_columns.insert(0, 'time')
df = await super().get_as_dataframe(model_id=model_id,
df = await super().get_as_dataframe(item=item,
with_only_columns=with_only_columns,
**kwargs)

View file

@ -1,6 +1,6 @@
from typing import Any
from pydantic import BaseModel
from pydantic import BaseModel, PrivateAttr
from gisaf.models.info_item import Tag, InfoItem
from gisaf.models.tags import Tags
@ -112,11 +112,15 @@ class ActionAction(BaseModel):
class Downloader(BaseModel):
# plugin: str
# downloader: str
_plugin: Any = PrivateAttr() # DownloadPlugin
roles: list[str] = []
name: str
icon: str | None = None
icon: str | None
def __init__(self, _plugin, **data):
super().__init__(**data)
# We generate the value for our private attribute
self._plugin = _plugin
class LegendItem(BaseModel):

View file

@ -12,9 +12,9 @@ from fastapi import HTTPException, status
from sqlalchemy import or_, and_
# from geoalchemy2.shape import to_shape, from_shape
# from graphene import ObjectType, String, List, Boolean, Field, Float, InputObjectType
import pandas as pd
import shapely # type: ignore
from pydantic import BaseModel
from gisaf.config import conf
from gisaf.models.store import Store # noqa: F401
@ -106,6 +106,12 @@ class TagPlugin:
return self.key
class DownloadResponse(BaseModel):
file_name: str
content_type: str
content: str
class DownloadPlugin:
"""
Base class for all download plugins.
@ -119,26 +125,26 @@ class DownloadPlugin:
self.roles = roles or []
self.icon = icon
async def execute(self, model, item, request):
async def execute(self, model, item) -> DownloadResponse:
raise NotImplementedError(f'Missing execute in downloader {self.name}')
class DownloadCSVPlugin(DownloadPlugin):
async def execute(self, model, item, request):
async def execute(self, model, item) -> DownloadResponse:
try:
values_models = registry.values_for_model[model]
except KeyError:
raise NotInRegistry
for value_model in values_models:
df = await value_model.get_as_dataframe(model_id=item.id)
df = await value_model.get_as_dataframe(item=item)
csv = df.to_csv(date_format='%d/%m/%Y %H:%M', float_format=value_model.float_format)
## TODO: implement multiple values for a model (search for values_for_model)
break
return {
'file_name': '{:s}.csv'.format(item.caption),
'content_type': 'text/csv',
'content': csv
}
return DownloadResponse(
file_name=f'{model.__name__}-id-{item.id}.csv',
content_type='text/csv',
content=csv
)
class PluginManager:
@ -171,7 +177,7 @@ class PluginManager:
self.actions_stores: dict[str, dict[str, list[ActionAction]]] = {}
self.executors = defaultdict(list)
self.downloaders = defaultdict(list)
self.downloaders_stores = defaultdict(list)
self.downloaders_stores: dict[str, list[Downloader]] = defaultdict(list)
registered_models = registry.geom
registered_stores = registered_models.keys()
@ -268,13 +274,13 @@ class PluginManager:
logger.warn(f'Downloader plugin {entry_point.name}: skip model {store}'
', which is not found in registry')
continue
self.downloaders_stores[store].append(
Downloader(
name=downloader.name,
roles=downloader.roles,
icon=downloader.icon,
)
_store = Downloader(
_plugin=downloader,
name=downloader.name,
roles=downloader.roles,
icon=downloader.icon,
)
self.downloaders_stores[store].append(_store)
logger.info(f'Added downloader plugin {entry_point.name}')
self.tagsStores = TagsStores(