Skip to content

Commit

Permalink
Merge pull request #325 from lincc-frameworks/licu_all_bands
Browse files Browse the repository at this point in the history
Allow to ignore band with FeatureExtractor
  • Loading branch information
hombit authored Dec 13, 2023
2 parents 64b6f99 + 0c3e069 commit d6dde0a
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 7 deletions.
1 change: 1 addition & 0 deletions docs/tutorials/working_with_the_ensemble.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,7 @@
"import light_curve as licu\n",
"\n",
"extractor = licu.Extractor(licu.Amplitude(), licu.AndersonDarlingNormal(), licu.StetsonK())\n",
"# band_to_calc=None will ignore the band column and use all sources for each object\n",
"res = ens.batch(extractor, compute=True, band_to_calc=\"g\")\n",
"res"
]
Expand Down
11 changes: 7 additions & 4 deletions src/tape/analysis/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ def __call__(self, time, flux, err, band, *, band_to_calc: str, **kwargs) -> pd.
Errors for "flux"
band : `numpy.ndarray`
Passband names.
band_to_calc : `str`
Name of the passband to calculate features for.
band_to_calc : `str` or `int` or `None`
Name of the passband to calculate features for, usually a string
like "g" or "r", or an integer. If None, then features are
calculated for all sources - band is ignored.
**kwargs : `dict`
Additional keyword arguments to pass to the feature extractor.
Expand All @@ -74,8 +76,9 @@ def __call__(self, time, flux, err, band, *, band_to_calc: str, **kwargs) -> pd.
"""

# Select passband to calculate
band_mask = band == band_to_calc
time, flux, err = (a[band_mask] for a in (time, flux, err))
if band_to_calc is not None:
band_mask = band == band_to_calc
time, flux, err = (a[band_mask] for a in (time, flux, err))

# Sort inputs by time if not already sorted
if not kwargs.get("sorted", False):
Expand Down
41 changes: 38 additions & 3 deletions tests/tape_tests/test_feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_stetsonk():
assert_array_equal(result.dtypes, np.float64)


def test_stetsonk_with_ensemble(dask_client):
def test_multiple_features_with_ensemble(dask_client):
n = 5

object1 = {
Expand All @@ -47,12 +47,47 @@ def test_stetsonk_with_ensemble(dask_client):
cmap = ColumnMapper(id_col="id", time_col="time", flux_col="flux", err_col="err", band_col="band")
ens = Ensemble(client=dask_client).from_source_dict(rows, cmap)

stetson_k = licu.Extractor(licu.AndersonDarlingNormal(), licu.InterPercentileRange(0.25), licu.StetsonK())
extractor = licu.Extractor(licu.AndersonDarlingNormal(), licu.InterPercentileRange(0.25), licu.StetsonK())
result = ens.batch(
stetson_k,
extractor,
band_to_calc="g",
)

assert result.shape == (2, 3)
assert_array_equal(result.columns, ["anderson_darling_normal", "inter_percentile_range_25", "stetson_K"])
assert_allclose(result, [[0.114875, 0.625, 0.848528]] * 2, atol=1e-5)


def test_otsu_with_ensemble_all_bands(dask_client):
n = 10
assert n % 2 == 0

object1 = {
"id": np.full(n, 1),
"time": np.arange(n, dtype=np.float64),
"flux": np.r_[np.zeros(n // 2), np.ones(n // 2)],
"err": np.full(n, 0.1),
"band": np.full(n, "g"),
}
object2 = {
"id": np.full(n, 2),
"time": object1["time"],
"flux": object1["flux"],
"err": object1["err"],
"band": np.r_[np.full(n // 2, "g"), np.full(n // 2, "r")],
}
rows = {column: np.concatenate([object1[column], object2[column]]) for column in object1}

cmap = ColumnMapper(id_col="id", time_col="time", flux_col="flux", err_col="err", band_col="band")
ens = Ensemble(client=dask_client).from_source_dict(rows, cmap)

result = ens.batch(
licu.OtsuSplit(),
band_to_calc=None,
)

assert result.shape == (2, 4)
assert_array_equal(
result.columns, ["otsu_mean_diff", "otsu_std_lower", "otsu_std_upper", "otsu_lower_to_all_ratio"]
)
assert_allclose(result, [[1.0, 0.0, 0.0, 0.5]] * 2, atol=1e-5)

0 comments on commit d6dde0a

Please sign in to comment.