Initial commit

This commit is contained in:
phil 2024-10-23 16:19:51 +02:00
commit f4cf78603a
25 changed files with 2895 additions and 0 deletions

View file

@ -0,0 +1,15 @@
on: [push]
jobs:
install:
runs-on: container
container:
image: tiptop:5000/treetrail-backend-ci-base
services:
treetrail-database:
image: treetrail-database
steps:
- uses: actions/checkout@v4
- name: Install dependencies
run: uv install
- name: Run basic test (bootstrap)
run: .venv/bin/pytest -s tests/basic.py

13
.gitignore vendored Normal file
View file

@ -0,0 +1,13 @@
# Python-generated files
__pycache__/
*.py[oc]
build/
dist/
wheels/
*.egg-info
# Virtual environments
.venv
# Custom
.python-version

12
Containerfile Normal file
View file

@ -0,0 +1,12 @@
FROM localhost/trixie_python
WORKDIR /usr/src/treetrail
ENV PATH="/usr/src/treetrail/.venv/bin:$PATH"
ENV PYTHONPATH="/usr/src"
COPY --from=localhost/treetrail_backend_deps /usr/src/treetrail/.venv/ /usr/src/treetrail/.venv
COPY --from=localhost/treetrail_backend_deps /usr/local/treetrail/ /usr/local/treetrail
COPY ./treetrail ./pyproject.toml ./README.md .
# Instances should override the prod.yaml file
COPY ./prod.yaml /etc/treetrail/prod.yaml
CMD ["uvicorn", "treetrail.application:app", "--port", "8081", "--log-config", "logging.yaml", "--host", "0.0.0.0"]

View file

