Skip to content

Commit

Permalink
Implement map type and merging (#511)
Browse files Browse the repository at this point in the history
* Initial work on map type

* Checkpoint of distributed map merging.

* Clean up typing

* Clean up expectations.

* isort .

* Bad merge, pylint, and better docstrings.
  • Loading branch information
delucchi-cmu authored Dec 2, 2024
1 parent 5236ce2 commit d56f99d
Show file tree
Hide file tree
Showing 27 changed files with 332 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/lsdb/catalog/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .catalog import Catalog
from .map_catalog import MapCatalog
from .margin_catalog import MarginCatalog
56 changes: 54 additions & 2 deletions src/lsdb/catalog/catalog.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import List, Tuple, Type
from typing import Callable, List, Tuple, Type

import hats as hc
import nested_dask as nd
Expand All @@ -14,6 +14,7 @@

from lsdb.catalog.association_catalog import AssociationCatalog
from lsdb.catalog.dataset.healpix_dataset import HealpixDataset
from lsdb.catalog.map_catalog import MapCatalog
from lsdb.catalog.margin_catalog import MarginCatalog
from lsdb.core.crossmatch.abstract_crossmatch_algorithm import AbstractCrossmatchAlgorithm
from lsdb.core.crossmatch.crossmatch_algorithms import BuiltInCrossmatchAlgorithm
Expand All @@ -28,12 +29,13 @@
join_catalog_data_through,
merge_asof_catalog_data,
)
from lsdb.dask.merge_map_catalog_data import merge_map_catalog_data
from lsdb.dask.partition_indexer import PartitionIndexer
from lsdb.io.schema import get_arrow_schema
from lsdb.types import DaskDFPixelMap


# pylint: disable=R0903, W0212
# pylint: disable=protected-access,too-many-public-methods
class Catalog(HealpixDataset):
"""LSDB Catalog DataFrame to perform analysis of sky catalogs and efficient
spatial operations.
Expand Down Expand Up @@ -227,6 +229,56 @@ def crossmatch(
)
return Catalog(ddf, ddf_map, hc_catalog)

def merge_map(
self,
map_catalog: MapCatalog,
func: Callable[..., npd.NestedFrame],
*args,
meta: npd.NestedFrame | None = None,
**kwargs,
) -> Catalog:
"""Applies a function to each pair of partitions in this catalog and the map catalog.
The pixels from each catalog are aligned via a `PixelAlignment`, and the respective dataframes
are passed to the function. The resulting catalog will have the same partitions as the point
source catalog.
Args:
map_catalog (MapCatalog): The continuous map to merge.
func (Callable): The function applied to each catalog partition, which will be called with:
`func(catalog_partition: npd.NestedFrame, map_partition: npd.NestedFrame, `
` healpix_pixel: HealpixPixel, *args, **kwargs)`
with the additional args and kwargs passed to the `merge_map` function.
*args: Additional positional arguments to call `func` with.
meta (pd.DataFrame | pd.Series | Dict | Iterable | Tuple | None): An empty pandas DataFrame that
has columns matching the output of the function applied to the catalog partition. Other types
are accepted to describe the output dataframe format, for full details see the dask
documentation https://blog.dask.org/2022/08/09/understanding-meta-keyword-argument
If meta is None (default), LSDB will try to work out the output schema of the function by
calling the function with an empty DataFrame. If the function does not work with an empty
DataFrame, this will raise an error and meta must be set. Note that some operations in LSDB
will generate empty partitions, though these can be removed by calling the
`Catalog.prune_empty_partitions` method.
**kwargs: Additional keyword args to pass to the function. These are passed to the Dask DataFrame
`dask.dataframe.map_partitions` function, so any of the dask function's keyword args such as
`transform_divisions` will be passed through and work as described in the dask documentation
https://docs.dask.org/en/stable/generated/dask.dataframe.DataFrame.map_partitions.html
Returns:
A Catalog with the data from the left and right catalogs merged with one row for each
pair of neighbors found from cross-matching.
The resulting table contains all columns from the left and right catalogs with their
respective suffixes and, whenever specified, a set of extra columns generated by the
crossmatch algorithm.
"""
ddf, ddf_map, alignment = merge_map_catalog_data(self, map_catalog, func, *args, meta=meta, **kwargs)
new_catalog_info = self.hc_structure.catalog_info.copy_and_update(total_rows=0)
hc_catalog = hc.catalog.Catalog(
new_catalog_info, alignment.pixel_tree, schema=get_arrow_schema(ddf), moc=alignment.moc
)
return Catalog(ddf, ddf_map, hc_catalog)

def cone_search(self, ra: float, dec: float, radius_arcsec: float, fine: bool = True) -> Catalog:
"""Perform a cone search to filter the catalog
Expand Down
2 changes: 1 addition & 1 deletion src/lsdb/catalog/dataset/healpix_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from lsdb.types import DaskDFPixelMap


