diff --git a/src/tape/ensemble.py b/src/tape/ensemble.py index 2deb5186..90bf1a7c 100644 --- a/src/tape/ensemble.py +++ b/src/tape/ensemble.py @@ -34,6 +34,8 @@ DEFAULT_FRAME_LABEL = "result" # A base default label for an Ensemble's result frames. +METADATA_FILENAME = "ensemble_metadata.json" + class Ensemble: """Ensemble object is a collection of light curve ids""" @@ -1286,9 +1288,9 @@ def save_ensemble(self, path=".", dirname="ensemble", additional_frames=True, ** # Determine the path ens_path = os.path.join(path, dirname) - # First look for an existing metadata.json file in the path + # First look for an existing metadata file in the path try: - with open(os.path.join(ens_path, "metadata.json"), "r") as oldfile: + with open(os.path.join(ens_path, METADATA_FILENAME), "r") as oldfile: # Reading from json file old_metadata = json.load(oldfile) old_subdirs = old_metadata["subdirs"] @@ -1302,7 +1304,7 @@ def save_ensemble(self, path=".", dirname="ensemble", additional_frames=True, ** if additional_frames is True: frames_to_save = list(self.frames.keys()) # save all frames elif additional_frames is False: - frames_to_save = ["object", "source"] # save just object and source + frames_to_save = [OBJECT_FRAME_LABEL, SOURCE_FRAME_LABEL] # save just object and source elif isinstance(additional_frames, Iterable): frames_to_save = set(additional_frames) invalid_frames = frames_to_save.difference(set(self.frames.keys())) @@ -1314,14 +1316,14 @@ def save_ensemble(self, path=".", dirname="ensemble", additional_frames=True, ** frames_to_save = list(frames_to_save) # Make sure object and source are in the frame list - if "object" not in frames_to_save: - frames_to_save.append("object") - if "source" not in frames_to_save: - frames_to_save.append("source") + if OBJECT_FRAME_LABEL not in frames_to_save: + frames_to_save.append(OBJECT_FRAME_LABEL) + if SOURCE_FRAME_LABEL not in frames_to_save: + frames_to_save.append(SOURCE_FRAME_LABEL) else: raise ValueError("Invalid input to `additional_frames`, must be boolean or list-like") - # Save the frame list to disk + # Generate the metadata first created_subdirs = [] # track the list of created subdirectories divisions_known = [] # log whether divisions were known for each frame for frame_label in frames_to_save: @@ -1331,17 +1333,14 @@ def save_ensemble(self, path=".", dirname="ensemble", additional_frames=True, ** # When the frame has no columns, avoid the save as parquet doesn't handle it # Most commonly this applies to the object table when it's built from source if len(frame.columns) == 0: - print(f"Frame: {frame_label} was not saved as no columns were present.") + print(f"Frame: {frame_label} will not be saved as no columns are present.") continue - # creates a subdirectory for the frame partition files - frame.to_parquet(os.path.join(ens_path, frame_label), write_metadata_file=True, **kwargs) created_subdirs.append(frame_label) divisions_known.append(frame.known_divisions) # Save a metadata file col_map = self.make_column_map() # grab the current column_mapper - metadata = { "subdirs": created_subdirs, "known_divisions": divisions_known, @@ -1349,10 +1348,14 @@ def save_ensemble(self, path=".", dirname="ensemble", additional_frames=True, ** } json_metadata = json.dumps(metadata, indent=4) - with open(os.path.join(ens_path, "metadata.json"), "w") as outfile: + # Make the directory if it doesn't already exist + os.makedirs(ens_path, exist_ok=True) + with open(os.path.join(ens_path, METADATA_FILENAME), "w") as outfile: outfile.write(json_metadata) - # np.save(os.path.join(ens_path, "column_mapper.npy"), col_map.map) + # Now write out the frames to subdirectories + for subdir in created_subdirs: + self.frames[subdir].to_parquet(os.path.join(ens_path, subdir), write_metadata_file=True, **kwargs) print(f"Saved to {os.path.join(path, dirname)}") @@ -1390,8 +1393,8 @@ def from_ensemble( The ensemble object. """ - # Read in the metadata.json file - with open(os.path.join(dirpath, "metadata.json"), "r") as metadatafile: + # Read in the metadata file + with open(os.path.join(dirpath, METADATA_FILENAME), "r") as metadatafile: # Reading from json file metadata = json.load(metadatafile) @@ -1405,16 +1408,16 @@ def from_ensemble( # Load Object and Source # Check for whether or not object is present, it's not saved when no columns are present - if "object" in subdirs: + if OBJECT_FRAME_LABEL in subdirs: # divisions should be known for both tables to use the sorted kwarg use_sorted = ( - frame_known_divisions[subdirs.index("object")] - and frame_known_divisions[subdirs.index("source")] + frame_known_divisions[subdirs.index(OBJECT_FRAME_LABEL)] + and frame_known_divisions[subdirs.index(SOURCE_FRAME_LABEL)] ) self.from_parquet( - os.path.join(dirpath, "source"), - os.path.join(dirpath, "object"), + os.path.join(dirpath, SOURCE_FRAME_LABEL), + os.path.join(dirpath, OBJECT_FRAME_LABEL), column_mapper=column_mapper, sorted=use_sorted, sort=False, @@ -1422,9 +1425,9 @@ def from_ensemble( **kwargs, ) else: - use_sorted = frame_known_divisions[subdirs.index("source")] + use_sorted = frame_known_divisions[subdirs.index(SOURCE_FRAME_LABEL)] self.from_parquet( - os.path.join(dirpath, "source"), + os.path.join(dirpath, SOURCE_FRAME_LABEL), column_mapper=column_mapper, sorted=use_sorted, sort=False, @@ -1446,7 +1449,9 @@ def from_ensemble( # Filter out object and source from additional frames frames_to_load = [ - frame for frame in frames_to_load if os.path.split(frame)[1] not in ["object", "source"] + frame + for frame in frames_to_load + if os.path.split(frame)[1] not in [OBJECT_FRAME_LABEL, SOURCE_FRAME_LABEL] ] if len(frames_to_load) > 0: for frame in frames_to_load: diff --git a/src/tape/ensemble_readers.py b/src/tape/ensemble_readers.py index 55b93e26..91f9add2 100644 --- a/src/tape/ensemble_readers.py +++ b/src/tape/ensemble_readers.py @@ -15,8 +15,6 @@ def read_ensemble( additional_frames=True, column_mapper=None, dask_client=True, - additional_cols=True, - partition_size=None, **kwargs, ): """Load an ensemble from an on-disk ensemble. @@ -37,19 +35,6 @@ def read_ensemble( Supplies a ColumnMapper to the Ensemble, if None (default) searches for a column_mapper.npy file in the directory, which should be created when the ensemble is saved. - additional_cols: 'bool', optional - Boolean to indicate whether to carry in columns beyond the - critical columns, true will, while false will only load the columns - containing the critical quantities (id,time,flux,err,band) - partition_size: `int`, optional - If specified, attempts to repartition the ensemble to partitions - of size `partition_size`. - sorted: bool, optional - If the index column is already sorted in increasing order. - Defaults to False - sort: `bool`, optional - If True, sorts the DataFrame by the id column. Otherwise set the - index on the individual existing partitions. Defaults to False. dask_client: `dask.distributed.client` or `bool`, optional Accepts an existing `dask.distributed.Client`, or creates one if `client=True`, passing any additional kwargs to a @@ -68,8 +53,6 @@ def read_ensemble( dirpath, additional_frames=additional_frames, column_mapper=column_mapper, - additional_cols=additional_cols, - partition_size=partition_size, **kwargs, ) diff --git a/tests/tape_tests/conftest.py b/tests/tape_tests/conftest.py index c0af84c3..ec6b4521 100644 --- a/tests/tape_tests/conftest.py +++ b/tests/tape_tests/conftest.py @@ -233,6 +233,7 @@ def parquet_ensemble_without_client(): return ens + @pytest.fixture def parquet_files_and_ensemble_without_client(): """Create an Ensemble from parquet data without a dask client.""" @@ -246,12 +247,10 @@ def parquet_files_and_ensemble_without_client(): err_col="psFluxErr", band_col="filterName", ) - ens = ens.from_parquet( - source_file, - object_file, - column_mapper=colmap) + ens = ens.from_parquet(source_file, object_file, column_mapper=colmap) return ens, source_file, object_file, colmap + # pylint: disable=redefined-outer-name @pytest.fixture def parquet_ensemble(dask_client): @@ -270,6 +269,25 @@ def parquet_ensemble(dask_client): return ens +# pylint: disable=redefined-outer-name +@pytest.fixture +def parquet_ensemble_partition_size(dask_client): + """Create an Ensemble from parquet data.""" + ens = Ensemble(client=dask_client) + ens.from_parquet( + "tests/tape_tests/data/source/test_source.parquet", + "tests/tape_tests/data/object/test_object.parquet", + id_col="ps1_objid", + time_col="midPointTai", + band_col="filterName", + flux_col="psFlux", + err_col="psFluxErr", + partition_size="1MB", + ) + + return ens + + # pylint: disable=redefined-outer-name @pytest.fixture def parquet_ensemble_with_divisions(dask_client): @@ -386,6 +404,34 @@ def dask_dataframe_ensemble(dask_client): return ens +# pylint: disable=redefined-outer-name +@pytest.fixture +def dask_dataframe_ensemble_partition_size(dask_client): + """Create an Ensemble from parquet data.""" + ens = Ensemble(client=dask_client) + + num_points = 1000 + all_bands = np.array(["r", "g", "b", "i"]) + rows = { + "id": 8000 + (np.arange(num_points) % 5), + "time": np.arange(num_points), + "flux": np.arange(num_points) % len(all_bands), + "band": np.repeat(all_bands, num_points / len(all_bands)), + "err": 0.1 * (np.arange(num_points) % 10), + "count": np.arange(num_points), + "something_else": np.full(num_points, None), + } + cmap = ColumnMapper(id_col="id", time_col="time", flux_col="flux", err_col="err", band_col="band") + + ens.from_dask_dataframe( + source_frame=dd.from_dict(rows, npartitions=1), + column_mapper=cmap, + partition_size="1MB", + ) + + return ens + + # pylint: disable=redefined-outer-name @pytest.fixture def dask_dataframe_with_object_ensemble(dask_client): @@ -490,6 +536,7 @@ def pandas_with_object_ensemble(dask_client): return ens + # pylint: disable=redefined-outer-name @pytest.fixture def ensemble_from_source_dict(dask_client): @@ -511,4 +558,4 @@ def ensemble_from_source_dict(dask_client): cmap = ColumnMapper(id_col="id", time_col="time", flux_col="flux", err_col="error", band_col="band") ens.from_source_dict(source_dict, column_mapper=cmap) - return ens, source_dict \ No newline at end of file + return ens, source_dict diff --git a/tests/tape_tests/test_ensemble.py b/tests/tape_tests/test_ensemble.py index 7fc52c2e..c2a8945c 100644 --- a/tests/tape_tests/test_ensemble.py +++ b/tests/tape_tests/test_ensemble.py @@ -50,6 +50,7 @@ def test_with_client(): "parquet_ensemble_from_hipscat", "parquet_ensemble_with_column_mapper", "parquet_ensemble_with_known_column_mapper", + "parquet_ensemble_partition_size", "read_parquet_ensemble", "read_parquet_ensemble_without_client", "read_parquet_ensemble_from_source", @@ -102,6 +103,7 @@ def test_parquet_construction(data_fixture, request): "data_fixture", [ "dask_dataframe_ensemble", + "dask_dataframe_ensemble_partition_size", "dask_dataframe_with_object_ensemble", "pandas_ensemble", "pandas_with_object_ensemble", @@ -533,7 +535,7 @@ def test_save_and_load_ensemble(dask_client, tmp_path, add_frames, obj_nocols, u dircontents = os.listdir(os.path.join(save_path, "ensemble")) assert "source" in dircontents # Source should always be there - assert "metadata.json" in dircontents # should make a metadata file + assert "ensemble_metadata.json" in dircontents # should make a metadata file if obj_nocols: # object shouldn't if it was empty assert "object" not in dircontents else: # otherwise it should be present