Skip to content

Commit

Permalink
full local unit test passes
Browse files Browse the repository at this point in the history
  • Loading branch information
dougbrn committed Mar 20, 2024
1 parent a04cee6 commit 10be686
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 41 deletions.
46 changes: 27 additions & 19 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import warnings
import requests
import lsdb
#import dask_expr as dd

# import dask_expr as dd
import dask.dataframe as dd
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -329,15 +330,18 @@ def insert_sources(
rows[key] = value

# Create the new row and set the paritioning to match the original dataframe.
df2 = dd.DataFrame.from_dict(rows, npartitions=1)
df2 = dd.DataFrame.from_dict(rows, npartitions=2) # need at least 2 partitions for div
df2 = df2.set_index(self._id_col, drop=True, sort=True)

# Save the divisions and number of partitions.
prev_div = self.source.divisions
prev_num = self.source.npartitions

# Append the new rows to the correct divisions.
self.update_frame(dd.concat([self.source, df2], axis=0, interleave_partitions=True))
result = dd.concat([self.source, df2], axis=0, interleave_partitions=True)
self.update_frame(
self.source._propagate_metadata(result)
) # propagate source metadata and update frame
self.source.set_dirty(True)

# Do the repartitioning if requested. If the divisions were set, reuse them.
Expand Down Expand Up @@ -993,12 +997,16 @@ def bin_sources(
aggr_funs[key] = custom_aggr[key]

# Group the columns by id, band, and time bucket and aggregate.
self.update_frame(
self.source.groupby([self._id_col, self._band_col, tmp_time_col]).aggregate(aggr_funs)
)
result = self.source.groupby([self._id_col, self._band_col, tmp_time_col]).aggregate(aggr_funs)
# self.update_frame(self.source._propagate_metadata(result))

# Fix the indices and remove the temporary column.
self.update_frame(self.source.reset_index().set_index(self._id_col).drop(tmp_time_col, axis=1))
result = self.source._propagate_metadata(

Check warning on line 1004 in src/tape/ensemble.py

View check run for this annotation

Codecov / codecov/patch

src/tape/ensemble.py#L1004

Added line #L1004 was not covered by tests
result.reset_index().set_index(self._id_col).drop(columns=[tmp_time_col])
)

# Updates the source frame
self.update_frame(result)

# Mark the source table as dirty.
self.source.set_dirty(True)
Expand Down Expand Up @@ -1215,34 +1223,34 @@ def _standardize_batch(self, batch, on, by_band):
# interpretted by dask as a single "index" column

# [expr] added map_partitions meta assignment
#batch._meta = TapeFrame(columns=on + res_cols)
batch = batch.map_partitions(TapeFrame, meta = TapeFrame(columns=on + res_cols))
# batch._meta = TapeFrame(columns=on + res_cols)
batch = batch.map_partitions(TapeFrame, meta=TapeFrame(columns=on + res_cols))

# Further reformatting for per-band results
# Pivots on the band column to generate a result column for each
# photometric band.
if by_band:
batch = batch.categorize(self._band_col)
# [expr] added values

Check warning on line 1234 in src/tape/ensemble.py

View check run for this annotation

Codecov / codecov/patch

src/tape/ensemble.py#L1234

Added line #L1234 was not covered by tests
#import pdb;pdb.set_trace()
col_values = [col for col in batch.columns if col not in [on[0], self._band_col]]
batch = batch.pivot_table(index=on[0], columns=self._band_col, values=col_values, aggfunc="sum")

batch = batch.pivot_table(
index=on[0], columns=self._band_col, values=col_values, aggfunc="sum"
)

# Need to once again reestablish meta for the pivot
band_labels = batch.columns.values
out_cols = []
# To align with pandas pivot_table results, the columns should be generated in reverse order
for col in res_cols[::-1]:
for band in band_labels:
# [expr] adjusted labeling
#out_cols += [(str(col), str(band))]
# out_cols += [(str(col), str(band))]
out_cols += [(str(band[0]), str(band[1]))]

#import pdb; pdb.set_trace()
# [expr] added map_partitions meta assignment
#batch._meta = TapeFrame(columns=out_cols) # apply new meta
#apply new meta
batch = batch.map_partitions(TapeFrame, meta = TapeFrame(columns=band_labels))
# batch._meta = TapeFrame(columns=out_cols) # apply new meta
# apply new meta
batch = batch.map_partitions(TapeFrame, meta=TapeFrame(columns=band_labels))

# Flatten the columns to a new column per band
batch.columns = ["_".join(col) for col in batch.columns.values]
Expand Down Expand Up @@ -2415,8 +2423,8 @@ def sf2(self, sf_method="basic", argument_container=None, use_map=True):

# Inherit divisions information if known
if self.source.known_divisions and self.object.known_divisions:
pass # TODO: Can no longer directly set divisions
#result.divisions = self.source.divisions
pass # TODO: Can no longer directly set divisions
# result.divisions = self.source.divisions

return result

Expand Down
14 changes: 7 additions & 7 deletions src/tape/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
elemwise,
from_graph,
get_collection_type,
#from_dict,
# from_dict,
)
from dask_expr._collection import new_collection, from_dict
from dask_expr._expr import _emulate, ApplyConcatApply

#from .ensemble_frame import TapeFrame, TapeSeries
# from .ensemble_frame import TapeFrame, TapeSeries

SOURCE_FRAME_LABEL = "source" # Reserved label for source table
OBJECT_FRAME_LABEL = "object" # Reserved label for object table.
Expand All @@ -53,6 +53,7 @@
ArrowDatasetEngine as DaskArrowDatasetEngine,
)


