diff --git a/src/tape/ensemble.py b/src/tape/ensemble.py index 839d39a7..1aa6cd7b 100644 --- a/src/tape/ensemble.py +++ b/src/tape/ensemble.py @@ -13,9 +13,12 @@ from .analysis.structure_function import SF_METHODS from .analysis.structurefunction2 import calc_sf2 +from .ensemble_frame import EnsembleFrame, TapeFrame from .timeseries import TimeSeries from .utils import ColumnMapper +SOURCE_FRAME_LABEL = "source" +OBJECT_FRAME_LABEL = "object" class Ensemble: """Ensemble object is a collection of light curve ids""" @@ -26,6 +29,12 @@ def __init__(self, client=None, **kwargs): self._source = None # Source Table self._object = None # Object Table + self.frames = {} # Frames managed by this Ensemble, keyed by label + + # TODO(wbeebe@uw.edu) Replace self._source and self._object with these + self.source = None # Source Table EnsembleFrame + self.object = None # Object Table EnsembleFrame + self._source_dirty = False # Source Dirty Flag self._object_dirty = False # Object Dirty Flag @@ -67,6 +76,152 @@ def __del__(self): self.client.close() return self + def add_frame(self, frame, label): + """Adds a new frame for the Ensemble to track. + + Parameters + ---------- + frame: `tape.ensemble.EnsembleFrame` + The frame object for the Ensemble to track. + label: `str` + | The label for the Ensemble to use to track the frame. + + Returns + ------- + self: `Ensemble` + + Raises + ------ + ValueError if the label is "source", "object", or already tracked by the Ensemble. + """ + if label == SOURCE_FRAME_LABEL or label == OBJECT_FRAME_LABEL: + raise ValueError( + f"Unable to add frame with reserved label " f"'{label}'" + ) + if label in self.frames: + raise ValueError( + f"Unable to add frame: a frame with label " f"'{label}'" f"is in the Ensemble." + ) + # Assign the frame to the requested tracking label. + frame.label = label + # Update the ensemble to track this labeled frame. + self.update_frame(frame) + return self + + def update_frame(self, frame): + """Updates a frame tracked by the Ensemble or otherwise adds it to the Ensemble. + The frame is tracked by its `EnsembleFrame.label` field. + + Parameters + ---------- + frame: `tape.ensemble.EnsembleFrame` + The frame for the Ensemble to update. If not already tracked, it is added. + + Returns + ------- + self: `Ensemble` + + Raises + ------ + ValueError if the `frame.label` is unpopulated, "source", or "object". + """ + if frame.label is None: + raise ValueError( + f"Unable to update frame with no populated `EnsembleFrame.label`." + ) + if frame.label == SOURCE_FRAME_LABEL or frame.label == OBJECT_FRAME_LABEL: + raise ValueError( + f"Unable to update frame with reserved label " f"'{frame.label}'" + ) + # Ensure this frame is assigned to this Ensemble. + frame.ensemble = self + self.frames[frame.label] = frame + return self + + def drop_frame(self, label): + """Drops a frame tracked by the Ensemble. + + Parameters + ---------- + label: `str` + | The label of the frame to be dropped by the Ensemble. + + Returns + ------- + self: `Ensemble` + + Raises + ------ + ValueError if the label is "source", or "object". + KeyError if the label is not tracked by the Ensemble. + """ + if label == SOURCE_FRAME_LABEL or label == OBJECT_FRAME_LABEL: + raise ValueError( + f"Unable to drop frame with reserved label " f"'{label}'" + ) + if label not in self.frames: + raise KeyError( + f"Unable to drop frame: no frame with label " f"'{label}'" f"is in the Ensemble." + ) + del self.frames[label] + return self + + def select_frame(self, label): + """Selects and returns frame tracked by the Ensemble. + + Parameters + ---------- + label: `str` + | The label of a frame tracked by the Ensemble to be selected. + + Returns + ------- + result: `tape.ensemeble.EnsembleFrame` + + Raises + ------ + KeyError if the label is not tracked by the Ensemble. + """ + if label not in self.frames: + raise KeyError( + f"Unable to select frame: no frame with label" f"'{label}'" f" is in the Ensemble." + ) + return self.frames[label] + + def frame_info(self, labels=None, verbose=True, memory_usage=True, **kwargs): + """Wrapper for calling dask.dataframe.DataFrame.info() on frames tracked by the Ensemble. + + Parameters + ---------- + labels: `list`, optional + A list of labels for Ensemble frames to summarize. + If None, info is printed for all tracked frames. + verbose: `bool`, optional + Whether to print the whole summary + memory_usage: `bool`, optional + Specifies whether total memory usage of the DataFrame elements + (including the index) should be displayed. + **kwargs: + keyword arguments passed along to + `dask.dataframe.DataFrame.info()` + Returns + ------- + None + + Raises + ------ + KeyError if a label in labels is not tracked by the Ensemble. + """ + if labels is None: + labels = self.frames.keys() + for label in labels: + if label not in self.frames: + raise KeyError( + f"Unable to get frame info: no frame with label " f"'{label}'" f" is in the Ensemble." + ) + print(label, "Frame") + print(self.frames[label].info(verbose=verbose, memory_usage=memory_usage, **kwargs)) + def insert_sources( self, obj_ids, @@ -174,7 +329,7 @@ def client_info(self): return self.client # Prints Dask dashboard to screen def info(self, verbose=True, memory_usage=True, **kwargs): - """Wrapper for dask.dataframe.DataFrame.info() + """Wrapper for dask.dataframe.DataFrame.info() for the Source and Object tables Parameters ---------- @@ -185,8 +340,7 @@ def info(self, verbose=True, memory_usage=True, **kwargs): (including the index) should be displayed. Returns ---------- - counts: `pandas.series` - A series of counts by object + None """ # Sync tables if user wants to retrieve their information self._lazy_sync_tables(table="all") diff --git a/tests/tape_tests/test_ensemble.py b/tests/tape_tests/test_ensemble.py index 49f92238..e29c89b9 100644 --- a/tests/tape_tests/test_ensemble.py +++ b/tests/tape_tests/test_ensemble.py @@ -6,7 +6,7 @@ import pandas as pd import pytest -from tape import Ensemble +from tape import Ensemble, EnsembleFrame, TapeFrame from tape.analysis.stetsonj import calc_stetson_J from tape.analysis.structure_function.base_argument_container import StructureFunctionArgumentContainer from tape.analysis.structurefunction2 import calc_sf2 @@ -78,6 +78,97 @@ def test_available_datasets(dask_client): assert isinstance(datasets, dict) assert len(datasets) > 0 # Find at least one +@pytest.mark.parametrize( + "data_fixture", + [ + "ensemble_from_source_dict", + ], +) +def test_frame_tracking(data_fixture, request): + """ + Tests a workflow of adding and removing the frames tracked by the Ensemble. + """ + ens, data = request.getfixturevalue(data_fixture) + + # Construct frames for the Ensemble to track. For this test, the underlying data is irrelevant. + ens_frame1 = EnsembleFrame.from_dict(data, npartitions=1) + ens_frame2 = EnsembleFrame.from_dict(data, npartitions=1) + ens_frame3 = EnsembleFrame.from_dict(data, npartitions=1) + ens_frame4 = EnsembleFrame.from_dict(data, npartitions=1) + + # Labels to give the EnsembleFrames + label1, label2, label3, label4 = "frame1", "frame2", "frame3", "frame4" + + assert not ens.frames + + # TODO(wbeebe@uw.edu) Remove once Ensemble.source and Ensemble.object are populated by loaders + ens.source = EnsembleFrame.from_tapeframe( + TapeFrame(ens._source), label="source", npartitions=1) + ens.object = EnsembleFrame.from_tapeframe( + TapeFrame(ens._source), label="object", npartitions=1) + ens.frames["source"] = ens.source + ens.frames["object"] = ens.object + + # Check that we can select source and object frames + assert len(ens.frames) == 2 + assert ens.select_frame("source") is ens.source + assert ens.select_frame("object") is ens.object + + # Validate that new source and object frames can't be added or updated. + with pytest.raises(ValueError): + ens.add_frame(ens_frame1, "source") + with pytest.raises(ValueError): + ens.add_frame(ens_frame1, "object") + with pytest.raises(ValueError): + ens.update_frame(ens.source) + with pytest.raises(ValueError): + ens.update_frame(ens.object) + + # Test that we can add and select a new ensemble frame + assert ens.add_frame(ens_frame1, label1).select_frame(label1) is ens_frame1 + assert len(ens.frames) == 3 + + # Validate that we can't add a new frame that uses an exisiting label + with pytest.raises(ValueError): + ens.add_frame(ens_frame2, label1) + + # We add two more frames to track + ens.add_frame(ens_frame2, label2).add_frame(ens_frame3, label3) + assert ens.select_frame(label2) is ens_frame2 + assert ens.select_frame(label3) is ens_frame3 + assert len(ens.frames) == 5 + + # Now we begin dropping frames. First verifyt that we can't drop object or source. + with pytest.raises(ValueError): + ens.drop_frame("source") + with pytest.raises(ValueError): + ens.drop_frame("object") + + # And verify that we can't call drop with an unknown label. + with pytest.raises(KeyError): + ens.drop_frame("nonsense") + + # Drop an existing frame and that it can no longer be selected. + ens.drop_frame(label3) + assert len(ens.frames) == 4 + with pytest.raises(KeyError): + ens.select_frame(label3) + + # Update the ensemble with the dropped frame, and then select the frame + assert ens.update_frame(ens_frame3).select_frame(label3) is ens_frame3 + assert len(ens.frames) == 5 + + # Update the ensemble with a new frame, verifying a missing label generates an error. + with pytest.raises(ValueError): + ens.update_frame(ens_frame4) + ens_frame4.label = label4 + assert ens.update_frame(ens_frame4).select_frame(label4) is ens_frame4 + assert len(ens.frames) == 6 + + # Change the label of the 4th ensemble frame to verify update overrides an existing frame + ens_frame4.label = label3 + assert ens.update_frame(ens_frame4).select_frame(label3) is ens_frame4 + assert len(ens.frames) == 6 def test_from_rrl_dataset(dask_client): """ @@ -291,6 +382,9 @@ def test_core_wrappers(parquet_ensemble): # Just test if these execute successfully parquet_ensemble.client_info() parquet_ensemble.info() + parquet_ensemble.frame_info() + with pytest.raises(KeyError): + parquet_ensemble.frame_info(labels=["source", "invalid_label"]) parquet_ensemble.columns() parquet_ensemble.head(n=5) parquet_ensemble.tail(n=5)