-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move counting routines to estimator class (#25)
* Create estimator wrapper * Move count routines to estimators * Increase test tolerance * Improve code clarity
- Loading branch information
1 parent
a21b9aa
commit 39e15e9
Showing
15 changed files
with
197 additions
and
92 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from __future__ import annotations | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Callable | ||
|
||
import numpy as np | ||
from gundam.gundam import tpcf, tpcf_wrp | ||
from lsdb import Catalog | ||
|
||
from corrgi.correlation.correlation import Correlation | ||
from corrgi.correlation.projected_correlation import ProjectedCorrelation | ||
from corrgi.utils import compute_catalog_size | ||
|
||
|
||
class Estimator(ABC): | ||
"""Estimator base class""" | ||
|
||
def __init__(self, correlation: Correlation): | ||
self.correlation = correlation | ||
|
||
def compute_auto_estimate(self, catalog: Catalog, random: Catalog) -> np.ndarray: | ||
"""Computes the auto-correlation for this estimator. | ||
Args: | ||
catalog (Catalog): The catalog of galaxy samples. | ||
random (Catalog): The catalog of random samples. | ||
Returns: | ||
The statistical estimate of the auto-correlation function, as a numpy array. | ||
""" | ||
num_galaxies = compute_catalog_size(catalog) | ||
num_random = compute_catalog_size(random) | ||
dd, rr, dr = self.compute_autocorrelation_counts(catalog, random) | ||
args = self._get_auto_args(num_galaxies, num_random, dd, rr, dr) | ||
estimate, _ = self._get_auto_subroutine()(*args) | ||
return estimate | ||
|
||
@abstractmethod | ||
def compute_autocorrelation_counts( | ||
self, catalog: Catalog, random: Catalog | ||
) -> list[np.ndarray, np.ndarray, np.ndarray | int]: | ||
"""Computes the auto-correlation counts (DD, RR, DR). These counts are | ||
represented as numpy arrays but DR may be 0 if it isn't used (e.g. with | ||
the natural estimator).""" | ||
raise NotImplementedError() | ||
|
||
def _get_auto_subroutine(self) -> Callable: | ||
"""Returns the Fortran routine to calculate the correlation estimate""" | ||
return tpcf_wrp if isinstance(self.correlation, ProjectedCorrelation) else tpcf | ||
|
||
def _get_auto_args( | ||
self, | ||
num_galaxies: int, | ||
num_random: int, | ||
counts_dd: np.ndarray, | ||
counts_rr: np.ndarray, | ||
counts_dr: np.ndarray, | ||
) -> list: | ||
"""Returns the args for the auto-correlation estimator routine""" | ||
counts_bdd = self.correlation.get_bdd_counts() | ||
args = [num_galaxies, num_random, counts_dd, counts_bdd, counts_rr, counts_dr] | ||
if isinstance(self.correlation, ProjectedCorrelation): | ||
# The projected routines require an additional parameter | ||
args.append(self.correlation.params.dsepv) | ||
args.append(self.correlation.params.estimator) | ||
return args |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from corrgi.correlation.correlation import Correlation | ||
from corrgi.estimators.estimator import Estimator | ||
from corrgi.estimators.natural_estimator import NaturalEstimator | ||
|
||
estimator_class_for_type: dict[str, type[Estimator]] = {"NAT": NaturalEstimator} | ||
|
||
|
||
def get_estimator_for_correlation(correlation: Correlation) -> Estimator: | ||
"""Constructs an Estimator instance for the specified correlation. | ||
Args: | ||
correlation (Correlation): The correlation instance. The type of | ||
"estimator" to use is specified in its parameters. | ||
Returns: | ||
An initialized Estimator object wrapping the correlation to compute. | ||
""" | ||
type_to_use = correlation.params.estimator | ||
if type_to_use not in estimator_class_for_type: | ||
raise ValueError(f"Cannot load estimator type: {str(type_to_use)}") | ||
estimator_class = estimator_class_for_type[type_to_use] | ||
return estimator_class(correlation) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from __future__ import annotations | ||
|
||
import dask | ||
import numpy as np | ||
from lsdb import Catalog | ||
|
||
from corrgi.dask import perform_auto_counts | ||
from corrgi.estimators.estimator import Estimator | ||
|
||
|
||
class NaturalEstimator(Estimator): | ||
"""Natural Estimator (`DD/RR - 1`)""" | ||
|
||
def compute_autocorrelation_counts( | ||
self, catalog: Catalog, random: Catalog | ||
) -> list[np.ndarray, np.ndarray, np.ndarray | int]: | ||
"""Computes the auto-correlation counts for the provided catalog. | ||
Args: | ||
catalog (Catalog): A galaxy samples catalog. | ||
random (Catalog): A random samples catalog. | ||
Returns: | ||
The DD, RR and DR counts for the natural estimator. | ||
""" | ||
counts_dd = perform_auto_counts(catalog, self.correlation) | ||
counts_rr = perform_auto_counts(random, self.correlation) | ||
counts_dr = 0 # The natural estimator does not use DR counts | ||
counts_dd_rr = dask.compute(*[counts_dd, counts_rr]) | ||
counts_dd_rr = self.correlation.transform_counts(counts_dd_rr) | ||
return [*counts_dd_rr, counts_dr] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.