WIP: attempt to fix issue with get_df (column names)

Cleanup & cosmetic
This commit is contained in:
phil 2024-03-03 23:50:48 +05:30
parent bcfda603be
commit d2c2e6cc69
4 changed files with 52 additions and 22 deletions

View file

@ -8,8 +8,7 @@ from typing import Annotated
from asyncio import CancelledError from asyncio import CancelledError
from fastapi import (Depends, APIRouter, HTTPException, Response, Header, from fastapi import (Depends, APIRouter, HTTPException, Response, Header,
WebSocket, WebSocketDisconnect, WebSocket, WebSocketDisconnect, status)
status, responses)
from gisaf.models.authentication import User from gisaf.models.authentication import User
from gisaf.redis_tools import store as redis_store from gisaf.redis_tools import store as redis_store

View file

@ -92,7 +92,10 @@ async def get_base_style(request: Request, name: str,
async def get_layer_style(request: Request, store: str, async def get_layer_style(request: Request, store: str,
response: Response, response: Response,
) -> MaplibreStyle | None: ) -> MaplibreStyle | None:
store_record = registry.stores.loc[store] try:
store_record = registry.stores.loc[store]
except KeyError:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if store_record.is_live: if store_record.is_live:
## No ttag for live layers' style (could be added?) ## No ttag for live layers' style (could be added?)
## Get layer_defs from live redis and give symbol ## Get layer_defs from live redis and give symbol

View file

@ -5,12 +5,13 @@ from collections.abc import AsyncGenerator
from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import joinedload, QueryableAttribute from sqlalchemy.orm import joinedload, QueryableAttribute
from sqlalchemy.sql.selectable import Select from sqlalchemy.sql.selectable import Select
from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL
from sqlmodel import SQLModel, select from sqlmodel import SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from fastapi import Depends from fastapi import Depends
# from geoalchemy2.functions import ST_SimplifyPreserveTopology # from geoalchemy2.functions import ST_SimplifyPreserveTopology
import pandas as pd import pandas as pd
import geopandas as gpd import geopandas as gpd # type: ignore
from gisaf.config import conf from gisaf.config import conf
@ -58,9 +59,7 @@ class BaseModel(SQLModel):
@classmethod @classmethod
async def get_gdf(cls, **kwargs) -> gpd.GeoDataFrame: async def get_gdf(cls, **kwargs) -> gpd.GeoDataFrame:
return await cls._get_df(geopandas_query, return await cls._get_df(geopandas_query, model=cls, **kwargs) # type: ignore
model=cls,
**kwargs)
@classmethod @classmethod
async def _get_df(cls, method, *, async def _get_df(cls, method, *,
@ -73,14 +72,31 @@ class BaseModel(SQLModel):
query = select(cls) query = select(cls)
else: else:
columns = set(with_only_columns) columns = set(with_only_columns)
columns.add(*(col.name for col in cls.__table__.primary_key.columns)) columns.add(*(col.name for col in cls.__table__.primary_key.columns)) # type: ignore
query = select(*(getattr(cls, col) for col in columns)) query = select(*(getattr(cls, col) for col in columns))
if where is not None: if where is not None:
query = query.where(where) query = query.where(where)
## Get the joined tables ## Get the joined tables
joined_tables = cls.selectinload() if with_related:
joined_tables = cls.selectinload()
# Rename columns of the primary table
col_prefix = f'{cls.__table__.schema}_{cls.__table__.name}_' # type: ignore
col_prefix_length = len(col_prefix)
# for col in query.columns:
# if col.name.startswith(col_prefix):
# breakpoint()
# col.label(col.name[col_prefix_length:])
# breakpoint()
if with_related and len(joined_tables) > 0: if with_related and len(joined_tables) > 0:
query = query.options(*(joinedload(jt) for jt in joined_tables)) query = query.options(*(joinedload(jt) for jt in joined_tables))
## Get the column details of the final query
# ffrom = query.get_final_froms()[0]
# for col in ffrom.columns:
# if col.table == cls.__table__:
# continue
# print(col, col._proxies[0])
# pass
query_with_full_labels = query.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
df = await session.run_sync(method, query, **kwargs) df = await session.run_sync(method, query, **kwargs)
if method is geopandas_query and simplify_tolerance is not None: if method is geopandas_query and simplify_tolerance is not None:
df['geom'] = df['geom'].simplify( df['geom'] = df['geom'].simplify(
@ -89,24 +105,36 @@ class BaseModel(SQLModel):
if preserve_topology is None if preserve_topology is None
else preserve_topology) else preserve_topology)
) )
# col_renames = {}
# col_name: str
# for col_name in df.columns:
# if col_name.startswith(col_prefix):
# col_renames[col_name] = col_name[col_prefix_length:]
# df.rename(columns=col_renames, inplace=True)
## Chamge column names to reflect the joined tables ## Chamge column names to reflect the joined tables
## Leave the first columns unchanged, as their names come straight ## Leave the first columns unchanged, as their names come straight
## from the model's fields ## from the model's fields
joined_columns = list(df.columns[len(cls.model_fields):]) joined_columns = list(df.columns[len(cls.model_fields):])
renames: dict[str, str] = {} renames: dict[str, str] = {}
## Match colum names with the joined tables # ## Match colum names with the joined tables
## Important: this assumes that orders of the joined tables # ## Important: this assumes that orders of the joined tables
## and their columns is preserved by pandas' read_sql # ## and their columns is preserved by pandas' read_sql
for joined_table in joined_tables: # for joined_table in joined_tables:
target = joined_table.property.target # type: ignore # target = joined_table.property.target # type: ignore
target_name = target.name # for col in target.columns:
for col in target.columns: # ## Pop the column from the colujmn list and make a new name
## Pop the column from the colujmn list and make a new name # renames[joined_columns.pop(0)] = f'{target.schema}_{target.name}_{col.name}'
renames[joined_columns.pop(0)] = f'{target.schema}_{target_name}_{col.name}' # df.rename(columns=renames, inplace=True)
df.rename(columns=renames, inplace=True)
## Finally, set the index of the df as the index of cls ## Finally, set the index of the df as the index of cls
df.set_index([c.name for c in cls.__table__.primary_key.columns], # type: ignore df.set_index([c.name for c in cls.__table__.primary_key.columns], # type: ignore
inplace=True) inplace=True)
ffroms = query.get_final_froms()[0]
ffroms1 = query_with_full_labels.get_final_froms()[0]
col = [col for col in ffroms.columns][-2]
col1 = [col for col in ffroms1.columns][-2]
all_cols = [cc for cc in query.froms[0].exported_columns]
breakpoint()
return df return df

View file

@ -12,7 +12,7 @@ from pydantic import create_model
from sqlalchemy import text from sqlalchemy import text
from sqlalchemy.orm import selectinload, joinedload from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy.exc import NoResultFound from sqlalchemy.exc import NoResultFound
from sqlmodel import SQLModel, col, select, inspect, Relationship from sqlmodel import SQLModel, select, inspect, Relationship
import pandas as pd import pandas as pd
import numpy as np import numpy as np
@ -36,7 +36,7 @@ from gisaf.utils import ToMigrate
from gisaf.models.category import Category, CategoryGroup from gisaf.models.category import Category, CategoryGroup
from gisaf.database import db_session from gisaf.database import db_session
from gisaf import models from gisaf import models
from gisaf.models.metadata import gisaf_survey, raw_survey, survey from gisaf.models.metadata import raw_survey, survey
from gisaf.models.to_migrate import FeatureInfo, InfoCategory from gisaf.models.to_migrate import FeatureInfo, InfoCategory
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)