From c9b477a003f337e6d7aebd6382477954a14e05b8 Mon Sep 17 00:00:00 2001 From: Wilson Beebe Date: Thu, 28 Mar 2024 17:20:52 -0700 Subject: [PATCH 1/5] Add support for batch on a single lightcurve --- docs/tutorials/batch_showcase.ipynb | 337 +++++++++++++++++++++++++++- src/tape/ensemble.py | 31 ++- tests/tape_tests/test_ensemble.py | 19 ++ 3 files changed, 367 insertions(+), 20 deletions(-) diff --git a/docs/tutorials/batch_showcase.ipynb b/docs/tutorials/batch_showcase.ipynb index b25b78fa..e0f60243 100644 --- a/docs/tutorials/batch_showcase.ipynb +++ b/docs/tutorials/batch_showcase.ipynb @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -32,7 +32,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -57,9 +57,30 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/wilsonbb/lincc/tape/test_batch/tape/tape/src/tape/ensemble.py:1676: UserWarning: Divisions for object are not set, certain downstream dask operations may fail as a result. We recommend setting the `sort` or `sorted` flags when loading data to establish division information.\n", + " warnings.warn(\n", + "/Users/wilsonbb/lincc/tape/test_batch/tape/tape/src/tape/ensemble.py:1676: UserWarning: Divisions for source are not set, certain downstream dask operations may fail as a result. We recommend setting the `sort` or `sorted` flags when loading data to establish division information.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# Load the data into an Ensemble\n", "ens = Ensemble()\n", @@ -70,6 +91,104 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
mjdfluxerrband
id
1091245.08.3602530.836025g
1091246.03.8924040.389240g
1091247.00.7516970.075170r
1091248.04.0829420.408294g
1091249.06.1228060.612281g
\n", + "
" + ], + "text/plain": [ + " mjd flux err band\n", + "id \n", + "109 1245.0 8.360253 0.836025 g\n", + "109 1246.0 3.892404 0.389240 g\n", + "109 1247.0 0.751697 0.075170 r\n", + "109 1248.0 4.082942 0.408294 g\n", + "109 1249.0 6.122806 0.612281 g" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ens.source.tail(5)" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -88,9 +207,20 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "3.0" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# Case 1: Simple\n", "def my_mean(flux):\n", @@ -110,9 +240,117 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using generated label, result_1, for a batch result.\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
result
id
105.020567
115.041076
124.916761
135.033744
145.084872
......
1055.113830
1065.097340
1075.013111
1084.962582
1095.143362
\n", + "

100 rows × 1 columns

\n", + "
" + ], + "text/plain": [ + " result\n", + "id \n", + "10 5.020567\n", + "11 5.041076\n", + "12 4.916761\n", + "13 5.033744\n", + "14 5.084872\n", + ".. ...\n", + "105 5.113830\n", + "106 5.097340\n", + "107 5.013111\n", + "108 4.962582\n", + "109 5.143362\n", + "\n", + "[100 rows x 1 columns]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# Default batch\n", "res1 = ens.batch(\n", @@ -121,6 +359,81 @@ "res1.compute() # Compute to see the result" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "By default `Ensemble.batch` will apply your function across all light curves. However with the `single_lc` parameter you can test out your function only on a single object." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using generated label, result_7, for a batch result.\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
result
id
1095.143362
\n", + "
" + ], + "text/plain": [ + " result\n", + "id \n", + "109 5.143362" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# The same batch call as above but now only on the final lightcurve.\n", + "\n", + "lc_id = 109\n", + "lc_res = ens.batch(my_mean, \"flux\", single_lc=lc_id)\n", + "lc_res.compute()" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -584,7 +897,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.14" }, "vscode": { "interpreter": { diff --git a/src/tape/ensemble.py b/src/tape/ensemble.py index 4d35fd8d..9efdb0ce 100644 --- a/src/tape/ensemble.py +++ b/src/tape/ensemble.py @@ -1059,6 +1059,7 @@ def batch( use_map=True, on=None, label="", + single_lc=None, **kwargs, ): """Run a function from tape.TimeSeries on the available ids @@ -1107,6 +1108,9 @@ def batch( source or object tables. If not specified, then the id column is used by default. For TAPE and `light-curve` functions this is populated automatically. + single_lc: `int`, optional + If provided, only the lightcurve with the specified id will be + used in batch. Default is None. label: 'str', optional If provided the ensemble will use this label to track the result dataframe. If not provided, a label of the from "result_{x}" where x @@ -1133,6 +1137,12 @@ def batch( from light_curve import EtaE ens.batch(EtaE(), band_to_calc='g') + To run a TAPE function on a single lightcurve: + from tape.analysis.stetsonj import calc_stetson_J + ens = Ensemble().from_dataset('rrlyr82') + lc_id = 4378437892 # The lightcurve id + ensemble.batch(calc_stetson_J, band_to_calc='i', single_lc=lc_id) + Run a custom function on the ensemble:: def s2n_inter_quartile_range(flux, err): @@ -1160,6 +1170,13 @@ def s2n_inter_quartile_range(flux, err): if meta is None: meta = (self._id_col, float) # return a series of ids, default assume a float is returned + src_to_batch = self.source + obj_to_batch = self.object + + if single_lc is not None: + src_to_batch = src_to_batch.loc[single_lc] + obj_to_batch = obj_to_batch.loc[single_lc] + # Translate the meta into an appropriate TapeFrame or TapeSeries. This ensures that the # batch result will be an EnsembleFrame or EnsembleSeries. meta = self._translate_meta(meta) @@ -1178,15 +1195,13 @@ def s2n_inter_quartile_range(flux, err): on[-1] = self._band_col # Handle object columns to group on - source_cols = list(self.source.columns) - object_cols = list(self.object.columns) + source_cols = list(src_to_batch.columns) + object_cols = list(obj_to_batch.columns) object_group_cols = [col for col in on if (col in object_cols) and (col not in source_cols)] if len(object_group_cols) > 0: - object_col_dd = self.object[object_group_cols] - source_to_batch = self.source.merge(object_col_dd, how="left") - else: - source_to_batch = self.source # Can directly use the source table + obj_to_batch = obj_to_batch[object_group_cols] + src_to_batch = src_to_batch.merge(obj_to_batch, how="left") id_col = self._id_col # pre-compute needed for dask in lambda function @@ -1211,11 +1226,11 @@ def _batch_apply(df, func, on, *args, **kwargs): id_col = self._id_col # need to grab this before mapping - batch = source_to_batch.map_partitions(_batch_apply, func, on, *args, **kwargs, meta=meta) + batch = src_to_batch.map_partitions(_batch_apply, func, on, *args, **kwargs, meta=meta) else: # use groupby # don't use _batch_apply as meta must be specified in the apply call - batch = source_to_batch.groupby(on, group_keys=True, sort=False).apply( + batch = src_to_batch.groupby(on, group_keys=True, sort=False).apply( _apply_func_to_lc, func, *args, diff --git a/tests/tape_tests/test_ensemble.py b/tests/tape_tests/test_ensemble.py index 147d14b5..ba8505b5 100644 --- a/tests/tape_tests/test_ensemble.py +++ b/tests/tape_tests/test_ensemble.py @@ -2127,6 +2127,25 @@ def my_bounds(flux): assert all([col in res.columns for col in res.compute().columns]) +@pytest.mark.parametrize("data_fixture", ["parquet_ensemble", "parquet_ensemble_with_divisions"]) +def test_batch_single_lc(data_fixture, request): + """ + Test that ensemble.batch() can run a function on a single light curve. + """ + parquet_ensemble = request.getfixturevalue(data_fixture) + + lc = 88472935274829959 + + lc_res = parquet_ensemble.prune(10).batch( + calc_stetson_J, use_map=True, on=None, band_to_calc=None, single_lc=lc + ) + assert len(lc_res) == 1 + + # Now ensure that we got the same result when we ran the function on the entire ensemble. + full_res = parquet_ensemble.prune(10).batch(calc_stetson_J, use_map=True, on=None, band_to_calc=None) + assert full_res.compute().loc[lc].stetsonJ == lc_res.compute().iloc[0].stetsonJ + + def test_batch_labels(parquet_ensemble): """ Test that ensemble.batch() generates unique labels for result frames when none are provided. From a0d3fbd1c291c2c42c0cc35aa365aaf3f83ed75f Mon Sep 17 00:00:00 2001 From: Wilson Beebe Date: Thu, 28 Mar 2024 17:31:43 -0700 Subject: [PATCH 2/5] Clear notebook output --- docs/tutorials/batch_showcase.ipynb | 327 +++------------------------- 1 file changed, 25 insertions(+), 302 deletions(-) diff --git a/docs/tutorials/batch_showcase.ipynb b/docs/tutorials/batch_showcase.ipynb index 9f28e16e..3c860889 100644 --- a/docs/tutorials/batch_showcase.ipynb +++ b/docs/tutorials/batch_showcase.ipynb @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -32,7 +32,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -57,30 +57,9 @@ }, { "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/wilsonbb/lincc/tape/test_batch/tape/tape/src/tape/ensemble.py:1676: UserWarning: Divisions for object are not set, certain downstream dask operations may fail as a result. We recommend setting the `sort` or `sorted` flags when loading data to establish division information.\n", - " warnings.warn(\n", - "/Users/wilsonbb/lincc/tape/test_batch/tape/tape/src/tape/ensemble.py:1676: UserWarning: Divisions for source are not set, certain downstream dask operations may fail as a result. We recommend setting the `sort` or `sorted` flags when loading data to establish division information.\n", - " warnings.warn(\n" - ] - }, - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# Load the data into an Ensemble\n", "ens = Ensemble()\n", @@ -94,98 +73,9 @@ }, { "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
mjdfluxerrband
id
1091245.08.3602530.836025g
1091246.03.8924040.389240g
1091247.00.7516970.075170r
1091248.04.0829420.408294g
1091249.06.1228060.612281g
\n", - "
" - ], - "text/plain": [ - " mjd flux err band\n", - "id \n", - "109 1245.0 8.360253 0.836025 g\n", - "109 1246.0 3.892404 0.389240 g\n", - "109 1247.0 0.751697 0.075170 r\n", - "109 1248.0 4.082942 0.408294 g\n", - "109 1249.0 6.122806 0.612281 g" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "ens.source.tail(5)" ] @@ -208,20 +98,9 @@ }, { "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "3.0" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# Case 1: Simple\n", "def my_mean(flux):\n", @@ -241,117 +120,9 @@ }, { "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using generated label, result_1, for a batch result.\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
result
id
105.020567
115.041076
124.916761
135.033744
145.084872
......
1055.113830
1065.097340
1075.013111
1084.962582
1095.143362
\n", - "

100 rows × 1 columns

\n", - "
" - ], - "text/plain": [ - " result\n", - "id \n", - "10 5.020567\n", - "11 5.041076\n", - "12 4.916761\n", - "13 5.033744\n", - "14 5.084872\n", - ".. ...\n", - "105 5.113830\n", - "106 5.097340\n", - "107 5.013111\n", - "108 4.962582\n", - "109 5.143362\n", - "\n", - "[100 rows x 1 columns]" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# Default batch\n", "res1 = ens.batch(\n", @@ -360,6 +131,13 @@ "res1.compute() # Compute to see the result" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "markdown", "metadata": {}, @@ -369,68 +147,13 @@ }, { "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using generated label, result_7, for a batch result.\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
result
id
1095.143362
\n", - "
" - ], - "text/plain": [ - " result\n", - "id \n", - "109 5.143362" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "# The same batch call as above but now only on the final lightcurve.\n", "\n", - "lc_id = 109\n", + "lc_id = 109 # id of the final lightcurve in the data\n", "lc_res = ens.batch(my_mean, \"flux\", single_lc=lc_id)\n", "lc_res.compute()" ] From 12b192434925590ca0864f99eb7a5b054b244df7 Mon Sep 17 00:00:00 2001 From: Wilson Beebe Date: Thu, 28 Mar 2024 17:37:26 -0700 Subject: [PATCH 3/5] Remove extra cells --- docs/tutorials/batch_showcase.ipynb | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/docs/tutorials/batch_showcase.ipynb b/docs/tutorials/batch_showcase.ipynb index 3c860889..047d8ddc 100644 --- a/docs/tutorials/batch_showcase.ipynb +++ b/docs/tutorials/batch_showcase.ipynb @@ -71,15 +71,6 @@ ")" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ens.source.tail(5)" - ] - }, { "attachments": {}, "cell_type": "markdown", @@ -131,18 +122,11 @@ "res1.compute() # Compute to see the result" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "markdown", "metadata": {}, "source": [ - "By default `Ensemble.batch` will apply your function across all light curves. However with the `single_lc` parameter you can test out your function only on a single object." + "By default `Ensemble.batch` will apply your function across all light curves. However with the `single_lc` parameter you can test out your function on only a single object." ] }, { From ab1038256a752cb2bc2452552c440d73c1211280 Mon Sep 17 00:00:00 2001 From: Wilson Beebe Date: Fri, 5 Apr 2024 12:37:30 -0700 Subject: [PATCH 4/5] Support batch on random lightcurve --- src/tape/ensemble.py | 28 ++++++++++++++++------ tests/tape_tests/test_ensemble.py | 39 +++++++++++++++++++++++++------ 2 files changed, 53 insertions(+), 14 deletions(-) diff --git a/src/tape/ensemble.py b/src/tape/ensemble.py index 1704e7c2..ea13b494 100644 --- a/src/tape/ensemble.py +++ b/src/tape/ensemble.py @@ -1059,7 +1059,7 @@ def batch( use_map=True, on=None, label="", - single_lc=None, + single_lc=False, **kwargs, ): """Run a function from tape.TimeSeries on the available ids @@ -1108,9 +1108,11 @@ def batch( source or object tables. If not specified, then the id column is used by default. For TAPE and `light-curve` functions this is populated automatically. - single_lc: `int`, optional - If provided, only the lightcurve with the specified id will be - used in batch. Default is None. + single_lc: `boolean` or `int`, optional + If a `boolean` and True, batch will only execute on a single randomly selected + lightcurve. If False, batch will execute across all lightcurves as normal. To + specify a specific lightcurve, the integer value of the lightcurve's id can be + provided. Default is False. label: 'str', optional If provided the ensemble will use this label to track the result dataframe. If not provided, a label of the from "result_{x}" where x @@ -1173,7 +1175,16 @@ def s2n_inter_quartile_range(flux, err): src_to_batch = self.source obj_to_batch = self.object - if single_lc is not None: + # Check if we only want to apply the batch function to a single lightcurve + if not isinstance(single_lc, bool) and not isinstance(single_lc, int): + raise ValueError("single_lc must be a boolean or an integer") + elif single_lc is True: + # Select the ID of a random lightcurve + rand_lc_id = self.select_random_timeseries(id_only=True) + src_to_batch = src_to_batch.loc[rand_lc_id] + obj_to_batch = obj_to_batch.loc[rand_lc_id] + elif single_lc is not False: + # The user provided the id of a specific lightcurve src_to_batch = src_to_batch.loc[single_lc] obj_to_batch = obj_to_batch.loc[single_lc] @@ -2305,7 +2316,7 @@ def _sync_tables(self): self.object.set_dirty(False) return self - def select_random_timeseries(self, seed=None): + def select_random_timeseries(self, seed=None, id_only=False): """Selects a random lightcurve from a random partition of the Ensemble. Parameters @@ -2313,6 +2324,9 @@ def select_random_timeseries(self, seed=None): seed: int, or None Sets a seed to return the same object id on successive runs. `None` by default, in which case a seed is not set for the operation. + id_only: bool, optional + If True, returns only a random object id. If False, returns the + full timeseries for the object. Default is False. Returns ------- @@ -2351,7 +2365,7 @@ def select_random_timeseries(self, seed=None): if i >= len(partitions): raise IndexError("Found no object IDs in the Object Table.") - return self.to_timeseries(lcid) + return self.to_timeseries(lcid) if not id_only else lcid def to_timeseries( self, diff --git a/tests/tape_tests/test_ensemble.py b/tests/tape_tests/test_ensemble.py index ba8505b5..8dee5192 100644 --- a/tests/tape_tests/test_ensemble.py +++ b/tests/tape_tests/test_ensemble.py @@ -2134,8 +2134,13 @@ def test_batch_single_lc(data_fixture, request): """ parquet_ensemble = request.getfixturevalue(data_fixture) + # Perform batch only on this specific lightcurve. lc = 88472935274829959 + # Check that we raise an error if single_lc is neither a bool nor an integer + with pytest.raises(ValueError): + parquet_ensemble.batch(calc_stetson_J, use_map=True, on=None, band_to_calc=None, single_lc="foo") + lc_res = parquet_ensemble.prune(10).batch( calc_stetson_J, use_map=True, on=None, band_to_calc=None, single_lc=lc ) @@ -2145,6 +2150,22 @@ def test_batch_single_lc(data_fixture, request): full_res = parquet_ensemble.prune(10).batch(calc_stetson_J, use_map=True, on=None, band_to_calc=None) assert full_res.compute().loc[lc].stetsonJ == lc_res.compute().iloc[0].stetsonJ + # Check that when single_lc is True we will perform batch on a random lightcurve and still get only one result. + rand_lc = parquet_ensemble.prune(10).batch( + calc_stetson_J, use_map=True, on=None, band_to_calc=None, single_lc=True + ) + assert len(rand_lc) == 1 + + # Now compare that result to what was computed when doing the full batch result + rand_lc_id = rand_lc.index.compute().values[0] + assert full_res.compute().loc[rand_lc_id].stetsonJ == rand_lc.compute().iloc[0].stetsonJ + + # Check that when single_lc is False we get the same # of results as the full batch + no_lc = parquet_ensemble.prune(10).batch( + calc_stetson_J, use_map=True, on=None, band_to_calc=None, single_lc=False + ) + assert len(full_res) == len(no_lc) + def test_batch_labels(parquet_ensemble): """ @@ -2255,7 +2276,8 @@ def test_batch_with_custom_frame_meta(parquet_ensemble, custom_meta): @pytest.mark.parametrize("repartition", [False, True]) @pytest.mark.parametrize("seed", [None, 42]) -def test_select_random_timeseries(parquet_ensemble, repartition, seed): +@pytest.mark.parametrize("id_only", [True, False]) +def test_select_random_timeseries(parquet_ensemble, repartition, seed, id_only): """Test the behavior of ensemble.select_random_timeseries""" ens = parquet_ensemble @@ -2263,14 +2285,17 @@ def test_select_random_timeseries(parquet_ensemble, repartition, seed): if repartition: ens.object = ens.object.repartition(npartitions=3) - ts = ens.select_random_timeseries(seed=seed) + ts = ens.select_random_timeseries(seed=seed, id_only=id_only) - assert isinstance(ts, TimeSeries) + if not id_only: + assert isinstance(ts, TimeSeries) - if seed == 42 and not repartition: - assert ts.meta["id"] == 88472935274829959 - elif seed == 42 and repartition: - assert ts.meta["id"] == 88480001333818899 + if seed == 42: + expected_id = 88480001333818899 if repartition else 88472935274829959 + if id_only: + assert ts == expected_id + else: + assert ts.meta["id"] == expected_id @pytest.mark.parametrize("all_empty", [False, True]) From c454bbab1a3849623d4ed950dfa5c5bda4df4f62 Mon Sep 17 00:00:00 2001 From: Wilson Beebe Date: Fri, 5 Apr 2024 12:42:22 -0700 Subject: [PATCH 5/5] Update batch showcase for random lightcurve --- docs/tutorials/batch_showcase.ipynb | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/docs/tutorials/batch_showcase.ipynb b/docs/tutorials/batch_showcase.ipynb index f0bb7b2b..42b4d77b 100644 --- a/docs/tutorials/batch_showcase.ipynb +++ b/docs/tutorials/batch_showcase.ipynb @@ -143,6 +143,23 @@ "lc_res.compute()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also run your batch function on a single random lightcurve with `single_lc=True`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rand_lc = ens.batch(my_mean, \"flux\", single_lc=True)\n", + "rand_lc.compute()" + ] + }, { "attachments": {}, "cell_type": "markdown",