Skip to content

Commit

Permalink
fixing more expressions issues
Browse files Browse the repository at this point in the history
  • Loading branch information
dougbrn committed Mar 19, 2024
1 parent 949a85d commit a04cee6
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 8 deletions.
12 changes: 9 additions & 3 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings
import requests
import lsdb
#import dask_expr as dd
import dask.dataframe as dd
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -958,7 +959,7 @@ def bin_sources(
if tmp_time_col in self.source.columns:
raise KeyError(f"Column '{tmp_time_col}' already exists in source table.")
self.source[tmp_time_col] = self.source[self._time_col].apply(
lambda x: np.floor((x + offset) / time_window) * time_window, meta=pd.Series(dtype=float)
lambda x: np.floor((x + offset) / time_window) * time_window, meta=TapeSeries(dtype=float)
)

# Set up the aggregation functions for the time and flux columns.
Expand Down Expand Up @@ -1364,6 +1365,10 @@ def save_ensemble(self, path=".", dirname="ensemble", additional_frames=True, **

# Now write out the frames to subdirectories
for subdir in created_subdirs:
# TODO: Figure this out, peek at the real meta as a stop gap
# TODO: It may be best to make sure batch returns valid index names
idx_name = self.frames[subdir].head(1).index.name
self.frames[subdir].index = self.frames[subdir].index.rename(idx_name)
self.frames[subdir].to_parquet(os.path.join(ens_path, subdir), write_metadata_file=True, **kwargs)

print(f"Saved to {os.path.join(path, dirname)}")
Expand Down Expand Up @@ -2262,7 +2267,7 @@ def select_random_timeseries(self, seed=None):

# Scan through the shuffled partition list until a partition with data is found
while not object_selected:
partition_index = self.object.partitions[partitions[i]].index
partition_index = self.object.partitions[int(partitions[i])].index
# Check for empty partitions
if len(partition_index) > 0:
lcid = rng.choice(partition_index.values) # randomly select lightcurve
Expand Down Expand Up @@ -2410,7 +2415,8 @@ def sf2(self, sf_method="basic", argument_container=None, use_map=True):

# Inherit divisions information if known
if self.source.known_divisions and self.object.known_divisions:
result.divisions = self.source.divisions
pass # TODO: Can no longer directly set divisions
#result.divisions = self.source.divisions

return result

Expand Down
2 changes: 1 addition & 1 deletion src/tape/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def partitions(self):
A TAPE EnsembleFrame Object
"""
self.set_dirty(True)
return IndexCallable(self._partitions, self.is_dirty(), self.ensemble)
return IndexCallable(self._partitions, self.is_dirty(), self.ensemble, self.label)

def optimize(self, fuse: bool = True):
result = new_collection(self.expr.optimize(fuse=fuse))
Expand Down
6 changes: 4 additions & 2 deletions src/tape/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@ class IndexCallable:
4
"""

__slots__ = ("fn", "dirty", "ensemble")
__slots__ = ("fn", "dirty", "ensemble", "label")

def __init__(self, fn, dirty, ensemble):
def __init__(self, fn, dirty, ensemble, label):
self.fn = fn
self.dirty = dirty # propagate metadata
self.ensemble = ensemble # propagate ensemble metadata
self.label = label # propagate label

def __getitem__(self, key):
result = self.fn(key)
result.dirty = self.dirty # propagate metadata
result.ensemble = self.ensemble # propagate ensemble metadata
result.label = self.label # propagate label
return result
16 changes: 16 additions & 0 deletions tests/tape_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,21 @@ def parquet_ensemble_partition_size():
def parquet_ensemble_with_divisions():
"""Create an Ensemble from parquet data."""
ens = Ensemble(client=False)

source_ddf = dd.read_parquet("tests/tape_tests/data/source/test_source.parquet",
calculate_divisions=True).repartition(npartitions=3)
object_ddf = dd.read_parquet("tests/tape_tests/data/object/test_object.parquet",
calculate_divisions=True).repartition(npartitions=2)

ens.from_dask_dataframe(source_ddf,
object_ddf,
id_col="ps1_objid",
time_col="midPointTai",
band_col="filterName",
flux_col="psFlux",
err_col="psFluxErr",
sorted=True,)
"""
ens.from_parquet(
"tests/tape_tests/data/source/test_source.parquet",
"tests/tape_tests/data/object/test_object.parquet",
Expand All @@ -313,6 +328,7 @@ def parquet_ensemble_with_divisions():
err_col="psFluxErr",
sort=True,
)
"""

return ens

Expand Down
5 changes: 4 additions & 1 deletion tests/tape_tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1545,8 +1545,11 @@ def test_calc_nobs(data_fixture, request, by_band, multi_partition):
# Get the Ensemble from a fixture
ens = request.getfixturevalue(data_fixture)

#if data_fixture == "parquet_ensemble_with_divisions":
# import pdb; pdb.set_trace()

if multi_partition:
ens.source = ens.source.repartition(3)
ens.source = ens.source.repartition(npartitions=3)

# Drop the existing nobs columns
ens.object = ens.object.drop(["nobs_g", "nobs_r", "nobs_total"], axis=1)
Expand Down
1 change: 1 addition & 0 deletions tests/tape_tests/test_ensemble_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def test_partition_slicing(parquet_ensemble_with_divisions):
prior_src_len = len(ens.source)

# slice on object
#import pdb;pdb.set_trace()
ens.object.partitions[0:3].update_ensemble()
ens._lazy_sync_tables("all") # sync needed as len() won't trigger one

Expand Down
3 changes: 2 additions & 1 deletion tests/tape_tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@ def test_index_callable(parquet_ensemble):

ens = parquet_ensemble

source_ic = IndexCallable(ens.source._partitions, True, "ensemble")
source_ic = IndexCallable(ens.source._partitions, True, "ensemble", ens.source.label)

# grab the first (and only) source partition
sliced_source_frame = source_ic[0]

# ensure that the metadata propagates to the result
assert sliced_source_frame.dirty is True
assert sliced_source_frame.ensemble == "ensemble"
assert sliced_source_frame.label == "source"


def test_column_mapper():
Expand Down

0 comments on commit a04cee6

Please sign in to comment.