treetrail-backend/src/treetrail/database.py

105 lines
3.5 KiB
Python
Raw Normal View History

2024-10-23 16:19:51 +02:00
from contextlib import asynccontextmanager
import sys
from typing import Annotated
from collections.abc import AsyncGenerator
from asyncio import sleep
import logging
from fastapi import Depends
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel import SQLModel, select, func, col
from treetrail.config import conf
logger = logging.getLogger(__name__)
CREATE_DB_TIMEOUT = 30
engine = create_async_engine(
conf.db.get_sqla_url(),
echo=conf.db.echo,
pool_size=conf.db.pool_size,
max_overflow=conf.db.max_overflow,
)
async def create_db(drop=False):
attempts = CREATE_DB_TIMEOUT
async def try_once():
async with engine.begin() as conn:
if drop:
await conn.run_sync(SQLModel.metadata.drop_all)
await conn.run_sync(SQLModel.metadata.create_all)
logger.debug(f'Connect to database with config: {conf.db}')
2024-10-23 16:19:51 +02:00
while attempts > 0:
try:
await try_once()
except ConnectionRefusedError:
logger.debug(
f"Cannot connect to database during init (create_db), "
f"waiting {attempts} more seconds"
)
attempts -= 1
await sleep(1)
else:
if await is_fresh_install():
await populate_init_db()
return
else:
logger.warning(
f"Cannot connect to database after {CREATE_DB_TIMEOUT}, giving up."
)
sys.exit(1)
async def is_fresh_install() -> bool:
"""Detect is the database is newly created, without data"""
from treetrail.models import User
async with db_session() as session:
nb_users = (await session.exec(select(func.count(col(User.username))))).one()
return nb_users == 0
async def populate_init_db():
"""Populate the database for a fresh install"""
from sqlalchemy import text
from treetrail.security import create_user, add_role, add_user_role
logger.info("Populating initial database")
user = await create_user(username="admin", password="admin")
role = await add_role(role_id="admin")
await add_user_role(user.username, role.name)
async with db_session() as session:
for initial in initials:
await session.execute(text(initial))
logger.debug(f'Added map style {initial}')
await session.commit()
## Default styles, to be inserted in the DB
initials: list[str] = [
"""INSERT INTO map_style (layer, paint, layout) values ('trail', '{"line-color": "#cd861a", "line-width": 6, "line-blur": 2, "line-opacity": 0.9 }', '{"line-join": "bevel"}');""",
"""INSERT INTO map_style (layer, layout) values ('tree', '{"icon-image":"tree", "icon-size": 0.4}');""",
"""INSERT INTO map_style (layer, layout) values ('tree-hl', '{"icon-image":"tree", "icon-size": 0.4}');""",
"""INSERT INTO map_style (layer, layout) values ('poi', '{"icon-image":"poi", "icon-size": 0.4}');""",
"""INSERT INTO map_style (layer, paint) VALUES ('zone', '{"fill-color": ["match", ["string", ["get", "type"]], "Forest", "#00FF00", "Master Plan", "#EE4455", "#000000"], "fill-opacity": 0.5}');""",
] # noqa: E501
async def get_db_session() -> AsyncGenerator[AsyncSession]:
async with AsyncSession(engine) as session:
yield session
@asynccontextmanager
async def db_session() -> AsyncGenerator[AsyncSession]:
async with AsyncSession(engine) as session:
yield session
fastapi_db_session = Annotated[AsyncSession, Depends(get_db_session)]