# pylint: disable=W0212
# pylint: disable=protected-access
class HealpixDataset(Dataset):
"""LSDB Catalog DataFrame to perform analysis of sky catalogs and efficient
spatial operations.
Expand Down
14 changes: 14 additions & 0 deletions src/lsdb/catalog/map_catalog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import hats as hc

from lsdb.catalog.dataset.healpix_dataset import HealpixDataset


class MapCatalog(HealpixDataset):
"""LSDB DataFrame to contain a continuous map.
Attributes:
hc_structure: `hats.MapCatalog` object representing the structure
and metadata of the HATS catalog
"""

hc_structure: hc.catalog.MapCatalog
137 changes: 137 additions & 0 deletions src/lsdb/dask/merge_map_catalog_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# pylint: disable=duplicate-code
from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Tuple

import dask
import nested_dask as nd
import nested_pandas as npd
from hats.catalog import TableProperties
from hats.pixel_math import HealpixPixel
from hats.pixel_tree import PixelAlignment, PixelAlignmentType
from hats.pixel_tree.pixel_alignment import align_with_mocs

from lsdb.dask.merge_catalog_functions import (
align_and_apply,
construct_catalog_args,
filter_by_spatial_index_to_pixel,
get_healpix_pixels_from_alignment,
)
from lsdb.types import DaskDFPixelMap

if TYPE_CHECKING:
from lsdb.catalog import Catalog, MapCatalog


# pylint: disable=too-many-arguments, unused-argument
@dask.delayed
def perform_merge_map(
catalog_partition: npd.NestedFrame,
map_partition: npd.NestedFrame,
catalog_pixel: HealpixPixel,
map_pixel: HealpixPixel,
catalog_structure: TableProperties,
map_structure: TableProperties,
func: Callable[..., npd.NestedFrame],
*args,
**kwargs,
):
"""Applies a function to each pair of partitions in this catalog and the map catalog.
Args:
catalog_partition (npd.NestedFrame): partition of the point-source catalog
map_partition (npd.NestedFrame): partition of the continuous map catalog
catalog_pixel (HealpixPixel): the HEALPix pixel of the catalog partition
map_pixel (HealpixPixel): the HEALPix pixel of the map partition
catalog_structure (hc.TableProperties): the catalog info of the catalog
map_structure (hc.TableProperties): the catalog info of the map
func (Callable): method to apply to the two partitions
Returns:
A dataframe with the result of calling `func`
"""
if map_pixel.order > catalog_pixel.order:
catalog_partition = filter_by_spatial_index_to_pixel(
catalog_partition, map_pixel.order, map_pixel.pixel
)

catalog_partition.sort_index(inplace=True)
map_partition.sort_index(inplace=True)
return func(catalog_partition, map_partition, catalog_pixel, map_pixel, *args, **kwargs)


# pylint: disable=protected-access
def merge_map_catalog_data(
point_catalog: Catalog,
map_catalog: MapCatalog,
func: Callable[..., npd.NestedFrame],
*args,
meta: npd.NestedFrame | None = None,
**kwargs,
) -> Tuple[nd.NestedFrame, DaskDFPixelMap, PixelAlignment]:
"""Applies a function to each pair of partitions in this catalog and the map catalog.
The pixels from each catalog are aligned via a `PixelAlignment`, and the respective dataframes
are passed to the function. The resulting catalog will have the same partitions as the point
source catalog.
Args:
point_catalog (lsdb.Catalog): the point-source catalog to apply
map_catalog (lsdb.MapCatalog): the continuous map catalog to apply
func (Callable): The function applied to each catalog partition, which will be called with:
`func(catalog_partition: npd.NestedFrame, map_partition: npd.NestedFrame, `
` healpix_pixel: HealpixPixel, *args, **kwargs)`
with the additional args and kwargs passed to the `merge_map` function.
*args: Additional positional arguments to call `func` with.
meta (pd.DataFrame | pd.Series | Dict | Iterable | Tuple | None): An empty pandas DataFrame that
has columns matching the output of the function applied to the catalog partition. Other types
are accepted to describe the output dataframe format, for full details see the dask
documentation https://blog.dask.org/2022/08/09/understanding-meta-keyword-argument
If meta is None (default), LSDB will try to work out the output schema of the function by
calling the function with an empty DataFrame. If the function does not work with an empty
DataFrame, this will raise an error and meta must be set. Note that some operations in LSDB
will generate empty partitions, though these can be removed by calling the
`Catalog.prune_empty_partitions` method.
**kwargs: Additional keyword args to pass to the function. These are passed to the Dask DataFrame
`dask.dataframe.map_partitions` function, so any of the dask function's keyword args such as
`transform_divisions` will be passed through and work as described in the dask documentation
https://docs.dask.org/en/stable/generated/dask.dataframe.DataFrame.map_partitions.html
Returns:
A tuple of the dask dataframe with the result of the operation, the pixel map from HEALPix
pixel to partition index within the dataframe, and the PixelAlignment of the two input
catalogs.
"""
if meta is None:
meta = func(
point_catalog._ddf._meta.copy(),
map_catalog._ddf._meta.copy(),
HealpixPixel(0, 0),
HealpixPixel(0, 0),
)
if meta is None:
raise ValueError(
"func returned None for empty DataFrame input. The function must return a value, changing"
" the partitions in place will not work. If the function does not work for empty inputs, "
"please specify a `meta` argument."
)
alignment = align_with_mocs(
point_catalog.hc_structure.pixel_tree,
map_catalog.hc_structure.pixel_tree,
point_catalog.hc_structure.moc,
map_catalog.hc_structure.moc,
alignment_type=PixelAlignmentType.INNER,
)

left_pixels, right_pixels = get_healpix_pixels_from_alignment(alignment)

partitions_with_func = align_and_apply(
[(point_catalog, left_pixels), (map_catalog, right_pixels)],
perform_merge_map,
func,
*args,
**kwargs,
)

return construct_catalog_args(partitions_with_func, meta, alignment)
2 changes: 1 addition & 1 deletion src/lsdb/io/to_hats.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def calculate_histogram(df: npd.NestedFrame, histogram_order: int) -> SparseHist
return SparseHistogram.make_from_counts(indexes, counts_at_indexes, histogram_order)


# pylint: disable=W0212
# pylint: disable=protected-access
def to_hats(
catalog: HealpixDataset,
*,
Expand Down
13 changes: 13 additions & 0 deletions src/lsdb/loaders/hats/read_hats.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from lsdb.catalog.association_catalog import AssociationCatalog
from lsdb.catalog.catalog import Catalog, DaskDFPixelMap, MarginCatalog
from lsdb.catalog.map_catalog import MapCatalog
from lsdb.catalog.margin_catalog import _validate_margin_catalog
from lsdb.core.search.abstract_search import AbstractSearch
from lsdb.dask.divisions import get_pixels_divisions
Expand Down Expand Up @@ -83,6 +84,8 @@ def read_hats(
return _load_margin_catalog(hc_catalog, config)
if catalog_type == CatalogType.ASSOCIATION:
return _load_association_catalog(hc_catalog, config)
if catalog_type == CatalogType.MAP:
return _load_map_catalog(hc_catalog, config)

raise NotImplementedError(f"Cannot load catalog of type {catalog_type}")

Expand Down Expand Up @@ -154,6 +157,16 @@ def _load_object_catalog(hc_catalog, config):
return catalog


def _load_map_catalog(hc_catalog, config):
"""Load a catalog from the configuration specified when the loader was created
Returns:
Catalog object with data from the source given at loader initialization
"""
dask_df, dask_df_pixel_map = _load_dask_df_and_map(hc_catalog, config)
return MapCatalog(dask_df, dask_df_pixel_map, hc_catalog)


def _create_dask_meta_schema(schema: pa.Schema, config) -> npd.NestedFrame:
"""Creates the Dask meta DataFrame from the HATS catalog schema."""
dask_meta_schema = schema.empty_table().to_pandas(types_mapper=config.get_dtype_mapper())
Expand Down
58 changes: 56 additions & 2 deletions tests/data/generate_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"from hats_import.soap import SoapArguments\n",
"import tempfile\n",
"from dask.distributed import Client\n",
"from pathlib import Path\n",
"\n",
"tmp_path = tempfile.TemporaryDirectory()\n",
"tmp_dir = tmp_path.name\n",
Expand Down Expand Up @@ -742,6 +743,59 @@
"cone_search_output.to_csv(\"raw/cone_search_expected/margin.csv\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Square map\n",
"\n",
"Create a trivial map-type catalog. This just contains a `star_count` per order 0\n",
"healpix tile. The value is the square of the healpix index."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from hats.pixel_math.spatial_index import healpix_to_spatial_index\n",
"\n",
"target_pixels = np.arange(0, 12)\n",
"\n",
"healpix_29 = healpix_to_spatial_index(0, target_pixels)\n",
"\n",
"square_vals = target_pixels * target_pixels\n",
"value_frame = pd.DataFrame({\"_healpix_29\": healpix_29, \"star_count\": square_vals})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with tempfile.TemporaryDirectory() as pipeline_tmp:\n",
" csv_file = Path(pipeline_tmp) / \"square_map.csv\"\n",
" value_frame.to_csv(csv_file, index=False)\n",
" args = ImportArguments(\n",
" constant_healpix_order=0, ## forces the moc to order 0.\n",
" catalog_type=\"map\",\n",
" use_healpix_29=True,\n",
" ra_column=None,\n",
" dec_column=None,\n",
" file_reader=\"csv\",\n",
" input_file_list=[csv_file],\n",
" output_artifact_name=\"square_map\",\n",
" output_path=\".\",\n",
" tmp_dir=pipeline_tmp,\n",
" )\n",
"\n",
" runner.pipeline_with_client(args, client)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -772,7 +826,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "hipscatenv",
"display_name": "demo",
"language": "python",
"name": "python3"
},
Expand All @@ -786,7 +840,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/square_map/dataset/_common_metadata
Binary file not shown.
Binary file added tests/data/square_map/dataset/_metadata
Binary file not shown.
13 changes: 13 additions & 0 deletions tests/data/square_map/partition_info.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
Norder,Npix
0,0
0,1
0,2
0,3
0,4
0,5
0,6
0,7
0,8
0,9
0,10
0,11
Binary file added tests/data/square_map/point_map.fits
Binary file not shown.
Loading

0 comments on commit d56f99d

Please sign in to comment.