diff --git a/src/corrgi/correlation/angular_correlation.py b/src/corrgi/correlation/angular_correlation.py index 70024fb..3a8319e 100644 --- a/src/corrgi/correlation/angular_correlation.py +++ b/src/corrgi/correlation/angular_correlation.py @@ -11,18 +11,21 @@ class AngularCorrelation(Correlation): """The angular correlation utilities.""" def _get_auto_method(self): - return cff.mod.th_A_wg if self.use_weights else cff.mod.th_A_naiveway + return cff.mod.th_A_wg_naiveway if self.use_weights else cff.mod.th_A_naiveway def _get_cross_method(self) -> Callable: - return cff.mod.th_C_wg if self.use_weights else cff.mod.th_C_naiveway + return cff.mod.th_C_wg_naiveway if self.use_weights else cff.mod.th_C_naiveway def _construct_auto_args(self, df: pd.DataFrame, catalog_info: CatalogInfo) -> list: - return [ - len(df), # number of particles + args = [ + len(df), *self.get_coords(df, catalog_info), # cartesian coordinates self.params.nsept, # number of angular separation bins self.bins, # bins in angular separation [deg] ] + if self.use_weights: + args = [args[0], df[self.weight_column].to_numpy(), *args[1:]] + return args def _construct_cross_args( self, @@ -31,7 +34,7 @@ def _construct_cross_args( left_catalog_info: CatalogInfo, right_catalog_info: CatalogInfo, ) -> list: - return [ + args = [ len(left_df), # number of particles of the left partition *self.get_coords(left_df, left_catalog_info), # X,Y,Z coordinates of particles len(right_df), # number of particles of the right partition @@ -39,3 +42,12 @@ def _construct_cross_args( self.params.nsept, # number of angular separation bins self.bins, # bins in angular separation [deg] ] + if self.use_weights: + args = [ + args[0], + left_df[self.weight_column].to_numpy(), + *args[1:5], + right_df[self.weight_column].to_numpy(), + *args[5:], + ] + return args diff --git a/src/corrgi/correlation/correlation.py b/src/corrgi/correlation/correlation.py index aba41fd..9e27e1c 100644 --- a/src/corrgi/correlation/correlation.py +++ b/src/corrgi/correlation/correlation.py @@ -18,10 +18,12 @@ def __init__( self, bins: np.ndarray, params: Munch, + weight_column: str = "wei", use_weights: bool = False, ): self.bins = bins self.params = params + self.weight_column = weight_column self.use_weights = use_weights def count_auto_pairs(self, df: pd.DataFrame, catalog_info: CatalogInfo) -> np.ndarray: diff --git a/src/corrgi/corrgi.py b/src/corrgi/corrgi.py index 2dc62bd..d8ad563 100644 --- a/src/corrgi/corrgi.py +++ b/src/corrgi/corrgi.py @@ -12,6 +12,8 @@ def compute_autocorrelation( catalog: Catalog, random: Catalog, params: Munch, + weight_column: str = "wei", + use_weights: bool = False, ) -> np.ndarray: """Calculates the auto-correlation for a catalog. @@ -21,13 +23,17 @@ def compute_autocorrelation( catalog (Catalog): The catalog. random (Catalog): A random samples catalog. params (Munch): The parameters dictionary to run gundam with. + weight_column (str): The weights column name. Defaults to "wei". + use_weights (bool): Whether to use weights or not. Defaults to False. Returns: A numpy array with the result of the auto-correlation, using the natural estimator. """ + if use_weights and weight_column not in catalog.columns: + raise ValueError(f"Weight column {weight_column} does not exist") num_galaxies = catalog.hc_structure.catalog_info.total_rows num_random = random.hc_structure.catalog_info.total_rows - counts_dd, counts_rr = compute_autocorrelation_counts(corr_type, catalog, random, params) + counts_dd, counts_rr = compute_autocorrelation_counts(corr_type, catalog, random, params, use_weights) return calculate_natural_estimate(counts_dd, counts_rr, num_galaxies, num_random) diff --git a/src/corrgi/dask.py b/src/corrgi/dask.py index bf87851..97a089a 100644 --- a/src/corrgi/dask.py +++ b/src/corrgi/dask.py @@ -15,7 +15,12 @@ def compute_autocorrelation_counts( - corr_type: type[Correlation], catalog: Catalog, random: Catalog, params: Munch + corr_type: type[Correlation], + catalog: Catalog, + random: Catalog, + params: Munch, + weight_column: str = "wei", + use_weights: bool = False, ) -> np.ndarray: """Computes the auto-correlation counts for a catalog. @@ -24,6 +29,8 @@ def compute_autocorrelation_counts( catalog (Catalog): The catalog with galaxy samples. random (Catalog): The catalog with random samples. params (dict): The gundam parameters for the Fortran subroutine. + weight_column (str): The weights column name. Defaults to "wei". + use_weights (bool): Whether to use weights or not. Defaults to False. Returns: The histogram counts to calculate the auto-correlation. @@ -31,7 +38,7 @@ def compute_autocorrelation_counts( # Calculate the angular separation bins bins, _ = gundam.makebins(params.nsept, params.septmin, params.dsept, params.logsept) # Create correlation with bins and params - correlation = corr_type(bins, params) + correlation = corr_type(bins, params, weight_column, use_weights) # Generate the histograms with counts for each catalog counts_dd = perform_auto_counts(catalog, correlation) counts_rr = perform_auto_counts(random, correlation) diff --git a/tests/corrgi/conftest.py b/tests/corrgi/conftest.py index 7df4758..6a1eca0 100644 --- a/tests/corrgi/conftest.py +++ b/tests/corrgi/conftest.py @@ -50,6 +50,16 @@ def rand_catalog_dir(hipscat_catalogs_dir): return hipscat_catalogs_dir / "RAND" +@pytest.fixture +def acf_gals_weight_dir(hipscat_catalogs_dir): + return hipscat_catalogs_dir / "acf_gals_weight" + + +@pytest.fixture +def acf_rans_weight_dir(hipscat_catalogs_dir): + return hipscat_catalogs_dir / "acf_rans_weight" + + @pytest.fixture def acf_bins_left_edges(acf_expected_results): return np.load(acf_expected_results / "l_binedges_acf.npy") @@ -70,6 +80,16 @@ def acf_rr_counts(acf_expected_results): return np.load(acf_expected_results / "rr_acf.npy") +@pytest.fixture +def acf_dd_counts_with_weights(acf_expected_results): + return np.load(acf_expected_results / "dd_acf_weight.npy") + + +@pytest.fixture +def acf_rr_counts_with_weights(acf_expected_results): + return np.load(acf_expected_results / "rr_acf_weight.npy") + + @pytest.fixture def acf_nat_estimate(acf_expected_results): return np.load(acf_expected_results / "w_acf_nat.npy") diff --git a/tests/corrgi/test_acf.py b/tests/corrgi/test_acf.py index 8f4249c..eee9a5a 100644 --- a/tests/corrgi/test_acf.py +++ b/tests/corrgi/test_acf.py @@ -2,6 +2,7 @@ import lsdb import numpy as np import numpy.testing as npt +import pytest from gundam import gundam from corrgi.correlation.angular_correlation import AngularCorrelation @@ -70,3 +71,41 @@ def test_acf_e2e( AngularCorrelation, galaxy_catalog, random_catalog, autocorr_params ) npt.assert_allclose(estimate, acf_nat_estimate, rtol=1e-7) + + +def test_acf_counts_with_weights_are_correct( + dask_client, + acf_gals_weight_dir, + acf_rans_weight_dir, + acf_dd_counts_with_weights, + acf_rr_counts_with_weights, + autocorr_params, +): + galaxy_catalog = lsdb.read_hipscat(acf_gals_weight_dir) + random_catalog = lsdb.read_hipscat(acf_rans_weight_dir) + assert isinstance(galaxy_catalog, lsdb.Catalog) + assert isinstance(random_catalog, lsdb.Catalog) + counts_dd, counts_rr = compute_autocorrelation_counts( + AngularCorrelation, + galaxy_catalog, + random_catalog, + autocorr_params, + use_weights=True, + ) + npt.assert_allclose(counts_dd, acf_dd_counts_with_weights, rtol=1e-3) + npt.assert_allclose(counts_rr, acf_rr_counts_with_weights, rtol=2e-3) + + +def test_acf_weights_not_provided(data_catalog_dir, rand_catalog_dir, autocorr_params): + galaxy_catalog = lsdb.read_hipscat(data_catalog_dir) + random_catalog = lsdb.read_hipscat(rand_catalog_dir) + assert isinstance(galaxy_catalog, lsdb.Catalog) + assert isinstance(random_catalog, lsdb.Catalog) + with pytest.raises(ValueError, match="does not exist"): + compute_autocorrelation( + AngularCorrelation, + galaxy_catalog, + random_catalog, + autocorr_params, + use_weights=True, + ) diff --git a/tests/data/expected_results/acf/dd_acf_weight.npy b/tests/data/expected_results/acf/dd_acf_weight.npy new file mode 100644 index 0000000..6015037 Binary files /dev/null and b/tests/data/expected_results/acf/dd_acf_weight.npy differ diff --git a/tests/data/expected_results/acf/rr_acf_weight.npy b/tests/data/expected_results/acf/rr_acf_weight.npy new file mode 100644 index 0000000..466e802 Binary files /dev/null and b/tests/data/expected_results/acf/rr_acf_weight.npy differ diff --git a/tests/data/generate_data.ipynb b/tests/data/generate_data.ipynb index e59165d..0cd1441 100644 --- a/tests/data/generate_data.ipynb +++ b/tests/data/generate_data.ipynb @@ -47,8 +47,12 @@ "metadata": {}, "outputs": [], "source": [ + "# Without weights\n", "generate_catalog(\"DATA\")\n", - "generate_catalog(\"DR7-lrg\")" + "generate_catalog(\"DR7-lrg\")\n", + "\n", + "# With weights\n", + "generate_catalog(\"acf_gals_weight\")" ] }, { @@ -64,8 +68,12 @@ "metadata": {}, "outputs": [], "source": [ + "# Without weights\n", "generate_catalog(\"RAND\")\n", - "generate_catalog(\"DR7-lrg-rand\")" + "generate_catalog(\"DR7-lrg-rand\")\n", + "\n", + "# With weights\n", + "generate_catalog(\"acf_rans_weight\")" ] } ], diff --git a/tests/data/hipscat/acf_gals_weight/Norder=0/Dir=0/Npix=1.parquet b/tests/data/hipscat/acf_gals_weight/Norder=0/Dir=0/Npix=1.parquet new file mode 100644 index 0000000..f7dba27 Binary files /dev/null and b/tests/data/hipscat/acf_gals_weight/Norder=0/Dir=0/Npix=1.parquet differ diff --git a/tests/data/hipscat/acf_gals_weight/Norder=0/Dir=0/Npix=10.parquet b/tests/data/hipscat/acf_gals_weight/Norder=0/Dir=0/Npix=10.parquet new file mode 100644 index 0000000..ff4a1d9 Binary files /dev/null and b/tests/data/hipscat/acf_gals_weight/Norder=0/Dir=0/Npix=10.parquet differ diff --git a/tests/data/hipscat/acf_gals_weight/Norder=0/Dir=0/Npix=2.parquet b/tests/data/hipscat/acf_gals_weight/Norder=0/Dir=0/Npix=2.parquet new file mode 100644 index 0000000..d624959 Binary files /dev/null and b/tests/data/hipscat/acf_gals_weight/Norder=0/Dir=0/Npix=2.parquet differ diff --git a/tests/data/hipscat/acf_gals_weight/Norder=0/Dir=0/Npix=5.parquet b/tests/data/hipscat/acf_gals_weight/Norder=0/Dir=0/Npix=5.parquet new file mode 100644 index 0000000..cd3b60a Binary files /dev/null and b/tests/data/hipscat/acf_gals_weight/Norder=0/Dir=0/Npix=5.parquet differ diff --git a/tests/data/hipscat/acf_gals_weight/Norder=0/Dir=0/Npix=6.parquet b/tests/data/hipscat/acf_gals_weight/Norder=0/Dir=0/Npix=6.parquet new file mode 100644 index 0000000..ca4d06e Binary files /dev/null and b/tests/data/hipscat/acf_gals_weight/Norder=0/Dir=0/Npix=6.parquet differ diff --git a/tests/data/hipscat/acf_gals_weight/Norder=0/Dir=0/Npix=7.parquet b/tests/data/hipscat/acf_gals_weight/Norder=0/Dir=0/Npix=7.parquet new file mode 100644 index 0000000..370f758 Binary files /dev/null and b/tests/data/hipscat/acf_gals_weight/Norder=0/Dir=0/Npix=7.parquet differ diff --git a/tests/data/hipscat/acf_gals_weight/_common_metadata b/tests/data/hipscat/acf_gals_weight/_common_metadata new file mode 100644 index 0000000..61d4b22 Binary files /dev/null and b/tests/data/hipscat/acf_gals_weight/_common_metadata differ diff --git a/tests/data/hipscat/acf_gals_weight/_metadata b/tests/data/hipscat/acf_gals_weight/_metadata new file mode 100644 index 0000000..f15f9e2 Binary files /dev/null and b/tests/data/hipscat/acf_gals_weight/_metadata differ diff --git a/tests/data/hipscat/acf_gals_weight/catalog_info.json b/tests/data/hipscat/acf_gals_weight/catalog_info.json new file mode 100644 index 0000000..278bcd8 --- /dev/null +++ b/tests/data/hipscat/acf_gals_weight/catalog_info.json @@ -0,0 +1,8 @@ +{ + "catalog_name": "acf_gals_weight", + "catalog_type": "object", + "total_rows": 8000, + "epoch": "J2000", + "ra_column": "ra", + "dec_column": "dec" +} diff --git a/tests/data/hipscat/acf_gals_weight/partition_info.csv b/tests/data/hipscat/acf_gals_weight/partition_info.csv new file mode 100644 index 0000000..9d1b151 --- /dev/null +++ b/tests/data/hipscat/acf_gals_weight/partition_info.csv @@ -0,0 +1,7 @@ +Norder,Dir,Npix,num_rows +0,0,1,2169 +0,0,2,3372 +0,0,5,71 +0,0,6,2048 +0,0,7,330 +0,0,10,10 diff --git a/tests/data/hipscat/acf_gals_weight/provenance_info.json b/tests/data/hipscat/acf_gals_weight/provenance_info.json new file mode 100644 index 0000000..1b9be44 --- /dev/null +++ b/tests/data/hipscat/acf_gals_weight/provenance_info.json @@ -0,0 +1,25 @@ +{ + "catalog_name": "acf_gals_weight", + "catalog_type": "object", + "total_rows": 8000, + "epoch": "J2000", + "ra_column": "ra", + "dec_column": "dec", + "version": "0.3.4", + "generation_date": "2024.07.02", + "tool_args": { + "tool_name": "lsdb", + "version": "0.2.5", + "runtime_args": { + "catalog_name": "acf_gals_weight", + "output_path": "hipscat/acf_gals_weight", + "output_catalog_name": "acf_gals_weight", + "catalog_path": "hipscat/acf_gals_weight", + "catalog_type": "object", + "total_rows": 8000, + "epoch": "J2000", + "ra_column": "ra", + "dec_column": "dec" + } + } +} diff --git a/tests/data/hipscat/acf_rans_weight/Norder=0/Dir=0/Npix=1.parquet b/tests/data/hipscat/acf_rans_weight/Norder=0/Dir=0/Npix=1.parquet new file mode 100644 index 0000000..ca12f91 Binary files /dev/null and b/tests/data/hipscat/acf_rans_weight/Norder=0/Dir=0/Npix=1.parquet differ diff --git a/tests/data/hipscat/acf_rans_weight/Norder=0/Dir=0/Npix=10.parquet b/tests/data/hipscat/acf_rans_weight/Norder=0/Dir=0/Npix=10.parquet new file mode 100644 index 0000000..aea184c Binary files /dev/null and b/tests/data/hipscat/acf_rans_weight/Norder=0/Dir=0/Npix=10.parquet differ diff --git a/tests/data/hipscat/acf_rans_weight/Norder=0/Dir=0/Npix=2.parquet b/tests/data/hipscat/acf_rans_weight/Norder=0/Dir=0/Npix=2.parquet new file mode 100644 index 0000000..d714def Binary files /dev/null and b/tests/data/hipscat/acf_rans_weight/Norder=0/Dir=0/Npix=2.parquet differ diff --git a/tests/data/hipscat/acf_rans_weight/Norder=0/Dir=0/Npix=5.parquet b/tests/data/hipscat/acf_rans_weight/Norder=0/Dir=0/Npix=5.parquet new file mode 100644 index 0000000..ea176ee Binary files /dev/null and b/tests/data/hipscat/acf_rans_weight/Norder=0/Dir=0/Npix=5.parquet differ diff --git a/tests/data/hipscat/acf_rans_weight/Norder=0/Dir=0/Npix=6.parquet b/tests/data/hipscat/acf_rans_weight/Norder=0/Dir=0/Npix=6.parquet new file mode 100644 index 0000000..7ac6715 Binary files /dev/null and b/tests/data/hipscat/acf_rans_weight/Norder=0/Dir=0/Npix=6.parquet differ diff --git a/tests/data/hipscat/acf_rans_weight/Norder=0/Dir=0/Npix=7.parquet b/tests/data/hipscat/acf_rans_weight/Norder=0/Dir=0/Npix=7.parquet new file mode 100644 index 0000000..0167a1e Binary files /dev/null and b/tests/data/hipscat/acf_rans_weight/Norder=0/Dir=0/Npix=7.parquet differ diff --git a/tests/data/hipscat/acf_rans_weight/Norder=0/Dir=0/Npix=9.parquet b/tests/data/hipscat/acf_rans_weight/Norder=0/Dir=0/Npix=9.parquet new file mode 100644 index 0000000..665dc66 Binary files /dev/null and b/tests/data/hipscat/acf_rans_weight/Norder=0/Dir=0/Npix=9.parquet differ diff --git a/tests/data/hipscat/acf_rans_weight/_common_metadata b/tests/data/hipscat/acf_rans_weight/_common_metadata new file mode 100644 index 0000000..61d4b22 Binary files /dev/null and b/tests/data/hipscat/acf_rans_weight/_common_metadata differ diff --git a/tests/data/hipscat/acf_rans_weight/_metadata b/tests/data/hipscat/acf_rans_weight/_metadata new file mode 100644 index 0000000..a079e03 Binary files /dev/null and b/tests/data/hipscat/acf_rans_weight/_metadata differ diff --git a/tests/data/hipscat/acf_rans_weight/catalog_info.json b/tests/data/hipscat/acf_rans_weight/catalog_info.json new file mode 100644 index 0000000..5494a68 --- /dev/null +++ b/tests/data/hipscat/acf_rans_weight/catalog_info.json @@ -0,0 +1,8 @@ +{ + "catalog_name": "acf_rans_weight", + "catalog_type": "object", + "total_rows": 22000, + "epoch": "J2000", + "ra_column": "ra", + "dec_column": "dec" +} diff --git a/tests/data/hipscat/acf_rans_weight/partition_info.csv b/tests/data/hipscat/acf_rans_weight/partition_info.csv new file mode 100644 index 0000000..d6bd6c5 --- /dev/null +++ b/tests/data/hipscat/acf_rans_weight/partition_info.csv @@ -0,0 +1,8 @@ +Norder,Dir,Npix,num_rows +0,0,1,7052 +0,0,2,8024 +0,0,5,188 +0,0,6,5904 +0,0,7,809 +0,0,9,1 +0,0,10,22 diff --git a/tests/data/hipscat/acf_rans_weight/provenance_info.json b/tests/data/hipscat/acf_rans_weight/provenance_info.json new file mode 100644 index 0000000..436a7e5 --- /dev/null +++ b/tests/data/hipscat/acf_rans_weight/provenance_info.json @@ -0,0 +1,25 @@ +{ + "catalog_name": "acf_rans_weight", + "catalog_type": "object", + "total_rows": 22000, + "epoch": "J2000", + "ra_column": "ra", + "dec_column": "dec", + "version": "0.3.4", + "generation_date": "2024.07.02", + "tool_args": { + "tool_name": "lsdb", + "version": "0.2.5", + "runtime_args": { + "catalog_name": "acf_rans_weight", + "output_path": "hipscat/acf_rans_weight", + "output_catalog_name": "acf_rans_weight", + "catalog_path": "hipscat/acf_rans_weight", + "catalog_type": "object", + "total_rows": 22000, + "epoch": "J2000", + "ra_column": "ra", + "dec_column": "dec" + } + } +} diff --git a/tests/data/raw/acf_gals_weight.fits b/tests/data/raw/acf_gals_weight.fits new file mode 100644 index 0000000..b49ab3a Binary files /dev/null and b/tests/data/raw/acf_gals_weight.fits differ diff --git a/tests/data/raw/acf_rans_weight.fits b/tests/data/raw/acf_rans_weight.fits new file mode 100644 index 0000000..c7a812c Binary files /dev/null and b/tests/data/raw/acf_rans_weight.fits differ