Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement cross correlation API #15

Merged
merged 15 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions src/corrgi/corrgi.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import numpy as np
from lsdb import Catalog
from munch import Munch

from corrgi.correlation.correlation import Correlation
from corrgi.dask import compute_autocorrelation_counts
from corrgi.estimators import calculate_natural_estimate
from corrgi.dask import compute_autocorrelation_counts, compute_crosscorrelation_counts
from corrgi.estimators import calculate_tpccf, calculate_tpcf


def compute_autocorrelation(
Expand All @@ -25,21 +24,30 @@
correlation = corr_type(**kwargs)
correlation.validate([catalog, random])
counts_dd, counts_rr = compute_autocorrelation_counts(catalog, random, correlation)
num_galaxies = catalog.hc_structure.catalog_info.total_rows
num_particles = catalog.hc_structure.catalog_info.total_rows
camposandro marked this conversation as resolved.
Show resolved Hide resolved
num_random = random.hc_structure.catalog_info.total_rows
return calculate_natural_estimate(counts_dd, counts_rr, num_galaxies, num_random)
return calculate_tpcf(counts_dd, counts_rr, num_particles, num_random)


def compute_crosscorrelation(left: Catalog, right: Catalog, random: Catalog, params: Munch) -> np.ndarray:
def compute_crosscorrelation(
camposandro marked this conversation as resolved.
Show resolved Hide resolved
left: Catalog, right: Catalog, random: Catalog, corr_type: type[Correlation], **kwargs
) -> np.ndarray:
"""Computes the cross-correlation between two catalogs.

Args:
left (Catalog): Left catalog for the cross-correlation.
right (Catalog): Right catalog for the cross-correlation.
random (Catalog): A random samples catalog.
params (Munch): The parameters dictionary to run gundam with.
corr_type (type[Correlation]): The corrgi class corresponding to the type of
correlation (AngularCorrelation, RedshiftCorrelation, or ProjectedCorrelation).
**kwargs (dict): The arguments for the creation of the correlation instance.

Returns:
A numpy array with the result of the cross-correlation, using the natural estimator.
"""
raise NotImplementedError()
correlation = corr_type(**kwargs)
correlation.validate([left, right, random])
counts_cd, counts_cr = compute_crosscorrelation_counts(left, right, random, correlation)
num_particles = left.hc_structure.catalog_info.total_rows
num_random = random.hc_structure.catalog_info.total_rows
return calculate_tpccf(counts_cd, counts_cr, num_particles, num_random)

Check warning on line 53 in src/corrgi/corrgi.py

View check run for this annotation

Codecov / codecov/patch

src/corrgi/corrgi.py#L48-L53

Added lines #L48 - L53 were not covered by tests
19 changes: 19 additions & 0 deletions src/corrgi/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,25 @@ def perform_auto_counts(catalog: Catalog, *args) -> np.ndarray:
return join_count_histograms(all_partials)


def compute_crosscorrelation_counts(
left: Catalog, right: Catalog, random: Catalog, correlation: Correlation
) -> np.ndarray:
"""Computes the cross-correlation counts for a catalog.

Args:
left (Catalog): The left catalog with galaxy samples.
right (Catalog): The right catalog with galaxy samples.
random (Catalog): The catalog with random samples.
correlation (Correlation): The correlation instance.

Returns:
The histogram counts to calculate the cross-correlation.
"""
counts_cd = perform_cross_counts(right, left, correlation)
counts_cr = perform_cross_counts(right, random, correlation)
return dask.compute(*[counts_cd, counts_cr])


def perform_cross_counts(left: Catalog, right: Catalog, *args) -> np.ndarray:
"""Aligns the pixel of two catalogs and performs the pairs counting.

Expand Down
36 changes: 28 additions & 8 deletions src/corrgi/estimators.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,45 @@
import numpy as np
from gundam import tpcf
from gundam import tpccf, tpcf


def calculate_natural_estimate(
counts_dd: np.ndarray,
counts_rr: np.ndarray,
num_galaxies: int,
num_random: int,
def calculate_tpcf(
counts_dd: np.ndarray, counts_rr: np.ndarray, num_particles: int, num_random: int
) -> np.ndarray:
"""Calculates the auto-correlation value for the natural estimator.

Evaluation given data (D), and random (R) samples.

Args:
counts_dd (np.ndarray): The counts for the galaxy samples.
counts_rr (np.ndarray): The counts for the random samples.
num_galaxies (int): The number of galaxy samples.
num_particles (int): The number of particles in data (D).
num_random (int): The number of random samples.

Returns:
The natural correlation function estimate.
"""
dr = 0 # We do not use DR counts for the natural estimator
bdd = np.zeros([len(counts_dd), 0]) # We do not compute the bootstrap counts
wth, _ = tpcf(num_galaxies, num_random, counts_dd, bdd, counts_rr, dr, estimator="NAT")
wth, _ = tpcf(num_particles, num_random, counts_dd, bdd, counts_rr, dr, estimator="NAT")
return wth


def calculate_tpccf(
counts_cd: np.ndarray, counts_cr: np.ndarray, num_particles: int, num_random: int
) -> np.ndarray:
"""Calculates the cross-correlation value for the natural estimator.

Evaluation given data (D), random (R) and cross (C) samples.

Args:
counts_cd (np.ndarray): The counts for data-cross.
counts_cr (np.ndarray): The counts for cross-random.
num_particles (int): The number of particles in data (D).
num_random (int): The number of particles in random samples (R).

Returns:
The natural correlation function estimate.
"""
bcd = np.zeros([len(counts_cd), 0]) # We do not compute the bootstrap counts
wth, _ = tpccf(num_particles, num_random, counts_cd, bcd, counts_cr, estimator="NAT")

Check warning on line 44 in src/corrgi/estimators.py

View check run for this annotation

Codecov / codecov/patch

src/corrgi/estimators.py#L43-L44

Added lines #L43 - L44 were not covered by tests
return wth
30 changes: 30 additions & 0 deletions tests/corrgi/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def pcf_expected_results(test_data_dir):
return test_data_dir / "expected_results" / "pcf"


@pytest.fixture
def pccf_expected_results(test_data_dir):
return test_data_dir / "expected_results" / "pccf"


@pytest.fixture
def data_catalog_dir(hipscat_catalogs_dir):
return hipscat_catalogs_dir / "DATA"
Expand Down Expand Up @@ -86,6 +91,21 @@ def acf_rans_weight_catalog(hipscat_catalogs_dir):
return lsdb.read_hipscat(hipscat_catalogs_dir / "acf_rans_weight")


@pytest.fixture
def pcf_gals_weight_catalog(hipscat_catalogs_dir):
return lsdb.read_hipscat(hipscat_catalogs_dir / "pcf_gals_weight")


@pytest.fixture
def pcf_gals1_weight_catalog(hipscat_catalogs_dir):
return lsdb.read_hipscat(hipscat_catalogs_dir / "pcf_gals1_weight")


@pytest.fixture
def pcf_rans_weight_catalog(hipscat_catalogs_dir):
return lsdb.read_hipscat(hipscat_catalogs_dir / "pcf_rans_weight")


@pytest.fixture
def acf_bins_left_edges(acf_expected_results):
return np.load(acf_expected_results / "l_binedges_acf.npy")
Expand Down Expand Up @@ -126,6 +146,16 @@ def pcf_rr_counts(pcf_expected_results):
return np.load(pcf_expected_results / "rr_pcf.npy")


@pytest.fixture
def pccf_cd_counts(pccf_expected_results):
return np.load(pccf_expected_results / "cd_pccf_weight.npy")


@pytest.fixture
def pccf_cr_counts(pccf_expected_results):
return np.load(pccf_expected_results / "cr_pccf_weight.npy")


@pytest.fixture
def acf_nat_estimate(acf_expected_results):
return np.load(acf_expected_results / "w_acf_nat.npy")
Expand Down
6 changes: 2 additions & 4 deletions tests/corrgi/test_acf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from corrgi.correlation.angular_correlation import AngularCorrelation
from corrgi.corrgi import compute_autocorrelation
from corrgi.dask import compute_autocorrelation_counts
from corrgi.estimators import calculate_natural_estimate
from corrgi.estimators import calculate_tpcf


def test_acf_bins_are_correct(acf_bins_left_edges, acf_bins_right_edges, acf_params):
Expand Down Expand Up @@ -48,9 +48,7 @@ def test_acf_natural_estimate_is_correct(
random_hc_catalog = hipscat.read_from_hipscat(rand_catalog_dir)
num_galaxies = galaxy_hc_catalog.catalog_info.total_rows
num_random = random_hc_catalog.catalog_info.total_rows
estimate = calculate_natural_estimate(
acf_dd_counts, acf_rr_counts, num_galaxies, num_random
)
estimate = calculate_tpcf(acf_dd_counts, acf_rr_counts, num_galaxies, num_random)
npt.assert_allclose(acf_nat_estimate, estimate, rtol=2e-3)


Expand Down
23 changes: 22 additions & 1 deletion tests/corrgi/test_pcf.py → tests/corrgi/test_projected.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from corrgi.correlation.projected_correlation import ProjectedCorrelation
from corrgi.corrgi import compute_autocorrelation
from corrgi.dask import compute_autocorrelation_counts
from corrgi.dask import compute_autocorrelation_counts, compute_crosscorrelation_counts
import numpy.testing as npt


Expand All @@ -26,3 +26,24 @@ def test_pcf_catalog_has_no_redshift(data_catalog, rand_catalog, pcf_params):
params=pcf_params,
redshift_column="ph_z",
)


def test_pccf_with_weights(
dask_client,
pcf_gals_weight_catalog,
pcf_gals1_weight_catalog,
pcf_rans_weight_catalog,
pccf_cd_counts,
pccf_cr_counts,
pcf_params,
):
proj_corr = ProjectedCorrelation(params=pcf_params, use_weights=True)
counts_cd, counts_cr = compute_crosscorrelation_counts(
pcf_gals_weight_catalog,
pcf_gals1_weight_catalog,
pcf_rans_weight_catalog,
proj_corr,
)
expected_cd, expected_cr = counts_cd.transpose([1, 0]), counts_cr.transpose([1, 0])
npt.assert_allclose(expected_cd, pccf_cd_counts, rtol=1e-3)
npt.assert_allclose(expected_cr, pccf_cr_counts, rtol=2e-3)
Binary file added tests/data/expected_results/pccf/cd_pccf_weight.npy
Binary file not shown.
Binary file added tests/data/expected_results/pccf/cr_pccf_weight.npy
Binary file not shown.
7 changes: 5 additions & 2 deletions tests/data/generate_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@
"generate_catalog(\"DR7-lrg\")\n",
"\n",
"# With weights\n",
"generate_catalog(\"acf_gals_weight\")"
"generate_catalog(\"acf_gals_weight\")\n",
"generate_catalog(\"pcf_gals_weight\")\n",
"generate_catalog(\"pcf_gals1_weight\")"
]
},
{
Expand All @@ -73,7 +75,8 @@
"generate_catalog(\"DR7-lrg-rand\")\n",
"\n",
"# With weights\n",
"generate_catalog(\"acf_rans_weight\")"
"generate_catalog(\"acf_rans_weight\")\n",
"generate_catalog(\"pcf_rans_weight\")"
]
}
],
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/hipscat/pcf_gals1_weight/_metadata
Binary file not shown.
8 changes: 8 additions & 0 deletions tests/data/hipscat/pcf_gals1_weight/catalog_info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"catalog_name": "pcf_gals1_weight",
"catalog_type": "object",
"total_rows": 35924,
"epoch": "J2000",
"ra_column": "ra",
"dec_column": "dec"
}
8 changes: 8 additions & 0 deletions tests/data/hipscat/pcf_gals1_weight/partition_info.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Norder,Dir,Npix,num_rows
0,0,1,10494
0,0,2,14244
0,0,5,255
0,0,6,9433
0,0,7,1456
0,0,9,2
0,0,10,40
25 changes: 25 additions & 0 deletions tests/data/hipscat/pcf_gals1_weight/provenance_info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"catalog_name": "pcf_gals1_weight",
"catalog_type": "object",
"total_rows": 35924,
"epoch": "J2000",
"ra_column": "ra",
"dec_column": "dec",
"version": "0.3.4",
"generation_date": "2024.07.09",
"tool_args": {
"tool_name": "lsdb",
"version": "0.2.5",
"runtime_args": {
"catalog_name": "pcf_gals1_weight",
"output_path": "hipscat/pcf_gals1_weight",
"output_catalog_name": "pcf_gals1_weight",
"catalog_path": "hipscat/pcf_gals1_weight",
"catalog_type": "object",
"total_rows": 35924,
"epoch": "J2000",
"ra_column": "ra",
"dec_column": "dec"
}
}
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/hipscat/pcf_gals_weight/_metadata
Binary file not shown.
8 changes: 8 additions & 0 deletions tests/data/hipscat/pcf_gals_weight/catalog_info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"catalog_name": "pcf_gals_weight",
"catalog_type": "object",
"total_rows": 84383,
"epoch": "J2000",
"ra_column": "ra",
"dec_column": "dec"
}
8 changes: 8 additions & 0 deletions tests/data/hipscat/pcf_gals_weight/partition_info.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Norder,Dir,Npix,num_rows
0,0,1,23743
0,0,2,34965
0,0,5,632
0,0,6,21350
0,0,7,3481
0,0,9,3
0,0,10,209
25 changes: 25 additions & 0 deletions tests/data/hipscat/pcf_gals_weight/provenance_info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"catalog_name": "pcf_gals_weight",
"catalog_type": "object",
"total_rows": 84383,
"epoch": "J2000",
"ra_column": "ra",
"dec_column": "dec",
"version": "0.3.4",
"generation_date": "2024.07.09",
"tool_args": {
"tool_name": "lsdb",
"version": "0.2.5",
"runtime_args": {
"catalog_name": "pcf_gals_weight",
"output_path": "hipscat/pcf_gals_weight",
"output_catalog_name": "pcf_gals_weight",
"catalog_path": "hipscat/pcf_gals_weight",
"catalog_type": "object",
"total_rows": 84383,
"epoch": "J2000",
"ra_column": "ra",
"dec_column": "dec"
}
}
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/hipscat/pcf_rans_weight/_metadata
Binary file not shown.
8 changes: 8 additions & 0 deletions tests/data/hipscat/pcf_rans_weight/catalog_info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"catalog_name": "pcf_rans_weight",
"catalog_type": "object",
"total_rows": 227336,
"epoch": "J2000",
"ra_column": "ra",
"dec_column": "dec"
}
8 changes: 8 additions & 0 deletions tests/data/hipscat/pcf_rans_weight/partition_info.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Norder,Dir,Npix,num_rows
0,0,1,73222
0,0,2,82167
0,0,5,1742
0,0,6,60914
0,0,7,8953
0,0,9,17
0,0,10,321
25 changes: 25 additions & 0 deletions tests/data/hipscat/pcf_rans_weight/provenance_info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"catalog_name": "pcf_rans_weight",
"catalog_type": "object",
"total_rows": 227336,
"epoch": "J2000",
"ra_column": "ra",
"dec_column": "dec",
"version": "0.3.4",
"generation_date": "2024.07.09",
"tool_args": {
"tool_name": "lsdb",
"version": "0.2.5",
"runtime_args": {
"catalog_name": "pcf_rans_weight",
"output_path": "hipscat/pcf_rans_weight",
"output_catalog_name": "pcf_rans_weight",
"catalog_path": "hipscat/pcf_rans_weight",
"catalog_type": "object",
"total_rows": 227336,
"epoch": "J2000",
"ra_column": "ra",
"dec_column": "dec"
}
}
}
Binary file added tests/data/raw/pcf_gals1_weight.fits
Binary file not shown.
Binary file added tests/data/raw/pcf_gals_weight.fits
Binary file not shown.
Binary file added tests/data/raw/pcf_rans_weight.fits
Binary file not shown.