Skip to content

Commit

Permalink
Pin dask-expr to < 1.10.11 and revert "Add support for running batch …
Browse files Browse the repository at this point in the history
…on a single lightcurve" (#433)

* Revert "Add support for running batch on a single lightcurve (#420)"

This reverts commit 89aa116.

* Pin dask-expr to old version

* Pin to last known good dask-expr

* Try pinning dask-expr to <1.0.10

* Move pin back to 1.0.10
  • Loading branch information
wilsonbb authored Apr 24, 2024
1 parent 89aa116 commit 6c8f1b4
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 128 deletions.
39 changes: 1 addition & 38 deletions docs/tutorials/batch_showcase.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -123,43 +123,6 @@
"res1.compute() # Compute to see the result"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"By default `Ensemble.batch` will apply your function across all light curves. However with the `single_lc` parameter you can test out your function on only a single object."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# The same batch call as above but now only on the final lightcurve.\n",
"\n",
"lc_id = 109 # id of the final lightcurve in the data\n",
"lc_res = ens.batch(my_mean, \"flux\", single_lc=lc_id)\n",
"lc_res.compute()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can also run your batch function on a single random lightcurve with `single_lc=True`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rand_lc = ens.batch(my_mean, \"flux\", single_lc=True)\n",
"rand_lc.compute()"
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down Expand Up @@ -623,7 +586,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"version": "3.10.11"
},
"vscode": {
"interpreter": {
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies = [
'pandas',
'numpy',
'dask>=2024.3.0',
'dask-expr<1.0.11', # Temporary pin due to compatibility bug
'dask[distributed]>=2024.3.0',
'pyarrow',
'pyvo',
Expand Down
49 changes: 10 additions & 39 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,7 +1059,6 @@ def batch(
use_map=True,
on=None,
label="",
single_lc=False,
**kwargs,
):
"""Run a function from tape.TimeSeries on the available ids
Expand Down Expand Up @@ -1108,11 +1107,6 @@ def batch(
source or object tables. If not specified, then the id column is
used by default. For TAPE and `light-curve` functions this is
populated automatically.
single_lc: `boolean` or `int`, optional
If a `boolean` and True, batch will only execute on a single randomly selected
lightcurve. If False, batch will execute across all lightcurves as normal. To
specify a specific lightcurve, the integer value of the lightcurve's id can be
provided. Default is False.
label: 'str', optional
If provided the ensemble will use this label to track the result
dataframe. If not provided, a label of the from "result_{x}" where x
Expand All @@ -1139,12 +1133,6 @@ def batch(
from light_curve import EtaE
ens.batch(EtaE(), band_to_calc='g')
To run a TAPE function on a single lightcurve:
from tape.analysis.stetsonj import calc_stetson_J
ens = Ensemble().from_dataset('rrlyr82')
lc_id = 4378437892 # The lightcurve id
ensemble.batch(calc_stetson_J, band_to_calc='i', single_lc=lc_id)
Run a custom function on the ensemble::
def s2n_inter_quartile_range(flux, err):
Expand Down Expand Up @@ -1172,22 +1160,6 @@ def s2n_inter_quartile_range(flux, err):
if meta is None:
meta = (self._id_col, float) # return a series of ids, default assume a float is returned

src_to_batch = self.source
obj_to_batch = self.object

# Check if we only want to apply the batch function to a single lightcurve
if not isinstance(single_lc, bool) and not isinstance(single_lc, int):
raise ValueError("single_lc must be a boolean or an integer")
elif single_lc is True:
# Select the ID of a random lightcurve
rand_lc_id = self.select_random_timeseries(id_only=True)
src_to_batch = src_to_batch.loc[rand_lc_id]
obj_to_batch = obj_to_batch.loc[rand_lc_id]
elif single_lc is not False:
# The user provided the id of a specific lightcurve
src_to_batch = src_to_batch.loc[single_lc]
obj_to_batch = obj_to_batch.loc[single_lc]

# Translate the meta into an appropriate TapeFrame or TapeSeries. This ensures that the
# batch result will be an EnsembleFrame or EnsembleSeries.
meta = self._translate_meta(meta)
Expand All @@ -1206,13 +1178,15 @@ def s2n_inter_quartile_range(flux, err):
on[-1] = self._band_col

# Handle object columns to group on
source_cols = list(src_to_batch.columns)
object_cols = list(obj_to_batch.columns)
source_cols = list(self.source.columns)
object_cols = list(self.object.columns)
object_group_cols = [col for col in on if (col in object_cols) and (col not in source_cols)]

if len(object_group_cols) > 0:
obj_to_batch = obj_to_batch[object_group_cols]
src_to_batch = src_to_batch.merge(obj_to_batch, how="left")
object_col_dd = self.object[object_group_cols]
source_to_batch = self.source.merge(object_col_dd, how="left")
else:
source_to_batch = self.source # Can directly use the source table

id_col = self._id_col # pre-compute needed for dask in lambda function

Expand All @@ -1237,11 +1211,11 @@ def _batch_apply(df, func, on, *args, **kwargs):

id_col = self._id_col # need to grab this before mapping

batch = src_to_batch.map_partitions(_batch_apply, func, on, *args, **kwargs, meta=meta)
batch = source_to_batch.map_partitions(_batch_apply, func, on, *args, **kwargs, meta=meta)

else: # use groupby
# don't use _batch_apply as meta must be specified in the apply call
batch = src_to_batch.groupby(on, group_keys=True, sort=False).apply(
batch = source_to_batch.groupby(on, group_keys=True, sort=False).apply(
_apply_func_to_lc,
func,
*args,
Expand Down Expand Up @@ -2316,17 +2290,14 @@ def _sync_tables(self):
self.object.set_dirty(False)
return self

def select_random_timeseries(self, seed=None, id_only=False):
def select_random_timeseries(self, seed=None):
"""Selects a random lightcurve from a random partition of the Ensemble.
Parameters
----------
seed: int, or None
Sets a seed to return the same object id on successive runs. `None`
by default, in which case a seed is not set for the operation.
id_only: bool, optional
If True, returns only a random object id. If False, returns the
full timeseries for the object. Default is False.
Returns
-------
Expand Down Expand Up @@ -2365,7 +2336,7 @@ def select_random_timeseries(self, seed=None, id_only=False):
if i >= len(partitions):
raise IndexError("Found no object IDs in the Object Table.")

return self.to_timeseries(lcid) if not id_only else lcid
return self.to_timeseries(lcid)

def to_timeseries(
self,
Expand Down
58 changes: 7 additions & 51 deletions tests/tape_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -2127,46 +2127,6 @@ def my_bounds(flux):
assert all([col in res.columns for col in res.compute().columns])


@pytest.mark.parametrize("data_fixture", ["parquet_ensemble", "parquet_ensemble_with_divisions"])
def test_batch_single_lc(data_fixture, request):
"""
Test that ensemble.batch() can run a function on a single light curve.
"""
parquet_ensemble = request.getfixturevalue(data_fixture)

# Perform batch only on this specific lightcurve.
lc = 88472935274829959

# Check that we raise an error if single_lc is neither a bool nor an integer
with pytest.raises(ValueError):
parquet_ensemble.batch(calc_stetson_J, use_map=True, on=None, band_to_calc=None, single_lc="foo")

lc_res = parquet_ensemble.prune(10).batch(
calc_stetson_J, use_map=True, on=None, band_to_calc=None, single_lc=lc
)
assert len(lc_res) == 1

# Now ensure that we got the same result when we ran the function on the entire ensemble.
full_res = parquet_ensemble.prune(10).batch(calc_stetson_J, use_map=True, on=None, band_to_calc=None)
assert full_res.compute().loc[lc].stetsonJ == lc_res.compute().iloc[0].stetsonJ

# Check that when single_lc is True we will perform batch on a random lightcurve and still get only one result.
rand_lc = parquet_ensemble.prune(10).batch(
calc_stetson_J, use_map=True, on=None, band_to_calc=None, single_lc=True
)
assert len(rand_lc) == 1

# Now compare that result to what was computed when doing the full batch result
rand_lc_id = rand_lc.index.compute().values[0]
assert full_res.compute().loc[rand_lc_id].stetsonJ == rand_lc.compute().iloc[0].stetsonJ

# Check that when single_lc is False we get the same # of results as the full batch
no_lc = parquet_ensemble.prune(10).batch(
calc_stetson_J, use_map=True, on=None, band_to_calc=None, single_lc=False
)
assert len(full_res) == len(no_lc)


def test_batch_labels(parquet_ensemble):
"""
Test that ensemble.batch() generates unique labels for result frames when none are provided.
Expand Down Expand Up @@ -2276,26 +2236,22 @@ def test_batch_with_custom_frame_meta(parquet_ensemble, custom_meta):

@pytest.mark.parametrize("repartition", [False, True])
@pytest.mark.parametrize("seed", [None, 42])
@pytest.mark.parametrize("id_only", [True, False])
def test_select_random_timeseries(parquet_ensemble, repartition, seed, id_only):
def test_select_random_timeseries(parquet_ensemble, repartition, seed):
"""Test the behavior of ensemble.select_random_timeseries"""

ens = parquet_ensemble

if repartition:
ens.object = ens.object.repartition(npartitions=3)

ts = ens.select_random_timeseries(seed=seed, id_only=id_only)
ts = ens.select_random_timeseries(seed=seed)

if not id_only:
assert isinstance(ts, TimeSeries)
assert isinstance(ts, TimeSeries)

if seed == 42:
expected_id = 88480001333818899 if repartition else 88472935274829959
if id_only:
assert ts == expected_id
else:
assert ts.meta["id"] == expected_id
if seed == 42 and not repartition:
assert ts.meta["id"] == 88472935274829959
elif seed == 42 and repartition:
assert ts.meta["id"] == 88480001333818899


@pytest.mark.parametrize("all_empty", [False, True])
Expand Down

0 comments on commit 6c8f1b4

Please sign in to comment.