Skip to content

Commit

Permalink
Fix shallow copy of hc structure
Browse files Browse the repository at this point in the history
  • Loading branch information
camposandro committed Oct 23, 2024
1 parent a282303 commit 21ddee1
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 21 deletions.
53 changes: 48 additions & 5 deletions src/lsdb/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,21 +614,17 @@ def nest_lists(
recommend setting the following dask config setting to prevent this:
`dask.config.set({"dataframe.convert-string":False})`
"""
new_ddf = super().nest_lists(
catalog = super().nest_lists(
base_columns=base_columns,
list_columns=list_columns,
name=name,
)

catalog = Catalog(new_ddf._ddf, self._ddf_pixel_map, self.hc_structure)

if self.margin is not None:
catalog.margin = self.margin.nest_lists(
base_columns=base_columns,
list_columns=list_columns,
name=name,
)

return catalog

def dropna(
Expand Down Expand Up @@ -708,6 +704,53 @@ def dropna(
return catalog

def reduce(self, func, *args, meta=None, **kwargs) -> Catalog:
"""
Takes a function and applies it to each top-level row of the Catalog.
docstring copied from nested-pandas
The user may specify which columns the function is applied to, with
columns from the 'base' layer being passsed to the function as
scalars and columns from the nested layers being passed as numpy arrays.
Parameters
----------
func : callable
Function to apply to each nested dataframe. The first arguments to `func` should be which
columns to apply the function to. See the Notes for recommendations
on writing func outputs.
args : positional arguments
Positional arguments to pass to the function, the first *args should be the names of the
columns to apply the function to.
meta : dataframe or series-like, optional
The dask meta of the output. If append_columns is True, the meta should specify just the
additional columns output by func.
append_columns : bool
If the output columns should be appended to the orignal dataframe.
kwargs : keyword arguments, optional
Keyword arguments to pass to the function.
Returns
-------
`HealpixDataset`
`HealpixDataset` with the results of the function applied to the columns of the frame.
Notes
-----
By default, `reduce` will produce a `NestedFrame` with enumerated
column names for each returned value of the function. For more useful
naming, it's recommended to have `func` return a dictionary where each
key is an output column of the dataframe returned by `reduce`.
Example User Function:
>>> def my_sum(col1, col2):
>>> '''reduce will return a NestedFrame with two columns'''
>>> return {"sum_col1": sum(col1), "sum_col2": sum(col2)}
>>>
>>> catalog.reduce(my_sum, 'sources.col1', 'sources.col2')
"""
catalog = super().reduce(func, *args, meta=meta, **kwargs)
if self.margin is not None:
catalog.margin = self.margin.reduce(func, *args, meta=meta, **kwargs)
Expand Down
27 changes: 12 additions & 15 deletions src/lsdb/catalog/dataset/healpix_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@ def __len__(self):
"""
return len(self.hc_structure)

def _create_modified_hc_structure(self) -> HCHealpixDataset:
"""Copy the catalog structure and invalidate the number of rows.
def _create_modified_hc_structure(self, **kwargs) -> HCHealpixDataset:
"""Copy the catalog structure and override the specified catalog info parameters.
Returns:
A copy of the catalog's structure with the total number of rows set to None.
A copy of the catalog's structure with updated info parameters.
"""
return self.hc_structure.__class__(
catalog_info=self.hc_structure.catalog_info.copy_and_update(total_rows=0),
catalog_info=self.hc_structure.catalog_info.copy_and_update(**kwargs),
pixels=self.hc_structure.pixel_tree,
catalog_path=self.hc_structure.catalog_path,
schema=self.hc_structure.schema,
Expand Down Expand Up @@ -159,7 +159,7 @@ def query(self, expr: str) -> Self:
with the query expression
"""
ndf = self._ddf.query(expr)
hc_structure = self._create_modified_hc_structure()
hc_structure = self._create_modified_hc_structure(total_rows=0)
return self.__class__(ndf, self._ddf_pixel_map, hc_structure)

def _perform_search(
Expand Down Expand Up @@ -539,7 +539,7 @@ def drop_na_part(df: npd.NestedFrame):
return df

ndf = self._ddf.map_partitions(drop_na_part, meta=self._ddf._meta)
hc_structure = self._create_modified_hc_structure()
hc_structure = self._create_modified_hc_structure(total_rows=0)
return self.__class__(ndf, self._ddf_pixel_map, hc_structure)

def nest_lists(
Expand Down Expand Up @@ -585,7 +585,7 @@ def nest_lists(
list_columns=list_columns,
name=name,
)
hc_structure = self._create_modified_hc_structure()
hc_structure = self._create_modified_hc_structure(total_rows=0)
return self.__class__(new_ddf, self._ddf_pixel_map, hc_structure)

def reduce(self, func, *args, meta=None, append_columns=False, **kwargs) -> Self:
Expand Down Expand Up @@ -652,13 +652,10 @@ def reduce_part(df):

ndf = nd.NestedFrame.from_dask_dataframe(self._ddf.map_partitions(reduce_part, meta=meta))

hc_catalog = self.hc_structure
hc_updates: dict = {"total_rows": 0}
if not append_columns:
new_catalog_info = self.hc_structure.catalog_info.copy_and_update(ra_column="", dec_column="")
hc_catalog = self.hc_structure.__class__(
new_catalog_info,
self.hc_structure.pixel_tree,
schema=get_arrow_schema(ndf),
moc=self.hc_structure.moc,
)
hc_updates = {**hc_updates, "ra_column": "", "dec_column": ""}

hc_catalog = self._create_modified_hc_structure(**hc_updates)
hc_catalog.schema = get_arrow_schema(ndf)
return self.__class__(ndf, self._ddf_pixel_map, hc_catalog)
2 changes: 1 addition & 1 deletion tests/lsdb/catalog/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ def test_modified_hc_structure_is_a_deep_copy(small_sky_order1_catalog):
assert small_sky_order1_catalog.hc_structure.moc is not None
assert small_sky_order1_catalog.hc_structure.catalog_info.total_rows == 131

modified_hc_structure = small_sky_order1_catalog._create_modified_hc_structure()
modified_hc_structure = small_sky_order1_catalog._create_modified_hc_structure(total_rows=0)
modified_hc_structure.pixel_tree = None
modified_hc_structure.catalog_path = None
modified_hc_structure.schema = None
Expand Down
3 changes: 3 additions & 0 deletions tests/lsdb/catalog/test_nested.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def mean_mag(ra, dec, mag):
assert isinstance(reduced_cat, Catalog)
assert isinstance(reduced_cat._ddf, nd.NestedFrame)

assert reduced_cat.hc_structure.catalog_info.ra_column == ""
assert reduced_cat.hc_structure.catalog_info.dec_column == ""

reduced_cat_compute = reduced_cat.compute()
assert isinstance(reduced_cat_compute, npd.NestedFrame)

Expand Down

0 comments on commit 21ddee1

Please sign in to comment.