Skip to content

Commit

Permalink
handle no-column object table; WIP unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
dougbrn committed Jan 9, 2024
1 parent 3c5ccbb commit eb0f2b7
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
16 changes: 14 additions & 2 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,13 @@ def save_ensemble(self, path=".", dirname="ensemble", additional_frames=True, **
# grab the dataframe from the frame label
frame = self.frames[frame_label]

# Object can have no columns, which parquet doesn't handle
# In this case, we'll avoid saving to parquet
if frame_label == "object":
if len(frame.columns) == 0:
print("The Object Frame was not saved as no columns were present.")
continue

Check warning on line 1305 in src/tape/ensemble.py

View check run for this annotation

Codecov / codecov/patch

src/tape/ensemble.py#L1304-L1305

Added lines #L1304 - L1305 were not covered by tests

# creates a subdirectory for the frame partition files
frame.to_parquet(os.path.join(ens_path, frame_label), **kwargs)

Expand Down Expand Up @@ -1342,8 +1349,13 @@ def from_ensemble(self, dirpath, additional_frames=True, column_mapper=None, **k

# Load Object and Source
obj_path = os.path.join(dirpath, "object")
src_path = os.path.join(dirpath, "object")
self.from_parquet(src_path, obj_path, column_mapper=column_mapper, **kwargs)
src_path = os.path.join(dirpath, "source")

Check warning on line 1352 in src/tape/ensemble.py

View check run for this annotation

Codecov / codecov/patch

src/tape/ensemble.py#L1351-L1352

Added lines #L1351 - L1352 were not covered by tests

# Check for whether or not object is present, it's not saved when no columns are present
if "object" in os.listdir(dirpath):
self.from_parquet(src_path, obj_path, column_mapper=column_mapper, **kwargs)

Check warning on line 1356 in src/tape/ensemble.py

View check run for this annotation

Codecov / codecov/patch

src/tape/ensemble.py#L1355-L1356

Added lines #L1355 - L1356 were not covered by tests
else:
self.from_parquet(src_path, column_mapper=column_mapper, **kwargs)

Check warning on line 1358 in src/tape/ensemble.py

View check run for this annotation

Codecov / codecov/patch

src/tape/ensemble.py#L1358

Added line #L1358 was not covered by tests

# Load all remaining frames
if additional_frames is False:
Expand Down
36 changes: 36 additions & 0 deletions tests/tape_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,42 @@ def test_read_source_dict(dask_client):
assert 8002 in obj_table.index


def test_save_and_load_ensemble(dask_client):
# Set a seed for reproducibility
np.random.seed(1)

# Create some toy data
obj_ids = np.array([])
mjds = np.array([])
for i in range(10, 110):
obj_ids = np.append(obj_ids, np.array([i] * 1250))
mjds = np.append(mjds, np.arange(0.0, 1250.0, 1.0))
obj_ids = np.array(obj_ids)
flux = 10 * np.random.random(125000)
err = flux / 10
band = np.random.choice(["g", "r"], 125000)

# Store it in a dictionary
source_dict = {"id": obj_ids, "mjd": mjds, "flux": flux, "err": err, "band": band}

# Create an Ensemble
ens = Ensemble()
ens.from_source_dict(
source_dict,
column_mapper=ColumnMapper(
id_col="id", time_col="mjd", flux_col="flux", err_col="err", band_col="band"
),
)

# Make a column for the object table
ens.calc_nobs(temporary=False)
# Add a few result frames
ens.batch(np.mean, "flux", label="mean")
ens.batch(np.max, "flux", label="max")

ens.save_ensemble("./ensemble")


def test_insert(parquet_ensemble):
num_partitions = parquet_ensemble.source.npartitions
(old_object, old_source) = parquet_ensemble.compute()
Expand Down

0 comments on commit eb0f2b7

Please sign in to comment.