This repository has been archived by the owner on Jan 14, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #217 from lincc-frameworks/tape_ensemble_refactor_…
…working Ensembles can now track a group of labeled frames
- Loading branch information
Showing
2 changed files
with
252 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,9 +13,12 @@ | |
|
||
from .analysis.structure_function import SF_METHODS | ||
from .analysis.structurefunction2 import calc_sf2 | ||
from .ensemble_frame import EnsembleFrame, TapeFrame | ||
from .timeseries import TimeSeries | ||
from .utils import ColumnMapper | ||
|
||
SOURCE_FRAME_LABEL = "source" | ||
OBJECT_FRAME_LABEL = "object" | ||
|
||
class Ensemble: | ||
"""Ensemble object is a collection of light curve ids""" | ||
|
@@ -26,6 +29,12 @@ def __init__(self, client=None, **kwargs): | |
self._source = None # Source Table | ||
self._object = None # Object Table | ||
|
||
self.frames = {} # Frames managed by this Ensemble, keyed by label | ||
|
||
# TODO([email protected]) Replace self._source and self._object with these | ||
self.source = None # Source Table EnsembleFrame | ||
self.object = None # Object Table EnsembleFrame | ||
|
||
self._source_dirty = False # Source Dirty Flag | ||
self._object_dirty = False # Object Dirty Flag | ||
|
||
|
@@ -67,6 +76,152 @@ def __del__(self): | |
self.client.close() | ||
return self | ||
|
||
def add_frame(self, frame, label): | ||
"""Adds a new frame for the Ensemble to track. | ||
Parameters | ||
---------- | ||
frame: `tape.ensemble.EnsembleFrame` | ||
The frame object for the Ensemble to track. | ||
label: `str` | ||
| The label for the Ensemble to use to track the frame. | ||
Returns | ||
------- | ||
self: `Ensemble` | ||
Raises | ||
------ | ||
ValueError if the label is "source", "object", or already tracked by the Ensemble. | ||
""" | ||
if label == SOURCE_FRAME_LABEL or label == OBJECT_FRAME_LABEL: | ||
raise ValueError( | ||
f"Unable to add frame with reserved label " f"'{label}'" | ||
) | ||
if label in self.frames: | ||
raise ValueError( | ||
f"Unable to add frame: a frame with label " f"'{label}'" f"is in the Ensemble." | ||
) | ||
# Assign the frame to the requested tracking label. | ||
frame.label = label | ||
# Update the ensemble to track this labeled frame. | ||
self.update_frame(frame) | ||
return self | ||
|
||
def update_frame(self, frame): | ||
"""Updates a frame tracked by the Ensemble or otherwise adds it to the Ensemble. | ||
The frame is tracked by its `EnsembleFrame.label` field. | ||
Parameters | ||
---------- | ||
frame: `tape.ensemble.EnsembleFrame` | ||
The frame for the Ensemble to update. If not already tracked, it is added. | ||
Returns | ||
------- | ||
self: `Ensemble` | ||
Raises | ||
------ | ||
ValueError if the `frame.label` is unpopulated, "source", or "object". | ||
""" | ||
if frame.label is None: | ||
raise ValueError( | ||
f"Unable to update frame with no populated `EnsembleFrame.label`." | ||
) | ||
if frame.label == SOURCE_FRAME_LABEL or frame.label == OBJECT_FRAME_LABEL: | ||
raise ValueError( | ||
f"Unable to update frame with reserved label " f"'{frame.label}'" | ||
) | ||
# Ensure this frame is assigned to this Ensemble. | ||
frame.ensemble = self | ||
self.frames[frame.label] = frame | ||
return self | ||
|
||
def drop_frame(self, label): | ||
"""Drops a frame tracked by the Ensemble. | ||
Parameters | ||
---------- | ||
label: `str` | ||
| The label of the frame to be dropped by the Ensemble. | ||
Returns | ||
------- | ||
self: `Ensemble` | ||
Raises | ||
------ | ||
ValueError if the label is "source", or "object". | ||
KeyError if the label is not tracked by the Ensemble. | ||
""" | ||
if label == SOURCE_FRAME_LABEL or label == OBJECT_FRAME_LABEL: | ||
raise ValueError( | ||
f"Unable to drop frame with reserved label " f"'{label}'" | ||
) | ||
if label not in self.frames: | ||
raise KeyError( | ||
f"Unable to drop frame: no frame with label " f"'{label}'" f"is in the Ensemble." | ||
) | ||
del self.frames[label] | ||
return self | ||
|
||
def select_frame(self, label): | ||
"""Selects and returns frame tracked by the Ensemble. | ||
Parameters | ||
---------- | ||
label: `str` | ||
| The label of a frame tracked by the Ensemble to be selected. | ||
Returns | ||
------- | ||
result: `tape.ensemeble.EnsembleFrame` | ||
Raises | ||
------ | ||
KeyError if the label is not tracked by the Ensemble. | ||
""" | ||
if label not in self.frames: | ||
raise KeyError( | ||
f"Unable to select frame: no frame with label" f"'{label}'" f" is in the Ensemble." | ||
) | ||
return self.frames[label] | ||
|
||
def frame_info(self, labels=None, verbose=True, memory_usage=True, **kwargs): | ||
"""Wrapper for calling dask.dataframe.DataFrame.info() on frames tracked by the Ensemble. | ||
Parameters | ||
---------- | ||
labels: `list`, optional | ||
A list of labels for Ensemble frames to summarize. | ||
If None, info is printed for all tracked frames. | ||
verbose: `bool`, optional | ||
Whether to print the whole summary | ||
memory_usage: `bool`, optional | ||
Specifies whether total memory usage of the DataFrame elements | ||
(including the index) should be displayed. | ||
**kwargs: | ||
keyword arguments passed along to | ||
`dask.dataframe.DataFrame.info()` | ||
Returns | ||
------- | ||
None | ||
Raises | ||
------ | ||
KeyError if a label in labels is not tracked by the Ensemble. | ||
""" | ||
if labels is None: | ||
labels = self.frames.keys() | ||
for label in labels: | ||
if label not in self.frames: | ||
raise KeyError( | ||
f"Unable to get frame info: no frame with label " f"'{label}'" f" is in the Ensemble." | ||
) | ||
print(label, "Frame") | ||
print(self.frames[label].info(verbose=verbose, memory_usage=memory_usage, **kwargs)) | ||
|
||
def insert_sources( | ||
self, | ||
obj_ids, | ||
|
@@ -174,7 +329,7 @@ def client_info(self): | |
return self.client # Prints Dask dashboard to screen | ||
|
||
def info(self, verbose=True, memory_usage=True, **kwargs): | ||
"""Wrapper for dask.dataframe.DataFrame.info() | ||
"""Wrapper for dask.dataframe.DataFrame.info() for the Source and Object tables | ||
Parameters | ||
---------- | ||
|
@@ -185,8 +340,7 @@ def info(self, verbose=True, memory_usage=True, **kwargs): | |
(including the index) should be displayed. | ||
Returns | ||
---------- | ||
counts: `pandas.series` | ||
A series of counts by object | ||
None | ||
""" | ||
# Sync tables if user wants to retrieve their information | ||
self._lazy_sync_tables(table="all") | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
import pandas as pd | ||
import pytest | ||
|
||
from tape import Ensemble | ||
from tape import Ensemble, EnsembleFrame, TapeFrame | ||
from tape.analysis.stetsonj import calc_stetson_J | ||
from tape.analysis.structure_function.base_argument_container import StructureFunctionArgumentContainer | ||
from tape.analysis.structurefunction2 import calc_sf2 | ||
|
@@ -78,6 +78,97 @@ def test_available_datasets(dask_client): | |
assert isinstance(datasets, dict) | ||
assert len(datasets) > 0 # Find at least one | ||
|
||
@pytest.mark.parametrize( | ||
"data_fixture", | ||
[ | ||
"ensemble_from_source_dict", | ||
], | ||
) | ||
def test_frame_tracking(data_fixture, request): | ||
""" | ||
Tests a workflow of adding and removing the frames tracked by the Ensemble. | ||
""" | ||
ens, data = request.getfixturevalue(data_fixture) | ||
|
||
# Construct frames for the Ensemble to track. For this test, the underlying data is irrelevant. | ||
ens_frame1 = EnsembleFrame.from_dict(data, npartitions=1) | ||
ens_frame2 = EnsembleFrame.from_dict(data, npartitions=1) | ||
ens_frame3 = EnsembleFrame.from_dict(data, npartitions=1) | ||
ens_frame4 = EnsembleFrame.from_dict(data, npartitions=1) | ||
|
||
# Labels to give the EnsembleFrames | ||
label1, label2, label3, label4 = "frame1", "frame2", "frame3", "frame4" | ||
|
||
assert not ens.frames | ||
|
||
# TODO([email protected]) Remove once Ensemble.source and Ensemble.object are populated by loaders | ||
ens.source = EnsembleFrame.from_tapeframe( | ||
TapeFrame(ens._source), label="source", npartitions=1) | ||
ens.object = EnsembleFrame.from_tapeframe( | ||
TapeFrame(ens._source), label="object", npartitions=1) | ||
ens.frames["source"] = ens.source | ||
ens.frames["object"] = ens.object | ||
|
||
# Check that we can select source and object frames | ||
assert len(ens.frames) == 2 | ||
assert ens.select_frame("source") is ens.source | ||
assert ens.select_frame("object") is ens.object | ||
|
||
# Validate that new source and object frames can't be added or updated. | ||
with pytest.raises(ValueError): | ||
ens.add_frame(ens_frame1, "source") | ||
with pytest.raises(ValueError): | ||
ens.add_frame(ens_frame1, "object") | ||
with pytest.raises(ValueError): | ||
ens.update_frame(ens.source) | ||
with pytest.raises(ValueError): | ||
ens.update_frame(ens.object) | ||
|
||
# Test that we can add and select a new ensemble frame | ||
assert ens.add_frame(ens_frame1, label1).select_frame(label1) is ens_frame1 | ||
assert len(ens.frames) == 3 | ||
|
||
# Validate that we can't add a new frame that uses an exisiting label | ||
with pytest.raises(ValueError): | ||
ens.add_frame(ens_frame2, label1) | ||
|
||
# We add two more frames to track | ||
ens.add_frame(ens_frame2, label2).add_frame(ens_frame3, label3) | ||
assert ens.select_frame(label2) is ens_frame2 | ||
assert ens.select_frame(label3) is ens_frame3 | ||
assert len(ens.frames) == 5 | ||
|
||
# Now we begin dropping frames. First verifyt that we can't drop object or source. | ||
with pytest.raises(ValueError): | ||
ens.drop_frame("source") | ||
with pytest.raises(ValueError): | ||
ens.drop_frame("object") | ||
|
||
# And verify that we can't call drop with an unknown label. | ||
with pytest.raises(KeyError): | ||
ens.drop_frame("nonsense") | ||
|
||
# Drop an existing frame and that it can no longer be selected. | ||
ens.drop_frame(label3) | ||
assert len(ens.frames) == 4 | ||
with pytest.raises(KeyError): | ||
ens.select_frame(label3) | ||
|
||
# Update the ensemble with the dropped frame, and then select the frame | ||
assert ens.update_frame(ens_frame3).select_frame(label3) is ens_frame3 | ||
assert len(ens.frames) == 5 | ||
|
||
# Update the ensemble with a new frame, verifying a missing label generates an error. | ||
with pytest.raises(ValueError): | ||
ens.update_frame(ens_frame4) | ||
ens_frame4.label = label4 | ||
assert ens.update_frame(ens_frame4).select_frame(label4) is ens_frame4 | ||
assert len(ens.frames) == 6 | ||
|
||
# Change the label of the 4th ensemble frame to verify update overrides an existing frame | ||
ens_frame4.label = label3 | ||
assert ens.update_frame(ens_frame4).select_frame(label3) is ens_frame4 | ||
assert len(ens.frames) == 6 | ||
|
||
def test_from_rrl_dataset(dask_client): | ||
""" | ||
|
@@ -291,6 +382,9 @@ def test_core_wrappers(parquet_ensemble): | |
# Just test if these execute successfully | ||
parquet_ensemble.client_info() | ||
parquet_ensemble.info() | ||
parquet_ensemble.frame_info() | ||
with pytest.raises(KeyError): | ||
parquet_ensemble.frame_info(labels=["source", "invalid_label"]) | ||
parquet_ensemble.columns() | ||
parquet_ensemble.head(n=5) | ||
parquet_ensemble.tail(n=5) | ||
|