Add CLI, with init-db; tweak db container; misc minor tweaks in code for the init-db
Some checks failed
/ test (push) Failing after 31s

This commit is contained in:
phil 2024-12-15 19:19:35 +01:00
parent a94d27db0c
commit 98d67f0226
11 changed files with 239 additions and 57 deletions

View file

@ -30,5 +30,8 @@ jobs:
- name: Install app with 'uv pip install' - name: Install app with 'uv pip install'
run: uv pip install --python=$UV_PROJECT_ENVIRONMENT --no-deps . run: uv pip install --python=$UV_PROJECT_ENVIRONMENT --no-deps .
- name: Initialize database
run: GISAF__DB__HOST=gisaf-database gisaf create-db
- name: Run tests (API call) - name: Run tests (API call)
run: GISAF__DB__HOST=gisaf-database pytest -s tests/basic.py run: GISAF__DB__HOST=gisaf-database pytest -s tests/basic.py

View file

@ -2,3 +2,6 @@ FROM docker.io/postgis/postgis:17-3.5-alpine
ENV POSTGRES_USER gisaf ENV POSTGRES_USER gisaf
ENV POSTGRES_PASSWORD secret ENV POSTGRES_PASSWORD secret
# Overwrite standard postgis entrypoint
COPY ./database-container-entrypoint-postgis.sh /docker-entrypoint-initdb.d/10_postgis.sh

View file

@ -0,0 +1,29 @@
#!/bin/bash
set -e
# Perform all actions as $POSTGRES_USER
export PGUSER="$POSTGRES_USER"
# Create the 'template_postgis' template db
"${psql[@]}" <<-'EOSQL'
CREATE DATABASE template_postgis IS_TEMPLATE true;
EOSQL
# Load PostGIS into both template_database and $POSTGRES_DB
for DB in template_postgis "$POSTGRES_DB"; do
echo "Loading PostGIS extensions into $DB"
"${psql[@]}" --dbname="$DB" <<-'EOSQL'
CREATE EXTENSION IF NOT EXISTS postgis;
EOSQL
done
"${psql[@]}" --dbname="$DB" <<-'EOSQL'
CREATE EXTENSION IF NOT EXISTS hstore;
CREATE SCHEMA gisaf;
CREATE SCHEMA gisaf_admin;
CREATE SCHEMA gisaf_map;
CREATE SCHEMA gisaf_survey;
CREATE SCHEMA raw_survey;
CREATE SCHEMA survey;
EOSQL

View file

@ -31,12 +31,13 @@ dependencies = [
"uvicorn>=0.23.2", "uvicorn>=0.23.2",
"websockets>=12.0", "websockets>=12.0",
"pyxdg>=0.28", "pyxdg>=0.28",
"typer-slim>=0.15.1",
] ]
requires-python = ">=3.12" requires-python = ">=3.12"
readme = "README.md" readme = "README.md"
[project.scripts] [project.scripts]
gisaf-backend = "gisaf_backend:main" gisaf = "gisaf.cli:cli"
[project.optional-dependencies] [project.optional-dependencies]
contextily = ["contextily>=1.4.0"] contextily = ["contextily>=1.4.0"]

53
src/gisaf/cli.py Normal file
View file

@ -0,0 +1,53 @@
#!/usr/bin/env python
from importlib.metadata import version as importlib_version
from sqlalchemy.engine import create
from typing_extensions import Annotated
import typer
cli = typer.Typer(no_args_is_help=True, help="Gisaf GIS backend")
@cli.command()
def create_db():
"""Populate the database with a functional empty structure"""
from gisaf.application import app
from gisaf.database import create_db
from asyncio import run
print(f"Create DB...")
run(create_db())
@cli.command()
def serve(host: str = "localhost", port: int = 8000):
"""
Run the uvicorn server.
Use yaml config files or environment variables for configuration.
Note that you can also run gisaf with:
uvicorn src.gisaf.application:app
"""
from uvicorn import run
from gisaf.application import app
run(app, host=host, port=port)
def version_callback(show_version: bool):
if show_version:
print(importlib_version("gisaf-backend"))
raise typer.Exit()
@cli.callback()
def main(
version: Annotated[
bool | None, typer.Option("--version", callback=version_callback)
] = None
):
pass
if __name__ == "__main__":
cli()

View file

