Skip to content

Commit

Permalink
replace batch lambdas
Browse files Browse the repository at this point in the history
  • Loading branch information
dougbrn committed Feb 2, 2024
1 parent 482f368 commit 62cc1ba
Showing 1 changed file with 27 additions and 13 deletions.
40 changes: 27 additions & 13 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,22 +1130,36 @@ def s2n_inter_quartile_range(flux, err):

id_col = self._id_col # pre-compute needed for dask in lambda function

def _apply_func_to_lc(lc, func, *args, **kwargs):
"""
Apply a batch function to a lightcurve
"""
return func(
*[lc[arg].to_numpy() if arg != id_col else lc.index.to_numpy() for arg in args],
**kwargs,
)

if use_map: # use map_partitions

def _batch_apply(df, func, on, *args, **kwargs):
"""
Apply a function to a partition of the dataframe
"""
return df.groupby(on, group_keys=True, sort=False).apply(
_apply_func_to_lc, func, *args, **kwargs
)

id_col = self._id_col # need to grab this before mapping
batch = source_to_batch.map_partitions(
lambda x: x.groupby(on, group_keys=True).apply(
lambda y: func(
*[y[arg].to_numpy() if arg != id_col else y.index.to_numpy() for arg in args],
**kwargs,
),
),
meta=meta,
)

batch = source_to_batch.map_partitions(_batch_apply, func, on, *args, **kwargs, meta=meta)

else: # use groupby
batch = source_to_batch.groupby(on, group_keys=False).apply(
lambda x: func(
*[x[arg].to_numpy() if arg != id_col else x.index.to_numpy() for arg in args], **kwargs
),
# don't use _batch_apply as meta must be specified in the apply call
batch = source_to_batch.groupby(on, group_keys=True, sort=False).apply(
_apply_func_to_lc,
func,
*args,
**kwargs,
meta=meta,
)

Expand Down

0 comments on commit 62cc1ba

Please sign in to comment.