treetrail-backend/src/treetrail/plantekey.py

489 lines
16 KiB
Python
Raw Normal View History

2024-10-23 16:19:51 +02:00
import tarfile
from io import BytesIO
from logging import getLogger
from pathlib import Path
import numpy as np
import pandas as pd
from aiohttp_client_cache import FileBackend
from aiohttp_client_cache.session import CachedSession
from fastapi import Depends, FastAPI, HTTPException, Response, status
from fastapi.responses import ORJSONResponse, PlainTextResponse
from fastapi.staticfiles import StaticFiles
from PIL import Image
from sqlalchemy import String
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.mutable import MutableDict
from sqlmodel import Field, Relationship, select
from treetrail.config import get_cache_dir
from treetrail.database import db_session, fastapi_db_session
from treetrail.models import BaseModel, User
from treetrail.security import get_current_active_user
from treetrail.utils import read_sql, mkdir
logger = getLogger(__name__)
cache_expiry = 3600 * 24
thumbnail_size = (200, 200)
# cache = SQLiteBackend(get_cache_dir() / 'plantekey')
cache = FileBackend(get_cache_dir() / "plantekey" / "http")
pek_app = FastAPI(
default_response_class=ORJSONResponse,
)
def get_plantekey_api_url(id):
return f"https://plantekey.com/api.php?action=plant&name={id}"
def get_storage_root():
return get_cache_dir() / "plantekey"
def get_storage(f):
return get_storage_root() / f
def get_img_root() -> Path:
return get_storage_root() / "img"
def get_img_path(img: str):
return get_img_root() / img
def get_thumbnail_root() -> Path:
return get_storage_root() / "thumbnails"
def get_thumbnail_tar_path():
return get_storage_root() / "thumbnails.tar"
def get_img_type_root():
return get_storage_root() / "type"
def get_img_type_path(type: str):
return get_img_type_root() / (type + ".png")
def setup(ap):
"""
Create empty directories if needed.
Intended to be used at startup
"""
get_img_root().mkdir(parents=True, exist_ok=True)
get_thumbnail_root().mkdir(parents=True, exist_ok=True)
get_img_type_root().mkdir(parents=True, exist_ok=True)
class Plant(BaseModel, table=True):
"""
Record of a Plantekey plant
"""
__tablename__: str = "plant" # type: ignore
def __str__(self):
return str(self.id)
def __repr__(self):
return f"treetrail.database.Plant: {self.id}"
id: str | None = Field(primary_key=True, default=None)
ID: int
family: str = Field(sa_type=String(100)) # type: ignore
name: str = Field(sa_type=String(200)) # type: ignore
description: str | None
habit: str | None
landscape: str | None
uses: str | None
planting: str | None
propagation: str | None
type: str = Field(sa_type=String(30)) # type: ignore
img: str = Field(sa_type=String(100)) # type: ignore
element: str = Field(sa_type=String(30)) # type: ignore
isOnMap: str | None = Field(sa_type=String(10)) # type: ignore
english: str | None = Field(sa_type=String(100)) # type: ignore
hindi: str | None = Field(sa_type=String(100)) # type: ignore
tamil: str | None = Field(sa_type=String(100)) # type: ignore
spiritual: str | None = Field(sa_type=String(150)) # type: ignore
woody: bool
latex: str | None = Field(sa_type=String(20)) # type: ignore
leaf_style: str | None = Field(sa_type=String(20)) # type: ignore
leaf_type: str | None = Field(sa_type=String(20)) # type: ignore
leaf_arrangement: str | None = Field(sa_type=String(20)) # type: ignore
leaf_aroma: bool | None
leaf_length: float | None
leaf_width: float | None
flower_color: str | None = Field(sa_type=String(20)) # type: ignore
flower_size: float | None
flower_aroma: bool | None
fruit_color: str | None = Field(sa_type=String(20)) # type: ignore
fruit_size: float | None
fruit_type: str | None = Field(sa_type=String(20)) # type: ignore
thorny: str | None = Field(sa_type=String(20)) # type: ignore
images: dict = Field(
sa_type=MutableDict.as_mutable(JSONB), # type: ignore
default_factory=dict,
) # type: ignore
# CREATE EXTENSION hstore;
# ALTER TABLE tree ADD COLUMN data JSONB;
data: dict = Field(
sa_type=MutableDict.as_mutable(JSONB), # type: ignore
default_factory=dict,
) # type: ignore
# __mapper_args__ = {"eager_defaults": True}
class PlantImage(BaseModel, table=True):
__tablename__: str = "plant_image" # type: ignore
id: int | None = Field(primary_key=True, default=None)
plant_id: str = Field(foreign_key="plant.id")
plant: Plant = Relationship()
caption: str = Field(sa_type=String(50)) # type: ignore
IsDefault: bool
src: str = Field(sa_type=String(100)) # type: ignore
def get_thumbnail_path(self) -> Path:
return get_thumbnail_root() / self.src # type: ignore
class Plantekey(BaseModel, table=True):
"""
Details for the plantekey data, like symbols for the map
"""
## CREATE TABLE plantekey (id VARCHAR(100) PRIMARY KEY NOT NULL, symbol CHAR(1));
## GRANT ALL on TABLE plantekey to treetrail ;
__tablename__: str = "plantekey" # type: ignore
id: str | None = Field(primary_key=True, default=None)
symbol: str = Field(sa_type=String(1)) # type: ignore
iso: str = Field(sa_type=String(100)) # type: ignore
async def fetch_browse():
logger.info("Fetching list of plants (browse) from plantekey.com...")
plantekey_url = "https://www.plantekey.com/api.php?action=browse"
async with CachedSession(cache=cache) as session:
async with session.get(plantekey_url) as response:
try:
content = await response.json(content_type=None)
except Exception as err:
logger.warning("Error browsing plantekey")
logger.exception(err)
content = {}
df = pd.DataFrame(
data=content,
columns=[
"ID",
"english",
"family",
"hindi",
"img",
"name",
"spiritual",
"tamil",
"type",
],
)
df["id"] = df["name"].str.replace(" ", "-").str.lower()
df["ID"] = df["ID"].astype(int)
# async with db_session() as session:
# array = df.apply(lambda item: Plant(**item), axis=1)
# await session.exec(delete(Plant))
# session.add_all(array)
# await session.commit()
return df
async def get_all():
"""
Return the list of plants, with a local cache
"""
## TODO: implement cache mechanism
## Store in db?
# path = get_storage('all')
# if path.stat().st_ctime - time() > cache_expiry:
# return fetch_browse()
# return pd.read_feather(path)
async with db_session() as session:
all = await session.exec(select(Plant))
df = pd.DataFrame(all.all())
return df
async def get_local_details() -> pd.DataFrame:
"""
Return a dataframe of the plantekey table, containing extra information
like symbols for the map
"""
async with db_session() as session:
data = await session.exec(select(Plantekey.id, Plantekey.symbol))
df = pd.DataFrame(data.all())
if len(df) > 0:
df.set_index("id", inplace=True)
return df
async def get_details():
"""
Return the details of plants, as stored on the local server
"""
local_details = await get_local_details()
async with db_session() as session:
con = await session.connection()
df = await con.run_sync(read_sql, select(Plant))
# df.set_index('id', inplace=True)
if len(local_details) > 0:
ndf = df.merge(local_details, left_on="id", right_index=True, how="left")
ndf["symbol"].fillna("\ue034", inplace=True)
return ndf
else:
return df
async def fetch_image(img: PlantImage):
url = f"https://www.plantekey.com/admin/images/plants/{img.src}"
logger.info(f"Fetching image at {url}")
async with CachedSession(cache=cache) as session:
try:
resp = await session.get(url)
except Exception as err:
logger.warn(f"Cannot get image for {img.plant_id} ({url}): {err}")
return
img_data = await resp.content.read()
if img_data[0] != 255:
logger.warn(f"Image for {img.plant_id} at {url} is not an image")
return
with open(get_img_path(img.src), "bw") as f: # type: ignore
f.write(img_data)
update_thumbnail(BytesIO(img_data), img)
async def fetch_images():
"""
Fetch all the images from the plantekey server
"""
logger.info("Fetching all images from plantekey server")
async with db_session() as session:
images = await session.exec(select(PlantImage))
for img_rec in images:
await fetch_image(img_rec[0])
def update_thumbnail(data, img: PlantImage):
try:
tn_image = Image.open(data)
tn_image.thumbnail(thumbnail_size)
tn_image.save(img.get_thumbnail_path())
except Exception as error:
logger.warning(
f"Cannot create thumbnail for {img.plant_id} ({img.src}): {error}"
)
async def fetch_types():
df = await get_all()
for type in df.type.unique():
await fetch_type(type)
async def fetch_type(type):
plantekey_url = f"https://plantekey.com/img/icons/plant_types/{type}.png"
async with CachedSession(cache=cache) as session:
async with session.get(plantekey_url) as data:
img_data = await data.content.read()
with open(get_img_type_path(type), "bw") as f:
f.write(img_data)
async def update_details(df: pd.DataFrame):
"""
Update the server database from plantekey details
"""
all = {}
images = {}
for id in df["id"]:
try:
all[id], images[id] = await fetch_detail(id)
except Exception as err:
logger.warning(f"Error fetching details for {id}: {err}")
df_details = pd.DataFrame.from_dict(all, orient="index")
df_details.index.name = "id"
df_details["ID"] = df_details["ID"].astype(int)
## Cleanup according to DB data types
for float_col_name in ("leaf_width", "leaf_length", "flower_size", "fruit_size"):
df_details[float_col_name] = pd.to_numeric(
df_details[float_col_name], errors="coerce", downcast="float"
)
for bool_col_name in ("woody", "leaf_aroma", "flower_aroma"):
df_details[bool_col_name] = (
df_details[bool_col_name].replace({"Yes": True, "No": False}).astype(bool)
)
# TODO: migrate __table__ to SQLModel, use model_fields
for str_col_name in [
c.name
for c in Plant.__table__.columns # type: ignore
if isinstance(c.type, String) and not c.primary_key
]: # type: ignore
df_details[str_col_name] = df_details[str_col_name].replace([np.nan], [None])
async with db_session() as session:
plants_array = df_details.reset_index().apply(
lambda item: Plant(**item), axis=1
) # type: ignore
existing_plants = await session.exec(select(Plant))
for existing_plant in existing_plants.all():
await session.delete(existing_plant)
session.add_all(plants_array)
await session.flush()
## Images
existing_plant_images = await session.exec(select(PlantImage))
for existing_plant_image in existing_plant_images.all():
await session.delete(existing_plant_image)
images_array: list[PlantImage] = []
for plant_id, plant_images in images.items():
images_array.extend(
[PlantImage(plant_id=plant_id, **image) for image in plant_images]
)
for image in images_array:
image.IsDefault = False if image.IsDefault == "0" else True # type: ignore
session.add_all(images_array)
await session.commit()
async def fetch_detail(id):
logger.info(f"Fetching details from plantekey.com for {id}...")
async with CachedSession(cache=cache) as session:
async with session.get(get_plantekey_api_url(id)) as response:
result = await response.json(content_type=None)
## Sanitize
result["plant"] = {
k: v for k, v in result["plant"].items() if not k.isdecimal()
}
result["images"] = [
{k: v for k, v in image.items() if not k.isdecimal()}
for image in result["images"]
]
## Merge dicts, Python 3.9
detail = result["plant"] | (result["characteristics"] or {})
return detail, result["images"]
async def create_thumbnail_archive():
"""
Create a tar file with all thumbnails of plants
"""
logger.info("Generating thumbnails and tar file")
async with db_session() as session:
images = await session.exec(select(PlantImage))
with tarfile.open(str(get_thumbnail_tar_path()), "w") as tar:
for img_rec in images:
img: PlantImage = img_rec[0]
path = img.get_thumbnail_path()
if img.IsDefault:
if path.is_file():
tar.add(path)
logger.info(
"Generation of thumbnails and tar file "
+ f"({get_thumbnail_tar_path()}) finished"
)
@pek_app.get("/updateData")
async def updateData(user: User = Depends(get_current_active_user)):
"""
Get list and details of all plants from plantekey
"""
try:
df = await fetch_browse()
await update_details(df)
except Exception as error:
logger.exception(error)
logger.error(error)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Update failed: {error.code}", # type: ignore
)
else:
logger.info("updateData finished")
return {"status": 0, "message": "Server updated"}
@pek_app.get("/updateImages")
async def updateImages(user: User = Depends(get_current_active_user)):
"""
Get images from plantekey, using the list of plants
fetched with updateData
"""
try:
await fetch_images()
except Exception as error:
logger.exception(error)
logger.error(error)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Update of images failed: {error}",
)
try:
await create_thumbnail_archive()
except Exception as error:
logger.exception(error)
logger.error(error)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Update of images failed while creating thumbnails: {error}",
)
return {"status": 0, "message": "Server updated"}
@pek_app.get("/plant/info/{id}")
async def get_plantekey(
id: str,
db_session: fastapi_db_session,
) -> Plant | None: # type: ignore
"""
Get details of a specific plant
"""
plant = await db_session.get(Plant, id)
if plant is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
)
return plant
@pek_app.get("/details")
async def get_plantekey_all_details():
"""
Get all plants
"""
df = await get_details()
content = df.to_json(orient="records")
return Response(content=content, media_type="application/json")
@pek_app.get("/details/csv", response_class=PlainTextResponse)
async def get_plantekey_all_details_csv():
"""
Get all plants, return CSV
"""
df = await get_details()
content = df.to_csv()
return Response(content=content, media_type="text/csv")
pek_app.mount("/img", StaticFiles(directory=mkdir(get_img_root())), name="plantekey_img")
pek_app.mount("/thumb", StaticFiles(directory=mkdir(get_thumbnail_root())), name="plantekey_thumbnail")
pek_app.mount("/type", StaticFiles(directory=mkdir(get_img_type_root())), name="plantekey_type")