Skip to content

Commit

Permalink
Get correct catalog length (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
camposandro authored Jul 17, 2024
1 parent 59d83c6 commit a21b9aa
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/corrgi/corrgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from corrgi.correlation.correlation import Correlation
from corrgi.dask import compute_autocorrelation_counts
from corrgi.estimators import calculate_natural_estimate
from corrgi.utils import compute_catalog_size


def compute_autocorrelation(
Expand All @@ -24,9 +25,9 @@ def compute_autocorrelation(
"""
correlation = corr_type(**kwargs)
correlation.validate([catalog, random])
num_galaxies = compute_catalog_size(catalog)
num_random = compute_catalog_size(random)
counts_dd, counts_rr = compute_autocorrelation_counts(catalog, random, correlation)
num_galaxies = catalog.hc_structure.catalog_info.total_rows
num_random = random.hc_structure.catalog_info.total_rows
return calculate_natural_estimate(counts_dd, counts_rr, num_galaxies, num_random)


Expand Down
13 changes: 13 additions & 0 deletions src/corrgi/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from gundam import gundam
from lsdb import Catalog
from numpy import deg2rad


Expand Down Expand Up @@ -42,3 +43,15 @@ def join_count_histograms(partial_histograms: list[np.ndarray]) -> np.ndarray:
The numpy array with the total counts for the partial histograms.
"""
return np.sum(np.stack(partial_histograms), axis=0)


def compute_catalog_size(catalog: Catalog) -> int:
"""Compute the number of rows in a catalog.
Args:
catalog (Catalog): An LSDB catalog.
Returns:
The number of rows in the catalog.
"""
return catalog._ddf.shape[0].compute()

0 comments on commit a21b9aa

Please sign in to comment.