Skip to content
This repository has been archived by the owner on Jan 14, 2025. It is now read-only.

Ensembles can now track a group of labeled frames #217

Merged
merged 1 commit into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 157 additions & 3 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@

from .analysis.structure_function import SF_METHODS
from .analysis.structurefunction2 import calc_sf2
from .ensemble_frame import EnsembleFrame, TapeFrame
from .timeseries import TimeSeries
from .utils import ColumnMapper

SOURCE_FRAME_LABEL = "source"
OBJECT_FRAME_LABEL = "object"

class Ensemble:
"""Ensemble object is a collection of light curve ids"""
Expand All @@ -26,6 +29,12 @@ def __init__(self, client=None, **kwargs):
self._source = None # Source Table
self._object = None # Object Table

self.frames = {} # Frames managed by this Ensemble, keyed by label

# TODO([email protected]) Replace self._source and self._object with these
self.source = None # Source Table EnsembleFrame
self.object = None # Object Table EnsembleFrame

self._source_dirty = False # Source Dirty Flag
self._object_dirty = False # Object Dirty Flag

Expand Down Expand Up @@ -67,6 +76,152 @@ def __del__(self):
self.client.close()
return self

def add_frame(self, frame, label):
"""Adds a new frame for the Ensemble to track.

Parameters
----------
frame: `tape.ensemble.EnsembleFrame`
The frame object for the Ensemble to track.
label: `str`
| The label for the Ensemble to use to track the frame.

Returns
-------
self: `Ensemble`

Raises
------
ValueError if the label is "source", "object", or already tracked by the Ensemble.
"""
if label == SOURCE_FRAME_LABEL or label == OBJECT_FRAME_LABEL:
raise ValueError(
f"Unable to add frame with reserved label " f"'{label}'"
)
if label in self.frames:
raise ValueError(
f"Unable to add frame: a frame with label " f"'{label}'" f"is in the Ensemble."
)
# Assign the frame to the requested tracking label.
frame.label = label
# Update the ensemble to track this labeled frame.
self.update_frame(frame)
return self

def update_frame(self, frame):
"""Updates a frame tracked by the Ensemble or otherwise adds it to the Ensemble.
The frame is tracked by its `EnsembleFrame.label` field.

Parameters
----------
frame: `tape.ensemble.EnsembleFrame`
The frame for the Ensemble to update. If not already tracked, it is added.

Returns
-------
self: `Ensemble`

Raises
------
ValueError if the `frame.label` is unpopulated, "source", or "object".
"""
if frame.label is None:
raise ValueError(
f"Unable to update frame with no populated `EnsembleFrame.label`."
)
if frame.label == SOURCE_FRAME_LABEL or frame.label == OBJECT_FRAME_LABEL:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future, we will probably want to update this to check if the frame to be updated is an instance of an ObjectFrame or SourceFrame, we may want to allow users to update in that case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that makes sense to me!

raise ValueError(
f"Unable to update frame with reserved label " f"'{frame.label}'"
)
# Ensure this frame is assigned to this Ensemble.
frame.ensemble = self
self.frames[frame.label] = frame
return self

def drop_frame(self, label):
"""Drops a frame tracked by the Ensemble.

Parameters
----------
label: `str`
| The label of the frame to be dropped by the Ensemble.

Returns
-------
self: `Ensemble`

Raises
------
ValueError if the label is "source", or "object".
KeyError if the label is not tracked by the Ensemble.
"""
if label == SOURCE_FRAME_LABEL or label == OBJECT_FRAME_LABEL:
raise ValueError(
f"Unable to drop frame with reserved label " f"'{label}'"
)
if label not in self.frames:
raise KeyError(
f"Unable to drop frame: no frame with label " f"'{label}'" f"is in the Ensemble."
)
del self.frames[label]
return self

def select_frame(self, label):
"""Selects and returns frame tracked by the Ensemble.

Parameters
----------
label: `str`
| The label of a frame tracked by the Ensemble to be selected.

Returns
-------
result: `tape.ensemeble.EnsembleFrame`

Raises
------
KeyError if the label is not tracked by the Ensemble.
"""
if label not in self.frames:
raise KeyError(
f"Unable to select frame: no frame with label" f"'{label}'" f" is in the Ensemble."
)
return self.frames[label]

def frame_info(self, labels=None, verbose=True, memory_usage=True, **kwargs):
"""Wrapper for calling dask.dataframe.DataFrame.info() on frames tracked by the Ensemble.

Parameters
----------
labels: `list`, optional
A list of labels for Ensemble frames to summarize.
If None, info is printed for all tracked frames.
verbose: `bool`, optional
Whether to print the whole summary
memory_usage: `bool`, optional
Specifies whether total memory usage of the DataFrame elements
(including the index) should be displayed.
**kwargs:
keyword arguments passed along to
`dask.dataframe.DataFrame.info()`
Returns
-------
None

Raises
------
KeyError if a label in labels is not tracked by the Ensemble.
"""
if labels is None:
labels = self.frames.keys()
for label in labels:
if label not in self.frames:
raise KeyError(
f"Unable to get frame info: no frame with label " f"'{label}'" f" is in the Ensemble."
)
print(label, "Frame")
print(self.frames[label].info(verbose=verbose, memory_usage=memory_usage, **kwargs))

