treetrail-backend/src/treetrail/utils.py

87 lines
2.8 KiB
Python
Raw Normal View History

2024-10-23 16:19:51 +02:00
import asyncio
import json
from pathlib import Path
import logging
import pandas as pd
from sqlalchemy.ext.declarative import DeclarativeMeta
from sqlalchemy.engine.row import Row
from sqlalchemy.sql.selectable import Select
import geopandas as gpd # type: ignore
from treetrail.config import conf
logger = logging.getLogger(__name__)
class AlchemyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj.__class__, DeclarativeMeta):
# an SQLAlchemy class
fields = {}
for field in [x for x in dir(obj)
if not x.startswith('_') and x != 'metadata']:
data = obj.__getattribute__(field)
try:
# this will fail on non-encodable values, like other classes
json.dumps(data)
fields[field] = data
except TypeError:
fields[field] = None
# a json-encodable dict
return fields
if isinstance(obj, Row):
return dict(obj)
return json.JSONEncoder.default(self, obj)
async def read_sql_async(stmt, con):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, pd.read_sql, stmt, con)
def read_sql(con, stmt):
## See https://stackoverflow.com/questions/70848256/how-can-i-use-pandas-read-sql-on-an-async-connection
return pd.read_sql_query(stmt, con)
def get_attachment_root(type: str):
return Path(conf.storage.root_attachment_path) / type
def get_attachment_tree_root():
return get_attachment_root('tree')
def get_attachment_trail_root():
return get_attachment_root('trail')
def get_attachment_poi_root():
return get_attachment_root('poi')
def pandas_query(session, query):
return pd.read_sql_query(query, session.connection())
def geopandas_query(session, query: Select, model, *,
# simplify_tolerance: float|None=None,
crs=None, cast=True,
):
## XXX: I could not get the add_columns work without creating a subquery,
## so moving the simplification to geopandas - see in _get_df
# if simplify_tolerance is not None:
# query = query.with_only_columns(*(col for col in query.columns
# if col.name != 'geom'))
# new_column = model.__table__.columns['geom'].ST_SimplifyPreserveTopology(
# simplify_tolerance).label('geom')
# query = query.add_columns(new_column)
return gpd.GeoDataFrame.from_postgis(query, session.connection(), crs=crs)
def mkdir(dir: Path | str) -> Path:
path = Path(dir)
if not path.is_dir():
logger.info(f'Create directory {path}')
path.mkdir(parents=True, exist_ok=True)
return path