Skip to content

Commit

Permalink
metadata updates, overwrite saves
Browse files Browse the repository at this point in the history
  • Loading branch information
dougbrn committed Jan 11, 2024
1 parent 6730d47 commit aad1978
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 72 deletions.
123 changes: 69 additions & 54 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import glob
import os
import json
import shutil
import warnings
import requests
import dask.dataframe as dd
Expand Down Expand Up @@ -1284,6 +1286,18 @@ def save_ensemble(self, path=".", dirname="ensemble", additional_frames=True, **
# Determine the path
ens_path = os.path.join(path, dirname)

# First look for an existing metadata.json file in the path
try:
with open(os.path.join(ens_path, "metadata.json"), "r") as oldfile:
# Reading from json file
old_metadata = json.load(oldfile)
old_subdirs = old_metadata["subdirs"]
# Delete any old subdirectories
for subdir in old_subdirs:
shutil.rmtree(os.path.join(ens_path, subdir))
except FileNotFoundError:
pass

# Compile frame list
if additional_frames is True:
frames_to_save = list(self.frames.keys()) # save all frames
Expand All @@ -1307,23 +1321,37 @@ def save_ensemble(self, path=".", dirname="ensemble", additional_frames=True, **
raise ValueError("Invalid input to `additional_frames`, must be boolean or list-like")

# Save the frame list to disk
created_subdirs = [] # track the list of created subdirectories
divisions_known = [] # log whether divisions were known for each frame
for frame_label in frames_to_save:
# grab the dataframe from the frame label
frame = self.frames[frame_label]

# Object can have no columns, which parquet doesn't handle
# In this case, we'll avoid saving to parquet
if frame_label == "object":
if len(frame.columns) == 0:
print("The Object Frame was not saved as no columns were present.")
continue
# When the frame has no columns, avoid the save as parquet doesn't handle it
# Most commonly this applies to the object table when it's built from source
if len(frame.columns) == 0:
print(f"Frame: {frame_label} was not saved as no columns were present.")
continue

# creates a subdirectory for the frame partition files
frame.to_parquet(os.path.join(ens_path, frame_label), **kwargs)
frame.to_parquet(os.path.join(ens_path, frame_label), write_metadata_file=True, **kwargs)
created_subdirs.append(frame_label)
divisions_known.append(frame.known_divisions)

# Save a metadata file
col_map = self.make_column_map() # grab the current column_mapper

metadata = {
"subdirs": created_subdirs,
"known_divisions": divisions_known,
"column_mapper": col_map.map,
}
json_metadata = json.dumps(metadata, indent=4)

# Save a ColumnMapper file
col_map = self.make_column_map()
np.save(os.path.join(ens_path, "column_mapper.npy"), col_map.map)
with open(os.path.join(ens_path, "metadata.json"), "w") as outfile:
outfile.write(json_metadata)

# np.save(os.path.join(ens_path, "column_mapper.npy"), col_map.map)

print(f"Saved to {os.path.join(path, dirname)}")

Expand All @@ -1334,10 +1362,6 @@ def from_ensemble(
dirpath,
additional_frames=True,
column_mapper=None,
additional_cols=True,
partition_size=None,
sorted=False,
sort=False,
**kwargs,
):
"""Load an ensemble from an on-disk ensemble.
Expand All @@ -1358,60 +1382,52 @@ def from_ensemble(
Supplies a ColumnMapper to the Ensemble, if None (default) searches
for a column_mapper.npy file in the directory, which should be
created when the ensemble is saved.
additional_cols: 'bool', optional
Boolean to indicate whether to carry in columns beyond the
critical columns, true will, while false will only load the columns
containing the critical quantities (id,time,flux,err,band)
partition_size: `int`, optional
If specified, attempts to repartition the ensemble to partitions
of size `partition_size`.
sorted: bool, optional
If the index column is already sorted in increasing order.
Defaults to False
sort: `bool`, optional
If True, sorts the DataFrame by the id column. Otherwise set the
index on the individual existing partitions. Defaults to False.
Returns
----------
ensemble: `tape.ensemble.Ensemble`
The ensemble object.
"""

# First grab the column_mapper if not specified
if column_mapper is None:
map_dict = np.load(os.path.join(dirpath, "column_mapper.npy"), allow_pickle="TRUE").item()
column_mapper = ColumnMapper()
column_mapper.map = map_dict
# Read in the metadata.json file
with open(os.path.join(dirpath, "metadata.json"), "r") as metadatafile:
# Reading from json file
metadata = json.load(metadatafile)

# Load in the metadata
subdirs = metadata["subdirs"]
frame_known_divisions = metadata["known_divisions"]
if column_mapper is None:
column_mapper = ColumnMapper()
column_mapper.map = metadata["column_mapper"]

# Load Object and Source
obj_path = os.path.join(dirpath, "object")
src_path = os.path.join(dirpath, "source")

# Check for whether or not object is present, it's not saved when no columns are present
if "object" in os.listdir(dirpath):
if "object" in subdirs:
# divisions should be known for both tables to use the sorted kwarg
use_sorted = (
frame_known_divisions[subdirs.index("object")]
and frame_known_divisions[subdirs.index("source")]
)

self.from_parquet(
src_path,
obj_path,
os.path.join(dirpath, "source"),
os.path.join(dirpath, "object"),
column_mapper=column_mapper,
additional_cols=additional_cols,
sorted=sorted,
sort=sort,
sorted=use_sorted,
sort=False,
sync_tables=False, # a sync should always be performed just before saving
npartitions=None, # disabled, as this would be applied to all frames
partition_size=partition_size,
**kwargs,
)
else:
use_sorted = frame_known_divisions[subdirs.index("source")]
self.from_parquet(
src_path,
os.path.join(dirpath, "source"),
column_mapper=column_mapper,
additional_cols=additional_cols,
sorted=sorted,
sort=sort,
sorted=use_sorted,
sort=False,
sync_tables=False, # a sync should always be performed just before saving
npartitions=None, # disabled, as this would be applied to all frames
partition_size=partition_size,
**kwargs,
)

Expand All @@ -1421,11 +1437,7 @@ def from_ensemble(
else:
if additional_frames is True:
# Grab all subdirectory paths in the top-level folder, filter out any files
frames_to_load = [
os.path.join(dirpath, f)
for f in os.listdir(dirpath)
if not os.path.isfile(os.path.join(dirpath, f))
]
frames_to_load = [os.path.join(dirpath, f) for f in subdirs]
elif isinstance(additional_frames, Iterable):
frames_to_load = [os.path.join(dirpath, frame) for frame in additional_frames]
else:
Expand All @@ -1438,7 +1450,10 @@ def from_ensemble(
if len(frames_to_load) > 0:
for frame in frames_to_load:
label = os.path.split(frame)[1]
ddf = EnsembleFrame.from_parquet(frame, label=label, **kwargs)
use_divisions = frame_known_divisions[subdirs.index(label)]
ddf = EnsembleFrame.from_parquet(
frame, label=label, calculate_divisions=use_divisions, **kwargs
)
self.add_frame(ddf, label)

return self
Expand Down
15 changes: 2 additions & 13 deletions src/tape/ensemble_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,14 +844,7 @@ def convert_flux_to_mag(
return result

@classmethod
def from_parquet(
cl,
path,
index=None,
columns=None,
label=None,
ensemble=None,
):
def from_parquet(cl, path, index=None, columns=None, label=None, ensemble=None, **kwargs):
"""Returns an EnsembleFrame constructed from loading a parquet file.
Parameters
----------
Expand Down Expand Up @@ -879,11 +872,7 @@ def from_parquet(
# Read the parquet file with an engine that will assume the meta is a TapeFrame which Dask will
# instantiate as EnsembleFrame via its dispatcher.
result = dd.read_parquet(
path,
index=index,
columns=columns,
split_row_groups=True,
engine=TapeArrowEngine,
path, index=index, columns=columns, split_row_groups=True, engine=TapeArrowEngine, **kwargs
)
result.label = label
result.ensemble = ensemble
Expand Down
4 changes: 0 additions & 4 deletions src/tape/ensemble_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ def read_ensemble(
dask_client=True,
additional_cols=True,
partition_size=None,
sorted=False,
sort=False,
**kwargs,
):
"""Load an ensemble from an on-disk ensemble.
Expand Down Expand Up @@ -72,8 +70,6 @@ def read_ensemble(
column_mapper=column_mapper,
additional_cols=additional_cols,
partition_size=partition_size,
sorted=sorted,
sort=sort,
**kwargs,
)

Expand Down
29 changes: 28 additions & 1 deletion tests/tape_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,7 @@ def test_read_source_dict(dask_client):
@pytest.mark.parametrize("obj_nocols", [True, False])
@pytest.mark.parametrize("use_reader", [False, True])
def test_save_and_load_ensemble(dask_client, tmp_path, add_frames, obj_nocols, use_reader):
"""Test the save and load ensemble loop"""
# Setup a temporary directory for files
save_path = tmp_path / "."

Expand Down Expand Up @@ -532,7 +533,7 @@ def test_save_and_load_ensemble(dask_client, tmp_path, add_frames, obj_nocols, u
dircontents = os.listdir(os.path.join(save_path, "ensemble"))

assert "source" in dircontents # Source should always be there
assert "column_mapper.npy" in dircontents # should make a column_mapper file
assert "metadata.json" in dircontents # should make a metadata file
if obj_nocols: # object shouldn't if it was empty
assert "object" not in dircontents
else: # otherwise it should be present
Expand Down Expand Up @@ -586,6 +587,32 @@ def test_save_and_load_ensemble(dask_client, tmp_path, add_frames, obj_nocols, u
loaded_ens.from_ensemble(os.path.join(save_path, "ensemble"), additional_frames=3)


def test_save_overwrite(parquet_ensemble, tmp_path):
"""Test that successive saves produce the correct behavior"""
# Setup a temporary directory for files
save_path = tmp_path / "."

ens = parquet_ensemble

# Add a few result frames
ens.batch(np.mean, "psFlux", label="mean")
ens.batch(np.max, "psFlux", label="max")

# Write first with all frames
ens.save_ensemble(save_path, dirname="ensemble", additional_frames=True)

# Inspect the save directory
dircontents = os.listdir(os.path.join(save_path, "ensemble"))
assert "max" in dircontents # "max" should have been added

# Then write again with "max" not included
ens.save_ensemble(save_path, dirname="ensemble", additional_frames=["mean"])

# Inspect the save directory
dircontents = os.listdir(os.path.join(save_path, "ensemble"))
assert "max" not in dircontents # "max" should have been removed


def test_insert(parquet_ensemble):
num_partitions = parquet_ensemble.source.npartitions
(old_object, old_source) = parquet_ensemble.compute()
Expand Down

0 comments on commit aad1978

Please sign in to comment.