import tarfile from io import BytesIO from logging import getLogger from pathlib import Path import numpy as np import pandas as pd from aiohttp_client_cache import FileBackend from aiohttp_client_cache.session import CachedSession from fastapi import Depends, FastAPI, HTTPException, Response, status from fastapi.responses import ORJSONResponse, PlainTextResponse from fastapi.staticfiles import StaticFiles from PIL import Image from sqlalchemy import String from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.ext.mutable import MutableDict from sqlmodel import Field, Relationship, select from treetrail.config import get_cache_dir from treetrail.database import db_session, fastapi_db_session from treetrail.models import BaseModel, User from treetrail.security import get_current_active_user from treetrail.utils import read_sql, mkdir logger = getLogger(__name__) cache_expiry = 3600 * 24 thumbnail_size = (200, 200) # cache = SQLiteBackend(get_cache_dir() / 'plantekey') cache = FileBackend(get_cache_dir() / "plantekey" / "http") pek_app = FastAPI( default_response_class=ORJSONResponse, ) def get_plantekey_api_url(id): return f"https://plantekey.com/api.php?action=plant&name={id}" def get_storage_root(): return get_cache_dir() / "plantekey" def get_storage(f): return get_storage_root() / f def get_img_root() -> Path: return get_storage_root() / "img" def get_img_path(img: str): return get_img_root() / img def get_thumbnail_root() -> Path: return get_storage_root() / "thumbnails" def get_thumbnail_tar_path(): return get_storage_root() / "thumbnails.tar" def get_img_type_root(): return get_storage_root() / "type" def get_img_type_path(type: str): return get_img_type_root() / (type + ".png") def setup(ap): """ Create empty directories if needed. Intended to be used at startup """ get_img_root().mkdir(parents=True, exist_ok=True) get_thumbnail_root().mkdir(parents=True, exist_ok=True) get_img_type_root().mkdir(parents=True, exist_ok=True) class Plant(BaseModel, table=True): """ Record of a Plantekey plant """ __tablename__: str = "plant" # type: ignore def __str__(self): return str(self.id) def __repr__(self): return f"treetrail.database.Plant: {self.id}" id: str | None = Field(primary_key=True, default=None) ID: int family: str = Field(sa_type=String(100)) # type: ignore name: str = Field(sa_type=String(200)) # type: ignore description: str | None habit: str | None landscape: str | None uses: str | None planting: str | None propagation: str | None type: str = Field(sa_type=String(30)) # type: ignore img: str = Field(sa_type=String(100)) # type: ignore element: str = Field(sa_type=String(30)) # type: ignore isOnMap: str | None = Field(sa_type=String(10)) # type: ignore english: str | None = Field(sa_type=String(100)) # type: ignore hindi: str | None = Field(sa_type=String(100)) # type: ignore tamil: str | None = Field(sa_type=String(100)) # type: ignore spiritual: str | None = Field(sa_type=String(150)) # type: ignore woody: bool latex: str | None = Field(sa_type=String(20)) # type: ignore leaf_style: str | None = Field(sa_type=String(20)) # type: ignore leaf_type: str | None = Field(sa_type=String(20)) # type: ignore leaf_arrangement: str | None = Field(sa_type=String(20)) # type: ignore leaf_aroma: bool | None leaf_length: float | None leaf_width: float | None flower_color: str | None = Field(sa_type=String(20)) # type: ignore flower_size: float | None flower_aroma: bool | None fruit_color: str | None = Field(sa_type=String(20)) # type: ignore fruit_size: float | None fruit_type: str | None = Field(sa_type=String(20)) # type: ignore thorny: str | None = Field(sa_type=String(20)) # type: ignore images: dict = Field( sa_type=MutableDict.as_mutable(JSONB), # type: ignore default_factory=dict, ) # type: ignore # CREATE EXTENSION hstore; # ALTER TABLE tree ADD COLUMN data JSONB; data: dict = Field( sa_type=MutableDict.as_mutable(JSONB), # type: ignore default_factory=dict, ) # type: ignore # __mapper_args__ = {"eager_defaults": True} class PlantImage(BaseModel, table=True): __tablename__: str = "plant_image" # type: ignore id: int | None = Field(primary_key=True, default=None) plant_id: str = Field(foreign_key="plant.id") plant: Plant = Relationship() caption: str = Field(sa_type=String(50)) # type: ignore IsDefault: bool src: str = Field(sa_type=String(100)) # type: ignore def get_thumbnail_path(self) -> Path: return get_thumbnail_root() / self.src # type: ignore class Plantekey(BaseModel, table=True): """ Details for the plantekey data, like symbols for the map """ ## CREATE TABLE plantekey (id VARCHAR(100) PRIMARY KEY NOT NULL, symbol CHAR(1)); ## GRANT ALL on TABLE plantekey to treetrail ; __tablename__: str = "plantekey" # type: ignore id: str | None = Field(primary_key=True, default=None) symbol: str = Field(sa_type=String(1)) # type: ignore iso: str = Field(sa_type=String(100)) # type: ignore async def fetch_browse(): logger.info("Fetching list of plants (browse) from plantekey.com...") plantekey_url = "https://www.plantekey.com/api.php?action=browse" async with CachedSession(cache=cache) as session: async with session.get(plantekey_url) as response: try: content = await response.json(content_type=None) except Exception as err: logger.warning("Error browsing plantekey") logger.exception(err) content = {} df = pd.DataFrame( data=content, columns=[ "ID", "english", "family", "hindi", "img", "name", "spiritual", "tamil", "type", ], ) df["id"] = df["name"].str.replace(" ", "-").str.lower() df["ID"] = df["ID"].astype(int) # async with db_session() as session: # array = df.apply(lambda item: Plant(**item), axis=1) # await session.exec(delete(Plant)) # session.add_all(array) # await session.commit() return df async def get_all(): """ Return the list of plants, with a local cache """ ## TODO: implement cache mechanism ## Store in db? # path = get_storage('all') # if path.stat().st_ctime - time() > cache_expiry: # return fetch_browse() # return pd.read_feather(path) async with db_session() as session: all = await session.exec(select(Plant)) df = pd.DataFrame(all.all()) return df async def get_local_details() -> pd.DataFrame: """ Return a dataframe of the plantekey table, containing extra information like symbols for the map """ async with db_session() as session: data = await session.exec(select(Plantekey.id, Plantekey.symbol)) df = pd.DataFrame(data.all()) if len(df) > 0: df.set_index("id", inplace=True) return df async def get_details(): """ Return the details of plants, as stored on the local server """ local_details = await get_local_details() async with db_session() as session: con = await session.connection() df = await con.run_sync(read_sql, select(Plant)) # df.set_index('id', inplace=True) if len(local_details) > 0: ndf = df.merge(local_details, left_on="id", right_index=True, how="left") ndf["symbol"].fillna("\ue034", inplace=True) return ndf else: return df async def fetch_image(img: PlantImage): url = f"https://www.plantekey.com/admin/images/plants/{img.src}" logger.info(f"Fetching image at {url}") async with CachedSession(cache=cache) as session: try: resp = await session.get(url) except Exception as err: logger.warn(f"Cannot get image for {img.plant_id} ({url}): {err}") return img_data = await resp.content.read() if img_data[0] != 255: logger.warn(f"Image for {img.plant_id} at {url} is not an image") return with open(get_img_path(img.src), "bw") as f: # type: ignore f.write(img_data) update_thumbnail(BytesIO(img_data), img) async def fetch_images(): """ Fetch all the images from the plantekey server """ logger.info("Fetching all images from plantekey server") async with db_session() as session: images = await session.exec(select(PlantImage)) for img_rec in images: await fetch_image(img_rec[0]) def update_thumbnail(data, img: PlantImage): try: tn_image = Image.open(data) tn_image.thumbnail(thumbnail_size) tn_image.save(img.get_thumbnail_path()) except Exception as error: logger.warning( f"Cannot create thumbnail for {img.plant_id} ({img.src}): {error}" ) async def fetch_types(): df = await get_all() for type in df.type.unique(): await fetch_type(type) async def fetch_type(type): plantekey_url = f"https://plantekey.com/img/icons/plant_types/{type}.png" async with CachedSession(cache=cache) as session: async with session.get(plantekey_url) as data: img_data = await data.content.read() with open(get_img_type_path(type), "bw") as f: f.write(img_data) async def update_details(df: pd.DataFrame): """ Update the server database from plantekey details """ all = {} images = {} for id in df["id"]: try: all[id], images[id] = await fetch_detail(id) except Exception as err: logger.warning(f"Error fetching details for {id}: {err}") df_details = pd.DataFrame.from_dict(all, orient="index") df_details.index.name = "id" df_details["ID"] = df_details["ID"].astype(int) ## Cleanup according to DB data types for float_col_name in ("leaf_width", "leaf_length", "flower_size", "fruit_size"): df_details[float_col_name] = pd.to_numeric( df_details[float_col_name], errors="coerce", downcast="float" ) for bool_col_name in ("woody", "leaf_aroma", "flower_aroma"): df_details[bool_col_name] = ( df_details[bool_col_name].replace({"Yes": True, "No": False}).astype(bool) ) # TODO: migrate __table__ to SQLModel, use model_fields for str_col_name in [ c.name for c in Plant.__table__.columns # type: ignore if isinstance(c.type, String) and not c.primary_key ]: # type: ignore df_details[str_col_name] = df_details[str_col_name].replace([np.nan], [None]) async with db_session() as session: plants_array = df_details.reset_index().apply( lambda item: Plant(**item), axis=1 ) # type: ignore existing_plants = await session.exec(select(Plant)) for existing_plant in existing_plants.all(): await session.delete(existing_plant) session.add_all(plants_array) await session.flush() ## Images existing_plant_images = await session.exec(select(PlantImage)) for existing_plant_image in existing_plant_images.all(): await session.delete(existing_plant_image) images_array: list[PlantImage] = [] for plant_id, plant_images in images.items(): images_array.extend( [PlantImage(plant_id=plant_id, **image) for image in plant_images] ) for image in images_array: image.IsDefault = False if image.IsDefault == "0" else True # type: ignore session.add_all(images_array) await session.commit() async def fetch_detail(id): logger.info(f"Fetching details from plantekey.com for {id}...") async with CachedSession(cache=cache) as session: async with session.get(get_plantekey_api_url(id)) as response: result = await response.json(content_type=None) ## Sanitize result["plant"] = { k: v for k, v in result["plant"].items() if not k.isdecimal() } result["images"] = [ {k: v for k, v in image.items() if not k.isdecimal()} for image in result["images"] ] ## Merge dicts, Python 3.9 detail = result["plant"] | (result["characteristics"] or {}) return detail, result["images"] async def create_thumbnail_archive(): """ Create a tar file with all thumbnails of plants """ logger.info("Generating thumbnails and tar file") async with db_session() as session: images = await session.exec(select(PlantImage)) with tarfile.open(str(get_thumbnail_tar_path()), "w") as tar: for img_rec in images: img: PlantImage = img_rec[0] path = img.get_thumbnail_path() if img.IsDefault: if path.is_file(): tar.add(path) logger.info( "Generation of thumbnails and tar file " + f"({get_thumbnail_tar_path()}) finished" ) @pek_app.get("/updateData") async def updateData(user: User = Depends(get_current_active_user)): """ Get list and details of all plants from plantekey """ try: df = await fetch_browse() await update_details(df) except Exception as error: logger.exception(error) logger.error(error) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Update failed: {error.code}", # type: ignore ) else: logger.info("updateData finished") return {"status": 0, "message": "Server updated"} @pek_app.get("/updateImages") async def updateImages(user: User = Depends(get_current_active_user)): """ Get images from plantekey, using the list of plants fetched with updateData """ try: await fetch_images() except Exception as error: logger.exception(error) logger.error(error) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Update of images failed: {error}", ) try: await create_thumbnail_archive() except Exception as error: logger.exception(error) logger.error(error) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Update of images failed while creating thumbnails: {error}", ) return {"status": 0, "message": "Server updated"} @pek_app.get("/plant/info/{id}") async def get_plantekey( id: str, db_session: fastapi_db_session, ) -> Plant | None: # type: ignore """ Get details of a specific plant """ plant = await db_session.get(Plant, id) if plant is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, ) return plant @pek_app.get("/details") async def get_plantekey_all_details(): """ Get all plants """ df = await get_details() content = df.to_json(orient="records") return Response(content=content, media_type="application/json") @pek_app.get("/details/csv", response_class=PlainTextResponse) async def get_plantekey_all_details_csv(): """ Get all plants, return CSV """ df = await get_details() content = df.to_csv() return Response(content=content, media_type="text/csv") pek_app.mount("/img", StaticFiles(directory=mkdir(get_img_root())), name="plantekey_img") pek_app.mount("/thumb", StaticFiles(directory=mkdir(get_thumbnail_root())), name="plantekey_thumbnail") pek_app.mount("/type", StaticFiles(directory=mkdir(get_img_type_root())), name="plantekey_type")