Skip to content

Commit

Permalink
Merge pull request #383 from lincc-frameworks/use_dask_expr
Browse files Browse the repository at this point in the history
Enable dask-expr
  • Loading branch information
dougbrn authored Mar 27, 2024
2 parents dae414b + 4ee4856 commit 750fe4b
Show file tree
Hide file tree
Showing 12 changed files with 361 additions and 217 deletions.
6 changes: 3 additions & 3 deletions docs/tutorials/batch_showcase.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -391,10 +391,10 @@
"metadata": {},
"outputs": [],
"source": [
"# Overwrite the _meta property\n",
"# Update the metadata\n",
"\n",
"res1_noindex = res1.reset_index()\n",
"res1_noindex._meta = real_meta_from_dataframe\n",
"res1_noindex = res1_noindex.map_partitions(TapeFrame, meta=real_meta_from_dataframe)\n",
"res1_noindex"
]
},
Expand Down Expand Up @@ -584,7 +584,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.10.11"
},
"vscode": {
"interpreter": {
Expand Down
36 changes: 29 additions & 7 deletions docs/tutorials/binning_slowly_changing_sources.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
" flux_col=\"psFlux\",\n",
" err_col=\"psFluxErr\",\n",
" band_col=\"filterName\",\n",
" sorted=True,\n",
")"
]
},
Expand Down Expand Up @@ -118,6 +119,19 @@
"metadata": {},
"outputs": [],
"source": [
"ens = Ensemble() # initialize an ensemble object\n",
"\n",
"# Read in data from a parquet file\n",
"ens.from_parquet(\n",
" \"../../tests/tape_tests/data/source/test_source.parquet\",\n",
" id_col=\"ps1_objid\",\n",
" time_col=\"midPointTai\",\n",
" flux_col=\"psFlux\",\n",
" err_col=\"psFluxErr\",\n",
" band_col=\"filterName\",\n",
" sorted=True,\n",
")\n",
"\n",
"ens.bin_sources(time_window=28.0, offset=0.0, custom_aggr={\"midPointTai\": \"min\"})\n",
"fig, ax = plt.subplots(1, 1)\n",
"ax.hist(ens.source[\"midPointTai\"].compute().tolist(), 500)\n",
Expand Down Expand Up @@ -147,7 +161,7 @@
" \"band\": [\"g\", \"g\", \"g\", \"g\", \"g\", \"g\"],\n",
"}\n",
"cmap = ColumnMapper(id_col=\"id\", time_col=\"midPointTai\", flux_col=\"flux\", err_col=\"err\", band_col=\"band\")\n",
"ens.from_source_dict(rows, column_mapper=cmap)\n",
"ens.from_source_dict(rows, column_mapper=cmap, sorted=True)\n",
"\n",
"fig, ax = plt.subplots(1, 1)\n",
"ax.hist(ens.source[\"midPointTai\"].compute().tolist(), 60)\n",
Expand Down Expand Up @@ -175,7 +189,7 @@
" \"band\": [\"g\", \"g\", \"g\", \"g\", \"g\", \"g\"],\n",
"}\n",
"cmap = ColumnMapper(id_col=\"id\", time_col=\"midPointTai\", flux_col=\"flux\", err_col=\"err\", band_col=\"band\")\n",
"ens.from_source_dict(rows, column_mapper=cmap)\n",
"ens.from_source_dict(rows, column_mapper=cmap, sorted=True)\n",
"ens.bin_sources(time_window=1.0, offset=0.0)\n",
"\n",
"fig, ax = plt.subplots(1, 1)\n",
Expand Down Expand Up @@ -205,7 +219,7 @@
" \"band\": [\"g\", \"g\", \"g\", \"g\", \"g\", \"g\"],\n",
"}\n",
"cmap = ColumnMapper(id_col=\"id\", time_col=\"midPointTai\", flux_col=\"flux\", err_col=\"err\", band_col=\"band\")\n",
"ens.from_source_dict(rows, column_mapper=cmap)\n",
"ens.from_source_dict(rows, column_mapper=cmap, sorted=True)\n",
"ens.bin_sources(time_window=1.0, offset=0.5)\n",
"\n",
"fig, ax = plt.subplots(1, 1)\n",
Expand Down Expand Up @@ -243,6 +257,7 @@
" flux_col=\"psFlux\",\n",
" err_col=\"psFluxErr\",\n",
" band_col=\"filterName\",\n",
" sorted=True,\n",
")\n",
"suggested_offset = ens.find_day_gap_offset()\n",
"print(f\"Suggested offset is {suggested_offset}\")\n",
Expand All @@ -255,19 +270,26 @@
" \"band\": [\"g\", \"g\", \"g\", \"g\", \"g\", \"g\"],\n",
"}\n",
"cmap = ColumnMapper(id_col=\"id\", time_col=\"midPointTai\", flux_col=\"flux\", err_col=\"err\", band_col=\"band\")\n",
"ens.from_source_dict(rows, column_mapper=cmap)\n",
"ens.from_source_dict(rows, column_mapper=cmap, sorted=True)\n",
"ens.bin_sources(time_window=1.0, offset=0.5)\n",
"\n",
"fig, ax = plt.subplots(1, 1)\n",
"ax.hist(ens.source[\"midPointTai\"].compute().tolist(), 60)\n",
"ax.set_xlabel(\"Time (MJD)\")\n",
"ax.set_ylabel(\"Source Count\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "py310",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
Expand All @@ -281,11 +303,11 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.10.11"
},
"vscode": {
"interpreter": {
"hash": "08968836a6367873274ed1d5e98a07391f42fc3a62bd5aba54afbd7b11ba8673"
"hash": "83afbb17b435d9bf8b0d0042367da76f26510da1c5781f0ff6e6c518eab621ec"
}
}
},
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ requires-python = ">=3.9"
dependencies = [
'pandas',
'numpy',
'dask>=2023.6.1,<2024.3.0', # We currently do not support dask/dask-expr
'dask[distributed]',
'dask>=2024.3.0',
'dask[distributed]>=2024.3.0',
'pyarrow',
'pyvo',
'scipy',
Expand Down
11 changes: 11 additions & 0 deletions src/tape/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
import dask
import warnings


QUERY_PLANNING_ON = dask.config.get("dataframe.query-planning")
# Force the use of dask-expressions backends
if QUERY_PLANNING_ON is False:
warnings.warn("This version of tape (v0.4.0+) requires dataframe query-planning, which has been enabled.")
dask.config.set({"dataframe.query-planning": True})

from .analysis import * # noqa
from .ensemble import * # noqa
from .ensemble_frame import * # noqa
from .timeseries import * # noqa
from .ensemble_readers import * # noqa
from ._version import __version__ # noqa
from .ensemble_frame import * # noqa
64 changes: 41 additions & 23 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,22 @@
import warnings
import requests
import lsdb
import dask

import dask.dataframe as dd
import numpy as np
import pandas as pd

from dask.distributed import Client
from dask import config
from collections import Counter
from collections.abc import Iterable

from .analysis.base import AnalysisFunction
from .analysis.feature_extractor import BaseLightCurveFeature, FeatureExtractor
from .analysis.structure_function import SF_METHODS
from .analysis.structurefunction2 import calc_sf2

from .ensemble_frame import (
EnsembleFrame,
EnsembleSeries,
Expand Down Expand Up @@ -326,15 +330,18 @@ def insert_sources(
rows[key] = value

# Create the new row and set the paritioning to match the original dataframe.
df2 = dd.DataFrame.from_dict(rows, npartitions=1)
df2 = dd.DataFrame.from_dict(rows, npartitions=2) # need at least 2 partitions for div
df2 = df2.set_index(self._id_col, drop=True, sort=True)

# Save the divisions and number of partitions.
prev_div = self.source.divisions
prev_num = self.source.npartitions

# Append the new rows to the correct divisions.
self.update_frame(dd.concat([self.source, df2], axis=0, interleave_partitions=True))
result = dd.concat([self.source, df2], axis=0, interleave_partitions=True)
self.update_frame(
self.source._propagate_metadata(result)
) # propagate source metadata and update frame
self.source.set_dirty(True)

# Do the repartitioning if requested. If the divisions were set, reuse them.
Expand Down Expand Up @@ -996,7 +1003,7 @@ def bin_sources(
if tmp_time_col in self.source.columns:
raise KeyError(f"Column '{tmp_time_col}' already exists in source table.")
self.source[tmp_time_col] = self.source[self._time_col].apply(
lambda x: np.floor((x + offset) / time_window) * time_window, meta=pd.Series(dtype=float)
lambda x: np.floor((x + offset) / time_window) * time_window, meta=TapeSeries(dtype=float)
)

# Set up the aggregation functions for the time and flux columns.
Expand Down Expand Up @@ -1030,12 +1037,14 @@ def bin_sources(
aggr_funs[key] = custom_aggr[key]

# Group the columns by id, band, and time bucket and aggregate.
self.update_frame(
self.source.groupby([self._id_col, self._band_col, tmp_time_col]).aggregate(aggr_funs)
result = self.source.groupby([self._id_col, self._band_col, tmp_time_col]).aggregate(aggr_funs)
# Fix the indices and remove the temporary column.
result = self.source._propagate_metadata(
result.reset_index().set_index(self._id_col).drop(columns=[tmp_time_col])
)

# Fix the indices and remove the temporary column.
self.update_frame(self.source.reset_index().set_index(self._id_col).drop(tmp_time_col, axis=1))
# Updates the source frame
self.update_frame(result)

# Mark the source table as dirty.
self.source.set_dirty(True)
Expand Down Expand Up @@ -1217,11 +1226,6 @@ def _batch_apply(df, func, on, *args, **kwargs):
# Output standardization
batch = self._standardize_batch(batch, on, by_band)

# Inherit divisions if known from source and the resulting index is the id
# Groupby on index should always return a subset that adheres to the same divisions criteria
if self.source.known_divisions and batch.index.name == self._id_col:
batch.divisions = self.source.divisions

if label is not None:
if label == "":
label = self._generate_frame_label()
Expand All @@ -1239,10 +1243,18 @@ def _standardize_batch(self, batch, on, by_band):
# make sure the output is separated from the id column
if batch.name == self._id_col:
batch = batch.rename("result")

# need to set the index name
set_idx_name = True
else:
set_idx_name = False

res_cols = [batch.name] # grab the series name to use as a column label

# convert the series to an EnsembleFrame object
batch = EnsembleFrame.from_dask_dataframe(batch.to_frame())
if set_idx_name and len(on) < 2:
batch.index = batch.index.rename(self._id_col)

elif isinstance(batch, EnsembleFrame):
# collect output columns
Expand All @@ -1260,23 +1272,34 @@ def _standardize_batch(self, batch, on, by_band):

# Need to overwrite the meta manually as the multiindex will be
# interpretted by dask as a single "index" column
batch._meta = TapeFrame(columns=on + res_cols)

# [expr] added map_partitions meta assignment
# batch._meta = TapeFrame(columns=on + res_cols)
batch = batch.map_partitions(TapeFrame, meta=TapeFrame(columns=on + res_cols))

# Further reformatting for per-band results
# Pivots on the band column to generate a result column for each
# photometric band.
if by_band:
batch = batch.categorize(self._band_col)
batch = batch.pivot_table(index=on[0], columns=self._band_col, aggfunc="sum")
# [expr] added values
col_values = [col for col in batch.columns if col not in [on[0], self._band_col]]
batch = batch.pivot_table(
index=on[0], columns=self._band_col, values=col_values, aggfunc="sum"
)

# Need to once again reestablish meta for the pivot
band_labels = batch.columns.values
out_cols = []
# To align with pandas pivot_table results, the columns should be generated in reverse order
for col in res_cols[::-1]:
for band in band_labels:
out_cols += [(str(col), str(band))]
batch._meta = TapeFrame(columns=out_cols) # apply new meta
# [expr] adjusted labeling
out_cols += [(str(band[0]), str(band[1]))]

# [expr] added map_partitions meta assignment
# apply new meta
batch = batch.map_partitions(TapeFrame, meta=TapeFrame(columns=band_labels))

# Flatten the columns to a new column per band
batch.columns = ["_".join(col) for col in batch.columns.values]
Expand Down Expand Up @@ -2160,7 +2183,6 @@ def convert_flux_to_mag(self, zero_point, zp_form="mag", out_col_name=None, flux
def _generate_object_table(self):
"""Generate an empty object table from the source table."""
res = self.source.map_partitions(lambda x: TapeObjectFrame(index=x.index.unique()))

return res

def _lazy_sync_tables_from_frame(self, frame):
Expand Down Expand Up @@ -2246,7 +2268,7 @@ def _sync_tables(self):
else:
warnings.warn("Divisions are not known, syncing using a non-lazy method.")
# Sync Source to Object; remove any objects that do not have sources
sor_idx = list(self.source.index.unique().compute())
sor_idx = list(self.source.index.compute().unique())
self.update_frame(self.object.map_partitions(lambda x: x[x.index.isin(sor_idx)]))
self.update_frame(self.object.persist()) # persist the object frame

Expand Down Expand Up @@ -2296,7 +2318,7 @@ def select_random_timeseries(self, seed=None):

# Scan through the shuffled partition list until a partition with data is found
while not object_selected:
partition_index = self.object.partitions[partitions[i]].index
partition_index = self.object.partitions[int(partitions[i])].index
# Check for empty partitions
if len(partition_index) > 0:
lcid = rng.choice(partition_index.values) # randomly select lightcurve
Expand Down Expand Up @@ -2442,10 +2464,6 @@ def sf2(self, sf_method="basic", argument_container=None, use_map=True):
else:
result = self.batch(calc_sf2, use_map=use_map, argument_container=argument_container)

# Inherit divisions information if known
if self.source.known_divisions and self.object.known_divisions:
result.divisions = self.source.divisions

return result

def _translate_meta(self, meta):
Expand Down
Loading

0 comments on commit 750fe4b

Please sign in to comment.