Skip to content
This repository has been archived by the owner on Jan 14, 2025. It is now read-only.

Commit

Permalink
Merge pull request #217 from lincc-frameworks/tape_ensemble_refactor_…
Browse files Browse the repository at this point in the history
…working

Ensembles can now track a group of labeled frames
  • Loading branch information
wilsonbb authored Aug 31, 2023
2 parents 17f9cc1 + 72b8629 commit 1cd049e
Show file tree
Hide file tree
Showing 2 changed files with 252 additions and 4 deletions.
160 changes: 157 additions & 3 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@

from .analysis.structure_function import SF_METHODS
from .analysis.structurefunction2 import calc_sf2
from .ensemble_frame import EnsembleFrame, TapeFrame
from .timeseries import TimeSeries
from .utils import ColumnMapper

SOURCE_FRAME_LABEL = "source"
OBJECT_FRAME_LABEL = "object"

class Ensemble:
"""Ensemble object is a collection of light curve ids"""
Expand All @@ -26,6 +29,12 @@ def __init__(self, client=None, **kwargs):
self._source = None # Source Table
self._object = None # Object Table

self.frames = {} # Frames managed by this Ensemble, keyed by label

# TODO([email protected]) Replace self._source and self._object with these
self.source = None # Source Table EnsembleFrame
self.object = None # Object Table EnsembleFrame

self._source_dirty = False # Source Dirty Flag
self._object_dirty = False # Object Dirty Flag

Expand Down Expand Up @@ -67,6 +76,152 @@ def __del__(self):
self.client.close()
return self

def add_frame(self, frame, label):
"""Adds a new frame for the Ensemble to track.
Parameters
----------
frame: `tape.ensemble.EnsembleFrame`
The frame object for the Ensemble to track.
label: `str`
| The label for the Ensemble to use to track the frame.
Returns
-------
self: `Ensemble`
Raises
------
ValueError if the label is "source", "object", or already tracked by the Ensemble.
"""
if label == SOURCE_FRAME_LABEL or label == OBJECT_FRAME_LABEL:
raise ValueError(
f"Unable to add frame with reserved label " f"'{label}'"
)
if label in self.frames:
raise ValueError(
f"Unable to add frame: a frame with label " f"'{label}'" f"is in the Ensemble."
)
# Assign the frame to the requested tracking label.
frame.label = label
# Update the ensemble to track this labeled frame.
self.update_frame(frame)
return self

def update_frame(self, frame):
"""Updates a frame tracked by the Ensemble or otherwise adds it to the Ensemble.
The frame is tracked by its `EnsembleFrame.label` field.
Parameters
----------
frame: `tape.ensemble.EnsembleFrame`
The frame for the Ensemble to update. If not already tracked, it is added.
Returns
-------
self: `Ensemble`
Raises
------
ValueError if the `frame.label` is unpopulated, "source", or "object".
"""
if frame.label is None:
raise ValueError(
f"Unable to update frame with no populated `EnsembleFrame.label`."
)
if frame.label == SOURCE_FRAME_LABEL or frame.label == OBJECT_FRAME_LABEL:
raise ValueError(
f"Unable to update frame with reserved label " f"'{frame.label}'"
)
# Ensure this frame is assigned to this Ensemble.
frame.ensemble = self
self.frames[frame.label] = frame
return self

def drop_frame(self, label):
"""Drops a frame tracked by the Ensemble.
Parameters
----------
label: `str`
| The label of the frame to be dropped by the Ensemble.
Returns
-------
self: `Ensemble`
Raises
------
ValueError if the label is "source", or "object".
KeyError if the label is not tracked by the Ensemble.
"""
if label == SOURCE_FRAME_LABEL or label == OBJECT_FRAME_LABEL:
raise ValueError(
f"Unable to drop frame with reserved label " f"'{label}'"
)
if label not in self.frames:
raise KeyError(
f"Unable to drop frame: no frame with label " f"'{label}'" f"is in the Ensemble."
)
del self.frames[label]
return self

def select_frame(self, label):
"""Selects and returns frame tracked by the Ensemble.
Parameters
----------
label: `str`
| The label of a frame tracked by the Ensemble to be selected.
Returns
-------
result: `tape.ensemeble.EnsembleFrame`
Raises
------
KeyError if the label is not tracked by the Ensemble.
"""
if label not in self.frames:
raise KeyError(
f"Unable to select frame: no frame with label" f"'{label}'" f" is in the Ensemble."
)
return self.frames[label]

def frame_info(self, labels=None, verbose=True, memory_usage=True, **kwargs):
"""Wrapper for calling dask.dataframe.DataFrame.info() on frames tracked by the Ensemble.
Parameters
----------
labels: `list`, optional
A list of labels for Ensemble frames to summarize.
If None, info is printed for all tracked frames.
verbose: `bool`, optional
Whether to print the whole summary
memory_usage: `bool`, optional
Specifies whether total memory usage of the DataFrame elements
(including the index) should be displayed.
**kwargs:
keyword arguments passed along to
`dask.dataframe.DataFrame.info()`
Returns
-------
None
Raises
------
KeyError if a label in labels is not tracked by the Ensemble.
"""
if labels is None:
labels = self.frames.keys()
for label in labels:
if label not in self.frames:
raise KeyError(
f"Unable to get frame info: no frame with label " f"'{label}'" f" is in the Ensemble."
)
print(label, "Frame")
print(self.frames[label].info(verbose=verbose, memory_usage=memory_usage, **kwargs))

