From f1534dfed79b354a7dd758cc0a8347b1dd37e09b Mon Sep 17 00:00:00 2001 From: phil Date: Mon, 4 Mar 2024 16:06:35 +0530 Subject: [PATCH] Revert breaking changes in get_df --- src/gisaf/database.py | 44 ++++++++++++------------------------------- 1 file changed, 12 insertions(+), 32 deletions(-) diff --git a/src/gisaf/database.py b/src/gisaf/database.py index f4a621f..1976b7e 100644 --- a/src/gisaf/database.py +++ b/src/gisaf/database.py @@ -5,7 +5,6 @@ from collections.abc import AsyncGenerator from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.orm import joinedload, QueryableAttribute from sqlalchemy.sql.selectable import Select -from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL from sqlmodel import SQLModel, select from sqlmodel.ext.asyncio.session import AsyncSession from fastapi import Depends @@ -79,14 +78,6 @@ class BaseModel(SQLModel): ## Get the joined tables if with_related: joined_tables = cls.selectinload() - # Rename columns of the primary table - col_prefix = f'{cls.__table__.schema}_{cls.__table__.name}_' # type: ignore - col_prefix_length = len(col_prefix) - # for col in query.columns: - # if col.name.startswith(col_prefix): - # breakpoint() - # col.label(col.name[col_prefix_length:]) - # breakpoint() if with_related and len(joined_tables) > 0: query = query.options(*(joinedload(jt) for jt in joined_tables)) ## Get the column details of the final query @@ -96,7 +87,6 @@ class BaseModel(SQLModel): # continue # print(col, col._proxies[0]) # pass - query_with_full_labels = query.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL) df = await session.run_sync(method, query, **kwargs) if method is geopandas_query and simplify_tolerance is not None: df['geom'] = df['geom'].simplify( @@ -105,36 +95,26 @@ class BaseModel(SQLModel): if preserve_topology is None else preserve_topology) ) - # col_renames = {} - # col_name: str - # for col_name in df.columns: - # if col_name.startswith(col_prefix): - # col_renames[col_name] = col_name[col_prefix_length:] - # df.rename(columns=col_renames, inplace=True) - ## Chamge column names to reflect the joined tables ## Leave the first columns unchanged, as their names come straight ## from the model's fields joined_columns = list(df.columns[len(cls.model_fields):]) renames: dict[str, str] = {} - # ## Match colum names with the joined tables - # ## Important: this assumes that orders of the joined tables - # ## and their columns is preserved by pandas' read_sql - # for joined_table in joined_tables: - # target = joined_table.property.target # type: ignore - # for col in target.columns: - # ## Pop the column from the colujmn list and make a new name - # renames[joined_columns.pop(0)] = f'{target.schema}_{target.name}_{col.name}' - # df.rename(columns=renames, inplace=True) + # Match colum names with the joined tables + # This uses the fact that orders of the joined tables + # and their columns is preserved by sqlalchemy query options (joinedlaod), + # and pandas' read_sql + # Important: apparently, the order defined in cls.selectinload has to match the + # order of the relationship fields defined in the model class definition + for joined_table in joined_tables: + target = joined_table.property.target # type: ignore + for col in target.columns: + ## Pop the column from the colujmn list and make a new name + renames[joined_columns.pop(0)] = f'{target.schema}_{target.name}_{col.name}' + df.rename(columns=renames, inplace=True) ## Finally, set the index of the df as the index of cls df.set_index([c.name for c in cls.__table__.primary_key.columns], # type: ignore inplace=True) - ffroms = query.get_final_froms()[0] - ffroms1 = query_with_full_labels.get_final_froms()[0] - col = [col for col in ffroms.columns][-2] - col1 = [col for col in ffroms1.columns][-2] - all_cols = [cc for cc in query.froms[0].exported_columns] - breakpoint() return df