def insert_sources(
self,
obj_ids,
Expand Down Expand Up @@ -174,7 +329,7 @@ def client_info(self):
return self.client # Prints Dask dashboard to screen

def info(self, verbose=True, memory_usage=True, **kwargs):
"""Wrapper for dask.dataframe.DataFrame.info()
"""Wrapper for dask.dataframe.DataFrame.info() for the Source and Object tables

Parameters
----------
Expand All @@ -185,8 +340,7 @@ def info(self, verbose=True, memory_usage=True, **kwargs):
(including the index) should be displayed.
Returns
----------
counts: `pandas.series`
A series of counts by object
None
"""
# Sync tables if user wants to retrieve their information
self._lazy_sync_tables(table="all")
Expand Down
96 changes: 95 additions & 1 deletion tests/tape_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
import pytest

from tape import Ensemble
from tape import Ensemble, EnsembleFrame, TapeFrame
from tape.analysis.stetsonj import calc_stetson_J
from tape.analysis.structure_function.base_argument_container import StructureFunctionArgumentContainer
from tape.analysis.structurefunction2 import calc_sf2
Expand Down Expand Up @@ -78,6 +78,97 @@ def test_available_datasets(dask_client):
assert isinstance(datasets, dict)
assert len(datasets) > 0 # Find at least one

@pytest.mark.parametrize(
"data_fixture",
[
"ensemble_from_source_dict",
],
)
def test_frame_tracking(data_fixture, request):
"""
Tests a workflow of adding and removing the frames tracked by the Ensemble.
"""
ens, data = request.getfixturevalue(data_fixture)

# Construct frames for the Ensemble to track. For this test, the underlying data is irrelevant.
ens_frame1 = EnsembleFrame.from_dict(data, npartitions=1)
ens_frame2 = EnsembleFrame.from_dict(data, npartitions=1)
ens_frame3 = EnsembleFrame.from_dict(data, npartitions=1)
ens_frame4 = EnsembleFrame.from_dict(data, npartitions=1)

# Labels to give the EnsembleFrames
label1, label2, label3, label4 = "frame1", "frame2", "frame3", "frame4"

assert not ens.frames

# TODO([email protected]) Remove once Ensemble.source and Ensemble.object are populated by loaders
ens.source = EnsembleFrame.from_tapeframe(
TapeFrame(ens._source), label="source", npartitions=1)
ens.object = EnsembleFrame.from_tapeframe(
TapeFrame(ens._source), label="object", npartitions=1)
ens.frames["source"] = ens.source
ens.frames["object"] = ens.object

# Check that we can select source and object frames
assert len(ens.frames) == 2
assert ens.select_frame("source") is ens.source
assert ens.select_frame("object") is ens.object

# Validate that new source and object frames can't be added or updated.
with pytest.raises(ValueError):
ens.add_frame(ens_frame1, "source")
with pytest.raises(ValueError):
ens.add_frame(ens_frame1, "object")
with pytest.raises(ValueError):
ens.update_frame(ens.source)
with pytest.raises(ValueError):
ens.update_frame(ens.object)

# Test that we can add and select a new ensemble frame
assert ens.add_frame(ens_frame1, label1).select_frame(label1) is ens_frame1
assert len(ens.frames) == 3

# Validate that we can't add a new frame that uses an exisiting label
with pytest.raises(ValueError):
ens.add_frame(ens_frame2, label1)

# We add two more frames to track
ens.add_frame(ens_frame2, label2).add_frame(ens_frame3, label3)
assert ens.select_frame(label2) is ens_frame2
assert ens.select_frame(label3) is ens_frame3
assert len(ens.frames) == 5

# Now we begin dropping frames. First verifyt that we can't drop object or source.
with pytest.raises(ValueError):
ens.drop_frame("source")
with pytest.raises(ValueError):
ens.drop_frame("object")

# And verify that we can't call drop with an unknown label.
with pytest.raises(KeyError):
ens.drop_frame("nonsense")

# Drop an existing frame and that it can no longer be selected.
ens.drop_frame(label3)
assert len(ens.frames) == 4
with pytest.raises(KeyError):
ens.select_frame(label3)

# Update the ensemble with the dropped frame, and then select the frame
assert ens.update_frame(ens_frame3).select_frame(label3) is ens_frame3
assert len(ens.frames) == 5

# Update the ensemble with a new frame, verifying a missing label generates an error.
with pytest.raises(ValueError):
ens.update_frame(ens_frame4)
ens_frame4.label = label4
assert ens.update_frame(ens_frame4).select_frame(label4) is ens_frame4
assert len(ens.frames) == 6

# Change the label of the 4th ensemble frame to verify update overrides an existing frame
ens_frame4.label = label3
assert ens.update_frame(ens_frame4).select_frame(label3) is ens_frame4
assert len(ens.frames) == 6

def test_from_rrl_dataset(dask_client):
"""
Expand Down Expand Up @@ -291,6 +382,9 @@ def test_core_wrappers(parquet_ensemble):
# Just test if these execute successfully
parquet_ensemble.client_info()
parquet_ensemble.info()
parquet_ensemble.frame_info()
with pytest.raises(KeyError):
parquet_ensemble.frame_info(labels=["source", "invalid_label"])
parquet_ensemble.columns()
parquet_ensemble.head(n=5)
parquet_ensemble.tail(n=5)
Expand Down