class TapeSeries(pd.Series):
"""A barebones extension of a Pandas series to be used for underlying Ensemble data.
Expand Down Expand Up @@ -683,7 +684,7 @@ def set_index(
result: `tape._Frame`
The indexed frame
"""
result = super().set_index(other, drop, sorted, npartitions, divisions, inplace, sort, **kwargs)
result = super().set_index(other, drop, sorted, npartitions, divisions, sort, **kwargs)
return self._propagate_metadata(result)

def map_partitions(self, func, *args, **kwargs):
Expand Down Expand Up @@ -869,9 +870,6 @@ def repartition(
return self._propagate_metadata(result)





class EnsembleSeries(_Frame, dd.Series):
"""A barebones extension of a Dask Series for Ensemble data."""

Expand Down Expand Up @@ -978,7 +976,9 @@ def from_dict(
cl, data, npartitions, orient="columns", dtype=None, columns=None, label=None, ensemble=None
):
""""""
result = from_dict(data, npartitions=npartitions, orient=orient, dtype=dtype, columns=columns, constructor=TapeFrame)
result = from_dict(
data, npartitions=npartitions, orient=orient, dtype=dtype, columns=columns, constructor=TapeFrame
)
result.label = label
result.ensemble = ensemble
return result
Expand Down
23 changes: 10 additions & 13 deletions tests/tape_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,8 +850,8 @@ def test_insert(parquet_ensemble):
assert new_source.shape[0] == old_size + 10


def test_insert_paritioned(dask_client):
ens = Ensemble(client=dask_client)
def test_insert_paritioned():
ens = Ensemble()

# Create all fake source data with known divisions.
num_points = 1000
Expand Down Expand Up @@ -883,6 +883,7 @@ def test_insert_paritioned(dask_client):
new_times = [1.0, 1.1, 1.2, 1.3, 1.4]
new_fluxes = [2.0, 2.5, 3.0, 3.5, 4.0]
new_errs = [0.1, 0.05, 0.01, 0.05, 0.01]

ens.insert_sources(new_inds, new_bands, new_times, new_fluxes, new_errs, force_repartition=True)

# Check we did not increase the number of partitions and the points
Expand All @@ -906,6 +907,7 @@ def test_insert_paritioned(dask_client):

# Insert a bunch of points into the second partition.
new_inds = [8804, 8804, 8804, 8804, 8804]

ens.insert_sources(new_inds, new_bands, new_times, new_fluxes, new_errs, force_repartition=True)

# Check we did not increase the number of partitions and the points
Expand Down Expand Up @@ -1545,9 +1547,6 @@ def test_calc_nobs(data_fixture, request, by_band, multi_partition):
# Get the Ensemble from a fixture
ens = request.getfixturevalue(data_fixture)

#if data_fixture == "parquet_ensemble_with_divisions":
# import pdb; pdb.set_trace()

if multi_partition:
ens.source = ens.source.repartition(npartitions=3)

Expand Down Expand Up @@ -1700,10 +1699,10 @@ def test_select(dask_client):


@pytest.mark.parametrize("legacy", [True, False])
def test_assign(dask_client, legacy):
def test_assign(legacy):
"""Tests assign for column-manipulation, using Ensemble.assign when `legacy` is `True`,
and EnsembleFrame.assign when `legacy` is `False`."""
ens = Ensemble(client=dask_client)
ens = Ensemble()

num_points = 1000
all_bands = ["r", "g", "b", "i"]
Expand Down Expand Up @@ -1735,7 +1734,7 @@ def test_assign(dask_client, legacy):
assert new_source.iloc[i]["lower_bnd"] == expected

# Create a series directly from the table.
res_col = ens.source["band"] + "2"
res_col = ens.source["band"]
if legacy:
ens.assign(table="source", band2=res_col)
else:
Expand All @@ -1746,7 +1745,7 @@ def test_assign(dask_client, legacy):
# Check the values in the new column.
new_source = ens.source.compute() if not legacy else ens.compute(table="source")
for i in range(1000):
assert new_source.iloc[i]["band2"] == new_source.iloc[i]["band"] + "2"
assert new_source.iloc[i]["band2"] == new_source.iloc[i]["band"]


@pytest.mark.parametrize("zero_point", [("zp_mag", "zp_flux"), (25.0, 10**10)])
Expand Down Expand Up @@ -1858,6 +1857,7 @@ def test_bin_sources_day(dask_client):
custom_aggr={ens._time_col: "min"},
count_col="aggregated_bin_count",
)

