From 6b720b362db5e78a2db3bad36618af6faea86386 Mon Sep 17 00:00:00 2001 From: Doug Branton Date: Mon, 25 Mar 2024 16:14:11 -0700 Subject: [PATCH] fix codebase todos --- src/tape/ensemble.py | 16 +++++++++------- src/tape/ensemble_frame.py | 6 ++++++ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/tape/ensemble.py b/src/tape/ensemble.py index 0e981204..ea63ed33 100644 --- a/src/tape/ensemble.py +++ b/src/tape/ensemble.py @@ -1176,8 +1176,8 @@ def _batch_apply(df, func, on, *args, **kwargs): ) # Output standardization - batch = self._standardize_batch(batch, on, by_band) + batch = self._standardize_batch(batch, on, by_band) # Inherit divisions if known from source and the resulting index is the id # Groupby on index should always return a subset that adheres to the same divisions criteria if self.source.known_divisions and batch.index.name == self._id_col: @@ -1200,10 +1200,18 @@ def _standardize_batch(self, batch, on, by_band): # make sure the output is separated from the id column if batch.name == self._id_col: batch = batch.rename("result") + + # need to set the index name + set_idx_name = True + else: + set_idx_name = False + res_cols = [batch.name] # grab the series name to use as a column label # convert the series to an EnsembleFrame object batch = EnsembleFrame.from_dask_dataframe(batch.to_frame()) + if set_idx_name and len(on) < 2: + batch.index = batch.index.rename(self._id_col) elif isinstance(batch, EnsembleFrame): # collect output columns @@ -1373,10 +1381,6 @@ def save_ensemble(self, path=".", dirname="ensemble", additional_frames=True, ** # Now write out the frames to subdirectories for subdir in created_subdirs: - # TODO: Figure this out, peek at the real meta as a stop gap - # TODO: It may be best to make sure batch returns valid index names - idx_name = self.frames[subdir].head(1).index.name - self.frames[subdir].index = self.frames[subdir].index.rename(idx_name) self.frames[subdir].to_parquet(os.path.join(ens_path, subdir), write_metadata_file=True, **kwargs) print(f"Saved to {os.path.join(path, dirname)}") @@ -2138,8 +2142,6 @@ def convert_flux_to_mag(self, zero_point, zp_form="mag", out_col_name=None, flux def _generate_object_table(self): """Generate an empty object table from the source table.""" res = self.source.map_partitions(lambda x: TapeObjectFrame(index=x.index.unique())) - res.label = "object" # TODO: propagation issue with label - return res def _lazy_sync_tables_from_frame(self, frame): diff --git a/src/tape/ensemble_frame.py b/src/tape/ensemble_frame.py index a8629fe1..07894c1c 100644 --- a/src/tape/ensemble_frame.py +++ b/src/tape/ensemble_frame.py @@ -733,6 +733,12 @@ def map_partitions(self, func, *args, **kwargs): if isinstance(result, self.__class__): # If the output of func is another _Frame, let's propagate any metadata. return self._propagate_metadata(result) + elif isinstance(result, ObjectFrame): + result = self._propagate_metadata(result) + result.label = "object" # override the label + return result + elif isinstance(result, SourceFrame): + return self._propagate_metadata(result) return result def compute(self, **kwargs):