Basic registry, with survey stores
Move to standard src/ dir
versions: sqlmodel official, pydantic v2
etc...
This commit is contained in:
phil 2023-12-13 01:25:00 +05:30
parent 5494f6085f
commit 049b8c9927
31 changed files with 670 additions and 526 deletions

0
src/gisaf/__init__.py Normal file
View file

1
src/gisaf/_version.py Normal file
View file

@ -0,0 +1 @@
__version__ = '2023.4.dev3+g5494f60.d20231212'

113
src/gisaf/api.py Normal file
View file

@ -0,0 +1,113 @@
import logging
from datetime import timedelta
from typing import Annotated
from fastapi import Depends, FastAPI, HTTPException, status, responses
from sqlalchemy.orm import selectinload
from fastapi.security import OAuth2PasswordRequestForm
from sqlmodel import select
from sqlmodel.ext.asyncio.session import AsyncSession
from .models.authentication import (
User, UserRead,
Role, RoleRead,
)
from .models.category import Category, CategoryRead
from .config import conf
from .models.bootstrap import BootstrapData
from .models.store import Store
from .database import get_db_session, pandas_query
from .security import (
Token,
authenticate_user, get_current_user, create_access_token,
)
from .registry import registry
logger = logging.getLogger(__name__)
api = FastAPI(
default_response_class=responses.ORJSONResponse,
)
#api.add_middleware(SessionMiddleware, secret_key=conf.crypto.secret)
db_session = Annotated[AsyncSession, Depends(get_db_session)]
@api.get('/bootstrap')
async def bootstrap(
user: Annotated[UserRead, Depends(get_current_user)]) -> BootstrapData:
return BootstrapData(user=user)
@api.post("/token")
async def login_for_access_token(
db_session: db_session,
form_data: OAuth2PasswordRequestForm = Depends()
) -> Token:
user = await authenticate_user(form_data.username, form_data.password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token = create_access_token(
data={"sub": user.username},
expires_delta=timedelta(seconds=conf.crypto.expire))
return {"access_token": access_token, "token_type": "bearer"}
@api.get("/list")
async def list_data_providers():
"""
Return a list of data providers, for use with the api (graphs, etc)
:return:
"""
return [{'name': m.__name__, 'store': m.get_store_name()}
for m in registry.values_for_model]
@api.get("/users")
async def get_users(
db_session: db_session,
) -> list[UserRead]:
query = select(User).options(selectinload(User.roles))
data = await db_session.exec(query)
return data.all()
@api.get("/roles")
async def get_roles(
db_session: db_session,
) -> list[RoleRead]:
query = select(Role).options(selectinload(Role.users))
data = await db_session.exec(query)
return data.all()
@api.get("/categories")
async def get_categories(
db_session: db_session,
) -> list[CategoryRead]:
query = select(Category)
data = await db_session.exec(query)
return data.all()
@api.get("/categories_pandas")
async def get_categories_p(
db_session: db_session,
) -> list[CategoryRead]:
query = select(Category)
df = await db_session.run_sync(pandas_query, query)
return df.to_dict(orient="records")
@api.get("/stores")
async def get_stores() -> list[Store]:
df = registry.stores.reset_index().drop(columns=['model', 'raw_model'])
return df.to_dict(orient="records")
# @api.get("/user-role")
# async def get_user_role_relation(
# *, db_session: AsyncSession = Depends(get_db_session)
# ) -> list[UserRoleLink]:
# roles = await db_session.exec(select(UserRoleLink))
# return roles.all()

40
src/gisaf/application.py Normal file
View file

@ -0,0 +1,40 @@
from contextlib import asynccontextmanager
import logging
from typing import Any
#import colorama
#colorama.init()
from fastapi import FastAPI, responses
from .api import api
from .config import conf
from .registry import registry, ModelRegistry
logging.basicConfig(level=conf.gisaf.debugLevel)
logger = logging.getLogger(__name__)
## Subclass FastAPI to add attributes to be used globally, ie. registry
class GisafExtra:
registry: ModelRegistry
#raw_survey_models: dict[str, Any] = {}
#survey_models: dict[str, Any] = {}
class GisafFastAPI(FastAPI):
gisaf_extra: GisafExtra
@asynccontextmanager
async def lifespan(app: FastAPI):
await registry.make_registry(app)
yield
app = FastAPI(
debug=False,
title=conf.gisaf.title,
version=conf.version,
lifespan=lifespan,
default_response_class=responses.ORJSONResponse,
)
app.mount('/v2', api)

331
src/gisaf/config.py Normal file
View file

@ -0,0 +1,331 @@
from os import environ
import logging
from pathlib import Path
from typing import Any, Type, Tuple
from pydantic_settings import (BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict)
from pydantic import ConfigDict
from pydantic.v1.utils import deep_update
from yaml import safe_load
from ._version import __version__
#from sqlalchemy.ext.asyncio.engine import AsyncEngine
#from sqlalchemy.orm.session import sessionmaker
logger = logging.getLogger(__name__)
ENV = environ.get('env', 'prod')
config_files = [
Path(Path.cwd().root) / 'etc' / 'gisaf' / ENV,
Path.home() / '.local' / 'gisaf' / ENV
]
class DashboardHome(BaseSettings):
title: str
content_file: str
footer_file: str
class GisafConfig(BaseSettings):
title: str
windowTitle: str
debugLevel: str
dashboard_home: DashboardHome
redirect: str = ''
use_pretty_errors: bool = False
class SpatialSysRef(BaseSettings):
author: str
ellps: str
k: int
lat_0: float
lon_0: float
no_defs: bool
proj: str
towgs84: str
units: str
x_0: float
y_0: float
class RawSurvey(BaseSettings):
spatial_sys_ref: SpatialSysRef
srid: int
class Geo(BaseSettings):
raw_survey: RawSurvey
simplify_geom_factor: int
srid: int
srid_for_proj: int
class Flask(BaseSettings):
secret_key: str
debug: int
class MQTT(BaseSettings):
broker: str = 'localhost'
port: int = 1883
class GisafLive(BaseSettings):
hostname: str
port: int
scheme: str
redis: str
mqtt: MQTT
class DefaultSurvey(BaseSettings):
surveyor_id: int
equipment_id: int
class Survey(BaseSettings):
model_config = ConfigDict(extra='ignore')
db_schema_raw: str
db_schema: str
default: DefaultSurvey
class Crypto(BaseSettings):
secret: str
algorithm: str
expire: float
class DB(BaseSettings):
uri: str
host: str
user: str
db: str
password: str
debug: bool
info: bool
pool_size: int = 10
max_overflow: int = 10
class Log(BaseSettings):
level: str
class OGCAPILicense(BaseSettings):
name: str
url: str
class OGCAPIProvider(BaseSettings):
name: str
url: str
class OGCAPIServerContact(BaseSettings):
name: str
address: str
city: str
stateorprovince: str
postalcode: int
country: str
email: str
class OGCAPIIdentification(BaseSettings):
title: str
description: str
keywords: list[str]
keywords_type: str
terms_of_service: str
url: str
class OGCAPIMetadata(BaseSettings):
identification: OGCAPIIdentification
license: OGCAPILicense
provider: OGCAPIProvider
contact: OGCAPIServerContact
class ServerBind(BaseSettings):
host: str
port: int
class OGCAPIServerMap(BaseSettings):
url: str
attribution: str
class OGCAPIServer(BaseSettings):
bind: ServerBind
url: str
mimetype: str
encoding: str
language: str
pretty_print: bool
limit: int
map: OGCAPIServerMap
class OGCAPI(BaseSettings):
base_url: str
bbox: list[float]
log: Log
metadata: OGCAPIMetadata
server: OGCAPIServer
class TileServer(BaseSettings):
baseDir: str
useRequestUrl: bool = False
spriteBaseDir: str
spriteUrl: str
spriteBaseUrl: str
openMapTilesKey: str | None = None
class Map(BaseSettings):
tileServer: TileServer | None = None
zoom: int
pitch: int
lat: float
lng: float
bearing: float
style: str
opacity: float
attribution: str
status: list[str]
defaultStatus: list[str] # FIXME: should be str
tagKeys: list[str]
class Measures(BaseSettings):
defaultStore: str
class BasketDefault(BaseSettings):
surveyor: str
equipment: str
project: str | None
status: str
store: str | None
class BasketOldDef(BaseSettings):
base_dir: str
class Basket(BaseSettings):
base_dir: str
default: BasketDefault
class Plot(BaseSettings):
maxDataSize: int
class Dashboard(BaseSettings):
base_source_url: str
base_storage_dir: str
base_storage_url: str
class Widgets(BaseSettings):
footer: str
class Admin(BaseSettings):
basket: Basket
class Attachments(BaseSettings):
base_dir: str
class Job(BaseSettings):
id: str
func: str
trigger: str
minutes: int | None = 0
seconds: int | None = 0
class Config(BaseSettings):
model_config = SettingsConfigDict(
#env_prefix='gisaf_',
env_nested_delimiter='__',
)
@classmethod
def settings_customise_sources(
cls,
settings_cls: Type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> Tuple[PydanticBaseSettingsSource, ...]:
return env_settings, init_settings, file_secret_settings, config_file_settings
admin: Admin
attachments: Attachments
basket: BasketOldDef
crypto: Crypto
dashboard: Dashboard
db: DB
flask: Flask
geo: Geo
gisaf: GisafConfig
gisaf_live: GisafLive
jobs: list[Job]
map: Map
measures: Measures
ogcapi: OGCAPI
plot: Plot
plugins: dict[str, dict[str, Any]]
survey: Survey
version: str
weather_station: dict[str, dict[str, Any]]
widgets: Widgets
#engine: AsyncEngine
#session_maker: sessionmaker
def config_file_settings() -> dict[str, Any]:
config: dict[str, Any] = {}
for p in config_files:
for suffix in {".yaml", ".yml"}:
path = p.with_suffix(suffix)
if not path.is_file():
logger.info(f"No file found at `{path.resolve()}`")
continue
logger.debug(f"Reading config file `{path.resolve()}`")
if path.suffix in {".yaml", ".yml"}:
config = deep_update(config, load_yaml(path))
else:
logger.info(f"Unknown config file extension `{path.suffix}`")
return config
def load_yaml(path: Path) -> dict[str, Any]:
with Path(path).open("r") as f:
config = safe_load(f)
if not isinstance(config, dict):
raise TypeError(
f"Config file has no top-level mapping: {path}"
)
return config
conf = Config(version=__version__)
# def set_app_config(app) -> None:
# raw_configs = []
# with open(Path(__file__).parent / 'defaults.yml') as cf:
# raw_configs.append(cf.read())
# for cf_path in (
# Path(Path.cwd().root) / 'etc' / 'gisaf' / ENV,
# Path.home() / '.local' / 'gisaf' / ENV
# ):
# try:
# with open(cf_path.with_suffix('.yml')) as cf:
# raw_configs.append(cf.read())
# except FileNotFoundError:
# pass
# yaml_config = safe_load('\n'.join(raw_configs))
# conf.app = yaml_config['app']
# conf.postgres = yaml_config['postgres']
# conf.storage = yaml_config['storage']
# conf.map = yaml_config['map']
# conf.security = yaml_config['security']
# # create_dirs()
# def create_dirs():
# """
# Create the directories needed for a proper functioning of the app
# """
# ## Avoid circular imports
# from treetrail.api_v1 import attachment_types
# for type in attachment_types:
# base_dir = Path(conf.storage['root_attachment_path']) / type
# base_dir.mkdir(parents=True, exist_ok=True)
# logger.info(f'Cache dir: {get_cache_dir()}')
# get_cache_dir().mkdir(parents=True, exist_ok=True)
# def get_cache_dir() -> Path:
# return Path(conf.storage['root_cache_path'])

30
src/gisaf/database.py Normal file
View file

@ -0,0 +1,30 @@
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import create_async_engine
from sqlmodel.ext.asyncio.session import AsyncSession
import pandas as pd
from .config import conf
echo = False
pg_url = "postgresql+asyncpg://avgis@localhost/avgis"
engine = create_async_engine(
pg_url,
echo=echo,
pool_size=conf.db.pool_size,
max_overflow=conf.db.max_overflow,
)
async def get_db_session() -> AsyncSession:
async with AsyncSession(engine) as session:
yield session
@asynccontextmanager
async def db_session() -> AsyncSession:
async with AsyncSession(engine) as session:
yield session
def pandas_query(session, query):
return pd.read_sql_query(query, session.connection())

View file

View file

@ -0,0 +1,59 @@
from sqlmodel import Field, SQLModel, MetaData, Relationship
from .metadata import gisaf_admin
class UserRoleLink(SQLModel, table=True):
metadata = gisaf_admin
__tablename__: str = 'roles_users'
user_id: int | None = Field(
default=None, foreign_key="user.id", primary_key=True
)
role_id: int | None = Field(
default=None, foreign_key="role.id", primary_key=True
)
class UserBase(SQLModel):
username: str
email: str
class User(UserBase, table=True):
metadata = gisaf_admin
id: int | None = Field(default=None, primary_key=True)
roles: list["Role"] = Relationship(back_populates="users",
link_model=UserRoleLink)
password: str | None = None
class RoleBase(SQLModel):
name: str = Field(unique=True)
class RoleWithDescription(RoleBase):
description: str | None
class Role(RoleWithDescription, table=True):
metadata = gisaf_admin
id: int | None = Field(default=None, primary_key=True)
users: list[User] = Relationship(back_populates="roles",
link_model=UserRoleLink)
class UserReadNoRoles(UserBase):
id: int
email: str | None
class RoleRead(RoleBase):
id: int
users: list[UserReadNoRoles] = []
class RoleReadNoUsers(RoleBase):
id: int
class UserRead(UserBase):
id: int
email: str | None
roles: list[RoleReadNoUsers] = []

View file

@ -0,0 +1,18 @@
from pydantic import BaseModel
from ..config import conf, Map, Measures, Geo
from .authentication import UserRead
class Proj(BaseModel):
srid: str
srid_for_proj: str
class BootstrapData(BaseModel):
version: str = conf.version
title: str = conf.gisaf.title
windowTitle: str = conf.gisaf.windowTitle
map: Map = conf.map
geo: Geo = conf.geo
measures: Measures = conf.measures
redirect: str = conf.gisaf.redirect
user: UserRead | None = None

View file

@ -0,0 +1,121 @@
from typing import Any, ClassVar
from pydantic import computed_field, ConfigDict
from sqlmodel import Field, Relationship, SQLModel, JSON, TEXT, Column, select
from .metadata import gisaf_survey
from ..database import db_session, pandas_query
mapbox_type_mapping = {
'Point': 'symbol',
'Line': 'line',
'Polygon': 'fill',
}
class BaseModel(SQLModel):
@classmethod
async def get_df(cls):
async with db_session() as session:
query = select(cls)
return await session.run_sync(pandas_query, query)
class CategoryGroup(BaseModel, table=True):
metadata = gisaf_survey
__tablename__ = 'category_group'
name: str | None = Field(min_length=4, max_length=4,
default=None, primary_key=True)
major: str
long_name: str
categories: list['Category'] = Relationship(back_populates='category_group')
class Admin:
menu = 'Other'
flask_admin_model_view = 'CategoryGroupModelView'
class CategoryModelType(BaseModel, table=True):
metadata = gisaf_survey
__tablename__ = 'category_model_type'
name: str = Field(default=None, primary_key=True)
class Admin:
menu = 'Other'
flask_admin_model_view = 'MyModelViewWithPrimaryKey'
class CategoryBase(BaseModel):
model_config = ConfigDict(protected_namespaces=())
class Admin:
menu = 'Other'
flask_admin_model_view = 'CategoryModelView'
name: str | None = Field(default=None, primary_key=True)
domain: ClassVar[str] = 'V'
description: str | None
group: str = Field(min_length=4, max_length=4,
foreign_key="category_group.name", index=True)
minor_group_1: str = Field(min_length=4, max_length=4, default='----')
minor_group_2: str = Field(min_length=4, max_length=4, default='----')
status: str = Field(min_length=1, max_length=1)
custom: bool | None
auto_import: bool = True
model_type: str = Field(max_length=50,
foreign_key='category_model_type.name',
default='Point')
long_name: str | None = Field(max_length=50)
style: str | None = Field(sa_type=TEXT)
symbol: str | None = Field(max_length=1)
mapbox_type_custom: str | None = Field(max_length=32)
mapbox_paint: dict[str, Any] | None = Field(sa_type=JSON(none_as_null=True))
mapbox_layout: dict[str, Any] | None = Field(sa_type=JSON(none_as_null=True))
viewable_role: str | None
extra: dict[str, Any] | None = Field(sa_type=JSON(none_as_null=True))
@computed_field
@property
def layer_name(self) -> str:
"""
ISO compliant layer name (see ISO 13567)
:return: str
"""
return '{self.domain}-{self.group:4s}-{self.minor_group_1:4s}-{self.minor_group_2:4s}-{self.status:1s}'.format(self=self)
@computed_field
@property
def table_name(self) -> str:
"""
Table name
:return:
"""
if self.minor_group_2 == '----':
return '{self.domain}_{self.group:4s}_{self.minor_group_1:4s}'.format(self=self)
else:
return '{self.domain}_{self.group:4s}_{self.minor_group_1:4s}_{self.minor_group_2:4s}'.format(self=self)
@computed_field
@property
def raw_survey_table_name(self) -> str:
"""
Table name
:return:
"""
if self.minor_group_2 == '----':
return 'RAW_{self.domain}_{self.group:4s}_{self.minor_group_1:4s}'.format(self=self)
else:
return 'RAW_{self.domain}_{self.group:4s}_{self.minor_group_1:4s}_{self.minor_group_2:4s}'.format(self=self)
@computed_field
@property
def mapbox_type(self) -> str:
return self.mapbox_type_custom or mapbox_type_mapping[self.model_type]
class Category(CategoryBase, table=True):
metadata = gisaf_survey
name: str = Field(default=None, primary_key=True)
category_group: CategoryGroup = Relationship(back_populates="categories")
class CategoryRead(CategoryBase):
name: str

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,66 @@
from typing import Any
from sqlmodel import Field, String, JSON, Column
from .models_base import Model
from .metadata import gisaf_map
class BaseStyle(Model):
metadata = gisaf_map
__tablename__ = 'map_base_style'
class Admin:
menu = 'Other'
flask_admin_model_view = 'MapBaseStyleModelView'
id: int = Field(primary_key=True)
name: str
style: dict[str, Any] | None = Field(sa_type=JSON(none_as_null=True))
mbtiles: str = Field(sa_type=String(50))
static_tiles_url: str
enabled: bool = True
def __repr__(self):
return '<models.BaseStyle {self.name:s}>'.format(self=self)
class BaseMap(Model):
metadata = gisaf_map
__tablename__ = 'base_map'
class Admin:
menu = 'Other'
id: int = Field(primary_key=True)
name: str
def __repr__(self):
return '<models.BaseMap {self.name:s}>'.format(self=self)
def __str__(self):
return self.name
class BaseMapLayer(Model):
metadata = gisaf_map
__tablename__ = 'base_map_layer'
class Admin:
menu = 'Other'
id: int = Field(primary_key=True)
base_map_id: int = Field(foreign_key='base_map.id', index=True)
store: str = Field(sa_type=String(100))
@classmethod
def dyn_join_with(cls):
return {
'base_map': BaseMap,
}
def __repr__(self):
return f"<models.BaseMapLayer {self.store or '':s}>"
def __str__(self):
return f"{self.store or '':s}"

View file

@ -0,0 +1,10 @@
from sqlmodel import MetaData
from ..config import conf
gisaf = MetaData(schema='gisaf')
gisaf_survey = MetaData(schema='gisaf_survey')
gisaf_admin = MetaData(schema='gisaf_admin')
gisaf_map = MetaData(schema='gisaf_map')
raw_survey = MetaData(schema=conf.survey.db_schema_raw)
survey = MetaData(schema=conf.survey.db_schema)

37
src/gisaf/models/misc.py Normal file
View file

@ -0,0 +1,37 @@
import logging
from typing import Any
from pydantic import ConfigDict
from sqlmodel import Field, JSON, Column
from .models_base import Model
from .metadata import gisaf_map
logger = logging.getLogger(__name__)
class NotADataframeError(Exception):
pass
class Qml(Model):
"""
Model for storing qml (QGis style)
"""
model_config = ConfigDict(protected_namespaces=())
metadata = gisaf_map
class Admin:
menu = 'Other'
flask_admin_model_view = 'QmlModelView'
model_name: str = Field(default=None, primary_key=True)
qml: str
attr: str
style: str
mapbox_paint: dict[str, Any] | None = Field(sa_type=JSON(none_as_null=True))
mapbox_layout: dict[str, Any] | None = Field(sa_type=JSON(none_as_null=True))
def __repr__(self):
return '<models.Qml {self.model_name:s}>'.format(self=self)

View file

@ -0,0 +1,118 @@
from typing import Any
import logging
from sqlmodel import Field, SQLModel, MetaData, JSON, TEXT, Relationship, Column
from pydantic import computed_field
import numpy as np
import pandas as pd
import geopandas as gpd
import shapely
from sqlalchemy.sql import sqltypes
from geoalchemy2.types import Geometry
pandas_cast_map = {
sqltypes.Integer: 'Int64',
sqltypes.Float: 'float64',
}
logger = logging.getLogger('model_base_base')
class Model(SQLModel):
"""
Base mixin class for models that can be converted to a Pandas dataframe with get_df
"""
class Meta:
filtered_columns_on_map: list[str] = []
@classmethod
def get_store_name(cls):
return "{}.{}".format(cls.metadata.schema, cls.__tablename__)
@classmethod
def get_table_name_prefix(cls):
return "{}_{}".format(cls.metadata.schema, cls.__tablename__)
@classmethod
async def get_df(cls, where=None,
with_related=None, recursive=True,
cast=True,
with_only_columns=None,
geom_as_ewkt=False,
**kwargs):
"""
Return a Pandas dataframe of all records
Optional arguments:
* an SQLAlchemy where clause
* with_related: automatically get data from related columns, following the foreign keys in the model definitions
* cast: automatically transform various data in their best python types (eg. with date, time...)
* with_only_columns: fetch only these columns (list of column names)
* geom_as_ewkt: convert geometry columns to EWKB (handy for eg. using upsert_df)
:return:
"""
query = cls.query
if with_related is not False:
if with_related or getattr(cls, 'get_gdf_with_related', False):
joins = get_join_with(cls, recursive)
model_loader = cls.load(**joins)
query = _get_query_with_table_names(model_loader)
if where is not None:
query.append_whereclause(where)
if with_only_columns:
query = query.with_only_columns([getattr(cls, colname) for colname in with_only_columns])
## Got idea from https://github.com/MagicStack/asyncpg/issues/173.
async with query.bind.raw_pool.acquire() as conn:
## Convert hstore fields to dict
await conn.set_builtin_type_codec('hstore', codec_name='pg_contrib.hstore')
compiled = query.compile()
stmt = await conn.prepare(compiled.string)
columns = [a.name for a in stmt.get_attributes()]
data = await stmt.fetch(*[compiled.params.get(param) for param in compiled.positiontup])
df = pd.DataFrame(data, columns=columns)
## Convert primary key columns to Int64:
## allows NaN, fixing type convertion to float with merge
for pk in [c.name for c in cls.__table__.primary_key.columns]:
if pk in df.columns and df[pk].dtype=='int64':
df[pk] = df[pk].astype('Int64')
if cast:
## Cast the type for known types (datetime, ...)
for column_name in df.columns:
col = getattr(query.columns, column_name, None)
if col is None:
logger.debug(f'Cannot get column {column_name} in query for model {cls.__name__}')
continue
column_type = getattr(query.columns, column_name).type
## XXX: Needs refinment, eg. nullable -> Int64 ...
if column_type.__class__ in pandas_cast_map:
df[column_name] = df[column_name].astype(pandas_cast_map[column_type.__class__])
elif isinstance(column_type, (sqltypes.Date, sqltypes.DateTime)):
## Dates, times
df[column_name] = pd.to_datetime(df[column_name])
#elif isinstance(column_type, (sqltypes.Integer, sqltypes.Float)):
# ## Numeric
# df[column_name] = pd.to_numeric(df[column_name], errors='coerce')
## XXX: keeping this note about that is about "char" SQL type, but the fix of #9694 makes it unnessary
#elif isinstance(column_type, sqltypes.CHAR) or (isinstance(column_type, sqltypes.String) and column_type.length == 1):
# ## Workaround for bytes being used for string of length 1 (not sure - why???)
# df[column_name] = df[column_name].str.decode('utf-8')
## Rename the columns, removing the schema_table prefix for the columns in that model
prefix = cls.get_table_name_prefix()
prefix_length = len(prefix) + 1
rename_map = {colname: colname[prefix_length:] for colname in df.columns if colname.startswith(prefix)}
df.rename(columns=rename_map, inplace=True)
## Eventually convert geometry columns to EWKB
if geom_as_ewkt:
geometry_columns = [col.name for col in cls.__table__.columns if isinstance(col.type, Geometry)]
for column in geometry_columns:
df[column] = shapely.to_wkb(shapely.from_wkb(df.geom), hex=True, include_srid=True)
return df

225
src/gisaf/models/project.py Normal file
View file

@ -0,0 +1,225 @@
from datetime import datetime
from csv import writer
from collections import defaultdict
from io import BytesIO, StringIO
from sqlmodel import Field, SQLModel, MetaData, JSON, TEXT, Relationship, Column
import pyproj
from shapely.geometry import Point
from ..config import conf
from .models_base import Model
from .metadata import gisaf_admin
class Project(Model):
metadata = gisaf_admin
class Admin:
menu = 'Other'
flask_admin_model_view = 'ProjectModelView'
id: int = Field(default=None, primary_key=True)
name: str
contact_person: str
site: str
date_approved: datetime
start_date_planned: datetime
start_date_effective: datetime
end_date_planned: datetime
end_date_effective: datetime
def __str__(self):
return '{self.name:s}'.format(self=self)
def __repr__(self):
return '<models.Project {self.name:s}>'.format(self=self)
async def auto_import(self, registry):
"""
Import the points of the given project to the GIS DB
in their appropriate models.raw_survey_models
:return: dict of result (stats)
"""
from .category import Category
from .models_base import GeoPointSurveyModel
result = defaultdict(int)
categories = {cat.table_name: cat for cat in await Category.query.gino.all()}
## Define projections
survey_proj = pyproj.Proj(**conf.raw_survey['spatial_sys_ref'])
target_proj = pyproj.Proj(f'epsg:{conf.srid:d}')
def reproject(x, y, z):
return pyproj.transform(survey_proj, target_proj, x, y, z)
## TODO: Gino session
for survey_model_name, raw_survey_model in registry.raw_survey_models.items():
category = categories[survey_model_name]
if not category.auto_import:
continue
survey_model = registry.geom_auto.get(survey_model_name)
if not survey_model:
continue
if not issubclass(survey_model, GeoPointSurveyModel):
continue
raw_survey_items = await raw_survey_model.query.where(raw_survey_model.project_id == self.id).gino.all()
for item in raw_survey_items:
if not item:
continue
new_point = survey_model(
id=item.id,
date=item.date,
accur_id=item.accur_id,
equip_id=item.equip_id,
srvyr_id=item.srvyr_id,
orig_id=item.orig_id,
status=item.status,
project_id=self.id,
)
geom = Point(*reproject(item.easting, item.northing, item.elevation))
new_point.geom = 'SRID={:d};{:s}'.format(conf.srid, geom.wkb_hex)
## TODO: merge with Gino
#session.merge(new_point)
result[survey_model_name] += 1
#session.commit()
return result
async def download_raw_survey_data(self):
from .raw_survey import RawSurveyModel
## FIXME: old query style
breakpoint()
raw_survey_items = await RawSurveyModel.query.where(RawSurveyModel.project_id == self.id).gino.all()
csv_file = StringIO()
csv_writer = writer(csv_file)
for item in raw_survey_items:
csv_writer.writerow(item.to_row())
now = '{:%Y-%m-%d_%H:%M}'.format(datetime.now())
## XXX: not tested (aiohttp)
#return send_file(BytesIO(bytes(csv_file.getvalue(), 'utf-8')),
# attachment_filename='{:s}-{:s}.csv'.format(self.name, now),
# mimetype='text/csv',
# as_attachment=True)
headers = {
'Content-Disposition': 'attachment; filename="{}"'.format('{:s}-{:s}.csv'.format(self.name, now))
}
return web.Response(
status=200,
headers=headers,
content_type='text/csv',
body=BytesIO(bytes(csv_file.getvalue(), 'utf-8'))
)
async def download_reconciled_raw_survey_data(self, registry):
csv_file = StringIO()
csv_writer = writer(csv_file)
for model_name, model in registry.raw_survey_models.items():
survey_items = await model.query.where(model.project_id == self.id).gino.all()
for item in survey_items:
csv_writer.writerow(item.to_row())
now = '{:%Y-%m-%d_%H:%M}'.format(datetime.now())
## XXX: not tested (aiohttp)
#return send_file(BytesIO(bytes(csv_file.getvalue(), 'utf-8')),
# attachment_filename='{:s}-{:s}-reconciled.csv'.format(self.name, now),
# mimetype='text/csv',
# as_attachment=True)
headers = {
'Content-Disposition': 'attachment; filename="{}"'.format('{:s}-{:s}.csv'.format(self.name, now))
}
return web.Response(
status=200,
headers=headers,
content_type='text/csv',
body=BytesIO(bytes(csv_file.getvalue(), 'utf-8'))
)
async def reconcile(self, registry):
from gisaf.models.reconcile import Reconciliation
result = {}
all_reconciliations = await Reconciliation.query.gino.all()
point_ids_to_reconcile = {p.id: registry.raw_survey_models[p.target]
for p in all_reconciliations
if p.target in registry.raw_survey_models}
result['bad target'] = set([p.target for p in all_reconciliations
if p.target not in registry.raw_survey_models])
result['from'] = defaultdict(int)
result['to'] = defaultdict(int)
result['unchanged'] = defaultdict(int)
## TODO: Gino session
for model_name, model in registry.raw_survey_models.items():
points_to_reconcile = await model.query.\
where(model.project_id==self.id).\
where(model.id.in_(point_ids_to_reconcile.keys())).gino.all()
for point in points_to_reconcile:
new_model = point_ids_to_reconcile[point.id]
if new_model == model:
result['unchanged'][model] += 1
continue
new_point = new_model(
id=point.id,
accur_id=point.accur_id,
srvyr_id=point.accur_id,
project_id=point.project_id,
status=point.status,
orig_id=point.orig_id,
equip_id=point.equip_id,
geom=point.geom,
date=point.date
)
## TODO: Gino add and delete
#session.add(new_point)
#session.delete(point)
result['from'][point.__class__] += 1
result['to'][new_point.__class__] += 1
return result
# def download_raw_survey_data(self, session=None):
# from gisaf.models.raw_survey_models import RawSurvey
# from gisaf.registry import registry
# if not session:
# session = db.session
# raw_survey_items = session.query(RawSurvey).filter(RawSurvey.project_id == self.id).all()
# csv_file = StringIO()
# csv_writer = writer(csv_file)
#
# SURVEY_PROJ = pyproj.Proj(**conf.raw_survey['spatial_sys_ref'])
# TARGET_PROJ = pyproj.Proj(init='epsg:{:d}'.format(conf.srid))
#
# def reproject(x, y, z):
# return pyproj.transform(SURVEY_PROJ, TARGET_PROJ, x, y, z)
#
# for item in raw_survey_items:
# csv_writer.writerow(item.to_row())
#
# ## Add import of points, incl. reprojection, to registry.raw_survey_models:
# new_coords = reproject(item.easting, item.northing, item.elevation)
# geom = Point(*new_coords)
# ## TODO: from here
# model = registry.raw_survey_models
# new_point = model(
# id=item.id,
# category=item.category,
# date=item.date,
# accur_id=item.accur_id,
# equip_id=item.equip_id,
# srvyr_id=item.srvyr_id,
# orig_id=item.original_id,
# )
# new_point.geom = 'SRID={:d};{:s}'.format(conf.srid, geom.wkb_hex)
# session.merge(new_point)
# result[item.category_info] += 1
#
# now = '{:%Y-%m-%d_%H:%M}'.format(datetime.now())
#
# return send_file(BytesIO(bytes(csv_file.getvalue(), 'utf-8')),
# attachment_filename='{:s}-{:s}.csv'.format(self.name, now),
# mimetype='text/csv',
# as_attachment=True)

View file

@ -0,0 +1,97 @@
from typing import ClassVar
from sqlmodel import Field, BigInteger
from .models_base import Model
from .geo_models_base import GeoPointMModel, BaseSurveyModel
from .project import Project
from .category import Category
from .metadata import gisaf_survey
class RawSurveyModel(BaseSurveyModel, GeoPointMModel):
metadata = gisaf_survey
__tablename__ = 'raw_survey'
hidden: ClassVar[bool] = True
id: int = Field(default=None, primary_key=True)
project_id: int | None = Field(foreign_key='project.id')
category: str = Field(foreign_key='category.name')
in_menu: bool = False
@classmethod
def dyn_join_with(cls):
return {
'project': Project.on(cls.project_id == Project.id),
'category_info': Category.on(cls.category == Category.name),
}
#id = db.Column(db.BigInteger, primary_key=True)
## XXX: Can remove the rest since it's is in the GeoPointSurveyModel class?
#geom = db.Column(Geometry('POINTZ', srid=conf.raw_survey_srid))
#date = db.Column(db.Date)
#orig_id = db.Column(db.String)
#status = db.Column(db.String(1))
def __str__(self):
return 'Raw Survey point id {:d}'.format(self.id)
def to_row(self):
"""
Get a list of attributes, typically used for exporting in CSV
:return: list of attributes
"""
return [
self.id,
self.easting,
self.northing,
self.elevation,
self.category,
self.surveyor,
self.equipment,
self.date.isoformat(),
self.accuracy.name,
self.category_info.status,
self.project.name,
self.orig_id
]
def auto_import(self, session, model=None, status=None):
"""
Automatically feed the raw_geom get_raw_survey_model_mapping
:return:
"""
if model is None:
# XXX: move as module import?
from gisaf.registry import registry
model = registry.get_raw_survey_model_mapping().get(self.category)
new_point = model(
id=self.id,
geom=self.geom,
date=self.date,
project_id=self.project_id,
equip_id=self.equip_id,
srvyr_id=self.srvyr_id,
accur_id=self.accur_id,
orig_id=self.orig_id,
status=status,
)
session.merge(new_point)
class OriginRawPoint(Model):
"""
Store information of the raw survey point used in the line work
for each line and polygon shape
Filled when importing shapefiles
"""
metadata = gisaf_survey
__tablename__ = 'origin_raw_point'
id: int = Field(default=None, primary_key=True)
shape_table: str = Field(index=True)
shape_id: int = Field(index=True)
raw_point_id: int = Field(sa_type=BigInteger())
def __repr__(self):
return f'<models.OriginRawPoint {self.id:d} {self.shape_table:s} ' \
f'{self.shape_id:d} {self.raw_point_id:d}>'

View file

@ -0,0 +1,43 @@
from datetime import datetime
from sqlalchemy import BigInteger
from sqlmodel import Field, SQLModel, MetaData, JSON, TEXT, Relationship, Column, String
from .models_base import Model
from .metadata import gisaf_admin
class Reconciliation(Model):
metadata = gisaf_admin
class Admin:
menu = 'Other'
flask_admin_model_view = 'ReconciliationModelView'
id: int = Field(primary_key=True, sa_type=BigInteger,
sa_column_kwargs={'autoincrement': False})
target: str = Field(sa_type=String(50))
source: str = Field(sa_type=String(50))
class StatusChange(Model):
metadata = gisaf_admin
__tablename__ = 'status_change'
id: int = Field(primary_key=True, sa_type=BigInteger,
sa_column_kwargs={'autoincrement': False})
store: str = Field(sa_type=String(50))
ref_id: int = Field(sa_type=BigInteger())
original: str = Field(sa_type=String(1))
new: str = Field(sa_type=String(1))
time: datetime
class FeatureDeletion(Model):
metadata = gisaf_admin
__tablename__ = 'feature_deletion'
id: int = Field(BigInteger, primary_key=True,
sa_column_kwargs={'autoincrement': False})
store: str = Field(sa_type=String(50))
ref_id: int = Field(sa_type=BigInteger())
time: datetime

43
src/gisaf/models/store.py Normal file
View file

@ -0,0 +1,43 @@
from typing import Any
from pydantic import BaseModel
from .geo_models_base import GeoModel, RawSurveyBaseModel, GeoPointSurveyModel
class MapLibreStyle(BaseModel):
...
class Store(BaseModel):
auto_import: bool
base_gis_type: str
count: int
custom: bool
description: str
#extra: dict[str, Any] | None
group: str
#icon: str
in_menu: bool
is_db: bool
is_line_work: bool
is_live: bool
long_name: str | None
#mapbox_layout: dict[str, Any] | None
#mapbox_paint: dict[str, Any] | None
#mapbox_type: str
mapbox_type_custom: str | None
#mapbox_type_default: str
minor_group_1: str
minor_group_2: str
#model: GeoModel
model_type: str
name: str
#name_letter: str
#name_number: int
#raw_model: GeoPointSurveyModel
#raw_model_store_name: str
status: str
store: str
style: str | None
symbol: str | None
title: str
viewable_role: str | None
z_index: int

View file

@ -0,0 +1,84 @@
from enum import Enum
from sqlmodel import Field, SQLModel
from .models_base import Model
from .metadata import gisaf_survey
class Accuracy(Model):
metadata = gisaf_survey
class Admin:
menu = 'Other'
flask_admin_model_view = 'MyModelViewWithPrimaryKey'
id: int = Field(default=None, primary_key=True)
name: str
accuracy: float
def __str__(self):
return f'{self.name} {self.accuracy}'
def __repr__(self):
return f'<models.Accuracy {self.name}>'
class Surveyor(Model):
metadata = gisaf_survey
class Admin:
menu = 'Other'
flask_admin_model_view = 'MyModelViewWithPrimaryKey'
id: int = Field(default=None, primary_key=True)
name: str
def __str__(self):
return self.name
def __repr__(self):
return f'<models.Surveyor {self.name}>'
class Equipment(Model):
metadata = gisaf_survey
class Admin:
menu = 'Other'
flask_admin_model_view = 'MyModelViewWithPrimaryKey'
id: int = Field(default=None, primary_key=True)
name: str
def __str__(self):
return self.name
def __repr__(self):
return f'<models.Equipment {self.name}>'
class GeometryType(str, Enum):
point = 'Point'
line_work = 'Line_work'
class AccuracyEquimentSurveyorMapping(Model):
metadata = gisaf_survey
__tablename__ = 'accuracy_equiment_surveyor_mapping'
class Admin:
menu = 'Other'
id: int = Field(default=None, primary_key=True)
surveyor_id: int = Field(foreign_key='surveyor.id', index=True)
equipment_id: int = Field(foreign_key='equipment.id', index=True)
geometry_type: GeometryType = Field(default='Point', index=True)
accuracy_id: int = Field(foreign_key='accuracy.id')
@classmethod
def dyn_join_with(cls):
return {
'surveyor': Surveyor,
'equipment': Equipment,
'accuracy': Accuracy,
}

45
src/gisaf/models/tags.py Normal file
View file

@ -0,0 +1,45 @@
from typing import Any, ClassVar
from sqlalchemy import BigInteger
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.dialects.postgresql import HSTORE
from sqlmodel import Field, SQLModel, MetaData, JSON, TEXT, Relationship, Column
from pydantic import computed_field
from .metadata import gisaf
from .geo_models_base import GeoPointModel
class Tags(GeoPointModel, table=True):
metadata = gisaf
hidden: ClassVar[bool] = True
class Admin:
menu = 'Other'
flask_admin_model_view = 'TagModelView'
id: int | None = Field(primary_key=True)
store: str = Field(index=True)
ref_id: int = Field(index=True, sa_type=BigInteger)
tags: dict = Field(sa_type=MutableDict.as_mutable(HSTORE))
def __str__(self):
return '{self.store:s} {self.ref_id}: {self.tags}'.format(self=self)
def __repr__(self):
return '<models.Tag {self.store:s} {self.ref_id}: {self.tags}>'.format(self=self)
class TagKey(SQLModel, table=True):
metadata = gisaf
## CREATE TABLE gisaf.tagkey (key VARCHAR(255) primary key);
class Admin:
menu = 'Other'
flask_admin_model_view = 'TagKeyModelView'
id: str | None = Field(primary_key=True)
def __str__(self):
return self.key
def __repr__(self):
return '<models.TagKey {self.key}>'.format(self=self)

664
src/gisaf/registry.py Normal file
View file

@ -0,0 +1,664 @@
"""
Define the models for the ORM
"""
import logging
import importlib
import pkgutil
from collections import defaultdict
from importlib.metadata import entry_points
from pydantic import create_model
from sqlalchemy import inspect, text
from sqlalchemy.orm import selectinload
from sqlmodel import select
import numpy as np
import pandas as pd
from .config import conf
from .models import (misc, category as category_module,
project, reconcile, map_bases, tags)
from .models.geo_models_base import (
PlottableModel,
GeoModel,
RawSurveyBaseModel,
LineWorkSurveyModel,
GeoPointSurveyModel,
GeoLineSurveyModel,
GeoPolygonSurveyModel,
)
from .utils import ToMigrate
from .models.category import Category, CategoryGroup
from .database import db_session
from .models.metadata import survey, raw_survey
logger = logging.getLogger(__name__)
category_model_mapper = {
'Point': GeoPointSurveyModel,
'Line': GeoLineSurveyModel,
'Polygon': GeoPolygonSurveyModel,
}
class NotInRegistry(Exception):
pass
def import_submodules(package, recursive=True):
""" Import all submodules of a module, recursively, including subpackages
:param package: package (name or actual module)
:type package: str | module
:param recursive: scan package recursively
:rtype: dict[str, types.ModuleType]
"""
if isinstance(package, str):
package = importlib.import_module(package)
results = {}
for loader, name, is_pkg in pkgutil.walk_packages(package.__path__):
full_name = package.__name__ + '.' + name
results[full_name] = importlib.import_module(full_name)
if recursive and is_pkg:
results.update(import_submodules(full_name))
return results
class ModelRegistry:
"""
Collect, categorize, and initialize the SQLAlchemy data models.
Maintains registries for all kind of model types, eg. geom, data, values...
Provides tools to get the models from their names, table names, etc.
"""
def __init__(self):
"""
Get geo models
:return: None
"""
self.geom_custom = {}
self.geom_custom_store = {}
self.values = {}
self.other = {}
self.misc = {}
self.raw_survey_models = {}
self.survey_models = {}
async def make_registry(self, app=None):
"""
Make (or refresh) the registry of models.
:return:
"""
logger.debug('make_registry')
await self.make_category_models()
self.scan()
await self.build()
## If ogcapi is in app (i.e. not with scheduler):
## Now that the models are refreshed, tells the ogcapi to (re)build
if app:
#app.extra['registry'] = self
if 'ogcapi' in app.extra:
await app.extra['ogcapi'].build()
async def make_category_models(self):
"""
Make geom models from the category model
and update raw_survey_models and survey_models
Important notes:
- the db must be bound before running this function
- the db must be rebound after running this function,
so that the models created are actually bound to the db connection
:return:
"""
logger.debug('make_category_models')
async with db_session() as session:
query = select(Category).order_by(Category.long_name).options(selectinload(Category.category_group))
data = await session.exec(query)
categories: list[Category] = data.all()
for category in categories:
## Several statuses can coexist for the same model, so
## consider only the ones with the 'E' (existing) status
## The other statuses are defined only for import (?)
if getattr(category, 'status', 'E') != 'E':
continue
## Use pydantic create_model, supported by SQLModel
## See https://github.com/tiangolo/sqlmodel/issues/377
store_name = f'{survey.schema}.{category.table_name}'
raw_store_name = f'{raw_survey.schema}.RAW_{category.table_name}'
raw_survey_field_definitions = {
## FIXME: RawSurveyBaseModel.category should be a Category, not category.name
'category_name': (str, category.name),
## FIXME: Same for RawSurveyBaseModel.group
'group_name': (str, category.category_group.name),
'viewable_role': (str, category.viewable_role),
'store_name': (str, raw_store_name),
# 'icon': (str, ''),
# 'icon': (str, ''),
}
## Raw survey points
try:
self.raw_survey_models[store_name] = create_model(
__base__=RawSurveyBaseModel,
__model_name=category.raw_survey_table_name,
__cls_kwargs__={
'table': True,
'metadata': raw_survey,
'__tablename__': category.raw_survey_table_name,
## FIXME: RawSurveyBaseModel.category should be a Category, not category.name
'category_name': category.name,
## FIXME: Same for RawSurveyBaseModel.group
'group_name': category.category_group.name,
'viewable_role': category.viewable_role,
'store_name': raw_store_name,
},
# **raw_survey_field_definitions
)
except Exception as err:
logger.exception(err)
logger.warning(err)
else:
logger.debug('Discovered {:s}'.format(category.raw_survey_table_name))
model_class = category_model_mapper.get(category.model_type)
## Final geometries
try:
if model_class:
survey_field_definitions = {
'category_name': (str, category.name),
'group_name': (str, category.category_group.name),
'raw_store_name': (str, raw_store_name),
'viewable_role': (str, category.viewable_role),
'symbol': (str, category.symbol),
#'raw_model': (str, self.raw_survey_models.get(raw_store_name)),
# 'icon': (str, f'{survey.schema}-{category.table_name}'),
}
self.survey_models[store_name] = create_model(
__base__= model_class,
__model_name=category.table_name,
__cls_kwargs__={
'table': True,
'metadata': survey,
'__tablename__': category.table_name,
'category_name': category.name,
'group_name': category.category_group.name,
'raw_store_name': raw_store_name,
'viewable_role': category.viewable_role,
'symbol': category.symbol,
},
# **survey_field_definitions,
)
except Exception as err:
logger.warning(err)
else:
logger.debug('Discovered {:s}'.format(category.table_name))
logger.info('Discovered {:d} models'.format(len(categories)))
def scan(self):
"""
Scan all models defined explicitely (not the survey ones,
which are defined by categories), and store them for reference.
"""
logger.debug('scan')
from . import models # nocheck
## Scan the models defined in modules
for module_name, module in import_submodules(models).items():
if module_name in (
'src.gisaf.models.geo_models_base',
'src.gisaf.models.models_base',
):
continue
for name in dir(module):
obj = getattr(module, name)
if hasattr(obj, '__module__') and obj.__module__.startswith(module.__name__)\
and hasattr(obj, '__tablename__') and hasattr(obj, 'get_store_name'):
model_type = self.add_model(obj)
logger.debug(f'Model {obj.get_store_name()} added in the registry from gisaf source tree as {model_type}')
## Scan the models defined in plugins (setuptools' entry points)
for module_name, model in self.scan_entry_points(name='gisaf_extras.models').items():
model_type = self.add_model(model)
logger.debug(f'Model {model.get_store_name()} added in the registry from {module_name} entry point as {model_type}')
for module_name, store in self.scan_entry_points(name='gisaf_extras.stores').items():
self.add_store(store)
logger.debug(f'Store {store} added in the registry from {module_name} gisaf_extras.stores entry point')
## Add misc models
for module in misc, category_module, project, reconcile, map_bases, tags:
for name in dir(module):
obj = getattr(module, name)
if hasattr(obj, '__module__') and hasattr(obj, '__tablename__'):
self.misc[name] = obj
async def build(self):
"""
Build the registry: organize all models in a common reference point.
This should be executed after the discovery of surey models (categories)
and the scan of custom/module defined models.
"""
logger.debug('build')
## Combine all geom models (auto and custom)
self.geom = {**self.survey_models, **self.geom_custom}
await self.make_stores()
## Some lists of table, by usage
## XXX: Gino: doesn't set __tablename__ and __table__ , or engine not started???
## So, hack the table names of auto_geom
#self.geom_tables = [model.__tablename__
#self.geom_tables = [getattr(model, "__tablename__", None)
# for model in sorted(list(self.geom.values()),
# key=lambda a: a.z_index)]
values_tables = [model.__tablename__ for model in self.values.values()]
other_tables = [model.__tablename__ for model in self.other.values()]
self.data_tables = values_tables + other_tables
## Build a dict for quick access to the values from a model
logger.warn(ToMigrate('get_geom_model_from_table_name, only used for values_for_model'))
self.values_for_model = {}
for model_value in self.values.values():
for constraint in inspect(model_value).foreign_key_constraints:
model = self.get_geom_model_from_table_name(constraint.referred_table.name)
self.values_for_model[model] = model_value
self.make_menu()
def scan_entry_points(self, name):
"""
Get the entry points in gisaf_extras.models, and return their models
:return: dict of name: models
"""
named_objects = {}
for entry_point in entry_points().select(group=name):
try:
named_objects.update({entry_point.name: entry_point.load()})
except ModuleNotFoundError as err:
logger.warning(err)
return named_objects
def add_model(self, model):
"""
Add the model
:return: Model type (one of {'GeoModel', 'PlottableModel', 'Other model'})
"""
# if not hasattr(model, 'get_store_name'):
# raise NotInRegistry()
table_name = model.get_store_name()
if issubclass(model, GeoModel) and not issubclass(model, RawSurveyBaseModel) and not model.hidden:
self.geom_custom[table_name] = model
return 'GeoModel'
elif issubclass(model, PlottableModel):
self.values[table_name] = model
return 'PlottableModel'
else:
self.other[table_name] = model
return 'Other model'
def add_store(self, store):
self.geom_custom_store[store.name] = store
def make_menu(self):
"""
Build the Admin menu
:return:
"""
self.menu = defaultdict(list)
for name, model in self.stores.model.items():
if hasattr(model, 'Admin'):
self.menu[model.Admin.menu].append(model)
# def get_raw_survey_model_mapping(self):
# """
# Get a mapping of category_name -> model for categories
# :return: dict of name -> model (class)
# """
# ## TODO: add option to pass a single item
# ## Local imports, avoiding cyclic dependencies
# ## FIXME: Gino
# categories = db.session.query(Category)
# return {category.name: self.raw_survey_models[category.table_name]
# for category in categories
# if self.raw_survey_models.get(category.table_name)}
async def get_model_id_params(self, model, id):
"""
Return the parameters for this item (table name, id), displayed in info pane
"""
if not model:
return {}
item = await model.load(**model.get_join_with()).query.where(model.id==id).gino.first()
if not item:
return {}
resp = {}
resp['itemName'] = item.caption
resp['geoInfoItems'] = await item.get_geo_info()
resp['surveyInfoItems'] = await item.get_survey_info()
resp['infoItems'] = await item.get_info()
resp['tags'] = await item.get_tags()
if hasattr(item, 'get_categorized_info'):
resp['categorized_info_items'] = await item.get_categorized_info()
if hasattr(item, 'get_graph'):
resp['graph'] = item.get_graph()
if hasattr(item, 'Attachments'):
if hasattr(item.Attachments, 'files'):
resp['files'] = await item.Attachments.files(item)
if hasattr(item.Attachments, 'images'):
resp['images'] = await item.Attachments.images(item)
if hasattr(item, 'get_external_record_url'):
resp['externalRecordUrl'] = item.get_external_record_url()
return resp
def get_geom_model_from_table_name(self, table_name):
"""
Utility func to get a geom model from a table name
:param table_name: str
:return: model or None
"""
for model in self.geom.values():
if model.__tablename__ == table_name:
return model
def get_other_model_from_table_name(self, table_name):
"""
Utility func to get a non-geom model from a table name
:param table_name: str
:return: model or None
"""
for model in registry.other.values():
if model.__tablename__ == table_name:
return model
for model in registry.values.values():
if model.__tablename__ == table_name:
return model
async def make_stores(self):
"""
Make registry for primary groups, categories and survey stores using Pandas dataframes.
Used in GraphQl queries.
"""
## Utility functions used with apply method (dataframes)
def fill_columns_from_custom_models(row):
return (
## FIXME: Like: 'AVESHTEquipment'
row.model.__namespace__['__qualname__'], ## Name of the class - hacky
row.model.description,
## FIXME: Like: 'other_aves'
row.model.__table__.schema
)
def fill_columns_from_custom_stores(row):
return (
row.model.description,
row.model.description,
None ## Schema
)
def get_store_name(category):
fragments = ['V', category.group, category.minor_group_1]
if category.minor_group_2 != '----':
fragments.append(category.minor_group_2)
return '.'.join([
survey.schema,
'_'.join(fragments)
])
self.categories = await Category.get_df()
self.categories['title'] = self.categories.long_name.fillna(self.categories.description)
self.categories['store'] = self.categories.apply(get_store_name, axis=1)
self.categories['count'] = pd.Series(dtype=pd.Int64Dtype())
self.categories.set_index('name', inplace=True)
df_models = pd.DataFrame(self.geom.items(),
columns=['store', 'model']
).set_index('store')
df_raw_models = pd.DataFrame(self.raw_survey_models.items(),
columns=('store', 'raw_model')
).set_index('store')
self.categories = self.categories.merge(df_models, left_on='store', right_index=True)
self.categories = self.categories.merge(df_raw_models, left_on='store', right_index=True)
self.categories['custom'] = False
self.categories['is_db'] = True
self.categories.sort_index(inplace=True)
# self.categories['name_letter'] = self.categories.index.str.slice(0, 1)
# self.categories['name_number'] = self.categories.index.str.slice(1).astype('int64')
# self.categories.sort_values(['name_letter', 'name_number'], inplace=True)
## Set in the stores dataframe some useful properties, from the model class
## Maybe at some point it makes sense to get away from class-based definitions
if len(self.categories) > 0:
## XXX: redundant self.categories['store_name'] with self.categories['store']
#self.categories['store_name'] = self.categories.apply(
# lambda row: row.model.get_store_name(),
# axis=1
#)
#self.categories['raw_model_store_name'] = self.categories.apply(
# lambda row: row.raw_model.store_name,
# axis=1
#)
self.categories['is_line_work'] = self.categories.apply(
lambda row: issubclass(row.model, LineWorkSurveyModel),
axis=1
)
else:
self.categories['store_name'] = None
self.categories['raw_model_store_name'] = None
self.categories['is_line_work'] = None
self.categories['raw_survey_model'] = None
## Custom models (Misc)
self.custom_models = pd.DataFrame(
self.geom_custom.items(),
columns=['store', 'model']
).set_index('store')
self.custom_models['group'] = 'Misc'
self.custom_models['custom'] = True
self.custom_models['is_db'] = True
self.custom_models['raw_model_store_name'] = ''
self.custom_models['in_menu'] = self.custom_models.apply(
lambda row: getattr(row.model, 'in_menu', True),
axis=1
)
self.custom_models = self.custom_models.loc[self.custom_models.in_menu]
self.custom_models['auto_import'] = False
self.custom_models['is_line_work'] = False
if len(self.custom_models) > 0:
self.custom_models['long_name'],\
self.custom_models['custom_description'],\
self.custom_models['db_schema'],\
= zip(*self.custom_models.apply(fill_columns_from_custom_models, axis=1))
## Try to give a meaningful description, eg. including the source (db_schema)
self.custom_models['description'] = self.custom_models['custom_description'].fillna(self.custom_models['long_name'] + '-' + self.custom_models['db_schema'])
self.custom_models['title'] = self.custom_models['long_name']
## Custom stores (Community)
self.custom_stores = pd.DataFrame(
self.geom_custom_store.items(),
columns=['store', 'model']
).set_index('store')
self.custom_stores['group'] = 'Community'
self.custom_stores['custom'] = True
self.custom_stores['is_db'] = False
if len(self.custom_stores) == 0:
self.custom_stores['in_menu'] = False
else:
self.custom_stores['in_menu'] = self.custom_stores.apply(
lambda row: getattr(row.model, 'in_menu', True),
axis=1
)
self.custom_stores = self.custom_stores.loc[self.custom_stores.in_menu]
self.custom_stores['auto_import'] = False
self.custom_stores['is_line_work'] = False
if len(self.custom_stores) > 0:
self.custom_stores['long_name'],\
self.custom_stores['description'],\
self.custom_stores['db_schema'],\
= zip(*self.custom_stores.apply(fill_columns_from_custom_stores, axis=1))
self.custom_stores['title'] = self.custom_stores['long_name']
## Combine Misc (custom) and survey (auto) stores
## Retain only one status per category (defaultStatus, 'E'/existing by default)
self.stores = pd.concat([
self.categories[self.categories.status==conf.map.defaultStatus[0]].reset_index().set_index('store').sort_values('title'),
self.custom_models,
self.custom_stores
])#.drop(columns=['store_name'])
self.stores['in_menu'] = self.stores['in_menu'].astype(bool)
## Set in the stores dataframe some useful properties, from the model class
## Maybe at some point it makes sense to get away from class-based definitions
def fill_columns_from_model(row):
return (
# row.model.icon,
# row.model.symbol,
row.model.mapbox_type, # or None,
row.model.base_gis_type,
row.model.z_index,
)
# self.stores['icon'],\
# self.stores['symbol'],\
self.stores['mapbox_type_default'],\
self.stores['base_gis_type'],\
self.stores['z_index']\
= zip(*self.stores.apply(fill_columns_from_model, axis=1))
#self.stores['mapbox_type_custom'] = self.stores['mapbox_type_custom'].replace('', np.nan).fillna(np.nan)
self.stores['mapbox_type'] = self.stores['mapbox_type_custom'].fillna(
self.stores['mapbox_type_default']
)
self.stores['viewable_role'] = self.stores.apply(
lambda row: getattr(row.model, 'viewable_role', None),
axis=1,
)
self.stores['viewable_role'].replace('', None, inplace=True)
#self.stores['gql_object_type'] = self.stores.apply(make_model_gql_object_type, axis=1)
self.stores['is_live'] = False
self.stores['description'].fillna('', inplace=True)
## Layer groups: Misc, survey's primary groups, Live
self.primary_groups = await CategoryGroup.get_df()
self.primary_groups.sort_values('name', inplace=True)
self.primary_groups['title'] = self.primary_groups['long_name']
## Add Misc and Live
self.primary_groups.loc[-1] = (
'Misc',
False,
'Misc and old layers (not coming from our survey; they will be organized, '
'eventually as the surveys get more complete)',
'Misc',
)
self.primary_groups.index = self.primary_groups.index + 1
self.primary_groups.loc[len(self.primary_groups)] = (
'Live',
False,
'Layers from data processing, sensors, etc, and are updated automatically',
'Live',
)
self.primary_groups.loc[len(self.primary_groups)] = (
'Community',
False,
'Layers from community',
'Community',
)
self.primary_groups.sort_index(inplace=True)
#def make_group(group):
# return GeomGroup(
# name=group['name'],
# title=group['title'],
# description=group['long_name']
# )
#self.primary_groups['gql_object_type'] = self.primary_groups.apply(make_group, axis=1)
await self.update_stores_counts()
async def get_stores(self):
"""
Get information about the available stores
"""
raise DeprecationWarning('get_stores was for graphql')
async def update_stores_counts(self):
"""
Update the counts of the stores fro the DB
"""
query = "SELECT schemaname, relname, n_live_tup FROM pg_stat_user_tables"
# async with db.acquire(reuse=False) as connection:
async with db_session() as session:
rows = await session.exec(text(query))
all_tables_count = pd.DataFrame(rows, columns=['schema', 'table', 'count'])
all_tables_count['store'] = all_tables_count['schema'] + '.' + all_tables_count['table']
all_tables_count.set_index(['store'], inplace=True)
## TODO: a DB VACUUM can be triggered if all counts are 0?
## Update the count in registry's stores
self.stores.loc[:, 'count'] = all_tables_count['count']
# ## FIXME: count for custom stores
# store_df = self.stores.loc[(self.stores['count'] != 0) | (self.stores['is_live'])]
# def set_count(row):
# row.gql_object_type.count = row['count']
# store_df[store_df.is_db].apply(set_count, axis=1)
# return store_df.gql_object_type.to_list()
#def update_live_layers(self, live_models: List[GeomModel]):
#raise ToMigrate('make_model_gql_object_type')
def update_live_layers(self, live_models):
"""
Update the live layers in the registry, using the provided list of GeomModel
"""
## Remove existing live layers
self.stores.drop(self.stores[self.stores.is_live==True].index, inplace=True)
## Add provided live layers
## Ideally, should be vectorized
for model in live_models:
self.stores.loc[model.store] = {
'description': model.description,
'group': model.group,
'name': model.name,
'gql_object_type': model,
'is_live': True,
'is_db': False,
'custom': True,
}
# Accessible as global
registry: ModelRegistry = ModelRegistry()
## Below, some unused code, maybe to be used later for displaying layers in a tree structure
## Some magic for making a tree from enumarables,
## https://gist.github.com/hrldcpr/2012250
#Tree = lambda: defaultdict(Tree)
#
#
#def add(t, path):
# for node in path:
# t = t[node]
#
#
#dicts = lambda t: {k: dicts(t[k]) for k in t}
#
#
#def get_geom_models_tree():
# tree = Tree()
# for model in models.geom_custom:
# full_name = model.__module__[len('gisaf.models')+1:]
# add(tree, full_name.split('.'))
# add(tree, full_name.split('.') + [model])
# return dicts(tree)

154
src/gisaf/security.py Normal file
View file

@ -0,0 +1,154 @@
from datetime import datetime, timedelta
import logging
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from passlib.context import CryptContext
from passlib.exc import UnknownHashError
from pydantic import BaseModel
from sqlmodel.ext.asyncio.session import AsyncSession
from jose import JWTError, jwt, ExpiredSignatureError
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from .config import conf
from .database import db_session
from .models.authentication import User, UserRead
logger = logging.getLogger(__name__)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
class Token(BaseModel):
access_token: str
token_type: str
class TokenData(BaseModel):
username: str | None = None
# class User(BaseModel):
# username: str
# email: str | None = None
# full_name: str | None = None
# disabled: bool | None = None
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
def get_password_hash(password: str):
return pwd_context.hash(password)
async def delete_user(session: AsyncSession, username: str) -> None:
user_in_db: User | None = await get_user(session, username)
if user_in_db is None:
raise SystemExit(f'User {username} does not exist in the database')
await session.delete(user_in_db)
async def enable_user(session: AsyncSession, username: str, enable=True):
user_in_db: UserRead | None = await get_user(session, username)
if user_in_db is None:
raise SystemExit(f'User {username} does not exist in the database')
user_in_db.disabled = not enable # type: ignore
session.add(user_in_db)
await session.commit()
async def create_user(session: AsyncSession, username: str, password: str, full_name: str,
email: str, **kwargs):
user_in_db: User | None = await get_user(session, username)
if user_in_db is None:
user = User(
username=username,
password=get_password_hash(password),
full_name=full_name,
email=email,
disabled=False
)
session.add(user)
else:
user_in_db.full_name = full_name # type: ignore
user_in_db.email = email # type: ignore
user_in_db.password = get_password_hash(password) # type: ignore
await session.commit()
async def get_user(
session: AsyncSession,
username: str) -> (User | None):
query = select(User).where(User.username==username).options(selectinload(User.roles))
data = await session.exec(query)
return data.scalar()
def verify_password(user: User, plain_password):
try:
return pwd_context.verify(plain_password, user.password)
except UnknownHashError:
logger.warning(f'Password not encrypted in DB for {user.username}, assuming it is stored in plain text')
return plain_password == user.password
async def get_current_user(
token: str = Depends(oauth2_scheme)) -> UserRead | None:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
expired_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has expired",
headers={"WWW-Authenticate": "Bearer"},
)
if token is None:
return None
try:
payload = jwt.decode(token, conf.crypto.secret,
algorithms=[conf.crypto.algorithm])
username: str = payload.get("sub", '')
if username == '':
raise credentials_exception
token_data = TokenData(username=username)
except ExpiredSignatureError:
raise expired_exception
except JWTError:
raise credentials_exception
async with db_session() as session:
user = await get_user(session, username=token_data.username)
if user is None:
raise credentials_exception
return user
async def authenticate_user(username: str, password: str):
async with db_session() as session:
user = await get_user(session, username)
if not user:
return False
if not verify_password(user, password):
return False
return user
# async def get_current_active_user(
# current_user: Annotated[UserRead, Depends(get_current_user)]):
# if current_user.disabled:
# raise HTTPException(status_code=400, detail="Inactive user")
# return current_user
def create_access_token(data: dict, expires_delta: timedelta):
to_encode = data.copy()
expire = datetime.utcnow() + expires_delta
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode,
conf.crypto.secret,
algorithm=conf.crypto.algorithm)
return encoded_jwt

