489 lines
16 KiB
Python
489 lines
16 KiB
Python
|
import tarfile
|
||
|
from io import BytesIO
|
||
|
from logging import getLogger
|
||
|
from pathlib import Path
|
||
|
|
||
|
import numpy as np
|
||
|
import pandas as pd
|
||
|
from aiohttp_client_cache import FileBackend
|
||
|
from aiohttp_client_cache.session import CachedSession
|
||
|
from fastapi import Depends, FastAPI, HTTPException, Response, status
|
||
|
from fastapi.responses import ORJSONResponse, PlainTextResponse
|
||
|
from fastapi.staticfiles import StaticFiles
|
||
|
from PIL import Image
|
||
|
from sqlalchemy import String
|
||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||
|
from sqlalchemy.ext.mutable import MutableDict
|
||
|
from sqlmodel import Field, Relationship, select
|
||
|
|
||
|
from treetrail.config import get_cache_dir
|
||
|
from treetrail.database import db_session, fastapi_db_session
|
||
|
from treetrail.models import BaseModel, User
|
||
|
from treetrail.security import get_current_active_user
|
||
|
from treetrail.utils import read_sql, mkdir
|
||
|
|
||
|
logger = getLogger(__name__)
|
||
|
|
||
|
cache_expiry = 3600 * 24
|
||
|
thumbnail_size = (200, 200)
|
||
|
|
||
|
# cache = SQLiteBackend(get_cache_dir() / 'plantekey')
|
||
|
cache = FileBackend(get_cache_dir() / "plantekey" / "http")
|
||
|
|
||
|
pek_app = FastAPI(
|
||
|
default_response_class=ORJSONResponse,
|
||
|
)
|
||
|
|
||
|
|
||
|
def get_plantekey_api_url(id):
|
||
|
return f"https://plantekey.com/api.php?action=plant&name={id}"
|
||
|
|
||
|
|
||
|
def get_storage_root():
|
||
|
return get_cache_dir() / "plantekey"
|
||
|
|
||
|
|
||
|
def get_storage(f):
|
||
|
return get_storage_root() / f
|
||
|
|
||
|
|
||
|
def get_img_root() -> Path:
|
||
|
return get_storage_root() / "img"
|
||
|
|
||
|
|
||
|
def get_img_path(img: str):
|
||
|
return get_img_root() / img
|
||
|
|
||
|
|
||
|
def get_thumbnail_root() -> Path:
|
||
|
return get_storage_root() / "thumbnails"
|
||
|
|
||
|
|
||
|
def get_thumbnail_tar_path():
|
||
|
return get_storage_root() / "thumbnails.tar"
|
||
|
|
||
|
|
||
|
def get_img_type_root():
|
||
|
return get_storage_root() / "type"
|
||
|
|
||
|
|
||
|
def get_img_type_path(type: str):
|
||
|
return get_img_type_root() / (type + ".png")
|
||
|
|
||
|
|
||
|
def setup(ap):
|
||
|
"""
|
||
|
Create empty directories if needed.
|
||
|
Intended to be used at startup
|
||
|
"""
|
||
|
get_img_root().mkdir(parents=True, exist_ok=True)
|
||
|
get_thumbnail_root().mkdir(parents=True, exist_ok=True)
|
||
|
get_img_type_root().mkdir(parents=True, exist_ok=True)
|
||
|
|
||
|
|
||
|
class Plant(BaseModel, table=True):
|
||
|
"""
|
||
|
Record of a Plantekey plant
|
||
|
"""
|
||
|
|
||
|
__tablename__: str = "plant" # type: ignore
|
||
|
|
||
|
def __str__(self):
|
||
|
return str(self.id)
|
||
|
|
||
|
def __repr__(self):
|
||
|
return f"treetrail.database.Plant: {self.id}"
|
||
|
|
||
|
id: str | None = Field(primary_key=True, default=None)
|
||
|
ID: int
|
||
|
family: str = Field(sa_type=String(100)) # type: ignore
|
||
|
name: str = Field(sa_type=String(200)) # type: ignore
|
||
|
description: str | None
|
||
|
habit: str | None
|
||
|
landscape: str | None
|
||
|
uses: str | None
|
||
|
planting: str | None
|
||
|
propagation: str | None
|
||
|
type: str = Field(sa_type=String(30)) # type: ignore
|
||
|
img: str = Field(sa_type=String(100)) # type: ignore
|
||
|
element: str = Field(sa_type=String(30)) # type: ignore
|
||
|
isOnMap: str | None = Field(sa_type=String(10)) # type: ignore
|
||
|
english: str | None = Field(sa_type=String(100)) # type: ignore
|
||
|
hindi: str | None = Field(sa_type=String(100)) # type: ignore
|
||
|
tamil: str | None = Field(sa_type=String(100)) # type: ignore
|
||
|
spiritual: str | None = Field(sa_type=String(150)) # type: ignore
|
||
|
woody: bool
|
||
|
latex: str | None = Field(sa_type=String(20)) # type: ignore
|
||
|
leaf_style: str | None = Field(sa_type=String(20)) # type: ignore
|
||
|
leaf_type: str | None = Field(sa_type=String(20)) # type: ignore
|
||
|
leaf_arrangement: str | None = Field(sa_type=String(20)) # type: ignore
|
||
|
leaf_aroma: bool | None
|
||
|
leaf_length: float | None
|
||
|
leaf_width: float | None
|
||
|
flower_color: str | None = Field(sa_type=String(20)) # type: ignore
|
||
|
flower_size: float | None
|
||
|
flower_aroma: bool | None
|
||
|
fruit_color: str | None = Field(sa_type=String(20)) # type: ignore
|
||
|
fruit_size: float | None
|
||
|
fruit_type: str | None = Field(sa_type=String(20)) # type: ignore
|
||
|
thorny: str | None = Field(sa_type=String(20)) # type: ignore
|
||
|
images: dict = Field(
|
||
|
sa_type=MutableDict.as_mutable(JSONB), # type: ignore
|
||
|
default_factory=dict,
|
||
|
) # type: ignore
|
||
|
|
||
|
# CREATE EXTENSION hstore;
|
||
|
# ALTER TABLE tree ADD COLUMN data JSONB;
|
||
|
data: dict = Field(
|
||
|
sa_type=MutableDict.as_mutable(JSONB), # type: ignore
|
||
|
default_factory=dict,
|
||
|
) # type: ignore
|
||
|
|
||
|
# __mapper_args__ = {"eager_defaults": True}
|
||
|
|
||
|
|
||
|
class PlantImage(BaseModel, table=True):
|
||
|
__tablename__: str = "plant_image" # type: ignore
|
||
|
|
||
|
id: int | None = Field(primary_key=True, default=None)
|
||
|
plant_id: str = Field(foreign_key="plant.id")
|
||
|
plant: Plant = Relationship()
|
||
|
caption: str = Field(sa_type=String(50)) # type: ignore
|
||
|
IsDefault: bool
|
||
|
src: str = Field(sa_type=String(100)) # type: ignore
|
||
|
|
||
|
def get_thumbnail_path(self) -> Path:
|
||
|
return get_thumbnail_root() / self.src # type: ignore
|
||
|
|
||
|
|
||
|
class Plantekey(BaseModel, table=True):
|
||
|
"""
|
||
|
Details for the plantekey data, like symbols for the map
|
||
|
"""
|
||
|
|
||
|
## CREATE TABLE plantekey (id VARCHAR(100) PRIMARY KEY NOT NULL, symbol CHAR(1));
|
||
|
## GRANT ALL on TABLE plantekey to treetrail ;
|
||
|
__tablename__: str = "plantekey" # type: ignore
|
||
|
|
||
|
id: str | None = Field(primary_key=True, default=None)
|
||
|
symbol: str = Field(sa_type=String(1)) # type: ignore
|
||
|
iso: str = Field(sa_type=String(100)) # type: ignore
|
||
|
|
||
|
|
||
|
async def fetch_browse():
|
||
|
logger.info("Fetching list of plants (browse) from plantekey.com...")
|
||
|
plantekey_url = "https://www.plantekey.com/api.php?action=browse"
|
||
|
async with CachedSession(cache=cache) as session:
|
||
|
async with session.get(plantekey_url) as response:
|
||
|
try:
|
||
|
content = await response.json(content_type=None)
|
||
|
except Exception as err:
|
||
|
logger.warning("Error browsing plantekey")
|
||
|
logger.exception(err)
|
||
|
content = {}
|
||
|
df = pd.DataFrame(
|
||
|
data=content,
|
||
|
columns=[
|
||
|
"ID",
|
||
|
"english",
|
||
|
"family",
|
||
|
"hindi",
|
||
|
"img",
|
||
|
"name",
|
||
|
"spiritual",
|
||
|
"tamil",
|
||
|
"type",
|
||
|
],
|
||
|
)
|
||
|
df["id"] = df["name"].str.replace(" ", "-").str.lower()
|
||
|
df["ID"] = df["ID"].astype(int)
|
||
|
# async with db_session() as session:
|
||
|
# array = df.apply(lambda item: Plant(**item), axis=1)
|
||
|
# await session.exec(delete(Plant))
|
||
|
# session.add_all(array)
|
||
|
# await session.commit()
|
||
|
return df
|
||
|
|
||
|
|
||
|
async def get_all():
|
||
|
"""
|
||
|
Return the list of plants, with a local cache
|
||
|
"""
|
||
|
## TODO: implement cache mechanism
|
||
|
## Store in db?
|
||
|
# path = get_storage('all')
|
||
|
# if path.stat().st_ctime - time() > cache_expiry:
|
||
|
# return fetch_browse()
|
||
|
|
||
|
# return pd.read_feather(path)
|
||
|
async with db_session() as session:
|
||
|
all = await session.exec(select(Plant))
|
||
|
df = pd.DataFrame(all.all())
|
||
|
return df
|
||
|
|
||
|
|
||
|
async def get_local_details() -> pd.DataFrame:
|
||
|
"""
|
||
|
Return a dataframe of the plantekey table, containing extra information
|
||
|
like symbols for the map
|
||
|
"""
|
||
|
async with db_session() as session:
|
||
|
data = await session.exec(select(Plantekey.id, Plantekey.symbol))
|
||
|
df = pd.DataFrame(data.all())
|
||
|
if len(df) > 0:
|
||
|
df.set_index("id", inplace=True)
|
||
|
return df
|
||
|
|
||
|
|
||
|
async def get_details():
|
||
|
"""
|
||
|
Return the details of plants, as stored on the local server
|
||
|
"""
|
||
|
local_details = await get_local_details()
|
||
|
async with db_session() as session:
|
||
|
con = await session.connection()
|
||
|
df = await con.run_sync(read_sql, select(Plant))
|
||
|
# df.set_index('id', inplace=True)
|
||
|
if len(local_details) > 0:
|
||
|
ndf = df.merge(local_details, left_on="id", right_index=True, how="left")
|
||
|
ndf["symbol"].fillna("\ue034", inplace=True)
|
||
|
return ndf
|
||
|
else:
|
||
|
return df
|
||
|
|
||
|
|
||
|
async def fetch_image(img: PlantImage):
|
||
|
url = f"https://www.plantekey.com/admin/images/plants/{img.src}"
|
||
|
logger.info(f"Fetching image at {url}")
|
||
|
async with CachedSession(cache=cache) as session:
|
||
|
try:
|
||
|
resp = await session.get(url)
|
||
|
except Exception as err:
|
||
|
logger.warn(f"Cannot get image for {img.plant_id} ({url}): {err}")
|
||
|
return
|
||
|
img_data = await resp.content.read()
|
||
|
if img_data[0] != 255:
|
||
|
logger.warn(f"Image for {img.plant_id} at {url} is not an image")
|
||
|
return
|
||
|
with open(get_img_path(img.src), "bw") as f: # type: ignore
|
||
|
f.write(img_data)
|
||
|
update_thumbnail(BytesIO(img_data), img)
|
||
|
|
||
|
|
||
|
async def fetch_images():
|
||
|
"""
|
||
|
Fetch all the images from the plantekey server
|
||
|
"""
|
||
|
logger.info("Fetching all images from plantekey server")
|
||
|
async with db_session() as session:
|
||
|
images = await session.exec(select(PlantImage))
|
||
|
for img_rec in images:
|
||
|
await fetch_image(img_rec[0])
|
||
|
|
||
|
|
||
|
def update_thumbnail(data, img: PlantImage):
|
||
|
try:
|
||
|
tn_image = Image.open(data)
|
||
|
tn_image.thumbnail(thumbnail_size)
|
||
|
tn_image.save(img.get_thumbnail_path())
|
||
|
except Exception as error:
|
||
|
logger.warning(
|
||
|
f"Cannot create thumbnail for {img.plant_id} ({img.src}): {error}"
|
||
|
)
|
||
|
|
||
|
|
||
|
async def fetch_types():
|
||
|
df = await get_all()
|
||
|
for type in df.type.unique():
|
||
|
await fetch_type(type)
|
||
|
|
||
|
|
||
|
async def fetch_type(type):
|
||
|
plantekey_url = f"https://plantekey.com/img/icons/plant_types/{type}.png"
|
||
|
async with CachedSession(cache=cache) as session:
|
||
|
async with session.get(plantekey_url) as data:
|
||
|
img_data = await data.content.read()
|
||
|
with open(get_img_type_path(type), "bw") as f:
|
||
|
f.write(img_data)
|
||
|
|
||
|
|
||
|
async def update_details(df: pd.DataFrame):
|
||
|
"""
|
||
|
Update the server database from plantekey details
|
||
|
"""
|
||
|
all = {}
|
||
|
images = {}
|
||
|
for id in df["id"]:
|
||
|
try:
|
||
|
all[id], images[id] = await fetch_detail(id)
|
||
|
except Exception as err:
|
||
|
logger.warning(f"Error fetching details for {id}: {err}")
|
||
|
df_details = pd.DataFrame.from_dict(all, orient="index")
|
||
|
df_details.index.name = "id"
|
||
|
df_details["ID"] = df_details["ID"].astype(int)
|
||
|
## Cleanup according to DB data types
|
||
|
for float_col_name in ("leaf_width", "leaf_length", "flower_size", "fruit_size"):
|
||
|
df_details[float_col_name] = pd.to_numeric(
|
||
|
df_details[float_col_name], errors="coerce", downcast="float"
|
||
|
)
|
||
|
for bool_col_name in ("woody", "leaf_aroma", "flower_aroma"):
|
||
|
df_details[bool_col_name] = (
|
||
|
df_details[bool_col_name].replace({"Yes": True, "No": False}).astype(bool)
|
||
|
)
|
||
|
# TODO: migrate __table__ to SQLModel, use model_fields
|
||
|
for str_col_name in [
|
||
|
c.name
|
||
|
for c in Plant.__table__.columns # type: ignore
|
||
|
if isinstance(c.type, String) and not c.primary_key
|
||
|
]: # type: ignore
|
||
|
df_details[str_col_name] = df_details[str_col_name].replace([np.nan], [None])
|
||
|
|
||
|
async with db_session() as session:
|
||
|
plants_array = df_details.reset_index().apply(
|
||
|
lambda item: Plant(**item), axis=1
|
||
|
) # type: ignore
|
||
|
existing_plants = await session.exec(select(Plant))
|
||
|
for existing_plant in existing_plants.all():
|
||
|
await session.delete(existing_plant)
|
||
|
session.add_all(plants_array)
|
||
|
await session.flush()
|
||
|
## Images
|
||
|
existing_plant_images = await session.exec(select(PlantImage))
|
||
|
for existing_plant_image in existing_plant_images.all():
|
||
|
await session.delete(existing_plant_image)
|
||
|
images_array: list[PlantImage] = []
|
||
|
for plant_id, plant_images in images.items():
|
||
|
images_array.extend(
|
||
|
[PlantImage(plant_id=plant_id, **image) for image in plant_images]
|
||
|
)
|
||
|
for image in images_array:
|
||
|
image.IsDefault = False if image.IsDefault == "0" else True # type: ignore
|
||
|
session.add_all(images_array)
|
||
|
await session.commit()
|
||
|
|
||
|
|
||
|
async def fetch_detail(id):
|
||
|
logger.info(f"Fetching details from plantekey.com for {id}...")
|
||
|
async with CachedSession(cache=cache) as session:
|
||
|
async with session.get(get_plantekey_api_url(id)) as response:
|
||
|
result = await response.json(content_type=None)
|
||
|
## Sanitize
|
||
|
result["plant"] = {
|
||
|
k: v for k, v in result["plant"].items() if not k.isdecimal()
|
||
|
}
|
||
|
result["images"] = [
|
||
|
{k: v for k, v in image.items() if not k.isdecimal()}
|
||
|
for image in result["images"]
|
||
|
]
|
||
|
## Merge dicts, Python 3.9
|
||
|
detail = result["plant"] | (result["characteristics"] or {})
|
||
|
return detail, result["images"]
|
||
|
|
||
|
|
||
|
async def create_thumbnail_archive():
|
||
|
"""
|
||
|
Create a tar file with all thumbnails of plants
|
||
|
"""
|
||
|
logger.info("Generating thumbnails and tar file")
|
||
|
async with db_session() as session:
|
||
|
images = await session.exec(select(PlantImage))
|
||
|
with tarfile.open(str(get_thumbnail_tar_path()), "w") as tar:
|
||
|
for img_rec in images:
|
||
|
img: PlantImage = img_rec[0]
|
||
|
path = img.get_thumbnail_path()
|
||
|
if img.IsDefault:
|
||
|
if path.is_file():
|
||
|
tar.add(path)
|
||
|
logger.info(
|
||
|
"Generation of thumbnails and tar file "
|
||
|
+ f"({get_thumbnail_tar_path()}) finished"
|
||
|
)
|
||
|
|
||
|
|
||
|
@pek_app.get("/updateData")
|
||
|
async def updateData(user: User = Depends(get_current_active_user)):
|
||
|
"""
|
||
|
Get list and details of all plants from plantekey
|
||
|
"""
|
||
|
try:
|
||
|
df = await fetch_browse()
|
||
|
await update_details(df)
|
||
|
except Exception as error:
|
||
|
logger.exception(error)
|
||
|
logger.error(error)
|
||
|
raise HTTPException(
|
||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
detail=f"Update failed: {error.code}", # type: ignore
|
||
|
)
|
||
|
else:
|
||
|
logger.info("updateData finished")
|
||
|
return {"status": 0, "message": "Server updated"}
|
||
|
|
||
|
|
||
|
@pek_app.get("/updateImages")
|
||
|
async def updateImages(user: User = Depends(get_current_active_user)):
|
||
|
"""
|
||
|
Get images from plantekey, using the list of plants
|
||
|
fetched with updateData
|
||
|
"""
|
||
|
try:
|
||
|
await fetch_images()
|
||
|
except Exception as error:
|
||
|
logger.exception(error)
|
||
|
logger.error(error)
|
||
|
raise HTTPException(
|
||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
detail=f"Update of images failed: {error}",
|
||
|
)
|
||
|
try:
|
||
|
await create_thumbnail_archive()
|
||
|
except Exception as error:
|
||
|
logger.exception(error)
|
||
|
logger.error(error)
|
||
|
raise HTTPException(
|
||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
|
detail=f"Update of images failed while creating thumbnails: {error}",
|
||
|
)
|
||
|
return {"status": 0, "message": "Server updated"}
|
||
|
|
||
|
|
||
|
@pek_app.get("/plant/info/{id}")
|
||
|
async def get_plantekey(
|
||
|
id: str,
|
||
|
db_session: fastapi_db_session,
|
||
|
) -> Plant | None: # type: ignore
|
||
|
"""
|
||
|
Get details of a specific plant
|
||
|
"""
|
||
|
plant = await db_session.get(Plant, id)
|
||
|
if plant is None:
|
||
|
raise HTTPException(
|
||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||
|
)
|
||
|
return plant
|
||
|
|
||
|
|
||
|
@pek_app.get("/details")
|
||
|
async def get_plantekey_all_details():
|
||
|
"""
|
||
|
Get all plants
|
||
|
"""
|
||
|
df = await get_details()
|
||
|
content = df.to_json(orient="records")
|
||
|
return Response(content=content, media_type="application/json")
|
||
|
|
||
|
|
||
|
@pek_app.get("/details/csv", response_class=PlainTextResponse)
|
||
|
async def get_plantekey_all_details_csv():
|
||
|
"""
|
||
|
Get all plants, return CSV
|
||
|
"""
|
||
|
df = await get_details()
|
||
|
content = df.to_csv()
|
||
|
return Response(content=content, media_type="text/csv")
|
||
|
|
||
|
|
||
|
pek_app.mount("/img", StaticFiles(directory=mkdir(get_img_root())), name="plantekey_img")
|
||
|
pek_app.mount("/thumb", StaticFiles(directory=mkdir(get_thumbnail_root())), name="plantekey_thumbnail")
|
||
|
pek_app.mount("/type", StaticFiles(directory=mkdir(get_img_type_root())), name="plantekey_type")
|