@ -1,12 +1,14 @@
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Annotated, Literal, Any from typing import Annotated, Literal, Any
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from asyncio import sleep
import logging
from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy.orm import joinedload, QueryableAttribute from sqlalchemy.orm import joinedload, QueryableAttribute
from sqlalchemy.sql.selectable import Select from sqlalchemy.sql.selectable import Select
from sqlmodel import SQLModel, select from sqlmodel import SQLModel, select, func, col
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from fastapi import Depends from fastapi import Depends
@ -16,6 +18,10 @@ import geopandas as gpd # type: ignore
from gisaf.config import conf from gisaf.config import conf
logger = logging.getLogger(__name__)
CREATE_DB_TIMEOUT = 10
engine = create_async_engine( engine = create_async_engine(
conf.db.get_sqla_url(), conf.db.get_sqla_url(),
echo=conf.db.echo, echo=conf.db.echo,
@ -151,3 +157,68 @@ class BaseModel(SQLModel):
fastapi_db_session = Annotated[AsyncSession, Depends(get_db_session)] fastapi_db_session = Annotated[AsyncSession, Depends(get_db_session)]
async def create_db(drop=False):
attempts = CREATE_DB_TIMEOUT
async def try_once():
async with engine.begin() as conn:
if drop:
await conn.run_sync(SQLModel.metadata.drop_all)
await conn.run_sync(SQLModel.metadata.create_all)
logger.debug(f"Connect to database with config: {conf.db}")
while attempts > 0:
try:
await try_once()
except ConnectionRefusedError:
logger.debug(
f"Cannot connect to database during init (create_db), "
f"waiting {attempts} more seconds"
)
attempts -= 1
await sleep(1)
else:
if await is_fresh_install():
await populate_init_db()
return
else:
logger.warning(
f"Cannot connect to database after {CREATE_DB_TIMEOUT}, giving up."
)
exit(1)
async def is_fresh_install() -> bool:
"""Detect is the database is newly created, without data"""
from gisaf.models.authentication import User
async with db_session() as session:
nb_users = (await session.exec(select(func.count(col(User.username))))).one()
return nb_users == 0
async def populate_init_db():
"""Populate the database for a fresh install"""
from sqlalchemy import text
from gisaf.security import create_user # , add_role, add_user_role
logger.info("Populating initial database")
async with db_session() as session:
user = await create_user(
session=session,
username="admin",
password="admin",
full_name="Admin",
email="root@localhost.localdomain",
active=True,
)
assert user is not None
# role = await add_role(role_id="admin")
# await add_user_role(user.username, role.name)
# for initial in initials:
# await session.execute(text(initial))
# logger.debug(f"Added map style {initial}")
# await session.commit()

View file

@ -7,17 +7,13 @@ from gisaf.models.metadata import gisaf_admin
class UserRoleLink(SQLModel, table=True): class UserRoleLink(SQLModel, table=True):
__tablename__: str = 'roles_users' # type: ignore __tablename__: str = "roles_users" # type: ignore
__table_args__ = gisaf_admin.table_args __table_args__ = gisaf_admin.table_args
user_id: int | None = Field( user_id: int | None = Field(
default=None, default=None, foreign_key=gisaf_admin.table("user.id"), primary_key=True
foreign_key=gisaf_admin.table('user.id'),
primary_key=True
) )
role_id: int | None = Field( role_id: int | None = Field(
default=None, default=None, foreign_key=gisaf_admin.table("role.id"), primary_key=True
foreign_key=gisaf_admin.table('role.id'),
primary_key=True
) )
@ -34,18 +30,17 @@ class User(UserBase, table=True):
username: str = Field(String(255), unique=True, index=True) username: str = Field(String(255), unique=True, index=True)
email: str = Field(sa_type=String(50), unique=True) email: str = Field(sa_type=String(50), unique=True)
password: str = Field(sa_type=String(255)) password: str = Field(sa_type=String(255))
active: bool active: bool = True
confirmed_at: datetime confirmed_at: datetime | None = None
last_login_at: datetime last_login_at: datetime | None = None
current_login_at: datetime current_login_at: datetime | None = None
last_login_ip: str = Field(sa_type=String(255)) last_login_ip: str | None = Field(sa_type=String(255), default=None)
current_login_ip: str = Field(sa_type=String(255)) current_login_ip: str | None = Field(sa_type=String(255), default=None)
login_count: int login_count: int = 0
roles: list["Role"] = Relationship(back_populates="users", roles: list["Role"] = Relationship(back_populates="users", link_model=UserRoleLink)
link_model=UserRoleLink)
def can_view(self, model) -> bool: def can_view(self, model) -> bool:
role = getattr(model, 'viewable_role', None) role = getattr(model, "viewable_role", None)
if role: if role:
return self.has_role(role) return self.has_role(role)
else: else:
@ -54,17 +49,19 @@ class User(UserBase, table=True):
def has_role(self, role: str) -> bool: def has_role(self, role: str) -> bool:
return role in (role.name for role in self.roles) return role in (role.name for role in self.roles)
class RoleBase(SQLModel): class RoleBase(SQLModel):
name: str = Field(unique=True) name: str = Field(unique=True)
class RoleWithDescription(RoleBase): class RoleWithDescription(RoleBase):
description: str | None description: str | None
class Role(RoleWithDescription, table=True): class Role(RoleWithDescription, table=True):
__table_args__ = gisaf_admin.table_args __table_args__ = gisaf_admin.table_args
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
users: list[User] = Relationship(back_populates="roles", users: list[User] = Relationship(back_populates="roles", link_model=UserRoleLink)
link_model=UserRoleLink)
class UserReadNoRoles(UserBase): class UserReadNoRoles(UserBase):
@ -87,7 +84,7 @@ class UserRead(UserBase):
roles: list[RoleReadNoUsers] = [] roles: list[RoleReadNoUsers] = []
def can_view(self, model) -> bool: def can_view(self, model) -> bool:
role = getattr(model, 'viewable_role', None) role = getattr(model, "viewable_role", None)
if role: if role:
return self.has_role(role) return self.has_role(role)
else: else:
@ -100,3 +97,4 @@ class UserRead(UserBase):
# class ACL(BaseModel): # class ACL(BaseModel):
# user_id: int # user_id: int
# role_ids: list[int] # role_ids: list[int]

View file

@ -538,8 +538,7 @@ async def create_tags(features, keys, values):
return pd.concat(result) return pd.concat(result)
from gisaf.utils import ToMigrate # TODO: Migrate('plugins.change_feature_status (graphql)'))
logger.warning(ToMigrate('plugins.change_feature_status (graphql)'))
change_feature_status = ChangeFeatureStatus( change_feature_status = ChangeFeatureStatus(
name='Change status', name='Change status',
stores_by_re=[f"{conf.survey.db_schema}"], stores_by_re=[f"{conf.survey.db_schema}"],

View file

@ -48,6 +48,7 @@ expired_exception = HTTPException(
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
def get_password_hash(password: str): def get_password_hash(password: str):
return pwd_context.hash(password) return pwd_context.hash(password)
@ -55,22 +56,27 @@ def get_password_hash(password: str):
async def delete_user(session: AsyncSession, username: str) -> None: async def delete_user(session: AsyncSession, username: str) -> None:
user_in_db: User | None = await get_user(session, username) user_in_db: User | None = await get_user(session, username)
if user_in_db is None: if user_in_db is None:
raise SystemExit(f'User {username} does not exist in the database') raise SystemExit(f"User {username} does not exist in the database")
await session.delete(user_in_db) await session.delete(user_in_db)
async def enable_user(session: AsyncSession, username: str, enable=True): async def enable_user(session: AsyncSession, username: str, enable=True):
user_in_db = await get_user(session, username) user_in_db = await get_user(session, username)
if user_in_db is None: if user_in_db is None:
raise SystemExit(f'User {username} does not exist in the database') raise SystemExit(f"User {username} does not exist in the database")
user_in_db.disabled = not enable # type: ignore user_in_db.disabled = not enable # type: ignore
session.add(user_in_db) session.add(user_in_db)
await session.commit() await session.commit()
async def create_user(session: AsyncSession, username: str, async def create_user(
password: str, full_name: str, session: AsyncSession,
email: str, **kwargs): username: str,
password: str,
full_name: str,
email: str,
**kwargs,
) -> User:
user_in_db: User | None = await get_user(session, username) user_in_db: User | None = await get_user(session, username)
if user_in_db is None: if user_in_db is None:
user = User( user = User(
@ -78,20 +84,23 @@ async def create_user(session: AsyncSession, username: str,
password=get_password_hash(password), password=get_password_hash(password),
full_name=full_name, full_name=full_name,
email=email, email=email,
disabled=False disabled=False,
) )
session.add(user) session.add(user)
await session.commit()
return user
else: else:
user_in_db.full_name = full_name # type: ignore user_in_db.full_name = full_name # type: ignore
user_in_db.email = email # type: ignore user_in_db.email = email # type: ignore
user_in_db.password = get_password_hash(password) # type: ignore user_in_db.password = get_password_hash(password) # type: ignore
await session.commit() await session.commit()
return user_in_db
async def get_user( async def get_user(session: AsyncSession, username: str) -> User | None:
session: AsyncSession, query = (
username: str) -> User | None: select(User).where(User.username == username).options(selectinload(User.roles))
query = select(User).where(User.username==username).options(selectinload(User.roles)) )
data = await session.exec(query) data = await session.exec(query)
return data.one_or_none() return data.one_or_none()
@ -100,19 +109,21 @@ def verify_password(user: User, plain_password):
try: try:
return pwd_context.verify(plain_password, user.password) return pwd_context.verify(plain_password, user.password)
except UnknownHashError: except UnknownHashError:
logger.warning(f'Password not encrypted in DB for {user.username}, assuming it is stored in plain text') logger.warning(
f"Password not encrypted in DB for {user.username}, assuming it is stored in plain text"
)
return plain_password == user.password return plain_password == user.password
async def get_current_user( async def get_current_user(token: str = Depends(oauth2_scheme)) -> User | None:
token: str = Depends(oauth2_scheme)) -> User | None:
if token is None: if token is None:
return None return None
try: try:
payload = jwt.decode(token, conf.crypto.secret, payload = jwt.decode(
algorithms=[conf.crypto.algorithm]) token, conf.crypto.secret, algorithms=[conf.crypto.algorithm]
username: str = payload.get("sub", '') )
if username == '': username: str = payload.get("sub", "")
if username == "":
raise credentials_exception raise credentials_exception
except ExpiredSignatureError: except ExpiredSignatureError:
# raise expired_exception # raise expired_exception
@ -139,7 +150,8 @@ async def authenticate_user(username: str, password: str):
async def get_current_active_user( async def get_current_active_user(
current_user: Annotated[UserRead, Depends(get_current_user)]): current_user: Annotated[UserRead, Depends(get_current_user)]
):
if current_user is not None and current_user.disabled: if current_user is not None and current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user") raise HTTPException(status_code=400, detail="Inactive user")
return current_user return current_user
@ -149,7 +161,8 @@ def create_access_token(data: dict, expires_delta: timedelta):
to_encode = data.copy() to_encode = data.copy()
expire = datetime.utcnow() + expires_delta expire = datetime.utcnow() + expires_delta
to_encode.update({"exp": expire}) to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, encoded_jwt = jwt.encode(
conf.crypto.secret, to_encode, conf.crypto.secret, algorithm=conf.crypto.algorithm
algorithm=conf.crypto.algorithm) )
return encoded_jwt return encoded_jwt

View file

@ -16,9 +16,6 @@ from sqlmodel import SQLModel, delete
from gisaf.config import conf from gisaf.config import conf
from gisaf.database import db_session from gisaf.database import db_session
class ToMigrate(Exception):
pass
SHAPELY_TYPE_TO_MAPBOX_TYPE = { SHAPELY_TYPE_TO_MAPBOX_TYPE = {
'Point': 'symbol', 'Point': 'symbol',
'LineString': 'line', 'LineString': 'line',

15
uv.lock generated
View file

@ -566,6 +566,7 @@ dependencies = [
{ name = "redis" }, { name = "redis" },
{ name = "sqlalchemy", extra = ["asyncio"] }, { name = "sqlalchemy", extra = ["asyncio"] },
{ name = "sqlmodel" }, { name = "sqlmodel" },
{ name = "typer-slim" },
{ name = "uvicorn" }, { name = "uvicorn" },
{ name = "websockets" }, { name = "websockets" },
] ]
@ -627,6 +628,7 @@ requires-dist = [
{ name = "redis", specifier = ">=5.0.1" }, { name = "redis", specifier = ">=5.0.1" },
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.23" }, { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.23" },
{ name = "sqlmodel", specifier = ">=0.0.18" }, { name = "sqlmodel", specifier = ">=0.0.18" },
{ name = "typer-slim", specifier = ">=0.15.1" },
{ name = "uvicorn", specifier = ">=0.23.2" }, { name = "uvicorn", specifier = ">=0.23.2" },
{ name = "websockets", specifier = ">=12.0" }, { name = "websockets", specifier = ">=12.0" },
] ]
@ -1691,6 +1693,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f", size = 85359 }, { url = "https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f", size = 85359 },
] ]
[[package]]
name = "typer-slim"
version = "0.15.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "click" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/f8/7d/f8e0a2678a44573b2bb1e20abecb10f937a7101ce2b8e07f4eab4c721a3d/typer_slim-0.15.1.tar.gz", hash = "sha256:b8ce8fd2a3c7d52f0d0c1318776e7f2bf897fa203daf899f3863514aa926c725", size = 99874 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/1f/7b/032ecd581e2170513bb6dc3cdb2581e20fdb94a272bae70fe93f2bca580b/typer_slim-0.15.1-py3-none-any.whl", hash = "sha256:20233cb89938ea3cca633afee10b906a1b0e7c5330f31ed8c55f4f0779efe6df", size = 44968 },
]
[[package]] [[package]]
name = "types-passlib" name = "types-passlib"
version = "1.7.7.20240819" version = "1.7.7.20240819"