treetrail-backend/src/treetrail/database.py

104 lines
3.5 KiB
Python

from contextlib import asynccontextmanager
import sys
from typing import Annotated
from collections.abc import AsyncGenerator
from asyncio import sleep
import logging
from fastapi import Depends
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel import SQLModel, select, func, col
from treetrail.config import conf
logger = logging.getLogger(__name__)
CREATE_DB_TIMEOUT = 30
engine = create_async_engine(
conf.db.get_sqla_url(),
echo=conf.db.echo,
pool_size=conf.db.pool_size,
max_overflow=conf.db.max_overflow,
)
async def create_db(drop=False):
attempts = CREATE_DB_TIMEOUT
async def try_once():
async with engine.begin() as conn:
if drop:
await conn.run_sync(SQLModel.metadata.drop_all)
await conn.run_sync(SQLModel.metadata.create_all)
logger.debug(f'Connect to database with config: {conf.db}')
while attempts > 0:
try:
await try_once()
except ConnectionRefusedError:
logger.debug(
f"Cannot connect to database during init (create_db), "
f"waiting {attempts} more seconds"
)
attempts -= 1
await sleep(1)
else:
if await is_fresh_install():
await populate_init_db()
return
else:
logger.warning(
f"Cannot connect to database after {CREATE_DB_TIMEOUT}, giving up."
)
sys.exit(1)
async def is_fresh_install() -> bool:
"""Detect is the database is newly created, without data"""
from treetrail.models import User
async with db_session() as session:
nb_users = (await session.exec(select(func.count(col(User.username))))).one()
return nb_users == 0
async def populate_init_db():
"""Populate the database for a fresh install"""
from sqlalchemy import text
from treetrail.security import create_user, add_role, add_user_role
logger.info("Populating initial database")
user = await create_user(username="admin", password="admin")
role = await add_role(role_id="admin")
await add_user_role(user.username, role.name)
async with db_session() as session:
for initial in initials:
await session.execute(text(initial))
logger.debug(f'Added map style {initial}')
await session.commit()
## Default styles, to be inserted in the DB
initials: list[str] = [
"""INSERT INTO map_style (layer, paint, layout) values ('trail', '{"line-color": "#cd861a", "line-width": 6, "line-blur": 2, "line-opacity": 0.9 }', '{"line-join": "bevel"}');""",
"""INSERT INTO map_style (layer, layout) values ('tree', '{"icon-image":"tree", "icon-size": 0.4}');""",
"""INSERT INTO map_style (layer, layout) values ('tree-hl', '{"icon-image":"tree", "icon-size": 0.4}');""",
"""INSERT INTO map_style (layer, layout) values ('poi', '{"icon-image":"poi", "icon-size": 0.4}');""",
"""INSERT INTO map_style (layer, paint) VALUES ('zone', '{"fill-color": ["match", ["string", ["get", "type"]], "Forest", "#00FF00", "Master Plan", "#EE4455", "#000000"], "fill-opacity": 0.5}');""",
] # noqa: E501
async def get_db_session() -> AsyncGenerator[AsyncSession]:
async with AsyncSession(engine) as session:
yield session
@asynccontextmanager
async def db_session() -> AsyncGenerator[AsyncSession]:
async with AsyncSession(engine) as session:
yield session
fastapi_db_session = Annotated[AsyncSession, Depends(get_db_session)]