new_source = ens.compute("source")
assert new_source.shape[0] == 6
assert new_source.shape[1] == 5
Expand Down Expand Up @@ -1989,9 +1989,6 @@ def my_mean(flux):
# An EnsembleFrame should be returned
assert isinstance(res, EnsembleFrame)


#import pdb; pdb.set_trace()

# Make sure we get all the expected columns
assert all([col in res.columns for col in ["result_g", "result_r"]])

Expand Down Expand Up @@ -2029,7 +2026,7 @@ def my_bounds(flux):

# [expr] need this TODO: investigate typing issue
filter_res.index = filter_res.index.astype("int")

assert (
res.loc[88472935274829959]["max_g"]
.compute()
Expand Down
3 changes: 1 addition & 2 deletions tests/tape_tests/test_ensemble_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def test_object_and_source_frame_propagation(data_fixture, request):
object_frame.copy(), on=[ens._id_col], suffixes=(None, "_drop_me")
)
cols_to_drop = [col for col in merged_frame.columns if "_drop_me" in col]
merged_frame = merged_frame.drop(cols_to_drop, axis=1).persist()
merged_frame = merged_frame.drop(columns=cols_to_drop).persist()
assert isinstance(merged_frame, SourceFrame)
assert merged_frame.label == SOURCE_LABEL
assert merged_frame.ensemble == ens
Expand Down Expand Up @@ -437,7 +437,6 @@ def test_partition_slicing(parquet_ensemble_with_divisions):
prior_src_len = len(ens.source)

# slice on object
#import pdb;pdb.set_trace()
ens.object.partitions[0:3].update_ensemble()
ens._lazy_sync_tables("all") # sync needed as len() won't trigger one

Expand Down

0 comments on commit 10be686

Please sign in to comment.