diff --git a/src/corrgi/correlation/angular_correlation.py b/src/corrgi/correlation/angular_correlation.py index a79c052..5b9e54e 100644 --- a/src/corrgi/correlation/angular_correlation.py +++ b/src/corrgi/correlation/angular_correlation.py @@ -1,6 +1,7 @@ from typing import Callable import gundam.cflibfor as cff +import numpy as np import pandas as pd from gundam import gundam from hipscat.catalog.catalog_info import CatalogInfo @@ -64,3 +65,7 @@ def _construct_cross_args( *args[5:], ] return args + + def get_bdd_counts(self) -> np.ndarray: + """Returns the boostrap counts for the angular correlation""" + return np.zeros([self.params.nsept, 0]) diff --git a/src/corrgi/correlation/correlation.py b/src/corrgi/correlation/correlation.py index ba8a1d9..51988ce 100644 --- a/src/corrgi/correlation/correlation.py +++ b/src/corrgi/correlation/correlation.py @@ -83,6 +83,15 @@ def _construct_cross_args( """Generate the arguments required for the cross pairing method""" raise NotImplementedError() + @abstractmethod + def get_bdd_counts(self) -> np.ndarray: + """Returns the boostrap counts for the correlation""" + raise NotImplementedError() + + def transform_counts(self, counts: list[np.ndarray]) -> list[np.ndarray]: + """Applies final transformations to the correlation counts""" + return counts + @staticmethod def get_coords(df: pd.DataFrame, catalog_info: CatalogInfo) -> tuple[float, float, float]: """Calculate the cartesian coordinates for the points in the partition""" diff --git a/src/corrgi/correlation/projected_correlation.py b/src/corrgi/correlation/projected_correlation.py index d8ff591..cf6f3cf 100644 --- a/src/corrgi/correlation/projected_correlation.py +++ b/src/corrgi/correlation/projected_correlation.py @@ -94,3 +94,11 @@ def _construct_cross_args( *args[7:], ] return args + + def transform_counts(self, counts: list[np.ndarray]) -> list[np.ndarray]: + """The projected counts need to be transposed before being sent to Fortran""" + return [c.transpose([1, 0]) for c in counts] + + def get_bdd_counts(self) -> np.ndarray: + """Returns the boostrap counts for the projected correlation""" + return np.zeros([self.params.nsepp, self.params.nsepv, 0]) diff --git a/src/corrgi/correlation/redshift_correlation.py b/src/corrgi/correlation/redshift_correlation.py index 80a0bf9..38bfd72 100644 --- a/src/corrgi/correlation/redshift_correlation.py +++ b/src/corrgi/correlation/redshift_correlation.py @@ -1,5 +1,6 @@ from typing import Callable +import numpy as np import pandas as pd from hipscat.catalog.catalog_info import CatalogInfo from lsdb import Catalog @@ -35,3 +36,11 @@ def _construct_cross_args( right_catalog_info: CatalogInfo, ) -> list: raise NotImplementedError() + + def get_bdd_counts(self) -> np.ndarray: + """Returns the boostrap counts for the correlation""" + raise NotImplementedError() + + def transform_counts(self, counts: list[np.ndarray]) -> list[np.ndarray]: + """Applies final transformations to the correlation counts""" + raise NotImplementedError() diff --git a/src/corrgi/corrgi.py b/src/corrgi/corrgi.py index 1c86a39..49e2d79 100644 --- a/src/corrgi/corrgi.py +++ b/src/corrgi/corrgi.py @@ -3,9 +3,7 @@ from munch import Munch from corrgi.correlation.correlation import Correlation -from corrgi.dask import compute_autocorrelation_counts -from corrgi.estimators import calculate_natural_estimate -from corrgi.utils import compute_catalog_size +from corrgi.estimators.estimator_factory import get_estimator_for_correlation def compute_autocorrelation( @@ -25,10 +23,8 @@ def compute_autocorrelation( """ correlation = corr_type(**kwargs) correlation.validate([catalog, random]) - num_galaxies = compute_catalog_size(catalog) - num_random = compute_catalog_size(random) - counts_dd, counts_rr = compute_autocorrelation_counts(catalog, random, correlation) - return calculate_natural_estimate(counts_dd, counts_rr, num_galaxies, num_random) + estimator = get_estimator_for_correlation(correlation) + return estimator.compute_auto_estimate(catalog, random) def compute_crosscorrelation(left: Catalog, right: Catalog, random: Catalog, params: Munch) -> np.ndarray: diff --git a/src/corrgi/dask.py b/src/corrgi/dask.py index 1457ea0..a810b9e 100644 --- a/src/corrgi/dask.py +++ b/src/corrgi/dask.py @@ -12,22 +12,6 @@ from corrgi.utils import join_count_histograms -def compute_autocorrelation_counts(catalog: Catalog, random: Catalog, correlation: Correlation) -> np.ndarray: - """Computes the auto-correlation counts for a catalog. - - Args: - catalog (Catalog): The catalog with galaxy samples. - random (Catalog): The catalog with random samples. - correlation (Correlation): The correlation instance. - - Returns: - The histogram counts to calculate the auto-correlation. - """ - counts_dd = perform_auto_counts(catalog, correlation) - counts_rr = perform_auto_counts(random, correlation) - return dask.compute(*[counts_dd, counts_rr]) - - def perform_auto_counts(catalog: Catalog, *args) -> np.ndarray: """Aligns the pixel of a single catalog and performs the pairs counting. diff --git a/src/corrgi/estimators.py b/src/corrgi/estimators.py deleted file mode 100644 index 7d03b38..0000000 --- a/src/corrgi/estimators.py +++ /dev/null @@ -1,25 +0,0 @@ -import numpy as np -from gundam import tpcf - - -def calculate_natural_estimate( - counts_dd: np.ndarray, - counts_rr: np.ndarray, - num_galaxies: int, - num_random: int, -) -> np.ndarray: - """Calculates the auto-correlation value for the natural estimator. - - Args: - counts_dd (np.ndarray): The counts for the galaxy samples. - counts_rr (np.ndarray): The counts for the random samples. - num_galaxies (int): The number of galaxy samples. - num_random (int): The number of random samples. - - Returns: - The natural correlation function estimate. - """ - dr = 0 # We do not use DR counts for the natural estimator - bdd = np.zeros([len(counts_dd), 0]) # We do not compute the bootstrap counts - wth, _ = tpcf(num_galaxies, num_random, counts_dd, bdd, counts_rr, dr, estimator="NAT") - return wth diff --git a/src/corrgi/estimators/__init__.py b/src/corrgi/estimators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/corrgi/estimators/estimator.py b/src/corrgi/estimators/estimator.py new file mode 100644 index 0000000..6ba4df1 --- /dev/null +++ b/src/corrgi/estimators/estimator.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Callable + +import numpy as np +from gundam.gundam import tpcf, tpcf_wrp +from lsdb import Catalog + +from corrgi.correlation.correlation import Correlation +from corrgi.correlation.projected_correlation import ProjectedCorrelation +from corrgi.utils import compute_catalog_size + + +class Estimator(ABC): + """Estimator base class""" + + def __init__(self, correlation: Correlation): + self.correlation = correlation + + def compute_auto_estimate(self, catalog: Catalog, random: Catalog) -> np.ndarray: + """Computes the auto-correlation for this estimator. + + Args: + catalog (Catalog): The catalog of galaxy samples. + random (Catalog): The catalog of random samples. + + Returns: + The statistical estimate of the auto-correlation function, as a numpy array. + """ + num_galaxies = compute_catalog_size(catalog) + num_random = compute_catalog_size(random) + dd, rr, dr = self.compute_autocorrelation_counts(catalog, random) + args = self._get_auto_args(num_galaxies, num_random, dd, rr, dr) + estimate, _ = self._get_auto_subroutine()(*args) + return estimate + + @abstractmethod + def compute_autocorrelation_counts( + self, catalog: Catalog, random: Catalog + ) -> list[np.ndarray, np.ndarray, np.ndarray | int]: + """Computes the auto-correlation counts (DD, RR, DR). These counts are + represented as numpy arrays but DR may be 0 if it isn't used (e.g. with + the natural estimator).""" + raise NotImplementedError() + + def _get_auto_subroutine(self) -> Callable: + """Returns the Fortran routine to calculate the correlation estimate""" + return tpcf_wrp if isinstance(self.correlation, ProjectedCorrelation) else tpcf + + def _get_auto_args( + self, + num_galaxies: int, + num_random: int, + counts_dd: np.ndarray, + counts_rr: np.ndarray, + counts_dr: np.ndarray, + ) -> list: + """Returns the args for the auto-correlation estimator routine""" + counts_bdd = self.correlation.get_bdd_counts() + args = [num_galaxies, num_random, counts_dd, counts_bdd, counts_rr, counts_dr] + if isinstance(self.correlation, ProjectedCorrelation): + # The projected routines require an additional parameter + args.append(self.correlation.params.dsepv) + args.append(self.correlation.params.estimator) + return args diff --git a/src/corrgi/estimators/estimator_factory.py b/src/corrgi/estimators/estimator_factory.py new file mode 100644 index 0000000..16b2298 --- /dev/null +++ b/src/corrgi/estimators/estimator_factory.py @@ -0,0 +1,22 @@ +from corrgi.correlation.correlation import Correlation +from corrgi.estimators.estimator import Estimator +from corrgi.estimators.natural_estimator import NaturalEstimator + +estimator_class_for_type: dict[str, type[Estimator]] = {"NAT": NaturalEstimator} + + +def get_estimator_for_correlation(correlation: Correlation) -> Estimator: + """Constructs an Estimator instance for the specified correlation. + + Args: + correlation (Correlation): The correlation instance. The type of + "estimator" to use is specified in its parameters. + + Returns: + An initialized Estimator object wrapping the correlation to compute. + """ + type_to_use = correlation.params.estimator + if type_to_use not in estimator_class_for_type: + raise ValueError(f"Cannot load estimator type: {str(type_to_use)}") + estimator_class = estimator_class_for_type[type_to_use] + return estimator_class(correlation) diff --git a/src/corrgi/estimators/natural_estimator.py b/src/corrgi/estimators/natural_estimator.py new file mode 100644 index 0000000..728e2de --- /dev/null +++ b/src/corrgi/estimators/natural_estimator.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import dask +import numpy as np +from lsdb import Catalog + +from corrgi.dask import perform_auto_counts +from corrgi.estimators.estimator import Estimator + + +class NaturalEstimator(Estimator): + """Natural Estimator (`DD/RR - 1`)""" + + def compute_autocorrelation_counts( + self, catalog: Catalog, random: Catalog + ) -> list[np.ndarray, np.ndarray, np.ndarray | int]: + """Computes the auto-correlation counts for the provided catalog. + + Args: + catalog (Catalog): A galaxy samples catalog. + random (Catalog): A random samples catalog. + + Returns: + The DD, RR and DR counts for the natural estimator. + """ + counts_dd = perform_auto_counts(catalog, self.correlation) + counts_rr = perform_auto_counts(random, self.correlation) + counts_dr = 0 # The natural estimator does not use DR counts + counts_dd_rr = dask.compute(*[counts_dd, counts_rr]) + counts_dd_rr = self.correlation.transform_counts(counts_dd_rr) + return [*counts_dd_rr, counts_dr] diff --git a/tests/corrgi/conftest.py b/tests/corrgi/conftest.py index 3a92811..52ec601 100644 --- a/tests/corrgi/conftest.py +++ b/tests/corrgi/conftest.py @@ -161,6 +161,11 @@ def acf_nat_estimate(acf_expected_results): return np.load(acf_expected_results / "w_acf_nat.npy") +@pytest.fixture +def pcf_nat_estimate(pcf_expected_results): + return np.load(pcf_expected_results / "w_pcf_nat.npy") + + @pytest.fixture def single_data_partition(data_catalog_dir): return pd.read_parquet(data_catalog_dir / "Norder=0" / "Dir=0" / "Npix=1.parquet") diff --git a/tests/corrgi/test_acf.py b/tests/corrgi/test_acf.py index 078e1ff..a5f0481 100644 --- a/tests/corrgi/test_acf.py +++ b/tests/corrgi/test_acf.py @@ -1,58 +1,40 @@ import numpy as np import numpy.testing as npt import pytest -from gundam import gundam -import hipscat from corrgi.correlation.angular_correlation import AngularCorrelation from corrgi.corrgi import compute_autocorrelation -from corrgi.dask import compute_autocorrelation_counts -from corrgi.estimators import calculate_natural_estimate +from corrgi.estimators.natural_estimator import NaturalEstimator def test_acf_bins_are_correct(acf_bins_left_edges, acf_bins_right_edges, acf_params): - bins, _ = gundam.makebins( - acf_params.nsept, - acf_params.septmin, - acf_params.dsept, - acf_params.logsept, - ) + bins = AngularCorrelation(params=acf_params).make_bins() all_bins = np.append(acf_bins_left_edges, acf_bins_right_edges[-1]) assert np.array_equal(bins, all_bins) -def test_acf_counts_are_correct( +def test_acf_natural_counts_are_correct( dask_client, data_catalog, rand_catalog, acf_dd_counts, acf_rr_counts, acf_params ): - ang_corr = AngularCorrelation(params=acf_params) - counts_dd, counts_rr = compute_autocorrelation_counts( - data_catalog, rand_catalog, ang_corr + estimator = NaturalEstimator(AngularCorrelation(params=acf_params)) + counts_dd, counts_rr, _ = estimator.compute_autocorrelation_counts( + data_catalog, rand_catalog ) npt.assert_allclose(counts_dd, acf_dd_counts, rtol=1e-3) npt.assert_allclose(counts_rr, acf_rr_counts, rtol=2e-3) def test_acf_natural_estimate_is_correct( - data_catalog_dir, rand_catalog_dir, acf_dd_counts, acf_rr_counts, acf_nat_estimate + dask_client, data_catalog, rand_catalog, acf_nat_estimate, acf_params ): - galaxy_hc_catalog = hipscat.read_from_hipscat(data_catalog_dir) - random_hc_catalog = hipscat.read_from_hipscat(rand_catalog_dir) - num_galaxies = galaxy_hc_catalog.catalog_info.total_rows - num_random = random_hc_catalog.catalog_info.total_rows - estimate = calculate_natural_estimate( - acf_dd_counts, acf_rr_counts, num_galaxies, num_random - ) - npt.assert_allclose(acf_nat_estimate, estimate, rtol=2e-3) - - -def test_acf_e2e(dask_client, data_catalog, rand_catalog, acf_nat_estimate, acf_params): + acf_params.estimator = "NAT" estimate = compute_autocorrelation( data_catalog, rand_catalog, AngularCorrelation, params=acf_params ) npt.assert_allclose(estimate, acf_nat_estimate, rtol=1e-7) -def test_acf_counts_with_weights_are_correct( +def test_acf_natural_counts_with_weights_are_correct( dask_client, acf_gals_weight_catalog, acf_rans_weight_catalog, @@ -60,9 +42,11 @@ def test_acf_counts_with_weights_are_correct( acf_rr_counts_with_weights, acf_params, ): - ang_corr = AngularCorrelation(params=acf_params, use_weights=True) - counts_dd, counts_rr = compute_autocorrelation_counts( - acf_gals_weight_catalog, acf_rans_weight_catalog, ang_corr + estimator = NaturalEstimator( + AngularCorrelation(params=acf_params, use_weights=True) + ) + counts_dd, counts_rr, _ = estimator.compute_autocorrelation_counts( + acf_gals_weight_catalog, acf_rans_weight_catalog ) npt.assert_allclose(counts_dd, acf_dd_counts_with_weights, rtol=1e-3) npt.assert_allclose(counts_rr, acf_rr_counts_with_weights, rtol=2e-3) diff --git a/tests/corrgi/test_pcf.py b/tests/corrgi/test_pcf.py index 0924a6f..caf9f9a 100644 --- a/tests/corrgi/test_pcf.py +++ b/tests/corrgi/test_pcf.py @@ -1,20 +1,30 @@ import pytest from corrgi.correlation.projected_correlation import ProjectedCorrelation from corrgi.corrgi import compute_autocorrelation -from corrgi.dask import compute_autocorrelation_counts import numpy.testing as npt +from corrgi.estimators.natural_estimator import NaturalEstimator -def test_pcf_counts_are_correct( + +def test_pcf_natural_counts_are_correct( dask_client, data_catalog, rand_catalog, pcf_dd_counts, pcf_rr_counts, pcf_params ): - proj_corr = ProjectedCorrelation(params=pcf_params) - counts_dd, counts_rr = compute_autocorrelation_counts( - data_catalog, rand_catalog, proj_corr + estimator = NaturalEstimator(ProjectedCorrelation(params=pcf_params)) + counts_dd, counts_rr, _ = estimator.compute_autocorrelation_counts( + data_catalog, rand_catalog + ) + npt.assert_allclose(counts_dd, pcf_dd_counts, rtol=1e-3) + npt.assert_allclose(counts_rr, pcf_rr_counts, rtol=2e-3) + + +def test_pcf_natural_estimate_is_correct( + dask_client, data_catalog, rand_catalog, pcf_nat_estimate, pcf_params +): + pcf_params.estimator = "NAT" + estimate = compute_autocorrelation( + data_catalog, rand_catalog, ProjectedCorrelation, params=pcf_params ) - expected_dd, expected_rr = counts_dd.transpose([1, 0]), counts_rr.transpose([1, 0]) - npt.assert_allclose(expected_dd, pcf_dd_counts, rtol=1e-3) - npt.assert_allclose(expected_rr, pcf_rr_counts, rtol=2e-3) + npt.assert_allclose(estimate, pcf_nat_estimate, rtol=1e-3) def test_pcf_counts_with_weights_are_correct( @@ -25,13 +35,14 @@ def test_pcf_counts_with_weights_are_correct( pcf_rr_counts_with_weights, pcf_params, ): - proj_corr = ProjectedCorrelation(params=pcf_params, use_weights=True) - counts_dd, counts_rr = compute_autocorrelation_counts( - pcf_gals_weight_catalog, pcf_rans_weight_catalog, proj_corr + estimator = NaturalEstimator( + ProjectedCorrelation(params=pcf_params, use_weights=True) + ) + counts_dd, counts_rr, _ = estimator.compute_autocorrelation_counts( + pcf_gals_weight_catalog, pcf_rans_weight_catalog ) - expected_dd, expected_rr = counts_dd.transpose([1, 0]), counts_rr.transpose([1, 0]) - npt.assert_allclose(expected_dd, pcf_dd_counts_with_weights, rtol=1e-3) - npt.assert_allclose(expected_rr, pcf_rr_counts_with_weights, rtol=2e-3) + npt.assert_allclose(counts_dd, pcf_dd_counts_with_weights, rtol=1e-3) + npt.assert_allclose(counts_rr, pcf_rr_counts_with_weights, rtol=2e-3) def test_pcf_catalog_has_no_redshift(data_catalog, rand_catalog, pcf_params): diff --git a/tests/data/expected_results/pcf/w_pcf_nat.npy b/tests/data/expected_results/pcf/w_pcf_nat.npy new file mode 100644 index 0000000..efc6a4f Binary files /dev/null and b/tests/data/expected_results/pcf/w_pcf_nat.npy differ