Skip to content

Commit

Permalink
batch with divisions
Browse files Browse the repository at this point in the history
  • Loading branch information
dougbrn committed Nov 6, 2023
1 parent 668d34b commit 71d3543
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
16 changes: 11 additions & 5 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,10 +618,8 @@ def calc_nobs(self, by_band=False, label="nobs", temporary=True):

# repartition the result to align with object
if self._object.known_divisions:
# self._object.divisions = tuple([None for i in range(self._object.npartitions + 1)])
band_counts.divisions = self._source.divisions
band_counts = band_counts.repartition(divisions=self._object.divisions)
# band_counts = band_counts.repartition(npartitions=self._object.npartitions)
else:
band_counts = band_counts.repartition(npartitions=self._object.npartitions)

Expand All @@ -639,7 +637,6 @@ def calc_nobs(self, by_band=False, label="nobs", temporary=True):

# repartition the result to align with object
if self._object.known_divisions and self._source.known_divisions:
# self._object.divisions = tuple([None for i in range(self._object.npartitions + 1)])
counts.divisions = self._source.divisions
counts = counts.repartition(divisions=self._object.divisions)
else:
Expand Down Expand Up @@ -957,6 +954,11 @@ def s2n_inter_quartile_range(flux, err):
meta=meta,
)

# Inherit divisions if known from source and the resulting index is the id
# Groupby on index should always return a subset that adheres to the same divisions criteria
if self._source.known_divisions and batch.index.name == self._id_col:
batch.divisions = self._source.divisions

if compute:
return batch.compute()
else:
Expand Down Expand Up @@ -1687,8 +1689,12 @@ def sf2(self, sf_method="basic", argument_container=None, use_map=True):
self._source.index,
argument_container=argument_container,
)
return result

else:
result = self.batch(calc_sf2, use_map=use_map, argument_container=argument_container)

return result
# Inherit divisions information if known
if self._source.known_divisions and self._object.known_divisions:
result.divisions = self._source.divisions

return result
22 changes: 20 additions & 2 deletions tests/tape_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,7 @@ def test_bin_sources_two_days(dask_client):
"data_fixture",
[
"parquet_ensemble",
"parquet_ensemble_with_divisions",
"parquet_ensemble_without_client",
],
)
Expand All @@ -1234,9 +1235,15 @@ def test_batch(data_fixture, request, use_map, on):
result = (
parquet_ensemble.prune(10)
.dropna(table="source")
.batch(calc_stetson_J, use_map=use_map, on=on, band_to_calc=None)
.batch(calc_stetson_J, use_map=use_map, on=on, band_to_calc=None, compute=False)
)

# Make sure that divisions information is propagated if known
if parquet_ensemble._source.known_divisions and parquet_ensemble._object.known_divisions:
assert result.known_divisions

result = result.compute()

if on is None:
assert pytest.approx(result.values[0]["g"], 0.001) == -0.04174282
assert pytest.approx(result.values[0]["r"], 0.001) == 0.6075282
Expand Down Expand Up @@ -1289,13 +1296,21 @@ def test_build_index(dask_client):
assert result_ids == target


@pytest.mark.parametrize(
"data_fixture",
[
"parquet_ensemble",
"parquet_ensemble_with_divisions",
],
)
@pytest.mark.parametrize("method", ["size", "length", "loglength"])
@pytest.mark.parametrize("combine", [True, False])
@pytest.mark.parametrize("sthresh", [50, 100])
def test_sf2(parquet_ensemble, method, combine, sthresh, use_map=False):
def test_sf2(data_fixture, request, method, combine, sthresh, use_map=False):
"""
Test calling sf2 from the ensemble
"""
parquet_ensemble = request.getfixturevalue(data_fixture)

arg_container = StructureFunctionArgumentContainer()
arg_container.bin_method = method
Expand All @@ -1305,6 +1320,9 @@ def test_sf2(parquet_ensemble, method, combine, sthresh, use_map=False):
res_sf2 = parquet_ensemble.sf2(argument_container=arg_container, use_map=use_map)
res_batch = parquet_ensemble.batch(calc_sf2, use_map=use_map, argument_container=arg_container)

if parquet_ensemble._source.known_divisions and parquet_ensemble._object.known_divisions:
assert res_sf2.known_divisions

if combine:
assert not res_sf2.equals(res_batch) # output should be different
else:
Expand Down

0 comments on commit 71d3543

Please sign in to comment.