WIP: attempt to fix issue with get_df (column names)

Cleanup & cosmetic
This commit is contained in:
phil 2024-03-03 23:50:48 +05:30
parent bcfda603be
commit d2c2e6cc69
4 changed files with 52 additions and 22 deletions

View file

@ -8,8 +8,7 @@ from typing import Annotated
from asyncio import CancelledError
from fastapi import (Depends, APIRouter, HTTPException, Response, Header,
WebSocket, WebSocketDisconnect,
status, responses)
WebSocket, WebSocketDisconnect, status)
from gisaf.models.authentication import User
from gisaf.redis_tools import store as redis_store

View file

@ -92,7 +92,10 @@ async def get_base_style(request: Request, name: str,
async def get_layer_style(request: Request, store: str,
response: Response,
) -> MaplibreStyle | None:
store_record = registry.stores.loc[store]
try:
store_record = registry.stores.loc[store]
except KeyError:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
if store_record.is_live:
## No ttag for live layers' style (could be added?)
## Get layer_defs from live redis and give symbol

View file

@ -5,12 +5,13 @@ from collections.abc import AsyncGenerator
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import joinedload, QueryableAttribute
from sqlalchemy.sql.selectable import Select
from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL
from sqlmodel import SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession
from fastapi import Depends
# from geoalchemy2.functions import ST_SimplifyPreserveTopology
import pandas as pd
import geopandas as gpd
import geopandas as gpd # type: ignore
from gisaf.config import conf
@ -58,9 +59,7 @@ class BaseModel(SQLModel):
@classmethod
async def get_gdf(cls, **kwargs) -> gpd.GeoDataFrame:
return await cls._get_df(geopandas_query,
model=cls,
**kwargs)
return await cls._get_df(geopandas_query, model=cls, **kwargs) # type: ignore
@classmethod
async def _get_df(cls, method, *,
@ -73,14 +72,31 @@ class BaseModel(SQLModel):
query = select(cls)
else:
columns = set(with_only_columns)
columns.add(*(col.name for col in cls.__table__.primary_key.columns))
columns.add(*(col.name for col in cls.__table__.primary_key.columns)) # type: ignore
query = select(*(getattr(cls, col) for col in columns))
if where is not None:
query = query.where(where)
## Get the joined tables
joined_tables = cls.selectinload()
if with_related:
joined_tables = cls.selectinload()
# Rename columns of the primary table
col_prefix = f'{cls.__table__.schema}_{cls.__table__.name}_' # type: ignore
col_prefix_length = len(col_prefix)
# for col in query.columns:
# if col.name.startswith(col_prefix):
# breakpoint()
# col.label(col.name[col_prefix_length:])
# breakpoint()
if with_related and len(joined_tables) > 0:
query = query.options(*(joinedload(jt) for jt in joined_tables))
## Get the column details of the final query
# ffrom = query.get_final_froms()[0]
# for col in ffrom.columns:
# if col.table == cls.__table__:
# continue
# print(col, col._proxies[0])
# pass
query_with_full_labels = query.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
df = await session.run_sync(method, query, **kwargs)
if method is geopandas_query and simplify_tolerance is not None:
df['geom'] = df['geom'].simplify(
@ -89,24 +105,36 @@ class BaseModel(SQLModel):
if preserve_topology is None
else preserve_topology)
)
# col_renames = {}
# col_name: str
# for col_name in df.columns:
# if col_name.startswith(col_prefix):
# col_renames[col_name] = col_name[col_prefix_length:]
# df.rename(columns=col_renames, inplace=True)
## Chamge column names to reflect the joined tables
## Leave the first columns unchanged, as their names come straight
## from the model's fields
joined_columns = list(df.columns[len(cls.model_fields):])
renames: dict[str, str] = {}
## Match colum names with the joined tables
## Important: this assumes that orders of the joined tables
## and their columns is preserved by pandas' read_sql
for joined_table in joined_tables:
target = joined_table.property.target # type: ignore
target_name = target.name
for col in target.columns:
## Pop the column from the colujmn list and make a new name
renames[joined_columns.pop(0)] = f'{target.schema}_{target_name}_{col.name}'
df.rename(columns=renames, inplace=True)
# ## Match colum names with the joined tables
# ## Important: this assumes that orders of the joined tables
# ## and their columns is preserved by pandas' read_sql
# for joined_table in joined_tables:
# target = joined_table.property.target # type: ignore
# for col in target.columns:
# ## Pop the column from the colujmn list and make a new name
# renames[joined_columns.pop(0)] = f'{target.schema}_{target.name}_{col.name}'
# df.rename(columns=renames, inplace=True)
## Finally, set the index of the df as the index of cls
df.set_index([c.name for c in cls.__table__.primary_key.columns], # type: ignore
inplace=True)
inplace=True)
ffroms = query.get_final_froms()[0]
ffroms1 = query_with_full_labels.get_final_froms()[0]
col = [col for col in ffroms.columns][-2]
col1 = [col for col in ffroms1.columns][-2]
all_cols = [cc for cc in query.froms[0].exported_columns]
breakpoint()
return df

View file

@ -12,7 +12,7 @@ from pydantic import create_model
from sqlalchemy import text
from sqlalchemy.orm import selectinload, joinedload
from sqlalchemy.exc import NoResultFound
from sqlmodel import SQLModel, col, select, inspect, Relationship
from sqlmodel import SQLModel, select, inspect, Relationship
import pandas as pd
import numpy as np
@ -36,7 +36,7 @@ from gisaf.utils import ToMigrate
from gisaf.models.category import Category, CategoryGroup
from gisaf.database import db_session
from gisaf import models
from gisaf.models.metadata import gisaf_survey, raw_survey, survey
from gisaf.models.metadata import raw_survey, survey
from gisaf.models.to_migrate import FeatureInfo, InfoCategory
logger = logging.getLogger(__name__)