Skip to content

Commit

Permalink
Move counting routines to estimator class (#25)
Browse files Browse the repository at this point in the history
* Create estimator wrapper

* Move count routines to estimators

* Increase test tolerance

* Improve code clarity
  • Loading branch information
camposandro authored Jul 18, 2024
1 parent a21b9aa commit 39e15e9
Show file tree
Hide file tree
Showing 15 changed files with 197 additions and 92 deletions.
5 changes: 5 additions & 0 deletions src/corrgi/correlation/angular_correlation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Callable

import gundam.cflibfor as cff
import numpy as np
import pandas as pd
from gundam import gundam
from hipscat.catalog.catalog_info import CatalogInfo
Expand Down Expand Up @@ -64,3 +65,7 @@ def _construct_cross_args(
*args[5:],
]
return args

def get_bdd_counts(self) -> np.ndarray:
"""Returns the boostrap counts for the angular correlation"""
return np.zeros([self.params.nsept, 0])
9 changes: 9 additions & 0 deletions src/corrgi/correlation/correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,15 @@ def _construct_cross_args(
"""Generate the arguments required for the cross pairing method"""
raise NotImplementedError()

@abstractmethod
def get_bdd_counts(self) -> np.ndarray:
"""Returns the boostrap counts for the correlation"""
raise NotImplementedError()

def transform_counts(self, counts: list[np.ndarray]) -> list[np.ndarray]:
"""Applies final transformations to the correlation counts"""
return counts

@staticmethod
def get_coords(df: pd.DataFrame, catalog_info: CatalogInfo) -> tuple[float, float, float]:
"""Calculate the cartesian coordinates for the points in the partition"""
Expand Down
8 changes: 8 additions & 0 deletions src/corrgi/correlation/projected_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,11 @@ def _construct_cross_args(
*args[7:],
]
return args

def transform_counts(self, counts: list[np.ndarray]) -> list[np.ndarray]:
"""The projected counts need to be transposed before being sent to Fortran"""
return [c.transpose([1, 0]) for c in counts]

def get_bdd_counts(self) -> np.ndarray:
"""Returns the boostrap counts for the projected correlation"""
return np.zeros([self.params.nsepp, self.params.nsepv, 0])
9 changes: 9 additions & 0 deletions src/corrgi/correlation/redshift_correlation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Callable

import numpy as np
import pandas as pd
from hipscat.catalog.catalog_info import CatalogInfo
from lsdb import Catalog
Expand Down Expand Up @@ -35,3 +36,11 @@ def _construct_cross_args(
right_catalog_info: CatalogInfo,
) -> list:
raise NotImplementedError()

def get_bdd_counts(self) -> np.ndarray:
"""Returns the boostrap counts for the correlation"""
raise NotImplementedError()

def transform_counts(self, counts: list[np.ndarray]) -> list[np.ndarray]:
"""Applies final transformations to the correlation counts"""
raise NotImplementedError()
10 changes: 3 additions & 7 deletions src/corrgi/corrgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
from munch import Munch

from corrgi.correlation.correlation import Correlation
from corrgi.dask import compute_autocorrelation_counts
from corrgi.estimators import calculate_natural_estimate
from corrgi.utils import compute_catalog_size
from corrgi.estimators.estimator_factory import get_estimator_for_correlation


def compute_autocorrelation(
Expand All @@ -25,10 +23,8 @@ def compute_autocorrelation(
"""
correlation = corr_type(**kwargs)
correlation.validate([catalog, random])
num_galaxies = compute_catalog_size(catalog)
num_random = compute_catalog_size(random)
counts_dd, counts_rr = compute_autocorrelation_counts(catalog, random, correlation)
return calculate_natural_estimate(counts_dd, counts_rr, num_galaxies, num_random)
estimator = get_estimator_for_correlation(correlation)
return estimator.compute_auto_estimate(catalog, random)


def compute_crosscorrelation(left: Catalog, right: Catalog, random: Catalog, params: Munch) -> np.ndarray:
Expand Down
16 changes: 0 additions & 16 deletions src/corrgi/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,6 @@
from corrgi.utils import join_count_histograms


def compute_autocorrelation_counts(catalog: Catalog, random: Catalog, correlation: Correlation) -> np.ndarray:
"""Computes the auto-correlation counts for a catalog.
Args:
catalog (Catalog): The catalog with galaxy samples.
random (Catalog): The catalog with random samples.
correlation (Correlation): The correlation instance.
Returns:
The histogram counts to calculate the auto-correlation.
"""
counts_dd = perform_auto_counts(catalog, correlation)
counts_rr = perform_auto_counts(random, correlation)
return dask.compute(*[counts_dd, counts_rr])


def perform_auto_counts(catalog: Catalog, *args) -> np.ndarray:
"""Aligns the pixel of a single catalog and performs the pairs counting.
Expand Down
25 changes: 0 additions & 25 deletions src/corrgi/estimators.py

This file was deleted.

Empty file.
66 changes: 66 additions & 0 deletions src/corrgi/estimators/estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Callable

import numpy as np
from gundam.gundam import tpcf, tpcf_wrp
from lsdb import Catalog

from corrgi.correlation.correlation import Correlation
from corrgi.correlation.projected_correlation import ProjectedCorrelation
from corrgi.utils import compute_catalog_size


class Estimator(ABC):
"""Estimator base class"""

def __init__(self, correlation: Correlation):
self.correlation = correlation

def compute_auto_estimate(self, catalog: Catalog, random: Catalog) -> np.ndarray:
"""Computes the auto-correlation for this estimator.
Args:
catalog (Catalog): The catalog of galaxy samples.
random (Catalog): The catalog of random samples.
Returns:
The statistical estimate of the auto-correlation function, as a numpy array.
"""
num_galaxies = compute_catalog_size(catalog)
num_random = compute_catalog_size(random)
dd, rr, dr = self.compute_autocorrelation_counts(catalog, random)
args = self._get_auto_args(num_galaxies, num_random, dd, rr, dr)
estimate, _ = self._get_auto_subroutine()(*args)
return estimate

@abstractmethod
def compute_autocorrelation_counts(
self, catalog: Catalog, random: Catalog
) -> list[np.ndarray, np.ndarray, np.ndarray | int]:
"""Computes the auto-correlation counts (DD, RR, DR). These counts are
represented as numpy arrays but DR may be 0 if it isn't used (e.g. with
the natural estimator)."""
raise NotImplementedError()

def _get_auto_subroutine(self) -> Callable:
"""Returns the Fortran routine to calculate the correlation estimate"""
return tpcf_wrp if isinstance(self.correlation, ProjectedCorrelation) else tpcf

def _get_auto_args(
self,
num_galaxies: int,
num_random: int,
counts_dd: np.ndarray,
counts_rr: np.ndarray,
counts_dr: np.ndarray,
) -> list:
"""Returns the args for the auto-correlation estimator routine"""
counts_bdd = self.correlation.get_bdd_counts()
args = [num_galaxies, num_random, counts_dd, counts_bdd, counts_rr, counts_dr]
if isinstance(self.correlation, ProjectedCorrelation):
# The projected routines require an additional parameter
args.append(self.correlation.params.dsepv)
args.append(self.correlation.params.estimator)
return args
22 changes: 22 additions & 0 deletions src/corrgi/estimators/estimator_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from corrgi.correlation.correlation import Correlation
from corrgi.estimators.estimator import Estimator
from corrgi.estimators.natural_estimator import NaturalEstimator

estimator_class_for_type: dict[str, type[Estimator]] = {"NAT": NaturalEstimator}


def get_estimator_for_correlation(correlation: Correlation) -> Estimator:
"""Constructs an Estimator instance for the specified correlation.
Args:
correlation (Correlation): The correlation instance. The type of
"estimator" to use is specified in its parameters.
Returns:
An initialized Estimator object wrapping the correlation to compute.
"""
type_to_use = correlation.params.estimator
if type_to_use not in estimator_class_for_type:
raise ValueError(f"Cannot load estimator type: {str(type_to_use)}")
estimator_class = estimator_class_for_type[type_to_use]
return estimator_class(correlation)
31 changes: 31 additions & 0 deletions src/corrgi/estimators/natural_estimator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

import dask
import numpy as np
from lsdb import Catalog

from corrgi.dask import perform_auto_counts
from corrgi.estimators.estimator import Estimator


class NaturalEstimator(Estimator):
"""Natural Estimator (`DD/RR - 1`)"""

def compute_autocorrelation_counts(
self, catalog: Catalog, random: Catalog
) -> list[np.ndarray, np.ndarray, np.ndarray | int]:
"""Computes the auto-correlation counts for the provided catalog.
Args:
catalog (Catalog): A galaxy samples catalog.
random (Catalog): A random samples catalog.
Returns:
The DD, RR and DR counts for the natural estimator.
"""
counts_dd = perform_auto_counts(catalog, self.correlation)
counts_rr = perform_auto_counts(random, self.correlation)
counts_dr = 0 # The natural estimator does not use DR counts
counts_dd_rr = dask.compute(*[counts_dd, counts_rr])
counts_dd_rr = self.correlation.transform_counts(counts_dd_rr)
return [*counts_dd_rr, counts_dr]
5 changes: 5 additions & 0 deletions tests/corrgi/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,11 @@ def acf_nat_estimate(acf_expected_results):
return np.load(acf_expected_results / "w_acf_nat.npy")


@pytest.fixture
def pcf_nat_estimate(pcf_expected_results):
return np.load(pcf_expected_results / "w_pcf_nat.npy")


@pytest.fixture
def single_data_partition(data_catalog_dir):
return pd.read_parquet(data_catalog_dir / "Norder=0" / "Dir=0" / "Npix=1.parquet")
Expand Down
44 changes: 14 additions & 30 deletions tests/corrgi/test_acf.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,52 @@
import numpy as np
import numpy.testing as npt
import pytest
from gundam import gundam
import hipscat

from corrgi.correlation.angular_correlation import AngularCorrelation
from corrgi.corrgi import compute_autocorrelation
from corrgi.dask import compute_autocorrelation_counts
from corrgi.estimators import calculate_natural_estimate
from corrgi.estimators.natural_estimator import NaturalEstimator


def test_acf_bins_are_correct(acf_bins_left_edges, acf_bins_right_edges, acf_params):
bins, _ = gundam.makebins(
acf_params.nsept,
acf_params.septmin,
acf_params.dsept,
acf_params.logsept,
)
bins = AngularCorrelation(params=acf_params).make_bins()
all_bins = np.append(acf_bins_left_edges, acf_bins_right_edges[-1])
assert np.array_equal(bins, all_bins)


def test_acf_counts_are_correct(
def test_acf_natural_counts_are_correct(
dask_client, data_catalog, rand_catalog, acf_dd_counts, acf_rr_counts, acf_params
):
ang_corr = AngularCorrelation(params=acf_params)
counts_dd, counts_rr = compute_autocorrelation_counts(
data_catalog, rand_catalog, ang_corr
estimator = NaturalEstimator(AngularCorrelation(params=acf_params))
counts_dd, counts_rr, _ = estimator.compute_autocorrelation_counts(
data_catalog, rand_catalog
)
npt.assert_allclose(counts_dd, acf_dd_counts, rtol=1e-3)
npt.assert_allclose(counts_rr, acf_rr_counts, rtol=2e-3)


def test_acf_natural_estimate_is_correct(
data_catalog_dir, rand_catalog_dir, acf_dd_counts, acf_rr_counts, acf_nat_estimate
dask_client, data_catalog, rand_catalog, acf_nat_estimate, acf_params
):
galaxy_hc_catalog = hipscat.read_from_hipscat(data_catalog_dir)
random_hc_catalog = hipscat.read_from_hipscat(rand_catalog_dir)
num_galaxies = galaxy_hc_catalog.catalog_info.total_rows
num_random = random_hc_catalog.catalog_info.total_rows
estimate = calculate_natural_estimate(
acf_dd_counts, acf_rr_counts, num_galaxies, num_random
)
npt.assert_allclose(acf_nat_estimate, estimate, rtol=2e-3)


def test_acf_e2e(dask_client, data_catalog, rand_catalog, acf_nat_estimate, acf_params):
acf_params.estimator = "NAT"
estimate = compute_autocorrelation(
data_catalog, rand_catalog, AngularCorrelation, params=acf_params
)
npt.assert_allclose(estimate, acf_nat_estimate, rtol=1e-7)


def test_acf_counts_with_weights_are_correct(
def test_acf_natural_counts_with_weights_are_correct(
dask_client,
acf_gals_weight_catalog,
acf_rans_weight_catalog,
acf_dd_counts_with_weights,
acf_rr_counts_with_weights,
acf_params,
):
ang_corr = AngularCorrelation(params=acf_params, use_weights=True)
counts_dd, counts_rr = compute_autocorrelation_counts(
acf_gals_weight_catalog, acf_rans_weight_catalog, ang_corr
estimator = NaturalEstimator(
AngularCorrelation(params=acf_params, use_weights=True)
)
counts_dd, counts_rr, _ = estimator.compute_autocorrelation_counts(
acf_gals_weight_catalog, acf_rans_weight_catalog
)
npt.assert_allclose(counts_dd, acf_dd_counts_with_weights, rtol=1e-3)
npt.assert_allclose(counts_rr, acf_rr_counts_with_weights, rtol=2e-3)
Expand Down
Loading

0 comments on commit 39e15e9

Please sign in to comment.