Skip to content

Commit

Permalink
Create classes for type of correlation (#11)
Browse files Browse the repository at this point in the history
* Create classes for type of correlation

* Review PR comments

* Fix method signatures

* Raise exceptions on counting
  • Loading branch information
camposandro authored Jul 2, 2024
1 parent 83224aa commit 1ff1f9d
Show file tree
Hide file tree
Showing 83 changed files with 249 additions and 135 deletions.
Empty file.
41 changes: 41 additions & 0 deletions src/corrgi/correlation/angular_correlation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Callable

import gundam.cflibfor as cff
import pandas as pd
from hipscat.catalog.catalog_info import CatalogInfo

from corrgi.correlation.correlation import Correlation


class AngularCorrelation(Correlation):
"""The angular correlation utilities."""

def _get_auto_method(self):
return cff.mod.th_A_wg if self.use_weights else cff.mod.th_A_naiveway

def _get_cross_method(self) -> Callable:
return cff.mod.th_C_wg if self.use_weights else cff.mod.th_C_naiveway

def _construct_auto_args(self, df: pd.DataFrame, catalog_info: CatalogInfo) -> list:
return [
len(df), # number of particles
*self.get_coords(df, catalog_info), # cartesian coordinates
self.params.nsept, # number of angular separation bins
self.bins, # bins in angular separation [deg]
]

def _construct_cross_args(
self,
left_df: pd.DataFrame,
right_df: pd.DataFrame,
left_catalog_info: CatalogInfo,
right_catalog_info: CatalogInfo,
) -> list:
return [
len(left_df), # number of particles of the left partition
*self.get_coords(left_df, left_catalog_info), # X,Y,Z coordinates of particles
len(right_df), # number of particles of the right partition
*self.get_coords(right_df, right_catalog_info), # X,Y,Z coordinates of particles
self.params.nsept, # number of angular separation bins
self.bins, # bins in angular separation [deg]
]
74 changes: 74 additions & 0 deletions src/corrgi/correlation/correlation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Callable

import numpy as np
import pandas as pd
from hipscat.catalog.catalog_info import CatalogInfo
from munch import Munch

from corrgi.utils import project_coordinates


class Correlation(ABC):
"""Correlation base class."""

def __init__(
self,
bins: np.ndarray,
params: Munch,
use_weights: bool = False,
):
self.bins = bins
self.params = params
self.use_weights = use_weights

def count_auto_pairs(self, df: pd.DataFrame, catalog_info: CatalogInfo) -> np.ndarray:
"""Computes the counts for pairs of the same partition"""
args = self._construct_auto_args(df, catalog_info)
return self._get_auto_method()(*args)

def count_cross_pairs(
self,
left_df: pd.DataFrame,
right_df: pd.DataFrame,
left_catalog_info: CatalogInfo,
right_catalog_info: CatalogInfo,
) -> np.ndarray:
"""Computes the counts for pairs of different partitions"""
args = self._construct_cross_args(left_df, right_df, left_catalog_info, right_catalog_info)
return self._get_cross_method()(*args)

@abstractmethod
def _get_auto_method(self) -> Callable:
"""Reference to Fortran routine to be called on auto pairing"""
raise NotImplementedError()

@abstractmethod
def _construct_auto_args(self, df: pd.DataFrame, catalog_info: CatalogInfo) -> list:
"""Generate the arguments required for the auto pairing method"""
raise NotImplementedError()

@abstractmethod
def _get_cross_method(self) -> Callable:
"""Reference to Fortran routine to be called on cross pairing"""
raise NotImplementedError()

@abstractmethod
def _construct_cross_args(
self,
left_df: pd.DataFrame,
right_df: pd.DataFrame,
left_catalog_info: CatalogInfo,
right_catalog_info: CatalogInfo,
) -> list:
"""Generate the arguments required for the cross pairing method"""
raise NotImplementedError()

@staticmethod
def get_coords(df: pd.DataFrame, catalog_info: CatalogInfo) -> tuple[float, float, float]:
"""Calculate the cartesian coordinates for the points in the partition"""
return project_coordinates(
ra=df[catalog_info.ra_column].to_numpy(), dec=df[catalog_info.dec_column].to_numpy()
)
28 changes: 28 additions & 0 deletions src/corrgi/correlation/projected_correlation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Callable

import pandas as pd
from hipscat.catalog.catalog_info import CatalogInfo

from corrgi.correlation.correlation import Correlation


class ProjectedCorrelation(Correlation):
"""The projected correlation utilities."""

def _get_auto_method(self) -> Callable:
raise NotImplementedError()

def _construct_auto_args(self, df: pd.DataFrame, catalog_info: CatalogInfo) -> list:
raise NotImplementedError()

def _get_cross_method(self) -> Callable:
raise NotImplementedError()

def _construct_cross_args(
self,
left_df: pd.DataFrame,
right_df: pd.DataFrame,
left_catalog_info: CatalogInfo,
right_catalog_info: CatalogInfo,
) -> list:
raise NotImplementedError()
28 changes: 28 additions & 0 deletions src/corrgi/correlation/redshift_correlation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Callable

import pandas as pd
from hipscat.catalog.catalog_info import CatalogInfo

from corrgi.correlation.correlation import Correlation


class RedshiftCorrelation(Correlation):
"""The redshift correlation utilities."""

def _get_auto_method(self) -> Callable:
raise NotImplementedError()

def _construct_auto_args(self, df: pd.DataFrame, catalog_info: CatalogInfo) -> list:
raise NotImplementedError()

def _get_cross_method(self) -> Callable:
raise NotImplementedError()

def _construct_cross_args(
self,
left_df: pd.DataFrame,
right_df: pd.DataFrame,
left_catalog_info: CatalogInfo,
right_catalog_info: CatalogInfo,
) -> list:
raise NotImplementedError()
12 changes: 10 additions & 2 deletions src/corrgi/corrgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,22 @@
from lsdb import Catalog
from munch import Munch

from corrgi.correlation.correlation import Correlation
from corrgi.dask import compute_autocorrelation_counts
from corrgi.estimators import calculate_natural_estimate


def compute_autocorrelation(catalog: Catalog, random: Catalog, params: Munch) -> np.ndarray:
def compute_autocorrelation(
corr_type: type[Correlation],
catalog: Catalog,
random: Catalog,
params: Munch,
) -> np.ndarray:
"""Calculates the auto-correlation for a catalog.
Args:
corr_type (type[Correlation]): The corrgi class corresponding to the type of
correlation (AngularCorrelation, RedshiftCorrelation, or ProjectedCorrelation).
catalog (Catalog): The catalog.
random (Catalog): A random samples catalog.
params (Munch): The parameters dictionary to run gundam with.
Expand All @@ -19,7 +27,7 @@ def compute_autocorrelation(catalog: Catalog, random: Catalog, params: Munch) ->
"""
num_galaxies = catalog.hc_structure.catalog_info.total_rows
num_random = random.hc_structure.catalog_info.total_rows
counts_dd, counts_rr = compute_autocorrelation_counts(catalog, random, params)
counts_dd, counts_rr = compute_autocorrelation_counts(corr_type, catalog, random, params)
return calculate_natural_estimate(counts_dd, counts_rr, num_galaxies, num_random)


Expand Down
104 changes: 18 additions & 86 deletions src/corrgi/dask.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import dask
import gundam.cflibfor as cff
import numpy as np
import pandas as pd
from dask.distributed import print as dask_print
Expand All @@ -11,14 +10,17 @@
from munch import Munch

from corrgi.alignment import autocorrelation_alignment, crosscorrelation_alignment
from corrgi.parameters import generate_dd_rr_params
from corrgi.utils import join_count_histograms, project_coordinates
from corrgi.correlation.correlation import Correlation
from corrgi.utils import join_count_histograms


def compute_autocorrelation_counts(catalog: Catalog, random: Catalog, params: Munch) -> np.ndarray:
def compute_autocorrelation_counts(
corr_type: type[Correlation], catalog: Catalog, random: Catalog, params: Munch
) -> np.ndarray:
"""Computes the auto-correlation counts for a catalog.
Args:
corr_type (type[Correlation]): The correlation class.
catalog (Catalog): The catalog with galaxy samples.
random (Catalog): The catalog with random samples.
params (dict): The gundam parameters for the Fortran subroutine.
Expand All @@ -28,11 +30,11 @@ def compute_autocorrelation_counts(catalog: Catalog, random: Catalog, params: Mu
"""
# Calculate the angular separation bins
bins, _ = gundam.makebins(params.nsept, params.septmin, params.dsept, params.logsept)
params_dd, params_rr = generate_dd_rr_params(params)
# Create correlation with bins and params
correlation = corr_type(bins, params)
# Generate the histograms with counts for each catalog
counts_dd = perform_auto_counts(catalog, bins, params_dd)
counts_rr = perform_auto_counts(random, bins, params_rr)
# Actually compute the results
counts_dd = perform_auto_counts(catalog, correlation)
counts_rr = perform_auto_counts(random, correlation)
return dask.compute(*[counts_dd, counts_rr])


Expand Down Expand Up @@ -82,25 +84,24 @@ def perform_cross_counts(left: Catalog, right: Catalog, *args) -> np.ndarray:
def count_auto_pairs(
df: pd.DataFrame,
catalog_info: CatalogInfo,
bins: np.ndarray,
params: Munch,
correlation: Correlation,
) -> np.ndarray:
"""Calls the fortran routine to compute the counts for pairs of
partitions belonging to the same catalog.
Args:
df (pd.DataFrame): The partition dataframe.
catalog_info (CatalogInfo): The catalog metadata.
bins (np.ndarray): The separation bins, in angular space.
params (Munch): The gundam subroutine parameters.
correlation (Correlation): The correlation instance.
Returns:
The count histogram for the partition pair.
"""
try:
return _count_auto_pairs(df, catalog_info, bins, params)
return correlation.count_auto_pairs(df, catalog_info)
except Exception as exception:
dask_print(exception)
raise exception


@dask.delayed
Expand All @@ -111,8 +112,7 @@ def count_cross_pairs(
right_pix: HealpixPixel,
left_catalog_info: CatalogInfo,
right_catalog_info: CatalogInfo,
bins: np.ndarray,
params: Munch,
correlation: Correlation,
) -> np.ndarray:
"""Calls the fortran routine to compute the counts for pairs of
partitions belonging to two different catalogs.
Expand All @@ -124,86 +124,18 @@ def count_cross_pairs(
right_pix (HealpixPixel): The pixel corresponding to `right_df`.
left_catalog_info (CatalogInfo): The left catalog metadata.
right_catalog_info (CatalogInfo): The right catalog metadata.
bins (np.ndarray): The separation bins, in angular space.
params (Munch): The gundam subroutine parameters.
correlation (Correlation): The correlation instance.
Returns:
The count histogram for the partition pair.
"""
try:
return _count_cross_pairs(
return correlation.count_cross_pairs(
left_df,
right_df,
left_catalog_info,
right_catalog_info,
bins,
params,
)
except Exception as exception:
dask_print(exception)


def _count_auto_pairs(
df: pd.DataFrame,
catalog_info: CatalogInfo,
bins: np.ndarray,
params: Munch,
) -> np.ndarray:
x, y, z = project_coordinates(
ra=df[catalog_info.ra_column].to_numpy(),
dec=df[catalog_info.dec_column].to_numpy(),
)
args = [
len(df), # number of particles
x,
y,
z, # X,Y,Z coordinates of particles
params.nsept, # number of angular separation bins
bins, # bins in angular separation [deg]
]
counts = cff.mod.th_A_naiveway(*args) # fast unweighted counting
return counts


def _count_cross_pairs(
left_df: pd.DataFrame,
right_df: pd.DataFrame,
left_catalog_info: CatalogInfo,
right_catalog_info: CatalogInfo,
bins: np.ndarray,
params: Munch,
) -> np.ndarray:
left_x, left_y, left_z = project_coordinates(
ra=left_df[left_catalog_info.ra_column].to_numpy(),
dec=left_df[left_catalog_info.dec_column].to_numpy(),
)
right_x, right_y, right_z = project_coordinates(
ra=right_df[right_catalog_info.ra_column].to_numpy(),
dec=right_df[right_catalog_info.dec_column].to_numpy(),
)
args = [
1, # number of threads OpenMP
len(left_df), # number of particles of the left partition
left_df[left_catalog_info.ra_column].to_numpy(), # RA of particles [deg]
left_df[left_catalog_info.dec_column].to_numpy(), # DEC of particles [deg]
left_x,
left_y,
left_z, # X,Y,Z coordinates of particles
len(right_df), # number of particles of the right partition
right_x,
right_y,
right_z, # X,Y,Z coordinates of particles
params.nsept, # number of angular separation bins
bins, # bins in angular separation [deg]
params.sbound,
params.mxh1,
params.mxh2,
params.cntid,
params.logf,
params.sk1,
np.zeros(len(right_df)),
params.grid,
]
# TODO: Create gundam th_C_naive_way that accepts only the necessary arguments
counts = cff.mod.th_C(*args) # fast unweighted counting
return counts
raise exception
Loading

0 comments on commit 1ff1f9d

Please sign in to comment.