Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for running batch on a single lightcurve #420

Merged
merged 7 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion docs/tutorials/batch_showcase.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,43 @@
"res1.compute() # Compute to see the result"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"By default `Ensemble.batch` will apply your function across all light curves. However with the `single_lc` parameter you can test out your function on only a single object."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# The same batch call as above but now only on the final lightcurve.\n",
"\n",
"lc_id = 109 # id of the final lightcurve in the data\n",
"lc_res = ens.batch(my_mean, \"flux\", single_lc=lc_id)\n",
"lc_res.compute()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also run your batch function on a single random lightcurve with `single_lc=True`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rand_lc = ens.batch(my_mean, \"flux\", single_lc=True)\n",
"rand_lc.compute()"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down Expand Up @@ -586,7 +623,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.10.14"
},
"vscode": {
"interpreter": {
Expand Down
49 changes: 39 additions & 10 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,7 @@
use_map=True,
on=None,
label="",
single_lc=False,
**kwargs,
):
"""Run a function from tape.TimeSeries on the available ids
Expand Down Expand Up @@ -1107,6 +1108,11 @@
source or object tables. If not specified, then the id column is
used by default. For TAPE and `light-curve` functions this is
populated automatically.
single_lc: `boolean` or `int`, optional
If a `boolean` and True, batch will only execute on a single randomly selected
lightcurve. If False, batch will execute across all lightcurves as normal. To
specify a specific lightcurve, the integer value of the lightcurve's id can be
provided. Default is False.
label: 'str', optional
If provided the ensemble will use this label to track the result
dataframe. If not provided, a label of the from "result_{x}" where x
Expand All @@ -1133,6 +1139,12 @@
from light_curve import EtaE
ens.batch(EtaE(), band_to_calc='g')

To run a TAPE function on a single lightcurve:
from tape.analysis.stetsonj import calc_stetson_J
ens = Ensemble().from_dataset('rrlyr82')
lc_id = 4378437892 # The lightcurve id
ensemble.batch(calc_stetson_J, band_to_calc='i', single_lc=lc_id)

Run a custom function on the ensemble::

def s2n_inter_quartile_range(flux, err):
Expand Down Expand Up @@ -1160,6 +1172,22 @@
if meta is None:
meta = (self._id_col, float) # return a series of ids, default assume a float is returned

src_to_batch = self.source
obj_to_batch = self.object

# Check if we only want to apply the batch function to a single lightcurve
if not isinstance(single_lc, bool) and not isinstance(single_lc, int):
raise ValueError("single_lc must be a boolean or an integer")
elif single_lc is True:
# Select the ID of a random lightcurve
rand_lc_id = self.select_random_timeseries(id_only=True)
src_to_batch = src_to_batch.loc[rand_lc_id]
obj_to_batch = obj_to_batch.loc[rand_lc_id]
elif single_lc is not False:
# The user provided the id of a specific lightcurve
src_to_batch = src_to_batch.loc[single_lc]
obj_to_batch = obj_to_batch.loc[single_lc]

# Translate the meta into an appropriate TapeFrame or TapeSeries. This ensures that the
# batch result will be an EnsembleFrame or EnsembleSeries.
meta = self._translate_meta(meta)
Expand All @@ -1178,15 +1206,13 @@
on[-1] = self._band_col

# Handle object columns to group on
source_cols = list(self.source.columns)
object_cols = list(self.object.columns)
source_cols = list(src_to_batch.columns)
object_cols = list(obj_to_batch.columns)
object_group_cols = [col for col in on if (col in object_cols) and (col not in source_cols)]

if len(object_group_cols) > 0:
object_col_dd = self.object[object_group_cols]
source_to_batch = self.source.merge(object_col_dd, how="left")
else:
source_to_batch = self.source # Can directly use the source table
obj_to_batch = obj_to_batch[object_group_cols]
src_to_batch = src_to_batch.merge(obj_to_batch, how="left")

Check warning on line 1215 in src/tape/ensemble.py

View check run for this annotation

Codecov / codecov/patch

src/tape/ensemble.py#L1214-L1215

Added lines #L1214 - L1215 were not covered by tests

id_col = self._id_col # pre-compute needed for dask in lambda function

Expand All @@ -1211,11 +1237,11 @@

id_col = self._id_col # need to grab this before mapping

batch = source_to_batch.map_partitions(_batch_apply, func, on, *args, **kwargs, meta=meta)
batch = src_to_batch.map_partitions(_batch_apply, func, on, *args, **kwargs, meta=meta)

else: # use groupby
# don't use _batch_apply as meta must be specified in the apply call
batch = source_to_batch.groupby(on, group_keys=True, sort=False).apply(
batch = src_to_batch.groupby(on, group_keys=True, sort=False).apply(
_apply_func_to_lc,
func,
*args,
Expand Down Expand Up @@ -2290,14 +2316,17 @@
self.object.set_dirty(False)
return self

def select_random_timeseries(self, seed=None):
def select_random_timeseries(self, seed=None, id_only=False):
"""Selects a random lightcurve from a random partition of the Ensemble.

Parameters
----------
seed: int, or None
Sets a seed to return the same object id on successive runs. `None`
by default, in which case a seed is not set for the operation.
id_only: bool, optional
If True, returns only a random object id. If False, returns the
full timeseries for the object. Default is False.

Returns
-------
Expand Down Expand Up @@ -2336,7 +2365,7 @@
if i >= len(partitions):
raise IndexError("Found no object IDs in the Object Table.")

return self.to_timeseries(lcid)
return self.to_timeseries(lcid) if not id_only else lcid

def to_timeseries(
self,
Expand Down
58 changes: 51 additions & 7 deletions tests/tape_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -2127,6 +2127,46 @@ def my_bounds(flux):
assert all([col in res.columns for col in res.compute().columns])


@pytest.mark.parametrize("data_fixture", ["parquet_ensemble", "parquet_ensemble_with_divisions"])
def test_batch_single_lc(data_fixture, request):
"""
Test that ensemble.batch() can run a function on a single light curve.
"""
parquet_ensemble = request.getfixturevalue(data_fixture)

# Perform batch only on this specific lightcurve.
lc = 88472935274829959

# Check that we raise an error if single_lc is neither a bool nor an integer
with pytest.raises(ValueError):
parquet_ensemble.batch(calc_stetson_J, use_map=True, on=None, band_to_calc=None, single_lc="foo")

lc_res = parquet_ensemble.prune(10).batch(
calc_stetson_J, use_map=True, on=None, band_to_calc=None, single_lc=lc
)
assert len(lc_res) == 1

# Now ensure that we got the same result when we ran the function on the entire ensemble.
full_res = parquet_ensemble.prune(10).batch(calc_stetson_J, use_map=True, on=None, band_to_calc=None)
assert full_res.compute().loc[lc].stetsonJ == lc_res.compute().iloc[0].stetsonJ

# Check that when single_lc is True we will perform batch on a random lightcurve and still get only one result.
rand_lc = parquet_ensemble.prune(10).batch(
calc_stetson_J, use_map=True, on=None, band_to_calc=None, single_lc=True
)
assert len(rand_lc) == 1

# Now compare that result to what was computed when doing the full batch result
rand_lc_id = rand_lc.index.compute().values[0]
assert full_res.compute().loc[rand_lc_id].stetsonJ == rand_lc.compute().iloc[0].stetsonJ

# Check that when single_lc is False we get the same # of results as the full batch
no_lc = parquet_ensemble.prune(10).batch(
calc_stetson_J, use_map=True, on=None, band_to_calc=None, single_lc=False
)
assert len(full_res) == len(no_lc)


def test_batch_labels(parquet_ensemble):
"""
Test that ensemble.batch() generates unique labels for result frames when none are provided.
Expand Down Expand Up @@ -2236,22 +2276,26 @@ def test_batch_with_custom_frame_meta(parquet_ensemble, custom_meta):

@pytest.mark.parametrize("repartition", [False, True])
@pytest.mark.parametrize("seed", [None, 42])
def test_select_random_timeseries(parquet_ensemble, repartition, seed):
@pytest.mark.parametrize("id_only", [True, False])
def test_select_random_timeseries(parquet_ensemble, repartition, seed, id_only):
"""Test the behavior of ensemble.select_random_timeseries"""

ens = parquet_ensemble

if repartition:
ens.object = ens.object.repartition(npartitions=3)

ts = ens.select_random_timeseries(seed=seed)
ts = ens.select_random_timeseries(seed=seed, id_only=id_only)

assert isinstance(ts, TimeSeries)
if not id_only:
assert isinstance(ts, TimeSeries)

if seed == 42 and not repartition:
assert ts.meta["id"] == 88472935274829959
elif seed == 42 and repartition:
assert ts.meta["id"] == 88480001333818899
if seed == 42:
expected_id = 88480001333818899 if repartition else 88472935274829959
if id_only:
assert ts == expected_id
else:
assert ts.meta["id"] == expected_id


@pytest.mark.parametrize("all_empty", [False, True])
Expand Down