Skip to content

Commit

Permalink
Merge pull request #283 from lincc-frameworks/lazy_object_generation
Browse files Browse the repository at this point in the history
Check Divisions, Enable Lazy Sync Operations on Divisions-enabled Ensembles, Overhaul Object Table Generation
  • Loading branch information
dougbrn authored Nov 14, 2023
2 parents 4243b85 + 996b1a2 commit 1082730
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 47 deletions.
84 changes: 61 additions & 23 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,15 +611,16 @@ def calc_nobs(self, by_band=False, label="nobs", temporary=True):
self._source.groupby([self._id_col])[self._band_col] # group by each object
.value_counts() # count occurence of each band
.to_frame() # convert series to dataframe
.rename(columns={self._band_col: "counts"}) # rename column
.reset_index() # break up the multiindex
.categorize(columns=[self._band_col]) # retype the band labels as categories
.pivot_table(values=self._band_col, index=self._id_col, columns=self._band_col, aggfunc="sum")
) # the pivot_table call makes each band_count a column of the id_col row

# repartition the result to align with object
if self._object.known_divisions:
self._object.divisions = tuple([None for i in range(self._object.npartitions + 1)])
band_counts = band_counts.repartition(npartitions=self._object.npartitions)
band_counts.divisions = self._source.divisions
band_counts = band_counts.repartition(divisions=self._object.divisions)
else:
band_counts = band_counts.repartition(npartitions=self._object.npartitions)

Expand All @@ -636,9 +637,9 @@ def calc_nobs(self, by_band=False, label="nobs", temporary=True):
counts = self._source.groupby([self._id_col])[[self._band_col]].aggregate("count")

# repartition the result to align with object
if self._object.known_divisions:
self._object.divisions = tuple([None for i in range(self._object.npartitions + 1)])
counts = counts.repartition(npartitions=self._object.npartitions)
if self._object.known_divisions and self._source.known_divisions:
counts.divisions = self._source.divisions
counts = counts.repartition(divisions=self._object.divisions)
else:
counts = counts.repartition(npartitions=self._object.npartitions)

Expand Down Expand Up @@ -677,8 +678,7 @@ def prune(self, threshold=50, col_name=None):
col_name = "nobs_total"

# Mask on object table
mask = self._object[col_name] >= threshold
self._object = self._object[mask]
self = self.query(f"{col_name} >= {threshold}", table="object")

self._object_dirty = True # Object Table is now dirty

Expand Down Expand Up @@ -953,6 +953,11 @@ def s2n_inter_quartile_range(flux, err):
meta=meta,
)

# Inherit divisions if known from source and the resulting index is the id
# Groupby on index should always return a subset that adheres to the same divisions criteria
if self._source.known_divisions and batch.index.name == self._id_col:
batch.divisions = self._source.divisions

