diff --git a/src/tape/ensemble.py b/src/tape/ensemble.py index 8d0b65dc..406a5881 100644 --- a/src/tape/ensemble.py +++ b/src/tape/ensemble.py @@ -611,6 +611,7 @@ def calc_nobs(self, by_band=False, label="nobs", temporary=True): self._source.groupby([self._id_col])[self._band_col] # group by each object .value_counts() # count occurence of each band .to_frame() # convert series to dataframe + .rename(columns={self._band_col: "counts"}) # rename column .reset_index() # break up the multiindex .categorize(columns=[self._band_col]) # retype the band labels as categories .pivot_table(values=self._band_col, index=self._id_col, columns=self._band_col, aggfunc="sum") @@ -618,8 +619,8 @@ def calc_nobs(self, by_band=False, label="nobs", temporary=True): # repartition the result to align with object if self._object.known_divisions: - self._object.divisions = tuple([None for i in range(self._object.npartitions + 1)]) - band_counts = band_counts.repartition(npartitions=self._object.npartitions) + band_counts.divisions = self._source.divisions + band_counts = band_counts.repartition(divisions=self._object.divisions) else: band_counts = band_counts.repartition(npartitions=self._object.npartitions) @@ -636,9 +637,9 @@ def calc_nobs(self, by_band=False, label="nobs", temporary=True): counts = self._source.groupby([self._id_col])[[self._band_col]].aggregate("count") # repartition the result to align with object - if self._object.known_divisions: - self._object.divisions = tuple([None for i in range(self._object.npartitions + 1)]) - counts = counts.repartition(npartitions=self._object.npartitions) + if self._object.known_divisions and self._source.known_divisions: + counts.divisions = self._source.divisions + counts = counts.repartition(divisions=self._object.divisions) else: counts = counts.repartition(npartitions=self._object.npartitions) @@ -677,8 +678,7 @@ def prune(self, threshold=50, col_name=None): col_name = "nobs_total" # Mask on object table - mask = self._object[col_name] >= threshold - self._object = self._object[mask] + self = self.query(f"{col_name} >= {threshold}", table="object") self._object_dirty = True # Object Table is now dirty @@ -953,6 +953,11 @@ def s2n_inter_quartile_range(flux, err): meta=meta, ) + # Inherit divisions if known from source and the resulting index is the id + # Groupby on index should always return a subset that adheres to the same divisions criteria + if self._source.known_divisions and batch.index.name == self._id_col: + batch.divisions = self._source.divisions + if compute: return batch.compute() else: @@ -1078,6 +1083,12 @@ def from_dask_dataframe( elif partition_size: self._source = self._source.repartition(partition_size=partition_size) + # Check that Divisions are established, warn if not. + for name, table in [("object", self._object), ("source", self._source)]: + if not table.known_divisions: + warnings.warn( + f"Divisions for {name} are not set, certain downstream dask operations may fail as a result. We recommend setting the `sort` or `sorted` flags when loading data to establish division information." + ) return self def from_hipscat(self, dir, source_subdir="source", object_subdir="object", column_mapper=None, **kwargs): @@ -1272,7 +1283,10 @@ def from_parquet( columns.append(self._provenance_col) # Read in the source parquet file(s) - source = dd.read_parquet(source_file, index=self._id_col, columns=columns, split_row_groups=True) + # Index is set False so that we can set it with a future set_index call + # This has the advantage of letting Dask set partition boundaries based + # on the divisions between the sources of different objects. + source = dd.read_parquet(source_file, index=False, columns=columns, split_row_groups=True) # Generate a provenance column if not provided if self._provenance_col is None: @@ -1282,7 +1296,9 @@ def from_parquet( object = None if object_file: # Read in the object file(s) - object = dd.read_parquet(object_file, index=self._id_col, split_row_groups=True) + # Index is False so that we can set it with a future set_index call + # More meaningful for source than object but parity seems good here + object = dd.read_parquet(object_file, index=False, split_row_groups=True) return self.from_dask_dataframe( source_frame=source, object_frame=object, @@ -1468,9 +1484,7 @@ def convert_flux_to_mag(self, zero_point, zp_form="mag", out_col_name=None, flux def _generate_object_table(self): """Generate an empty object table from the source table.""" - sor_idx = self._source.index.unique() - obj_df = pd.DataFrame(index=sor_idx) - res = dd.from_pandas(obj_df, npartitions=int(np.ceil(self._source.npartitions / 100))) + res = self._source.map_partitions(lambda x: pd.DataFrame(index=x.index.unique())) return res @@ -1504,9 +1518,18 @@ def _sync_tables(self): if self._object_dirty: # Sync Object to Source; remove any missing objects from source - obj_idx = list(self._object.index.compute()) - self._source = self._source.map_partitions(lambda x: x[x.index.isin(obj_idx)]) - self._source = self._source.persist() # persist the source frame + + if self._object.known_divisions and self._source.known_divisions: + # Lazily Create an empty object table (just index) for joining + empty_obj = self._object.map_partitions(lambda x: pd.DataFrame(index=x.index)) + + # Join source onto the empty object table to align + self._source = empty_obj.join(self._source) + else: + warnings.warn("Divisions are not known, syncing using a non-lazy method.") + obj_idx = list(self._object.index.compute()) + self._source = self._source.map_partitions(lambda x: x[x.index.isin(obj_idx)]) + self._source = self._source.persist() # persist the source frame # Drop Temporary Source Columns on Sync if len(self._source_temp): @@ -1516,10 +1539,19 @@ def _sync_tables(self): if self._source_dirty: # not elif if not self.keep_empty_objects: - # Sync Source to Object; remove any objects that do not have sources - sor_idx = list(self._source.index.unique().compute()) - self._object = self._object.map_partitions(lambda x: x[x.index.isin(sor_idx)]) - self._object = self._object.persist() # persist the object frame + if self._object.known_divisions and self._source.known_divisions: + # Lazily Create an empty source table (just unique indexes) for joining + empty_src = self._source.map_partitions(lambda x: pd.DataFrame(index=x.index.unique())) + + # Join object onto the empty unique source table to align + self._object = empty_src.join(self._object) + + else: + warnings.warn("Divisions are not known, syncing using a non-lazy method.") + # Sync Source to Object; remove any objects that do not have sources + sor_idx = list(self._source.index.unique().compute()) + self._object = self._object.map_partitions(lambda x: x[x.index.isin(sor_idx)]) + self._object = self._object.persist() # persist the object frame # Drop Temporary Object Columns on Sync if len(self._object_temp): @@ -1619,7 +1651,7 @@ def _build_index(self, obj_id, band): index = pd.MultiIndex.from_tuples(tuples, names=["object_id", "band", "index"]) return index - def sf2(self, sf_method="basic", argument_container=None, use_map=True): + def sf2(self, sf_method="basic", argument_container=None, use_map=True, compute=True): """Wrapper interface for calling structurefunction2 on the ensemble Parameters @@ -1661,8 +1693,14 @@ def sf2(self, sf_method="basic", argument_container=None, use_map=True): self._source.index, argument_container=argument_container, ) - return result + else: - result = self.batch(calc_sf2, use_map=use_map, argument_container=argument_container) + result = self.batch( + calc_sf2, use_map=use_map, argument_container=argument_container, compute=compute + ) + + # Inherit divisions information if known + if self._source.known_divisions and self._object.known_divisions: + result.divisions = self._source.divisions - return result + return result diff --git a/tests/tape_tests/conftest.py b/tests/tape_tests/conftest.py index 63eef1e4..06040690 100644 --- a/tests/tape_tests/conftest.py +++ b/tests/tape_tests/conftest.py @@ -252,6 +252,25 @@ def parquet_ensemble(dask_client): return ens +# pylint: disable=redefined-outer-name +@pytest.fixture +def parquet_ensemble_with_divisions(dask_client): + """Create an Ensemble from parquet data.""" + ens = Ensemble(client=dask_client) + ens.from_parquet( + "tests/tape_tests/data/source/test_source.parquet", + "tests/tape_tests/data/object/test_object.parquet", + id_col="ps1_objid", + time_col="midPointTai", + band_col="filterName", + flux_col="psFlux", + err_col="psFluxErr", + sort=True, + ) + + return ens + + # pylint: disable=redefined-outer-name @pytest.fixture def parquet_ensemble_from_source(dask_client): diff --git a/tests/tape_tests/test_ensemble.py b/tests/tape_tests/test_ensemble.py index 1db32e2b..9fbf1d53 100644 --- a/tests/tape_tests/test_ensemble.py +++ b/tests/tape_tests/test_ensemble.py @@ -32,6 +32,7 @@ def test_with_client(): "data_fixture", [ "parquet_ensemble", + "parquet_ensemble_with_divisions", "parquet_ensemble_without_client", "parquet_ensemble_from_source", "parquet_ensemble_from_hipscat", @@ -55,6 +56,11 @@ def test_parquet_construction(data_fixture, request): assert parquet_ensemble._source is not None assert parquet_ensemble._object is not None + # Make sure divisions are set + if data_fixture == "parquet_ensemble_with_divisions": + assert parquet_ensemble._source.known_divisions + assert parquet_ensemble._object.known_divisions + # Check that the data is not empty. obj, source = parquet_ensemble.compute() assert len(source) == 2000 @@ -566,11 +572,20 @@ def test_update_column_map(dask_client): assert cmap_2.map["provenance_col"] == "p" -def test_sync_tables(parquet_ensemble): +@pytest.mark.parametrize( + "data_fixture", + [ + "parquet_ensemble", + "parquet_ensemble_with_divisions", + ], +) +def test_sync_tables(data_fixture, request): """ Test that _sync_tables works as expected """ + parquet_ensemble = request.getfixturevalue(data_fixture) + assert len(parquet_ensemble.compute("object")) == 15 assert len(parquet_ensemble.compute("source")) == 2000 @@ -593,6 +608,11 @@ def test_sync_tables(parquet_ensemble): assert not parquet_ensemble._object_dirty assert not parquet_ensemble._source_dirty + # Make sure that divisions are preserved + if data_fixture == "parquet_ensemble_with_divisions": + assert parquet_ensemble._source.known_divisions + assert parquet_ensemble._object.known_divisions + def test_lazy_sync_tables(parquet_ensemble): """ @@ -691,7 +711,16 @@ def test_temporary_cols(parquet_ensemble): assert "f2" not in ens._source.columns -def test_dropna(parquet_ensemble): +@pytest.mark.parametrize( + "data_fixture", + [ + "parquet_ensemble", + "parquet_ensemble_with_divisions", + ], +) +def test_dropna(data_fixture, request): + parquet_ensemble = request.getfixturevalue(data_fixture) + # Try passing in an unrecognized 'table' parameter and verify an exception is thrown with pytest.raises(ValueError): parquet_ensemble.dropna(table="banana") @@ -703,7 +732,7 @@ def test_dropna(parquet_ensemble): # Try dropping NaNs from source and confirm nothing is dropped (there are no NaNs). parquet_ensemble.dropna(table="source") - assert len(parquet_ensemble._source.compute().index) == source_length + assert len(parquet_ensemble._source) == source_length # Get a valid ID to use and count its occurrences. valid_source_id = source_pdf.index.values[1] @@ -719,11 +748,12 @@ def test_dropna(parquet_ensemble): parquet_ensemble.dropna(table="source") assert len(parquet_ensemble._source.compute().index) == source_length - occurrences_source - # Sync the table and check that the number of objects decreased. - # parquet_ensemble._sync_tables() + if data_fixture == "parquet_ensemble_with_divisions": + # divisions should be preserved + assert parquet_ensemble._source.known_divisions # Now test dropping na from 'object' table - # + object_pdf = parquet_ensemble._object.compute() object_length = len(object_pdf.index) @@ -731,10 +761,8 @@ def test_dropna(parquet_ensemble): parquet_ensemble.dropna(table="object") assert len(parquet_ensemble._object.compute().index) == object_length - # get a valid object id and set at least two occurences of that id in the object table + # select an id from the object table valid_object_id = object_pdf.index.values[1] - object_pdf.index.values[0] = valid_object_id - occurrences_object = len(object_pdf.loc[valid_object_id].values) # Set the nobs_g values for one object to NaN so we can drop it. # We do this on the instantiated object (pdf) and convert it back into a @@ -742,12 +770,16 @@ def test_dropna(parquet_ensemble): object_pdf.loc[valid_object_id, parquet_ensemble._object.columns[0]] = pd.NA parquet_ensemble._object = dd.from_pandas(object_pdf, npartitions=1) - # Try dropping NaNs from object and confirm that we did. + # Try dropping NaNs from object and confirm that we dropped a row parquet_ensemble.dropna(table="object") - assert len(parquet_ensemble._object.compute().index) == object_length - occurrences_object + assert len(parquet_ensemble._object.compute().index) == object_length - 1 + + if data_fixture == "parquet_ensemble_with_divisions": + # divisions should be preserved + assert parquet_ensemble._object.known_divisions new_objects_pdf = parquet_ensemble._object.compute() - assert len(new_objects_pdf.index) == len(object_pdf.index) - occurrences_object + assert len(new_objects_pdf.index) == len(object_pdf.index) - 1 # Assert the filtered ID is no longer in the objects. assert valid_source_id not in new_objects_pdf.index.values @@ -783,18 +815,25 @@ def test_keep_zeros(parquet_ensemble): assert parquet_ensemble._object.npartitions == prev_npartitions +@pytest.mark.parametrize( + "data_fixture", + [ + "parquet_ensemble", + "parquet_ensemble_with_divisions", + ], +) @pytest.mark.parametrize("by_band", [True, False]) -@pytest.mark.parametrize("know_divisions", [True, False]) -def test_calc_nobs(parquet_ensemble, by_band, know_divisions): - ens = parquet_ensemble - ens._object = ens._object.drop(["nobs_g", "nobs_r", "nobs_total"], axis=1) +def test_calc_nobs(data_fixture, request, by_band): + # Get the Ensemble from a fixture + ens = request.getfixturevalue(data_fixture) - if know_divisions: - ens._object = ens._object.reset_index().set_index(ens._id_col) - assert ens._object.known_divisions + # Drop the existing nobs columns + ens._object = ens._object.drop(["nobs_g", "nobs_r", "nobs_total"], axis=1) + # Calculate nobs ens.calc_nobs(by_band) + # Check that things turned out as we expect lc = ens._object.loc[88472935274829959].compute() if by_band: @@ -805,16 +844,46 @@ def test_calc_nobs(parquet_ensemble, by_band, know_divisions): assert "nobs_total" in ens._object.columns assert lc["nobs_total"].values[0] == 499 + # Make sure that if divisions were set previously, they are preserved + if data_fixture == "parquet_ensemble_with_divisions": + assert ens._object.known_divisions + assert ens._source.known_divisions + -def test_prune(parquet_ensemble): +@pytest.mark.parametrize( + "data_fixture", + [ + "parquet_ensemble", + "parquet_ensemble_with_divisions", + ], +) +@pytest.mark.parametrize("generate_nobs", [False, True]) +def test_prune(data_fixture, request, generate_nobs): """ Test that ensemble.prune() appropriately filters the dataframe """ + + # Get the Ensemble from a fixture + parquet_ensemble = request.getfixturevalue(data_fixture) + threshold = 10 - parquet_ensemble.prune(threshold) + # Generate the nobs cols from within prune + if generate_nobs: + # Drop the existing nobs columns + parquet_ensemble._object = parquet_ensemble._object.drop(["nobs_g", "nobs_r", "nobs_total"], axis=1) + parquet_ensemble.prune(threshold) + + # Use an existing column + else: + parquet_ensemble.prune(threshold, col_name="nobs_total") assert not np.any(parquet_ensemble._object["nobs_total"].values < threshold) + # Make sure that if divisions were set previously, they are preserved + if data_fixture == "parquet_ensemble_with_divisions": + assert parquet_ensemble._source.known_divisions + assert parquet_ensemble._object.known_divisions + def test_query(dask_client): ens = Ensemble(client=dask_client) @@ -1156,6 +1225,7 @@ def test_bin_sources_two_days(dask_client): "data_fixture", [ "parquet_ensemble", + "parquet_ensemble_with_divisions", "parquet_ensemble_without_client", ], ) @@ -1171,9 +1241,15 @@ def test_batch(data_fixture, request, use_map, on): result = ( parquet_ensemble.prune(10) .dropna(table="source") - .batch(calc_stetson_J, use_map=use_map, on=on, band_to_calc=None) + .batch(calc_stetson_J, use_map=use_map, on=on, band_to_calc=None, compute=False) ) + # Make sure that divisions information is propagated if known + if parquet_ensemble._source.known_divisions and parquet_ensemble._object.known_divisions: + assert result.known_divisions + + result = result.compute() + if on is None: assert pytest.approx(result.values[0]["g"], 0.001) == -0.04174282 assert pytest.approx(result.values[0]["r"], 0.001) == 0.6075282 @@ -1226,25 +1302,41 @@ def test_build_index(dask_client): assert result_ids == target +@pytest.mark.parametrize( + "data_fixture", + [ + "parquet_ensemble", + "parquet_ensemble_with_divisions", + ], +) @pytest.mark.parametrize("method", ["size", "length", "loglength"]) @pytest.mark.parametrize("combine", [True, False]) @pytest.mark.parametrize("sthresh", [50, 100]) -def test_sf2(parquet_ensemble, method, combine, sthresh, use_map=False): +def test_sf2(data_fixture, request, method, combine, sthresh, use_map=False): """ Test calling sf2 from the ensemble """ + parquet_ensemble = request.getfixturevalue(data_fixture) arg_container = StructureFunctionArgumentContainer() arg_container.bin_method = method arg_container.combine = combine arg_container.bin_count_target = sthresh - res_sf2 = parquet_ensemble.sf2(argument_container=arg_container, use_map=use_map) + if not combine: + res_sf2 = parquet_ensemble.sf2(argument_container=arg_container, use_map=use_map, compute=False) + else: + res_sf2 = parquet_ensemble.sf2(argument_container=arg_container, use_map=use_map) res_batch = parquet_ensemble.batch(calc_sf2, use_map=use_map, argument_container=arg_container) + if parquet_ensemble._source.known_divisions and parquet_ensemble._object.known_divisions: + if not combine: + assert res_sf2.known_divisions + if combine: assert not res_sf2.equals(res_batch) # output should be different else: + res_sf2 = res_sf2.compute() assert res_sf2.equals(res_batch) # output should be identical