Skip to content

Commit

Permalink
Compute angular correlation with weights (#12)
Browse files Browse the repository at this point in the history
* Add acf with weights

* Improve code coverage

* Simplify arg construction

* Remove nbins attribute

* Add weight column as a parameter
  • Loading branch information
camposandro authored Jul 8, 2024
1 parent 1ff1f9d commit f64c7b7
Show file tree
Hide file tree
Showing 34 changed files with 185 additions and 10 deletions.
22 changes: 17 additions & 5 deletions src/corrgi/correlation/angular_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,21 @@ class AngularCorrelation(Correlation):
"""The angular correlation utilities."""

def _get_auto_method(self):
return cff.mod.th_A_wg if self.use_weights else cff.mod.th_A_naiveway
return cff.mod.th_A_wg_naiveway if self.use_weights else cff.mod.th_A_naiveway

def _get_cross_method(self) -> Callable:
return cff.mod.th_C_wg if self.use_weights else cff.mod.th_C_naiveway
return cff.mod.th_C_wg_naiveway if self.use_weights else cff.mod.th_C_naiveway

def _construct_auto_args(self, df: pd.DataFrame, catalog_info: CatalogInfo) -> list:
return [
len(df), # number of particles
args = [
len(df),
*self.get_coords(df, catalog_info), # cartesian coordinates
self.params.nsept, # number of angular separation bins
self.bins, # bins in angular separation [deg]
]
if self.use_weights:
args = [args[0], df[self.weight_column].to_numpy(), *args[1:]]
return args

def _construct_cross_args(
self,
Expand All @@ -31,11 +34,20 @@ def _construct_cross_args(
left_catalog_info: CatalogInfo,
right_catalog_info: CatalogInfo,
) -> list:
return [
args = [
len(left_df), # number of particles of the left partition
*self.get_coords(left_df, left_catalog_info), # X,Y,Z coordinates of particles
len(right_df), # number of particles of the right partition
*self.get_coords(right_df, right_catalog_info), # X,Y,Z coordinates of particles
self.params.nsept, # number of angular separation bins
self.bins, # bins in angular separation [deg]
]
if self.use_weights:
args = [
args[0],
left_df[self.weight_column].to_numpy(),
*args[1:5],
right_df[self.weight_column].to_numpy(),
*args[5:],
]
return args
2 changes: 2 additions & 0 deletions src/corrgi/correlation/correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ def __init__(
self,
bins: np.ndarray,
params: Munch,
weight_column: str = "wei",
use_weights: bool = False,
):
self.bins = bins
self.params = params
self.weight_column = weight_column
self.use_weights = use_weights

def count_auto_pairs(self, df: pd.DataFrame, catalog_info: CatalogInfo) -> np.ndarray:
Expand Down
8 changes: 7 additions & 1 deletion src/corrgi/corrgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ def compute_autocorrelation(
catalog: Catalog,
random: Catalog,
params: Munch,
weight_column: str = "wei",
use_weights: bool = False,
) -> np.ndarray:
"""Calculates the auto-correlation for a catalog.
Expand All @@ -21,13 +23,17 @@ def compute_autocorrelation(
catalog (Catalog): The catalog.
random (Catalog): A random samples catalog.
params (Munch): The parameters dictionary to run gundam with.
weight_column (str): The weights column name. Defaults to "wei".
use_weights (bool): Whether to use weights or not. Defaults to False.
Returns:
A numpy array with the result of the auto-correlation, using the natural estimator.
"""
if use_weights and weight_column not in catalog.columns:
raise ValueError(f"Weight column {weight_column} does not exist")
num_galaxies = catalog.hc_structure.catalog_info.total_rows
num_random = random.hc_structure.catalog_info.total_rows
counts_dd, counts_rr = compute_autocorrelation_counts(corr_type, catalog, random, params)
counts_dd, counts_rr = compute_autocorrelation_counts(corr_type, catalog, random, params, use_weights)
return calculate_natural_estimate(counts_dd, counts_rr, num_galaxies, num_random)


Expand Down
11 changes: 9 additions & 2 deletions src/corrgi/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@


def compute_autocorrelation_counts(
corr_type: type[Correlation], catalog: Catalog, random: Catalog, params: Munch
corr_type: type[Correlation],
catalog: Catalog,
random: Catalog,
params: Munch,
weight_column: str = "wei",
use_weights: bool = False,
) -> np.ndarray:
"""Computes the auto-correlation counts for a catalog.
Expand All @@ -24,14 +29,16 @@ def compute_autocorrelation_counts(
catalog (Catalog): The catalog with galaxy samples.
random (Catalog): The catalog with random samples.
params (dict): The gundam parameters for the Fortran subroutine.
weight_column (str): The weights column name. Defaults to "wei".
use_weights (bool): Whether to use weights or not. Defaults to False.
Returns:
The histogram counts to calculate the auto-correlation.
"""
# Calculate the angular separation bins
bins, _ = gundam.makebins(params.nsept, params.septmin, params.dsept, params.logsept)
# Create correlation with bins and params
correlation = corr_type(bins, params)
correlation = corr_type(bins, params, weight_column, use_weights)
# Generate the histograms with counts for each catalog
counts_dd = perform_auto_counts(catalog, correlation)
counts_rr = perform_auto_counts(random, correlation)
Expand Down
20 changes: 20 additions & 0 deletions tests/corrgi/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ def rand_catalog_dir(hipscat_catalogs_dir):
return hipscat_catalogs_dir / "RAND"


@pytest.fixture
def acf_gals_weight_dir(hipscat_catalogs_dir):
return hipscat_catalogs_dir / "acf_gals_weight"


@pytest.fixture
def acf_rans_weight_dir(hipscat_catalogs_dir):
return hipscat_catalogs_dir / "acf_rans_weight"


@pytest.fixture
def acf_bins_left_edges(acf_expected_results):
return np.load(acf_expected_results / "l_binedges_acf.npy")
Expand All @@ -70,6 +80,16 @@ def acf_rr_counts(acf_expected_results):
return np.load(acf_expected_results / "rr_acf.npy")


@pytest.fixture
def acf_dd_counts_with_weights(acf_expected_results):
return np.load(acf_expected_results / "dd_acf_weight.npy")


@pytest.fixture
def acf_rr_counts_with_weights(acf_expected_results):
return np.load(acf_expected_results / "rr_acf_weight.npy")


@pytest.fixture
def acf_nat_estimate(acf_expected_results):
return np.load(acf_expected_results / "w_acf_nat.npy")
Expand Down
39 changes: 39 additions & 0 deletions tests/corrgi/test_acf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import lsdb
import numpy as np
import numpy.testing as npt
import pytest
from gundam import gundam

from corrgi.correlation.angular_correlation import AngularCorrelation
Expand Down Expand Up @@ -70,3 +71,41 @@ def test_acf_e2e(
AngularCorrelation, galaxy_catalog, random_catalog, autocorr_params
)
npt.assert_allclose(estimate, acf_nat_estimate, rtol=1e-7)


def test_acf_counts_with_weights_are_correct(
dask_client,
acf_gals_weight_dir,
acf_rans_weight_dir,
acf_dd_counts_with_weights,
acf_rr_counts_with_weights,
autocorr_params,
):
galaxy_catalog = lsdb.read_hipscat(acf_gals_weight_dir)
random_catalog = lsdb.read_hipscat(acf_rans_weight_dir)
assert isinstance(galaxy_catalog, lsdb.Catalog)
assert isinstance(random_catalog, lsdb.Catalog)
counts_dd, counts_rr = compute_autocorrelation_counts(
AngularCorrelation,
galaxy_catalog,
random_catalog,
autocorr_params,
use_weights=True,
)
npt.assert_allclose(counts_dd, acf_dd_counts_with_weights, rtol=1e-3)
npt.assert_allclose(counts_rr, acf_rr_counts_with_weights, rtol=2e-3)


def test_acf_weights_not_provided(data_catalog_dir, rand_catalog_dir, autocorr_params):
galaxy_catalog = lsdb.read_hipscat(data_catalog_dir)
random_catalog = lsdb.read_hipscat(rand_catalog_dir)
assert isinstance(galaxy_catalog, lsdb.Catalog)
assert isinstance(random_catalog, lsdb.Catalog)
with pytest.raises(ValueError, match="does not exist"):
compute_autocorrelation(
AngularCorrelation,
galaxy_catalog,
random_catalog,
autocorr_params,
use_weights=True,
)
Binary file added tests/data/expected_results/acf/dd_acf_weight.npy
Binary file not shown.
Binary file added tests/data/expected_results/acf/rr_acf_weight.npy
Binary file not shown.
12 changes: 10 additions & 2 deletions tests/data/generate_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,12 @@
"metadata": {},
"outputs": [],
"source": [
"# Without weights\n",
"generate_catalog(\"DATA\")\n",
"generate_catalog(\"DR7-lrg\")"
"generate_catalog(\"DR7-lrg\")\n",
"\n",
"# With weights\n",
"generate_catalog(\"acf_gals_weight\")"
]
},
{
Expand All @@ -64,8 +68,12 @@
"metadata": {},
"outputs": [],
"source": [
"# Without weights\n",
"generate_catalog(\"RAND\")\n",
"generate_catalog(\"DR7-lrg-rand\")"
"generate_catalog(\"DR7-lrg-rand\")\n",
"\n",
"# With weights\n",
"generate_catalog(\"acf_rans_weight\")"
]
}
],
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/hipscat/acf_gals_weight/_metadata
Binary file not shown.
8 changes: 8 additions & 0 deletions tests/data/hipscat/acf_gals_weight/catalog_info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"catalog_name": "acf_gals_weight",
"catalog_type": "object",
"total_rows": 8000,
"epoch": "J2000",
"ra_column": "ra",
"dec_column": "dec"
}
7 changes: 7 additions & 0 deletions tests/data/hipscat/acf_gals_weight/partition_info.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Norder,Dir,Npix,num_rows
0,0,1,2169
0,0,2,3372
0,0,5,71
0,0,6,2048
0,0,7,330
0,0,10,10
25 changes: 25 additions & 0 deletions tests/data/hipscat/acf_gals_weight/provenance_info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"catalog_name": "acf_gals_weight",
"catalog_type": "object",
"total_rows": 8000,
"epoch": "J2000",
"ra_column": "ra",
"dec_column": "dec",
"version": "0.3.4",
"generation_date": "2024.07.02",
"tool_args": {
"tool_name": "lsdb",
"version": "0.2.5",
"runtime_args": {
"catalog_name": "acf_gals_weight",
"output_path": "hipscat/acf_gals_weight",
"output_catalog_name": "acf_gals_weight",
"catalog_path": "hipscat/acf_gals_weight",
"catalog_type": "object",
"total_rows": 8000,
"epoch": "J2000",
"ra_column": "ra",
"dec_column": "dec"
}
}
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/hipscat/acf_rans_weight/_metadata
Binary file not shown.
8 changes: 8 additions & 0 deletions tests/data/hipscat/acf_rans_weight/catalog_info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"catalog_name": "acf_rans_weight",
"catalog_type": "object",
"total_rows": 22000,
"epoch": "J2000",
"ra_column": "ra",
"dec_column": "dec"
}
8 changes: 8 additions & 0 deletions tests/data/hipscat/acf_rans_weight/partition_info.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Norder,Dir,Npix,num_rows
0,0,1,7052
0,0,2,8024
0,0,5,188
0,0,6,5904
0,0,7,809
0,0,9,1
0,0,10,22
25 changes: 25 additions & 0 deletions tests/data/hipscat/acf_rans_weight/provenance_info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{
"catalog_name": "acf_rans_weight",
"catalog_type": "object",
"total_rows": 22000,
"epoch": "J2000",
"ra_column": "ra",
"dec_column": "dec",
"version": "0.3.4",
"generation_date": "2024.07.02",
"tool_args": {
"tool_name": "lsdb",
"version": "0.2.5",
"runtime_args": {
"catalog_name": "acf_rans_weight",
"output_path": "hipscat/acf_rans_weight",
"output_catalog_name": "acf_rans_weight",
"catalog_path": "hipscat/acf_rans_weight",
"catalog_type": "object",
"total_rows": 22000,
"epoch": "J2000",
"ra_column": "ra",
"dec_column": "dec"
}
}
}
Binary file added tests/data/raw/acf_gals_weight.fits
Binary file not shown.
Binary file added tests/data/raw/acf_rans_weight.fits
Binary file not shown.

0 comments on commit f64c7b7

Please sign in to comment.