Add CLI, with init-db; tweak db container; misc minor tweaks in code for the init-db
Some checks failed
/ test (push) Failing after 31s
Some checks failed
/ test (push) Failing after 31s
This commit is contained in:
parent
a94d27db0c
commit
98d67f0226
11 changed files with 239 additions and 57 deletions
53
src/gisaf/cli.py
Normal file
53
src/gisaf/cli.py
Normal file
|
@ -0,0 +1,53 @@
|
|||
#!/usr/bin/env python
|
||||
from importlib.metadata import version as importlib_version
|
||||
from sqlalchemy.engine import create
|
||||
from typing_extensions import Annotated
|
||||
import typer
|
||||
|
||||
|
||||
cli = typer.Typer(no_args_is_help=True, help="Gisaf GIS backend")
|
||||
|
||||
|
||||
@cli.command()
|
||||
def create_db():
|
||||
"""Populate the database with a functional empty structure"""
|
||||
from gisaf.application import app
|
||||
from gisaf.database import create_db
|
||||
from asyncio import run
|
||||
|
||||
print(f"Create DB...")
|
||||
run(create_db())
|
||||
|
||||
|
||||
@cli.command()
|
||||
def serve(host: str = "localhost", port: int = 8000):
|
||||
"""
|
||||
Run the uvicorn server.
|
||||
|
||||
Use yaml config files or environment variables for configuration.
|
||||
Note that you can also run gisaf with:
|
||||
uvicorn src.gisaf.application:app
|
||||
"""
|
||||
from uvicorn import run
|
||||
from gisaf.application import app
|
||||
|
||||
run(app, host=host, port=port)
|
||||
|
||||
|
||||
def version_callback(show_version: bool):
|
||||
if show_version:
|
||||
print(importlib_version("gisaf-backend"))
|
||||
raise typer.Exit()
|
||||
|
||||
|
||||
@cli.callback()
|
||||
def main(
|
||||
version: Annotated[
|
||||
bool | None, typer.Option("--version", callback=version_callback)
|
||||
] = None
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
|
@ -1,12 +1,14 @@
|
|||
from contextlib import asynccontextmanager
|
||||
from typing import Annotated, Literal, Any
|
||||
from collections.abc import AsyncGenerator
|
||||
from asyncio import sleep
|
||||
import logging
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import joinedload, QueryableAttribute
|
||||
from sqlalchemy.sql.selectable import Select
|
||||
from sqlmodel import SQLModel, select
|
||||
from sqlmodel import SQLModel, select, func, col
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from fastapi import Depends
|
||||
|
||||
|
@ -16,6 +18,10 @@ import geopandas as gpd # type: ignore
|
|||
|
||||
from gisaf.config import conf
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CREATE_DB_TIMEOUT = 10
|
||||
|
||||
engine = create_async_engine(
|
||||
conf.db.get_sqla_url(),
|
||||
echo=conf.db.echo,
|
||||
|
@ -151,3 +157,68 @@ class BaseModel(SQLModel):
|
|||
|
||||
|
||||
fastapi_db_session = Annotated[AsyncSession, Depends(get_db_session)]
|
||||
|
||||
|
||||
async def create_db(drop=False):
|
||||
attempts = CREATE_DB_TIMEOUT
|
||||
|
||||
async def try_once():
|
||||
async with engine.begin() as conn:
|
||||
if drop:
|
||||
await conn.run_sync(SQLModel.metadata.drop_all)
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
logger.debug(f"Connect to database with config: {conf.db}")
|
||||
while attempts > 0:
|
||||
try:
|
||||
await try_once()
|
||||
except ConnectionRefusedError:
|
||||
logger.debug(
|
||||
f"Cannot connect to database during init (create_db), "
|
||||
f"waiting {attempts} more seconds"
|
||||
)
|
||||
attempts -= 1
|
||||
await sleep(1)
|
||||
else:
|
||||
if await is_fresh_install():
|
||||
await populate_init_db()
|
||||
return
|
||||
else:
|
||||
logger.warning(
|
||||
f"Cannot connect to database after {CREATE_DB_TIMEOUT}, giving up."
|
||||
)
|
||||
exit(1)
|
||||
|
||||
|
||||
async def is_fresh_install() -> bool:
|
||||
"""Detect is the database is newly created, without data"""
|
||||
from gisaf.models.authentication import User
|
||||
|
||||
async with db_session() as session:
|
||||
nb_users = (await session.exec(select(func.count(col(User.username))))).one()
|
||||
return nb_users == 0
|
||||
|
||||
|
||||
async def populate_init_db():
|
||||
"""Populate the database for a fresh install"""
|
||||
from sqlalchemy import text
|
||||
from gisaf.security import create_user # , add_role, add_user_role
|
||||
|
||||
logger.info("Populating initial database")
|
||||
|
||||
async with db_session() as session:
|
||||
user = await create_user(
|
||||
session=session,
|
||||
username="admin",
|
||||
password="admin",
|
||||
full_name="Admin",
|
||||
email="root@localhost.localdomain",
|
||||
active=True,
|
||||
)
|
||||
assert user is not None
|
||||
# role = await add_role(role_id="admin")
|
||||
# await add_user_role(user.username, role.name)
|
||||
# for initial in initials:
|
||||
# await session.execute(text(initial))
|
||||
# logger.debug(f"Added map style {initial}")
|
||||
# await session.commit()
|
||||
|
|
|
@ -7,17 +7,13 @@ from gisaf.models.metadata import gisaf_admin
|
|||
|
||||
|
||||
class UserRoleLink(SQLModel, table=True):
|
||||
__tablename__: str = 'roles_users' # type: ignore
|
||||
__tablename__: str = "roles_users" # type: ignore
|
||||
__table_args__ = gisaf_admin.table_args
|
||||
user_id: int | None = Field(
|
||||
default=None,
|
||||
foreign_key=gisaf_admin.table('user.id'),
|
||||
primary_key=True
|
||||
default=None, foreign_key=gisaf_admin.table("user.id"), primary_key=True
|
||||
)
|
||||
role_id: int | None = Field(
|
||||
default=None,
|
||||
foreign_key=gisaf_admin.table('role.id'),
|
||||
primary_key=True
|
||||
default=None, foreign_key=gisaf_admin.table("role.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
|
@ -34,18 +30,17 @@ class User(UserBase, table=True):
|
|||
username: str = Field(String(255), unique=True, index=True)
|
||||
email: str = Field(sa_type=String(50), unique=True)
|
||||
password: str = Field(sa_type=String(255))
|
||||
active: bool
|
||||
confirmed_at: datetime
|
||||
last_login_at: datetime
|
||||
current_login_at: datetime
|
||||
last_login_ip: str = Field(sa_type=String(255))
|
||||
current_login_ip: str = Field(sa_type=String(255))
|
||||
login_count: int
|
||||
roles: list["Role"] = Relationship(back_populates="users",
|
||||
link_model=UserRoleLink)
|
||||
active: bool = True
|
||||
confirmed_at: datetime | None = None
|
||||
last_login_at: datetime | None = None
|
||||
current_login_at: datetime | None = None
|
||||
last_login_ip: str | None = Field(sa_type=String(255), default=None)
|
||||
current_login_ip: str | None = Field(sa_type=String(255), default=None)
|
||||
login_count: int = 0
|
||||
roles: list["Role"] = Relationship(back_populates="users", link_model=UserRoleLink)
|
||||
|
||||
def can_view(self, model) -> bool:
|
||||
role = getattr(model, 'viewable_role', None)
|
||||
role = getattr(model, "viewable_role", None)
|
||||
if role:
|
||||
return self.has_role(role)
|
||||
else:
|
||||
|
@ -54,22 +49,24 @@ class User(UserBase, table=True):
|
|||
def has_role(self, role: str) -> bool:
|
||||
return role in (role.name for role in self.roles)
|
||||
|
||||
|
||||
class RoleBase(SQLModel):
|
||||
name: str = Field(unique=True)
|
||||
|
||||
|
||||
class RoleWithDescription(RoleBase):
|
||||
description: str | None
|
||||
|
||||
|
||||
class Role(RoleWithDescription, table=True):
|
||||
__table_args__ = gisaf_admin.table_args
|
||||
id: int | None = Field(default=None, primary_key=True)
|
||||
users: list[User] = Relationship(back_populates="roles",
|
||||
link_model=UserRoleLink)
|
||||
users: list[User] = Relationship(back_populates="roles", link_model=UserRoleLink)
|
||||
|
||||
|
||||
class UserReadNoRoles(UserBase):
|
||||
id: int
|
||||
email: str | None # type: ignore
|
||||
email: str | None # type: ignore
|
||||
|
||||
|
||||
class RoleRead(RoleBase):
|
||||
|
@ -83,11 +80,11 @@ class RoleReadNoUsers(RoleBase):
|
|||
|
||||
class UserRead(UserBase):
|
||||
id: int
|
||||
email: str | None # type: ignore
|
||||
email: str | None # type: ignore
|
||||
roles: list[RoleReadNoUsers] = []
|
||||
|
||||
def can_view(self, model) -> bool:
|
||||
role = getattr(model, 'viewable_role', None)
|
||||
role = getattr(model, "viewable_role", None)
|
||||
if role:
|
||||
return self.has_role(role)
|
||||
else:
|
||||
|
@ -99,4 +96,5 @@ class UserRead(UserBase):
|
|||
|
||||
# class ACL(BaseModel):
|
||||
# user_id: int
|
||||
# role_ids: list[int]
|
||||
# role_ids: list[int]
|
||||
|
||||
|
|
|
@ -538,8 +538,7 @@ async def create_tags(features, keys, values):
|
|||
|
||||
return pd.concat(result)
|
||||
|
||||
from gisaf.utils import ToMigrate
|
||||
logger.warning(ToMigrate('plugins.change_feature_status (graphql)'))
|
||||
# TODO: Migrate('plugins.change_feature_status (graphql)'))
|
||||
change_feature_status = ChangeFeatureStatus(
|
||||
name='Change status',
|
||||
stores_by_re=[f"{conf.survey.db_schema}"],
|
||||
|
|
|
@ -48,6 +48,7 @@ expired_exception = HTTPException(
|
|||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
def get_password_hash(password: str):
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
@ -55,22 +56,27 @@ def get_password_hash(password: str):
|
|||
async def delete_user(session: AsyncSession, username: str) -> None:
|
||||
user_in_db: User | None = await get_user(session, username)
|
||||
if user_in_db is None:
|
||||
raise SystemExit(f'User {username} does not exist in the database')
|
||||
raise SystemExit(f"User {username} does not exist in the database")
|
||||
await session.delete(user_in_db)
|
||||
|
||||
|
||||
async def enable_user(session: AsyncSession, username: str, enable=True):
|
||||
user_in_db = await get_user(session, username)
|
||||
if user_in_db is None:
|
||||
raise SystemExit(f'User {username} does not exist in the database')
|
||||
user_in_db.disabled = not enable # type: ignore
|
||||
raise SystemExit(f"User {username} does not exist in the database")
|
||||
user_in_db.disabled = not enable # type: ignore
|
||||
session.add(user_in_db)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def create_user(session: AsyncSession, username: str,
|
||||
password: str, full_name: str,
|
||||
email: str, **kwargs):
|
||||
async def create_user(
|
||||
session: AsyncSession,
|
||||
username: str,
|
||||
password: str,
|
||||
full_name: str,
|
||||
email: str,
|
||||
**kwargs,
|
||||
) -> User:
|
||||
user_in_db: User | None = await get_user(session, username)
|
||||
if user_in_db is None:
|
||||
user = User(
|
||||
|
@ -78,20 +84,23 @@ async def create_user(session: AsyncSession, username: str,
|
|||
password=get_password_hash(password),
|
||||
full_name=full_name,
|
||||
email=email,
|
||||
disabled=False
|
||||
disabled=False,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
return user
|
||||
else:
|
||||
user_in_db.full_name = full_name # type: ignore
|
||||
user_in_db.email = email # type: ignore
|
||||
user_in_db.password = get_password_hash(password) # type: ignore
|
||||
await session.commit()
|
||||
user_in_db.full_name = full_name # type: ignore
|
||||
user_in_db.email = email # type: ignore
|
||||
user_in_db.password = get_password_hash(password) # type: ignore
|
||||
await session.commit()
|
||||
return user_in_db
|
||||
|
||||
|
||||
async def get_user(
|
||||
session: AsyncSession,
|
||||
username: str) -> User | None:
|
||||
query = select(User).where(User.username==username).options(selectinload(User.roles))
|
||||
async def get_user(session: AsyncSession, username: str) -> User | None:
|
||||
query = (
|
||||
select(User).where(User.username == username).options(selectinload(User.roles))
|
||||
)
|
||||
data = await session.exec(query)
|
||||
return data.one_or_none()
|
||||
|
||||
|
@ -100,19 +109,21 @@ def verify_password(user: User, plain_password):
|
|||
try:
|
||||
return pwd_context.verify(plain_password, user.password)
|
||||
except UnknownHashError:
|
||||
logger.warning(f'Password not encrypted in DB for {user.username}, assuming it is stored in plain text')
|
||||
logger.warning(
|
||||
f"Password not encrypted in DB for {user.username}, assuming it is stored in plain text"
|
||||
)
|
||||
return plain_password == user.password
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme)) -> User | None:
|
||||
async def get_current_user(token: str = Depends(oauth2_scheme)) -> User | None:
|
||||
if token is None:
|
||||
return None
|
||||
try:
|
||||
payload = jwt.decode(token, conf.crypto.secret,
|
||||
algorithms=[conf.crypto.algorithm])
|
||||
username: str = payload.get("sub", '')
|
||||
if username == '':
|
||||
payload = jwt.decode(
|
||||
token, conf.crypto.secret, algorithms=[conf.crypto.algorithm]
|
||||
)
|
||||
username: str = payload.get("sub", "")
|
||||
if username == "":
|
||||
raise credentials_exception
|
||||
except ExpiredSignatureError:
|
||||
# raise expired_exception
|
||||
|
@ -139,7 +150,8 @@ async def authenticate_user(username: str, password: str):
|
|||
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: Annotated[UserRead, Depends(get_current_user)]):
|
||||
current_user: Annotated[UserRead, Depends(get_current_user)]
|
||||
):
|
||||
if current_user is not None and current_user.disabled:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
return current_user
|
||||
|
@ -149,7 +161,8 @@ def create_access_token(data: dict, expires_delta: timedelta):
|
|||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode,
|
||||
conf.crypto.secret,
|
||||
algorithm=conf.crypto.algorithm)
|
||||
return encoded_jwt
|
||||
encoded_jwt = jwt.encode(
|
||||
to_encode, conf.crypto.secret, algorithm=conf.crypto.algorithm
|
||||
)
|
||||
return encoded_jwt
|
||||
|
||||
|
|
|
@ -16,9 +16,6 @@ from sqlmodel import SQLModel, delete
|
|||
from gisaf.config import conf
|
||||
from gisaf.database import db_session
|
||||
|
||||
class ToMigrate(Exception):
|
||||
pass
|
||||
|
||||
SHAPELY_TYPE_TO_MAPBOX_TYPE = {
|
||||
'Point': 'symbol',
|
||||
'LineString': 'line',
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue