Revert breaking changes in get_df

This commit is contained in:
phil 2024-03-04 16:06:35 +05:30
parent d2c2e6cc69
commit f1534dfed7

View file

@ -5,7 +5,6 @@ from collections.abc import AsyncGenerator
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import joinedload, QueryableAttribute
from sqlalchemy.sql.selectable import Select
from sqlalchemy import LABEL_STYLE_TABLENAME_PLUS_COL
from sqlmodel import SQLModel, select
from sqlmodel.ext.asyncio.session import AsyncSession
from fastapi import Depends
@ -79,14 +78,6 @@ class BaseModel(SQLModel):
## Get the joined tables
if with_related:
joined_tables = cls.selectinload()
# Rename columns of the primary table
col_prefix = f'{cls.__table__.schema}_{cls.__table__.name}_' # type: ignore
col_prefix_length = len(col_prefix)
# for col in query.columns:
# if col.name.startswith(col_prefix):
# breakpoint()
# col.label(col.name[col_prefix_length:])
# breakpoint()
if with_related and len(joined_tables) > 0:
query = query.options(*(joinedload(jt) for jt in joined_tables))
## Get the column details of the final query
@ -96,7 +87,6 @@ class BaseModel(SQLModel):
# continue
# print(col, col._proxies[0])
# pass
query_with_full_labels = query.set_label_style(LABEL_STYLE_TABLENAME_PLUS_COL)
df = await session.run_sync(method, query, **kwargs)
if method is geopandas_query and simplify_tolerance is not None:
df['geom'] = df['geom'].simplify(
@ -105,36 +95,26 @@ class BaseModel(SQLModel):
if preserve_topology is None
else preserve_topology)
)
# col_renames = {}
# col_name: str
# for col_name in df.columns:
# if col_name.startswith(col_prefix):
# col_renames[col_name] = col_name[col_prefix_length:]
# df.rename(columns=col_renames, inplace=True)
## Chamge column names to reflect the joined tables
## Leave the first columns unchanged, as their names come straight
## from the model's fields
joined_columns = list(df.columns[len(cls.model_fields):])
renames: dict[str, str] = {}
# ## Match colum names with the joined tables
# ## Important: this assumes that orders of the joined tables
# ## and their columns is preserved by pandas' read_sql
# for joined_table in joined_tables:
# target = joined_table.property.target # type: ignore
# for col in target.columns:
# ## Pop the column from the colujmn list and make a new name
# renames[joined_columns.pop(0)] = f'{target.schema}_{target.name}_{col.name}'
# df.rename(columns=renames, inplace=True)
# Match colum names with the joined tables
# This uses the fact that orders of the joined tables
# and their columns is preserved by sqlalchemy query options (joinedlaod),
# and pandas' read_sql
# Important: apparently, the order defined in cls.selectinload has to match the
# order of the relationship fields defined in the model class definition
for joined_table in joined_tables:
target = joined_table.property.target # type: ignore
for col in target.columns:
## Pop the column from the colujmn list and make a new name
renames[joined_columns.pop(0)] = f'{target.schema}_{target.name}_{col.name}'
df.rename(columns=renames, inplace=True)
## Finally, set the index of the df as the index of cls
df.set_index([c.name for c in cls.__table__.primary_key.columns], # type: ignore
inplace=True)
ffroms = query.get_final_froms()[0]
ffroms1 = query_with_full_labels.get_final_froms()[0]
col = [col for col in ffroms.columns][-2]
col1 = [col for col in ffroms1.columns][-2]
all_cols = [cc for cc in query.froms[0].exported_columns]
breakpoint()
return df