104 lines
3.4 KiB
Python
104 lines
3.4 KiB
Python
|
from contextlib import asynccontextmanager
|
||
|
import sys
|
||
|
from typing import Annotated
|
||
|
from collections.abc import AsyncGenerator
|
||
|
from asyncio import sleep
|
||
|
import logging
|
||
|
|
||
|
from fastapi import Depends
|
||
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
||
|
from sqlmodel import SQLModel, select, func, col
|
||
|
|
||
|
from treetrail.config import conf
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
CREATE_DB_TIMEOUT = 30
|
||
|
|
||
|
engine = create_async_engine(
|
||
|
conf.db.get_sqla_url(),
|
||
|
echo=conf.db.echo,
|
||
|
pool_size=conf.db.pool_size,
|
||
|
max_overflow=conf.db.max_overflow,
|
||
|
)
|
||
|
|
||
|
|
||
|
async def create_db(drop=False):
|
||
|
attempts = CREATE_DB_TIMEOUT
|
||
|
|
||
|
async def try_once():
|
||
|
async with engine.begin() as conn:
|
||
|
if drop:
|
||
|
await conn.run_sync(SQLModel.metadata.drop_all)
|
||
|
await conn.run_sync(SQLModel.metadata.create_all)
|
||
|
|
||
|
while attempts > 0:
|
||
|
try:
|
||
|
await try_once()
|
||
|
except ConnectionRefusedError:
|
||
|
logger.debug(
|
||
|
f"Cannot connect to database during init (create_db), "
|
||
|
f"waiting {attempts} more seconds"
|
||
|
)
|
||
|
attempts -= 1
|
||
|
await sleep(1)
|
||
|
else:
|
||
|
if await is_fresh_install():
|
||
|
await populate_init_db()
|
||
|
return
|
||
|
else:
|
||
|
logger.warning(
|
||
|
f"Cannot connect to database after {CREATE_DB_TIMEOUT}, giving up."
|
||
|
)
|
||
|
sys.exit(1)
|
||
|
|
||
|
|
||
|
async def is_fresh_install() -> bool:
|
||
|
"""Detect is the database is newly created, without data"""
|
||
|
from treetrail.models import User
|
||
|
|
||
|
async with db_session() as session:
|
||
|
nb_users = (await session.exec(select(func.count(col(User.username))))).one()
|
||
|
return nb_users == 0
|
||
|
|
||
|
|
||
|
async def populate_init_db():
|
||
|
"""Populate the database for a fresh install"""
|
||
|
from sqlalchemy import text
|
||
|
from treetrail.security import create_user, add_role, add_user_role
|
||
|
logger.info("Populating initial database")
|
||
|
|
||
|
user = await create_user(username="admin", password="admin")
|
||
|
role = await add_role(role_id="admin")
|
||
|
await add_user_role(user.username, role.name)
|
||
|
async with db_session() as session:
|
||
|
for initial in initials:
|
||
|
await session.execute(text(initial))
|
||
|
logger.debug(f'Added map style {initial}')
|
||
|
await session.commit()
|
||
|
|
||
|
|
||
|
## Default styles, to be inserted in the DB
|
||
|
initials: list[str] = [
|
||
|
"""INSERT INTO map_style (layer, paint, layout) values ('trail', '{"line-color": "#cd861a", "line-width": 6, "line-blur": 2, "line-opacity": 0.9 }', '{"line-join": "bevel"}');""",
|
||
|
"""INSERT INTO map_style (layer, layout) values ('tree', '{"icon-image":"tree", "icon-size": 0.4}');""",
|
||
|
"""INSERT INTO map_style (layer, layout) values ('tree-hl', '{"icon-image":"tree", "icon-size": 0.4}');""",
|
||
|
"""INSERT INTO map_style (layer, layout) values ('poi', '{"icon-image":"poi", "icon-size": 0.4}');""",
|
||
|
"""INSERT INTO map_style (layer, paint) VALUES ('zone', '{"fill-color": ["match", ["string", ["get", "type"]], "Forest", "#00FF00", "Master Plan", "#EE4455", "#000000"], "fill-opacity": 0.5}');""",
|
||
|
] # noqa: E501
|
||
|
|
||
|
|
||
|
async def get_db_session() -> AsyncGenerator[AsyncSession]:
|
||
|
async with AsyncSession(engine) as session:
|
||
|
yield session
|
||
|
|
||
|
|
||
|
@asynccontextmanager
|
||
|
async def db_session() -> AsyncGenerator[AsyncSession]:
|
||
|
async with AsyncSession(engine) as session:
|
||
|
yield session
|
||
|
|
||
|
|
||
|
fastapi_db_session = Annotated[AsyncSession, Depends(get_db_session)]
|