Skip to content

Commit

Permalink
reduce scope of sync_tables
Browse files Browse the repository at this point in the history
  • Loading branch information
dougbrn committed Oct 4, 2023
1 parent 0f962c5 commit 3a6e3fb
Show file tree
Hide file tree
Showing 11 changed files with 175 additions and 901 deletions.
7 changes: 6 additions & 1 deletion docs/examples/rrlyr-period.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
"version": "3.10.11"
},
"vscode": {
"interpreter": {
"hash": "83afbb17b435d9bf8b0d0042367da76f26510da1c5781f0ff6e6c518eab621ec"
}
}
},
"nbformat": 4,
Expand Down
1 change: 1 addition & 0 deletions docs/gettingstarted/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
"metadata": {},
"outputs": [],
"source": [
"ens.calc_nobs() # calculates number of observations, produces \"nobs_total\" column \n",
"ens = ens.query(\"nobs_total >= 95 & nobs_total <= 105\", \"object\")"
]
},
Expand Down
44 changes: 22 additions & 22 deletions docs/tutorials/binning_slowly_changing_sources.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@
"outputs": [],
"source": [
"fig, ax = plt.subplots(1, 1)\n",
"_ = ax.hist(ens._source[\"midPointTai\"].compute().tolist(), 500)\n",
"_ = ax.set_xlabel(\"Time (MJD)\")\n",
"_ = ax.set_ylabel(\"Source Count\")"
"ax.hist(ens._source[\"midPointTai\"].compute().tolist(), 500)\n",
"ax.set_xlabel(\"Time (MJD)\")\n",
"ax.set_ylabel(\"Source Count\")"
]
},
{
Expand Down Expand Up @@ -90,9 +90,9 @@
"source": [
"ens.bin_sources(time_window=7.0, offset=0.0)\n",
"fig, ax = plt.subplots(1, 1)\n",
"_ = ax.hist(ens._source[\"midPointTai\"].compute().tolist(), 500)\n",
"_ = ax.set_xlabel(\"Time (MJD)\")\n",
"_ = ax.set_ylabel(\"Source Count\")"
"ax.hist(ens._source[\"midPointTai\"].compute().tolist(), 500)\n",
"ax.set_xlabel(\"Time (MJD)\")\n",
"ax.set_ylabel(\"Source Count\")"
]
},
{
Expand Down Expand Up @@ -120,9 +120,9 @@
"source": [
"ens.bin_sources(time_window=28.0, offset=0.0, custom_aggr={\"midPointTai\": \"min\"})\n",
"fig, ax = plt.subplots(1, 1)\n",
"_ = ax.hist(ens._source[\"midPointTai\"].compute().tolist(), 500)\n",
"_ = ax.set_xlabel(\"Time (MJD)\")\n",
"_ = ax.set_ylabel(\"Source Count\")"
"ax.hist(ens._source[\"midPointTai\"].compute().tolist(), 500)\n",
"ax.set_xlabel(\"Time (MJD)\")\n",
"ax.set_ylabel(\"Source Count\")"
]
},
{
Expand Down Expand Up @@ -150,9 +150,9 @@
"ens.from_source_dict(rows, column_mapper=cmap)\n",
"\n",
"fig, ax = plt.subplots(1, 1)\n",
"_ = ax.hist(ens._source[\"midPointTai\"].compute().tolist(), 60)\n",
"_ = ax.set_xlabel(\"Time (MJD)\")\n",
"_ = ax.set_ylabel(\"Source Count\")"
"ax.hist(ens._source[\"midPointTai\"].compute().tolist(), 60)\n",
"ax.set_xlabel(\"Time (MJD)\")\n",
"ax.set_ylabel(\"Source Count\")"
]
},
{
Expand All @@ -179,9 +179,9 @@
"ens.bin_sources(time_window=1.0, offset=0.0)\n",
"\n",
"fig, ax = plt.subplots(1, 1)\n",
"_ = ax.hist(ens._source[\"midPointTai\"].compute().tolist(), 60)\n",
"_ = ax.set_xlabel(\"Time (MJD)\")\n",
"_ = ax.set_ylabel(\"Source Count\")"
"ax.hist(ens._source[\"midPointTai\"].compute().tolist(), 60)\n",
"ax.set_xlabel(\"Time (MJD)\")\n",
"ax.set_ylabel(\"Source Count\")"
]
},
{
Expand Down Expand Up @@ -209,9 +209,9 @@
"ens.bin_sources(time_window=1.0, offset=0.5)\n",
"\n",
"fig, ax = plt.subplots(1, 1)\n",
"_ = ax.hist(ens._source[\"midPointTai\"].compute().tolist(), 60)\n",
"_ = ax.set_xlabel(\"Time (MJD)\")\n",
"_ = ax.set_ylabel(\"Source Count\")"
"ax.hist(ens._source[\"midPointTai\"].compute().tolist(), 60)\n",
"ax.set_xlabel(\"Time (MJD)\")\n",
"ax.set_ylabel(\"Source Count\")"
]
},
{
Expand Down Expand Up @@ -259,9 +259,9 @@
"ens.bin_sources(time_window=1.0, offset=0.5)\n",
"\n",
"fig, ax = plt.subplots(1, 1)\n",
"_ = ax.hist(ens._source[\"midPointTai\"].compute().tolist(), 60)\n",
"_ = ax.set_xlabel(\"Time (MJD)\")\n",
"_ = ax.set_ylabel(\"Source Count\")"
"ax.hist(ens._source[\"midPointTai\"].compute().tolist(), 60)\n",
"ax.set_xlabel(\"Time (MJD)\")\n",
"ax.set_ylabel(\"Source Count\")"
]
},
{
Expand Down Expand Up @@ -290,7 +290,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.6"
},
"vscode": {
"interpreter": {
Expand Down
528 changes: 40 additions & 488 deletions docs/tutorials/structure_function_showcase.ipynb

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions docs/tutorials/tape_datasets.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,7 @@
" flux_col=\"psFlux\",\n",
" err_col=\"psFluxErr\",\n",
" band_col=\"filterName\",\n",
" nobs_total_col=\"nobs_total\",\n",
" nobs_band_cols=[\"nobs_g\", \"nobs_r\"])\n",
")\n",
"\n",
"# Read in data from a parquet file that contains source (timeseries) data\n",
"ens.from_parquet(source_file=f\"{rel_path}/source/test_source.parquet\",\n",
Expand Down
316 changes: 50 additions & 266 deletions docs/tutorials/working_with_the_ensemble.ipynb

Large diffs are not rendered by default.

129 changes: 47 additions & 82 deletions src/tape/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,20 @@ def __init__(self, client=True, **kwargs):
self._source_dirty = False # Source Dirty Flag
self._object_dirty = False # Object Dirty Flag

self._source_temp = [] # List of temporary columns in Source
self._object_temp = [] # List of temporary columns in Object

# Default to removing empty objects.
self.keep_empty_objects = kwargs.get("keep_empty_objects", False)

# Initialize critical column quantities
# Source
self._id_col = None
self._time_col = None
self._flux_col = None
self._err_col = None
self._band_col = None
self._provenance_col = None

# Object, _id_col is shared
self._nobs_tot_col = None
self._nobs_band_cols = []

self.client = None
self.cleanup_client = False

Expand Down Expand Up @@ -510,7 +508,7 @@ def coalesce_partition(df, input_cols, output_col):

return self

def calc_nobs(self, by_band=False, label="nobs"):
def calc_nobs(self, by_band=False, label="nobs", temporary=True):
"""Calculates the number of observations per lightcurve.
Parameters
Expand All @@ -521,6 +519,13 @@ def calc_nobs(self, by_band=False, label="nobs"):
label: `str`, optional
The label used to generate output columns. "_total" and the band
labels (e.g. "_g") are appended.
temporary: 'bool', optional
Dictates whether the resulting columns are flagged as "temporary"
columns within the Ensemble. Temporary columns are dropped when
table syncs are performed, as their information is often made
invalid by future operations. For example, the number of
observations information is made invalid by a filter on the source
table. Defaults to True.
Returns
-------
Expand All @@ -547,11 +552,17 @@ def calc_nobs(self, by_band=False, label="nobs"):
bands = band_counts.columns.values
self._object = self._object.assign(**{label + "_" + band: band_counts[band] for band in bands})

if temporary:
self._object_temp.extend([label + "_" + band for band in bands])

else:
counts = self._source.groupby([self._id_col])[self._band_col].aggregate("count")
counts = counts.repartition(obj_npartitions) # counts inherits the source partitions
self._object = self._object.assign(**{label + "_total": counts}) # assign new columns

if temporary:
self._object_temp.extend([label + "_total"])

return self

def prune(self, threshold=50, col_name=None):
Expand All @@ -563,19 +574,24 @@ def prune(self, threshold=50, col_name=None):
The minimum number of observations needed to retain an object.
Default is 50.
col_name: `str`, optional
The name of the column to assess the threshold
The name of the column to assess the threshold if available in
the object table. If not specified, the ensemble will calculate
the number of observations and filter on the total (sum across
bands).
Returns
-------
ensemble: `tape.ensemble.Ensemble`
The ensemble object with pruned rows removed
"""
if not col_name:
col_name = self._nobs_tot_col

# Sync Required if source is dirty
self._lazy_sync_tables(table="object")

if not col_name:
self.calc_nobs(label="nobs")
col_name = "nobs_total"

# Mask on object table
mask = self._object[col_name] >= threshold
self._object = self._object[mask]
Expand Down Expand Up @@ -952,21 +968,9 @@ def from_dask_dataframe(

if object_frame is None: # generate an indexed object table from source
self._object = self._generate_object_table()
self._nobs_bands = [col for col in list(self._object.columns) if col != self._nobs_tot_col]

else:
self._object = object_frame
if self._nobs_band_cols is None:
# sets empty nobs cols in object
unq_filters = np.unique(self._source[self._band_col])
self._nobs_band_cols = [f"nobs_{filt}" for filt in unq_filters]
for col in self._nobs_band_cols:
self._object[col] = np.nan

# Handle nobs_total column
if self._nobs_tot_col is None:
self._object["nobs_total"] = np.nan
self._nobs_tot_col = "nobs_total"

self._object = self._object.set_index(self._id_col)

# Optionally sync the tables, recalculates nobs columns
Expand Down Expand Up @@ -1037,8 +1041,6 @@ def make_column_map(self):
err_col=self._err_col,
band_col=self._band_col,
provenance_col=self._provenance_col,
nobs_total_col=self._nobs_tot_col,
nobs_band_cols=self._nobs_band_cols,
)
return result

Expand Down Expand Up @@ -1100,10 +1102,6 @@ def _load_column_mapper(self, column_mapper, **kwargs):
# Assign optional columns if provided
if column_mapper.map["provenance_col"] is not None:
self._provenance_col = column_mapper.map["provenance_col"]
if column_mapper.map["nobs_total_col"] is not None:
self._nobs_tot_col = column_mapper.map["nobs_total_col"]
if column_mapper.map["nobs_band_cols"] is not None:
self._nobs_band_cols = column_mapper.map["nobs_band_cols"]

else:
raise ValueError(f"Missing required column mapping information: {needed}")
Expand Down Expand Up @@ -1170,11 +1168,6 @@ def from_parquet(
columns = [self._time_col, self._flux_col, self._err_col, self._band_col]
if self._provenance_col is not None:
columns.append(self._provenance_col)
if self._nobs_tot_col is not None:
columns.append(self._nobs_tot_col)
if self._nobs_band_cols is not None:
for col in self._nobs_band_cols:
columns.append(col)

# Read in the source parquet file(s)
source = dd.read_parquet(source_file, index=self._id_col, columns=columns, split_row_groups=True)
Expand Down Expand Up @@ -1360,47 +1353,10 @@ def convert_flux_to_mag(self, zero_point, zp_form="mag", out_col_name=None, flux
return self

def _generate_object_table(self):
"""Generate the object table from the source table."""
counts = self._source.groupby([self._id_col, self._band_col])[self._time_col].aggregate("count")
res = (
counts.to_frame()
.reset_index()
.categorize(columns=[self._band_col])
.pivot_table(values=self._time_col, index=self._id_col, columns=self._band_col, aggfunc="sum")
)

# If the ensemble's keep_empty_objects attribute is True and there are previous
# objects, then copy them into the res table with counts of zero.
if self.keep_empty_objects and self._object is not None:
prev_partitions = self._object.npartitions

# Check that there are existing object ids.
object_inds = self._object.index.unique().values.compute()
if len(object_inds) > 0:
# Determine which object IDs are missing from the source table.
source_inds = self._source.index.unique().values.compute()
missing_inds = np.setdiff1d(object_inds, source_inds).tolist()

# Create a dataframe of the missing IDs with zeros for all bands and counts.
rows = {self._id_col: missing_inds}
for i in res.columns.values:
rows[i] = [0] * len(missing_inds)

zero_pdf = pd.DataFrame(rows, dtype=int).set_index(self._id_col)
zero_ddf = dd.from_pandas(zero_pdf, sort=True, npartitions=1)

# Concatonate the zero dataframe onto the results.
res = dd.concat([res, zero_ddf], interleave_partitions=True).astype(int)
res = res.repartition(npartitions=prev_partitions)

# Rename bands to nobs_[band]
band_cols = {col: f"nobs_{col}" for col in list(res.columns)}
res = res.rename(columns=band_cols)

# Add total nobs by summing across each band.
if self._nobs_tot_col is None:
self._nobs_tot_col = "nobs_total"
res[self._nobs_tot_col] = res.sum(axis=1)
"""Generate an empty object table from the source table."""
sor_idx = self._source.index.unique()
obj_df = pd.DataFrame(index=sor_idx)
res = dd.from_pandas(obj_df, npartitions=int(np.ceil(self._source.npartitions / 100)))

return res

Expand Down Expand Up @@ -1438,15 +1394,24 @@ def _sync_tables(self):
self._source = self._source.map_partitions(lambda x: x[x.index.isin(obj_idx)])
self._source = self._source.persist() # persist the source frame

# Drop Temporary Source Columns on Sync
if len(self._source_temp):
self._source.drop(columns=self._source_temp)
print(f"Temporary columns dropped from Source Table: {self._source_temp}")
self._source_temp = []

Check warning on line 1401 in src/tape/ensemble.py

View check run for this annotation

Codecov / codecov/patch

src/tape/ensemble.py#L1399-L1401

Added lines #L1399 - L1401 were not covered by tests

if self._source_dirty: # not elif
# Generate a new object table; updates n_obs, removes missing ids
new_obj = self._generate_object_table()

# Join old obj to new obj; pulls in other existing obj columns
self._object = new_obj.join(self._object, on=self._id_col, how="left", lsuffix="", rsuffix="_old")
old_cols = [col for col in list(self._object.columns) if "_old" in col]
self._object = self._object.drop(old_cols, axis=1)
self._object = self._object.persist() # persist object
if not self.keep_empty_objects:
# Sync Source to Object; remove any objects that do not have sources
sor_idx = list(self._object.index.unique().compute())
self._object = self._object.map_partitions(lambda x: x[x.index.isin(sor_idx)])
self._object = self._object.persist() # persist the object frame

# Drop Temporary Object Columns on Sync
if len(self._object_temp):
self._object.drop(columns=self._object_temp)
print(f"Temporary columns dropped from Object Table: {self._object_temp}")
self._object_temp = []

# Now synced and clean
self._source_dirty = False
Expand Down
Loading

0 comments on commit 3a6e3fb

Please sign in to comment.