From 6c8f1b40acccdbf17d56a465ee306f8288f70fcf Mon Sep 17 00:00:00 2001 From: Wilson Beebe Date: Wed, 24 Apr 2024 15:43:29 -0700 Subject: [PATCH] Pin dask-expr to < 1.10.11 and revert "Add support for running batch on a single lightcurve" (#433) * Revert "Add support for running batch on a single lightcurve (#420)" This reverts commit 89aa116943564c6add1281d049f62ff9c2531fb9. * Pin dask-expr to old version * Pin to last known good dask-expr * Try pinning dask-expr to <1.0.10 * Move pin back to 1.0.10 --- docs/tutorials/batch_showcase.ipynb | 39 +------------------ pyproject.toml | 1 + src/tape/ensemble.py | 49 +++++------------------- tests/tape_tests/test_ensemble.py | 58 ++++------------------------- 4 files changed, 19 insertions(+), 128 deletions(-) diff --git a/docs/tutorials/batch_showcase.ipynb b/docs/tutorials/batch_showcase.ipynb index 42b4d77b..4050405d 100644 --- a/docs/tutorials/batch_showcase.ipynb +++ b/docs/tutorials/batch_showcase.ipynb @@ -123,43 +123,6 @@ "res1.compute() # Compute to see the result" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "By default `Ensemble.batch` will apply your function across all light curves. However with the `single_lc` parameter you can test out your function on only a single object." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# The same batch call as above but now only on the final lightcurve.\n", - "\n", - "lc_id = 109 # id of the final lightcurve in the data\n", - "lc_res = ens.batch(my_mean, \"flux\", single_lc=lc_id)\n", - "lc_res.compute()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can also run your batch function on a single random lightcurve with `single_lc=True`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "rand_lc = ens.batch(my_mean, \"flux\", single_lc=True)\n", - "rand_lc.compute()" - ] - }, { "attachments": {}, "cell_type": "markdown", @@ -623,7 +586,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.14" + "version": "3.10.11" }, "vscode": { "interpreter": { diff --git a/pyproject.toml b/pyproject.toml index a37321eb..635fa2b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ 'pandas', 'numpy', 'dask>=2024.3.0', + 'dask-expr<1.0.11', # Temporary pin due to compatibility bug 'dask[distributed]>=2024.3.0', 'pyarrow', 'pyvo', diff --git a/src/tape/ensemble.py b/src/tape/ensemble.py index ea13b494..9a6ef7f9 100644 --- a/src/tape/ensemble.py +++ b/src/tape/ensemble.py @@ -1059,7 +1059,6 @@ def batch( use_map=True, on=None, label="", - single_lc=False, **kwargs, ): """Run a function from tape.TimeSeries on the available ids @@ -1108,11 +1107,6 @@ def batch( source or object tables. If not specified, then the id column is used by default. For TAPE and `light-curve` functions this is populated automatically. - single_lc: `boolean` or `int`, optional - If a `boolean` and True, batch will only execute on a single randomly selected - lightcurve. If False, batch will execute across all lightcurves as normal. To - specify a specific lightcurve, the integer value of the lightcurve's id can be - provided. Default is False. label: 'str', optional If provided the ensemble will use this label to track the result dataframe. If not provided, a label of the from "result_{x}" where x @@ -1139,12 +1133,6 @@ def batch( from light_curve import EtaE ens.batch(EtaE(), band_to_calc='g') - To run a TAPE function on a single lightcurve: - from tape.analysis.stetsonj import calc_stetson_J - ens = Ensemble().from_dataset('rrlyr82') - lc_id = 4378437892 # The lightcurve id - ensemble.batch(calc_stetson_J, band_to_calc='i', single_lc=lc_id) - Run a custom function on the ensemble:: def s2n_inter_quartile_range(flux, err): @@ -1172,22 +1160,6 @@ def s2n_inter_quartile_range(flux, err): if meta is None: meta = (self._id_col, float) # return a series of ids, default assume a float is returned - src_to_batch = self.source - obj_to_batch = self.object - - # Check if we only want to apply the batch function to a single lightcurve - if not isinstance(single_lc, bool) and not isinstance(single_lc, int): - raise ValueError("single_lc must be a boolean or an integer") - elif single_lc is True: - # Select the ID of a random lightcurve - rand_lc_id = self.select_random_timeseries(id_only=True) - src_to_batch = src_to_batch.loc[rand_lc_id] - obj_to_batch = obj_to_batch.loc[rand_lc_id] - elif single_lc is not False: - # The user provided the id of a specific lightcurve - src_to_batch = src_to_batch.loc[single_lc] - obj_to_batch = obj_to_batch.loc[single_lc] - # Translate the meta into an appropriate TapeFrame or TapeSeries. This ensures that the # batch result will be an EnsembleFrame or EnsembleSeries. meta = self._translate_meta(meta) @@ -1206,13 +1178,15 @@ def s2n_inter_quartile_range(flux, err): on[-1] = self._band_col # Handle object columns to group on - source_cols = list(src_to_batch.columns) - object_cols = list(obj_to_batch.columns) + source_cols = list(self.source.columns) + object_cols = list(self.object.columns) object_group_cols = [col for col in on if (col in object_cols) and (col not in source_cols)] if len(object_group_cols) > 0: - obj_to_batch = obj_to_batch[object_group_cols] - src_to_batch = src_to_batch.merge(obj_to_batch, how="left") + object_col_dd = self.object[object_group_cols] + source_to_batch = self.source.merge(object_col_dd, how="left") + else: + source_to_batch = self.source # Can directly use the source table id_col = self._id_col # pre-compute needed for dask in lambda function @@ -1237,11 +1211,11 @@ def _batch_apply(df, func, on, *args, **kwargs): id_col = self._id_col # need to grab this before mapping - batch = src_to_batch.map_partitions(_batch_apply, func, on, *args, **kwargs, meta=meta) + batch = source_to_batch.map_partitions(_batch_apply, func, on, *args, **kwargs, meta=meta) else: # use groupby # don't use _batch_apply as meta must be specified in the apply call - batch = src_to_batch.groupby(on, group_keys=True, sort=False).apply( + batch = source_to_batch.groupby(on, group_keys=True, sort=False).apply( _apply_func_to_lc, func, *args, @@ -2316,7 +2290,7 @@ def _sync_tables(self): self.object.set_dirty(False) return self - def select_random_timeseries(self, seed=None, id_only=False): + def select_random_timeseries(self, seed=None): """Selects a random lightcurve from a random partition of the Ensemble. Parameters @@ -2324,9 +2298,6 @@ def select_random_timeseries(self, seed=None, id_only=False): seed: int, or None Sets a seed to return the same object id on successive runs. `None` by default, in which case a seed is not set for the operation. - id_only: bool, optional - If True, returns only a random object id. If False, returns the - full timeseries for the object. Default is False. Returns ------- @@ -2365,7 +2336,7 @@ def select_random_timeseries(self, seed=None, id_only=False): if i >= len(partitions): raise IndexError("Found no object IDs in the Object Table.") - return self.to_timeseries(lcid) if not id_only else lcid + return self.to_timeseries(lcid) def to_timeseries( self, diff --git a/tests/tape_tests/test_ensemble.py b/tests/tape_tests/test_ensemble.py index 8dee5192..147d14b5 100644 --- a/tests/tape_tests/test_ensemble.py +++ b/tests/tape_tests/test_ensemble.py @@ -2127,46 +2127,6 @@ def my_bounds(flux): assert all([col in res.columns for col in res.compute().columns]) -@pytest.mark.parametrize("data_fixture", ["parquet_ensemble", "parquet_ensemble_with_divisions"]) -def test_batch_single_lc(data_fixture, request): - """ - Test that ensemble.batch() can run a function on a single light curve. - """ - parquet_ensemble = request.getfixturevalue(data_fixture) - - # Perform batch only on this specific lightcurve. - lc = 88472935274829959 - - # Check that we raise an error if single_lc is neither a bool nor an integer - with pytest.raises(ValueError): - parquet_ensemble.batch(calc_stetson_J, use_map=True, on=None, band_to_calc=None, single_lc="foo") - - lc_res = parquet_ensemble.prune(10).batch( - calc_stetson_J, use_map=True, on=None, band_to_calc=None, single_lc=lc - ) - assert len(lc_res) == 1 - - # Now ensure that we got the same result when we ran the function on the entire ensemble. - full_res = parquet_ensemble.prune(10).batch(calc_stetson_J, use_map=True, on=None, band_to_calc=None) - assert full_res.compute().loc[lc].stetsonJ == lc_res.compute().iloc[0].stetsonJ - - # Check that when single_lc is True we will perform batch on a random lightcurve and still get only one result. - rand_lc = parquet_ensemble.prune(10).batch( - calc_stetson_J, use_map=True, on=None, band_to_calc=None, single_lc=True - ) - assert len(rand_lc) == 1 - - # Now compare that result to what was computed when doing the full batch result - rand_lc_id = rand_lc.index.compute().values[0] - assert full_res.compute().loc[rand_lc_id].stetsonJ == rand_lc.compute().iloc[0].stetsonJ - - # Check that when single_lc is False we get the same # of results as the full batch - no_lc = parquet_ensemble.prune(10).batch( - calc_stetson_J, use_map=True, on=None, band_to_calc=None, single_lc=False - ) - assert len(full_res) == len(no_lc) - - def test_batch_labels(parquet_ensemble): """ Test that ensemble.batch() generates unique labels for result frames when none are provided. @@ -2276,8 +2236,7 @@ def test_batch_with_custom_frame_meta(parquet_ensemble, custom_meta): @pytest.mark.parametrize("repartition", [False, True]) @pytest.mark.parametrize("seed", [None, 42]) -@pytest.mark.parametrize("id_only", [True, False]) -def test_select_random_timeseries(parquet_ensemble, repartition, seed, id_only): +def test_select_random_timeseries(parquet_ensemble, repartition, seed): """Test the behavior of ensemble.select_random_timeseries""" ens = parquet_ensemble @@ -2285,17 +2244,14 @@ def test_select_random_timeseries(parquet_ensemble, repartition, seed, id_only): if repartition: ens.object = ens.object.repartition(npartitions=3) - ts = ens.select_random_timeseries(seed=seed, id_only=id_only) + ts = ens.select_random_timeseries(seed=seed) - if not id_only: - assert isinstance(ts, TimeSeries) + assert isinstance(ts, TimeSeries) - if seed == 42: - expected_id = 88480001333818899 if repartition else 88472935274829959 - if id_only: - assert ts == expected_id - else: - assert ts.meta["id"] == expected_id + if seed == 42 and not repartition: + assert ts.meta["id"] == 88472935274829959 + elif seed == 42 and repartition: + assert ts.meta["id"] == 88480001333818899 @pytest.mark.parametrize("all_empty", [False, True])