Skip to content

Commit

Permalink
adds the select_random_timeseries function
Browse files Browse the repository at this point in the history
  • Loading branch information
dougbrn committed Dec 21, 2023
1 parent 97a2bfc commit 4df0dd2
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
30 changes: 30 additions & 0 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1847,6 +1847,36 @@ def _sync_tables(self):
self.object.set_dirty(False)
return self

def select_random_timeseries(self, seed=None):
"""Selects a random lightcurve from the Ensemble
Parameters
----------
seed: int, or None
Sets a seed to return the same object id on successive runs. `None`
by default, in which case a seed is not set for the operation.
Returns
-------
ts: `TimeSeries`
Timeseries for a single object
"""

if seed is not None:
np.random.seed(seed)

# Avoid a choice from full index space, select a random partition to grab from
if self.object.npartitions > 1:
partition_num = np.random.randint(0, self.object.npartitions - 1)
partition_ids = self.object.get_partition(partition_num).index.values
lcid = np.random.choice(partition_ids)
else:
partition_num = 0
lcid = np.random.choice(self.object.index.values)
print(f"Selected Object {lcid} from Partition {partition_num}")
return self.to_timeseries(lcid)

def to_timeseries(
self,
target,
Expand Down
21 changes: 21 additions & 0 deletions tests/tape_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
TapeSeries,
TapeObjectFrame,
TapeSourceFrame,
TimeSeries,
)
from tape.analysis.stetsonj import calc_stetson_J
from tape.analysis.structure_function.base_argument_container import StructureFunctionArgumentContainer
Expand Down Expand Up @@ -1829,6 +1830,26 @@ def test_batch_with_custom_frame_meta(parquet_ensemble, custom_meta):
assert isinstance(parquet_ensemble.select_frame("sf2_result"), EnsembleFrame)


@pytest.mark.parametrize("repartition", [False, True])
@pytest.mark.parametrize("seed", [None, 42])
def test_select_random_timeseries(parquet_ensemble, repartition, seed):
"""Test the behavior of ensemble.select_random_timeseries"""

ens = parquet_ensemble

if repartition:
ens.object = ens.object.repartition(3)

ts = ens.select_random_timeseries(seed=seed)

assert isinstance(ts, TimeSeries)

if seed == 42 and not repartition:
assert ts.meta["id"] == 88480000587403327
elif seed == 42 and repartition:
assert ts.meta["id"] == 88480000310609896


def test_to_timeseries(parquet_ensemble):
"""
Test that ensemble.to_timeseries() runs and assigns the correct metadata
Expand Down

0 comments on commit 4df0dd2

Please sign in to comment.