if compute:
return batch.compute()
else:
Expand Down Expand Up @@ -1078,6 +1083,12 @@ def from_dask_dataframe(
elif partition_size:
self._source = self._source.repartition(partition_size=partition_size)

# Check that Divisions are established, warn if not.
for name, table in [("object", self._object), ("source", self._source)]:
if not table.known_divisions:
warnings.warn(
f"Divisions for {name} are not set, certain downstream dask operations may fail as a result. We recommend setting the `sort` or `sorted` flags when loading data to establish division information."
)
return self

def from_hipscat(self, dir, source_subdir="source", object_subdir="object", column_mapper=None, **kwargs):
Expand Down Expand Up @@ -1272,7 +1283,10 @@ def from_parquet(
columns.append(self._provenance_col)

# Read in the source parquet file(s)
source = dd.read_parquet(source_file, index=self._id_col, columns=columns, split_row_groups=True)
# Index is set False so that we can set it with a future set_index call
# This has the advantage of letting Dask set partition boundaries based
# on the divisions between the sources of different objects.
source = dd.read_parquet(source_file, index=False, columns=columns, split_row_groups=True)

# Generate a provenance column if not provided
if self._provenance_col is None:
Expand All @@ -1282,7 +1296,9 @@ def from_parquet(
object = None
if object_file:
# Read in the object file(s)
object = dd.read_parquet(object_file, index=self._id_col, split_row_groups=True)
# Index is False so that we can set it with a future set_index call
# More meaningful for source than object but parity seems good here
object = dd.read_parquet(object_file, index=False, split_row_groups=True)
return self.from_dask_dataframe(
source_frame=source,
object_frame=object,
Expand Down Expand Up @@ -1468,9 +1484,7 @@ def convert_flux_to_mag(self, zero_point, zp_form="mag", out_col_name=None, flux

def _generate_object_table(self):
"""Generate an empty object table from the source table."""
sor_idx = self._source.index.unique()
obj_df = pd.DataFrame(index=sor_idx)
res = dd.from_pandas(obj_df, npartitions=int(np.ceil(self._source.npartitions / 100)))
res = self._source.map_partitions(lambda x: pd.DataFrame(index=x.index.unique()))

return res

Expand Down Expand Up @@ -1504,9 +1518,18 @@ def _sync_tables(self):

if self._object_dirty:
# Sync Object to Source; remove any missing objects from source
obj_idx = list(self._object.index.compute())
self._source = self._source.map_partitions(lambda x: x[x.index.isin(obj_idx)])
self._source = self._source.persist() # persist the source frame

if self._object.known_divisions and self._source.known_divisions:
# Lazily Create an empty object table (just index) for joining
empty_obj = self._object.map_partitions(lambda x: pd.DataFrame(index=x.index))

# Join source onto the empty object table to align
self._source = empty_obj.join(self._source)
else:
warnings.warn("Divisions are not known, syncing using a non-lazy method.")
obj_idx = list(self._object.index.compute())
self._source = self._source.map_partitions(lambda x: x[x.index.isin(obj_idx)])
self._source = self._source.persist() # persist the source frame

# Drop Temporary Source Columns on Sync
if len(self._source_temp):
Expand All @@ -1516,10 +1539,19 @@ def _sync_tables(self):

if self._source_dirty: # not elif
if not self.keep_empty_objects:
# Sync Source to Object; remove any objects that do not have sources
sor_idx = list(self._source.index.unique().compute())
self._object = self._object.map_partitions(lambda x: x[x.index.isin(sor_idx)])
self._object = self._object.persist() # persist the object frame
if self._object.known_divisions and self._source.known_divisions:
# Lazily Create an empty source table (just unique indexes) for joining
empty_src = self._source.map_partitions(lambda x: pd.DataFrame(index=x.index.unique()))

# Join object onto the empty unique source table to align
self._object = empty_src.join(self._object)

else:
warnings.warn("Divisions are not known, syncing using a non-lazy method.")
# Sync Source to Object; remove any objects that do not have sources
sor_idx = list(self._source.index.unique().compute())
self._object = self._object.map_partitions(lambda x: x[x.index.isin(sor_idx)])
self._object = self._object.persist() # persist the object frame

# Drop Temporary Object Columns on Sync
if len(self._object_temp):
Expand Down Expand Up @@ -1619,7 +1651,7 @@ def _build_index(self, obj_id, band):
index = pd.MultiIndex.from_tuples(tuples, names=["object_id", "band", "index"])
return index

def sf2(self, sf_method="basic", argument_container=None, use_map=True):
def sf2(self, sf_method="basic", argument_container=None, use_map=True, compute=True):
"""Wrapper interface for calling structurefunction2 on the ensemble
Parameters
Expand Down Expand Up @@ -1661,8 +1693,14 @@ def sf2(self, sf_method="basic", argument_container=None, use_map=True):
self._source.index,
argument_container=argument_container,
)
return result

else:
result = self.batch(calc_sf2, use_map=use_map, argument_container=argument_container)
result = self.batch(
calc_sf2, use_map=use_map, argument_container=argument_container, compute=compute
)

# Inherit divisions information if known
if self._source.known_divisions and self._object.known_divisions:
result.divisions = self._source.divisions

return result
return result
19 changes: 19 additions & 0 deletions tests/tape_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,25 @@ def parquet_ensemble(dask_client):
return ens


# pylint: disable=redefined-outer-name
@pytest.fixture
def parquet_ensemble_with_divisions(dask_client):
"""Create an Ensemble from parquet data."""
ens = Ensemble(client=dask_client)
ens.from_parquet(
"tests/tape_tests/data/source/test_source.parquet",
"tests/tape_tests/data/object/test_object.parquet",
id_col="ps1_objid",
time_col="midPointTai",
band_col="filterName",
flux_col="psFlux",
err_col="psFluxErr",
sort=True,
)

return ens


# pylint: disable=redefined-outer-name
@pytest.fixture
def parquet_ensemble_from_source(dask_client):
Expand Down
Loading

0 comments on commit 1082730

Please sign in to comment.