from contextlib import asynccontextmanager import sys from typing import Annotated from collections.abc import AsyncGenerator from asyncio import sleep import logging from fastapi import Depends from sqlalchemy.ext.asyncio import create_async_engine from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel import SQLModel, select, func, col from treetrail.config import conf logger = logging.getLogger(__name__) CREATE_DB_TIMEOUT = 30 engine = create_async_engine( conf.db.get_sqla_url(), echo=conf.db.echo, pool_size=conf.db.pool_size, max_overflow=conf.db.max_overflow, ) async def create_db(drop=False): attempts = CREATE_DB_TIMEOUT async def try_once(): async with engine.begin() as conn: if drop: await conn.run_sync(SQLModel.metadata.drop_all) await conn.run_sync(SQLModel.metadata.create_all) logger.debug(f'Connect to database with config: {conf.db}') while attempts > 0: try: await try_once() except ConnectionRefusedError: logger.debug( f"Cannot connect to database during init (create_db), " f"waiting {attempts} more seconds" ) attempts -= 1 await sleep(1) else: if await is_fresh_install(): await populate_init_db() return else: logger.warning( f"Cannot connect to database after {CREATE_DB_TIMEOUT}, giving up." ) sys.exit(1) async def is_fresh_install() -> bool: """Detect is the database is newly created, without data""" from treetrail.models import User async with db_session() as session: nb_users = (await session.exec(select(func.count(col(User.username))))).one() return nb_users == 0 async def populate_init_db(): """Populate the database for a fresh install""" from sqlalchemy import text from treetrail.security import create_user, add_role, add_user_role logger.info("Populating initial database") user = await create_user(username="admin", password="admin") role = await add_role(role_id="admin") await add_user_role(user.username, role.name) async with db_session() as session: for initial in initials: await session.execute(text(initial)) logger.debug(f'Added map style {initial}') await session.commit() ## Default styles, to be inserted in the DB initials: list[str] = [ """INSERT INTO map_style (layer, paint, layout) values ('trail', '{"line-color": "#cd861a", "line-width": 6, "line-blur": 2, "line-opacity": 0.9 }', '{"line-join": "bevel"}');""", """INSERT INTO map_style (layer, layout) values ('tree', '{"icon-image":"tree", "icon-size": 0.4}');""", """INSERT INTO map_style (layer, layout) values ('tree-hl', '{"icon-image":"tree", "icon-size": 0.4}');""", """INSERT INTO map_style (layer, layout) values ('poi', '{"icon-image":"poi", "icon-size": 0.4}');""", """INSERT INTO map_style (layer, paint) VALUES ('zone', '{"fill-color": ["match", ["string", ["get", "type"]], "Forest", "#00FF00", "Master Plan", "#EE4455", "#000000"], "fill-opacity": 0.5}');""", ] # noqa: E501 async def get_db_session() -> AsyncGenerator[AsyncSession]: async with AsyncSession(engine) as session: yield session @asynccontextmanager async def db_session() -> AsyncGenerator[AsyncSession]: async with AsyncSession(engine) as session: yield session fastapi_db_session = Annotated[AsyncSession, Depends(get_db_session)]