Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement cross correlation API #15

Merged
merged 15 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions src/corrgi/corrgi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
from lsdb import Catalog
from munch import Munch

from corrgi.correlation.correlation import Correlation
from corrgi.estimators.estimator_factory import get_estimator_for_correlation
Expand All @@ -12,31 +11,42 @@ def compute_autocorrelation(
"""Calculates the auto-correlation for a catalog.

Args:
catalog (Catalog): The catalog.
random (Catalog): A random samples catalog.
catalog (Catalog): The galaxies catalog (D).
random (Catalog): A random samples catalog (R).
corr_type (type[Correlation]): The corrgi class corresponding to the type of
correlation (AngularCorrelation, RedshiftCorrelation, or ProjectedCorrelation).
**kwargs (dict): The arguments for the creation of the correlation instance.

Returns:
A numpy array with the result of the auto-correlation, using the natural estimator.
A numpy array with the result of the auto-correlation, according to the estimator
provided in the correlation kwargs. More information on how to set up the input parameters
in https://gundam.readthedocs.io/en/latest/introduction.html#set-up-input-parameters.
"""
correlation = corr_type(**kwargs)
correlation.validate([catalog, random])
estimator = get_estimator_for_correlation(correlation)
return estimator.compute_auto_estimate(catalog, random)


def compute_crosscorrelation(left: Catalog, right: Catalog, random: Catalog, params: Munch) -> np.ndarray:
def compute_crosscorrelation(
camposandro marked this conversation as resolved.
Show resolved Hide resolved
left: Catalog, right: Catalog, random: Catalog, corr_type: type[Correlation], **kwargs
) -> np.ndarray:
"""Computes the cross-correlation between two catalogs.

Args:
left (Catalog): Left catalog for the cross-correlation.
right (Catalog): Right catalog for the cross-correlation.
random (Catalog): A random samples catalog.
params (Munch): The parameters dictionary to run gundam with.
left (Catalog): Left catalog for the cross-correlation (D).
right (Catalog): Right catalog for the cross-correlation (C).
random (Catalog): A random samples catalog (R).
corr_type (type[Correlation]): The corrgi class corresponding to the type of
correlation (AngularCorrelation, RedshiftCorrelation, or ProjectedCorrelation).
**kwargs (dict): The arguments for the creation of the correlation instance.

Returns:
A numpy array with the result of the cross-correlation, using the natural estimator.
A numpy array with the result of the cross-correlation, according to the estimator
provided in the correlation kwargs. More information on how to set up the input parameters
in https://gundam.readthedocs.io/en/latest/introduction.html#set-up-input-parameters.
"""
raise NotImplementedError()
correlation = corr_type(**kwargs)
correlation.validate([left, right, random])
estimator = get_estimator_for_correlation(correlation)
return estimator.compute_cross_estimate(left, right, random)
36 changes: 36 additions & 0 deletions src/corrgi/estimators/davis_peebles_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import annotations

import dask
import numpy as np
from lsdb import Catalog

from corrgi.dask import perform_cross_counts
from corrgi.estimators.estimator import Estimator


class DavisPeeblesEstimator(Estimator):
"""Davis-Peebles Estimator"""

def compute_autocorrelation_counts(
self, catalog: Catalog, random: Catalog
) -> list[np.ndarray, np.ndarray, np.ndarray | int]:
"""Computes the auto-correlation counts for the provided catalog"""
raise NotImplementedError()

Check warning on line 18 in src/corrgi/estimators/davis_peebles_estimator.py

View check run for this annotation

Codecov / codecov/patch

src/corrgi/estimators/davis_peebles_estimator.py#L18

Added line #L18 was not covered by tests

def compute_crosscorrelation_counts(
self, left: Catalog, right: Catalog, random: Catalog
) -> list[np.ndarray, np.ndarray]:
"""Computes the cross-correlation counts for the provided catalog.

Args:
left (Catalog): A left galaxy samples catalog (D).
right (Catalog): A right galaxy samples catalog (C).
random (Catalog): A random samples catalog (R).

Returns:
The CD and CR counts for the DP estimator.
"""
counts_cd = perform_cross_counts(right, left, self.correlation)
counts_cr = perform_cross_counts(right, random, self.correlation)
counts_cd_cr = dask.compute(*[counts_cd, counts_cr])
return self.correlation.transform_counts(counts_cd_cr)
53 changes: 49 additions & 4 deletions src/corrgi/estimators/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Callable

import numpy as np
from gundam.gundam import tpcf, tpcf_wrp
from gundam.gundam import tpccf, tpccf_wrp, tpcf, tpcf_wrp
from lsdb import Catalog

from corrgi.correlation.correlation import Correlation
Expand All @@ -22,8 +22,8 @@
"""Computes the auto-correlation for this estimator.

Args:
catalog (Catalog): The catalog of galaxy samples.
random (Catalog): The catalog of random samples.
catalog (Catalog): The catalog of galaxy samples (D).
random (Catalog): The catalog of random samples (R).

Returns:
The statistical estimate of the auto-correlation function, as a numpy array.
Expand All @@ -35,6 +35,24 @@
estimate, _ = self._get_auto_subroutine()(*args)
return estimate

def compute_cross_estimate(self, left: Catalog, right: Catalog, random: Catalog) -> np.ndarray:
"""Computes the cross-correlation for this estimator.

Args:
left (Catalog): The left catalog of galaxy samples (D).
right (Catalog): The right catalog of galaxy samples (C).
random (Catalog): The catalog of random samples (R).

Returns:
The statistical estimate of the cross-correlation function, as a numpy array.
"""
num_galaxies = compute_catalog_size(left)
num_random = compute_catalog_size(random)
cd, cr = self.compute_crosscorrelation_counts(left, right, random)
args = self._get_cross_args(num_galaxies, num_random, cd, cr)
estimate, _ = self._get_cross_subroutine()(*args)
return estimate

@abstractmethod
def compute_autocorrelation_counts(
self, catalog: Catalog, random: Catalog
Expand All @@ -44,8 +62,15 @@
the natural estimator)."""
raise NotImplementedError()

@abstractmethod
def compute_crosscorrelation_counts(
self, left: Catalog, right: Catalog, random: Catalog
) -> list[np.ndarray, np.ndarray]:
"""Computes the cross-correlation counts (CD, CR)."""
raise NotImplementedError()

Check warning on line 70 in src/corrgi/estimators/estimator.py

View check run for this annotation

Codecov / codecov/patch

src/corrgi/estimators/estimator.py#L70

Added line #L70 was not covered by tests

def _get_auto_subroutine(self) -> Callable:
"""Returns the Fortran routine to calculate the correlation estimate"""
"""Returns the Fortran routine to calculate the auto-correlation estimate"""
return tpcf_wrp if isinstance(self.correlation, ProjectedCorrelation) else tpcf

def _get_auto_args(
Expand All @@ -64,3 +89,23 @@
args.append(self.correlation.params.dsepv)
args.append(self.correlation.params.estimator)
return args

def _get_cross_subroutine(self) -> Callable:
"""Returns the Fortran routine to calculate the cross-correlation estimate"""
return tpccf_wrp if isinstance(self.correlation, ProjectedCorrelation) else tpccf

def _get_cross_args(
self,
num_galaxies: int,
num_random: int,
counts_cd: np.ndarray,
counts_cr: np.ndarray,
) -> list:
"""Returns the args for the cross-correlation estimator routine"""
counts_bdd = self.correlation.get_bdd_counts()
args = [num_galaxies, num_random, counts_cd, counts_bdd, counts_cr]
if isinstance(self.correlation, ProjectedCorrelation):
# The projected routines require an additional parameter
args.append(self.correlation.params.dsepv)
args.append(self.correlation.params.estimator)
return args
3 changes: 2 additions & 1 deletion src/corrgi/estimators/estimator_factory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from corrgi.correlation.correlation import Correlation
from corrgi.estimators.davis_peebles_estimator import DavisPeeblesEstimator
from corrgi.estimators.estimator import Estimator
from corrgi.estimators.natural_estimator import NaturalEstimator

estimator_class_for_type: dict[str, type[Estimator]] = {"NAT": NaturalEstimator}
estimator_class_for_type: dict[str, type[Estimator]] = {"NAT": NaturalEstimator, "DP": DavisPeeblesEstimator}


def get_estimator_for_correlation(correlation: Correlation) -> Estimator:
Expand Down
14 changes: 10 additions & 4 deletions src/corrgi/estimators/natural_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@


class NaturalEstimator(Estimator):
"""Natural Estimator (`DD/RR - 1`)"""
"""Natural Estimator"""

def compute_autocorrelation_counts(
self, catalog: Catalog, random: Catalog
) -> list[np.ndarray, np.ndarray, np.ndarray | int]:
"""Computes the auto-correlation counts for the provided catalog.
"""Computes the auto-correlation counts for the provided catalog (`DD/RR - 1`).

Args:
catalog (Catalog): A galaxy samples catalog.
random (Catalog): A random samples catalog.
catalog (Catalog): A galaxy samples catalog (D).
random (Catalog): A random samples catalog (R).

Returns:
The DD, RR and DR counts for the natural estimator.
Expand All @@ -29,3 +29,9 @@
counts_dd_rr = dask.compute(*[counts_dd, counts_rr])
counts_dd_rr = self.correlation.transform_counts(counts_dd_rr)
return [*counts_dd_rr, counts_dr]

def compute_crosscorrelation_counts(
self, left: Catalog, right: Catalog, random: Catalog
) -> list[np.ndarray, np.ndarray, np.ndarray]:
"""Computes the cross-correlation counts for the provided catalog"""
raise NotImplementedError()

Check warning on line 37 in src/corrgi/estimators/natural_estimator.py

View check run for this annotation

Codecov / codecov/patch

src/corrgi/estimators/natural_estimator.py#L37

Added line #L37 was not covered by tests
30 changes: 30 additions & 0 deletions tests/corrgi/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def pcf_expected_results(test_data_dir):
return test_data_dir / "expected_results" / "pcf"


@pytest.fixture
def pccf_expected_results(test_data_dir):
return test_data_dir / "expected_results" / "pccf"


@pytest.fixture
def data_catalog_dir(hipscat_catalogs_dir):
return hipscat_catalogs_dir / "DATA"
Expand Down Expand Up @@ -96,6 +101,16 @@ def pcf_gals_weight_catalog(pcf_gals_weight_dir):
return lsdb.read_hipscat(pcf_gals_weight_dir)


@pytest.fixture
def pcf_gals1_weight_dir(hipscat_catalogs_dir):
return hipscat_catalogs_dir / "pcf_gals1_weight"


@pytest.fixture
def pcf_gals1_weight_catalog(pcf_gals1_weight_dir):
return lsdb.read_hipscat(pcf_gals1_weight_dir)


@pytest.fixture
def pcf_rans_weight_dir(hipscat_catalogs_dir):
return hipscat_catalogs_dir / "pcf_rans_weight"
Expand Down Expand Up @@ -156,6 +171,16 @@ def pcf_rr_counts_with_weights(pcf_expected_results):
return np.load(pcf_expected_results / "rr_pcf_weight.npy")


@pytest.fixture
def pccf_cd_counts_with_weights(pccf_expected_results):
return np.load(pccf_expected_results / "cd_pccf_weight.npy")


@pytest.fixture
def pccf_cr_counts_with_weights(pccf_expected_results):
return np.load(pccf_expected_results / "cr_pccf_weight.npy")


@pytest.fixture
def acf_nat_estimate(acf_expected_results):
return np.load(acf_expected_results / "w_acf_nat.npy")
Expand All @@ -166,6 +191,11 @@ def pcf_nat_estimate(pcf_expected_results):
return np.load(pcf_expected_results / "w_pcf_nat.npy")


@pytest.fixture
def pccf_with_weights_dp_estimate(pccf_expected_results):
return np.load(pccf_expected_results / "w_pccf_weights_dp.npy")


@pytest.fixture
def single_data_partition(data_catalog_dir):
return pd.read_parquet(data_catalog_dir / "Norder=0" / "Dir=0" / "Npix=1.parquet")
Expand Down
42 changes: 41 additions & 1 deletion tests/corrgi/test_pcf.py → tests/corrgi/test_projected.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import pytest
from corrgi.correlation.projected_correlation import ProjectedCorrelation
from corrgi.corrgi import compute_autocorrelation
from corrgi.corrgi import compute_autocorrelation, compute_crosscorrelation
import numpy.testing as npt

from corrgi.estimators.natural_estimator import NaturalEstimator
from corrgi.estimators.davis_peebles_estimator import DavisPeeblesEstimator


def test_pcf_natural_counts_are_correct(
Expand Down Expand Up @@ -54,3 +55,42 @@ def test_pcf_catalog_has_no_redshift(data_catalog, rand_catalog, pcf_params):
params=pcf_params,
redshift_column="ph_z",
)


def test_pccf_with_weights_davis_peebles_estimate_is_correct(
dask_client,
pcf_gals_weight_catalog,
pcf_gals1_weight_catalog,
pcf_rans_weight_catalog,
pccf_with_weights_dp_estimate,
pcf_params,
):
pcf_params.estimator = "DP"
estimate = compute_crosscorrelation(
pcf_gals_weight_catalog,
pcf_gals1_weight_catalog,
pcf_rans_weight_catalog,
ProjectedCorrelation,
params=pcf_params,
use_weights=True,
)
npt.assert_allclose(estimate, pccf_with_weights_dp_estimate, rtol=1e-2)


def test_pccf_counts_with_weights_are_correct(
dask_client,
pcf_gals_weight_catalog,
pcf_gals1_weight_catalog,
pcf_rans_weight_catalog,
pccf_cd_counts_with_weights,
pccf_cr_counts_with_weights,
pcf_params,
):
estimator = DavisPeeblesEstimator(
ProjectedCorrelation(params=pcf_params, use_weights=True)
)
counts_cd, counts_cr = estimator.compute_crosscorrelation_counts(
pcf_gals_weight_catalog, pcf_gals1_weight_catalog, pcf_rans_weight_catalog
)
npt.assert_allclose(counts_cd, pccf_cd_counts_with_weights, rtol=1e-3)
npt.assert_allclose(counts_cr, pccf_cr_counts_with_weights, rtol=2e-3)
Binary file added tests/data/expected_results/pccf/cd_pccf_weight.npy
Binary file not shown.
Binary file added tests/data/expected_results/pccf/cr_pccf_weight.npy
Binary file not shown.
Binary file not shown.
3 changes: 2 additions & 1 deletion tests/data/generate_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@
"\n",
"# With weights\n",
"generate_catalog(\"acf_gals_weight\")\n",
"generate_catalog(\"pcf_gals_weight\")"
"generate_catalog(\"pcf_gals_weight\")\n",
"generate_catalog(\"pcf_gals1_weight\")"
]
},
{
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/hipscat/pcf_gals1_weight/_metadata
Binary file not shown.
8 changes: 8 additions & 0 deletions tests/data/hipscat/pcf_gals1_weight/catalog_info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"catalog_name": "pcf_gals1_weight",
"catalog_type": "object",
"total_rows": 35924,
"epoch": "J2000",
"ra_column": "ra",
"dec_column": "dec"
}
8 changes: 8 additions & 0 deletions tests/data/hipscat/pcf_gals1_weight/partition_info.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Norder,Dir,Npix,num_rows
0,0,1,10494
0,0,2,14244
0,0,5,255
0,0,6,9433
0,0,7,1456
0,0,9,2
0,0,10,40
25 changes: 25 additions & 0 deletions tests/data/hipscat/pcf_gals1_weight/provenance_info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"catalog_name": "pcf_gals1_weight",
"catalog_type": "object",
"total_rows": 35924,
"epoch": "J2000",
"ra_column": "ra",
"dec_column": "dec",
"version": "0.3.4",
"generation_date": "2024.07.09",
"tool_args": {
"tool_name": "lsdb",
"version": "0.2.5",
"runtime_args": {
"catalog_name": "pcf_gals1_weight",
"output_path": "hipscat/pcf_gals1_weight",
"output_catalog_name": "pcf_gals1_weight",
"catalog_path": "hipscat/pcf_gals1_weight",
"catalog_type": "object",
"total_rows": 35924,
"epoch": "J2000",
"ra_column": "ra",
"dec_column": "dec"
}
}
}
Binary file added tests/data/raw/pcf_gals1_weight.fits
Binary file not shown.