From 0c3e06951defc4f8c1e396898217d72d8f020bbb Mon Sep 17 00:00:00 2001 From: Konstantin Malanchev Date: Wed, 13 Dec 2023 15:12:19 -0500 Subject: [PATCH] Allow to ignore band with FeatureExtractor --- .../tutorials/working_with_the_ensemble.ipynb | 1 + src/tape/analysis/feature_extractor.py | 11 +++-- tests/tape_tests/test_feature_extraction.py | 41 +++++++++++++++++-- 3 files changed, 46 insertions(+), 7 deletions(-) diff --git a/docs/tutorials/working_with_the_ensemble.ipynb b/docs/tutorials/working_with_the_ensemble.ipynb index 2d2eb993..5908fa58 100644 --- a/docs/tutorials/working_with_the_ensemble.ipynb +++ b/docs/tutorials/working_with_the_ensemble.ipynb @@ -750,6 +750,7 @@ "import light_curve as licu\n", "\n", "extractor = licu.Extractor(licu.Amplitude(), licu.AndersonDarlingNormal(), licu.StetsonK())\n", + "# band_to_calc=None will ignore the band column and use all sources for each object\n", "res = ens.batch(extractor, compute=True, band_to_calc=\"g\")\n", "res" ] diff --git a/src/tape/analysis/feature_extractor.py b/src/tape/analysis/feature_extractor.py index 58560d9e..d8841759 100644 --- a/src/tape/analysis/feature_extractor.py +++ b/src/tape/analysis/feature_extractor.py @@ -62,8 +62,10 @@ def __call__(self, time, flux, err, band, *, band_to_calc: str, **kwargs) -> pd. Errors for "flux" band : `numpy.ndarray` Passband names. - band_to_calc : `str` - Name of the passband to calculate features for. + band_to_calc : `str` or `int` or `None` + Name of the passband to calculate features for, usually a string + like "g" or "r", or an integer. If None, then features are + calculated for all sources - band is ignored. **kwargs : `dict` Additional keyword arguments to pass to the feature extractor. @@ -74,8 +76,9 @@ def __call__(self, time, flux, err, band, *, band_to_calc: str, **kwargs) -> pd. """ # Select passband to calculate - band_mask = band == band_to_calc - time, flux, err = (a[band_mask] for a in (time, flux, err)) + if band_to_calc is not None: + band_mask = band == band_to_calc + time, flux, err = (a[band_mask] for a in (time, flux, err)) # Sort inputs by time if not already sorted if not kwargs.get("sorted", False): diff --git a/tests/tape_tests/test_feature_extraction.py b/tests/tape_tests/test_feature_extraction.py index d5ee9db7..6fdd394b 100644 --- a/tests/tape_tests/test_feature_extraction.py +++ b/tests/tape_tests/test_feature_extraction.py @@ -25,7 +25,7 @@ def test_stetsonk(): assert_array_equal(result.dtypes, np.float64) -def test_stetsonk_with_ensemble(dask_client): +def test_multiple_features_with_ensemble(dask_client): n = 5 object1 = { @@ -47,12 +47,47 @@ def test_stetsonk_with_ensemble(dask_client): cmap = ColumnMapper(id_col="id", time_col="time", flux_col="flux", err_col="err", band_col="band") ens = Ensemble(client=dask_client).from_source_dict(rows, cmap) - stetson_k = licu.Extractor(licu.AndersonDarlingNormal(), licu.InterPercentileRange(0.25), licu.StetsonK()) + extractor = licu.Extractor(licu.AndersonDarlingNormal(), licu.InterPercentileRange(0.25), licu.StetsonK()) result = ens.batch( - stetson_k, + extractor, band_to_calc="g", ) assert result.shape == (2, 3) assert_array_equal(result.columns, ["anderson_darling_normal", "inter_percentile_range_25", "stetson_K"]) assert_allclose(result, [[0.114875, 0.625, 0.848528]] * 2, atol=1e-5) + + +def test_otsu_with_ensemble_all_bands(dask_client): + n = 10 + assert n % 2 == 0 + + object1 = { + "id": np.full(n, 1), + "time": np.arange(n, dtype=np.float64), + "flux": np.r_[np.zeros(n // 2), np.ones(n // 2)], + "err": np.full(n, 0.1), + "band": np.full(n, "g"), + } + object2 = { + "id": np.full(n, 2), + "time": object1["time"], + "flux": object1["flux"], + "err": object1["err"], + "band": np.r_[np.full(n // 2, "g"), np.full(n // 2, "r")], + } + rows = {column: np.concatenate([object1[column], object2[column]]) for column in object1} + + cmap = ColumnMapper(id_col="id", time_col="time", flux_col="flux", err_col="err", band_col="band") + ens = Ensemble(client=dask_client).from_source_dict(rows, cmap) + + result = ens.batch( + licu.OtsuSplit(), + band_to_calc=None, + ) + + assert result.shape == (2, 4) + assert_array_equal( + result.columns, ["otsu_mean_diff", "otsu_std_lower", "otsu_std_upper", "otsu_lower_to_all_ratio"] + ) + assert_allclose(result, [[1.0, 0.0, 0.0, 0.5]] * 2, atol=1e-5)