Revert breaking changes in get_df

This commit is contained in:
phil 2024-03-04 16:06:35 +05:30
parent d2c2e6cc69
commit f1534dfed7

View file

@ -5,7 +5,6 @@ from collections.abc import AsyncGenerator
from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import joinedload, QueryableAttribute from sqlalchemy.orm import joinedload, QueryableAttribute
from sqlalchemy.sql.selectable import Select from sqlalchemy.sql.selectable import Select
from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL
from sqlmodel import SQLModel, select from sqlmodel import SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from fastapi import Depends from fastapi import Depends
@ -79,14 +78,6 @@ class BaseModel(SQLModel):
## Get the joined tables ## Get the joined tables
if with_related: if with_related:
joined_tables = cls.selectinload() joined_tables = cls.selectinload()
# Rename columns of the primary table
col_prefix = f'{cls.__table__.schema}_{cls.__table__.name}_' # type: ignore
col_prefix_length = len(col_prefix)
# for col in query.columns:
# if col.name.startswith(col_prefix):
# breakpoint()
# col.label(col.name[col_prefix_length:])
# breakpoint()
if with_related and len(joined_tables) > 0: if with_related and len(joined_tables) > 0:
query = query.options(*(joinedload(jt) for jt in joined_tables)) query = query.options(*(joinedload(jt) for jt in joined_tables))
## Get the column details of the final query ## Get the column details of the final query
@ -96,7 +87,6 @@ class BaseModel(SQLModel):
# continue # continue
# print(col, col._proxies[0]) # print(col, col._proxies[0])
# pass # pass
query_with_full_labels = query.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
df = await session.run_sync(method, query, **kwargs) df = await session.run_sync(method, query, **kwargs)
if method is geopandas_query and simplify_tolerance is not None: if method is geopandas_query and simplify_tolerance is not None:
df['geom'] = df['geom'].simplify( df['geom'] = df['geom'].simplify(
@ -105,36 +95,26 @@ class BaseModel(SQLModel):
if preserve_topology is None if preserve_topology is None
else preserve_topology) else preserve_topology)
) )
# col_renames = {}
# col_name: str
# for col_name in df.columns:
# if col_name.startswith(col_prefix):
# col_renames[col_name] = col_name[col_prefix_length:]
# df.rename(columns=col_renames, inplace=True)
## Chamge column names to reflect the joined tables ## Chamge column names to reflect the joined tables
## Leave the first columns unchanged, as their names come straight ## Leave the first columns unchanged, as their names come straight
## from the model's fields ## from the model's fields
joined_columns = list(df.columns[len(cls.model_fields):]) joined_columns = list(df.columns[len(cls.model_fields):])
renames: dict[str, str] = {} renames: dict[str, str] = {}
# ## Match colum names with the joined tables # Match colum names with the joined tables
# ## Important: this assumes that orders of the joined tables # This uses the fact that orders of the joined tables
# ## and their columns is preserved by pandas' read_sql # and their columns is preserved by sqlalchemy query options (joinedlaod),
# for joined_table in joined_tables: # and pandas' read_sql
# target = joined_table.property.target # type: ignore # Important: apparently, the order defined in cls.selectinload has to match the
# for col in target.columns: # order of the relationship fields defined in the model class definition
# ## Pop the column from the colujmn list and make a new name for joined_table in joined_tables:
# renames[joined_columns.pop(0)] = f'{target.schema}_{target.name}_{col.name}' target = joined_table.property.target # type: ignore
# df.rename(columns=renames, inplace=True) for col in target.columns:
## Pop the column from the colujmn list and make a new name
renames[joined_columns.pop(0)] = f'{target.schema}_{target.name}_{col.name}'
df.rename(columns=renames, inplace=True)
## Finally, set the index of the df as the index of cls ## Finally, set the index of the df as the index of cls
df.set_index([c.name for c in cls.__table__.primary_key.columns], # type: ignore df.set_index([c.name for c in cls.__table__.primary_key.columns], # type: ignore
inplace=True) inplace=True)
ffroms = query.get_final_froms()[0]
ffroms1 = query_with_full_labels.get_final_froms()[0]
col = [col for col in ffroms.columns][-2]
col1 = [col for col in ffroms1.columns][-2]
all_cols = [cc for cc in query.froms[0].exported_columns]
breakpoint()
return df return df