def insert_sources(
self,
obj_ids,
Expand Down Expand Up @@ -174,7 +329,7 @@ def client_info(self):
return self.client # Prints Dask dashboard to screen

def info(self, verbose=True, memory_usage=True, **kwargs):
"""Wrapper for dask.dataframe.DataFrame.info()
"""Wrapper for dask.dataframe.DataFrame.info() for the Source and Object tables
Parameters
----------
Expand All @@ -185,8 +340,7 @@ def info(self, verbose=True, memory_usage=True, **kwargs):
(including the index) should be displayed.
Returns
----------
counts: `pandas.series`
A series of counts by object
None
"""
# Sync tables if user wants to retrieve their information
self._lazy_sync_tables(table="all")
Expand Down
96 changes: 95 additions & 1 deletion tests/tape_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
import pytest

from tape import Ensemble
from tape import Ensemble, EnsembleFrame, TapeFrame
from tape.analysis.stetsonj import calc_stetson_J
from tape.analysis.structure_function.base_argument_container import StructureFunctionArgumentContainer
from tape.analysis.structurefunction2 import calc_sf2
Expand Down Expand Up @@ -78,6 +78,97 @@ def test_available_datasets(dask_client):
assert isinstance(datasets, dict)
assert len(datasets) > 0 # Find at least one

@pytest.mark.parametrize(
"data_fixture",
[
"ensemble_from_source_dict",
],
)
def test_frame_tracking(data_fixture, request):
"""
Tests a workflow of adding and removing the frames tracked by the Ensemble.
"""
ens, data = request.getfixturevalue(data_fixture)

# Construct frames for the Ensemble to track. For this test, the underlying data is irrelevant.
ens_frame1 = EnsembleFrame.from_dict(data, npartitions=1)
ens_frame2 = EnsembleFrame.from_dict(data, npartitions=1)
ens_frame3 = EnsembleFrame.from_dict(data, npartitions=1)
ens_frame4 = EnsembleFrame.from_dict(data, npartitions=1)

# Labels to give the EnsembleFrames
label1, label2, label3, label4 = "frame1", "frame2", "frame3", "frame4"

assert not ens.frames

# TODO([email protected]) Remove once Ensemble.source and Ensemble.object are populated by loaders
ens.source = EnsembleFrame.from_tapeframe(
TapeFrame(ens._source), label="source", npartitions=1)
ens.object = EnsembleFrame.from_tapeframe(
TapeFrame(ens._source), label="object", npartitions=1)
ens.frames["source"] = ens.source
ens.frames["object"] = ens.object

# Check that we can select source and object frames
assert len(ens.frames) == 2
assert ens.select_frame("source") is ens.source
assert ens.select_frame("object") is ens.object

# Validate that new source and object frames can't be added or updated.
with pytest.raises(ValueError):
ens.add_frame(ens_frame1, "source")
with pytest.raises(ValueError):
ens.add_frame(ens_frame1, "object")
with pytest.raises(ValueError):
ens.update_frame(ens.source)
with pytest.raises(ValueError):
ens.update_frame(ens.object)

# Test that we can add and select a new ensemble frame
assert ens.add_frame(ens_frame1, label1).select_frame(label1) is ens_frame1
assert len(ens.frames) == 3

# Validate that we can't add a new frame that uses an exisiting label
with pytest.raises(ValueError):
ens.add_frame(ens_frame2, label1)

# We add two more frames to track
ens.add_frame(ens_frame2, label2).add_frame(ens_frame3, label3)
assert ens.select_frame(label2) is ens_frame2
assert ens.select_frame(label3) is ens_frame3
assert len(ens.frames) == 5

# Now we begin dropping frames. First verifyt that we can't drop object or source.
with pytest.raises(ValueError):
ens.drop_frame("source")
with pytest.raises(ValueError):
ens.drop_frame("object")

# And verify that we can't call drop with an unknown label.
with pytest.raises(KeyError):
ens.drop_frame("nonsense")

# Drop an existing frame and that it can no longer be selected.
ens.drop_frame(label3)
assert len(ens.frames) == 4
with pytest.raises(KeyError):
ens.select_frame(label3)

# Update the ensemble with the dropped frame, and then select the frame
assert ens.update_frame(ens_frame3).select_frame(label3) is ens_frame3
assert len(ens.frames) == 5

# Update the ensemble with a new frame, verifying a missing label generates an error.
with pytest.raises(ValueError):
ens.update_frame(ens_frame4)
ens_frame4.label = label4
assert ens.update_frame(ens_frame4).select_frame(label4) is ens_frame4
assert len(ens.frames) == 6

# Change the label of the 4th ensemble frame to verify update overrides an existing frame
ens_frame4.label = label3
assert ens.update_frame(ens_frame4).select_frame(label3) is ens_frame4
assert len(ens.frames) == 6

def test_from_rrl_dataset(dask_client):
"""
Expand Down Expand Up @@ -291,6 +382,9 @@ def test_core_wrappers(parquet_ensemble):
# Just test if these execute successfully
parquet_ensemble.client_info()
parquet_ensemble.info()
parquet_ensemble.frame_info()
with pytest.raises(KeyError):
parquet_ensemble.frame_info(labels=["source", "invalid_label"])
parquet_ensemble.columns()
parquet_ensemble.head(n=5)
parquet_ensemble.tail(n=5)
Expand Down

0 comments on commit 1cd049e

Please sign in to comment.