diff --git a/docs/tutorials/batch_showcase.ipynb b/docs/tutorials/batch_showcase.ipynb index 4050405d..42b4d77b 100644 --- a/docs/tutorials/batch_showcase.ipynb +++ b/docs/tutorials/batch_showcase.ipynb @@ -123,6 +123,43 @@ "res1.compute() # Compute to see the result" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "By default `Ensemble.batch` will apply your function across all light curves. However with the `single_lc` parameter you can test out your function on only a single object." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The same batch call as above but now only on the final lightcurve.\n", + "\n", + "lc_id = 109 # id of the final lightcurve in the data\n", + "lc_res = ens.batch(my_mean, \"flux\", single_lc=lc_id)\n", + "lc_res.compute()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also run your batch function on a single random lightcurve with `single_lc=True`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rand_lc = ens.batch(my_mean, \"flux\", single_lc=True)\n", + "rand_lc.compute()" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -586,7 +623,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.14" }, "vscode": { "interpreter": { diff --git a/src/tape/ensemble.py b/src/tape/ensemble.py index 9a6ef7f9..ea13b494 100644 --- a/src/tape/ensemble.py +++ b/src/tape/ensemble.py @@ -1059,6 +1059,7 @@ def batch( use_map=True, on=None, label="", + single_lc=False, **kwargs, ): """Run a function from tape.TimeSeries on the available ids @@ -1107,6 +1108,11 @@ def batch( source or object tables. If not specified, then the id column is used by default. For TAPE and `light-curve` functions this is populated automatically. + single_lc: `boolean` or `int`, optional + If a `boolean` and True, batch will only execute on a single randomly selected + lightcurve. If False, batch will execute across all lightcurves as normal. To + specify a specific lightcurve, the integer value of the lightcurve's id can be + provided. Default is False. label: 'str', optional If provided the ensemble will use this label to track the result dataframe. If not provided, a label of the from "result_{x}" where x @@ -1133,6 +1139,12 @@ def batch( from light_curve import EtaE ens.batch(EtaE(), band_to_calc='g') + To run a TAPE function on a single lightcurve: + from tape.analysis.stetsonj import calc_stetson_J + ens = Ensemble().from_dataset('rrlyr82') + lc_id = 4378437892 # The lightcurve id + ensemble.batch(calc_stetson_J, band_to_calc='i', single_lc=lc_id) + Run a custom function on the ensemble:: def s2n_inter_quartile_range(flux, err): @@ -1160,6 +1172,22 @@ def s2n_inter_quartile_range(flux, err): if meta is None: meta = (self._id_col, float) # return a series of ids, default assume a float is returned + src_to_batch = self.source + obj_to_batch = self.object + + # Check if we only want to apply the batch function to a single lightcurve + if not isinstance(single_lc, bool) and not isinstance(single_lc, int): + raise ValueError("single_lc must be a boolean or an integer") + elif single_lc is True: + # Select the ID of a random lightcurve + rand_lc_id = self.select_random_timeseries(id_only=True) + src_to_batch = src_to_batch.loc[rand_lc_id] + obj_to_batch = obj_to_batch.loc[rand_lc_id] + elif single_lc is not False: + # The user provided the id of a specific lightcurve + src_to_batch = src_to_batch.loc[single_lc] + obj_to_batch = obj_to_batch.loc[single_lc] + # Translate the meta into an appropriate TapeFrame or TapeSeries. This ensures that the # batch result will be an EnsembleFrame or EnsembleSeries. meta = self._translate_meta(meta) @@ -1178,15 +1206,13 @@ def s2n_inter_quartile_range(flux, err): on[-1] = self._band_col # Handle object columns to group on - source_cols = list(self.source.columns) - object_cols = list(self.object.columns) + source_cols = list(src_to_batch.columns) + object_cols = list(obj_to_batch.columns) object_group_cols = [col for col in on if (col in object_cols) and (col not in source_cols)] if len(object_group_cols) > 0: - object_col_dd = self.object[object_group_cols] - source_to_batch = self.source.merge(object_col_dd, how="left") - else: - source_to_batch = self.source # Can directly use the source table + obj_to_batch = obj_to_batch[object_group_cols] + src_to_batch = src_to_batch.merge(obj_to_batch, how="left") id_col = self._id_col # pre-compute needed for dask in lambda function @@ -1211,11 +1237,11 @@ def _batch_apply(df, func, on, *args, **kwargs): id_col = self._id_col # need to grab this before mapping - batch = source_to_batch.map_partitions(_batch_apply, func, on, *args, **kwargs, meta=meta) + batch = src_to_batch.map_partitions(_batch_apply, func, on, *args, **kwargs, meta=meta) else: # use groupby # don't use _batch_apply as meta must be specified in the apply call - batch = source_to_batch.groupby(on, group_keys=True, sort=False).apply( + batch = src_to_batch.groupby(on, group_keys=True, sort=False).apply( _apply_func_to_lc, func, *args, @@ -2290,7 +2316,7 @@ def _sync_tables(self): self.object.set_dirty(False) return self - def select_random_timeseries(self, seed=None): + def select_random_timeseries(self, seed=None, id_only=False): """Selects a random lightcurve from a random partition of the Ensemble. Parameters @@ -2298,6 +2324,9 @@ def select_random_timeseries(self, seed=None): seed: int, or None Sets a seed to return the same object id on successive runs. `None` by default, in which case a seed is not set for the operation. + id_only: bool, optional + If True, returns only a random object id. If False, returns the + full timeseries for the object. Default is False. Returns ------- @@ -2336,7 +2365,7 @@ def select_random_timeseries(self, seed=None): if i >= len(partitions): raise IndexError("Found no object IDs in the Object Table.") - return self.to_timeseries(lcid) + return self.to_timeseries(lcid) if not id_only else lcid def to_timeseries( self, diff --git a/tests/tape_tests/test_ensemble.py b/tests/tape_tests/test_ensemble.py index 147d14b5..8dee5192 100644 --- a/tests/tape_tests/test_ensemble.py +++ b/tests/tape_tests/test_ensemble.py @@ -2127,6 +2127,46 @@ def my_bounds(flux): assert all([col in res.columns for col in res.compute().columns]) +@pytest.mark.parametrize("data_fixture", ["parquet_ensemble", "parquet_ensemble_with_divisions"]) +def test_batch_single_lc(data_fixture, request): + """ + Test that ensemble.batch() can run a function on a single light curve. + """ + parquet_ensemble = request.getfixturevalue(data_fixture) + + # Perform batch only on this specific lightcurve. + lc = 88472935274829959 + + # Check that we raise an error if single_lc is neither a bool nor an integer + with pytest.raises(ValueError): + parquet_ensemble.batch(calc_stetson_J, use_map=True, on=None, band_to_calc=None, single_lc="foo") + + lc_res = parquet_ensemble.prune(10).batch( + calc_stetson_J, use_map=True, on=None, band_to_calc=None, single_lc=lc + ) + assert len(lc_res) == 1 + + # Now ensure that we got the same result when we ran the function on the entire ensemble. + full_res = parquet_ensemble.prune(10).batch(calc_stetson_J, use_map=True, on=None, band_to_calc=None) + assert full_res.compute().loc[lc].stetsonJ == lc_res.compute().iloc[0].stetsonJ + + # Check that when single_lc is True we will perform batch on a random lightcurve and still get only one result. + rand_lc = parquet_ensemble.prune(10).batch( + calc_stetson_J, use_map=True, on=None, band_to_calc=None, single_lc=True + ) + assert len(rand_lc) == 1 + + # Now compare that result to what was computed when doing the full batch result + rand_lc_id = rand_lc.index.compute().values[0] + assert full_res.compute().loc[rand_lc_id].stetsonJ == rand_lc.compute().iloc[0].stetsonJ + + # Check that when single_lc is False we get the same # of results as the full batch + no_lc = parquet_ensemble.prune(10).batch( + calc_stetson_J, use_map=True, on=None, band_to_calc=None, single_lc=False + ) + assert len(full_res) == len(no_lc) + + def test_batch_labels(parquet_ensemble): """ Test that ensemble.batch() generates unique labels for result frames when none are provided. @@ -2236,7 +2276,8 @@ def test_batch_with_custom_frame_meta(parquet_ensemble, custom_meta): @pytest.mark.parametrize("repartition", [False, True]) @pytest.mark.parametrize("seed", [None, 42]) -def test_select_random_timeseries(parquet_ensemble, repartition, seed): +@pytest.mark.parametrize("id_only", [True, False]) +def test_select_random_timeseries(parquet_ensemble, repartition, seed, id_only): """Test the behavior of ensemble.select_random_timeseries""" ens = parquet_ensemble @@ -2244,14 +2285,17 @@ def test_select_random_timeseries(parquet_ensemble, repartition, seed): if repartition: ens.object = ens.object.repartition(npartitions=3) - ts = ens.select_random_timeseries(seed=seed) + ts = ens.select_random_timeseries(seed=seed, id_only=id_only) - assert isinstance(ts, TimeSeries) + if not id_only: + assert isinstance(ts, TimeSeries) - if seed == 42 and not repartition: - assert ts.meta["id"] == 88472935274829959 - elif seed == 42 and repartition: - assert ts.meta["id"] == 88480001333818899 + if seed == 42: + expected_id = 88480001333818899 if repartition else 88472935274829959 + if id_only: + assert ts == expected_id + else: + assert ts.meta["id"] == expected_id @pytest.mark.parametrize("all_empty", [False, True])