diff --git a/src/tape/ensemble.py b/src/tape/ensemble.py index f0cb4540..9ea7234b 100644 --- a/src/tape/ensemble.py +++ b/src/tape/ensemble.py @@ -49,9 +49,6 @@ def __init__(self, client=True, **kwargs): """ self.result = None # holds the latest query - self._source = None # Source Table - self._object = None # Object Table - self.frames = {} # Frames managed by this Ensemble, keyed by label # A unique ID to allocate new result frame labels. @@ -63,9 +60,6 @@ def __init__(self, client=True, **kwargs): self._source_temp = [] # List of temporary columns in Source self._object_temp = [] # List of temporary columns in Object - self._source_temp = [] # List of temporary columns in Source - self._object_temp = [] # List of temporary columns in Object - # Default to removing empty objects. self.keep_empty_objects = kwargs.get("keep_empty_objects", False) @@ -154,10 +148,8 @@ def update_frame(self, frame): if frame.label != expected_label: raise ValueError(f"Unable to update frame with reserved label " f"'{frame.label}'") if isinstance(frame, SourceFrame): - self._source = frame self.source = frame elif isinstance(frame, ObjectFrame): - self._object = frame self.object = frame # Ensure this frame is assigned to this Ensemble. @@ -334,20 +326,20 @@ def insert_sources( df2 = df2.set_index(self._id_col, drop=True, sort=True) # Save the divisions and number of partitions. - prev_div = self._source.divisions - prev_num = self._source.npartitions + prev_div = self.source.divisions + prev_num = self.source.npartitions # Append the new rows to the correct divisions. - self.update_frame(dd.concat([self._source, df2], axis=0, interleave_partitions=True)) - self._source.set_dirty(True) + self.update_frame(dd.concat([self.source, df2], axis=0, interleave_partitions=True)) + self.source.set_dirty(True) # Do the repartitioning if requested. If the divisions were set, reuse them. # Otherwise, use the same number of partitions. if force_repartition: if all(prev_div): - self.update_frame(self._source.repartition(divisions=prev_div)) - elif self._source.npartitions != prev_num: - self._source = self._source.repartition(npartitions=prev_num) + self.update_frame(self.source.repartition(divisions=prev_div)) + elif self.source.npartitions != prev_num: + self.source = self.source.repartition(npartitions=prev_num) return self @@ -383,9 +375,9 @@ def info(self, verbose=True, memory_usage=True, **kwargs): self._lazy_sync_tables(table="all") print("Object Table") - self._object.info(verbose=verbose, memory_usage=memory_usage, **kwargs) + self.object.info(verbose=verbose, memory_usage=memory_usage, **kwargs) print("Source Table") - self._source.info(verbose=verbose, memory_usage=memory_usage, **kwargs) + self.source.info(verbose=verbose, memory_usage=memory_usage, **kwargs) def check_sorted(self, table="object"): """Checks to see if an Ensemble Dataframe is sorted (increasing) on @@ -402,9 +394,9 @@ def check_sorted(self, table="object"): or not (False) """ if table == "object": - idx = self._object.index + idx = self.object.index elif table == "source": - idx = self._source.index + idx = self.source.index else: raise ValueError(f"{table} is not one of 'object' or 'source'") @@ -428,7 +420,7 @@ def check_lightcurve_cohesion(self): across multiple partitions (False) """ - idx = self._source.index + idx = self.source.index counts = idx.map_partitions(lambda a: Counter(a.unique())).compute() unq_counter = counts[0] @@ -457,12 +449,12 @@ def compute(self, table=None, **kwargs): if table: self._lazy_sync_tables(table) if table == "object": - return self._object.compute(**kwargs) + return self.object.compute(**kwargs) elif table == "source": - return self._source.compute(**kwargs) + return self.source.compute(**kwargs) else: self._lazy_sync_tables(table="all") - return (self._object.compute(**kwargs), self._source.compute(**kwargs)) + return (self.object.compute(**kwargs), self.source.compute(**kwargs)) def persist(self, **kwargs): """Wrapper for dask.dataframe.DataFrame.persist() @@ -473,15 +465,15 @@ def persist(self, **kwargs): of the computation. """ self._lazy_sync_tables("all") - self.update_frame(self._object.persist(**kwargs)) - self.update_frame(self._source.persist(**kwargs)) + self.update_frame(self.object.persist(**kwargs)) + self.update_frame(self.source.persist(**kwargs)) def columns(self, table="object"): """Retrieve columns from dask dataframe""" if table == "object": - return self._object.columns + return self.object.columns elif table == "source": - return self._source.columns + return self.source.columns else: raise ValueError(f"{table} is not one of 'object' or 'source'") @@ -490,9 +482,9 @@ def head(self, table="object", n=5, **kwargs): self._lazy_sync_tables(table) if table == "object": - return self._object.head(n=n, **kwargs) + return self.object.head(n=n, **kwargs) elif table == "source": - return self._source.head(n=n, **kwargs) + return self.source.head(n=n, **kwargs) else: raise ValueError(f"{table} is not one of 'object' or 'source'") @@ -501,9 +493,9 @@ def tail(self, table="object", n=5, **kwargs): self._lazy_sync_tables(table) if table == "object": - return self._object.tail(n=n, **kwargs) + return self.object.tail(n=n, **kwargs) elif table == "source": - return self._source.tail(n=n, **kwargs) + return self.source.tail(n=n, **kwargs) else: raise ValueError(f"{table} is not one of 'object' or 'source'") @@ -526,9 +518,9 @@ def dropna(self, table="source", **kwargs): scheme """ if table == "object": - self.update_frame(self._object.dropna(**kwargs)) + self.update_frame(self.object.dropna(**kwargs)) elif table == "source": - self.update_frame(self._source.dropna(**kwargs)) + self.update_frame(self.source.dropna(**kwargs)) else: raise ValueError(f"{table} is not one of 'object' or 'source'") @@ -548,11 +540,11 @@ def select(self, columns, table="object"): """ self._lazy_sync_tables(table) if table == "object": - cols_to_drop = [col for col in self._object.columns if col not in columns] - self.update_frame(self._object.drop(cols_to_drop, axis=1)) + cols_to_drop = [col for col in self.object.columns if col not in columns] + self.update_frame(self.object.drop(cols_to_drop, axis=1)) elif table == "source": - cols_to_drop = [col for col in self._source.columns if col not in columns] - self.update_frame(self._source.drop(cols_to_drop, axis=1)) + cols_to_drop = [col for col in self.source.columns if col not in columns] + self.update_frame(self.source.drop(cols_to_drop, axis=1)) else: raise ValueError(f"{table} is not one of 'object' or 'source'") @@ -581,9 +573,9 @@ def query(self, expr, table="object"): """ self._lazy_sync_tables(table) if table == "object": - self.update_frame(self._object.query(expr)) + self.update_frame(self.object.query(expr)) elif table == "source": - self.update_frame(self._source.query(expr)) + self.update_frame(self.source.query(expr)) return self def filter_from_series(self, keep_series, table="object"): @@ -601,10 +593,10 @@ def filter_from_series(self, keep_series, table="object"): """ self._lazy_sync_tables(table) if table == "object": - self.update_frame(self._object[keep_series]) + self.update_frame(self.object[keep_series]) elif table == "source": - self.update_frame(self._source[keep_series]) + self.update_frame(self.source[keep_series]) return self def assign(self, table="object", temporary=False, **kwargs): @@ -642,17 +634,17 @@ def assign(self, table="object", temporary=False, **kwargs): self._lazy_sync_tables(table) if table == "object": - pre_cols = self._object.columns - self.update_frame(self._object.assign(**kwargs)) - post_cols = self._object.columns + pre_cols = self.object.columns + self.update_frame(self.object.assign(**kwargs)) + post_cols = self.object.columns if temporary: self._object_temp.extend(col for col in post_cols if col not in pre_cols) elif table == "source": - pre_cols = self._source.columns - self.update_frame(self._source.assign(**kwargs)) - post_cols = self._source.columns + pre_cols = self.source.columns + self.update_frame(self.source.assign(**kwargs)) + post_cols = self.source.columns if temporary: self._source_temp.extend(col for col in post_cols if col not in pre_cols) @@ -687,9 +679,9 @@ def coalesce(self, input_cols, output_col, table="object", drop_inputs=False): """ # we shouldn't need to sync for this if table == "object": - table_ddf = self._object + table_ddf = self.object elif table == "source": - table_ddf = self._source + table_ddf = self.source else: raise ValueError(f"{table} is not one of 'object' or 'source'") @@ -777,27 +769,27 @@ def calc_nobs(self, by_band=False, label="nobs", temporary=True): if by_band: # repartition the result to align with object - if self._object.known_divisions: + if self.object.known_divisions: # Grab these up front to help out the task graph id_col = self._id_col band_col = self._band_col # Get the band metadata - unq_bands = np.unique(self._source[band_col]) + unq_bands = np.unique(self.source[band_col]) meta = {band: float for band in unq_bands} # Map the groupby to each partition - band_counts = self._source.map_partitions( + band_counts = self.source.map_partitions( lambda x: x.groupby(id_col)[[band_col]] .value_counts() .to_frame() .reset_index() .pivot_table(values=band_col, index=id_col, columns=band_col, aggfunc="sum"), meta=meta, - ).repartition(divisions=self._object.divisions) + ).repartition(divisions=self.object.divisions) else: band_counts = ( - self._source.groupby([self._id_col])[self._band_col] # group by each object + self.source.groupby([self._id_col])[self._band_col] # group by each object .value_counts() # count occurence of each band .to_frame() # convert series to dataframe .rename(columns={self._band_col: "counts"}) # rename column @@ -808,13 +800,13 @@ def calc_nobs(self, by_band=False, label="nobs", temporary=True): ) ) # the pivot_table call makes each band_count a column of the id_col row - band_counts = band_counts.repartition(npartitions=self._object.npartitions) + band_counts = band_counts.repartition(npartitions=self.object.npartitions) # short-hand for calculating nobs_total band_counts["total"] = band_counts[list(band_counts.columns)].sum(axis=1) bands = band_counts.columns.values - self._object = self._object.assign( + self.object = self.object.assign( **{label + "_" + str(band): band_counts[band] for band in bands} ) @@ -822,24 +814,24 @@ def calc_nobs(self, by_band=False, label="nobs", temporary=True): self._object_temp.extend(label + "_" + str(band) for band in bands) else: - if self._object.known_divisions and self._source.known_divisions: + if self.object.known_divisions and self.source.known_divisions: # Grab these up front to help out the task graph id_col = self._id_col band_col = self._band_col # Map the groupby to each partition - counts = self._source.map_partitions( + counts = self.source.map_partitions( lambda x: x.groupby([id_col])[[band_col]].aggregate("count") - ).repartition(divisions=self._object.divisions) + ).repartition(divisions=self.object.divisions) else: # Just do a groupby on all source counts = ( - self._source.groupby([self._id_col])[[self._band_col]] + self.source.groupby([self._id_col])[[self._band_col]] .aggregate("count") - .repartition(npartitions=self._object.npartitions) + .repartition(npartitions=self.object.npartitions) ) - self._object = self._object.assign(**{label + "_total": counts[self._band_col]}) + self.object = self.object.assign(**{label + "_total": counts[self._band_col]}) if temporary: self._object_temp.extend([label + "_total"]) @@ -876,7 +868,7 @@ def prune(self, threshold=50, col_name=None): # Mask on object table self = self.query(f"{col_name} >= {threshold}", table="object") - self._object.set_dirty(True) # Object table is now dirty + self.object.set_dirty(True) # Object table is now dirty return self @@ -902,7 +894,7 @@ def find_day_gap_offset(self): self._lazy_sync_tables(table="source") # Compute a histogram of observations by hour of the day. - hours = self._source[self._time_col].apply( + hours = self.source[self._time_col].apply( lambda x: np.floor(x * 24.0).astype(int) % 24, meta=pd.Series(dtype=int) ) hour_counts = hours.value_counts().compute() @@ -978,9 +970,9 @@ def bin_sources( # Bin the time and add it as a column. We create a temporary column that # truncates the time into increments of `time_window`. tmp_time_col = "tmp_time_for_aggregation" - if tmp_time_col in self._source.columns: + if tmp_time_col in self.source.columns: raise KeyError(f"Column '{tmp_time_col}' already exists in source table.") - self._source[tmp_time_col] = self._source[self._time_col].apply( + self.source[tmp_time_col] = self.source[self._time_col].apply( lambda x: np.floor((x + offset) / time_window) * time_window, meta=pd.Series(dtype=float) ) @@ -988,7 +980,7 @@ def bin_sources( aggr_funs = {self._time_col: "mean", self._flux_col: "mean"} # If the source table has errors then add an aggregation function for it. - if self._err_col in self._source.columns: + if self._err_col in self.source.columns: aggr_funs[self._err_col] = dd.Aggregation( name="err_agg", chunk=lambda x: (x.count(), x.apply(lambda s: np.sum(np.power(s, 2)))), @@ -1000,8 +992,8 @@ def bin_sources( # adding an initial column of all ones if needed. if count_col is not None: self._bin_count_col = count_col - if self._bin_count_col not in self._source.columns: - self._source[self._bin_count_col] = self._source[self._time_col].apply( + if self._bin_count_col not in self.source.columns: + self.source[self._bin_count_col] = self.source[self._time_col].apply( lambda x: 1, meta=pd.Series(dtype=int) ) aggr_funs[self._bin_count_col] = "sum" @@ -1016,14 +1008,14 @@ def bin_sources( # Group the columns by id, band, and time bucket and aggregate. self.update_frame( - self._source.groupby([self._id_col, self._band_col, tmp_time_col]).aggregate(aggr_funs) + self.source.groupby([self._id_col, self._band_col, tmp_time_col]).aggregate(aggr_funs) ) # Fix the indices and remove the temporary column. - self.update_frame(self._source.reset_index().set_index(self._id_col).drop(tmp_time_col, axis=1)) + self.update_frame(self.source.reset_index().set_index(self._id_col).drop(tmp_time_col, axis=1)) # Mark the source table as dirty. - self._source.set_dirty(True) + self.source.set_dirty(True) return self def batch(self, func, *args, meta=None, use_map=True, compute=True, on=None, label="", **kwargs): @@ -1129,15 +1121,15 @@ def s2n_inter_quartile_range(flux, err): on = [on] # Convert to list if only one column is passed # Handle object columns to group on - source_cols = list(self._source.columns) - object_cols = list(self._object.columns) + source_cols = list(self.source.columns) + object_cols = list(self.object.columns) object_group_cols = [col for col in on if (col in object_cols) and (col not in source_cols)] if len(object_group_cols) > 0: - object_col_dd = self._object[object_group_cols] - source_to_batch = self._source.merge(object_col_dd, how="left") + object_col_dd = self.object[object_group_cols] + source_to_batch = self.source.merge(object_col_dd, how="left") else: - source_to_batch = self._source # Can directly use the source table + source_to_batch = self.source # Can directly use the source table id_col = self._id_col # pre-compute needed for dask in lambda function @@ -1162,8 +1154,8 @@ def s2n_inter_quartile_range(flux, err): # Inherit divisions if known from source and the resulting index is the id # Groupby on index should always return a subset that adheres to the same divisions criteria - if self._source.known_divisions and batch.index.name == self._id_col: - batch.divisions = self._source.divisions + if self.source.known_divisions and batch.index.name == self._id_col: + batch.divisions = self.source.divisions if label is not None: if label == "": @@ -1285,21 +1277,21 @@ def from_dask_dataframe( else: self.update_frame(ObjectFrame.from_dask_dataframe(object_frame, ensemble=self)) - self.update_frame(self._object.set_index(self._id_col, sorted=sorted, sort=sort)) + self.update_frame(self.object.set_index(self._id_col, sorted=sorted, sort=sort)) # Optionally sync the tables, recalculates nobs columns if sync_tables: - self._source.set_dirty(True) - self._object.set_dirty(True) + self.source.set_dirty(True) + self.object.set_dirty(True) self._sync_tables() if npartitions and npartitions > 1: - self._source = self._source.repartition(npartitions=npartitions) + self.source = self.source.repartition(npartitions=npartitions) elif partition_size: - self._source = self._source.repartition(partition_size=partition_size) + self.source = self.source.repartition(partition_size=partition_size) # Check that Divisions are established, warn if not. - for name, table in [("object", self._object), ("source", self._source)]: + for name, table in [("object", self.object), ("source", self.source)]: if not table.known_divisions: warnings.warn( f"Divisions for {name} are not set, certain downstream dask operations may fail as a result. We recommend setting the `sort` or `sorted` flags when loading data to establish division information." @@ -1670,25 +1662,25 @@ def convert_flux_to_mag(self, zero_point, zp_form="mag", out_col_name=None, flux if zp_form == "flux": # mag = -2.5*np.log10(flux/zp) if isinstance(zero_point, str): self.update_frame( - self._source.assign( + self.source.assign( **{out_col_name: lambda x: -2.5 * np.log10(x[flux_col] / x[zero_point])} ) ) else: self.update_frame( - self._source.assign(**{out_col_name: lambda x: -2.5 * np.log10(x[flux_col] / zero_point)}) + self.source.assign(**{out_col_name: lambda x: -2.5 * np.log10(x[flux_col] / zero_point)}) ) elif zp_form == "magnitude" or zp_form == "mag": # mag = -2.5*np.log10(flux) + zp if isinstance(zero_point, str): self.update_frame( - self._source.assign( + self.source.assign( **{out_col_name: lambda x: -2.5 * np.log10(x[flux_col]) + x[zero_point]} ) ) else: self.update_frame( - self._source.assign(**{out_col_name: lambda x: -2.5 * np.log10(x[flux_col]) + zero_point}) + self.source.assign(**{out_col_name: lambda x: -2.5 * np.log10(x[flux_col]) + zero_point}) ) else: raise ValueError(f"{zp_form} is not a valid zero_point format.") @@ -1696,7 +1688,7 @@ def convert_flux_to_mag(self, zero_point, zp_form="mag", out_col_name=None, flux # Calculate Errors if err_col is not None: self.update_frame( - self._source.assign( + self.source.assign( **{out_col_name + "_err": lambda x: (2.5 / np.log(10)) * (x[err_col] / x[flux_col])} ) ) @@ -1705,7 +1697,7 @@ def convert_flux_to_mag(self, zero_point, zp_form="mag", out_col_name=None, flux def _generate_object_table(self): """Generate an empty object table from the source table.""" - res = self._source.map_partitions(lambda x: TapeObjectFrame(index=x.index.unique())) + res = self.source.map_partitions(lambda x: TapeObjectFrame(index=x.index.unique())) return res @@ -1740,11 +1732,11 @@ def _lazy_sync_tables(self, table="object"): The table being modified. Should be one of "object", "source", or "all" """ - if table == "object" and self._source.is_dirty(): # object table should be updated + if table == "object" and self.source.is_dirty(): # object table should be updated self._sync_tables() - elif table == "source" and self._object.is_dirty(): # source table should be updated + elif table == "source" and self.object.is_dirty(): # source table should be updated self._sync_tables() - elif table == "all" and (self._source.is_dirty() or self._object.is_dirty()): + elif table == "all" and (self.source.is_dirty() or self.object.is_dirty()): self._sync_tables() return self @@ -1756,55 +1748,55 @@ def _sync_tables(self): keep_empty_objects attribute is set to True. """ - if self._object.is_dirty(): + if self.object.is_dirty(): # Sync Object to Source; remove any missing objects from source - if self._object.known_divisions and self._source.known_divisions: + if self.object.known_divisions and self.source.known_divisions: # Lazily Create an empty object table (just index) for joining - empty_obj = self._object.map_partitions(lambda x: TapeObjectFrame(index=x.index)) - if type(empty_obj) != type(self._object): + empty_obj = self.object.map_partitions(lambda x: TapeObjectFrame(index=x.index)) + if type(empty_obj) != type(self.object): raise ValueError("Bad type for empty_obj: " + str(type(empty_obj))) # Join source onto the empty object table to align - self.update_frame(self._source.join(empty_obj, how="inner")) + self.update_frame(self.source.join(empty_obj, how="inner")) else: warnings.warn("Divisions are not known, syncing using a non-lazy method.") - obj_idx = list(self._object.index.compute()) - self.update_frame(self._source.map_partitions(lambda x: x[x.index.isin(obj_idx)])) - self.update_frame(self._source.persist()) # persist the source frame + obj_idx = list(self.object.index.compute()) + self.update_frame(self.source.map_partitions(lambda x: x[x.index.isin(obj_idx)])) + self.update_frame(self.source.persist()) # persist the source frame # Drop Temporary Source Columns on Sync if len(self._source_temp): - self.update_frame(self._source.drop(columns=self._source_temp)) + self.update_frame(self.source.drop(columns=self._source_temp)) print(f"Temporary columns dropped from Source Table: {self._source_temp}") self._source_temp = [] - if self._source.is_dirty(): # not elif + if self.source.is_dirty(): # not elif if not self.keep_empty_objects: - if self._object.known_divisions and self._source.known_divisions: + if self.object.known_divisions and self.source.known_divisions: # Lazily Create an empty source table (just unique indexes) for joining - empty_src = self._source.map_partitions(lambda x: TapeSourceFrame(index=x.index.unique())) - if type(empty_src) != type(self._source): + empty_src = self.source.map_partitions(lambda x: TapeSourceFrame(index=x.index.unique())) + if type(empty_src) != type(self.source): raise ValueError("Bad type for empty_src: " + str(type(empty_src))) # Join object onto the empty unique source table to align - self.update_frame(self._object.join(empty_src, how="inner")) + self.update_frame(self.object.join(empty_src, how="inner")) else: warnings.warn("Divisions are not known, syncing using a non-lazy method.") # Sync Source to Object; remove any objects that do not have sources - sor_idx = list(self._source.index.unique().compute()) - self.update_frame(self._object.map_partitions(lambda x: x[x.index.isin(sor_idx)])) - self.update_frame(self._object.persist()) # persist the object frame + sor_idx = list(self.source.index.unique().compute()) + self.update_frame(self.object.map_partitions(lambda x: x[x.index.isin(sor_idx)])) + self.update_frame(self.object.persist()) # persist the object frame # Drop Temporary Object Columns on Sync if len(self._object_temp): - self.update_frame(self._object.drop(columns=self._object_temp)) + self.update_frame(self.object.drop(columns=self._object_temp)) print(f"Temporary columns dropped from Object Table: {self._object_temp}") self._object_temp = [] # Now synced and clean - self._source.set_dirty(False) - self._object.set_dirty(False) + self.source.set_dirty(False) + self.object.set_dirty(False) return self def to_timeseries( @@ -1857,7 +1849,7 @@ def to_timeseries( if band_col is None: band_col = self._band_col - df = self._source.loc[target].compute() + df = self.source.loc[target].compute() ts = TimeSeries().from_dataframe( data=df, object_id=target, @@ -1929,11 +1921,11 @@ def sf2(self, sf_method="basic", argument_container=None, use_map=True, compute= if argument_container.combine: result = calc_sf2( - self._source[self._time_col], - self._source[self._flux_col], - self._source[self._err_col], - self._source[self._band_col], - self._source.index, + self.source[self._time_col], + self.source[self._flux_col], + self.source[self._err_col], + self.source[self._band_col], + self.source.index, argument_container=argument_container, ) @@ -1943,8 +1935,8 @@ def sf2(self, sf_method="basic", argument_container=None, use_map=True, compute= ) # Inherit divisions information if known - if self._source.known_divisions and self._object.known_divisions: - result.divisions = self._source.divisions + if self.source.known_divisions and self.object.known_divisions: + result.divisions = self.source.divisions return result diff --git a/tests/tape_tests/test_ensemble.py b/tests/tape_tests/test_ensemble.py index f0118b5c..ef61d7e0 100644 --- a/tests/tape_tests/test_ensemble.py +++ b/tests/tape_tests/test_ensemble.py @@ -59,13 +59,13 @@ def test_parquet_construction(data_fixture, request): parquet_ensemble = request.getfixturevalue(data_fixture) # Check to make sure the source and object tables were created - assert parquet_ensemble._source is not None - assert parquet_ensemble._object is not None + assert parquet_ensemble.source is not None + assert parquet_ensemble.object is not None # Make sure divisions are set if data_fixture == "parquet_ensemble_with_divisions": - assert parquet_ensemble._source.known_divisions - assert parquet_ensemble._object.known_divisions + assert parquet_ensemble.source.known_divisions + assert parquet_ensemble.object.known_divisions # Check that the data is not empty. obj, source = parquet_ensemble.compute() @@ -84,7 +84,7 @@ def test_parquet_construction(data_fixture, request): parquet_ensemble._provenance_col, ]: # Check to make sure the critical quantity labels are bound to real columns - assert parquet_ensemble._source[col] is not None + assert parquet_ensemble.source[col] is not None @pytest.mark.parametrize( @@ -107,8 +107,8 @@ def test_dataframe_constructors(data_fixture, request): ens = request.getfixturevalue(data_fixture) # Check to make sure the source and object tables were created - assert ens._source is not None - assert ens._object is not None + assert ens.source is not None + assert ens.object is not None # Check that the data is not empty. obj, source = ens.compute() @@ -126,7 +126,7 @@ def test_dataframe_constructors(data_fixture, request): ens._band_col, ]: # Check to make sure the critical quantity labels are bound to real columns - assert ens._source[col] is not None + assert ens.source[col] is not None # Check that we can compute an analysis function on the ensemble. amplitude = ens.batch(calc_stetson_J) @@ -146,33 +146,33 @@ def test_update_ensemble(data_fixture, request): ens = request.getfixturevalue(data_fixture) # Filter the object table and have the ensemble track the updated table. - updated_obj = ens._object.query("nobs_total > 50") - assert updated_obj is not ens._object + updated_obj = ens.object.query("nobs_total > 50") + assert updated_obj is not ens.object assert updated_obj.is_dirty() # Update the ensemble and validate that it marks the object table dirty - assert ens._object.is_dirty() == False + assert ens.object.is_dirty() == False updated_obj.update_ensemble() - assert ens._object.is_dirty() == True - assert updated_obj is ens._object + assert ens.object.is_dirty() == True + assert updated_obj is ens.object # Filter the source table and have the ensemble track the updated table. - updated_src = ens._source.query("psFluxErr > 0.1") - assert updated_src is not ens._source + updated_src = ens.source.query("psFluxErr > 0.1") + assert updated_src is not ens.source # Update the ensemble and validate that it marks the source table dirty - assert ens._source.is_dirty() == False + assert ens.source.is_dirty() == False updated_src.update_ensemble() - assert ens._source.is_dirty() == True - assert updated_src is ens._source + assert ens.source.is_dirty() == True + assert updated_src is ens.source # Compute a result to trigger a table sync obj, src = ens.compute() assert len(obj) > 0 assert len(src) > 0 - assert ens._object.is_dirty() == False - assert ens._source.is_dirty() == False + assert ens.object.is_dirty() == False + assert ens.source.is_dirty() == False # Create an additional result table for the ensemble to track. - cnts = ens._source.groupby([ens._id_col, ens._band_col])[ens._time_col].aggregate("count") + cnts = ens.source.groupby([ens._id_col, ens._band_col])[ens._time_col].aggregate("count") res = ( cnts.to_frame() .reset_index() @@ -464,7 +464,7 @@ def test_read_source_dict(dask_client): def test_insert(parquet_ensemble): - num_partitions = parquet_ensemble._source.npartitions + num_partitions = parquet_ensemble.source.npartitions (old_object, old_source) = parquet_ensemble.compute() old_size = old_source.shape[0] @@ -486,7 +486,7 @@ def test_insert(parquet_ensemble): ) # Check we did not increase the number of partitions. - assert parquet_ensemble._source.npartitions == num_partitions + assert parquet_ensemble.source.npartitions == num_partitions # Check that all the new data points are in there. The order may be different # due to the repartitioning. @@ -515,7 +515,7 @@ def test_insert(parquet_ensemble): ) # Check we *did* increase the number of partitions and the size increased. - assert parquet_ensemble._source.npartitions != num_partitions + assert parquet_ensemble.source.npartitions != num_partitions (new_obj, new_source) = parquet_ensemble.compute() assert new_source.shape[0] == old_size + 10 @@ -544,8 +544,8 @@ def test_insert_paritioned(dask_client): # Save the old data for comparison. old_data = ens.compute("source") - old_div = copy.copy(ens._source.divisions) - old_sizes = [len(ens._source.partitions[i]) for i in range(4)] + old_div = copy.copy(ens.source.divisions) + old_sizes = [len(ens.source.partitions[i]) for i in range(4)] assert old_data.shape[0] == num_points # Test an insertion of 5 observations. @@ -558,12 +558,12 @@ def test_insert_paritioned(dask_client): # Check we did not increase the number of partitions and the points # were placed in the correct partitions. - assert ens._source.npartitions == 4 - assert ens._source.divisions == old_div - assert len(ens._source.partitions[0]) == old_sizes[0] + 3 - assert len(ens._source.partitions[1]) == old_sizes[1] - assert len(ens._source.partitions[2]) == old_sizes[2] + 2 - assert len(ens._source.partitions[3]) == old_sizes[3] + assert ens.source.npartitions == 4 + assert ens.source.divisions == old_div + assert len(ens.source.partitions[0]) == old_sizes[0] + 3 + assert len(ens.source.partitions[1]) == old_sizes[1] + assert len(ens.source.partitions[2]) == old_sizes[2] + 2 + assert len(ens.source.partitions[3]) == old_sizes[3] # Check that all the new data points are in there. The order may be different # due to the repartitioning. @@ -581,12 +581,12 @@ def test_insert_paritioned(dask_client): # Check we did not increase the number of partitions and the points # were placed in the correct partitions. - assert ens._source.npartitions == 4 - assert ens._source.divisions == old_div - assert len(ens._source.partitions[0]) == old_sizes[0] + 3 - assert len(ens._source.partitions[1]) == old_sizes[1] + 5 - assert len(ens._source.partitions[2]) == old_sizes[2] + 2 - assert len(ens._source.partitions[3]) == old_sizes[3] + assert ens.source.npartitions == 4 + assert ens.source.divisions == old_div + assert len(ens.source.partitions[0]) == old_sizes[0] + 3 + assert len(ens.source.partitions[1]) == old_sizes[1] + 5 + assert len(ens.source.partitions[2]) == old_sizes[2] + 2 + assert len(ens.source.partitions[3]) == old_sizes[3] def test_core_wrappers(parquet_ensemble): @@ -677,9 +677,9 @@ def test_persist(dask_client): ens.query("flux <= 1.5", table="source") # Compute the task graph size before and after the persist. - old_graph_size = len(ens._source.dask) + old_graph_size = len(ens.source.dask) ens.persist() - new_graph_size = len(ens._source.dask) + new_graph_size = len(ens.source.dask) assert new_graph_size < old_graph_size @@ -782,7 +782,7 @@ def test_sync_tables(data_fixture, request, legacy): filtered_src.update_ensemble() # Verify that the object ID we removed from the source table is present in the object table - assert dropped_obj_id in parquet_ensemble._object.index.compute().values + assert dropped_obj_id in parquet_ensemble.object.index.compute().values # Perform an operation which should trigger syncing both tables. parquet_ensemble.compute() @@ -824,8 +824,8 @@ def test_lazy_sync_tables(parquet_ensemble, legacy): # Modify only the object table. parquet_ensemble.prune(50, col_name="nobs_r").prune(50, col_name="nobs_g") - assert parquet_ensemble._object.is_dirty() - assert not parquet_ensemble._source.is_dirty() + assert parquet_ensemble.object.is_dirty() + assert not parquet_ensemble.source.is_dirty() # For a lazy sync on the object table, nothing should change, because # it is already dirty. @@ -833,34 +833,34 @@ def test_lazy_sync_tables(parquet_ensemble, legacy): parquet_ensemble.compute("object") else: parquet_ensemble.object.compute() - assert parquet_ensemble._object.is_dirty() - assert not parquet_ensemble._source.is_dirty() + assert parquet_ensemble.object.is_dirty() + assert not parquet_ensemble.source.is_dirty() # For a lazy sync on the source table, the source table should be updated. if legacy: parquet_ensemble.compute("source") else: parquet_ensemble.source.compute() - assert not parquet_ensemble._object.is_dirty() - assert not parquet_ensemble._source.is_dirty() + assert not parquet_ensemble.object.is_dirty() + assert not parquet_ensemble.source.is_dirty() # Modify only the source table. # Replace the maximum flux value with a NaN so that we will have a row to drop. - max_flux = max(parquet_ensemble._source[parquet_ensemble._flux_col]) - parquet_ensemble._source[parquet_ensemble._flux_col] = parquet_ensemble._source[ + max_flux = max(parquet_ensemble.source[parquet_ensemble._flux_col]) + parquet_ensemble.source[parquet_ensemble._flux_col] = parquet_ensemble.source[ parquet_ensemble._flux_col].apply( lambda x: np.nan if x == max_flux else x, meta=pd.Series(dtype=float) ) - assert not parquet_ensemble._object.is_dirty() - assert not parquet_ensemble._source.is_dirty() + assert not parquet_ensemble.object.is_dirty() + assert not parquet_ensemble.source.is_dirty() if legacy: parquet_ensemble.dropna(table="source") else: parquet_ensemble.source.dropna().update_ensemble() - assert not parquet_ensemble._object.is_dirty() - assert parquet_ensemble._source.is_dirty() + assert not parquet_ensemble.object.is_dirty() + assert parquet_ensemble.source.is_dirty() # For a lazy sync on the source table, nothing should change, because # it is already dirty. @@ -868,16 +868,16 @@ def test_lazy_sync_tables(parquet_ensemble, legacy): parquet_ensemble.compute("source") else: parquet_ensemble.source.compute() - assert not parquet_ensemble._object.is_dirty() - assert parquet_ensemble._source.is_dirty() + assert not parquet_ensemble.object.is_dirty() + assert parquet_ensemble.source.is_dirty() # For a lazy sync on the source, the object table should be updated. if legacy: parquet_ensemble.compute("object") else: parquet_ensemble.object.compute() - assert not parquet_ensemble._object.is_dirty() - assert not parquet_ensemble._source.is_dirty() + assert not parquet_ensemble.object.is_dirty() + assert not parquet_ensemble.source.is_dirty() def test_compute_triggers_syncing(parquet_ensemble): @@ -931,7 +931,7 @@ def test_temporary_cols(parquet_ensemble): """ ens = parquet_ensemble - ens.update_frame(ens._object.drop(columns=["nobs_r", "nobs_g", "nobs_total"])) + ens.update_frame(ens.object.drop(columns=["nobs_r", "nobs_g", "nobs_total"])) # Make sure temp lists are available but empty assert not len(ens._source_temp) @@ -941,29 +941,29 @@ def test_temporary_cols(parquet_ensemble): # nobs_total should be a temporary column assert "nobs_total" in ens._object_temp - assert "nobs_total" in ens._object.columns + assert "nobs_total" in ens.object.columns ens.assign(nobs2=lambda x: x["nobs_total"] * 2, table="object", temporary=True) # nobs2 should be a temporary column assert "nobs2" in ens._object_temp - assert "nobs2" in ens._object.columns + assert "nobs2" in ens.object.columns # drop NaNs from source, source should be dirty now ens.dropna(how="any", table="source") - assert ens._source.is_dirty() + assert ens.source.is_dirty() # try a sync ens._sync_tables() # nobs_total should be removed from object assert "nobs_total" not in ens._object_temp - assert "nobs_total" not in ens._object.columns + assert "nobs_total" not in ens.object.columns # nobs2 should be removed from object assert "nobs2" not in ens._object_temp - assert "nobs2" not in ens._object.columns + assert "nobs2" not in ens.object.columns # add a source column that we manually set as dirty, don't have a function # that adds temporary source columns at the moment @@ -972,14 +972,14 @@ def test_temporary_cols(parquet_ensemble): # prune object, object should be dirty ens.prune(threshold=10) - assert ens._object.is_dirty() + assert ens.object.is_dirty() # try a sync ens._sync_tables() # f2 should be removed from source assert "f2" not in ens._source_temp - assert "f2" not in ens._source.columns + assert "f2" not in ens.source.columns def test_temporary_cols(parquet_ensemble): @@ -988,7 +988,7 @@ def test_temporary_cols(parquet_ensemble): """ ens = parquet_ensemble - ens._object = ens._object.drop(columns=["nobs_r", "nobs_g", "nobs_total"]) + ens.object = ens.object.drop(columns=["nobs_r", "nobs_g", "nobs_total"]) # Make sure temp lists are available but empty assert not len(ens._source_temp) @@ -998,17 +998,17 @@ def test_temporary_cols(parquet_ensemble): # nobs_total should be a temporary column assert "nobs_total" in ens._object_temp - assert "nobs_total" in ens._object.columns + assert "nobs_total" in ens.object.columns ens.assign(nobs2=lambda x: x["nobs_total"] * 2, table="object", temporary=True) # nobs2 should be a temporary column assert "nobs2" in ens._object_temp - assert "nobs2" in ens._object.columns + assert "nobs2" in ens.object.columns # Replace the maximum flux value with a NaN so that we will have a row to drop. - max_flux = max(parquet_ensemble._source[parquet_ensemble._flux_col]) - parquet_ensemble._source[parquet_ensemble._flux_col] = parquet_ensemble._source[ + max_flux = max(parquet_ensemble.source[parquet_ensemble._flux_col]) + parquet_ensemble.source[parquet_ensemble._flux_col] = parquet_ensemble.source[ parquet_ensemble._flux_col].apply( lambda x: np.nan if x == max_flux else x, meta=pd.Series(dtype=float) ) @@ -1016,18 +1016,18 @@ def test_temporary_cols(parquet_ensemble): # drop NaNs from source, source should be dirty now ens.dropna(how="any", table="source") - assert ens._source.is_dirty() + assert ens.source.is_dirty() # try a sync ens._sync_tables() # nobs_total should be removed from object assert "nobs_total" not in ens._object_temp - assert "nobs_total" not in ens._object.columns + assert "nobs_total" not in ens.object.columns # nobs2 should be removed from object assert "nobs2" not in ens._object_temp - assert "nobs2" not in ens._object.columns + assert "nobs2" not in ens.object.columns # add a source column that we manually set as dirty, don't have a function # that adds temporary source columns at the moment @@ -1036,14 +1036,14 @@ def test_temporary_cols(parquet_ensemble): # prune object, object should be dirty ens.prune(threshold=10) - assert ens._object.is_dirty() + assert ens.object.is_dirty() # try a sync ens._sync_tables() # f2 should be removed from source assert "f2" not in ens._source_temp - assert "f2" not in ens._source.columns + assert "f2" not in ens.source.columns @pytest.mark.parametrize( @@ -1089,11 +1089,11 @@ def test_dropna(data_fixture, request, legacy): parquet_ensemble.dropna(table="source") else: parquet_ensemble.source.dropna().update_ensemble() - assert len(parquet_ensemble._source.compute().index) == source_length - occurrences_source + assert len(parquet_ensemble.source.compute().index) == source_length - occurrences_source if data_fixture == "parquet_ensemble_with_divisions": # divisions should be preserved - assert parquet_ensemble._source.known_divisions + assert parquet_ensemble.source.known_divisions # Now test dropping na from 'object' table # Sync the tables @@ -1116,7 +1116,7 @@ def test_dropna(data_fixture, request, legacy): # Set the nobs_g values for one object to NaN so we can drop it. # We do this on the instantiated object (pdf) and convert it back into a # ObjectFrame. - object_pdf.loc[valid_object_id, parquet_ensemble._object.columns[0]] = pd.NA + object_pdf.loc[valid_object_id, parquet_ensemble.object.columns[0]] = pd.NA parquet_ensemble.update_frame(ObjectFrame.from_tapeframe(TapeObjectFrame(object_pdf), label="object", npartitions=1)) # Try dropping NaNs from object and confirm that we dropped a row @@ -1128,7 +1128,7 @@ def test_dropna(data_fixture, request, legacy): if data_fixture == "parquet_ensemble_with_divisions": # divisions should be preserved - assert parquet_ensemble._object.known_divisions + assert parquet_ensemble.object.known_divisions new_objects_pdf = parquet_ensemble.object.compute() assert len(new_objects_pdf.index) == len(object_pdf.index) - 1 @@ -1147,16 +1147,16 @@ def test_keep_zeros(parquet_ensemble, legacy): Ensemble.dropna when `legacy` is `True`, and EnsembleFrame.dropna when `legacy` is `False`.""" parquet_ensemble.keep_empty_objects = True - prev_npartitions = parquet_ensemble._object.npartitions - old_objects_pdf = parquet_ensemble._object.compute() - pdf = parquet_ensemble._source.compute() + prev_npartitions = parquet_ensemble.object.npartitions + old_objects_pdf = parquet_ensemble.object.compute() + pdf = parquet_ensemble.source.compute() # Set the psFlux values for one object to NaN so we can drop it. # We do this on the instantiated object (pdf) and convert it back into a # Dask DataFrame. valid_id = pdf.index.values[1] pdf.loc[valid_id, parquet_ensemble._flux_col] = pd.NA - parquet_ensemble._source = dd.from_pandas(pdf, npartitions=1) + parquet_ensemble.source = dd.from_pandas(pdf, npartitions=1) parquet_ensemble.update_frame(SourceFrame.from_tapeframe(TapeSourceFrame(pdf), npartitions=1, label="source")) # Sync the table and check that the number of objects decreased. @@ -1167,9 +1167,9 @@ def test_keep_zeros(parquet_ensemble, legacy): parquet_ensemble._sync_tables() # Check that objects are preserved after sync - new_objects_pdf = parquet_ensemble._object.compute() + new_objects_pdf = parquet_ensemble.object.compute() assert len(new_objects_pdf.index) == len(old_objects_pdf.index) - assert parquet_ensemble._object.npartitions == prev_npartitions + assert parquet_ensemble.object.npartitions == prev_npartitions @pytest.mark.parametrize( @@ -1186,29 +1186,29 @@ def test_calc_nobs(data_fixture, request, by_band, multi_partition): ens = request.getfixturevalue(data_fixture) if multi_partition: - ens._source = ens._source.repartition(3) + ens.source = ens.source.repartition(3) # Drop the existing nobs columns - ens._object = ens._object.drop(["nobs_g", "nobs_r", "nobs_total"], axis=1) + ens.object = ens.object.drop(["nobs_g", "nobs_r", "nobs_total"], axis=1) # Calculate nobs ens.calc_nobs(by_band) # Check that things turned out as we expect - lc = ens._object.loc[88472935274829959].compute() + lc = ens.object.loc[88472935274829959].compute() if by_band: - assert np.all([col in ens._object.columns for col in ["nobs_g", "nobs_r"]]) + assert np.all([col in ens.object.columns for col in ["nobs_g", "nobs_r"]]) assert lc["nobs_g"].values[0] == 98 assert lc["nobs_r"].values[0] == 401 - assert "nobs_total" in ens._object.columns + assert "nobs_total" in ens.object.columns assert lc["nobs_total"].values[0] == 499 # Make sure that if divisions were set previously, they are preserved if data_fixture == "parquet_ensemble_with_divisions": - assert ens._object.known_divisions - assert ens._source.known_divisions + assert ens.object.known_divisions + assert ens.source.known_divisions @pytest.mark.parametrize( @@ -1231,19 +1231,19 @@ def test_prune(data_fixture, request, generate_nobs): # Generate the nobs cols from within prune if generate_nobs: # Drop the existing nobs columns - parquet_ensemble._object = parquet_ensemble._object.drop(["nobs_g", "nobs_r", "nobs_total"], axis=1) + parquet_ensemble.object = parquet_ensemble.object.drop(["nobs_g", "nobs_r", "nobs_total"], axis=1) parquet_ensemble.prune(threshold) # Use an existing column else: parquet_ensemble.prune(threshold, col_name="nobs_total") - assert not np.any(parquet_ensemble._object["nobs_total"].values < threshold) + assert not np.any(parquet_ensemble.object["nobs_total"].values < threshold) # Make sure that if divisions were set previously, they are preserved if data_fixture == "parquet_ensemble_with_divisions": - assert parquet_ensemble._source.known_divisions - assert parquet_ensemble._object.known_divisions + assert parquet_ensemble.source.known_divisions + assert parquet_ensemble.object.known_divisions def test_query(dask_client): @@ -1285,7 +1285,7 @@ def test_filter_from_series(dask_client): ens.from_source_dict(rows, column_mapper=cmap, npartitions=2) # Filter the data set to low flux sources only. - keep_series = ens._source[ens._time_col] >= 250.0 + keep_series = ens.source[ens._time_col] >= 250.0 ens.filter_from_series(keep_series, table="source") # Check that all of the filtered rows are value. @@ -1310,22 +1310,22 @@ def test_select(dask_client): } cmap = ColumnMapper(id_col="id", time_col="time", flux_col="flux", err_col="err", band_col="band") ens.from_source_dict(rows, column_mapper=cmap, npartitions=2) - assert len(ens._source.columns) == 5 - assert "time" in ens._source.columns - assert "flux" in ens._source.columns - assert "band" in ens._source.columns - assert "count" in ens._source.columns - assert "something_else" in ens._source.columns + assert len(ens.source.columns) == 5 + assert "time" in ens.source.columns + assert "flux" in ens.source.columns + assert "band" in ens.source.columns + assert "count" in ens.source.columns + assert "something_else" in ens.source.columns # Select on just time and flux ens.select(["time", "flux"], table="source") - assert len(ens._source.columns) == 2 - assert "time" in ens._source.columns - assert "flux" in ens._source.columns - assert "band" not in ens._source.columns - assert "count" not in ens._source.columns - assert "something_else" not in ens._source.columns + assert len(ens.source.columns) == 2 + assert "time" in ens.source.columns + assert "flux" in ens.source.columns + assert "band" not in ens.source.columns + assert "count" not in ens.source.columns + assert "something_else" not in ens.source.columns @pytest.mark.parametrize("legacy", [True, False]) def test_assign(dask_client, legacy): @@ -1345,7 +1345,7 @@ def test_assign(dask_client, legacy): cmap = ColumnMapper(id_col="id", time_col="time", flux_col="flux", err_col="err", band_col="band") ens.from_source_dict(rows, column_mapper=cmap, npartitions=1) assert len(ens.source.columns) == 4 - assert "lower_bnd" not in ens._source.columns + assert "lower_bnd" not in ens.source.columns # Insert a new column for the "lower bound" computation. if legacy: @@ -1400,7 +1400,7 @@ def test_coalesce(dask_client, drop_inputs): ens.coalesce(["flux1", "flux2", "flux3"], "flux", table="source", drop_inputs=drop_inputs) # Coalesce should return this exact flux array - assert list(ens._source["flux"].values.compute()) == [5.0, 3.0, 4.0, 10.0, 7.0] + assert list(ens.source["flux"].values.compute()) == [5.0, 3.0, 4.0, 10.0, 7.0] if drop_inputs: # The column mapping should be updated @@ -1408,7 +1408,7 @@ def test_coalesce(dask_client, drop_inputs): # The columns to drop should be dropped for col in ["flux1", "flux2", "flux3"]: - assert col not in ens._source.columns + assert col not in ens.source.columns # Test for the drop warning with pytest.warns(UserWarning): @@ -1417,7 +1417,7 @@ def test_coalesce(dask_client, drop_inputs): else: # The input columns should still be present for col in ["flux1", "flux2", "flux3"]: - assert col in ens._source.columns + assert col in ens.source.columns @pytest.mark.parametrize("zero_point", [("zp_mag", "zp_flux"), (25.0, 10**10)]) @@ -1448,19 +1448,19 @@ def test_convert_flux_to_mag(dask_client, zero_point, zp_form, out_col_name): if zp_form == "flux": ens.convert_flux_to_mag(zero_point[1], zp_form, out_col_name) - res_mag = ens._source.compute()[output_column].to_list()[0] + res_mag = ens.source.compute()[output_column].to_list()[0] assert pytest.approx(res_mag, 0.001) == 21.28925 - res_err = ens._source.compute()[output_column + "_err"].to_list()[0] + res_err = ens.source.compute()[output_column + "_err"].to_list()[0] assert pytest.approx(res_err, 0.001) == 0.355979 elif zp_form == "mag" or zp_form == "magnitude": ens.convert_flux_to_mag(zero_point[0], zp_form, out_col_name) - res_mag = ens._source.compute()[output_column].to_list()[0] + res_mag = ens.source.compute()[output_column].to_list()[0] assert pytest.approx(res_mag, 0.001) == 21.28925 - res_err = ens._source.compute()[output_column + "_err"].to_list()[0] + res_err = ens.source.compute()[output_column + "_err"].to_list()[0] assert pytest.approx(res_err, 0.001) == 0.355979 else: @@ -1626,7 +1626,7 @@ def test_batch(data_fixture, request, use_map, on): assert result is tracked_result # Make sure that divisions information is propagated if known - if parquet_ensemble._source.known_divisions and parquet_ensemble._object.known_divisions: + if parquet_ensemble.source.known_divisions and parquet_ensemble.object.known_divisions: assert result.known_divisions result = result.compute() @@ -1790,7 +1790,7 @@ def test_sf2(data_fixture, request, method, combine, sthresh, use_map=False): res_sf2 = parquet_ensemble.sf2(argument_container=arg_container, use_map=use_map) res_batch = parquet_ensemble.batch(calc_sf2, use_map=use_map, argument_container=arg_container) - if parquet_ensemble._source.known_divisions and parquet_ensemble._object.known_divisions: + if parquet_ensemble.source.known_divisions and parquet_ensemble.object.known_divisions: if not combine: assert res_sf2.known_divisions