Fix custom plugin downloaders
Change get_as_dataframe call signature
This commit is contained in:
parent
08c53cf894
commit
52e1d2135b
9 changed files with 156 additions and 105 deletions
|
@ -1 +1 @@
|
|||
__version__: str = '2023.4.dev56+g775030d.d20240325'
|
||||
__version__: str = '2023.4.dev62+g08c53cf.d20240405'
|
|
@ -9,11 +9,13 @@ from sqlalchemy.orm import selectinload, joinedload
|
|||
|
||||
from gisaf.database import pandas_query, fastapi_db_session as db_session
|
||||
from gisaf.models.geo_models_base import GeoModel, PlottableModel
|
||||
from gisaf.models.info import Downloader
|
||||
from gisaf.security import (
|
||||
Token, authenticate_user, get_current_active_user, create_access_token,
|
||||
)
|
||||
from gisaf.models.authentication import (User, UserRead, Role, RoleRead)
|
||||
from gisaf.registry import registry, NotInRegistry
|
||||
from gisaf.plugins import DownloadResponse, manager as plugin_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -47,7 +49,8 @@ async def download_csv(
|
|||
if custom_getter:
|
||||
df = await custom_getter(model_id)
|
||||
else:
|
||||
df = await values_model.get_as_dataframe(model_id=model_id, with_only_columns=[value])
|
||||
item = await db_session.get(model, model_id)
|
||||
df = await values_model.get_as_dataframe(item=item, with_only_columns=[value])
|
||||
if len(df) == 0:
|
||||
raise HTTPException(status.HTTP_204_NO_CONTENT)
|
||||
if resample and resample != '0':
|
||||
|
@ -75,3 +78,49 @@ async def download_csv(
|
|||
'Content-Disposition': f"attachment; filename={filename}"
|
||||
})
|
||||
return response
|
||||
|
||||
|
||||
@api.get('/plugin/{name}/{store}/{id}')
|
||||
async def execute_action(
|
||||
name: str,
|
||||
store: str,
|
||||
id: int,
|
||||
db_session: db_session,
|
||||
user: Annotated[UserRead, Depends(get_current_active_user)]
|
||||
):
|
||||
"""
|
||||
Download the result of an action
|
||||
"""
|
||||
## TODO: implement permissions for actions
|
||||
#await check_permission(info.context['request'], 'action')
|
||||
try:
|
||||
store_record = registry.stores.loc[store]
|
||||
model: type[GeoModel] = store_record.model
|
||||
values_model = registry.values_for_model[model][0]
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
||||
item = await db_session.get(model, id)
|
||||
dls: list[Downloader] = [dl for dl in plugin_manager.downloaders_stores[store]
|
||||
if dl.name == name]
|
||||
if len(dls) == 0:
|
||||
raise HTTPException(status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail=f'No downloader {name} for {store}')
|
||||
elif len(dls) > 1:
|
||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f'Too many downloaders ({len(dls)}) {name} for {store}')
|
||||
downloader = dls[0]
|
||||
result: DownloadResponse
|
||||
try:
|
||||
result = await downloader._plugin.execute(model, item)
|
||||
except Exception as err:
|
||||
logging.exception(err)
|
||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f'Error in action: {err.args[0]}'
|
||||
)
|
||||
return StreamingResponse(
|
||||
iter([result.content]),
|
||||
headers = {
|
||||
'Content-Disposition': f'attachment; filename="{result.file_name}"'
|
||||
},
|
||||
media_type=result.content_type,
|
||||
)
|
|
@ -178,11 +178,14 @@ async def get_model_list(
|
|||
return resp
|
||||
|
||||
@api.get('/{store_name}/values/{value}')
|
||||
async def get_model_values(store_name: str, value: str,
|
||||
response: Response,
|
||||
where: str,
|
||||
resample: str | None = None,
|
||||
):
|
||||
async def get_model_values(
|
||||
db_session: db_session,
|
||||
store_name: str,
|
||||
value: str,
|
||||
response: Response,
|
||||
where: str,
|
||||
resample: str | None = None,
|
||||
):
|
||||
"""
|
||||
Get values
|
||||
"""
|
||||
|
@ -202,7 +205,8 @@ async def get_model_values(store_name: str, value: str,
|
|||
if getter:
|
||||
df = await getter(model_id)
|
||||
else:
|
||||
df = await values_model.get_as_dataframe(model_id=model_id,
|
||||
item = await db_session.get(model, model_id)
|
||||
df = await values_model.get_as_dataframe(item=item,
|
||||
with_only_columns=[value])
|
||||
|
||||
if len(df) == 0:
|
||||
|
|
|
@ -15,7 +15,7 @@ from gisaf.api.admin import api as admin_api
|
|||
from gisaf.api.dashboard import api as dashboard_api
|
||||
from gisaf.api.map import api as map_api
|
||||
from gisaf.api.download import api as download_api
|
||||
from gisaf.plugins import manager as plugin_manger
|
||||
from gisaf.plugins import manager as plugin_manager
|
||||
|
||||
logging.basicConfig(level=conf.gisaf.debugLevel)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -27,7 +27,7 @@ async def lifespan(app: FastAPI):
|
|||
await setup_redis()
|
||||
await setup_redis_cache()
|
||||
await setup_live()
|
||||
await plugin_manger.scan_plugins()
|
||||
await plugin_manager.scan_plugins()
|
||||
await admin_manager.setup_admin()
|
||||
await map_tile_registry.setup()
|
||||
yield
|
||||
|
|
|
@ -1130,44 +1130,31 @@ class PlottableModel(Model):
|
|||
values: ClassVar[list[dict[str, str]]] = []
|
||||
|
||||
@classmethod
|
||||
async def get_as_dataframe(cls, model_id=None, where=None, **kwargs):
|
||||
async def get_as_dataframe(cls, item, where=None, **kwargs):
|
||||
"""
|
||||
Get a dataframe for the data.
|
||||
It's quite generic, so subclasses might want to subclass this.
|
||||
"""
|
||||
if where is None:
|
||||
if model_id is None:
|
||||
where_ = None
|
||||
else:
|
||||
where_ = cls.ref_id == model_id
|
||||
else:
|
||||
if model_id is None:
|
||||
where_ = where
|
||||
else:
|
||||
where_ = and_(cls.ref_id == model_id, where)
|
||||
|
||||
if where_ is not None:
|
||||
df = await cls.get_df(where=where_, **kwargs)
|
||||
else:
|
||||
df = await cls.get_df(**kwargs)
|
||||
|
||||
return df
|
||||
where_ = cls.ref_id == item.id
|
||||
if where is not None:
|
||||
where_ = and_(where_, where)
|
||||
return await cls.get_df(where=where_, **kwargs)
|
||||
|
||||
|
||||
class TimePlottableModel(PlottableModel):
|
||||
time: datetime
|
||||
|
||||
@classmethod
|
||||
async def get_as_dataframe(cls, model_id=None, with_only_columns=None, **kwargs):
|
||||
async def get_as_dataframe(cls, item, with_only_columns=None, **kwargs):
|
||||
"""
|
||||
Get the data as a time-indexed dataframe
|
||||
"""
|
||||
if with_only_columns == None:
|
||||
if with_only_columns is None:
|
||||
with_only_columns = [val['name'] for val in cls.values]
|
||||
if 'time' not in with_only_columns:
|
||||
with_only_columns.insert(0, 'time')
|
||||
|
||||
df = await super().get_as_dataframe(model_id=model_id,
|
||||
df = await super().get_as_dataframe(item=item,
|
||||
with_only_columns=with_only_columns,
|
||||
**kwargs)
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
from gisaf.models.info_item import Tag, InfoItem
|
||||
from gisaf.models.tags import Tags
|
||||
|
@ -112,11 +112,15 @@ class ActionAction(BaseModel):
|
|||
|
||||
|
||||
class Downloader(BaseModel):
|
||||
# plugin: str
|
||||
# downloader: str
|
||||
_plugin: Any = PrivateAttr() # DownloadPlugin
|
||||
roles: list[str] = []
|
||||
name: str
|
||||
icon: str | None = None
|
||||
icon: str | None
|
||||
|
||||
def __init__(self, _plugin, **data):
|
||||
super().__init__(**data)
|
||||
# We generate the value for our private attribute
|
||||
self._plugin = _plugin
|
||||
|
||||
|
||||
class LegendItem(BaseModel):
|
||||
|
|
|
@ -12,9 +12,9 @@ from fastapi import HTTPException, status
|
|||
from sqlalchemy import or_, and_
|
||||
# from geoalchemy2.shape import to_shape, from_shape
|
||||
# from graphene import ObjectType, String, List, Boolean, Field, Float, InputObjectType
|
||||
|
||||
import pandas as pd
|
||||
import shapely # type: ignore
|
||||
from pydantic import BaseModel
|
||||
|
||||
from gisaf.config import conf
|
||||
from gisaf.models.store import Store # noqa: F401
|
||||
|
@ -106,6 +106,12 @@ class TagPlugin:
|
|||
return self.key
|
||||
|
||||
|
||||
class DownloadResponse(BaseModel):
|
||||
file_name: str
|
||||
content_type: str
|
||||
content: str
|
||||
|
||||
|
||||
class DownloadPlugin:
|
||||
"""
|
||||
Base class for all download plugins.
|
||||
|
@ -119,26 +125,26 @@ class DownloadPlugin:
|
|||
self.roles = roles or []
|
||||
self.icon = icon
|
||||
|
||||
async def execute(self, model, item, request):
|
||||
async def execute(self, model, item) -> DownloadResponse:
|
||||
raise NotImplementedError(f'Missing execute in downloader {self.name}')
|
||||
|
||||
|
||||
class DownloadCSVPlugin(DownloadPlugin):
|
||||
async def execute(self, model, item, request):
|
||||
async def execute(self, model, item) -> DownloadResponse:
|
||||
try:
|
||||
values_models = registry.values_for_model[model]
|
||||
except KeyError:
|
||||
raise NotInRegistry
|
||||
for value_model in values_models:
|
||||
df = await value_model.get_as_dataframe(model_id=item.id)
|
||||
df = await value_model.get_as_dataframe(item=item)
|
||||
csv = df.to_csv(date_format='%d/%m/%Y %H:%M', float_format=value_model.float_format)
|
||||
## TODO: implement multiple values for a model (search for values_for_model)
|
||||
break
|
||||
return {
|
||||
'file_name': '{:s}.csv'.format(item.caption),
|
||||
'content_type': 'text/csv',
|
||||
'content': csv
|
||||
}
|
||||
return DownloadResponse(
|
||||
file_name=f'{model.__name__}-id-{item.id}.csv',
|
||||
content_type='text/csv',
|
||||
content=csv
|
||||
)
|
||||
|
||||
|
||||
class PluginManager:
|
||||
|
@ -171,7 +177,7 @@ class PluginManager:
|
|||
self.actions_stores: dict[str, dict[str, list[ActionAction]]] = {}
|
||||
self.executors = defaultdict(list)
|
||||
self.downloaders = defaultdict(list)
|
||||
self.downloaders_stores = defaultdict(list)
|
||||
self.downloaders_stores: dict[str, list[Downloader]] = defaultdict(list)
|
||||
|
||||
registered_models = registry.geom
|
||||
registered_stores = registered_models.keys()
|
||||
|
@ -268,13 +274,13 @@ class PluginManager:
|
|||
logger.warn(f'Downloader plugin {entry_point.name}: skip model {store}'
|
||||
', which is not found in registry')
|
||||
continue
|
||||
self.downloaders_stores[store].append(
|
||||
Downloader(
|
||||
name=downloader.name,
|
||||
roles=downloader.roles,
|
||||
icon=downloader.icon,
|
||||
)
|
||||
_store = Downloader(
|
||||
_plugin=downloader,
|
||||
name=downloader.name,
|
||||
roles=downloader.roles,
|
||||
icon=downloader.icon,
|
||||
)
|
||||
self.downloaders_stores[store].append(_store)
|
||||
logger.info(f'Added downloader plugin {entry_point.name}')
|
||||
|
||||
self.tagsStores = TagsStores(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue