Add CLI, with init-db; tweak db container; misc minor tweaks in code for the init-db
Some checks failed
/ test (push) Failing after 31s

This commit is contained in:
phil 2024-12-15 19:19:35 +01:00
parent a94d27db0c
commit 98d67f0226
11 changed files with 239 additions and 57 deletions

53
src/gisaf/cli.py Normal file
View file

@ -0,0 +1,53 @@
#!/usr/bin/env python
from importlib.metadata import version as importlib_version
from sqlalchemy.engine import create
from typing_extensions import Annotated
import typer
cli = typer.Typer(no_args_is_help=True, help="Gisaf GIS backend")
@cli.command()
def create_db():
"""Populate the database with a functional empty structure"""
from gisaf.application import app
from gisaf.database import create_db
from asyncio import run
print(f"Create DB...")
run(create_db())
@cli.command()
def serve(host: str = "localhost", port: int = 8000):
"""
Run the uvicorn server.
Use yaml config files or environment variables for configuration.
Note that you can also run gisaf with:
uvicorn src.gisaf.application:app
"""
from uvicorn import run
from gisaf.application import app
run(app, host=host, port=port)
def version_callback(show_version: bool):
if show_version:
print(importlib_version("gisaf-backend"))
raise typer.Exit()
@cli.callback()
def main(
version: Annotated[
bool | None, typer.Option("--version", callback=version_callback)
] = None
):
pass
if __name__ == "__main__":
cli()

View file

@ -1,12 +1,14 @@
from contextlib import asynccontextmanager
from typing import Annotated, Literal, Any
from collections.abc import AsyncGenerator
from asyncio import sleep
import logging
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy import create_engine
from sqlalchemy.orm import joinedload, QueryableAttribute
from sqlalchemy.sql.selectable import Select
from sqlmodel import SQLModel, select
from sqlmodel import SQLModel, select, func, col
from sqlmodel.ext.asyncio.session import AsyncSession
from fastapi import Depends
@ -16,6 +18,10 @@ import geopandas as gpd # type: ignore
from gisaf.config import conf
logger = logging.getLogger(__name__)
CREATE_DB_TIMEOUT = 10
engine = create_async_engine(
conf.db.get_sqla_url(),
echo=conf.db.echo,
@ -151,3 +157,68 @@ class BaseModel(SQLModel):
fastapi_db_session = Annotated[AsyncSession, Depends(get_db_session)]
async def create_db(drop=False):
attempts = CREATE_DB_TIMEOUT
async def try_once():
async with engine.begin() as conn:
if drop:
await conn.run_sync(SQLModel.metadata.drop_all)
await conn.run_sync(SQLModel.metadata.create_all)
logger.debug(f"Connect to database with config: {conf.db}")
while attempts > 0:
try:
await try_once()
except ConnectionRefusedError:
logger.debug(
f"Cannot connect to database during init (create_db), "
f"waiting {attempts} more seconds"
)
attempts -= 1
await sleep(1)
else:
if await is_fresh_install():
await populate_init_db()
return
else:
logger.warning(
f"Cannot connect to database after {CREATE_DB_TIMEOUT}, giving up."
)
exit(1)
async def is_fresh_install() -> bool:
"""Detect is the database is newly created, without data"""
from gisaf.models.authentication import User
async with db_session() as session:
nb_users = (await session.exec(select(func.count(col(User.username))))).one()
return nb_users == 0
async def populate_init_db():
"""Populate the database for a fresh install"""
from sqlalchemy import text
from gisaf.security import create_user # , add_role, add_user_role
logger.info("Populating initial database")
async with db_session() as session:
user = await create_user(
session=session,
username="admin",
password="admin",
full_name="Admin",
email="root@localhost.localdomain",
active=True,
)
assert user is not None
# role = await add_role(role_id="admin")
# await add_user_role(user.username, role.name)
# for initial in initials:
# await session.execute(text(initial))
# logger.debug(f"Added map style {initial}")
# await session.commit()

View file

@ -7,17 +7,13 @@ from gisaf.models.metadata import gisaf_admin
class UserRoleLink(SQLModel, table=True):
__tablename__: str = 'roles_users' # type: ignore
__tablename__: str = "roles_users" # type: ignore
__table_args__ = gisaf_admin.table_args
user_id: int | None = Field(
default=None,
foreign_key=gisaf_admin.table('user.id'),
primary_key=True
default=None, foreign_key=gisaf_admin.table("user.id"), primary_key=True
)
role_id: int | None = Field(
default=None,
foreign_key=gisaf_admin.table('role.id'),
primary_key=True
default=None, foreign_key=gisaf_admin.table("role.id"), primary_key=True
)
@ -34,18 +30,17 @@ class User(UserBase, table=True):
username: str = Field(String(255), unique=True, index=True)
email: str = Field(sa_type=String(50), unique=True)
password: str = Field(sa_type=String(255))
active: bool
confirmed_at: datetime
last_login_at: datetime
current_login_at: datetime
last_login_ip: str = Field(sa_type=String(255))
current_login_ip: str = Field(sa_type=String(255))
login_count: int
roles: list["Role"] = Relationship(back_populates="users",
link_model=UserRoleLink)
active: bool = True
confirmed_at: datetime | None = None
last_login_at: datetime | None = None
current_login_at: datetime | None = None
last_login_ip: str | None = Field(sa_type=String(255), default=None)
current_login_ip: str | None = Field(sa_type=String(255), default=None)
login_count: int = 0
roles: list["Role"] = Relationship(back_populates="users", link_model=UserRoleLink)
def can_view(self, model) -> bool:
role = getattr(model, 'viewable_role', None)
role = getattr(model, "viewable_role", None)
if role:
return self.has_role(role)
else:
@ -54,22 +49,24 @@ class User(UserBase, table=True):
def has_role(self, role: str) -> bool:
return role in (role.name for role in self.roles)
class RoleBase(SQLModel):
name: str = Field(unique=True)
class RoleWithDescription(RoleBase):
description: str | None
class Role(RoleWithDescription, table=True):
__table_args__ = gisaf_admin.table_args
id: int | None = Field(default=None, primary_key=True)
users: list[User] = Relationship(back_populates="roles",
link_model=UserRoleLink)
users: list[User] = Relationship(back_populates="roles", link_model=UserRoleLink)
class UserReadNoRoles(UserBase):
id: int
email: str | None # type: ignore
email: str | None # type: ignore
class RoleRead(RoleBase):
@ -83,11 +80,11 @@ class RoleReadNoUsers(RoleBase):
class UserRead(UserBase):
id: int
email: str | None # type: ignore
email: str | None # type: ignore
roles: list[RoleReadNoUsers] = []
def can_view(self, model) -> bool:
role = getattr(model, 'viewable_role', None)
role = getattr(model, "viewable_role", None)
if role:
return self.has_role(role)
else:
@ -99,4 +96,5 @@ class UserRead(UserBase):
# class ACL(BaseModel):
# user_id: int
# role_ids: list[int]
# role_ids: list[int]

View file

@ -538,8 +538,7 @@ async def create_tags(features, keys, values):
return pd.concat(result)
from gisaf.utils import ToMigrate
logger.warning(ToMigrate('plugins.change_feature_status (graphql)'))
# TODO: Migrate('plugins.change_feature_status (graphql)'))
change_feature_status = ChangeFeatureStatus(
name='Change status',
stores_by_re=[f"{conf.survey.db_schema}"],

View file

@ -48,6 +48,7 @@ expired_exception = HTTPException(
headers={"WWW-Authenticate": "Bearer"},
)
def get_password_hash(password: str):
return pwd_context.hash(password)
@ -55,22 +56,27 @@ def get_password_hash(password: str):
async def delete_user(session: AsyncSession, username: str) -> None:
user_in_db: User | None = await get_user(session, username)
if user_in_db is None:
raise SystemExit(f'User {username} does not exist in the database')
raise SystemExit(f"User {username} does not exist in the database")
await session.delete(user_in_db)
async def enable_user(session: AsyncSession, username: str, enable=True):
user_in_db = await get_user(session, username)
if user_in_db is None:
raise SystemExit(f'User {username} does not exist in the database')
user_in_db.disabled = not enable # type: ignore
raise SystemExit(f"User {username} does not exist in the database")
user_in_db.disabled = not enable # type: ignore
session.add(user_in_db)
await session.commit()
async def create_user(session: AsyncSession, username: str,
password: str, full_name: str,
email: str, **kwargs):
async def create_user(
session: AsyncSession,
username: str,
password: str,
full_name: str,
email: str,
**kwargs,
) -> User:
user_in_db: User | None = await get_user(session, username)
if user_in_db is None:
user = User(
@ -78,20 +84,23 @@ async def create_user(session: AsyncSession, username: str,
password=get_password_hash(password),
full_name=full_name,
email=email,
disabled=False
disabled=False,
)
session.add(user)
await session.commit()
return user
else:
user_in_db.full_name = full_name # type: ignore
user_in_db.email = email # type: ignore
user_in_db.password = get_password_hash(password) # type: ignore
await session.commit()
user_in_db.full_name = full_name # type: ignore
user_in_db.email = email # type: ignore
user_in_db.password = get_password_hash(password) # type: ignore
await session.commit()
return user_in_db
async def get_user(
session: AsyncSession,
username: str) -> User | None:
query = select(User).where(User.username==username).options(selectinload(User.roles))
async def get_user(session: AsyncSession, username: str) -> User | None:
query = (
select(User).where(User.username == username).options(selectinload(User.roles))
)
data = await session.exec(query)
return data.one_or_none()
@ -100,19 +109,21 @@ def verify_password(user: User, plain_password):
try:
return pwd_context.verify(plain_password, user.password)
except UnknownHashError:
logger.warning(f'Password not encrypted in DB for {user.username}, assuming it is stored in plain text')
logger.warning(
f"Password not encrypted in DB for {user.username}, assuming it is stored in plain text"
)
return plain_password == user.password
async def get_current_user(
token: str = Depends(oauth2_scheme)) -> User | None:
async def get_current_user(token: str = Depends(oauth2_scheme)) -> User | None:
if token is None:
return None
try:
payload = jwt.decode(token, conf.crypto.secret,
algorithms=[conf.crypto.algorithm])
username: str = payload.get("sub", '')
if username == '':
payload = jwt.decode(
token, conf.crypto.secret, algorithms=[conf.crypto.algorithm]
)
username: str = payload.get("sub", "")
if username == "":
raise credentials_exception
except ExpiredSignatureError:
# raise expired_exception
@ -139,7 +150,8 @@ async def authenticate_user(username: str, password: str):
async def get_current_active_user(
current_user: Annotated[UserRead, Depends(get_current_user)]):
current_user: Annotated[UserRead, Depends(get_current_user)]
):
if current_user is not None and current_user.disabled:
raise HTTPException(status_code=400, detail="Inactive user")
return current_user
@ -149,7 +161,8 @@ def create_access_token(data: dict, expires_delta: timedelta):
to_encode = data.copy()
expire = datetime.utcnow() + expires_delta
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode,
conf.crypto.secret,
algorithm=conf.crypto.algorithm)
return encoded_jwt
encoded_jwt = jwt.encode(
to_encode, conf.crypto.secret, algorithm=conf.crypto.algorithm
)
return encoded_jwt

View file

@ -16,9 +16,6 @@ from sqlmodel import SQLModel, delete
from gisaf.config import conf
from gisaf.database import db_session
class ToMigrate(Exception):
pass
SHAPELY_TYPE_TO_MAPBOX_TYPE = {
'Point': 'symbol',
'LineString': 'line',