400
src/gisaf/utils.py Normal file
View file

@ -0,0 +1,400 @@
import logging
import asyncio
from functools import wraps
from json import dumps, JSONEncoder
from math import isnan
from time import time
import datetime
import pyproj
from numpy import ndarray
import pandas as pd
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.sql.expression import delete
# from graphene import ObjectType
from .config import conf
class ToMigrate(Exception):
pass
SHAPELY_TYPE_TO_MAPBOX_TYPE = {
'Point': 'symbol',
'LineString': 'line',
'Polygon': 'fill',
'MultiPolygon': 'fill',
}
DEFAULT_MAPBOX_LAYOUT = {
'symbol': {
'text-line-height': 1,
'text-padding': 0,
'text-allow-overlap': True,
'text-field': '\ue32b',
'icon-optional': True,
'text-font': ['GisafSymbols'],
'text-size': 24,
}
}
DEFAULT_MAPBOX_PAINT = {
'symbol': {
'text-translate-anchor': 'viewport',
'text-color': '#000000',
},
'line': {
'line-color': 'red',
'line-opacity': 0.70,
'line-width': 2,
'line-blur': 0.5,
},
'fill': {
'fill-color': 'blue',
'fill-opacity': 0.50,
}
}
MAPBOX_COLOR_ATTRIBUTE_NAME = {
'symbol': 'text-color',
'line': 'line-color',
'fill': 'fill-color',
}
MAPBOX_OPACITY_ATTRIBUTE_NAME = {
'symbol': 'text-opacity',
'line': 'line-opacity',
'fill': 'fill-opacity',
}
gisTypeSymbolMap = {
'Point': '\ue32b',
'Line': '\ue32c',
'Polygon': '\ue32d',
'MultiPolygon': '\ue32d',
}
# survey_to_db_project_func = pyproj.Transformer.from_crs(
# conf.geo.raw_survey.spatial_sys_ref,
# conf.geo.srid,
# always_xy=True
# ).transform
class NumpyEncoder(JSONEncoder):
"""
Encoder that can serialize numpy arrays and datetime objects
"""
def default(self, obj):
if isinstance(obj, datetime.datetime):
return obj.isoformat()
if isinstance(obj, datetime.date):
return obj.isoformat()
if isinstance(obj, datetime.timedelta):
return (datetime.datetime.min + obj).time().isoformat()
if isinstance(obj, ndarray):
#return obj.tolist()
## TODO: convert nat to None
return [None if isinstance(rr, float) and isnan(rr) else rr for rr in obj]
if isinstance(obj, float) and isnan(obj):
return None
if isinstance(obj, bytes):
return obj.decode()
return JSONEncoder.default(self, obj)
# class GraphQlObjectTypeEncoder(JSONEncoder):
# """
# Encoder that can serialize basic Graphene ObjectTypes
# """
# def default(self, obj):
# if isinstance(obj, datetime.datetime):
# return obj.isoformat()
# if isinstance(obj, datetime.date):
# return obj.isoformat()
# if isinstance(obj, ObjectType):
# return obj.__dict__
# def json_response(data, body=None, status=200,
# reason=None, headers=None, content_type='application/json', check_circular=True,
# **kwargs):
# text = dumps(data, cls=NumpyEncoder, separators=(',', ':'), check_circular=check_circular)
# return web.Response(text=text, body=body, status=status, reason=reason,
# headers=headers, content_type=content_type, **kwargs)
def get_join_with(cls, recursive=True):
"""
Helper function for loading related tables with a Gino loader (left outer join)
Should work recursively...
Eg:
cls.load(**get_join_with(cls)).query.gino.all()
:param cls:
:return:
"""
if hasattr(cls, 'dyn_join_with'):
joins = cls.dyn_join_with()
else:
joins = {}
if hasattr(cls, '_join_with'):
joins.update(cls._join_with)
if not recursive:
return joins
recursive_joins = {}
for name, join in joins.items():
more_joins = get_join_with(join)
if more_joins:
aliased = {name: join.alias() for name, join in more_joins.items()}
recursive_joins[name] = join.load(**aliased)
else:
recursive_joins[name] = join
return recursive_joins
def get_joined_query(cls):
"""
Helper function to get a query from a model with all the related tables loaded
:param cls:
:return:
"""
return cls.load(**get_join_with(cls)).query
def timeit(f):
"""
Decorator for timing *non async* methods (development tool for performance analysis)
"""
@wraps(f)
def wrap(*args, **kw):
ts = time()
result = f(*args, **kw)
te = time()
logging.debug('func:{} args:{}, {} took: {:2.4f} sec'.format(f.__name__, args, kw, te-ts))
return result
return wrap
def atimeit(func):
"""
Decorator for timing *async* methods (development tool for performance analysis)
"""
async def process(func, *args, **params):
if asyncio.iscoroutinefunction(func):
#logging.debug('this function is a coroutine: {}'.format(func.__name__))
return await func(*args, **params)
else:
#logging.debug('this is not a coroutine')
return func(*args, **params)
async def helper(*args, **params):
#logging.debug('{}.time'.format(func.__name__))
start = time()
result = await process(func, *args, **params)
# Test normal function route...
# result = await process(lambda *a, **p: print(*a, **p), *args, **params)
logging.debug("{} {}".format(func.__name__, time() - start))
return result
return helper
async def delete_df(df, model):
"""
Delete all data in the model's table in the database
that matches data in the pandas dataframe.
"""
table = model.__table__
ids = df.reset_index()['id'].values
delete_stmt = delete(table).where(model.id.in_(ids))
async with db.bind.raw_pool.acquire() as conn:
async with conn.transaction():
await conn.execute(str(delete_stmt), *ids)
async def upsert_df(df, model):
"""
Insert or update all data in the model's table in the database
that's present in the pandas dataframe.
Use postgres insert ... on conflict update...
with a series of inserts with with one row at a time.
For GeoDataFrame: the "geometry" column (df._geometry_column_name) is not honnored
(yet). It's the caller's responsibility to have a proper column name
(typically "geom" in Gisaf models) with a EWKT or EWKB representation of the geometry.
"""
## See: https://stackoverflow.com/questions/33307250/postgresql-on-conflict-in-sqlalchemy
if len(df) == 0:
return df
table = model.__table__
## Generate the 'upsert' statement, using fake values but defining columns
columns = {c.name for c in table.columns}
values = {col: None for col in df.columns if col in columns}
insrt_stmnt = insert(table, inline=True, values=values, returning=table.primary_key.columns)
df_columns = set(df.columns)
do_update_stmt = insrt_stmnt.on_conflict_do_update(
constraint=table.primary_key,
set_={
k.name: getattr(insrt_stmnt.excluded, k.name)
for k in insrt_stmnt.excluded
if k.name in df_columns and
k.name not in [c.name for c in table.primary_key.columns]
}
)
## Filter and reorder the df columns
## in order to match the order of columns in the insert statement
df = df[[col for col in do_update_stmt.compile().positiontup
if col in df_columns]].copy()
def convert_to_object(value):
"""
Quick (but slow) and dirty: clean up values (nan, nat) for inserting to postgres via asyncpg
"""
if isinstance(value, float) and isnan(value):
return None
elif pd.isna(value):
return None
else:
return value
# def encode_geometry(geometry):
# if not hasattr(geometry, '__geo_interface__'):
# raise TypeError('{g} does not conform to '
# 'the geo interface'.format(g=geometry))
# shape = shapely.geometry.asShape(geometry)
# return shapely.wkb.dumps(shape)
# def decode_geometry(wkb):
# return shapely.wkb.loads(wkb)
## pks: list of dicts of primary keys
pks = {pk.name: [] for pk in table.primary_key.columns}
async with db.bind.raw_pool.acquire() as conn:
## Set standard encoder for HSTORE, geometry
await conn.set_builtin_type_codec('hstore', codec_name='pg_contrib.hstore')
#await conn.set_type_codec(
# 'geometry', # also works for 'geography'
# encoder=encode_geometry,
# decoder=decode_geometry,
# format='binary',
#)
#await conn.set_type_codec(
# 'json',
# encoder=json.dumps,
# decoder=json.loads,
# schema='pg_catalog'
#)
## For a sequence of inserts:
insrt_stmnt_single = await conn.prepare(str(do_update_stmt))
async with conn.transaction():
for row in df.itertuples(index=False):
converted_row = [convert_to_object(v) for v in row]
returned = await insrt_stmnt_single.fetch(*converted_row)
for returned_single in returned:
for pk, value in returned_single.items():
pks[pk].append(value)
## Return a copy of the original df, with actual DB columns, data and the primary keys
for pk, values in pks.items():
df[pk] = values
return df
#async def upsert_df(df, model):
# """
# Experiment with port of pandas.io.sql port for asyncpg: sql_async
# """
# from .sql_async import SQLDatabase, SQLTable
#
# table = model.__table__
#
# async with db.bind.raw_pool.acquire() as conn:
# sql_db = SQLDatabase(conn)
# result = await sql_db.to_sql(df, table.name, if_exists='replace', index=False)
# return f'{len(df)} records imported (create or update)'
#async def upsert_df_bulk(df, model):
# """
# Insert or update all data in the pandas dataframe to the model's table in the database.
# Use postgres insert ... on conflict update...
# in a bulk insert with all data in one request.
# """
# raise NotImplementedError('Needs fix, use upsert_df instead')
# ## See: https://stackoverflow.com/questions/33307250/postgresql-on-conflict-in-sqlalchemy
# insrt_vals = df.to_dict(orient='records')
#
# insrt_stmnt = insert(model.__table__).values(insrt_vals)
# do_update_stmt = insrt_stmnt.on_conflict_do_update(
# constraint=model.__table__.primary_key,
# set_={
# k.name: getattr(insrt_stmnt.excluded, k.name)
# for k in insrt_stmnt.excluded
# if k.name not in [c.name for c in model.__table__.primary_key.columns]
# }
# )
# async with db.bind.raw_pool.acquire() as conn:
# ## For a sequence of inserts:
# insrt_stmnt_single = await conn.prepare(str(insert(model.__table__)))
# async with conn.transaction():
# ## TODO: flatten the insrt_vals so that they match the request's $n placeholders
# await conn.execute(do_update_stmt, insrt_vals)
#def encode_geometry(geometry):
# if not hasattr(geometry, '__geo_interface__'):
# raise TypeError('{g} does not conform to '
# 'the geo interface'.format(g=geometry))
# shape = shapely.geometry.asShape(geometry)
# geos.lgeos.GEOSSetSRID(shape._geom, conf.raw_survey['srid'])
# return shapely.wkb.dumps(shape, include_srid=True)
#def decode_geometry(wkb):
# return shapely.wkb.loads(wkb)
## XXX: dev notes
## What's the best way to save a dataframe to the DB?
## 1/ df.to_sql might have been an easy solution, doesn't support async operations
#
## 2/ Experiment with COPY (copy_records_to_table, see below): it doesn't update records.
#async with db.bind.raw_pool.acquire() as conn:
# await conn.set_type_codec(
# 'geometry', # also works for 'geography'
# encoder=encode_geometry,
# decoder=decode_geometry,
# format='binary',
# )
# async with conn.transaction():
# ## See https://github.com/MagicStack/asyncpg/issues/245
# s = await conn.copy_records_to_table(
# model.__table__.name,
# schema_name=model.__table__.schema,
# records=[tuple(x) for x in gdf_for_db.values],
# columns=list(gdf_for_db.columns),
# timeout=10
# )
#
## 3/ SqlAclhemy/Asyncpg multiple inserts, then updates
### Build SQL statements
#insert = db.insert(model.__table__).compile()
#update = db.update(model.__table__).compile()
### Reorder the columns of the dataframe
#gdf_for_db = gdf_for_db[insert.positiontup]
### Find the records whose id already present in the DB, and segregate the df
#existing_records = await model.get_df(with_only_columns=['id'])
#gdf_insert = gdf_for_db[~gdf_for_db.id.isin(existing_records.id)]
#gdf_update = gdf_for_db[gdf_for_db.id.isin(existing_records.id)]
#async with db.bind.raw_pool.acquire() as conn:
# await conn.executemany(insert.string, [tuple(x) for x in gdf_insert.values])
# await conn.executemany(update.string, [tuple(x) for x in gdf_update.values])
##
## 4/ Fall back to gino. Bad luck, there's no equivalent to "merge", so the strategy is:
## - get all records ids in DB
## - build the set of records that needs update, and other that needs insert
## - do these operations (possibly in bulk)
#
## 5/ Make a utility lib for other use cases...