-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create classes for type of correlation (#11)
* Create classes for type of correlation * Review PR comments * Fix method signatures * Raise exceptions on counting
- Loading branch information
1 parent
83224aa
commit 1ff1f9d
Showing
83 changed files
with
249 additions
and
135 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from typing import Callable | ||
|
||
import gundam.cflibfor as cff | ||
import pandas as pd | ||
from hipscat.catalog.catalog_info import CatalogInfo | ||
|
||
from corrgi.correlation.correlation import Correlation | ||
|
||
|
||
class AngularCorrelation(Correlation): | ||
"""The angular correlation utilities.""" | ||
|
||
def _get_auto_method(self): | ||
return cff.mod.th_A_wg if self.use_weights else cff.mod.th_A_naiveway | ||
|
||
def _get_cross_method(self) -> Callable: | ||
return cff.mod.th_C_wg if self.use_weights else cff.mod.th_C_naiveway | ||
|
||
def _construct_auto_args(self, df: pd.DataFrame, catalog_info: CatalogInfo) -> list: | ||
return [ | ||
len(df), # number of particles | ||
*self.get_coords(df, catalog_info), # cartesian coordinates | ||
self.params.nsept, # number of angular separation bins | ||
self.bins, # bins in angular separation [deg] | ||
] | ||
|
||
def _construct_cross_args( | ||
self, | ||
left_df: pd.DataFrame, | ||
right_df: pd.DataFrame, | ||
left_catalog_info: CatalogInfo, | ||
right_catalog_info: CatalogInfo, | ||
) -> list: | ||
return [ | ||
len(left_df), # number of particles of the left partition | ||
*self.get_coords(left_df, left_catalog_info), # X,Y,Z coordinates of particles | ||
len(right_df), # number of particles of the right partition | ||
*self.get_coords(right_df, right_catalog_info), # X,Y,Z coordinates of particles | ||
self.params.nsept, # number of angular separation bins | ||
self.bins, # bins in angular separation [deg] | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
from __future__ import annotations | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Callable | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from hipscat.catalog.catalog_info import CatalogInfo | ||
from munch import Munch | ||
|
||
from corrgi.utils import project_coordinates | ||
|
||
|
||
class Correlation(ABC): | ||
"""Correlation base class.""" | ||
|
||
def __init__( | ||
self, | ||
bins: np.ndarray, | ||
params: Munch, | ||
use_weights: bool = False, | ||
): | ||
self.bins = bins | ||
self.params = params | ||
self.use_weights = use_weights | ||
|
||
def count_auto_pairs(self, df: pd.DataFrame, catalog_info: CatalogInfo) -> np.ndarray: | ||
"""Computes the counts for pairs of the same partition""" | ||
args = self._construct_auto_args(df, catalog_info) | ||
return self._get_auto_method()(*args) | ||
|
||
def count_cross_pairs( | ||
self, | ||
left_df: pd.DataFrame, | ||
right_df: pd.DataFrame, | ||
left_catalog_info: CatalogInfo, | ||
right_catalog_info: CatalogInfo, | ||
) -> np.ndarray: | ||
"""Computes the counts for pairs of different partitions""" | ||
args = self._construct_cross_args(left_df, right_df, left_catalog_info, right_catalog_info) | ||
return self._get_cross_method()(*args) | ||
|
||
@abstractmethod | ||
def _get_auto_method(self) -> Callable: | ||
"""Reference to Fortran routine to be called on auto pairing""" | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def _construct_auto_args(self, df: pd.DataFrame, catalog_info: CatalogInfo) -> list: | ||
"""Generate the arguments required for the auto pairing method""" | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def _get_cross_method(self) -> Callable: | ||
"""Reference to Fortran routine to be called on cross pairing""" | ||
raise NotImplementedError() | ||
|
||
@abstractmethod | ||
def _construct_cross_args( | ||
self, | ||
left_df: pd.DataFrame, | ||
right_df: pd.DataFrame, | ||
left_catalog_info: CatalogInfo, | ||
right_catalog_info: CatalogInfo, | ||
) -> list: | ||
"""Generate the arguments required for the cross pairing method""" | ||
raise NotImplementedError() | ||
|
||
@staticmethod | ||
def get_coords(df: pd.DataFrame, catalog_info: CatalogInfo) -> tuple[float, float, float]: | ||
"""Calculate the cartesian coordinates for the points in the partition""" | ||
return project_coordinates( | ||
ra=df[catalog_info.ra_column].to_numpy(), dec=df[catalog_info.dec_column].to_numpy() | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from typing import Callable | ||
|
||
import pandas as pd | ||
from hipscat.catalog.catalog_info import CatalogInfo | ||
|
||
from corrgi.correlation.correlation import Correlation | ||
|
||
|
||
class ProjectedCorrelation(Correlation): | ||
"""The projected correlation utilities.""" | ||
|
||
def _get_auto_method(self) -> Callable: | ||
raise NotImplementedError() | ||
|
||
def _construct_auto_args(self, df: pd.DataFrame, catalog_info: CatalogInfo) -> list: | ||
raise NotImplementedError() | ||
|
||
def _get_cross_method(self) -> Callable: | ||
raise NotImplementedError() | ||
|
||
def _construct_cross_args( | ||
self, | ||
left_df: pd.DataFrame, | ||
right_df: pd.DataFrame, | ||
left_catalog_info: CatalogInfo, | ||
right_catalog_info: CatalogInfo, | ||
) -> list: | ||
raise NotImplementedError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from typing import Callable | ||
|
||
import pandas as pd | ||
from hipscat.catalog.catalog_info import CatalogInfo | ||
|
||
from corrgi.correlation.correlation import Correlation | ||
|
||
|
||
class RedshiftCorrelation(Correlation): | ||
"""The redshift correlation utilities.""" | ||
|
||
def _get_auto_method(self) -> Callable: | ||
raise NotImplementedError() | ||
|
||
def _construct_auto_args(self, df: pd.DataFrame, catalog_info: CatalogInfo) -> list: | ||
raise NotImplementedError() | ||
|
||
def _get_cross_method(self) -> Callable: | ||
raise NotImplementedError() | ||
|
||
def _construct_cross_args( | ||
self, | ||
left_df: pd.DataFrame, | ||
right_df: pd.DataFrame, | ||
left_catalog_info: CatalogInfo, | ||
right_catalog_info: CatalogInfo, | ||
) -> list: | ||
raise NotImplementedError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.