From eb0f2b7b64f2059c1eacaf33831f3faa435359bf Mon Sep 17 00:00:00 2001 From: Doug Branton Date: Tue, 9 Jan 2024 13:45:31 -0800 Subject: [PATCH] handle no-column object table; WIP unit test --- src/tape/ensemble.py | 16 ++++++++++++-- tests/tape_tests/test_ensemble.py | 36 +++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/src/tape/ensemble.py b/src/tape/ensemble.py index d427d9de..6d286f66 100644 --- a/src/tape/ensemble.py +++ b/src/tape/ensemble.py @@ -1297,6 +1297,13 @@ def save_ensemble(self, path=".", dirname="ensemble", additional_frames=True, ** # grab the dataframe from the frame label frame = self.frames[frame_label] + # Object can have no columns, which parquet doesn't handle + # In this case, we'll avoid saving to parquet + if frame_label == "object": + if len(frame.columns) == 0: + print("The Object Frame was not saved as no columns were present.") + continue + # creates a subdirectory for the frame partition files frame.to_parquet(os.path.join(ens_path, frame_label), **kwargs) @@ -1342,8 +1349,13 @@ def from_ensemble(self, dirpath, additional_frames=True, column_mapper=None, **k # Load Object and Source obj_path = os.path.join(dirpath, "object") - src_path = os.path.join(dirpath, "object") - self.from_parquet(src_path, obj_path, column_mapper=column_mapper, **kwargs) + src_path = os.path.join(dirpath, "source") + + # Check for whether or not object is present, it's not saved when no columns are present + if "object" in os.listdir(dirpath): + self.from_parquet(src_path, obj_path, column_mapper=column_mapper, **kwargs) + else: + self.from_parquet(src_path, column_mapper=column_mapper, **kwargs) # Load all remaining frames if additional_frames is False: diff --git a/tests/tape_tests/test_ensemble.py b/tests/tape_tests/test_ensemble.py index 995b1e82..18604775 100644 --- a/tests/tape_tests/test_ensemble.py +++ b/tests/tape_tests/test_ensemble.py @@ -478,6 +478,42 @@ def test_read_source_dict(dask_client): assert 8002 in obj_table.index +def test_save_and_load_ensemble(dask_client): + # Set a seed for reproducibility + np.random.seed(1) + + # Create some toy data + obj_ids = np.array([]) + mjds = np.array([]) + for i in range(10, 110): + obj_ids = np.append(obj_ids, np.array([i] * 1250)) + mjds = np.append(mjds, np.arange(0.0, 1250.0, 1.0)) + obj_ids = np.array(obj_ids) + flux = 10 * np.random.random(125000) + err = flux / 10 + band = np.random.choice(["g", "r"], 125000) + + # Store it in a dictionary + source_dict = {"id": obj_ids, "mjd": mjds, "flux": flux, "err": err, "band": band} + + # Create an Ensemble + ens = Ensemble() + ens.from_source_dict( + source_dict, + column_mapper=ColumnMapper( + id_col="id", time_col="mjd", flux_col="flux", err_col="err", band_col="band" + ), + ) + + # Make a column for the object table + ens.calc_nobs(temporary=False) + # Add a few result frames + ens.batch(np.mean, "flux", label="mean") + ens.batch(np.max, "flux", label="max") + + ens.save_ensemble("./ensemble") + + def test_insert(parquet_ensemble): num_partitions = parquet_ensemble.source.npartitions (old_object, old_source) = parquet_ensemble.compute()