@ -0,0 +1,11 @@
FROM debian:trixie-slim
MAINTAINER philo email phil.dev@philome.mooo.com
RUN apt update
RUN apt install --no-install-recommends -y python-is-python3 python3-pip python3-venv nodejs git
RUN pip install --break-system-packages pdm
RUN apt-get clean && \
rm -rf /var/lib/apt/lists/*
RUN rm -rf /root/.cache

View file

@ -0,0 +1,36 @@
FROM localhost/trixie_python
MAINTAINER philo email phil.dev@philome.mooo.com
#ENV PROJ_DIR=/usr
ENV PYTHONDONTWRITEBYTECODE 1
ENV PDM_CHECK_UPDATE=false
#RUN apk add --no-cache make cmake clang gdal-dev geos-dev proj-dev proj-util gcc musl-dev bash
#RUN apk add --no-cache gdal-dev geos-dev proj-dev proj-util gcc musl-dev bash
WORKDIR /usr/src/treetrail
COPY ./pyproject.toml ./README.md ./pdm.lock .
# Cheating pdm with the app version to allow install of dependencies
RUN PDM_BUILD_SCM_VERSION=1.0 pdm install --check --prod --no-editable
## Instances should populate these dirs below
RUN mkdir -p /usr/local/treetrail/osm \
/usr/local/treetrail/sprite \
/usr/local/treetrail/cache/plantekey/img \
/usr/local/treetrail/cache/plantekey/thumbnails \
/usr/local/treetrail/cache/plantekey/type \
/usr/local/treetrail/map/sprite \
/usr/local/treetrail/map/osm \
/usr/local/treetrail/attachments/tree \
/usr/local/treetrail/attachments/trail \
/usr/local/treetrail/attachments/poi
#COPY ./sprite /usr/local/treetrail
#COPY ./osm /usr/local/treetrail
#RUN python -c 'import _version as v;print(v.__version__)' > version.txt
#RUN PDM_BUILD_SCM_VERSION=$(cat version.txt) pdm install --check --prod --no-editable
#
# Clear some space (caches)
#RUN pdm cache clear
#RUN rm -rf .mypy_cache
#RUN rm -rf __pycache__

View file

@ -0,0 +1,11 @@
FROM debian:trixie-slim
MAINTAINER philo email phil.dev@philome.mooo.com
RUN apt update
RUN apt install --no-install-recommends -y python-is-python3 python3-pip python3-venv
RUN pip install --break-system-packages pdm
RUN apt-get clean && \
rm -rf /var/lib/apt/lists/*
RUN rm -rf /root/.cache

3
README.md Normal file
View file

@ -0,0 +1,3 @@
*Tree Trail* is a fun and pedagogic tool to discover the trails and trees around.
This is the server (back-end), written in Python.

79
pyproject.toml Normal file
View file

@ -0,0 +1,79 @@
[project]
name = "treetrail-srv"
version = "0.1.0"
#dynamic = ["version"]
dynamic = ["version"]
description = "A fun and pedagogic tool to discover the trails and trees around"
authors = [
{ name = "Philippe May", email = "phil.treetrail@philome.mooo.com" }
]
dependencies = [
"aiofiles",
"aiohttp-client-cache",
"aiosqlite",
"asyncpg",
"fastapi",
"geoalchemy2",
"geopandas",
"httptools>=0.6.1",
"orjson",
"pandas",
"passlib[bcrypt]",
"pillow",
"psycopg2-binary",
"pyarrow",
"pydantic-settings",
"python-jose[cryptography]",
"python-multipart",
"requests",
"sqlalchemy[asyncio]",
"sqlmodel",
"uvicorn[standard]",
"uvloop",
]
requires-python = ">=3.11"
readme = "README.md"
license = {text = "MIT"}
classifiers = [
"Development Status :: 3 - Alpha",
"Framework :: FastAPI",
"Environment :: Web Environment",
"Intended Audience :: Developers",
"License :: OSI Approved :: GNU General Public License (GPL)",
"Programming Language :: Python :: 3",
"Operating System :: MacOS :: MacOS X",
"Operating System :: POSIX",
"Programming Language :: Python",
]
#[project.scripts]
#treetrail-srv = "treetrail_srv:main"
#[tool.pdm.build]
#includes = ["src/"]
#
#[tool.pdm.version]
#source = "scm"
#write_to = "treetrail/_version.py"
#write_template = "__version__ = '{}'"
#
#[tool.pdm.dev-dependencies]
#dev = [
# "ipdb",
# "pandas-stubs",
# "types-Pillow",
# "types-PyYAML",
# "types-aiofiles",
# "types-passlib",
# "types-python-jose",
# "types-requests",
#]
#test = [
# "pytest>=8.3.3",
# "httpx>=0.27.2",
#]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

View file

@ -0,0 +1 @@
__version__ = '2024.4.dev3+g0527f08.d20241021'

487
src/treetrail/api_v1.py Normal file
View file

@ -0,0 +1,487 @@
import logging
import mimetypes
from pathlib import Path
from datetime import timedelta
import tarfile
from typing import Optional
from base64 import standard_b64decode
import re
from typing import Tuple
from json import loads
from uuid import UUID
from fastapi import (FastAPI, Response, HTTPException,
File, UploadFile, Request, Form, responses,
Depends, status)
from fastapi.staticfiles import StaticFiles
from fastapi.security import OAuth2PasswordRequestForm
from sqlmodel import select
from sqlalchemy import or_
import geopandas as gpd # type: ignore
import pandas as pd
import aiofiles
import aiofiles.os
from PIL import Image
from treetrail.utils import (get_attachment_poi_root, get_attachment_root,
get_attachment_trail_root, get_attachment_tree_root, mkdir)
from treetrail.security import (
Token,
authenticate_user, create_access_token,
get_current_active_user, get_current_user, get_current_roles,
)
from treetrail.database import fastapi_db_session as db_session
from treetrail.models import (BaseMapStyles, User, Role, Bootstrap,
MapStyle, Tree, Trail,
TreeTrail, POI, UserWithRoles, Zone,
VersionedComponent)
from treetrail.config import conf, get_cache_dir, __version__
from treetrail.plantekey import get_local_details
from treetrail.tiles import registry as tilesRegistry
logger = logging.getLogger(__name__)
api_app = FastAPI(
debug=False,
title=conf.app.title,
version=conf.version,
# lifespan=lifespan,
default_response_class=responses.ORJSONResponse,
)
re_findmimetype = re.compile('^data:(\S+);') # type: ignore
attachment_types: dict[str, type[Tree] | type[Trail] | type[POI]] = {
'tree': Tree,
'trail': Trail,
'poi': POI
}
attachment_thumbnailable_fields = {
'photo'
}
thumbnail_size = (200, 200)
@api_app.get('/bootstrap')
async def get_bootstrap(
user: UserWithRoles = Depends(get_current_user)
) -> Bootstrap:
# XXX: hide password - issue zith SQLModel
return Bootstrap(
server=VersionedComponent(version=__version__),
client=VersionedComponent(version=__version__),
app=conf.app,
user=user,
map=conf.map,
baseMapStyles=BaseMapStyles(
embedded=list(tilesRegistry.mbtiles.keys()),
external=conf.mapStyles,
),
)
@api_app.post("/token", response_model=Token)
async def login_for_access_token(
form_data: OAuth2PasswordRequestForm = Depends()
):
user = await authenticate_user(form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(
minutes=conf.security.access_token_expire_minutes)
access_token = create_access_token(
data={"sub": user.username},
expires_delta=access_token_expires)
return {"access_token": access_token, "token_type": "bearer"}
@api_app.post("/upload/{type}/{field}/{id}")
async def upload(request: Request, type: str, field: str, id: str,
db_session: db_session,
file: UploadFile = File(...),
user: User = Depends(get_current_active_user)
):
if type not in attachment_types:
raise HTTPException(status_code=status.HTTP_417_EXPECTATION_FAILED,
detail=f"No such type: {type}")
model = attachment_types[type]
if field not in model.model_fields:
raise HTTPException(status_code=status.HTTP_417_EXPECTATION_FAILED,
detail=f"No such field for {type}: {field}")
base_dir = get_attachment_root(type) / id
if not base_dir.is_dir():
await aiofiles.os.mkdir(base_dir)
filename = base_dir / file.filename # type: ignore
if field in attachment_thumbnailable_fields:
try:
# TODO: async save
image = Image.open(file.file)
image.thumbnail(thumbnail_size)
image.save(filename)
logger.info(f'Saved thumbnail {filename}')
except Exception as error:
logger.warning('Cannot create thumbnail for ' +
f'{type} {field} {id} ({filename}): {error}')
else:
async with aiofiles.open(filename, 'wb') as f:
await f.write(file.file.read())
logger.info(f'Saved file {filename}')
rec = await db_session.get(model, int(id))
if rec is None:
raise HTTPException(status_code=status.HTTP_417_EXPECTATION_FAILED,
detail=f'No such {type} id {id}')
setattr(rec, field, file.filename)
await db_session.commit()
return {
"message": "Successfully uploaded",
"filename": file.filename,
}
@api_app.get("/makeAttachmentsTarFile")
async def makeAttachmentsTarFile(
db_session: db_session,
current_user: User = Depends(get_current_active_user)
):
"""
Create a tar file with all photos, used to feed clients' caches
for offline use
"""
logger.info('Generating thumbnails and tar file')
tarfile_path = get_cache_dir() / 'attachments.tar'
with tarfile.open(str(tarfile_path), 'w') as tar:
for type, model in attachment_types.items():
data = await db_session.exec(select(model.id, model.photo))
# recs: list[Tree | Trail | POI]
recs = data.all()
for rec in recs:
photo: str = rec.photo # type: ignore
id: str = rec.id # type: ignore
if photo:
file = get_attachment_root(type) / str(id) / photo
if file.is_file():
tar.add(file)
logger.info(f'Generation of thumbnails and tar file ({tarfile_path}) finished')
return {
"message": "Successfully made attachments tar file",
}
@api_app.get("/logout")
def logout(response: Response):
response.delete_cookie(key='token')
return response
@api_app.get('/trail')
async def get_trails(
roles: list[Role] = Depends(get_current_roles),
):
"""
Get all trails
"""
gdf = await Trail.get_gdf(
where=or_(Trail.viewable_role_id.in_([role.name for role in roles]), # type: ignore
Trail.viewable_role_id == None)) # type: ignore # noqa: E711
if len(gdf) == 0:
gdf.set_geometry([], inplace=True)
# Get only file name of the photo URL
else:
photos_path_df = gdf['photo'].str.rpartition('/') # type: ignore
if 2 in photos_path_df.columns:
gdf['photo'] = photos_path_df[2]
gdf['create_date'] = gdf['create_date'].astype(str) # type: ignore
return Response(content=gdf.to_json(),
media_type="application/json") # type: ignore
@api_app.get('/trail/details')
async def get_trail_all_details(
db_session: db_session,
):
"""
Get details of all trails
"""
trails = await db_session.exec(select(
Trail.id,
Trail.name,
Trail.description,
Trail.photo,
))
df = pd.DataFrame(trails.all())
# Get only file name of the photo URL
photos_path_df = df['photo'].str.rpartition('/')
if 2 in photos_path_df.columns:
df['photo'] = photos_path_df[2]
return Response(content=df.to_json(orient='records'),
media_type="application/json")
@api_app.get('/tree-trail')
async def get_tree_trail(
db_session: db_session,
) -> list[TreeTrail]:
"""
Get all relations between trees and trails.
Note that these are not checked for permissions, as there's no really
valuable information.
"""
data = await db_session.exec(select(TreeTrail))
return data.all() # type: ignore
@api_app.get('/tree')
async def get_trees(
roles: list[Role] = Depends(get_current_roles),
):
"""
Get all trees
"""
gdf = await Tree.get_gdf(
where=or_(Tree.viewable_role_id.in_([role.name for role in roles]), # type: ignore
Tree.viewable_role_id == None)) # type: ignore # noqa: E711
if len(gdf) > 0:
gdf['plantekey_id'] = gdf['plantekey_id'].fillna('')
tree_trail_details = await get_local_details()
if len(tree_trail_details) > 0:
gdf = gdf.merge(tree_trail_details, left_on='plantekey_id',
right_index=True, how='left')
gdf['symbol'].fillna('\uE034', inplace=True)
else:
gdf['symbol'] = '\uE034'
else:
gdf.set_geometry([], inplace=True)
# Get only file name of the photo URL
if len(gdf) > 0:
photos_path_df = gdf['photo'].str.rpartition('/') # type: ignore
if 2 in photos_path_df.columns:
gdf['photo'] = photos_path_df[2]
## TODO: format create_date in proper json
gdf['create_date'] = gdf['create_date'].astype(str) # type: ignore
gdf['id'] = gdf.index.astype(str) # type: ignore
return Response(content=gdf.to_json(),
media_type="application/json")
def get_attachment_path(uuid, extension, feature_type, feature_id) -> Tuple[str, Path]:
root_storage_path = Path(conf.storage.root_attachment_path)
full_name = str(uuid) + extension
dir: Path = root_storage_path / feature_type / str(feature_id)
dir.mkdir(parents=True, exist_ok=True)
return full_name, dir / full_name
@api_app.post('/tree')
async def addTree(
request: Request,
db_session: db_session,
user: User = Depends(get_current_active_user),
plantekey_id: str = Form(),
picture: Optional[str] = Form(None),
trail_ids: str | None = Form(None),
lng: str = Form(),
lat: str = Form(),
uuid1: Optional[str] = Form(None),
details: str | None = Form(None)
):
tree = Tree(**Tree.get_tree_insert_params(
plantekey_id,
lng, lat,
user.username,
loads(details) if details else {},
))
if trail_ids is not None:
for trail_id in trail_ids.split(','):
tree_trail = TreeTrail(
tree_id=tree.id,
trail_id=int(trail_id)
)
db_session.add(tree_trail)
## Save files
resp:dict[str, UUID | str | None] = {'id': tree.id}
if picture is not None:
re_mimetype = re_findmimetype.search(picture)
if re_mimetype:
mimetype: str = re_mimetype.group(1)
picture_file, full_path = get_attachment_path(
uuid1, mimetypes.guess_extension(mimetype),
'tree', tree.id)
with open(full_path, 'wb') as file_:
## Feels i'm missing something as it's quite ugly:
# print(full_path)
decoded = standard_b64decode(picture[picture.find(',')+1:])
file_.write(decoded)
resp['picture'] = picture_file
tree.photo = picture_file
else:
logger.warning('Bad picture data: cannot find mimetype')
db_session.add(tree)
await db_session.commit()
return resp
@api_app.get('/poi')
async def get_pois(
db_session: db_session,
roles: list[Role] = Depends(get_current_roles),
) -> list[POI]:
"""
Get all POI
"""
gdf = await POI.get_gdf() # type: ignore
if len(gdf) > 0:
gdf.set_index('id', inplace=True)
gdf.set_geometry(gpd.GeoSeries.from_wkb(gdf.wkb), inplace=True)
gdf.drop('wkb', axis=1, inplace=True)
gdf['symbol'] = '\uE001'
else:
gdf.set_geometry([], inplace=True)
gdf['id'] = gdf.index.astype('str')
# Also remove create_date, not really required and would need to be
# propared to be serialized
gdf.drop(columns='create_date', inplace=True)
return Response(content=gdf.to_json(),
media_type="application/json") # type: ignore
@api_app.get('/zone')
async def get_zones(
db_session: db_session,
roles: list[Role] = Depends(get_current_roles),
) -> list[Zone]:
"""
Get all Zones
"""
gdf = await Zone.get_gdf(
where=or_(Zone.viewable_role_id.in_([role.name for role in roles]), # type: ignore
Zone.viewable_role_id == None)) # type: ignore # noqa: E711
# Sort by area, a simple workaround for selecting smaller areas on the map
gdf['area'] = gdf.area
gdf.sort_values('area', ascending=False, inplace=True)
gdf.drop(columns='area', inplace=True)
# Also remove create_date, not really required and would need to be
# propared to be serialized
gdf.drop(columns='create_date', inplace=True)
return Response(content=gdf.to_json(),
media_type="application/json") # type: ignore
@api_app.get('/style')
async def get_styles(
db_session: db_session,
) -> list[MapStyle]:
"""
Get all Styles
"""
data = await db_session.exec(select(MapStyle))
return data.all() # type: ignore
@api_app.put("/trail/photo/{id}/{file_name}")
async def upload_trail_photo(request: Request,
db_session: db_session,
id: str, file_name: str,
file: UploadFile | None = None):
"""
This was tested with QGis, provided the properties for the trail layer
have been defined correctly.
This includes: in "Attributes Form", field "photo", "Widget Type"
is set as WebDav storage, with store URL set correcly with a URL like:
* 'http://localhost:4200/v1/trail/photo/' || "id" || '/' || file_name(@selected_file_path)
* 'https://treetrail.avcsr.org/v1/trail/' || "id" || '/' || file_name(@selected_file_path)
## XXX: probably broken info as paths have changed
""" # noqa: E501
base_dir = get_attachment_trail_root() / id
if not base_dir.is_dir():
await aiofiles.os.mkdir(base_dir)
if not file:
contents = await request.body()
# WebDAV
if len(contents) > 0:
# Save the file
async with aiofiles.open(base_dir / file_name, 'wb') as f:
await f.write(contents)
# Update the trail record
# With QGis this gets overwritten when it is saved
trail = await db_session.get(Trail, int(id))
if trail is None:
raise HTTPException(status_code=status.HTTP_417_EXPECTATION_FAILED,
detail=f'No such trail id {id}')
trail.photo = file_name
await db_session.commit()
else:
return {"message": "No file found in the request"}
else:
# Multipart form - not tested
try:
contents = file.file.read()
async with aiofiles.open(base_dir, 'wb') as f:
await f.write(contents)
except Exception:
return {"message": "There was an error uploading the file"}
finally:
file.file.close()
return {"message": f"Successfully uploaded {file.filename} for id {id}"}
@api_app.put("/tree/photo/{id}/{file_name}")
async def upload_tree_photo(request: Request,
db_session: db_session,
id: str, file_name: str,
file: UploadFile | None = None):
"""
This was tested with QGis, provided the properties for the tree layer
have been defined correctly.
This includes: in "Attributes Form", field "photo", "Widget Type"
is set as WebDav storage, with store URL set correcly with a URL like:
* 'http://localhost:4200/v1/tree/photo/' || "id" || '/' || file_name(@selected_file_path)
* 'https://treetrail.avcsr.org/v1/tree/' || "id" || '/' || file_name(@selected_file_path)
## XXX: probably broken info as paths have changed
""" # noqa: E501
base_dir = get_attachment_tree_root() / id
if not base_dir.is_dir():
await aiofiles.os.mkdir(base_dir)
if not file:
contents = await request.body()
# WebDAV
if len(contents) > 0:
# Save the file
async with aiofiles.open(base_dir / file_name, 'wb') as f:
await f.write(contents)
# Update the tree record
# With QGis this gets overwritten when it is saved
tree = await db_session.get(Tree, int(id))
if tree is None:
raise HTTPException(status_code=status.HTTP_417_EXPECTATION_FAILED,
detail=f'No such tree id {id}')
tree.photo = file_name
await db_session.commit()
else:
return {'message': 'No file found in the request'}
else:
# Multipart form - not tested
try:
contents = file.file.read()
async with aiofiles.open(base_dir, 'wb') as f:
await f.write(contents)
except Exception:
return {"message": "There was an error uploading the file"}
finally:
file.file.close()
return {"message": f"Successfully uploaded {file.filename} for id {id}"}
# => Below =>
# Serve the images
# The URLs are better served by a reverse proxy front-end, like Nginx
api_app.mount('/tree', StaticFiles(directory=mkdir(get_attachment_tree_root())), name='tree_attachments')
api_app.mount('/trail', StaticFiles(directory=mkdir(get_attachment_trail_root())), name='trail_attachments')
api_app.mount('/poi', StaticFiles(directory=mkdir(get_attachment_poi_root())), name='poi_attachments')

249
src/treetrail/application.py Executable file
View file

@ -0,0 +1,249 @@
#!/usr/bin/env python
import logging
import sys
from contextlib import asynccontextmanager
try:
import coloredlogs # type: ignore
except ImportError:
pass
else:
coloredlogs.install()
from fastapi import FastAPI, responses
from fastapi.staticfiles import StaticFiles
from treetrail.config import conf, get_cache_dir, create_dirs
from treetrail.plantekey import setup as setup_plantekey, pek_app
from treetrail.api_v1 import api_app
from treetrail.tiles import tiles_app, registry as tiles_registry
from treetrail.attachments import attachment_app
from treetrail.database import create_db
from treetrail.utils import mkdir
@asynccontextmanager
async def lifespan(app: FastAPI):
create_dirs()
setup_plantekey(app)
await create_db()
await tiles_registry.setup(app)
yield
await tiles_registry.shutdown(app)
app = FastAPI(
title=conf.app.title,
lifespan=lifespan,
version=conf.version,
default_response_class=responses.ORJSONResponse,
)
api_app.mount('/plantekey', pek_app)
app.mount(f'{conf.base_href}/v1', api_app)
app.mount(f'{conf.base_href}/tiles', tiles_app)
app.mount(f'{conf.base_href}/attachment', attachment_app)
app.mount(
f'{conf.base_href}/static/cache',
StaticFiles(directory=mkdir(get_cache_dir())),
name='static_generated'
)
def _main(argv=None):
from argparse import ArgumentParser
arg_parser = ArgumentParser(
description="fastapi Application server",
prog="fastapi"
)
arg_parser.add_argument(
'--path',
help='Path of socket file',
)
arg_parser.add_argument(
"-H", "--hostname",
help="TCP/IP hostname to serve on (default: %(default)r)",
default="localhost"
)
arg_parser.add_argument(
"-P", "--port",
help="TCP/IP port to serve on",
type=int,
)
arg_parser.add_argument(
"-c", "--create-db",
help="Create tables in database",
action="store_true"
)
arg_parser.add_argument(
"--username",
help="Create or update a user in database",
type=str,
)
arg_parser.add_argument(
"--add-role",
help="Add the role",
type=str,
)
arg_parser.add_argument(
"--add-user-role",
help="Add the role to the user",
type=str,
)
arg_parser.add_argument(
"--password",
help="Set the password for a user in database",
type=str,
)
arg_parser.add_argument(
"--full-name",
help="Set the full name for a user in database",
type=str,
)
arg_parser.add_argument(
"--email",
help="Set the email for a user in database",
type=str,
)
arg_parser.add_argument(
"--enable",
help="Enable user",
action="store_true"
)
arg_parser.add_argument(
"--disable",
help="Disable user",
action="store_true"
)
arg_parser.add_argument(
"--delete-user",
help="Delete user",
action="store_true"
)
arg_parser.add_argument(
"--import-trees",
help="Import trees (eg. gpkg file). Images can be imported.",
type=str,
)
arg_parser.add_argument(
"--layers",
help="Layers to import.",
nargs='*',
type=str,
)
arg_parser.add_argument(
"--import-zones",
help="Import zones (eg. gpkg file).",
type=str,
)
arg_parser.add_argument(
"--import-plantekey-trees-to-trail",
help="Import trees from plantekey web site. Provide the trail id.",
type=str,
)
arg_parser.add_argument(
"--import-plantekey-plants",
help="Import plants from plantekey web site",
action="store_true"
)
arg_parser.add_argument(
"--list-layers",
help="List layers in the geodata file",
type=str,
)
arg_parser.add_argument(
"-d", "--debug", '-d',
help="Set debug logging",
action="store_true"
)
args = arg_parser.parse_args()
if args.debug:
logging.root.setLevel(logging.DEBUG)
## For ipdb:
logging.getLogger('parso').setLevel(logging.WARNING)
if args.create_db:
from treetrail.database import create_db
import asyncio
asyncio.run(create_db())
sys.exit(0)
if args.enable:
from treetrail.security import enable_user
import asyncio
asyncio.run(enable_user(args.username))
sys.exit(0)
if args.disable:
from treetrail.security import enable_user
import asyncio
asyncio.run(enable_user(args.username, False))
sys.exit(0)
if args.add_role:
from treetrail.security import add_role
import asyncio
asyncio.run(add_role(args.add_role))
sys.exit(0)
if args.add_user_role:
from treetrail.security import add_user_role
import asyncio
if not args.username:
print('Please provide username')
sys.exit(1)
asyncio.run(add_user_role(args.username, args.add_user_role))
sys.exit(0)
if args.delete_user:
from treetrail.security import delete_user
import asyncio
asyncio.run(delete_user(args.username))
sys.exit(0)
if args.list_layers:
from treetrail.import_cli import list_layers
list_layers(args.list_layers)
sys.exit(0)
if args.import_trees:
from treetrail.import_cli import import_trees
import_trees(args)
sys.exit(0)
if args.import_zones:
from treetrail.import_cli import import_zones
import_zones(args)
sys.exit(0)
if args.import_plantekey_plants:
from treetrail.import_cli import import_plantekey_plants
import asyncio
asyncio.run(import_plantekey_plants(args))
sys.exit(0)
if args.import_plantekey_trees_to_trail:
from treetrail.import_cli import import_plantekey_trees
import_plantekey_trees(args)
sys.exit(0)
if args.username:
from treetrail.security import create_user
import asyncio
asyncio.run(create_user(**vars(args)))
sys.exit(0)
print(
'This application needs to be run with an asgi server like uvicorn.',
'For example:',
'uvicorn application:app',
'or:',
'uvicorn application:app --port 5002',
'or (for development):',
'uvicorn --reload application:app --uds /var/run/treetrail.socket',
'or (for production):',
'uvicorn --loop uvloop application:app --port 5002',
sep='\n'
)
if __name__ == '__main__':
_main()

View file

@ -0,0 +1,23 @@
from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from treetrail.api_v1 import (get_attachment_tree_root,
get_attachment_trail_root, get_attachment_poi_root)
from treetrail.plantekey import get_thumbnail_root, get_img_root, get_img_type_root
from treetrail.utils import mkdir
attachment_app = FastAPI()
attachment_app.mount("/plantekey/img", StaticFiles(directory=mkdir(get_img_root())),
name="plantekey_img")
attachment_app.mount("/plantekey/thumb", StaticFiles(directory=mkdir(get_thumbnail_root())),
name="plantekey_thumb")
attachment_app.mount("/plantekey/type", StaticFiles(directory=mkdir(get_img_type_root())),
name="plantekey_type")
attachment_app.mount("/trail", StaticFiles(directory=mkdir(get_attachment_trail_root())),
name="trail")
attachment_app.mount("/tree", StaticFiles(directory=mkdir(get_attachment_tree_root())),
name="tree")
attachment_app.mount("/poi", StaticFiles(directory=mkdir(get_attachment_poi_root())),
name="poi")

167
src/treetrail/config.py Normal file
View file

@ -0,0 +1,167 @@
from os import environ
from pathlib import Path
from secrets import token_hex
from typing import Any, Type, Tuple
from yaml import safe_load
import logging
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict,
)
from pydantic.v1.utils import deep_update
from treetrail._version import __version__
logger = logging.getLogger(__name__)
ENV = environ.get("env", "prod")
config_files = [
Path(Path.cwd().root) / "etc" / "treetrail" / ENV,
Path.home() / ".local" / "treetrail" / ENV,
]
def config_file_settings() -> dict[str, Any]:
config: dict[str, Any] = {}
for p in config_files:
for suffix in {".yaml", ".yml"}:
path = p.with_suffix(suffix)
if not path.is_file():
logger.debug(f"No file found at `{path.resolve()}`")
continue
logger.info(f"Reading config file `{path.resolve()}`")
if path.suffix in {".yaml", ".yml"}:
config = deep_update(config, load_yaml(path))
else:
logger.info(f"Unknown config file extension `{path.suffix}`")
return config
def load_yaml(path: Path) -> dict[str, Any]:
with Path(path).open("r") as f:
config = safe_load(f)
if not isinstance(config, dict):
raise TypeError(f"Config file has no top-level mapping: {path}")
return config
def create_dirs():
"""
Create the directories needed for a proper functioning of the app
"""
## Avoid circular imports
from treetrail.api_v1 import attachment_types
for type in attachment_types:
base_dir = Path(conf.storage.root_attachment_path) / type
base_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Cache dir: {get_cache_dir()}")
get_cache_dir().mkdir(parents=True, exist_ok=True)
def get_cache_dir() -> Path:
return Path(conf.storage.root_cache_path)
class MyBaseSettings(BaseSettings):
model_config = SettingsConfigDict(
env_prefix='treetrail_',
env_nested_delimiter="_",
)
class DB(MyBaseSettings):
# uri: str
host: str = "treetrail-database"
port: int = 5432
user: str = "treetrail"
db: str = "treetrail"
password: str = "treetrail"
debug: bool = False
info: bool = False
pool_size: int = 10
max_overflow: int = 10
echo: bool = False
def get_sqla_url(self):
return f"postgresql+asyncpg://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}"
def get_pg_url(self):
return f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db}"
class App(MyBaseSettings):
title: str = "Tree Trail"
class Storage(MyBaseSettings):
root_attachment_path: str = "/var/lib/treetrail/attachments"
root_cache_path: str = "/var/lib/treetrail/cache"
class Tiles(MyBaseSettings):
baseDir: str = "/var/lib/treetrail/mbtiles_files"
useRequestUrl: bool = True
spriteBaseDir: str = "/var/lib/treetrail/mbtiles_sprites"
spriteUrl: str = "/tiles/sprite/sprite"
spriteBaseUrl: str = "https://treetrail.example.org"
osmBaseDir: str = "/var/lib/treetrail/osm"
class Map(MyBaseSettings):
zoom: float = 14.0
pitch: float = 0.0
lat: float = 12.0000
lng: float = 79.8106
bearing: float = 0
background: str = "OpenFreeMap"
class Geo(MyBaseSettings):
simplify_geom_factor: int = 10000000
simplify_preserve_topology: bool = False
class Security(MyBaseSettings):
"""
JWT security configuration
"""
secret_key: str = token_hex(32)
'''Generate with eg.: "openssl rand -hex 32"'''
access_token_expire_minutes: float = 30
class ExternalMapStyle(MyBaseSettings):
name: str
url: str
class Config(MyBaseSettings):
@classmethod
def settings_customise_sources(
cls,
settings_cls: Type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> Tuple[PydanticBaseSettingsSource, ...]:
return (env_settings, init_settings, file_secret_settings, config_file_settings) # type: ignore
app: App = App()
# postgres: dict
storage: Storage = Storage()
map: Map = Map()
mapStyles: dict[str, str] = {}
tiles: Tiles = Tiles()
security: Security = Security()
geo: Geo = Geo()
version: str
db: DB = DB()
base_href: str = '/treetrail'
conf = Config(version=__version__) # type: ignore

103
src/treetrail/database.py Normal file
View file

@ -0,0 +1,103 @@
from contextlib import asynccontextmanager
import sys
from typing import Annotated
from collections.abc import AsyncGenerator
from asyncio import sleep
import logging
from fastapi import Depends
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel import SQLModel, select, func, col
from treetrail.config import conf
logger = logging.getLogger(__name__)
CREATE_DB_TIMEOUT = 30
engine = create_async_engine(
conf.db.get_sqla_url(),
echo=conf.db.echo,
pool_size=conf.db.pool_size,
max_overflow=conf.db.max_overflow,
)
async def create_db(drop=False):
attempts = CREATE_DB_TIMEOUT
async def try_once():
async with engine.begin() as conn:
if drop:
await conn.run_sync(SQLModel.metadata.drop_all)
await conn.run_sync(SQLModel.metadata.create_all)
while attempts > 0:
try:
await try_once()
except ConnectionRefusedError:
logger.debug(
f"Cannot connect to database during init (create_db), "
f"waiting {attempts} more seconds"
)
attempts -= 1
await sleep(1)
else:
if await is_fresh_install():
await populate_init_db()
return
else:
logger.warning(
f"Cannot connect to database after {CREATE_DB_TIMEOUT}, giving up."
)
sys.exit(1)
async def is_fresh_install() -> bool:
"""Detect is the database is newly created, without data"""
from treetrail.models import User
async with db_session() as session:
nb_users = (await session.exec(select(func.count(col(User.username))))).one()
return nb_users == 0
async def populate_init_db():
"""Populate the database for a fresh install"""
from sqlalchemy import text
from treetrail.security import create_user, add_role, add_user_role
logger.info("Populating initial database")
user = await create_user(username="admin", password="admin")
role = await add_role(role_id="admin")
await add_user_role(user.username, role.name)
async with db_session() as session:
for initial in initials:
await session.execute(text(initial))
logger.debug(f'Added map style {initial}')
await session.commit()
## Default styles, to be inserted in the DB
initials: list[str] = [
"""INSERT INTO map_style (layer, paint, layout) values ('trail', '{"line-color": "#cd861a", "line-width": 6, "line-blur": 2, "line-opacity": 0.9 }', '{"line-join": "bevel"}');""",
"""INSERT INTO map_style (layer, layout) values ('tree', '{"icon-image":"tree", "icon-size": 0.4}');""",
"""INSERT INTO map_style (layer, layout) values ('tree-hl', '{"icon-image":"tree", "icon-size": 0.4}');""",
"""INSERT INTO map_style (layer, layout) values ('poi', '{"icon-image":"poi", "icon-size": 0.4}');""",
"""INSERT INTO map_style (layer, paint) VALUES ('zone', '{"fill-color": ["match", ["string", ["get", "type"]], "Forest", "#00FF00", "Master Plan", "#EE4455", "#000000"], "fill-opacity": 0.5}');""",
] # noqa: E501
async def get_db_session() -> AsyncGenerator[AsyncSession]:
async with AsyncSession(engine) as session:
yield session
@asynccontextmanager
async def db_session() -> AsyncGenerator[AsyncSession]:
async with AsyncSession(engine) as session:
yield session
fastapi_db_session = Annotated[AsyncSession, Depends(get_db_session)]

View file

@ -0,0 +1,37 @@
app:
title: Tree Trail
db:
database: treetrail
user: treetrail
password: treetrail!secret
host: localhost
port: 5432
minsize: 1
maxsize: 5
storage:
root_attachment_path: /var/lib/treetrail/attachments
root_cache_path: /var/lib/treetrail/cache
map:
tiles:
baseDir: /var/lib/treetrail/mbtiles_files
useRequestUrl: true
spriteBaseDir: /var/lib/treetrail/mbtiles_sprites
spriteUrl: /tiles/sprite/sprite
spriteBaseUrl: https://treetrail.example.org
osmBaseDir: /var/lib/treetrail/osm
zoom: 14
pitch: 45
lat: 45.8822
lng: 6.1781
bearing: 0
background: OpenFreeMap
mapStyles:
OpenFreeMap: https://tiles.openfreemap.org/styles/liberty
security:
secret_key: '993e39ce154ca95d0908384a4eedc9bd26147b34995be96cc722b654616d0c28'
access_token_expire_minutes: 30

19
src/treetrail/gisaf.py Normal file
View file

@ -0,0 +1,19 @@
from datetime import datetime
from typing import Annotated
from sqlalchemy import String
from geoalchemy2 import Geometry, WKBElement
from sqlmodel import Field
from treetrail.models import BaseModel
class GisafTree(BaseModel, table=True):
__tablename__: str = "gisaf_tree" # type: ignore
plantekey_id: int = Field(foreign_key='plantekey.id', primary_key=True)
data: int
create_date: datetime = Field(default_factory=datetime.now)
geom: Annotated[str, WKBElement] = Field(
sa_type=Geometry('POINTZ', srid=4326, dimension=3))
photo: str = Field(sa_type=String(250)) # type: ignore

276
src/treetrail/import_cli.py Normal file
View file

@ -0,0 +1,276 @@
import sys
from pathlib import Path
from json import dumps
import logging
from shutil import copy
from datetime import datetime
import requests
from uuid import UUID
from sqlalchemy import create_engine
from sqlmodel import select, Session, delete, create_engine as sqlmodel_create_engine
import geopandas as gpd # type: ignore
import pandas as pd
from treetrail.config import conf
from treetrail.utils import get_attachment_tree_root
from treetrail.models import Tree, Trail, TreeTrail, User, Zone
from treetrail.plantekey import Plant, fetch_browse, update_details
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
column_mapper = {
'pic_full': 'photo',
'Comment': 'comment',
}
base_tree_attachment_dir = get_attachment_tree_root()
def list_layers(file):
from fiona import listlayers
print(' '.join(f"'{x}'" for x in listlayers(file)))
def copy_image(record, base_dir):
'''Copy the file to the proper location for attachments'''
if pd.isna(record.photo):
return
file = base_dir / record.photo
dest_dir = base_tree_attachment_dir / record.id
dest_dir.mkdir(exist_ok=True)
copy(file, dest_dir)
def import_trees(args) -> None:
""" Import trees from a file containing geo data.
The geopackage file name is expected to be strict.
A description sould be given in the documentation"""
contributor_id = args.username
if contributor_id is None:
raise Exception('A user name is required to identify the contributor')
file_to_import = Path(args.import_trees).expanduser()
sync_engine = create_engine(conf.db.get_pg_url())
session = Session(sync_engine)
# Read and format the data in the file
gdf_trees = gpd.read_file(file_to_import, layer=args.layers or None)
gdf_trees.rename_geometry('geom', inplace=True)
gdf_trees.to_crs(4326, inplace=True)
gdf_trees.rename(columns=column_mapper, inplace=True)
# Photos: take only the file name
# gdf_trees['photo'] = gdf_trees['photo'].str.split('/', expand=True)[1]
gdf_trees['pic_stem'] = gdf_trees['pic_stem'].str.split('/', expand=True)[1]
gdf_trees['contributor_id'] = contributor_id
gdf_trees['create_date'] = pd.to_datetime(gdf_trees['Date Edited'])
gdf_trees['id'] = gdf_trees['UUID'].str.strip('{').str.strip('}')
gdf_trees.drop(columns='UUID', inplace=True)
## Determine which columns are in the database
# ... and store the remaining in a dict datastructure to store in JSON "data" column
gdf_existing_trees: gpd.GeoDataFrame
gdf_existing_trees = gpd.read_postgis('select * from tree', sync_engine) # type: ignore
unknown_columns = {col for col in gdf_trees if col not in gdf_existing_trees.columns}
known_columns = {col for col in gdf_trees if col in gdf_existing_trees.columns}
left_columns = {col for col in gdf_existing_trees.columns if col not in gdf_trees}
logger.debug(f'Known columns: {known_columns}')
logger.debug(f'Unknown left: {unknown_columns}')
logger.debug(f'Columns left: {left_columns}')
# Remove empty extra fields
new_trees_data_raw = gdf_trees[list(unknown_columns)].to_dict(orient='records')
new_trees_data = []
for data in new_trees_data_raw:
new_trees_data.append(
{k: v for k, v in data.items()
if not pd.isna(v)
}
)
gdf_trees['data'] = [dumps(d) for d in new_trees_data]
gdf_trees.drop(columns=unknown_columns, inplace=True)
gdf_trees.reset_index(inplace=True)
gdf_trees.drop(columns='index', inplace=True)
# Find the trails
gdf_trails: gpd.GeoDataFrame
gdf_trails = gpd.read_postgis(select(Trail), sync_engine, index_col='id') # type: ignore
# Assign trails to the new trees
gdf_trails['zone'] = gdf_trails.to_crs(3857).buffer(150).to_crs(4326) # type: ignore
gdf_trails.set_geometry('zone', inplace=True)
gdf_trees[['trail', 'viewable_role_id']] = gdf_trees.sjoin(gdf_trails, how='left')[['index_right', 'viewable_role_id']]
# Save trees to the database
## Remove the trees already in the DB from the datafreame to insert
gdf_trees.set_index('id', inplace=True)
gdf_new_trees = gdf_trees.loc[~gdf_trees.index.isin(gdf_existing_trees['id'].astype(str))].reset_index() # type:ignore
gdf_new_trees.drop(columns='trail').to_postgis(Tree.__tablename__, sync_engine, if_exists='append')
# Copy the images to the treetail storage dir
gdf_new_trees.apply(copy_image, axis=1, base_dir=file_to_import.parent)
# for file in import_image_dir.iterdir():
# id = file.stem.split('_')[-1]
# gdf_trees.photo.str.split('/', expand=True)1]
df_tt_existing = pd.read_sql(select(TreeTrail), sync_engine)
df_tt_existing.rename(columns={'tree_id': 'id', 'trail_id': 'trail'}, inplace=True)
df_tt_existing['id'] = df_tt_existing['id'].astype(str)
df_tt_new = gdf_trees['trail'].reset_index()
df_tt_to_insert = pd.concat([df_tt_new, df_tt_existing]).drop_duplicates(keep=False) # type: ignore
def get_tt_rel(tree):
return TreeTrail(tree_id=tree.id, trail_id=tree.trail)
tree_trails = df_tt_to_insert.reset_index().apply(get_tt_rel, axis=1) # type: ignore
with Session(sync_engine) as session:
for tt in tree_trails:
session.add(tt)
session.commit()
logger.info(f'Imported on behalf of {args.username} {len(gdf_new_trees)} trees')
def import_zones(args) -> None:
"""Import a geopackage with zones.
The format of the input file is strict.
"""
if args.layers is None:
print('Provide layer names from:')
list_layers(args.import_zones)
sys.exit(1)
file_to_import = Path(args.import_zones).expanduser()
fields_map = {
'Area': 'name',
}
fields_ignored = [
'fid',
'Area Area',
'Area type',
]
sync_engine = create_engine(conf.db.get_pg_url())
gdf_existing_zones: gpd.GeoDataFrame
gdf_existing_zones = gpd.read_postgis('select * from zone', sync_engine) # type: ignore
now = datetime.now()
for layer in args.layers:
print(layer)
gdf = gpd.read_file(file_to_import, layer=layer)
gdf.rename(columns=fields_map, inplace=True)
gdf.drop(columns=fields_ignored, inplace=True, errors='ignore')
unknown_columns = {col for col in gdf
if col not in gdf_existing_zones.columns}
unknown_columns = unknown_columns - {'geometry'}
all_data_raw = gdf[list(unknown_columns)].to_dict(orient='records')
all_data = []
for data in all_data_raw:
all_data.append(
{k: v for k, v in data.items()
if not pd.isna(v)
}
)
gdf['data'] = [dumps(d) for d in all_data]
gdf.drop(columns=unknown_columns, inplace=True)
gdf.reset_index(inplace=True)
gdf.drop(columns='index', inplace=True)
gdf['type'] = layer
gdf['create_date'] = now
gdf.to_crs("EPSG:4326", inplace=True)
gdf.rename_geometry('geom', inplace=True)
if 'name' not in gdf.columns:
gdf['name'] = '?'
else:
gdf['name'].fillna('?', inplace=True)
gdf.to_postgis(Zone.__tablename__, sync_engine,
if_exists='append', index=False)
def import_plantekey_trees(args, contributor_id='plantekey'):
"""Import all trees from plantekey web site
"""
now = datetime.now()
sync_engine = sqlmodel_create_engine(conf.db.get_pg_url())
trail_id = int(args.import_plantekey_trees_to_trail)
## Harmless check that the 'plantekey' contributor exists
with Session(sync_engine) as session:
contributor = session.get(User, contributor_id)
if contributor is None:
raise UserWarning('User plantekey not found')
## Get the raw data from the plantekey web site
plantekey_trees_raw = requests.get('https://plantekey.com/api.php?action=markers').json()['markers']
## Convert that raw data into a nice dataframe
df_tree_plantekey = pd.DataFrame(plantekey_trees_raw).drop(columns=['0', '1', '2', '3'])
df_tree_plantekey['MasterID'] = df_tree_plantekey['MasterID'].astype(int)
df_tree_plantekey.sort_values('MasterID', inplace=True)
df_tree_plantekey.reset_index(drop=True, inplace=True)
print(f'Found {len(df_tree_plantekey)} trees in Plantekey web site')
## The MasterID is probably the plant id, so just dropping it
df_tree_plantekey.drop(columns=['MasterID'], inplace=True)
## Get the existing plants in the database
df_plants = pd.read_sql(select(Plant), sync_engine)
## Merge those trees with the plants
df_tree = df_tree_plantekey.merge(df_plants[['name', 'id']], left_on='name', right_on='name', how='left')
df_tree.rename(columns={'id': 'plantekey_id'}, inplace=True)
## Generate a primary key (custom UUID), which is predictable: it marks the source as plantekey, and tracks the plantkey tree id
base_fields = (0x10000001, 0x0001, 0x0001, 0x00, 0x01)
def id_to_uuid(tree) -> UUID:
return UUID(fields=base_fields + (tree['index'], ))
df_tree['id'] = df_tree.reset_index().apply(lambda _: id_to_uuid(_), axis=1) # type: ignore
#gdf_tree.drop(columns=['plantekey_tree_id'], inplace=True)
df_tree.set_index('id', inplace=True)
## Detect missing plants
missing_plants = df_tree.loc[df_tree.plantekey_id.isna()]['name'].unique()
if len(missing_plants) > 0:
print(f'* Warning: {len(missing_plants)} plants are missing in Treetrail, please update it!')
print('* Missing plants:')
for mp in missing_plants:
print(f' {mp}')
df_tree = df_tree.loc[~df_tree.plantekey_id.isna()]
print(f'* Importing only {len(df_tree)} trees.')
## Build the geodataframe
gdf_tree = gpd.GeoDataFrame(
df_tree,
geometry=gpd.points_from_xy(df_tree["Longitude"], df_tree["Latitude"]),
crs="EPSG:4326",
) # type: ignore
gdf_tree.drop(columns=['Latitude', 'Longitude', 'name'], inplace=True)
#gdf_tree['data'] = gdf_tree.plantekey_tree_id.apply(lambda _: dumps({'Plantekey tree id': int(_)}))
#gdf_tree['data'] = dumps({'Source': 'Plantekey (Botanical Graden)'})
gdf_tree.rename(columns={'geometry': 'geom'}, inplace=True)
gdf_tree.set_geometry('geom', inplace=True)
# Save to the database
## Prepare the geodataframe for saving
gdf_tree['create_date'] = now
gdf_tree['contributor_id'] = contributor.username
gdf_tree['data'] = dumps({})
## Remove all trees with the contributor "plantekey" before adding the new ones
with Session(sync_engine) as session:
to_delete_trees = session.exec(select(Tree).where(Tree.contributor_id==contributor_id)).all()
to_delete_tree_ids = [tree.id for tree in to_delete_trees]
print(f'{len(to_delete_trees)} trees existing in the database from plantekey, deleting all of them')
## Also delete their relationships to that trail
session.exec(delete(TreeTrail).where(TreeTrail.tree_id.in_(to_delete_tree_ids))) # type: ignore
session.exec(delete(Tree).where(Tree.id.in_(to_delete_tree_ids))) # type: ignore
session.commit()
## Finally insert to the database
gdf_tree.to_postgis(Tree.__tablename__, sync_engine, if_exists='append', index=True)
## And add those to the trail
with Session(sync_engine) as session:
for tree_id in gdf_tree.index:
session.add(TreeTrail(tree_id=tree_id, trail_id=trail_id))
session.commit()
print(f'Added {len(gdf_tree)} trees.')
print('Import done.')
async def import_plantekey_plants(args):
df = await fetch_browse()
await update_details(df)

View file

@ -0,0 +1,22 @@
version: 1
disable_existing_loggers: false
formatters:
standard:
format: "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
handlers:
console:
class: logging.StreamHandler
formatter: standard
stream: ext://sys.stdout
loggers:
uvicorn:
error:
propagate: true
root:
level: INFO
handlers: [console]
propagate: no

271
src/treetrail/models.py Normal file
View file

@ -0,0 +1,271 @@
from typing import Annotated, Any, Literal
from datetime import datetime
import uuid
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import joinedload, QueryableAttribute
from geoalchemy2 import Geometry, WKBElement # type: ignore
from sqlmodel import (SQLModel, Field, String, Relationship, JSON,
select)
import pandas as pd
import geopandas as gpd # type: ignore
from treetrail.utils import pandas_query, geopandas_query
from treetrail.config import Map, conf, App
from treetrail.database import db_session
class BaseModel(SQLModel):
@classmethod
def selectinload(cls) -> list[Literal['*'] | QueryableAttribute[Any]]:
return []
@classmethod
async def get_df(cls, **kwargs) -> pd.DataFrame:
return await cls._get_df(pandas_query, **kwargs)
@classmethod
async def get_gdf(cls, **kwargs) -> gpd.GeoDataFrame:
return await cls._get_df(geopandas_query, model=cls, **kwargs) # type: ignore
@classmethod
async def _get_df(cls, method, *,
where=None, with_related=True, with_only_columns=[],
simplify_tolerance: float | None=None,
preserve_topology: bool | None=None,
**kwargs) -> pd.DataFrame | gpd.GeoDataFrame:
async with db_session() as session:
if len(with_only_columns) == 0:
query = select(cls)
else:
columns = set(with_only_columns)
# TODO: user SQLModel model_fields instead of __table__
columns.add(*(col.name for col in cls.__table__.primary_key.columns)) # type: ignore
query = select(*(getattr(cls, col) for col in columns))
if where is not None:
query = query.where(where)
## Get the joined tables
joined_tables = cls.selectinload()
if with_related and len(joined_tables) > 0:
query = query.options(*(joinedload(jt) for jt in joined_tables))
df = await session.run_sync(method, query, **kwargs)
if method is geopandas_query and simplify_tolerance is not None:
df['geom'] = df['geom'].simplify(
simplify_tolerance / conf.geo.simplify_geom_factor,
preserve_topology=(conf.geo.simplify_preserve_topology
if preserve_topology is None
else preserve_topology)
)
## Chamge column names to reflect the joined tables
## Leave the first columns unchanged, as their names come straight
## from the model's fields
joined_columns = list(df.columns[len(cls.model_fields):])
renames: dict[str, str] = {}
## Match colum names with the joined tables
## Important: this assumes that orders of the joined tables
## and their columns is preserved by pandas' read_sql
for joined_table in joined_tables:
target = joined_table.property.target # type: ignore
target_name = target.name
for col in target.columns:
## Pop the column from the colujmn list and make a new name
renames[joined_columns.pop(0)] = f'{target.schema}_{target_name}_{col.name}'
df.rename(columns=renames, inplace=True)
## Finally, set the index of the df as the index of cls
df.set_index([c.name for c in cls.__table__.primary_key.columns], # type: ignore
inplace=True)
return df
class TreeTrail(BaseModel, table=True):
__tablename__: str = 'tree_trail' # type: ignore
tree_id: uuid.UUID | None = Field(
default=None,
foreign_key='tree.id',
primary_key=True
)
trail_id: int | None = Field(
default=None,
foreign_key='trail.id',
primary_key=True
)
class Trail(BaseModel, table=True):
__tablename__: str = "trail" # type: ignore
id: int = Field(primary_key=True)
name: str
description: str
create_date: datetime = Field(default_factory=datetime.now)
geom: Annotated[str, WKBElement] = Field(
sa_type=Geometry('LINESTRING', srid=4326, dimension=2),
)
photo: str = Field(sa_type=String(250)) # type: ignore
trees: list['Tree'] = Relationship(
link_model=TreeTrail,
back_populates="trails")
viewable_role_id: str | None = Field(foreign_key='role.name', index=True)
viewable_role: 'Role' = Relationship(back_populates='viewable_trails')
# __mapper_args__ = {"eager_defaults": True}
life_stages = ('Y', 'MA', 'M', 'OM', 'A')
class Tree(BaseModel, table=True):
__tablename__: str = "tree" # type: ignore
id: uuid.UUID | None = Field(
default_factory=uuid.uuid1,
primary_key=True,
index=True,
nullable=False,
)
create_date: datetime = Field(default_factory=datetime.now)
# ALTER TABLE tree ADD CONSTRAINT tree_plant_id_fkey FOREIGN KEY (plantekey_id) REFERENCES plant(id); # noqa: E501
plantekey_id: str = Field(foreign_key='plant.id')
geom: Annotated[str, WKBElement] = Field(
sa_type=Geometry('POINT', srid=4326, dimension=2))
photo: str | None = Field(sa_type=String(250)) # type: ignore
height: float | None
comments: str | None
# ALTER TABLE public.tree ADD contributor_id varchar(50) NULL;
# ALTER TABLE public.tree ADD CONSTRAINT contributor_fk FOREIGN KEY (contributor_id) REFERENCES public."user"(username);
contributor_id: str = Field(foreign_key='user.username', index=True)
contributor: 'User' = Relationship()
viewable_role_id: str | None = Field(foreign_key='role.name', index=True)
viewable_role: 'Role' = Relationship(back_populates='viewable_trees')
# CREATE EXTENSION hstore;
# ALTER TABLE tree ADD COLUMN data JSONB;
data: dict = Field(sa_type=MutableDict.as_mutable(JSONB), # type: ignore
default_factory=dict) # type: ignore
trails: list[Trail] = Relationship(
link_model=TreeTrail,
back_populates="trees")
__mapper_args__ = {"eager_defaults": True}
@classmethod
def get_tree_insert_params(
cls,
plantekey_id: str,
lng, lat,
username,
details: dict,
) -> dict:
params = {
'plantekey_id': plantekey_id,
'geom': f'POINT({lng} {lat})',
'contributor_id': username
}
## Consume some details in their respective field...
if p:=details.pop('comments', None):
params['comments'] = p
if p:=details.pop('height', None):
params['height'] = p
# ... and store the rest in data
params['data'] = {k: v for k, v in details.items() if v}
return params
class UserRoleLink(SQLModel, table=True):
__tablename__: str = 'roles_users' # type: ignore
user_id: str | None = Field(
default=None,
foreign_key='user.username',
primary_key=True
)
role_id: str | None = Field(
default=None,
foreign_key='role.name',
primary_key=True
)
class UserBase(BaseModel):
username: str = Field(sa_type=String(50), primary_key=True) # type: ignore
full_name: str | None = None
email: str | None = None
class User(UserBase, table=True):
__tablename__: str = "user" # type: ignore
roles: list["Role"] = Relationship(back_populates="users",
link_model=UserRoleLink)
password: str
disabled: bool = False
class UserWithRoles(UserBase):
roles: list['Role']
class Role(BaseModel, table=True):
__tablename__: str = "role" # type: ignore
name: str = Field(sa_type=String(50), primary_key=True) # type: ignore
users: list[User] = Relationship(back_populates="roles",
link_model=UserRoleLink)
viewable_trees: list[Tree] = Relationship(back_populates='viewable_role')
viewable_zones: list['Zone'] = Relationship(back_populates='viewable_role')
viewable_trails: list[Trail] = Relationship(back_populates='viewable_role')
class POI(BaseModel, table=True):
__tablename__: str = "poi" # type: ignore
id: int = Field(primary_key=True)
name: str = Field(sa_column=String(200)) # type: ignore
description: str | None = None
create_date: datetime = Field(default_factory=datetime.now)
geom: Annotated[str, WKBElement] = Field(
sa_type=Geometry('POINTZ', srid=4326, dimension=3))
photo: str = Field(sa_column=String(250)) # type: ignore
type: str = Field(sa_column=String(25)) # type: ignore
data: dict = Field(sa_type=MutableDict.as_mutable(JSONB), # type: ignore
default_factory=dict) # type: ignore
class Zone(BaseModel, table=True):
__tablename__: str = "zone" # type: ignore
id: int = Field(primary_key=True)
name: str = Field(sa_type=String(200)) # type:ignore
description: str
create_date: datetime = Field(default_factory=datetime.now)
geom: Annotated[str, WKBElement] = Field(
sa_type=Geometry('MULTIPOLYGON', srid=4326))
photo: str | None = Field(sa_type=String(250)) # type:ignore
type: str = Field(sa_type=String(30)) # type:ignore
data: dict | None = Field(sa_type=MutableDict.as_mutable(JSONB), # type:ignore
default_factory=dict) # type:ignore
viewable_role_id: str | None = Field(foreign_key='role.name', index=True)
viewable_role: 'Role' = Relationship(back_populates='viewable_zones')
class MapStyle(BaseModel, table=True):
__tablename__: str = "map_style" # type: ignore
id: int = Field(primary_key=True)
layer: str = Field(sa_type=String(100), nullable=False) # type:ignore
paint: dict[str, Any] | None = Field(sa_type=JSON(none_as_null=True)) # type:ignore
layout: dict[str, Any] | None = Field(sa_type=JSON(none_as_null=True)) # type:ignore
class VersionedComponent(BaseModel):
version: str
class BaseMapStyles(BaseModel):
embedded: list[str]
external: dict[str, str]
class Bootstrap(BaseModel):
client: VersionedComponent
server: VersionedComponent
app: App
user: UserWithRoles | None # type:ignore
map: Map
baseMapStyles: BaseMapStyles

488
src/treetrail/plantekey.py Normal file
View file

@ -0,0 +1,488 @@
import tarfile
from io import BytesIO
from logging import getLogger
from pathlib import Path
import numpy as np
import pandas as pd
from aiohttp_client_cache import FileBackend
from aiohttp_client_cache.session import CachedSession
from fastapi import Depends, FastAPI, HTTPException, Response, status
from fastapi.responses import ORJSONResponse, PlainTextResponse
from fastapi.staticfiles import StaticFiles
from PIL import Image
from sqlalchemy import String
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.mutable import MutableDict
from sqlmodel import Field, Relationship, select
from treetrail.config import get_cache_dir
from treetrail.database import db_session, fastapi_db_session
from treetrail.models import BaseModel, User
from treetrail.security import get_current_active_user
from treetrail.utils import read_sql, mkdir
logger = getLogger(__name__)
cache_expiry = 3600 * 24
thumbnail_size = (200, 200)
# cache = SQLiteBackend(get_cache_dir() / 'plantekey')
cache = FileBackend(get_cache_dir() / "plantekey" / "http")
pek_app = FastAPI(
default_response_class=ORJSONResponse,
)
def get_plantekey_api_url(id):
return f"https://plantekey.com/api.php?action=plant&name={id}"
def get_storage_root():
return get_cache_dir() / "plantekey"
def get_storage(f):
return get_storage_root() / f
def get_img_root() -> Path:
return get_storage_root() / "img"
def get_img_path(img: str):
return get_img_root() / img
def get_thumbnail_root() -> Path:
return get_storage_root() / "thumbnails"
def get_thumbnail_tar_path():
return get_storage_root() / "thumbnails.tar"
def get_img_type_root():
return get_storage_root() / "type"
def get_img_type_path(type: str):
return get_img_type_root() / (type + ".png")
def setup(ap):
"""
Create empty directories if needed.
Intended to be used at startup
"""
get_img_root().mkdir(parents=True, exist_ok=True)
get_thumbnail_root().mkdir(parents=True, exist_ok=True)
get_img_type_root().mkdir(parents=True, exist_ok=True)
class Plant(BaseModel, table=True):
"""
Record of a Plantekey plant
"""
__tablename__: str = "plant" # type: ignore
def __str__(self):
return str(self.id)
def __repr__(self):
return f"treetrail.database.Plant: {self.id}"
id: str | None = Field(primary_key=True, default=None)
ID: int
family: str = Field(sa_type=String(100)) # type: ignore
name: str = Field(sa_type=String(200)) # type: ignore
description: str | None
habit: str | None
landscape: str | None
uses: str | None
planting: str | None
propagation: str | None
type: str = Field(sa_type=String(30)) # type: ignore
img: str = Field(sa_type=String(100)) # type: ignore
element: str = Field(sa_type=String(30)) # type: ignore
isOnMap: str | None = Field(sa_type=String(10)) # type: ignore
english: str | None = Field(sa_type=String(100)) # type: ignore
hindi: str | None = Field(sa_type=String(100)) # type: ignore
tamil: str | None = Field(sa_type=String(100)) # type: ignore
spiritual: str | None = Field(sa_type=String(150)) # type: ignore
woody: bool
latex: str | None = Field(sa_type=String(20)) # type: ignore
leaf_style: str | None = Field(sa_type=String(20)) # type: ignore
leaf_type: str | None = Field(sa_type=String(20)) # type: ignore
leaf_arrangement: str | None = Field(sa_type=String(20)) # type: ignore
leaf_aroma: bool | None
leaf_length: float | None
leaf_width: float | None
flower_color: str | None = Field(sa_type=String(20)) # type: ignore
flower_size: float | None
flower_aroma: bool | None
fruit_color: str | None = Field(sa_type=String(20)) # type: ignore
fruit_size: float | None
fruit_type: str | None = Field(sa_type=String(20)) # type: ignore
thorny: str | None = Field(sa_type=String(20)) # type: ignore
images: dict = Field(
sa_type=MutableDict.as_mutable(JSONB), # type: ignore
default_factory=dict,
) # type: ignore
# CREATE EXTENSION hstore;
# ALTER TABLE tree ADD COLUMN data JSONB;
data: dict = Field(
sa_type=MutableDict.as_mutable(JSONB), # type: ignore
default_factory=dict,
) # type: ignore
# __mapper_args__ = {"eager_defaults": True}
class PlantImage(BaseModel, table=True):
__tablename__: str = "plant_image" # type: ignore
id: int | None = Field(primary_key=True, default=None)
plant_id: str = Field(foreign_key="plant.id")
plant: Plant = Relationship()
caption: str = Field(sa_type=String(50)) # type: ignore
IsDefault: bool
src: str = Field(sa_type=String(100)) # type: ignore
def get_thumbnail_path(self) -> Path:
return get_thumbnail_root() / self.src # type: ignore
class Plantekey(BaseModel, table=True):
"""
Details for the plantekey data, like symbols for the map
"""
## CREATE TABLE plantekey (id VARCHAR(100) PRIMARY KEY NOT NULL, symbol CHAR(1));
## GRANT ALL on TABLE plantekey to treetrail ;
__tablename__: str = "plantekey" # type: ignore
id: str | None = Field(primary_key=True, default=None)
symbol: str = Field(sa_type=String(1)) # type: ignore
iso: str = Field(sa_type=String(100)) # type: ignore
async def fetch_browse():
logger.info("Fetching list of plants (browse) from plantekey.com...")
plantekey_url = "https://www.plantekey.com/api.php?action=browse"
async with CachedSession(cache=cache) as session:
async with session.get(plantekey_url) as response:
try:
content = await response.json(content_type=None)
except Exception as err:
logger.warning("Error browsing plantekey")
logger.exception(err)
content = {}
df = pd.DataFrame(
data=content,
columns=[
"ID",
"english",
"family",
"hindi",
"img",
"name",
"spiritual",
"tamil",
"type",
],
)
df["id"] = df["name"].str.replace(" ", "-").str.lower()
df["ID"] = df["ID"].astype(int)
# async with db_session() as session:
# array = df.apply(lambda item: Plant(**item), axis=1)
# await session.exec(delete(Plant))
# session.add_all(array)
# await session.commit()
return df
async def get_all():
"""
Return the list of plants, with a local cache
"""
## TODO: implement cache mechanism
## Store in db?
# path = get_storage('all')
# if path.stat().st_ctime - time() > cache_expiry:
# return fetch_browse()
# return pd.read_feather(path)
async with db_session() as session:
all = await session.exec(select(Plant))
df = pd.DataFrame(all.all())
return df
async def get_local_details() -> pd.DataFrame:
"""
Return a dataframe of the plantekey table, containing extra information
like symbols for the map
"""
async with db_session() as session:
data = await session.exec(select(Plantekey.id, Plantekey.symbol))
df = pd.DataFrame(data.all())
if len(df) > 0:
df.set_index("id", inplace=True)
return df
async def get_details():
"""
Return the details of plants, as stored on the local server
"""
local_details = await get_local_details()
async with db_session() as session:
con = await session.connection()
df = await con.run_sync(read_sql, select(Plant))
# df.set_index('id', inplace=True)
if len(local_details) > 0:
ndf = df.merge(local_details, left_on="id", right_index=True, how="left")
ndf["symbol"].fillna("\ue034", inplace=True)
return ndf
else:
return df
async def fetch_image(img: PlantImage):
url = f"https://www.plantekey.com/admin/images/plants/{img.src}"
logger.info(f"Fetching image at {url}")
async with CachedSession(cache=cache) as session:
try:
resp = await session.get(url)
except Exception as err:
logger.warn(f"Cannot get image for {img.plant_id} ({url}): {err}")
return
img_data = await resp.content.read()
if img_data[0] != 255:
logger.warn(f"Image for {img.plant_id} at {url} is not an image")
return
with open(get_img_path(img.src), "bw") as f: # type: ignore
f.write(img_data)
update_thumbnail(BytesIO(img_data), img)
async def fetch_images():
"""
Fetch all the images from the plantekey server
"""
logger.info("Fetching all images from plantekey server")
async with db_session() as session:
images = await session.exec(select(PlantImage))
for img_rec in images:
await fetch_image(img_rec[0])
def update_thumbnail(data, img: PlantImage):
try:
tn_image = Image.open(data)
tn_image.thumbnail(thumbnail_size)
tn_image.save(img.get_thumbnail_path())
except Exception as error:
logger.warning(
f"Cannot create thumbnail for {img.plant_id} ({img.src}): {error}"
)
async def fetch_types():
df = await get_all()
for type in df.type.unique():
await fetch_type(type)
async def fetch_type(type):
plantekey_url = f"https://plantekey.com/img/icons/plant_types/{type}.png"
async with CachedSession(cache=cache) as session:
async with session.get(plantekey_url) as data:
img_data = await data.content.read()
with open(get_img_type_path(type), "bw") as f:
f.write(img_data)
async def update_details(df: pd.DataFrame):
"""
Update the server database from plantekey details
"""
all = {}
images = {}
for id in df["id"]:
try:
all[id], images[id] = await fetch_detail(id)
except Exception as err:
logger.warning(f"Error fetching details for {id}: {err}")
df_details = pd.DataFrame.from_dict(all, orient="index")
df_details.index.name = "id"
df_details["ID"] = df_details["ID"].astype(int)
## Cleanup according to DB data types
for float_col_name in ("leaf_width", "leaf_length", "flower_size", "fruit_size"):
df_details[float_col_name] = pd.to_numeric(
df_details[float_col_name], errors="coerce", downcast="float"
)
for bool_col_name in ("woody", "leaf_aroma", "flower_aroma"):
df_details[bool_col_name] = (
df_details[bool_col_name].replace({"Yes": True, "No": False}).astype(bool)
)
# TODO: migrate __table__ to SQLModel, use model_fields
for str_col_name in [
c.name
for c in Plant.__table__.columns # type: ignore
if isinstance(c.type, String) and not c.primary_key
]: # type: ignore
df_details[str_col_name] = df_details[str_col_name].replace([np.nan], [None])
async with db_session() as session:
plants_array = df_details.reset_index().apply(
lambda item: Plant(**item), axis=1
) # type: ignore
existing_plants = await session.exec(select(Plant))
for existing_plant in existing_plants.all():
await session.delete(existing_plant)
session.add_all(plants_array)
await session.flush()
## Images
existing_plant_images = await session.exec(select(PlantImage))
for existing_plant_image in existing_plant_images.all():
await session.delete(existing_plant_image)
images_array: list[PlantImage] = []
for plant_id, plant_images in images.items():
images_array.extend(
[PlantImage(plant_id=plant_id, **image) for image in plant_images]
)
for image in images_array:
image.IsDefault = False if image.IsDefault == "0" else True # type: ignore
session.add_all(images_array)
await session.commit()
async def fetch_detail(id):
logger.info(f"Fetching details from plantekey.com for {id}...")
async with CachedSession(cache=cache) as session:
async with session.get(get_plantekey_api_url(id)) as response:
result = await response.json(content_type=None)
## Sanitize
result["plant"] = {
k: v for k, v in result["plant"].items() if not k.isdecimal()
}
result["images"] = [
{k: v for k, v in image.items() if not k.isdecimal()}
for image in result["images"]
]
## Merge dicts, Python 3.9
detail = result["plant"] | (result["characteristics"] or {})
return detail, result["images"]
async def create_thumbnail_archive():
"""
Create a tar file with all thumbnails of plants
"""
logger.info("Generating thumbnails and tar file")
async with db_session() as session:
images = await session.exec(select(PlantImage))
with tarfile.open(str(get_thumbnail_tar_path()), "w") as tar:
for img_rec in images:
img: PlantImage = img_rec[0]
path = img.get_thumbnail_path()
if img.IsDefault:
if path.is_file():
tar.add(path)
logger.info(
"Generation of thumbnails and tar file "
+ f"({get_thumbnail_tar_path()}) finished"
)
@pek_app.get("/updateData")
async def updateData(user: User = Depends(get_current_active_user)):
"""
Get list and details of all plants from plantekey
"""
try:
df = await fetch_browse()
await update_details(df)
except Exception as error:
logger.exception(error)
logger.error(error)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Update failed: {error.code}", # type: ignore
)
else:
logger.info("updateData finished")
return {"status": 0, "message": "Server updated"}
@pek_app.get("/updateImages")
async def updateImages(user: User = Depends(get_current_active_user)):
"""
Get images from plantekey, using the list of plants
fetched with updateData
"""
try:
await fetch_images()
except Exception as error:
logger.exception(error)
logger.error(error)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Update of images failed: {error}",
)
try:
await create_thumbnail_archive()
except Exception as error:
logger.exception(error)
logger.error(error)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Update of images failed while creating thumbnails: {error}",
)
return {"status": 0, "message": "Server updated"}
@pek_app.get("/plant/info/{id}")
async def get_plantekey(
id: str,
db_session: fastapi_db_session,
) -> Plant | None: # type: ignore
"""
Get details of a specific plant
"""
plant = await db_session.get(Plant, id)
if plant is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
)
return plant
@pek_app.get("/details")
async def get_plantekey_all_details():
"""
Get all plants
"""
df = await get_details()
content = df.to_json(orient="records")
return Response(content=content, media_type="application/json")
@pek_app.get("/details/csv", response_class=PlainTextResponse)
async def get_plantekey_all_details_csv():
"""
Get all plants, return CSV
"""
df = await get_details()
content = df.to_csv()
return Response(content=content, media_type="text/csv")
pek_app.mount("/img", StaticFiles(directory=mkdir(get_img_root())), name="plantekey_img")
pek_app.mount("/thumb", StaticFiles(directory=mkdir(get_thumbnail_root())), name="plantekey_thumbnail")
pek_app.mount("/type", StaticFiles(directory=mkdir(get_img_type_root())), name="plantekey_type")

163
src/treetrail/security.py Normal file
View file

@ -0,0 +1,163 @@
from datetime import datetime, timedelta
from passlib.context import CryptContext
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.orm import joinedload
from pydantic import BaseModel
from jose import JWTError, jwt
from sqlmodel import select
from treetrail.config import conf
from treetrail.models import User, Role, UserRoleLink
from treetrail.database import db_session
# openssl rand -hex 32
# import secrets
# SECRET_KEY = secrets.token_hex(32)
ALGORITHM: str = "HS256"
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: str | None = None
# class User(BaseModel):
# username: str
# email: str | None = None
# full_name: str | None = None
# disabled: bool = False
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
pwd_context = CryptContext(schemes=["sha256_crypt", "bcrypt"], deprecated="auto")
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
def get_password_hash(password: str):
return pwd_context.hash(password)
async def delete_user(username) -> None:
async with db_session() as session:
user_in_db = await get_user(username)
if user_in_db is None:
raise SystemExit(f'User {username} does not exist in the database')
await session.delete(user_in_db)
async def enable_user(username, enable=True) -> None:
async with db_session() as session:
user_in_db = await get_user(username)
if user_in_db is None:
raise SystemExit(f'User {username} does not exist in the database')
user_in_db.disabled = not enable # type: ignore
session.add(user_in_db)
await session.commit()
async def create_user(username: str, password: str,
full_name: str | None = None,
email: str | None = None) -> User:
async with db_session() as session:
user = await get_user(username)
if user is None:
user = User(
username=username,
password=get_password_hash(password),
full_name=full_name,
email=email,
disabled=False
)
session.add(user)
await session.commit()
else:
user.full_name = full_name # type: ignore
user.email = email # type: ignore
user.password = get_password_hash(password) # type: ignore
await session.commit()
await session.refresh(user)
return user
async def get_user(username: str) -> User | None: # type: ignore
async with db_session() as session:
query = select(User)\
.where(User.username==username)\
.options(joinedload(User.roles)) # type: ignore
data = await session.exec(query)
return data.first()
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
async def get_current_user(token: str = Depends(oauth2_scheme)) -> User | None: # type: ignore
if token is None or token == 'null':
return None
try:
payload = jwt.decode(token, conf.security.secret_key, algorithms=[ALGORITHM])
username: str = payload.get("sub", '')
if username == '':
raise credentials_exception
token_data = TokenData(username=username)
except JWTError:
return None
user = await get_user(username=token_data.username) # type: ignore
return user
async def authenticate_user(username: str, password: str):
user = await get_user(username)
if not user:
return False
if not verify_password(password, user.password):
return False
return user
async def get_current_active_user(
current_user: User | None = Depends(get_current_user)) -> User: # type: ignore
if current_user is None:
raise HTTPException(status_code=400, detail="Not authenticated")
if current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
async def get_current_roles(user: User | None = Depends(get_current_user)) -> list[Role]: # type: ignore
roles: list[Role]
if user is None:
roles = []
else:
roles = user.roles
return roles
def create_access_token(data: dict, expires_delta: timedelta):
to_encode = data.copy()
expire = datetime.utcnow() + expires_delta
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode,
conf.security.secret_key,
algorithm=ALGORITHM)
return encoded_jwt
async def add_user_role(username: str, role_id: str):
async with db_session() as session:
user_in_db = await get_user(username)
if user_in_db is None:
raise SystemExit(f'User {username} does not exist in the database')
user_role = UserRoleLink(user_id=user_in_db.username, role_id=role_id)
session.add(user_role)
await session.commit()
async def add_role(role_id: str) -> Role:
async with db_session() as session:
role = Role(name=role_id)
session.add(role)
await session.commit()
await session.refresh(role)
return role

307
src/treetrail/tiles.py Normal file
View file

@ -0,0 +1,307 @@
"""
mbtile server
Instructions (example):
cd map ## Matches tilesBaseDir in config
curl http://download.geofabrik.de/asia/india/southern-zone-latest.osm.pbf -o osm.pbf
TILEMAKER_SRC=/home/phil/gisaf_misc/tilemaker
# Or, for fish
set TILEMAKER_SRC /home/phil/gisaf_misc/tilemaker
cp $TILEMAKER_SRC/resources/config-openmaptiles.json .
cp $TILEMAKER_SRC/resources/process-openmaptiles.lua .
## Edit config-openmaptiles.json, eg add in "settings":
# "bounding_box":[79.76777,11.96541,79.86909,12.04497]
vi config-openmaptiles.json
## Generate mbtile database:
tilemaker \
--config config-openmaptiles.json \
--process process-openmaptiles.lua \
--input osm.pbf \
--output osm.mbtiles
## Generate static tiles files
mkdir osm
tilemaker \
--config config-openmaptiles.json \
--process process-openmaptiles.lua \
--input osm.pbf \
--output osm
----
Get the style from https://github.com/openmaptiles, eg.
curl -o osm-bright-full.json https://raw.githubusercontent.com/openmaptiles/osm-bright-gl-style/master/style.json
## Minify json:
python -c 'import json, sys;json.dump(json.load(sys.stdin), sys.stdout)' < osm-bright-full.json > osm-bright.json
----
Get the sprites from openmaptiles:
cd tiles ## Matches tilesSpriteBaseDir in config
curl -O 'https://openmaptiles.github.io/osm-bright-gl-style/sprite.png'
curl -O 'https://openmaptiles.github.io/osm-bright-gl-style/sprite.json'
curl -O 'https://openmaptiles.github.io/osm-bright-gl-style/sprite@2x.png'
curl -O 'https://openmaptiles.github.io/osm-bright-gl-style/sprite@2x.json'
""" # noqa: E501
import logging
import tarfile
from pathlib import Path
from json import loads, dumps
from io import BytesIO
from fastapi import FastAPI, Response, HTTPException, Request
from fastapi.staticfiles import StaticFiles
import aiosqlite
from treetrail.config import conf
from treetrail.models import BaseMapStyles
from treetrail.utils import mkdir
logger = logging.getLogger('treetrail tile server')
tiles_app = FastAPI()
def get_tiles_tar_path(style):
## FIXME: use conf
return Path(__file__).parent.parent/f'treetrail-app/src/data/tiles/{style}.tar'
OSM_ATTRIBUTION = '<a href=\"http://www.openstreetmap.org/about/" target="_blank">' \
'&copy; OpenStreetMap contributors</a>'
class MBTiles:
def __init__(self, file_path, style_name):
self.file_path = file_path
self.name = style_name
self.scheme = 'tms'
self.etag = f'W/"{hex(int(file_path.stat().st_mtime))[2:]}"'
self.style_layers: list[dict]
## FIXME: use conf
try:
with open(Path(__file__).parent.parent / 'treetrail-app' / 'src' /
'assets' / 'map' / 'style.json') as f:
style = loads(f.read())
self.style_layers = style['layers']
except FileNotFoundError:
self.style_layers = []
for layer in self.style_layers:
if 'source' in layer:
layer['source'] = 'treeTrailTiles'
async def connect(self):
self.db = await aiosqlite.connect(self.file_path)
self.metadata = {}
try:
async with self.db.execute('select name, value from metadata') as cursor:
async for row in cursor:
self.metadata[row[0]] = row[1]
except aiosqlite.DatabaseError as err:
logger.warning(f'Cannot read {self.file_path}, will not be able'
f' to serve tiles (error: {err.args[0]})')
## Fix types
if 'bounds' in self.metadata:
self.metadata['bounds'] = [float(v)
for v in self.metadata['bounds'].split(',')]
self.metadata['maxzoom'] = int(self.metadata['maxzoom'])
self.metadata['minzoom'] = int(self.metadata['minzoom'])
logger.info(f'Serving tiles in {self.file_path}')
async def get_style(self, request: Request):
"""
Generate on the fly the style
"""
if conf.tiles.useRequestUrl:
base_url = str(request.base_url).removesuffix("/")
else:
base_url = conf.tiles.spriteBaseUrl
base_tiles_url = f"{base_url}/tiles/{self.name}"
scheme = self.scheme
resp = {
'basename': self.file_path.stem,
#'center': self.center,
'description': f'Extract of {self.file_path.stem} from OSM by Gisaf',
'format': self.metadata['format'],
'id': f'gisaftiles_{self.name}',
'maskLevel': 5,
'name': self.name,
#'pixel_scale': 256,
#'planettime': '1499040000000',
'tilejson': '2.0.0',
'version': 8,
'glyphs': "/assets/fonts/glyphs/{fontstack}/{range}.pbf",
'sprite': f"{base_url}{conf.tiles.spriteUrl}",
'sources': {
'treeTrailTiles': {
'type': 'vector',
'tiles': [
f'{base_tiles_url}/{{z}}/{{x}}/{{y}}.pbf',
],
'maxzoom': self.metadata['maxzoom'],
'minzoom': self.metadata['minzoom'],
'bounds': self.metadata['bounds'],
'scheme': scheme,
'attribution': OSM_ATTRIBUTION,
'version': self.metadata['version'],
}
},
'layers': self.style_layers,
}
return resp
async def get_tile(self, z, x, y):
async with self.db.execute(
'select tile_data from tiles where zoom_level=? ' \
'and tile_column=? and tile_row=?', (z, x, y)) as cursor:
async for row in cursor:
return row[0]
async def get_all_tiles_tar(self, style, request):
s = 0
n = 0
buf = BytesIO()
with tarfile.open(fileobj=buf, mode='w') as tar:
## Add tiles
async with self.db.execute('select zoom_level, tile_column, ' \
'tile_row, tile_data from tiles') as cursor:
async for row in cursor:
z, x, y, tile = row
tar_info = tarfile.TarInfo()
tar_info.path = f'{style}/{z}/{x}/{y}.pbf'
tar_info.size = len(tile)
tar.addfile(tar_info, BytesIO(tile))
logger.debug(f'Added {style}/{z}/{x}/{y} ({len(tile)})')
n += 1
s += len(tile)
logger.info(f'Added {n} files ({s} bytes)')
## Add style
tar_info = tarfile.TarInfo()
tar_info.path = f'style/{style}'
style_definition = await self.get_style(request)
style_data = dumps(style_definition, check_circular=False).encode('utf-8')
tar_info.size = len(style_data)
tar.addfile(tar_info, BytesIO(style_data))
## Add sprites ex. /tiles/sprite/sprite.json and /tiles/sprite/sprite.png
tar.add(conf.tiles.spriteBaseDir, 'sprite')
## Extract
buf.seek(0)
## XXX: Could write to file:
#file_path = get_tiles_tar_path(style)
return buf.read()
class MBTilesRegistry:
mbtiles: dict[str, MBTiles]
async def setup(self, app):
"""
Read all mbtiles, construct styles
"""
self.mbtiles = {}
for file_path in Path(conf.tiles.baseDir).glob('*.mbtiles'):
mbtiles = MBTiles(file_path, file_path.stem)
self.mbtiles[file_path.stem] = mbtiles
await mbtiles.connect()
async def shutdown(self, app):
"""
Tear down the connection to the mbtiles files
"""
for mbtiles in self.mbtiles.values():
await mbtiles.db.close()
gzip_headers = {
'Content-Encoding': 'gzip',
'Content-Type': 'application/octet-stream',
}
tar_headers = {
'Content-Type': 'application/x-tar',
}
@tiles_app.get('/styles')
async def get_styles() -> BaseMapStyles:
"""Styles for the map background. There are 2 types:
- found on the embedded tiles server, that can be used offline
- external providers, defined in the config with a simple url
"""
return BaseMapStyles(
external=conf.mapStyles,
embedded=list(registry.mbtiles.keys())
)
@tiles_app.get('/{style_name}/{z}/{x}/{y}.pbf')
async def get_tile(style_name:str, z:int, x:int, y:int):
"""
Return the specific tile
"""
## TODO: implement etag
#if request.headers.get('If-None-Match') == mbtiles.etag:
# request.not_modified = True
# return web.Response(body=None)
#request.response_etag = mbtiles.etag
if style_name not in registry.mbtiles:
raise HTTPException(status_code=404)
mbtiles = registry.mbtiles[style_name]
try:
tile = await mbtiles.get_tile(z, x, y)
except Exception as err:
logger.info(f'Cannot get tile {z}, {x}, {y}')
logger.exception(err)
raise HTTPException(status_code=404)
else:
return Response(content=tile,
media_type="application/json",
headers=gzip_headers)
@tiles_app.get('/{style_name}/all.tar')
async def get_tiles_tar(style_name, request: Request):
"""
Get a tar file with all the tiles. Typically, used to feed into
the browser's cache for offline use.
"""
mbtiles: MBTiles = registry.mbtiles[style_name]
tar = await mbtiles.get_all_tiles_tar(style_name, request)
return Response(content=tar, media_type="application/x-tar", headers=tar_headers)
#@tiles_app.get('/sprite/{name:\S+}')
#async def get_sprite(request):
@tiles_app.get('/style/{style_name}')
async def get_style(style_name: str, request: Request):
"""
Return the base style.
"""
if style_name not in registry.mbtiles:
raise HTTPException(status_code=404)
mbtiles = registry.mbtiles[style_name]
return await mbtiles.get_style(request)
registry = MBTilesRegistry()
tiles_app.mount("/sprite",
StaticFiles(directory=mkdir(conf.tiles.spriteBaseDir)),
name="tiles_sprites")
tiles_app.mount('/osm',
StaticFiles(directory=mkdir(conf.tiles.osmBaseDir)),
name='tiles_osm')

86
src/treetrail/utils.py Normal file
View file

@ -0,0 +1,86 @@
import asyncio
import json
from pathlib import Path
import logging
import pandas as pd
from sqlalchemy.ext.declarative import DeclarativeMeta
from sqlalchemy.engine.row import Row
from sqlalchemy.sql.selectable import Select
import geopandas as gpd # type: ignore
from treetrail.config import conf
logger = logging.getLogger(__name__)
class AlchemyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj.__class__, DeclarativeMeta):
# an SQLAlchemy class
fields = {}
for field in [x for x in dir(obj)
if not x.startswith('_') and x != 'metadata']:
data = obj.__getattribute__(field)
try:
# this will fail on non-encodable values, like other classes
json.dumps(data)
fields[field] = data
except TypeError:
fields[field] = None
# a json-encodable dict
return fields
if isinstance(obj, Row):
return dict(obj)
return json.JSONEncoder.default(self, obj)
async def read_sql_async(stmt, con):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, pd.read_sql, stmt, con)
def read_sql(con, stmt):
## See https://stackoverflow.com/questions/70848256/how-can-i-use-pandas-read-sql-on-an-async-connection
return pd.read_sql_query(stmt, con)
def get_attachment_root(type: str):
return Path(conf.storage.root_attachment_path) / type
def get_attachment_tree_root():
return get_attachment_root('tree')
def get_attachment_trail_root():
return get_attachment_root('trail')
def get_attachment_poi_root():
return get_attachment_root('poi')
def pandas_query(session, query):
return pd.read_sql_query(query, session.connection())
def geopandas_query(session, query: Select, model, *,
# simplify_tolerance: float|None=None,
crs=None, cast=True,
):
## XXX: I could not get the add_columns work without creating a subquery,
## so moving the simplification to geopandas - see in _get_df
# if simplify_tolerance is not None:
# query = query.with_only_columns(*(col for col in query.columns
# if col.name != 'geom'))
# new_column = model.__table__.columns['geom'].ST_SimplifyPreserveTopology(
# simplify_tolerance).label('geom')
# query = query.add_columns(new_column)
return gpd.GeoDataFrame.from_postgis(query, session.connection(), crs=crs)
def mkdir(dir: Path | str) -> Path:
path = Path(dir)
if not path.is_dir():
logger.info(f'Create directory {path}')
path.mkdir(parents=True, exist_ok=True)
return path

0
tests/__init__.py Normal file
View file

16
tests/basic.py Normal file
View file

@ -0,0 +1,16 @@
from fastapi.testclient import TestClient
from treetrail.application import app
client = TestClient(app)
def test_read_main():
with TestClient(app) as client:
response = client.get("/treetrail/v1/bootstrap")
assert response.status_code == 200
json = response.json()
assert set(json) == {'client', 'server', 'user', 'map', 'baseMapStyles', 'app'}
assert json['user'] is None
assert set(json['map']) == {'bearing', 'lat', 'background', 'lng', 'pitch', 'zoom'}
assert set(json['baseMapStyles']) == {'external', 'embedded'}