diff --git a/docs/tutorials/binning_slowly_changing_sources.ipynb b/docs/tutorials/binning_slowly_changing_sources.ipynb index 853e62b8..767b34c8 100644 --- a/docs/tutorials/binning_slowly_changing_sources.ipynb +++ b/docs/tutorials/binning_slowly_changing_sources.ipynb @@ -60,7 +60,7 @@ "outputs": [], "source": [ "fig, ax = plt.subplots(1, 1)\n", - "ax.hist(ens._source[\"midPointTai\"].compute().tolist(), 500)\n", + "ax.hist(ens.source[\"midPointTai\"].compute().tolist(), 500)\n", "ax.set_xlabel(\"Time (MJD)\")\n", "ax.set_ylabel(\"Source Count\")" ] @@ -90,7 +90,7 @@ "source": [ "ens.bin_sources(time_window=7.0, offset=0.0)\n", "fig, ax = plt.subplots(1, 1)\n", - "ax.hist(ens._source[\"midPointTai\"].compute().tolist(), 500)\n", + "ax.hist(ens.source[\"midPointTai\"].compute().tolist(), 500)\n", "ax.set_xlabel(\"Time (MJD)\")\n", "ax.set_ylabel(\"Source Count\")" ] @@ -120,7 +120,7 @@ "source": [ "ens.bin_sources(time_window=28.0, offset=0.0, custom_aggr={\"midPointTai\": \"min\"})\n", "fig, ax = plt.subplots(1, 1)\n", - "ax.hist(ens._source[\"midPointTai\"].compute().tolist(), 500)\n", + "ax.hist(ens.source[\"midPointTai\"].compute().tolist(), 500)\n", "ax.set_xlabel(\"Time (MJD)\")\n", "ax.set_ylabel(\"Source Count\")" ] @@ -150,7 +150,7 @@ "ens.from_source_dict(rows, column_mapper=cmap)\n", "\n", "fig, ax = plt.subplots(1, 1)\n", - "ax.hist(ens._source[\"midPointTai\"].compute().tolist(), 60)\n", + "ax.hist(ens.source[\"midPointTai\"].compute().tolist(), 60)\n", "ax.set_xlabel(\"Time (MJD)\")\n", "ax.set_ylabel(\"Source Count\")" ] @@ -179,7 +179,7 @@ "ens.bin_sources(time_window=1.0, offset=0.0)\n", "\n", "fig, ax = plt.subplots(1, 1)\n", - "ax.hist(ens._source[\"midPointTai\"].compute().tolist(), 60)\n", + "ax.hist(ens.source[\"midPointTai\"].compute().tolist(), 60)\n", "ax.set_xlabel(\"Time (MJD)\")\n", "ax.set_ylabel(\"Source Count\")" ] @@ -209,7 +209,7 @@ "ens.bin_sources(time_window=1.0, offset=0.5)\n", "\n", "fig, ax = plt.subplots(1, 1)\n", - "ax.hist(ens._source[\"midPointTai\"].compute().tolist(), 60)\n", + "ax.hist(ens.source[\"midPointTai\"].compute().tolist(), 60)\n", "ax.set_xlabel(\"Time (MJD)\")\n", "ax.set_ylabel(\"Source Count\")" ] @@ -259,7 +259,7 @@ "ens.bin_sources(time_window=1.0, offset=0.5)\n", "\n", "fig, ax = plt.subplots(1, 1)\n", - "ax.hist(ens._source[\"midPointTai\"].compute().tolist(), 60)\n", + "ax.hist(ens.source[\"midPointTai\"].compute().tolist(), 60)\n", "ax.set_xlabel(\"Time (MJD)\")\n", "ax.set_ylabel(\"Source Count\")" ] @@ -290,7 +290,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.10.13" }, "vscode": { "interpreter": { diff --git a/docs/tutorials/scaling_to_large_data.ipynb b/docs/tutorials/scaling_to_large_data.ipynb index b1238409..9e38f6d2 100644 --- a/docs/tutorials/scaling_to_large_data.ipynb +++ b/docs/tutorials/scaling_to_large_data.ipynb @@ -216,7 +216,7 @@ "\n", "print(\"number of lightcurve results in mapres: \", len(mapres))\n", "print(\"number of lightcurve results in groupres: \", len(groupres))\n", - "print(\"True number of lightcurves in the dataset:\", len(np.unique(ens._source.index)))" + "print(\"True number of lightcurves in the dataset:\", len(np.unique(ens.source.index)))" ] }, { @@ -263,7 +263,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.10.13" }, "vscode": { "interpreter": { diff --git a/docs/tutorials/structure_function_showcase.ipynb b/docs/tutorials/structure_function_showcase.ipynb index 592436fe..f2168f23 100644 --- a/docs/tutorials/structure_function_showcase.ipynb +++ b/docs/tutorials/structure_function_showcase.ipynb @@ -267,7 +267,7 @@ "metadata": {}, "outputs": [], "source": [ - "ens.head(\"object\", 5) \n" + "ens.object.head(5) \n" ] }, { @@ -276,7 +276,7 @@ "metadata": {}, "outputs": [], "source": [ - "ens.head(\"source\", 5) " + "ens.source.head(5) " ] }, { diff --git a/docs/tutorials/tape_datasets.ipynb b/docs/tutorials/tape_datasets.ipynb index 1cd3670f..ddcec2de 100644 --- a/docs/tutorials/tape_datasets.ipynb +++ b/docs/tutorials/tape_datasets.ipynb @@ -52,7 +52,7 @@ " column_mapper=col_map\n", " )\n", "\n", - "ens.head(\"source\") # View the first 5 entries of the source table" + "ens.source.head(5) # View the first 5 entries of the source table" ] }, { @@ -93,7 +93,7 @@ " column_mapper=col_map\n", " )\n", "\n", - "ens.head(\"object\") # View the first 5 entries of the object table" + "ens.object.head(5) # View the first 5 entries of the object table" ] }, { @@ -168,7 +168,7 @@ "source": [ "ens.from_dataset(\"s82_rrlyrae\") # Let's grab the Stripe 82 RR Lyrae\n", "\n", - "ens.head(\"object\", 5)" + "ens.object.head(5)" ] }, { @@ -270,7 +270,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.13" }, "vscode": { "interpreter": { diff --git a/docs/tutorials/using_ray_with_the_ensemble.ipynb b/docs/tutorials/using_ray_with_the_ensemble.ipynb index f0ba09a0..b19ca28f 100644 --- a/docs/tutorials/using_ray_with_the_ensemble.ipynb +++ b/docs/tutorials/using_ray_with_the_ensemble.ipynb @@ -81,7 +81,7 @@ "outputs": [], "source": [ "ens.from_dataset(\"s82_qso\")\n", - "ens._source = ens._source.repartition(npartitions=10)\n", + "ens.source = ens.source.repartition(npartitions=10)\n", "ens.batch(calc_sf2, use_map=False) # use_map is false as we repartition naively, splitting per-object sources across partitions" ] }, @@ -116,7 +116,7 @@ "\n", "ens=Ensemble(client=False) # Do not use a client\n", "ens.from_dataset(\"s82_qso\")\n", - "ens._source = ens._source.repartition(npartitions=10)\n", + "ens.source = ens.source.repartition(npartitions=10)\n", "ens.batch(calc_sf2, use_map=False)" ] }, @@ -150,7 +150,7 @@ "\n", "ens = Ensemble()\n", "ens.from_dataset(\"s82_qso\")\n", - "ens._source = ens._source.repartition(npartitions=10)\n", + "ens.source = ens.source.repartition(npartitions=10)\n", "ens.batch(calc_sf2, use_map=False)" ] } @@ -171,7 +171,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.13" }, "vscode": { "interpreter": { diff --git a/docs/tutorials/working_with_the_ensemble.ipynb b/docs/tutorials/working_with_the_ensemble.ipynb index 10110329..2d2eb993 100644 --- a/docs/tutorials/working_with_the_ensemble.ipynb +++ b/docs/tutorials/working_with_the_ensemble.ipynb @@ -20,32 +20,41 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2023-08-30T14:58:34.203827Z", - "start_time": "2023-08-30T14:58:34.187300Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", - "np.random.seed(1)\n", + "import pandas as pd\n", + "\n", + "np.random.seed(1) \n", "\n", - "# initialize a dictionary of empty arrays\n", - "source_dict = {\"id\": np.array([]),\n", - " \"time\": np.array([]),\n", - " \"flux\": np.array([]),\n", - " \"error\": np.array([]),\n", - " \"band\": np.array([])}\n", + "# Generate 10 astronomical objects\n", + "n_obj = 10\n", + "ids = 8000 + np.arange(n_obj)\n", + "names = ids.astype(str)\n", + "object_table = pd.DataFrame(\n", + " {\n", + " \"id\": ids, \n", + " \"name\": names,\n", + " \"ddf_bool\": np.random.randint(0, 2, n_obj), # 0 if from deep drilling field, 1 otherwise\n", + " \"libid_cadence\": np.random.randint(1, 130, n_obj),\n", + " }\n", + ")\n", "\n", - "# Create 10 lightcurves with 100 measurements each\n", + "# Create 1000 lightcurves with 100 measurements each\n", "lc_len = 100\n", - "for i in range(10):\n", - " source_dict[\"id\"] = np.append(source_dict[\"id\"], np.array([i]*lc_len)).astype(int)\n", - " source_dict[\"time\"] = np.append(source_dict[\"time\"], np.linspace(1, lc_len, lc_len))\n", - " source_dict[\"flux\"] = np.append(source_dict[\"flux\"], 100 + 50 * np.random.rand(lc_len))\n", - " source_dict[\"error\"] = np.append(source_dict[\"error\"], 10 + 5 * np.random.rand(lc_len))\n", - " source_dict[\"band\"] = np.append(source_dict[\"band\"], [\"g\"]*50+[\"r\"]*50)" + "num_points = 1000\n", + "all_bands = np.array([\"r\", \"g\", \"b\", \"i\"])\n", + "source_table = pd.DataFrame(\n", + " {\n", + " \"id\": 8000 + (np.arange(num_points) % n_obj),\n", + " \"time\": np.arange(num_points),\n", + " \"flux\": np.random.random_sample(size=num_points)*10,\n", + " \"band\": np.repeat(all_bands, num_points / len(all_bands)),\n", + " \"error\": np.random.random_sample(size=num_points),\n", + " \"count\": np.arange(num_points),\n", + " },\n", + ")" ] }, { @@ -53,7 +62,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We can load these into the `Ensemble` using `Ensemble.from_source_dict()`:" + "We can load these into the `Ensemble` using `Ensemble.from_pandas()`:" ] }, { @@ -72,12 +81,15 @@ "ens = Ensemble() # initialize an ensemble object\n", "\n", "# Read in the generated lightcurve data\n", - "ens.from_source_dict(source_dict, \n", - " id_col=\"id\",\n", - " time_col=\"time\",\n", - " flux_col=\"flux\",\n", - " err_col=\"error\",\n", - " band_col=\"band\")" + "ens.from_pandas(\n", + " source_frame=source_table,\n", + " object_frame=object_table,\n", + " id_col=\"id\",\n", + " time_col=\"time\",\n", + " flux_col=\"flux\",\n", + " err_col=\"error\",\n", + " band_col=\"band\",\n", + " npartitions=1)" ] }, { @@ -85,7 +97,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We now have an `Ensemble` object, and have provided it with the constructed data in the source dictionary. Within the call to `Ensemble.from_source_dict`, we specified which columns of the input file mapped to timeseries quantities that the `Ensemble` needs to understand. It's important to link these arguments properly, as the `Ensemble` will use these columns when operations are requested on understood quantities. For example, if an TAPE analysis function requires the time column, from this linking the `Ensemble` will automatically supply that function with the 'time' column." + "We now have an `Ensemble` object, and have provided it with the constructed data in the source dictionary. Within the call to `Ensemble.from_pandas`, we specified which columns of the input file mapped to timeseries quantities that the `Ensemble` needs to understand. It's important to link these arguments properly, as the `Ensemble` will use these columns when operations are requested on understood quantities. For example, if a TAPE analysis function requires the time column, from this linking the `Ensemble` will automatically supply that function with the 'time' column." ] }, { @@ -95,7 +107,7 @@ "source": [ "## Column Mapping with the ColumnMapper\n", "\n", - "In the above example, we manually provide the column labels within the call to `Ensemble.from_source_dict`. Alternatively, the `tape.utils.ColumnMapper` class offers a means to assign the column mappings. Either manually as shown before, or even populated from a known mapping scheme." + "In the above example, we manually provide the column labels within the call to `Ensemble.from_pandas`. Alternatively, the `tape.utils.ColumnMapper` class offers a means to assign the column mappings. Either manually as shown before, or even populated from a known mapping scheme." ] }, { @@ -118,8 +130,12 @@ " err_col=\"error\",\n", " band_col=\"band\")\n", "\n", - "# Pass the ColumnMapper along to from_source_dict\n", - "ens.from_source_dict(source_dict, column_mapper=col_map)" + "# Pass the ColumnMapper along to from_pandas\n", + "ens.from_pandas(\n", + " source_frame=source_table,\n", + " object_frame=object_table,\n", + " column_mapper=col_map,\n", + " npartitions=1)" ] }, { @@ -128,7 +144,9 @@ "metadata": {}, "source": [ "## The Object and Source Frames\n", - "The `Ensemble` maintains two dataframes under the hood, the \"object dataframe\" and the \"source dataframe\". This borrows from the Rubin Observatories object-source convention, where object denotes a given astronomical object and source is the collection of measurements of that object. Essentially, the Object frame stores one-off information about objects, and the source frame stores the available time-domain data. As a result, `Ensemble` functions that operate on the underlying dataframes need to be pointed at either object or source. In most cases, the default is the object table as it's a more helpful interface for understanding the contents of the `Ensemble`, especially when dealing with large volumes of data." + "The `Ensemble` maintains two dataframes under the hood, the \"object dataframe\" and the \"source dataframe\". This borrows from the Rubin Observatories object-source convention, where object denotes a given astronomical object and source is the collection of measurements of that object. Essentially, the Object frame stores one-off information about objects, and the source frame stores the available time-domain data. As a result, `Ensemble` functions that operate on the underlying dataframes need to be pointed at either object or source. In most cases, the default is the object table as it's a more helpful interface for understanding the contents of the `Ensemble`, especially when dealing with large volumes of data.\n", + "\n", + "We can also access Ensemble frames individually with `Ensemble.source` and `Ensemble.object`" ] }, { @@ -151,14 +169,14 @@ }, "outputs": [], "source": [ - "ens._source # We have not actually loaded any data into memory" + "ens.source # We have not actually loaded any data into memory" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Here we are accessing the Dask dataframe underneath, and despite running a command to read in our data, we only see an empty dataframe with some high-level information available. To explicitly bring the data into memory, we must run a `compute()` command." + "Here we are accessing the Dask dataframe and despite running a command to read in our source data, we only see an empty dataframe with some high-level information available. To explicitly bring the data into memory, we must run a `compute()` command on the data's frame." ] }, { @@ -172,7 +190,7 @@ }, "outputs": [], "source": [ - "ens.compute(\"source\") # Compute lets dask know we're ready to bring the data into memory" + "ens.source.compute() # Compute lets dask know we're ready to bring the data into memory" ] }, { @@ -180,9 +198,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "With this compute, we see above that we now have a populated dataframe (a Pandas dataframe in fact!). From this, many workflows in Dask and by extension TAPE, will look like a series of lazily evaluated commands that are chained together and then executed with a .compute() call at the end of the workflow.\n", + "With this compute, we see above that we have returned a populated dataframe (a Pandas dataframe in fact!). From this, many workflows in Dask and by extension TAPE, will look like a series of lazily evaluated commands that are chained together and then executed with a .compute() call at the end of the workflow.\n", + "\n", + "Alternatively we can use `ens.persist()` to execute the chained commands without loading the result into memory. This can speed up future `compute()` calls.\n", "\n", - "Alternatively we can use `ens.persist()` to execute the chained commands without loading the result into memory. This can speed up future `compute()` calls." + "Note that `Ensemble.source` and `Ensemble.object` are instances of the `tape.SourceFrame` and `tape.ObjectFrame` classes respectively. These are subclasses of Dask dataframes that provide some additional utility for tracking by the ensemble while supporting most of the Dask dataframe API. " ] }, { @@ -223,7 +243,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "`Ensemble.info` shows that we have 2000 rows with 54.7 KBs of used memory, and shows the columns we've brought in with their respective data types. If you'd like to actually bring a few rows into memory to inspect, `Ensemble.head` and `Ensemble.tail` provide access to the first n and last n rows respectively." + "`Ensemble.info` shows that we have 2000 rows and the the memory they use, and it also shows the columns we've brought in with their respective data types. If you'd like to actually bring a few rows into memory to inspect, `EnsembleFrame.head` and `EnsembleFrame.tail` provide access to the first n and last n rows respectively." ] }, { @@ -237,7 +257,7 @@ }, "outputs": [], "source": [ - "ens.head(\"object\", 5) # Grabs the first 5 rows of the object table" + "ens.object.head(5) # Grabs the first 5 rows of the object table" ] }, { @@ -251,7 +271,7 @@ }, "outputs": [], "source": [ - "ens.tail(\"source\", 5) # Grabs the last 5 rows of the source table" + "ens.source.tail(5) # Grabs the last 5 rows of the source table" ] }, { @@ -272,7 +292,7 @@ }, "outputs": [], "source": [ - "ens.compute(\"source\")" + "ens.source.compute()" ] }, { @@ -281,9 +301,9 @@ "source": [ "### Filtering\n", "\n", - "The `Ensemble` provides a general filtering function `query` that mirrors a Pandas or Dask `query` command. Specifically, the function takes a string that provides an expression indicating which rows to **keep**. As with other `Ensemble` functions, an optional `table` parameter allows you to filter on either the object or the source table.\n", + "The `Ensemble` provides a general filtering function [`query`](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.query.html) that mirrors a Pandas or Dask `query` command. Specifically, the function takes a string that provides an expression indicating which rows to **keep**. As with other `Ensemble` functions, an optional `table` parameter allows you to filter on either the object or the source table.\n", "\n", - "For example, the following code filters the sources to only include rows with a flux value above 18.2. It uses `ens._flux_col` to retrieve the name of the column with that information." + "For example, the following code filters the sources to only include rows with flux values above the median. It uses `ens._flux_col` to retrieve the name of the column with that information." ] }, { @@ -297,8 +317,8 @@ }, "outputs": [], "source": [ - "ens.query(f\"{ens._flux_col} > 130.0\", table=\"source\")\n", - "ens.compute(\"source\")" + "highest_flux = ens.source[ens._flux_col].quantile(0.95).compute()\n", + "ens.source.query(f\"{ens._flux_col} < {highest_flux}\").compute()" ] }, { @@ -319,7 +339,8 @@ }, "outputs": [], "source": [ - "keep_rows = ens._source[\"error\"] < 12.0\n", + "# Find all of the source points with the lowest 90% of errors.\n", + "keep_rows = ens.source[\"error\"] < ens.source[\"error\"].quantile(0.9)\n", "keep_rows.compute()" ] }, @@ -327,7 +348,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We then pass that series to a `filter_from_series` function:" + "We also provide filtering at the `Ensemble` level, so you can pass the above series to the `Ensemble.filter_from_series` function:" ] }, { @@ -342,7 +363,7 @@ "outputs": [], "source": [ "ens.filter_from_series(keep_rows, table=\"source\")\n", - "ens.compute(\"source\")" + "ens.source.compute()" ] }, { @@ -364,8 +385,8 @@ "outputs": [], "source": [ "# Cleaning nans\n", - "ens.dropna(table=\"source\") # clean nans from source table\n", - "ens.dropna(table=\"object\") # clean nans from object table\n", + "ens.source.dropna() # clean nans from source table\n", + "ens.object.dropna() # clean nans from object table\n", "\n", "# Filtering on number of observations\n", "ens.prune(threshold=10) # threshold is the minimum number of observations needed to retain the object\n", @@ -402,8 +423,7 @@ "outputs": [], "source": [ "# Add a new column so we can filter it out later.\n", - "ens._source = ens._source.assign(band2=ens._source[\"band\"] + \"2\")\n", - "ens.compute(\"source\")" + "ens.source.assign(band2=ens.source[\"band\"] + \"2\").compute()" ] }, { @@ -418,7 +438,68 @@ "outputs": [], "source": [ "ens.select([\"time\", \"flux\", \"error\", \"band\"], table=\"source\")\n", - "ens.compute(\"source\")" + "print(\"The Source table is dirty: \" + str(ens.source.is_dirty()))\n", + "ens.source.compute()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Updating an Ensemble's Frames" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `Ensemble` is a manager of `EnsembleFrame` objects (of which `Ensemble.source` and `Ensemble.object` are special cases). When performing operations on one of the tables, the results are not automatically sent to the `Ensemble`.\n", + "\n", + "So while in the above examples we demonstrate several methods where we generated filtered views of the source table, note that the underlying data remained unchanged, with no changes to the rows or columns of `Ensemble.source`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "queried_src = ens.source.query(f\"{ens._flux_col} < {highest_flux}\")\n", + "\n", + "print(len(queried_src))\n", + "print(len(ens.source))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When modifying the views of a dataframe tracked by the `Ensemble`, we can update the `Source` or `Object` frame to use the updated view by calling\n", + "\n", + "`Ensemble.update_frame(view_frame)`\n", + "\n", + "Or alternately:\n", + "\n", + "`view_frame.update_ensemble()`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Now apply the views filter to the source frame.\n", + "queried_src.update_ensemble()\n", + "\n", + "ens.source.compute()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that the above is still a series of lazy operations that will not be fully evaluated until an operation such as `compute`. So a call to `update_ensemble` will not yet alter or move any underlying data." ] }, { @@ -443,8 +524,8 @@ }, "outputs": [], "source": [ - "ens.assign(table=\"source\", lower_bnd=lambda x: x[\"flux\"] - 2.0 * x[\"error\"])\n", - "ens.compute(table=\"source\")" + "lower_bnd = ens.source.assign(lower_bnd=lambda x: x[\"flux\"] - 2.0 * x[\"error\"])\n", + "lower_bnd" ] }, { @@ -475,6 +556,175 @@ "res" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Storing and Accessing Result Frames" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note for the above `batch` operation, we also printed:\n", + "\n", + "`Using generated label, result_1, for a batch result.`\n", + "\n", + "In addition to the source and object frames, the `Ensemble` may track other frames as well, accessed by either generated or user-provided labels.\n", + "\n", + "We can access a saved frame with `Ensemble.select_frame(label)`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ens.select_frame(\"result_1\").compute()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`Ensemble.batch` has an optional `label` argument that will store the result with a user-provided label." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "res = ens.batch(calc_stetson_J, compute=True, label=\"stetson_j\")\n", + "\n", + "ens.select_frame(\"stetson_j\").compute()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Likewise we can rename a frame with with a new label, and drop the original frame." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ens.add_frame(ens.select_frame(\"stetson_j\"), \"stetson_j_result_1\") # Add result under new label\n", + "ens.drop_frame(\"stetson_j\") # Drop original label\n", + "\n", + "ens.select_frame(\"stetson_j_result_1\").compute()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also add our own frames with `Ensemble.add_frame(frame, label)`. For instance, we can copy this result and add it to a new frame for the `Ensemble` to track as well." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ens.add_frame(res.copy(), \"new_res\")\n", + "ens.select_frame(\"new_res\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally we can also drop frames we are no longer interested in having the `Ensemble` track." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ens.drop_frame(\"result_1\")\n", + "\n", + "try:\n", + " ens.select_frame(\"result_1\") # This should result in a KeyError since the frame has been dropped.\n", + "except Exception as e:\n", + " print(\"As expected, the frame 'result_1 was dropped.\\n\" + str(e))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Keeping the Object and Source Tables in Sync\n", + "\n", + "The Tape `Ensemble` attempts to lazily \"sync\" the Object and Source tables such that:\n", + "\n", + "* If a series of operations removes all lightcurves for a particular object from the Source table, we will lazily remove that object from the Object table.\n", + "* If a series of operations removes an object from the Object table, we will lazily remove all light curves for that object from the Source table.\n", + "\n", + "As an example let's filter the Object table only for objects observed from deep drilling fields. This operation marks the result table as `dirty` indicating to the `Ensemble` that if used as part of a result computation, it should check if the object and source tables are synced. \n", + "\n", + "Note that because we have not called `update_ensemble()` the `Ensemble` is still using the original Object table which is **not** marked `dirty`.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ddf_only = ens.object.query(\"ddf_bool == True\")\n", + "\n", + "print(\"Object table is dirty: \" + str(ens.object.is_dirty()))\n", + "print(\"ddf_only is dirty: \" + str(ddf_only.is_dirty()))\n", + "ddf_only.compute()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now let's update the `Ensemble`'s Object table. We can see that the Object table is now considered \"dirty\" so a sync between the Source and Object tables will be triggered by computing a `batch` operation. \n", + "\n", + "As part of the sync the Source table has been modified to drop all sources for objects not observed via Deep Drilling Fields. This is reflected both in the `batch` result output and in the reduced number of rows in the Source table." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ddf_only.update_ensemble()\n", + "print(\"Updated object table is now dirty: \" + str(ens.object.is_dirty()))\n", + "\n", + "print(\"Length of the Source table before the batch operation: \" + str(len(ens.source)))\n", + "res = ens.batch(calc_stetson_J, compute=True)\n", + "print(\"Post-computation object table is now dirty: \" + str(ens.object.is_dirty()))\n", + "print(\"Length of the Source table after the batch operation: \" + str(len(ens.source)))\n", + "res" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To summarize:\n", + "\n", + "* An operation that alters a frame marks that frame as \"dirty\"\n", + "* Such an operation on `Ensemble.source` or `Ensemble.object` won't cause a sync unless the output frame is stored back to either `Ensemble.source` or `Ensemble.object` respectively. This is usually done by a call to `EnsembleFrame.update_ensemble()`\n", + "* Syncs are done lazily such that even when the Object and/or Source frames are \"dirty\", a sync between tables won't be triggered until a relevant computation yields an observable output, such as `batch(..., compute=True)`" + ] + }, { "cell_type": "markdown", "metadata": { @@ -587,6 +837,13 @@ "source": [ "ens.client.close() # Tear down the ensemble client" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -605,7 +862,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.10.13" }, "vscode": { "interpreter": { diff --git a/src/tape/__init__.py b/src/tape/__init__.py index ad639c70..46eba57c 100644 --- a/src/tape/__init__.py +++ b/src/tape/__init__.py @@ -1,5 +1,6 @@ from .analysis import * # noqa from .ensemble import * # noqa +from .ensemble_frame import * # noqa from .timeseries import * # noqa from .ensemble_readers import * # noqa from ._version import __version__ # noqa diff --git a/src/tape/ensemble.py b/src/tape/ensemble.py index 00233c44..5d654232 100644 --- a/src/tape/ensemble.py +++ b/src/tape/ensemble.py @@ -13,9 +13,24 @@ from .analysis.feature_extractor import BaseLightCurveFeature, FeatureExtractor from .analysis.structure_function import SF_METHODS from .analysis.structurefunction2 import calc_sf2 +from .ensemble_frame import ( + EnsembleFrame, + EnsembleSeries, + ObjectFrame, + SourceFrame, + TapeFrame, + TapeObjectFrame, + TapeSourceFrame, + TapeSeries, +) from .timeseries import TimeSeries from .utils import ColumnMapper +SOURCE_FRAME_LABEL = "source" +OBJECT_FRAME_LABEL = "object" + +DEFAULT_FRAME_LABEL = "result" # A base default label for an Ensemble's result frames. + class Ensemble: """Ensemble object is a collection of light curve ids""" @@ -34,11 +49,13 @@ def __init__(self, client=True, **kwargs): """ self.result = None # holds the latest query - self._source = None # Source Table - self._object = None # Object Table + self.frames = {} # Frames managed by this Ensemble, keyed by label - self._source_dirty = False # Source Dirty Flag - self._object_dirty = False # Object Dirty Flag + # A unique ID to allocate new result frame labels. + self.default_frame_id = 1 + + self.source = None # Source Table EnsembleFrame + self.object = None # Object Table EnsembleFrame self._source_temp = [] # List of temporary columns in Source self._object_temp = [] # List of temporary columns in Object @@ -78,6 +95,160 @@ def __del__(self): self.client.close() return self + def add_frame(self, frame, label): + """Adds a new frame for the Ensemble to track. + + Parameters + ---------- + frame: `tape.ensemble.EnsembleFrame` + The frame object for the Ensemble to track. + label: `str` + | The label for the Ensemble to use to track the frame. + + Returns + ------- + self: `Ensemble` + + Raises + ------ + ValueError if the label is "source", "object", or already tracked by the Ensemble. + """ + if label == SOURCE_FRAME_LABEL or label == OBJECT_FRAME_LABEL: + raise ValueError(f"Unable to add frame with reserved label " f"'{label}'") + if label in self.frames: + raise ValueError(f"Unable to add frame: a frame with label " f"'{label}'" f"is in the Ensemble.") + # Assign the frame to the requested tracking label. + frame.label = label + # Update the ensemble to track this labeled frame. + self.update_frame(frame) + return self + + def update_frame(self, frame): + """Updates a frame tracked by the Ensemble or otherwise adds it to the Ensemble. + The frame is tracked by its `EnsembleFrame.label` field. + + Parameters + ---------- + frame: `tape.ensemble.EnsembleFrame` + The frame for the Ensemble to update. If not already tracked, it is added. + + Returns + ------- + self: `Ensemble` + + Raises + ------ + ValueError if the `frame.label` is unpopulated, or if the frame is not a SourceFrame or ObjectFrame + but uses the reserved labels. + """ + if frame.label is None: + raise ValueError(f"Unable to update frame with no populated `EnsembleFrame.label`.") + if isinstance(frame, SourceFrame) or isinstance(frame, ObjectFrame): + expected_label = SOURCE_FRAME_LABEL if isinstance(frame, SourceFrame) else OBJECT_FRAME_LABEL + if frame.label != expected_label: + raise ValueError(f"Unable to update frame with reserved label " f"'{frame.label}'") + if isinstance(frame, SourceFrame): + self.source = frame + elif isinstance(frame, ObjectFrame): + self.object = frame + + # Ensure this frame is assigned to this Ensemble. + frame.ensemble = self + self.frames[frame.label] = frame + return self + + def drop_frame(self, label): + """Drops a frame tracked by the Ensemble. + + Parameters + ---------- + label: `str` + | The label of the frame to be dropped by the Ensemble. + + Returns + ------- + self: `Ensemble` + + Raises + ------ + ValueError if the label is "source", or "object". + KeyError if the label is not tracked by the Ensemble. + """ + if label == SOURCE_FRAME_LABEL or label == OBJECT_FRAME_LABEL: + raise ValueError(f"Unable to drop frame with reserved label " f"'{label}'") + if label not in self.frames: + raise KeyError(f"Unable to drop frame: no frame with label " f"'{label}'" f"is in the Ensemble.") + del self.frames[label] + return self + + def select_frame(self, label): + """Selects and returns frame tracked by the Ensemble. + + Parameters + ---------- + label: `str` + | The label of a frame tracked by the Ensemble to be selected. + + Returns + ------- + result: `tape.ensemble.EnsembleFrame` + + Raises + ------ + KeyError if the label is not tracked by the Ensemble. + """ + if label not in self.frames: + raise KeyError( + f"Unable to select frame: no frame with label" f"'{label}'" f" is in the Ensemble." + ) + return self.frames[label] + + def frame_info(self, labels=None, verbose=True, memory_usage=True, **kwargs): + """Wrapper for calling dask.dataframe.DataFrame.info() on frames tracked by the Ensemble. + + Parameters + ---------- + labels: `list`, optional + A list of labels for Ensemble frames to summarize. + If None, info is printed for all tracked frames. + verbose: `bool`, optional + Whether to print the whole summary + memory_usage: `bool`, optional + Specifies whether total memory usage of the DataFrame elements + (including the index) should be displayed. + **kwargs: + keyword arguments passed along to + `dask.dataframe.DataFrame.info()` + Returns + ------- + None + + Raises + ------ + KeyError if a label in labels is not tracked by the Ensemble. + """ + if labels is None: + labels = self.frames.keys() + for label in labels: + if label not in self.frames: + raise KeyError( + f"Unable to get frame info: no frame with label " f"'{label}'" f" is in the Ensemble." + ) + print(label, "Frame") + print(self.frames[label].info(verbose=verbose, memory_usage=memory_usage, **kwargs)) + + def _generate_frame_label(self): + """Generates a new unique label for a result frame.""" + result = DEFAULT_FRAME_LABEL + "_" + str(self.default_frame_id) + self.default_frame_id += 1 # increment to guarantee uniqueness + while result in self.frames: + # If the generated label has been taken by a user, increment again. + # In most workflows, we expect the number of frames to be O(100) so it's unlikely for + # the performance cost of this method to be high. + result = DEFAULT_FRAME_LABEL + "_" + str(self.default_frame_id) + self.default_frame_id += 1 + return result + def insert_sources( self, obj_ids, @@ -155,20 +326,20 @@ def insert_sources( df2 = df2.set_index(self._id_col, drop=True, sort=True) # Save the divisions and number of partitions. - prev_div = self._source.divisions - prev_num = self._source.npartitions + prev_div = self.source.divisions + prev_num = self.source.npartitions # Append the new rows to the correct divisions. - self._source = dd.concat([self._source, df2], axis=0, interleave_partitions=True) - self._source_dirty = True + self.update_frame(dd.concat([self.source, df2], axis=0, interleave_partitions=True)) + self.source.set_dirty(True) # Do the repartitioning if requested. If the divisions were set, reuse them. # Otherwise, use the same number of partitions. if force_repartition: if all(prev_div): - self._source = self._source.repartition(divisions=prev_div) - elif self._source.npartitions != prev_num: - self._source = self._source.repartition(npartitions=prev_num) + self.update_frame(self.source.repartition(divisions=prev_div)) + elif self.source.npartitions != prev_num: + self.source = self.source.repartition(npartitions=prev_num) return self @@ -187,7 +358,7 @@ def client_info(self): return self.client # Prints Dask dashboard to screen def info(self, verbose=True, memory_usage=True, **kwargs): - """Wrapper for dask.dataframe.DataFrame.info() + """Wrapper for dask.dataframe.DataFrame.info() for the Source and Object tables Parameters ---------- @@ -198,16 +369,15 @@ def info(self, verbose=True, memory_usage=True, **kwargs): (including the index) should be displayed. Returns ---------- - counts: `pandas.series` - A series of counts by object + None """ # Sync tables if user wants to retrieve their information self._lazy_sync_tables(table="all") print("Object Table") - self._object.info(verbose=verbose, memory_usage=memory_usage, **kwargs) + self.object.info(verbose=verbose, memory_usage=memory_usage, **kwargs) print("Source Table") - self._source.info(verbose=verbose, memory_usage=memory_usage, **kwargs) + self.source.info(verbose=verbose, memory_usage=memory_usage, **kwargs) def check_sorted(self, table="object"): """Checks to see if an Ensemble Dataframe is sorted (increasing) on @@ -224,9 +394,9 @@ def check_sorted(self, table="object"): or not (False) """ if table == "object": - idx = self._object.index + idx = self.object.index elif table == "source": - idx = self._source.index + idx = self.source.index else: raise ValueError(f"{table} is not one of 'object' or 'source'") @@ -250,7 +420,7 @@ def check_lightcurve_cohesion(self): across multiple partitions (False) """ - idx = self._source.index + idx = self.source.index counts = idx.map_partitions(lambda a: Counter(a.unique())).compute() unq_counter = counts[0] @@ -279,12 +449,12 @@ def compute(self, table=None, **kwargs): if table: self._lazy_sync_tables(table) if table == "object": - return self._object.compute(**kwargs) + return self.object.compute(**kwargs) elif table == "source": - return self._source.compute(**kwargs) + return self.source.compute(**kwargs) else: self._lazy_sync_tables(table="all") - return (self._object.compute(**kwargs), self._source.compute(**kwargs)) + return (self.object.compute(**kwargs), self.source.compute(**kwargs)) def persist(self, **kwargs): """Wrapper for dask.dataframe.DataFrame.persist() @@ -295,15 +465,15 @@ def persist(self, **kwargs): of the computation. """ self._lazy_sync_tables("all") - self._object = self._object.persist(**kwargs) - self._source = self._source.persist(**kwargs) + self.update_frame(self.object.persist(**kwargs)) + self.update_frame(self.source.persist(**kwargs)) def columns(self, table="object"): """Retrieve columns from dask dataframe""" if table == "object": - return self._object.columns + return self.object.columns elif table == "source": - return self._source.columns + return self.source.columns else: raise ValueError(f"{table} is not one of 'object' or 'source'") @@ -312,9 +482,9 @@ def head(self, table="object", n=5, **kwargs): self._lazy_sync_tables(table) if table == "object": - return self._object.head(n=n, **kwargs) + return self.object.head(n=n, **kwargs) elif table == "source": - return self._source.head(n=n, **kwargs) + return self.source.head(n=n, **kwargs) else: raise ValueError(f"{table} is not one of 'object' or 'source'") @@ -323,9 +493,9 @@ def tail(self, table="object", n=5, **kwargs): self._lazy_sync_tables(table) if table == "object": - return self._object.tail(n=n, **kwargs) + return self.object.tail(n=n, **kwargs) elif table == "source": - return self._source.tail(n=n, **kwargs) + return self.source.tail(n=n, **kwargs) else: raise ValueError(f"{table} is not one of 'object' or 'source'") @@ -348,11 +518,9 @@ def dropna(self, table="source", **kwargs): scheme """ if table == "object": - self._object = self._object.dropna(**kwargs) - self._object_dirty = True # This operation modifies the object table + self.update_frame(self.object.dropna(**kwargs)) elif table == "source": - self._source = self._source.dropna(**kwargs) - self._source_dirty = True # This operation modifies the source table + self.update_frame(self.source.dropna(**kwargs)) else: raise ValueError(f"{table} is not one of 'object' or 'source'") @@ -372,13 +540,11 @@ def select(self, columns, table="object"): """ self._lazy_sync_tables(table) if table == "object": - cols_to_drop = [col for col in self._object.columns if col not in columns] - self._object = self._object.drop(cols_to_drop, axis=1) - self._object_dirty = True + cols_to_drop = [col for col in self.object.columns if col not in columns] + self.update_frame(self.object.drop(cols_to_drop, axis=1)) elif table == "source": - cols_to_drop = [col for col in self._source.columns if col not in columns] - self._source = self._source.drop(cols_to_drop, axis=1) - self._source_dirty = True + cols_to_drop = [col for col in self.source.columns if col not in columns] + self.update_frame(self.source.drop(cols_to_drop, axis=1)) else: raise ValueError(f"{table} is not one of 'object' or 'source'") @@ -407,11 +573,9 @@ def query(self, expr, table="object"): """ self._lazy_sync_tables(table) if table == "object": - self._object = self._object.query(expr) - self._object_dirty = True + self.update_frame(self.object.query(expr)) elif table == "source": - self._source = self._source.query(expr) - self._source_dirty = True + self.update_frame(self.source.query(expr)) return self def filter_from_series(self, keep_series, table="object"): @@ -429,11 +593,10 @@ def filter_from_series(self, keep_series, table="object"): """ self._lazy_sync_tables(table) if table == "object": - self._object = self._object[keep_series] - self._object_dirty = True + self.update_frame(self.object[keep_series]) + elif table == "source": - self._source = self._source[keep_series] - self._source_dirty = True + self.update_frame(self.source[keep_series]) return self def assign(self, table="object", temporary=False, **kwargs): @@ -471,19 +634,17 @@ def assign(self, table="object", temporary=False, **kwargs): self._lazy_sync_tables(table) if table == "object": - pre_cols = self._object.columns - self._object = self._object.assign(**kwargs) - self._object_dirty = True - post_cols = self._object.columns + pre_cols = self.object.columns + self.update_frame(self.object.assign(**kwargs)) + post_cols = self.object.columns if temporary: self._object_temp.extend(col for col in post_cols if col not in pre_cols) elif table == "source": - pre_cols = self._source.columns - self._source = self._source.assign(**kwargs) - self._source_dirty = True - post_cols = self._source.columns + pre_cols = self.source.columns + self.update_frame(self.source.assign(**kwargs)) + post_cols = self.source.columns if temporary: self._source_temp.extend(col for col in post_cols if col not in pre_cols) @@ -518,9 +679,9 @@ def coalesce(self, input_cols, output_col, table="object", drop_inputs=False): """ # we shouldn't need to sync for this if table == "object": - table_ddf = self._object + table_ddf = self.object elif table == "source": - table_ddf = self._source + table_ddf = self.source else: raise ValueError(f"{table} is not one of 'object' or 'source'") @@ -575,9 +736,9 @@ def coalesce_partition(df, input_cols, output_col): table_ddf = table_ddf.drop(columns=input_cols) if table == "object": - self._object = table_ddf + self.update_frame(table_ddf) elif table == "source": - self._source = table_ddf + self.update_frame(table_ddf) return self @@ -608,27 +769,27 @@ def calc_nobs(self, by_band=False, label="nobs", temporary=True): if by_band: # repartition the result to align with object - if self._object.known_divisions: + if self.object.known_divisions: # Grab these up front to help out the task graph id_col = self._id_col band_col = self._band_col # Get the band metadata - unq_bands = np.unique(self._source[band_col]) + unq_bands = np.unique(self.source[band_col]) meta = {band: float for band in unq_bands} # Map the groupby to each partition - band_counts = self._source.map_partitions( + band_counts = self.source.map_partitions( lambda x: x.groupby(id_col)[[band_col]] .value_counts() .to_frame() .reset_index() .pivot_table(values=band_col, index=id_col, columns=band_col, aggfunc="sum"), meta=meta, - ).repartition(divisions=self._object.divisions) + ).repartition(divisions=self.object.divisions) else: band_counts = ( - self._source.groupby([self._id_col])[self._band_col] # group by each object + self.source.groupby([self._id_col])[self._band_col] # group by each object .value_counts() # count occurence of each band .to_frame() # convert series to dataframe .rename(columns={self._band_col: "counts"}) # rename column @@ -639,38 +800,36 @@ def calc_nobs(self, by_band=False, label="nobs", temporary=True): ) ) # the pivot_table call makes each band_count a column of the id_col row - band_counts = band_counts.repartition(npartitions=self._object.npartitions) + band_counts = band_counts.repartition(npartitions=self.object.npartitions) # short-hand for calculating nobs_total band_counts["total"] = band_counts[list(band_counts.columns)].sum(axis=1) bands = band_counts.columns.values - self._object = self._object.assign( - **{label + "_" + str(band): band_counts[band] for band in bands} - ) + self.object = self.object.assign(**{label + "_" + str(band): band_counts[band] for band in bands}) if temporary: self._object_temp.extend(label + "_" + str(band) for band in bands) else: - if self._object.known_divisions and self._source.known_divisions: + if self.object.known_divisions and self.source.known_divisions: # Grab these up front to help out the task graph id_col = self._id_col band_col = self._band_col # Map the groupby to each partition - counts = self._source.map_partitions( + counts = self.source.map_partitions( lambda x: x.groupby([id_col])[[band_col]].aggregate("count") - ).repartition(divisions=self._object.divisions) + ).repartition(divisions=self.object.divisions) else: # Just do a groupby on all source counts = ( - self._source.groupby([self._id_col])[[self._band_col]] + self.source.groupby([self._id_col])[[self._band_col]] .aggregate("count") - .repartition(npartitions=self._object.npartitions) + .repartition(npartitions=self.object.npartitions) ) - self._object = self._object.assign(**{label + "_total": counts[self._band_col]}) + self.object = self.object.assign(**{label + "_total": counts[self._band_col]}) if temporary: self._object_temp.extend([label + "_total"]) @@ -707,7 +866,7 @@ def prune(self, threshold=50, col_name=None): # Mask on object table self = self.query(f"{col_name} >= {threshold}", table="object") - self._object_dirty = True # Object Table is now dirty + self.object.set_dirty(True) # Object table is now dirty return self @@ -733,7 +892,7 @@ def find_day_gap_offset(self): self._lazy_sync_tables(table="source") # Compute a histogram of observations by hour of the day. - hours = self._source[self._time_col].apply( + hours = self.source[self._time_col].apply( lambda x: np.floor(x * 24.0).astype(int) % 24, meta=pd.Series(dtype=int) ) hour_counts = hours.value_counts().compute() @@ -809,9 +968,9 @@ def bin_sources( # Bin the time and add it as a column. We create a temporary column that # truncates the time into increments of `time_window`. tmp_time_col = "tmp_time_for_aggregation" - if tmp_time_col in self._source.columns: + if tmp_time_col in self.source.columns: raise KeyError(f"Column '{tmp_time_col}' already exists in source table.") - self._source[tmp_time_col] = self._source[self._time_col].apply( + self.source[tmp_time_col] = self.source[self._time_col].apply( lambda x: np.floor((x + offset) / time_window) * time_window, meta=pd.Series(dtype=float) ) @@ -819,7 +978,7 @@ def bin_sources( aggr_funs = {self._time_col: "mean", self._flux_col: "mean"} # If the source table has errors then add an aggregation function for it. - if self._err_col in self._source.columns: + if self._err_col in self.source.columns: aggr_funs[self._err_col] = dd.Aggregation( name="err_agg", chunk=lambda x: (x.count(), x.apply(lambda s: np.sum(np.power(s, 2)))), @@ -831,8 +990,8 @@ def bin_sources( # adding an initial column of all ones if needed. if count_col is not None: self._bin_count_col = count_col - if self._bin_count_col not in self._source.columns: - self._source[self._bin_count_col] = self._source[self._time_col].apply( + if self._bin_count_col not in self.source.columns: + self.source[self._bin_count_col] = self.source[self._time_col].apply( lambda x: 1, meta=pd.Series(dtype=int) ) aggr_funs[self._bin_count_col] = "sum" @@ -846,16 +1005,18 @@ def bin_sources( aggr_funs[key] = custom_aggr[key] # Group the columns by id, band, and time bucket and aggregate. - self._source = self._source.groupby([self._id_col, self._band_col, tmp_time_col]).aggregate(aggr_funs) + self.update_frame( + self.source.groupby([self._id_col, self._band_col, tmp_time_col]).aggregate(aggr_funs) + ) # Fix the indices and remove the temporary column. - self._source = self._source.reset_index().set_index(self._id_col).drop(tmp_time_col, axis=1) + self.update_frame(self.source.reset_index().set_index(self._id_col).drop(tmp_time_col, axis=1)) # Mark the source table as dirty. - self._source_dirty = True + self.source.set_dirty(True) return self - def batch(self, func, *args, meta=None, use_map=True, compute=True, on=None, **kwargs): + def batch(self, func, *args, meta=None, use_map=True, compute=True, on=None, label="", **kwargs): """Run a function from tape.TimeSeries on the available ids Parameters @@ -893,6 +1054,11 @@ def batch(self, func, *args, meta=None, use_map=True, compute=True, on=None, **k Designates which column(s) to groupby. Columns may be from the source or object tables. For TAPE and `light-curve` functions this is populated automatically. + label: 'str', optional + If provided the ensemble will use this label to track the result + dataframe. If not provided, a label of the from "result_{x}" where x + is a monotonically increasing integer is generated. If `None`, + the result frame will not be tracked. **kwargs: Additional optional parameters passed for the selected function @@ -943,21 +1109,25 @@ def s2n_inter_quartile_range(flux, err): if meta is None: meta = (self._id_col, float) # return a series of ids, default assume a float is returned + # Translate the meta into an appropriate TapeFrame or TapeSeries. This ensures that the + # batch result will be an EnsembleFrame or EnsembleSeries. + meta = self._translate_meta(meta) + if on is None: on = self._id_col # Default grouping is by id_col if isinstance(on, str): on = [on] # Convert to list if only one column is passed # Handle object columns to group on - source_cols = list(self._source.columns) - object_cols = list(self._object.columns) + source_cols = list(self.source.columns) + object_cols = list(self.object.columns) object_group_cols = [col for col in on if (col in object_cols) and (col not in source_cols)] if len(object_group_cols) > 0: - object_col_dd = self._object[object_group_cols] - source_to_batch = self._source.merge(object_col_dd, how="left") + object_col_dd = self.object[object_group_cols] + source_to_batch = self.source.merge(object_col_dd, how="left") else: - source_to_batch = self._source # Can directly use the source table + source_to_batch = self.source # Can directly use the source table id_col = self._id_col # pre-compute needed for dask in lambda function @@ -982,8 +1152,15 @@ def s2n_inter_quartile_range(flux, err): # Inherit divisions if known from source and the resulting index is the id # Groupby on index should always return a subset that adheres to the same divisions criteria - if self._source.known_divisions and batch.index.name == self._id_col: - batch.divisions = self._source.divisions + if self.source.known_divisions and batch.index.name == self._id_col: + batch.divisions = self.source.divisions + + if label is not None: + if label == "": + label = self._generate_frame_label() + print(f"Using generated label, {label}, for a batch result.") + # Track the result frame under the provided label + self.add_frame(batch, label) if compute: return batch.compute() @@ -1088,30 +1265,31 @@ def from_dask_dataframe( The ensemble object with the Dask dataframe data loaded. """ self._load_column_mapper(column_mapper, **kwargs) + source_frame = SourceFrame.from_dask_dataframe(source_frame, self) # Set the index of the source frame and save the resulting table - self._source = source_frame.set_index(self._id_col, drop=True, sorted=sorted, sort=sort) + self.update_frame(source_frame.set_index(self._id_col, drop=True, sorted=sorted, sort=sort)) if object_frame is None: # generate an indexed object table from source - self._object = self._generate_object_table() + self.update_frame(self._generate_object_table()) else: - self._object = object_frame - self._object = self._object.set_index(self._id_col, sorted=sorted, sort=sort) + self.update_frame(ObjectFrame.from_dask_dataframe(object_frame, ensemble=self)) + self.update_frame(self.object.set_index(self._id_col, sorted=sorted, sort=sort)) # Optionally sync the tables, recalculates nobs columns if sync_tables: - self._source_dirty = True - self._object_dirty = True + self.source.set_dirty(True) + self.object.set_dirty(True) self._sync_tables() if npartitions and npartitions > 1: - self._source = self._source.repartition(npartitions=npartitions) + self.source = self.source.repartition(npartitions=npartitions) elif partition_size: - self._source = self._source.repartition(partition_size=partition_size) + self.source = self.source.repartition(partition_size=partition_size) # Check that Divisions are established, warn if not. - for name, table in [("object", self._object), ("source", self._source)]: + for name, table in [("object", self.object), ("source", self.source)]: if not table.known_divisions: warnings.warn( f"Divisions for {name} are not set, certain downstream dask operations may fail as a result. We recommend setting the `sort` or `sorted` flags when loading data to establish division information." @@ -1261,7 +1439,7 @@ def from_parquet( source_file: 'str' Path to a parquet file, or multiple parquet files that contain source information to be read into the ensemble - object_file: 'str' + object_file: 'str', optional Path to a parquet file, or multiple parquet files that contain object information. If not specified, it is generated from the source table @@ -1313,7 +1491,7 @@ def from_parquet( # Index is set False so that we can set it with a future set_index call # This has the advantage of letting Dask set partition boundaries based # on the divisions between the sources of different objects. - source = dd.read_parquet(source_file, index=False, columns=columns, split_row_groups=True) + source = SourceFrame.from_parquet(source_file, index=False, columns=columns, ensemble=self) # Generate a provenance column if not provided if self._provenance_col is None: @@ -1325,7 +1503,7 @@ def from_parquet( # Read in the object file(s) # Index is False so that we can set it with a future set_index call # More meaningful for source than object but parity seems good here - object = dd.read_parquet(object_file, index=False, split_row_groups=True) + object = ObjectFrame.from_parquet(object_file, index=False, ensemble=self) return self.from_dask_dataframe( source_frame=source, object_frame=object, @@ -1421,7 +1599,7 @@ def from_source_dict( """ # Load the source data into a dataframe. - source_frame = dd.DataFrame.from_dict(source_dict, npartitions=npartitions) + source_frame = SourceFrame.from_dict(source_dict, npartitions=npartitions) return self.from_dask_dataframe( source_frame, @@ -1481,40 +1659,65 @@ def convert_flux_to_mag(self, zero_point, zp_form="mag", out_col_name=None, flux if zp_form == "flux": # mag = -2.5*np.log10(flux/zp) if isinstance(zero_point, str): - self._source = self._source.assign( - **{out_col_name: lambda x: -2.5 * np.log10(x[flux_col] / x[zero_point])} + self.update_frame( + self.source.assign( + **{out_col_name: lambda x: -2.5 * np.log10(x[flux_col] / x[zero_point])} + ) ) else: - self._source = self._source.assign( - **{out_col_name: lambda x: -2.5 * np.log10(x[flux_col] / zero_point)} + self.update_frame( + self.source.assign(**{out_col_name: lambda x: -2.5 * np.log10(x[flux_col] / zero_point)}) ) elif zp_form == "magnitude" or zp_form == "mag": # mag = -2.5*np.log10(flux) + zp if isinstance(zero_point, str): - self._source = self._source.assign( - **{out_col_name: lambda x: -2.5 * np.log10(x[flux_col]) + x[zero_point]} + self.update_frame( + self.source.assign( + **{out_col_name: lambda x: -2.5 * np.log10(x[flux_col]) + x[zero_point]} + ) ) else: - self._source = self._source.assign( - **{out_col_name: lambda x: -2.5 * np.log10(x[flux_col]) + zero_point} + self.update_frame( + self.source.assign(**{out_col_name: lambda x: -2.5 * np.log10(x[flux_col]) + zero_point}) ) else: raise ValueError(f"{zp_form} is not a valid zero_point format.") # Calculate Errors if err_col is not None: - self._source = self._source.assign( - **{out_col_name + "_err": lambda x: (2.5 / np.log(10)) * (x[err_col] / x[flux_col])} + self.update_frame( + self.source.assign( + **{out_col_name + "_err": lambda x: (2.5 / np.log(10)) * (x[err_col] / x[flux_col])} + ) ) return self def _generate_object_table(self): """Generate an empty object table from the source table.""" - res = self._source.map_partitions(lambda x: pd.DataFrame(index=x.index.unique())) + res = self.source.map_partitions(lambda x: TapeObjectFrame(index=x.index.unique())) return res + def _lazy_sync_tables_from_frame(self, frame): + """Call the sync operation for the frame only if the + table being modified (`frame`) needs to be synced. + Does nothing in the case that only the table to be modified + is dirty or if it is not the object or source frame for this + `Ensemble`. + + Parameters + ---------- + frame: `tape.EnsembleFrame` + The frame being modified. Only an `ObjectFrame` or + `SourceFrame tracked by this `Ensemble` may trigger + a sync. + """ + if frame is self.object or frame is self.source: + # See if we should sync the Object or Source tables. + self._lazy_sync_tables(frame.label) + return self + def _lazy_sync_tables(self, table="object"): """Call the sync operation for the table only if the the table being modified (`table`) needs to be synced. @@ -1527,11 +1730,11 @@ def _lazy_sync_tables(self, table="object"): The table being modified. Should be one of "object", "source", or "all" """ - if table == "object" and self._source_dirty: # object table should be updated + if table == "object" and self.source.is_dirty(): # object table should be updated self._sync_tables() - elif table == "source" and self._object_dirty: # source table should be updated + elif table == "source" and self.object.is_dirty(): # source table should be updated self._sync_tables() - elif table == "all" and (self._source_dirty or self._object_dirty): + elif table == "all" and (self.source.is_dirty() or self.object.is_dirty()): self._sync_tables() return self @@ -1543,53 +1746,55 @@ def _sync_tables(self): keep_empty_objects attribute is set to True. """ - if self._object_dirty: + if self.object.is_dirty(): # Sync Object to Source; remove any missing objects from source - if self._object.known_divisions and self._source.known_divisions: + if self.object.known_divisions and self.source.known_divisions: # Lazily Create an empty object table (just index) for joining - empty_obj = self._object.map_partitions(lambda x: pd.DataFrame(index=x.index)) + empty_obj = self.object.map_partitions(lambda x: TapeObjectFrame(index=x.index)) + if type(empty_obj) != type(self.object): + raise ValueError("Bad type for empty_obj: " + str(type(empty_obj))) - # Join source onto the empty object table to remove IDs not present in both tables - self._source = self._source.join(empty_obj, how="inner") + # Join source onto the empty object table to align + self.update_frame(self.source.join(empty_obj, how="inner")) else: warnings.warn("Divisions are not known, syncing using a non-lazy method.") - obj_idx = list(self._object.index.compute()) - self._source = self._source.map_partitions(lambda x: x[x.index.isin(obj_idx)]) - self._source = self._source.persist() # persist the source frame + obj_idx = list(self.object.index.compute()) + self.update_frame(self.source.map_partitions(lambda x: x[x.index.isin(obj_idx)])) + self.update_frame(self.source.persist()) # persist the source frame # Drop Temporary Source Columns on Sync if len(self._source_temp): - self._source = self._source.drop(columns=self._source_temp) + self.update_frame(self.source.drop(columns=self._source_temp)) print(f"Temporary columns dropped from Source Table: {self._source_temp}") self._source_temp = [] - if self._source_dirty: # not elif + if self.source.is_dirty(): # not elif if not self.keep_empty_objects: - if self._object.known_divisions and self._source.known_divisions: + if self.object.known_divisions and self.source.known_divisions: # Lazily Create an empty source table (just unique indexes) for joining - empty_src = self._source.map_partitions(lambda x: pd.DataFrame(index=x.index.unique())) - - # Join object onto the empty unique source table to remove IDs not present in - # both tables - self._object = self._object.join(empty_src, how="inner") + empty_src = self.source.map_partitions(lambda x: TapeSourceFrame(index=x.index.unique())) + if type(empty_src) != type(self.source): + raise ValueError("Bad type for empty_src: " + str(type(empty_src))) + # Join object onto the empty unique source table to align + self.update_frame(self.object.join(empty_src, how="inner")) else: warnings.warn("Divisions are not known, syncing using a non-lazy method.") # Sync Source to Object; remove any objects that do not have sources - sor_idx = list(self._source.index.unique().compute()) - self._object = self._object.map_partitions(lambda x: x[x.index.isin(sor_idx)]) - self._object = self._object.persist() # persist the object frame + sor_idx = list(self.source.index.unique().compute()) + self.update_frame(self.object.map_partitions(lambda x: x[x.index.isin(sor_idx)])) + self.update_frame(self.object.persist()) # persist the object frame # Drop Temporary Object Columns on Sync if len(self._object_temp): - self._object = self._object.drop(columns=self._object_temp) + self.update_frame(self.object.drop(columns=self._object_temp)) print(f"Temporary columns dropped from Object Table: {self._object_temp}") self._object_temp = [] # Now synced and clean - self._source_dirty = False - self._object_dirty = False + self.source.set_dirty(False) + self.object.set_dirty(False) return self def to_timeseries( @@ -1642,7 +1847,7 @@ def to_timeseries( if band_col is None: band_col = self._band_col - df = self._source.loc[target].compute() + df = self.source.loc[target].compute() ts = TimeSeries().from_dataframe( data=df, object_id=target, @@ -1714,11 +1919,11 @@ def sf2(self, sf_method="basic", argument_container=None, use_map=True, compute= if argument_container.combine: result = calc_sf2( - self._source[self._time_col], - self._source[self._flux_col], - self._source[self._err_col], - self._source[self._band_col], - self._source.index, + self.source[self._time_col], + self.source[self._flux_col], + self.source[self._err_col], + self.source[self._band_col], + self.source.index, argument_container=argument_container, ) @@ -1728,7 +1933,37 @@ def sf2(self, sf_method="basic", argument_container=None, use_map=True, compute= ) # Inherit divisions information if known - if self._source.known_divisions and self._object.known_divisions: - result.divisions = self._source.divisions + if self.source.known_divisions and self.object.known_divisions: + result.divisions = self.source.divisions return result + + def _translate_meta(self, meta): + """Translates Dask-style meta into a TapeFrame or TapeSeries object. + + Parameters + ---------- + meta : `dict`, `tuple`, `list`, `pd.Series`, `pd.DataFrame`, `pd.Index`, `dtype`, `scalar` + + Returns + ---------- + result : `ensemble.TapeFrame` or `ensemble.TapeSeries` + The appropriate meta for Dask producing an `Ensemble.EnsembleFrame` or + `Ensemble.EnsembleSeries` respectively + """ + if isinstance(meta, TapeFrame) or isinstance(meta, TapeSeries): + return meta + + # If the meta is not a DataFrame or Series, have Dask attempt translate the meta into an + # appropriate Pandas object. + meta_object = meta + if not (isinstance(meta_object, pd.DataFrame) or isinstance(meta_object, pd.Series)): + meta_object = dd.backends.make_meta_object(meta_object) + + # Convert meta_object into the appropriate TAPE extension. + if isinstance(meta_object, pd.DataFrame): + return TapeFrame(meta_object) + elif isinstance(meta_object, pd.Series): + return TapeSeries(meta_object) + else: + raise ValueError("Unsupported Meta: " + str(meta) + "\nTry a Pandas DataFrame or Series instead.") diff --git a/src/tape/ensemble_frame.py b/src/tape/ensemble_frame.py new file mode 100644 index 00000000..7a910f51 --- /dev/null +++ b/src/tape/ensemble_frame.py @@ -0,0 +1,1150 @@ +from collections.abc import Sequence + +import dask.dataframe as dd + +import dask +from dask.dataframe.dispatch import make_meta_dispatch +from dask.dataframe.backends import _nonempty_index, meta_nonempty, meta_nonempty_dataframe, _nonempty_series + +from dask.dataframe.core import get_parallel_type +from dask.dataframe.extensions import make_array_nonempty + +import numpy as np +import pandas as pd + +from typing import Literal + + +from functools import partial +from dask.dataframe.io.parquet.arrow import ( + ArrowDatasetEngine as DaskArrowDatasetEngine, +) + +SOURCE_FRAME_LABEL = "source" # Reserved label for source table +OBJECT_FRAME_LABEL = "object" # Reserved label for object table. + +__all__ = [ + "EnsembleFrame", + "EnsembleSeries", + "ObjectFrame", + "SourceFrame", + "TapeFrame", + "TapeObjectFrame", + "TapeSourceFrame", + "TapeSeries", +] + + +class TapeArrowEngine(DaskArrowDatasetEngine): + """ + Engine for reading parquet files into Tape and assigning the appropriate Dask meta. + + Based off of the approach used in dask_geopandas.io + """ + + @classmethod + def _creates_meta(cls, meta, schema): + """ + Converts the meta to a TapeFrame. + """ + return TapeFrame(meta) + + @classmethod + def _create_dd_meta(cls, dataset_info, use_nullable_dtypes=False): + """Overriding private method for dask >= 2021.10.0""" + meta = super()._create_dd_meta(dataset_info) + + schema = dataset_info["schema"] + if not schema.names and not schema.metadata: + if len(list(dataset_info["ds"].get_fragments())) == 0: + raise ValueError( + "No dataset parts discovered. Use dask.dataframe.read_parquet " + "to read it as an empty DataFrame" + ) + meta = cls._creates_meta(meta, schema) + return meta + + +class TapeSourceArrowEngine(TapeArrowEngine): + """ + Barebones subclass of TapeArrowEngine for assigning the meta when loading from a parquet file + of source data. + """ + + @classmethod + def _creates_meta(cls, meta, schema): + """ + Convert meta to a TapeSourceFrame + """ + return TapeSourceFrame(meta) + + +class TapeObjectArrowEngine(TapeArrowEngine): + """ + Barebones subclass of TapeArrowEngine for assigning the meta when loading from a parquet file + of object data. + """ + + @classmethod + def _creates_meta(cls, meta, schema): + """ + Convert meta to a TapeObjectFrame + """ + return TapeObjectFrame(meta) + + +class _Frame(dd.core._Frame): + """Base class for extensions of Dask Dataframes that track additional Ensemble-related metadata.""" + + def __init__(self, dsk, name, meta, divisions, label=None, ensemble=None): + # We define relevant object fields before super().__init__ since that call may lead to a + # map_partitions call which will assume these fields exist. + self.label = label # A label used by the Ensemble to identify this frame. + self.ensemble = ensemble # The Ensemble object containing this frame. + self.dirty = False # True if the underlying data is out of sync with the Ensemble + + super().__init__(dsk, name, meta, divisions) + + def is_dirty(self): + return self.dirty + + def set_dirty(self, dirty): + self.dirty = dirty + + @property + def _args(self): + # Ensure our Dask extension can correctly be used by pickle. + # See https://github.com/geopandas/dask-geopandas/issues/237 + return super()._args + (self.label, self.ensemble) + + def _propagate_metadata(self, new_frame): + """Propagatees any relevant metadata to a new frame. + + Parameters + ---------- + new_frame: `_Frame` + | A frame to propage metadata to + + Returns + ---------- + new_frame: `_Frame` + The modifed frame + """ + new_frame.label = self.label + new_frame.ensemble = self.ensemble + new_frame.set_dirty(self.is_dirty()) + return new_frame + + def copy(self): + self_copy = super().copy() + return self._propagate_metadata(self_copy) + + def assign(self, **kwargs): + """Assign new columns to a DataFrame. + + This docstring was copied from dask.dataframe.DataFrame.assign. + + Some inconsistencies with the Dask version may exist. + + Returns a new object with all original columns in addition to new ones. Existing columns + that are re-assigned will be overwritten. + + Parameters + ---------- + **kwargs: `dict` + The column names are keywords. If the values are callable, they are computed on the + DataFrame and assigned to the new columns. The callable must not change input DataFrame + (though pandas doesn’t check it). If the values are not callable, (e.g. a Series, + scalar, or array), they are simply assigned. + + Returns + ---------- + result: `tape._Frame` + The modifed frame + """ + result = self._propagate_metadata(super().assign(**kwargs)) + result.set_dirty(True) + return result + + def query(self, expr, **kwargs): + """Filter dataframe with complex expression + + Doc string below derived from dask.dataframe.core + + Blocked version of pd.DataFrame.query + + Parameters + ---------- + expr: str + The query string to evaluate. + You can refer to column names that are not valid Python variable names + by surrounding them in backticks. + Dask does not fully support referring to variables using the '@' character, + use f-strings or the ``local_dict`` keyword argument instead. + **kwargs: `dict` + See the documentation for eval() for complete details on the keyword arguments accepted + by pandas.DataFrame.query(). + + Returns + ---------- + result: `tape._Frame` + The modifed frame + + Notes + ----- + This is like the sequential version except that this will also happen + in many threads. This may conflict with ``numexpr`` which will use + multiple threads itself. We recommend that you set ``numexpr`` to use a + single thread: + + .. code-block:: python + + import numexpr + numexpr.set_num_threads(1) + """ + result = self._propagate_metadata(super().query(expr, **kwargs)) + result.set_dirty(True) + return result + + def merge(self, right, **kwargs): + """Merge the Dataframe with another DataFrame + + Doc string below derived from dask.dataframe.core + + This will merge the two datasets, either on the indices, a certain column + in each dataset or the index in one dataset and the column in another. + + Parameters + ---------- + right: dask.dataframe.DataFrame + how : {'left', 'right', 'outer', 'inner'}, default: 'inner' + How to handle the operation of the two objects: + + - left: use calling frame's index (or column if on is specified) + - right: use other frame's index + - outer: form union of calling frame's index (or column if on is + specified) with other frame's index, and sort it + lexicographically + - inner: form intersection of calling frame's index (or column if + on is specified) with other frame's index, preserving the order + of the calling's one + + on : label or list + Column or index level names to join on. These must be found in both + DataFrames. If on is None and not merging on indexes then this + defaults to the intersection of the columns in both DataFrames. + left_on : label or list, or array-like + Column to join on in the left DataFrame. Other than in pandas + arrays and lists are only support if their length is 1. + right_on : label or list, or array-like + Column to join on in the right DataFrame. Other than in pandas + arrays and lists are only support if their length is 1. + left_index : boolean, default False + Use the index from the left DataFrame as the join key. + right_index : boolean, default False + Use the index from the right DataFrame as the join key. + suffixes : 2-length sequence (tuple, list, ...) + Suffix to apply to overlapping column names in the left and + right side, respectively + indicator : boolean or string, default False + If True, adds a column to output DataFrame called "_merge" with + information on the source of each row. If string, column with + information on source of each row will be added to output DataFrame, + and column will be named value of string. Information column is + Categorical-type and takes on a value of "left_only" for observations + whose merge key only appears in `left` DataFrame, "right_only" for + observations whose merge key only appears in `right` DataFrame, + and "both" if the observation’s merge key is found in both. + npartitions: int or None, optional + The ideal number of output partitions. This is only utilised when + performing a hash_join (merging on columns only). If ``None`` then + ``npartitions = max(lhs.npartitions, rhs.npartitions)``. + Default is ``None``. + shuffle: {'disk', 'tasks', 'p2p'}, optional + Either ``'disk'`` for single-node operation or ``'tasks'`` and + ``'p2p'``` for distributed operation. Will be inferred by your + current scheduler. + broadcast: boolean or float, optional + Whether to use a broadcast-based join in lieu of a shuffle-based + join for supported cases. By default, a simple heuristic will be + used to select the underlying algorithm. If a floating-point value + is specified, that number will be used as the ``broadcast_bias`` + within the simple heuristic (a large number makes Dask more likely + to choose the ``broacast_join`` code path). See ``broadcast_join`` + for more information. + + Notes + ----- + + There are three ways to join dataframes: + + 1. Joining on indices. In this case the divisions are + aligned using the function ``dask.dataframe.multi.align_partitions``. + Afterwards, each partition is merged with the pandas merge function. + + 2. Joining one on index and one on column. In this case the divisions of + dataframe merged by index (:math:`d_i`) are used to divide the column + merged dataframe (:math:`d_c`) one using + ``dask.dataframe.multi.rearrange_by_divisions``. In this case the + merged dataframe (:math:`d_m`) has the exact same divisions + as (:math:`d_i`). This can lead to issues if you merge multiple rows from + (:math:`d_c`) to one row in (:math:`d_i`). + + 3. Joining both on columns. In this case a hash join is performed using + ``dask.dataframe.multi.hash_join``. + + In some cases, you may see a ``MemoryError`` if the ``merge`` operation requires + an internal ``shuffle``, because shuffling places all rows that have the same + index in the same partition. To avoid this error, make sure all rows with the + same ``on``-column value can fit on a single partition. + """ + result = super().merge(right, **kwargs) + return self._propagate_metadata(result) + + def join(self, other, **kwargs): + """Join columns of another DataFrame. Note that if `other` is a different type, + we expect the result to have the type of this object regardless of the value + of the`how` parameter. + + This docstring was copied from pandas.core.frame.DataFrame.join. + + Some inconsistencies with this version may exist. + + Join columns with `other` DataFrame either on index or on a key + column. Efficiently join multiple DataFrame objects by index at once by + passing a list. + + Parameters + ---------- + other : DataFrame, Series, or a list containing any combination of them + Index should be similar to one of the columns in this one. If a + Series is passed, its name attribute must be set, and that will be + used as the column name in the resulting joined DataFrame. + on : str, list of str, or array-like, optional + Column or index level name(s) in the caller to join on the index + in `other`, otherwise joins index-on-index. If multiple + values given, the `other` DataFrame must have a MultiIndex. Can + pass an array as the join key if it is not already contained in + the calling DataFrame. Like an Excel VLOOKUP operation. + how : {'left', 'right', 'outer', 'inner', 'cross'}, default 'left' + How to handle the operation of the two objects. + + * left: use calling frame's index (or column if on is specified) + * right: use `other`'s index. + * outer: form union of calling frame's index (or column if on is + specified) with `other`'s index, and sort it lexicographically. + * inner: form intersection of calling frame's index (or column if + on is specified) with `other`'s index, preserving the order + of the calling's one. + * cross: creates the cartesian product from both frames, preserves the order + of the left keys. + lsuffix : str, default '' + Suffix to use from left frame's overlapping columns. + rsuffix : str, default '' + Suffix to use from right frame's overlapping columns. + sort : bool, default False + Order result DataFrame lexicographically by the join key. If False, + the order of the join key depends on the join type (how keyword). + validate : str, optional + If specified, checks if join is of specified type. + + * "one_to_one" or "1:1": check if join keys are unique in both left + and right datasets. + * "one_to_many" or "1:m": check if join keys are unique in left dataset. + * "many_to_one" or "m:1": check if join keys are unique in right dataset. + * "many_to_many" or "m:m": allowed, but does not result in checks. + + Returns + ------- + result: `tape._Frame` + A TAPE dataframe containing columns from both the caller and `other`. + + """ + result = super().join(other, **kwargs) + return self._propagate_metadata(result) + + def drop(self, labels=None, axis=0, columns=None, errors="raise"): + """Drop specified labels from rows or columns. + + Doc string below derived from dask.dataframe.core + + Remove rows or columns by specifying label names and corresponding + axis, or by directly specifying index or column names. When using a + multi-index, labels on different levels can be removed by specifying + the level. See the :ref:`user guide ` + for more information about the now unused levels. + + Parameters + ---------- + labels : single label or list-like + Index or column labels to drop. A tuple will be used as a single + label and not treated as a list-like. + axis : {0 or 'index', 1 or 'columns'}, default 0 + Whether to drop labels from the index (0 or 'index') or + columns (1 or 'columns'). + is equivalent to ``index=labels``). + columns : single label or list-like + Alternative to specifying axis (``labels, axis=1`` + is equivalent to ``columns=labels``). + errors : {'ignore', 'raise'}, default 'raise' + If 'ignore', suppress error and only existing labels are + dropped. + + Returns + ------- + result: `tape._Frame` + Returns the frame or None with the specified + index or column labels removed or None if inplace=True. + """ + result = self._propagate_metadata( + super().drop(labels=labels, axis=axis, columns=columns, errors=errors) + ) + result.set_dirty(True) + return result + + def dropna(self, **kwargs): + """ + Remove missing values. + + Doc string below derived from dask.dataframe.core + + Parameters + ---------- + + how : {'any', 'all'}, default 'any' + Determine if row or column is removed from DataFrame, when we have + at least one NA or all NA. + + * 'any' : If any NA values are present, drop that row or column. + * 'all' : If all values are NA, drop that row or column. + + thresh : int, optional + Require that many non-NA values. Cannot be combined with how. + subset : column label or sequence of labels, optional + Labels along other axis to consider, e.g. if you are dropping rows + these would be a list of columns to include. + + Returns + ---------- + result: `tape._Frame` + The modifed frame with NA entries dropped from it or None if ``inplace=True``. + """ + result = self._propagate_metadata(super().dropna(**kwargs)) + result.set_dirty(True) + return result + + def persist(self, **kwargs): + """Persist this dask collection into memory + + Doc string below derived from dask.base + + This turns a lazy Dask collection into a Dask collection with the same + metadata, but now with the results fully computed or actively computing + in the background. + + The action of function differs significantly depending on the active + task scheduler. If the task scheduler supports asynchronous computing, + such as is the case of the dask.distributed scheduler, then persist + will return *immediately* and the return value's task graph will + contain Dask Future objects. However if the task scheduler only + supports blocking computation then the call to persist will *block* + and the return value's task graph will contain concrete Python results. + + This function is particularly useful when using distributed systems, + because the results will be kept in distributed memory, rather than + returned to the local process as with compute. + + Parameters + ---------- + **kwargs + Extra keywords to forward to the scheduler function. + + Returns + ------- + result: `tape._Frame` + The modifed frame backed by in-memory data + """ + result = super().persist(**kwargs) + return self._propagate_metadata(result) + + def set_index( + self, + other, + drop=True, + sorted=False, + npartitions=None, + divisions=None, + inplace=False, + sort=True, + **kwargs, + ): + """Set the DataFrame index (row labels) using an existing column. + + Doc string below derived from dask.dataframe.core + + If ``sort=False``, this function operates exactly like ``pandas.set_index`` + and sets the index on the DataFrame. If ``sort=True`` (default), + this function also sorts the DataFrame by the new index. This can have a + significant impact on performance, because joins, groupbys, lookups, etc. + are all much faster on that column. However, this performance increase + comes with a cost, sorting a parallel dataset requires expensive shuffles. + Often we ``set_index`` once directly after data ingest and filtering and + then perform many cheap computations off of the sorted dataset. + + With ``sort=True``, this function is much more expensive. Under normal + operation this function does an initial pass over the index column to + compute approximate quantiles to serve as future divisions. It then passes + over the data a second time, splitting up each input partition into several + pieces and sharing those pieces to all of the output partitions now in + sorted order. + + In some cases we can alleviate those costs, for example if your dataset is + sorted already then we can avoid making many small pieces or if you know + good values to split the new index column then we can avoid the initial + pass over the data. For example if your new index is a datetime index and + your data is already sorted by day then this entire operation can be done + for free. You can control these options with the following parameters. + + Parameters + ---------- + other: string or Dask Series + Column to use as index. + drop: boolean, default True + Delete column to be used as the new index. + sorted: bool, optional + If the index column is already sorted in increasing order. + Defaults to False + npartitions: int, None, or 'auto' + The ideal number of output partitions. If None, use the same as + the input. If 'auto' then decide by memory use. + Only used when ``divisions`` is not given. If ``divisions`` is given, + the number of output partitions will be ``len(divisions) - 1``. + divisions: list, optional + The "dividing lines" used to split the new index into partitions. + For ``divisions=[0, 10, 50, 100]``, there would be three output partitions, + where the new index contained [0, 10), [10, 50), and [50, 100), respectively. + See https://docs.dask.org/en/latest/dataframe-design.html#partitions. + If not given (default), good divisions are calculated by immediately computing + the data and looking at the distribution of its values. For large datasets, + this can be expensive. + Note that if ``sorted=True``, specified divisions are assumed to match + the existing partitions in the data; if this is untrue you should + leave divisions empty and call ``repartition`` after ``set_index``. + inplace: bool, optional + Modifying the DataFrame in place is not supported by Dask. + Defaults to False. + sort: bool, optional + If ``True``, sort the DataFrame by the new index. Otherwise + set the index on the individual existing partitions. + Defaults to ``True``. + shuffle: {'disk', 'tasks', 'p2p'}, optional + Either ``'disk'`` for single-node operation or ``'tasks'`` and + ``'p2p'`` for distributed operation. Will be inferred by your + current scheduler. + compute: bool, default False + Whether or not to trigger an immediate computation. Defaults to False. + Note, that even if you set ``compute=False``, an immediate computation + will still be triggered if ``divisions`` is ``None``. + partition_size: int, optional + Desired size of each partitions in bytes. + Only used when ``npartitions='auto'`` + + Returns + ---------- + result: `tape._Frame` + The indexed frame + """ + result = super().set_index(other, drop, sorted, npartitions, divisions, inplace, sort, **kwargs) + return self._propagate_metadata(result) + + def map_partitions(self, func, *args, **kwargs): + """Apply Python function on each DataFrame partition. + + Doc string below derived from dask.dataframe.core + + If ``sort=False``, this function operates exactly like ``pandas.set_index`` + and sets the index on the DataFrame. If ``sort=True`` (default), + this function also sorts the DataFrame by the new index. This can have a + significant impact on performance, because joins, groupbys, lookups, etc. + are all much faster on that column. However, this performance increase + comes with a cost, sorting a parallel dataset requires expensive shuffles. + Often we ``set_index`` once directly after data ingest and filtering and + then perform many cheap computations off of the sorted dataset. + + With ``sort=True``, this function is much more expensive. Under normal + operation this function does an initial pass over the index column to + compute approximate quantiles to serve as future divisions. It then passes + over the data a second time, splitting up each input partition into several + pieces and sharing those pieces to all of the output partitions now in + sorted order. + + In some cases we can alleviate those costs, for example if your dataset is + sorted already then we can avoid making many small pieces or if you know + good values to split the new index column then we can avoid the initial + pass over the data. For example if your new index is a datetime index and + your data is already sorted by day then this entire operation can be done + for free. You can control these options with the following parameters. + + Parameters + ---------- + other: string or Dask Series + Column to use as index. + drop: boolean, default True + Delete column to be used as the new index. + sorted: bool, optional + If the index column is already sorted in increasing order. + Defaults to False + npartitions: int, None, or 'auto' + The ideal number of output partitions. If None, use the same as + the input. If 'auto' then decide by memory use. + Only used when ``divisions`` is not given. If ``divisions`` is given, + the number of output partitions will be ``len(divisions) - 1``. + divisions: list, optional + The "dividing lines" used to split the new index into partitions. + For ``divisions=[0, 10, 50, 100]``, there would be three output partitions, + where the new index contained [0, 10), [10, 50), and [50, 100), respectively. + See https://docs.dask.org/en/latest/dataframe-design.html#partitions. + If not given (default), good divisions are calculated by immediately computing + the data and looking at the distribution of its values. For large datasets, + this can be expensive. + Note that if ``sorted=True``, specified divisions are assumed to match + the existing partitions in the data; if this is untrue you should + leave divisions empty and call ``repartition`` after ``set_index``. + inplace: bool, optional + Modifying the DataFrame in place is not supported by Dask. + Defaults to False. + sort: bool, optional + If ``True``, sort the DataFrame by the new index. Otherwise + set the index on the individual existing partitions. + Defaults to ``True``. + shuffle: {'disk', 'tasks', 'p2p'}, optional + Either ``'disk'`` for single-node operation or ``'tasks'`` and + ``'p2p'`` for distributed operation. Will be inferred by your + current scheduler. + compute: bool, default False + Whether or not to trigger an immediate computation. Defaults to False. + Note, that even if you set ``compute=False``, an immediate computation + will still be triggered if ``divisions`` is ``None``. + partition_size: int, optional + Desired size of each partitions in bytes. + Only used when ``npartitions='auto'`` + """ + result = super().map_partitions(func, *args, **kwargs) + if isinstance(result, self.__class__): + # If the output of func is another _Frame, let's propagate any metadata. + return self._propagate_metadata(result) + return result + + def compute(self, **kwargs): + """Compute this Dask collection, returning the underlying dataframe or series. + If tracked by an `Ensemble`, the `Ensemble` is informed of this operation and + is given the opportunity to sync any of its tables prior to this Dask collection + being computed. + + Doc string below derived from dask.dataframe.DataFrame.compute + + This turns a lazy Dask collection into its in-memory equivalent. For example + a Dask array turns into a NumPy array and a Dask dataframe turns into a + Pandas dataframe. The entire dataset must fit into memory before calling + this operation. + + Parameters + ---------- + scheduler: `string`, optional + Which scheduler to use like “threads”, “synchronous” or “processes”. + If not provided, the default is to check the global settings first, + and then fall back to the collection defaults. + optimize_graph: `bool`, optional + If True [default], the graph is optimized before computation. + Otherwise the graph is run as is. This can be useful for debugging. + **kwargs: `dict`, optional + Extra keywords to forward to the scheduler function. + """ + if self.ensemble is not None: + self.ensemble._lazy_sync_tables_from_frame(self) + return super().compute(**kwargs) + + +class TapeSeries(pd.Series): + """A barebones extension of a Pandas series to be used for underlying Ensemble data. + + See https://pandas.pydata.org/docs/development/extending.html#subclassing-pandas-data-structures + """ + + @property + def _constructor(self): + return TapeSeries + + @property + def _constructor_sliced(self): + return TapeSeries + + +class TapeFrame(pd.DataFrame): + """A barebones extension of a Pandas frame to be used for underlying Ensemble data. + + See https://pandas.pydata.org/docs/development/extending.html#subclassing-pandas-data-structures + """ + + @property + def _constructor(self): + return TapeFrame + + @property + def _constructor_expanddim(self): + return TapeFrame + + +class EnsembleSeries(_Frame, dd.core.Series): + """A barebones extension of a Dask Series for Ensemble data.""" + + _partition_type = TapeSeries # Tracks the underlying data type + + +class EnsembleFrame(_Frame, dd.core.DataFrame): + """An extension for a Dask Dataframe for data used by a lightcurve Ensemble. + + The underlying non-parallel dataframes are TapeFrames and TapeSeries which extend Pandas frames. + + Example + ---------- + import tape + ens = tape.Ensemble() + data = {...} # Some data you want tracked by the Ensemble + ensemble_frame = tape.EnsembleFrame.from_dict(data, label="my_frame", ensemble=ens) + """ + + _partition_type = TapeFrame # Tracks the underlying data type + + def __getitem__(self, key): + result = super().__getitem__(key) + if isinstance(result, _Frame): + # Ensures that any _Frame metadata is propagated. + result = self._propagate_metadata(result) + return result + + @classmethod + def from_tapeframe(cls, data, npartitions=None, chunksize=None, sort=True, label=None, ensemble=None): + """Returns an EnsembleFrame constructed from a TapeFrame. + Parameters + ---------- + data: `TapeFrame` + Frame containing the underlying data fro the EnsembleFram + npartitions: `int`, optional + The number of partitions of the index to create. Note that depending on + the size and index of the dataframe, the output may have fewer + partitions than requested. + chunksize: `int`, optional + Size of the individual chunks of data in non-parallel objects that make up Dask frames. + sort: `bool`, optional + Whether to sort the frame by a default index. + label: `str`, optional + | The label used to by the Ensemble to identify the frame. + ensemble: `tape.Ensemble`, optional + | A link to the Ensemble object that owns this frame. + Returns + result: `tape.EnsembleFrame` + The constructed EnsembleFrame object. + """ + result = dd.from_pandas(data, npartitions=npartitions, chunksize=chunksize, sort=sort) + result.label = label + result.ensemble = ensemble + return result + + @classmethod + def from_dask_dataframe(cl, df, ensemble=None, label=None): + """Returns an EnsembleFrame constructed from a Dask dataframe. + Parameters + ---------- + df: `dask.dataframe.DataFrame` or `list` + a Dask dataframe to convert to an EnsembleFrame + ensemble: `tape.ensemble.Ensemble`, optional + | A link to the Ensemble object that owns this frame. + label: `str`, optional + | The label used to by the Ensemble to identify the frame. + Returns + result: `tape.EnsembleFrame` + The constructed EnsembleFrame object. + """ + # Create a EnsembleFrame by mapping the partitions to the appropriate meta, TapeFrame + # TODO(wbeebe@uw.edu): Determine if there is a better method + result = df.map_partitions(TapeFrame) + result.ensemble = ensemble + result.label = label + return result + + def update_ensemble(self): + """Updates the Ensemble linked by the `EnsembelFrame.ensemble` property to track this frame. + + Returns + result: `tape.Ensemble` + The Ensemble object which tracks this frame, `None` if no such Ensemble. + """ + if self.ensemble is None: + return None + # Update the Ensemble to track this frame and return the ensemble. + return self.ensemble.update_frame(self) + + def convert_flux_to_mag( + self, + flux_col, + zero_point, + err_col=None, + zp_form="mag", + out_col_name=None, + ): + """Converts this EnsembleFrame's flux column into a magnitude column, returning a new + EnsembleFrame. + + Parameters + ---------- + flux_col: 'str' + The name of the EnsembleFrame flux column to convert into magnitudes. + zero_point: 'str' + The name of the EnsembleFrame column containing the zero point + information for column transformation. + err_col: 'str', optional + The name of the EnsembleFrame column containing the errors to propagate. + Errors are propagated using the following approximation: + Err= (2.5/log(10))*(flux_error/flux), which holds mainly when the + error in flux is much smaller than the flux. + zp_form: `str`, optional + The form of the zero point column, either "flux" or + "magnitude"/"mag". Determines how the zero point (zp) is applied in + the conversion. If "flux", then the function is applied as + mag=-2.5*log10(flux/zp), or if "magnitude", then + mag=-2.5*log10(flux)+zp. + out_col_name: 'str', optional + The name of the output magnitude column, if None then the output + is just the flux column name + "_mag". The error column is also + generated as the out_col_name + "_err". + Returns + ---------- + result: `tape.EnsembleFrame` + A new EnsembleFrame object with a new magnitude (and error) column. + """ + if out_col_name is None: + out_col_name = flux_col + "_mag" + + result = None + if zp_form == "flux": # mag = -2.5*np.log10(flux/zp) + result = self.assign(**{out_col_name: lambda x: -2.5 * np.log10(x[flux_col] / x[zero_point])}) + + elif zp_form == "magnitude" or zp_form == "mag": # mag = -2.5*np.log10(flux) + zp + result = self.assign(**{out_col_name: lambda x: -2.5 * np.log10(x[flux_col]) + x[zero_point]}) + else: + raise ValueError(f"{zp_form} is not a valid zero_point format.") + + # Calculate Errors + if err_col is not None: + result = result.assign( + **{out_col_name + "_err": lambda x: (2.5 / np.log(10)) * (x[err_col] / x[flux_col])} + ) + + return result + + @classmethod + def from_parquet( + cl, + path, + index=None, + columns=None, + label=None, + ensemble=None, + ): + """Returns an EnsembleFrame constructed from loading a parquet file. + Parameters + ---------- + path: `str` or `list` + Source directory for data, or path(s) to individual parquet files. Prefix with a + protocol like s3:// to read from alternative filesystems. To read from multiple + files you can pass a globstring or a list of paths, with the caveat that they must all + have the same protocol. + index: `str`, `list`, `False`, optional + Field name(s) to use as the output frame index. Default is None and index will be + inferred from the pandas parquet file metadata, if present. Use False to read all + fields as columns. + columns: `str` or `list`, optional + Field name(s) to read in as columns in the output. By default all non-index fields will + be read (as determined by the pandas parquet metadata, if present). Provide a single + field name instead of a list to read in the data as a Series. + label: `str`, optional + | The label used to by the Ensemble to identify the frame. + ensemble: `tape.ensemble.Ensemble`, optional + | A link to the Ensemble object that owns this frame. + Returns + result: `tape.EnsembleFrame` + The constructed EnsembleFrame object. + """ + # Read the parquet file with an engine that will assume the meta is a TapeFrame which Dask will + # instantiate as EnsembleFrame via its dispatcher. + result = dd.read_parquet( + path, + index=index, + columns=columns, + split_row_groups=True, + engine=TapeArrowEngine, + ) + result.label = label + result.ensemble = ensemble + + return result + + +class TapeSourceFrame(TapeFrame): + """A barebones extension of a Pandas frame to be used for underlying Ensemble source data + + See https://pandas.pydata.org/docs/development/extending.html#subclassing-pandas-data-structures + """ + + @property + def _constructor(self): + return TapeSourceFrame + + @property + def _constructor_expanddim(self): + return TapeSourceFrame + + +class TapeObjectFrame(TapeFrame): + """A barebones extension of a Pandas frame to be used for underlying Ensemble object data. + + See https://pandas.pydata.org/docs/development/extending.html#subclassing-pandas-data-structures + """ + + @property + def _constructor(self): + return TapeObjectFrame + + @property + def _constructor_expanddim(self): + return TapeObjectFrame + + +class SourceFrame(EnsembleFrame): + """A subclass of EnsembleFrame for Source data.""" + + _partition_type = TapeSourceFrame # Tracks the underlying data type + + def __init__(self, dsk, name, meta, divisions, ensemble=None): + super().__init__(dsk, name, meta, divisions) + self.label = SOURCE_FRAME_LABEL # A label used by the Ensemble to identify this frame. + self.ensemble = ensemble # The Ensemble object containing this frame. + + def __getitem__(self, key): + result = super().__getitem__(key) + if isinstance(result, _Frame): + # Ensures that we have any metadata + result = self._propagate_metadata(result) + return result + + @classmethod + def from_parquet( + cl, + path, + index=None, + columns=None, + ensemble=None, + ): + """Returns a SourceFrame constructed from loading a parquet file. + Parameters + ---------- + path: `str` or `list` + Source directory for data, or path(s) to individual parquet files. Prefix with a + protocol like s3:// to read from alternative filesystems. To read from multiple + files you can pass a globstring or a list of paths, with the caveat that they must all + have the same protocol. + columns: `str` or `list`, optional + Field name(s) to read in as columns in the output. By default all non-index fields will + be read (as determined by the pandas parquet metadata, if present). Provide a single + field name instead of a list to read in the data as a Series. + index: `str`, `list`, `False`, optional + Field name(s) to use as the output frame index. Default is None and index will be + inferred from the pandas parquet file metadata, if present. Use False to read all + fields as columns. + ensemble: `tape.ensemble.Ensemble`, optional + | A link to the Ensemble object that owns this frame. + Returns + result: `tape.EnsembleFrame` + The constructed EnsembleFrame object. + """ + # Read the source parquet file with an engine that will assume the meta is a + # TapeSourceFrame which tells Dask to instantiate a SourceFrame via its + # dispatcher. + result = dd.read_parquet( + path, + index=index, + columns=columns, + split_row_groups=True, + engine=TapeSourceArrowEngine, + ) + result.ensemble = ensemble + result.label = SOURCE_FRAME_LABEL + + return result + + @classmethod + def from_dask_dataframe(cl, df, ensemble=None): + """Returns a SourceFrame constructed from a Dask dataframe.. + Parameters + ---------- + df: `dask.dataframe.DataFrame` or `list` + a Dask dataframe to convert to a SourceFrame + ensemble: `tape.ensemble.Ensemble`, optional + | A link to the Ensemble object that owns this frame. + Returns + result: `tape.SourceFrame` + The constructed SourceFrame object. + """ + # Create a SourceFrame by mapping the partitions to the appropriate meta, TapeSourceFrame + # TODO(wbeebe@uw.edu): Determine if there is a better method + result = df.map_partitions(TapeSourceFrame) + result.ensemble = ensemble + result.label = SOURCE_FRAME_LABEL + return result + + +class ObjectFrame(EnsembleFrame): + """A subclass of EnsembleFrame for Object data.""" + + _partition_type = TapeObjectFrame # Tracks the underlying data type + + def __init__(self, dsk, name, meta, divisions, ensemble=None): + super().__init__(dsk, name, meta, divisions) + self.label = OBJECT_FRAME_LABEL # A label used by the Ensemble to identify this frame. + self.ensemble = ensemble # The Ensemble object containing this frame. + + @classmethod + def from_parquet( + cl, + path, + index=None, + columns=None, + ensemble=None, + ): + """Returns an ObjectFrame constructed from loading a parquet file. + Parameters + ---------- + path: `str` or `list` + Source directory for data, or path(s) to individual parquet files. Prefix with a + protocol like s3:// to read from alternative filesystems. To read from multiple + files you can pass a globstring or a list of paths, with the caveat that they must all + have the same protocol. + columns: `str` or `list`, optional + Field name(s) to read in as columns in the output. By default all non-index fields will + be read (as determined by the pandas parquet metadata, if present). Provide a single + field name instead of a list to read in the data as a Series. + index: `str`, `list`, `False`, optional + Field name(s) to use as the output frame index. Default is None and index will be + inferred from the pandas parquet file metadata, if present. Use False to read all + fields as columns. + ensemble: `tape.ensemble.Ensemble`, optional + | A link to the Ensemble object that owns this frame. + Returns + result: `tape.ObjectFrame` + The constructed ObjectFrame object. + """ + # Read in the object Parquet file + result = dd.read_parquet( + path, + index=index, + columns=columns, + split_row_groups=True, + engine=TapeObjectArrowEngine, + ) + result.ensemble = ensemble + result.label = OBJECT_FRAME_LABEL + + return result + + @classmethod + def from_dask_dataframe(cl, df, ensemble=None): + """Returns an ObjectFrame constructed from a Dask dataframe.. + Parameters + ---------- + df: `dask.dataframe.DataFrame` or `list` + a Dask dataframe to convert to an ObjectFrame + ensemble: `tape.ensemble.Ensemble`, optional + | A link to the Ensemble object that owns this frame. + Returns + result: `tape.ObjectFrame` + The constructed ObjectFrame object. + """ + # Create an ObjectFrame by mapping the partitions to the appropriate meta, TapeObjectFrame + # TODO(wbeebe@uw.edu): Determine if there is a better method + result = df.map_partitions(TapeObjectFrame) + result.ensemble = ensemble + result.label = OBJECT_FRAME_LABEL + return result + + +# Dask Dataframes are constructed indirectly using method dispatching and inference on the +# underlying data. So to ensure our subclasses behave correctly, we register the methods +# below. +# +# For more information, see https://docs.dask.org/en/latest/dataframe-extend.html +# +# The following should ensure that any Dask Dataframes which use TapeSeries or TapeFrames as their +# underlying data will be resolved as EnsembleFrames or EnsembleSeries as their parrallel +# counterparts. The underlying Dask Dataframe _meta will be a TapeSeries or TapeFrame. + +get_parallel_type.register(TapeSeries, lambda _: EnsembleSeries) +get_parallel_type.register(TapeFrame, lambda _: EnsembleFrame) +get_parallel_type.register(TapeObjectFrame, lambda _: ObjectFrame) +get_parallel_type.register(TapeSourceFrame, lambda _: SourceFrame) + + +@make_meta_dispatch.register(TapeSeries) +def make_meta_series(x, index=None): + # Create an empty TapeSeries to use as Dask's underlying object meta. + result = x.head(0) + return result + + +@make_meta_dispatch.register(TapeFrame) +def make_meta_frame(x, index=None): + # Create an empty TapeFrame to use as Dask's underlying object meta. + result = x.head(0) + return result + + +@meta_nonempty.register(TapeSeries) +def _nonempty_tapeseries(x, index=None): + # Construct a new TapeSeries with the same underlying data. + data = _nonempty_series(x) + return TapeSeries(data) + + +@meta_nonempty.register(TapeFrame) +def _nonempty_tapeseries(x, index=None): + # Construct a new TapeFrame with the same underlying data. + df = meta_nonempty_dataframe(x) + return TapeFrame(df) + + +@make_meta_dispatch.register(TapeObjectFrame) +def make_meta_frame(x, index=None): + # Create an empty TapeObjectFrame to use as Dask's underlying object meta. + result = x.head(0) + return result + + +@meta_nonempty.register(TapeObjectFrame) +def _nonempty_tapesourceframe(x, index=None): + # Construct a new TapeObjectFrame with the same underlying data. + df = meta_nonempty_dataframe(x) + return TapeObjectFrame(df) + + +@make_meta_dispatch.register(TapeSourceFrame) +def make_meta_frame(x, index=None): + # Create an empty TapeSourceFrame to use as Dask's underlying object meta. + result = x.head(0) + return result + + +@meta_nonempty.register(TapeSourceFrame) +def _nonempty_tapesourceframe(x, index=None): + # Construct a new TapeSourceFrame with the same underlying data. + df = meta_nonempty_dataframe(x) + return TapeSourceFrame(df) diff --git a/tests/tape_tests/conftest.py b/tests/tape_tests/conftest.py index 06040690..c0af84c3 100644 --- a/tests/tape_tests/conftest.py +++ b/tests/tape_tests/conftest.py @@ -233,6 +233,24 @@ def parquet_ensemble_without_client(): return ens +@pytest.fixture +def parquet_files_and_ensemble_without_client(): + """Create an Ensemble from parquet data without a dask client.""" + ens = Ensemble(client=False) + source_file = "tests/tape_tests/data/source/test_source.parquet" + object_file = "tests/tape_tests/data/object/test_object.parquet" + colmap = ColumnMapper().assign( + id_col="ps1_objid", + time_col="midPointTai", + flux_col="psFlux", + err_col="psFluxErr", + band_col="filterName", + ) + ens = ens.from_parquet( + source_file, + object_file, + column_mapper=colmap) + return ens, source_file, object_file, colmap # pylint: disable=redefined-outer-name @pytest.fixture @@ -471,3 +489,26 @@ def pandas_with_object_ensemble(dask_client): ) return ens + +# pylint: disable=redefined-outer-name +@pytest.fixture +def ensemble_from_source_dict(dask_client): + """Create an Ensemble from a source dict, returning the ensemble and the source dict.""" + ens = Ensemble(client=dask_client) + + # Create some fake data with two IDs (8001, 8002), two bands ["g", "b"] + # a few time steps, flux, and data for zero point calculations. + source_dict = { + "id": [8001, 8001, 8002, 8002, 8002], + "time": [1, 2, 3, 4, 5], + "flux": [30.5, 70, 80.6, 30.2, 60.3], + "zp_mag": [25.0, 25.0, 25.0, 25.0, 25.0], + "zp_flux": [10**10, 10**10, 10**10, 10**10, 10**10], + "error": [10, 10, 10, 10, 10], + "band": ["g", "g", "b", "b", "b"], + } + # map flux_col to one of the flux columns at the start + cmap = ColumnMapper(id_col="id", time_col="time", flux_col="flux", err_col="error", band_col="band") + ens.from_source_dict(source_dict, column_mapper=cmap) + + return ens, source_dict \ No newline at end of file diff --git a/tests/tape_tests/test_analysis.py b/tests/tape_tests/test_analysis.py index c75a9621..824e4954 100644 --- a/tests/tape_tests/test_analysis.py +++ b/tests/tape_tests/test_analysis.py @@ -28,7 +28,7 @@ def test_analysis_function(cls, dask_client): "flux": [1.0, 2.0, 5.0, 3.0, 1.0, 2.0, 3.0, 4.0, 5.0], } cmap = ColumnMapper(id_col="id", time_col="time", flux_col="flux", err_col="err", band_col="band") - ens = Ensemble(client=dask_client).from_source_dict(rows, column_mapper=cmap) + ens = Ensemble(client=False).from_source_dict(rows, column_mapper=cmap) assert isinstance(obj.cols(ens), list) assert len(obj.cols(ens)) > 0 diff --git a/tests/tape_tests/test_ensemble.py b/tests/tape_tests/test_ensemble.py index 7c8a96ba..40a0264a 100644 --- a/tests/tape_tests/test_ensemble.py +++ b/tests/tape_tests/test_ensemble.py @@ -7,7 +7,17 @@ import pytest import tape -from tape import Ensemble +from tape import ( + Ensemble, + EnsembleFrame, + EnsembleSeries, + ObjectFrame, + SourceFrame, + TapeFrame, + TapeSeries, + TapeObjectFrame, + TapeSourceFrame, +) from tape.analysis.stetsonj import calc_stetson_J from tape.analysis.structure_function.base_argument_container import StructureFunctionArgumentContainer from tape.analysis.structurefunction2 import calc_sf2 @@ -44,6 +54,12 @@ def test_with_client(): "read_parquet_ensemble_from_hipscat", "read_parquet_ensemble_with_column_mapper", "read_parquet_ensemble_with_known_column_mapper", + "read_parquet_ensemble", + "read_parquet_ensemble_without_client", + "read_parquet_ensemble_from_source", + "read_parquet_ensemble_from_hipscat", + "read_parquet_ensemble_with_column_mapper", + "read_parquet_ensemble_with_known_column_mapper", ], ) def test_parquet_construction(data_fixture, request): @@ -53,13 +69,13 @@ def test_parquet_construction(data_fixture, request): parquet_ensemble = request.getfixturevalue(data_fixture) # Check to make sure the source and object tables were created - assert parquet_ensemble._source is not None - assert parquet_ensemble._object is not None + assert parquet_ensemble.source is not None + assert parquet_ensemble.object is not None # Make sure divisions are set if data_fixture == "parquet_ensemble_with_divisions": - assert parquet_ensemble._source.known_divisions - assert parquet_ensemble._object.known_divisions + assert parquet_ensemble.source.known_divisions + assert parquet_ensemble.object.known_divisions # Check that the data is not empty. obj, source = parquet_ensemble.compute() @@ -78,7 +94,7 @@ def test_parquet_construction(data_fixture, request): parquet_ensemble._provenance_col, ]: # Check to make sure the critical quantity labels are bound to real columns - assert parquet_ensemble._source[col] is not None + assert parquet_ensemble.source[col] is not None @pytest.mark.parametrize( @@ -101,8 +117,8 @@ def test_dataframe_constructors(data_fixture, request): ens = request.getfixturevalue(data_fixture) # Check to make sure the source and object tables were created - assert ens._source is not None - assert ens._object is not None + assert ens.source is not None + assert ens.object is not None # Check that the data is not empty. obj, source = ens.compute() @@ -120,13 +136,71 @@ def test_dataframe_constructors(data_fixture, request): ens._band_col, ]: # Check to make sure the critical quantity labels are bound to real columns - assert ens._source[col] is not None + assert ens.source[col] is not None # Check that we can compute an analysis function on the ensemble. amplitude = ens.batch(calc_stetson_J) assert len(amplitude) == 5 +@pytest.mark.parametrize( + "data_fixture", + [ + "parquet_ensemble", + "parquet_ensemble_without_client", + ], +) +def test_update_ensemble(data_fixture, request): + """ + Tests that the ensemble can be updated with a result frame. + """ + ens = request.getfixturevalue(data_fixture) + + # Filter the object table and have the ensemble track the updated table. + updated_obj = ens.object.query("nobs_total > 50") + assert updated_obj is not ens.object + assert updated_obj.is_dirty() + # Update the ensemble and validate that it marks the object table dirty + assert ens.object.is_dirty() == False + updated_obj.update_ensemble() + assert ens.object.is_dirty() == True + assert updated_obj is ens.object + + # Filter the source table and have the ensemble track the updated table. + updated_src = ens.source.query("psFluxErr > 0.1") + assert updated_src is not ens.source + # Update the ensemble and validate that it marks the source table dirty + assert ens.source.is_dirty() == False + updated_src.update_ensemble() + assert ens.source.is_dirty() == True + assert updated_src is ens.source + + # Compute a result to trigger a table sync + obj, src = ens.compute() + assert len(obj) > 0 + assert len(src) > 0 + assert ens.object.is_dirty() == False + assert ens.source.is_dirty() == False + + # Create an additional result table for the ensemble to track. + cnts = ens.source.groupby([ens._id_col, ens._band_col])[ens._time_col].aggregate("count") + res = ( + cnts.to_frame() + .reset_index() + .categorize(columns=[ens._band_col]) + .pivot_table(values=ens._time_col, index=ens._id_col, columns=ens._band_col, aggfunc="sum") + ) + + # Convert the resulting dataframe into an EnsembleFrame and update the Ensemble + result_frame = EnsembleFrame.from_dask_dataframe(res, ensemble=ens, label="result") + result_frame.update_ensemble() + assert ens.select_frame("result") is result_frame + + # Test update_ensemble when a frame is unlinked to its parent ensemble. + result_frame.ensemble = None + assert result_frame.update_ensemble() is None + + def test_available_datasets(dask_client): """ Test that the ensemble is able to successfully read in the list of available TAPE datasets @@ -139,6 +213,101 @@ def test_available_datasets(dask_client): assert len(datasets) > 0 # Find at least one +@pytest.mark.parametrize( + "data_fixture", + [ + "parquet_files_and_ensemble_without_client", + ], +) +def test_frame_tracking(data_fixture, request): + """ + Tests a workflow of adding and removing the frames tracked by the Ensemble. + """ + ens, source_file, object_file, colmap = request.getfixturevalue(data_fixture) + + # Since we load the ensemble from a parquet, we expect the Source and Object frames to be populated. + assert len(ens.frames) == 2 + assert isinstance(ens.select_frame("source"), SourceFrame) + assert isinstance(ens.select_frame("object"), ObjectFrame) + + # Check that we can select source and object frames + assert len(ens.frames) == 2 + assert ens.select_frame("source") is ens.source + assert isinstance(ens.select_frame("source"), SourceFrame) + assert ens.select_frame("object") is ens.object + assert isinstance(ens.select_frame("object"), ObjectFrame) + + # Construct some result frames for the Ensemble to track. Underlying data is irrelevant for + # this test. + num_points = 100 + data = TapeFrame( + { + "id": [8000 + 2 * i for i in range(num_points)], + "time": [float(i) for i in range(num_points)], + "flux": [0.5 * float(i % 4) for i in range(num_points)], + } + ) + # Labels to give the EnsembleFrames + label1, label2, label3 = "frame1", "frame2", "frame3" + ens_frame1 = EnsembleFrame.from_tapeframe(data, npartitions=1, ensemble=ens, label=label1) + ens_frame2 = EnsembleFrame.from_tapeframe(data, npartitions=1, ensemble=ens, label=label2) + ens_frame3 = EnsembleFrame.from_tapeframe(data, npartitions=1, ensemble=ens, label=label3) + + # Validate that new source and object frames can't be added or updated. + with pytest.raises(ValueError): + ens.add_frame(ens_frame1, "source") + with pytest.raises(ValueError): + ens.add_frame(ens_frame1, "object") + + # Test that we can add and select a new ensemble frame + assert ens.add_frame(ens_frame1, label1).select_frame(label1) is ens_frame1 + assert len(ens.frames) == 3 + + # Validate that we can't add a new frame that uses an exisiting label + with pytest.raises(ValueError): + ens.add_frame(ens_frame2, label1) + + # We add two more frames to track + ens.add_frame(ens_frame2, label2).add_frame(ens_frame3, label3) + assert ens.select_frame(label2) is ens_frame2 + assert ens.select_frame(label3) is ens_frame3 + assert len(ens.frames) == 5 + + # Now we begin dropping frames. First verify that we can't drop object or source. + with pytest.raises(ValueError): + ens.drop_frame("source") + with pytest.raises(ValueError): + ens.drop_frame("object") + + # And verify that we can't call drop with an unknown label. + with pytest.raises(KeyError): + ens.drop_frame("nonsense") + + # Drop an existing frame and that it can no longer be selected. + ens.drop_frame(label3) + assert len(ens.frames) == 4 + with pytest.raises(KeyError): + ens.select_frame(label3) + + # Update the ensemble with the dropped frame, and then select the frame + assert ens.update_frame(ens_frame3).select_frame(label3) is ens_frame3 + assert len(ens.frames) == 5 + + # Update the ensemble with an unlabeled frame, verifying a missing label generates an error. + ens_frame4 = EnsembleFrame.from_tapeframe(data, npartitions=1, ensemble=ens, label=None) + label4 = "frame4" + with pytest.raises(ValueError): + ens.update_frame(ens_frame4) + ens_frame4.label = label4 + assert ens.update_frame(ens_frame4).select_frame(label4) is ens_frame4 + assert len(ens.frames) == 6 + + # Change the label of the 4th ensemble frame to verify update overrides an existing frame + ens_frame4.label = label3 + assert ens.update_frame(ens_frame4).select_frame(label3) is ens_frame4 + assert len(ens.frames) == 6 + + def test_from_rrl_dataset(dask_client): """ Test a basic load and analyze workflow from the S82 RR Lyrae Dataset @@ -310,7 +479,7 @@ def test_read_source_dict(dask_client): def test_insert(parquet_ensemble): - num_partitions = parquet_ensemble._source.npartitions + num_partitions = parquet_ensemble.source.npartitions (old_object, old_source) = parquet_ensemble.compute() old_size = old_source.shape[0] @@ -332,7 +501,7 @@ def test_insert(parquet_ensemble): ) # Check we did not increase the number of partitions. - assert parquet_ensemble._source.npartitions == num_partitions + assert parquet_ensemble.source.npartitions == num_partitions # Check that all the new data points are in there. The order may be different # due to the repartitioning. @@ -361,7 +530,7 @@ def test_insert(parquet_ensemble): ) # Check we *did* increase the number of partitions and the size increased. - assert parquet_ensemble._source.npartitions != num_partitions + assert parquet_ensemble.source.npartitions != num_partitions (new_obj, new_source) = parquet_ensemble.compute() assert new_source.shape[0] == old_size + 10 @@ -390,8 +559,8 @@ def test_insert_paritioned(dask_client): # Save the old data for comparison. old_data = ens.compute("source") - old_div = copy.copy(ens._source.divisions) - old_sizes = [len(ens._source.partitions[i]) for i in range(4)] + old_div = copy.copy(ens.source.divisions) + old_sizes = [len(ens.source.partitions[i]) for i in range(4)] assert old_data.shape[0] == num_points # Test an insertion of 5 observations. @@ -404,12 +573,12 @@ def test_insert_paritioned(dask_client): # Check we did not increase the number of partitions and the points # were placed in the correct partitions. - assert ens._source.npartitions == 4 - assert ens._source.divisions == old_div - assert len(ens._source.partitions[0]) == old_sizes[0] + 3 - assert len(ens._source.partitions[1]) == old_sizes[1] - assert len(ens._source.partitions[2]) == old_sizes[2] + 2 - assert len(ens._source.partitions[3]) == old_sizes[3] + assert ens.source.npartitions == 4 + assert ens.source.divisions == old_div + assert len(ens.source.partitions[0]) == old_sizes[0] + 3 + assert len(ens.source.partitions[1]) == old_sizes[1] + assert len(ens.source.partitions[2]) == old_sizes[2] + 2 + assert len(ens.source.partitions[3]) == old_sizes[3] # Check that all the new data points are in there. The order may be different # due to the repartitioning. @@ -427,12 +596,12 @@ def test_insert_paritioned(dask_client): # Check we did not increase the number of partitions and the points # were placed in the correct partitions. - assert ens._source.npartitions == 4 - assert ens._source.divisions == old_div - assert len(ens._source.partitions[0]) == old_sizes[0] + 3 - assert len(ens._source.partitions[1]) == old_sizes[1] + 5 - assert len(ens._source.partitions[2]) == old_sizes[2] + 2 - assert len(ens._source.partitions[3]) == old_sizes[3] + assert ens.source.npartitions == 4 + assert ens.source.divisions == old_div + assert len(ens.source.partitions[0]) == old_sizes[0] + 3 + assert len(ens.source.partitions[1]) == old_sizes[1] + 5 + assert len(ens.source.partitions[2]) == old_sizes[2] + 2 + assert len(ens.source.partitions[3]) == old_sizes[3] def test_core_wrappers(parquet_ensemble): @@ -442,6 +611,9 @@ def test_core_wrappers(parquet_ensemble): # Just test if these execute successfully parquet_ensemble.client_info() parquet_ensemble.info() + parquet_ensemble.frame_info() + with pytest.raises(KeyError): + parquet_ensemble.frame_info(labels=["source", "invalid_label"]) parquet_ensemble.columns() parquet_ensemble.head(n=5) parquet_ensemble.tail(n=5) @@ -520,9 +692,9 @@ def test_persist(dask_client): ens.query("flux <= 1.5", table="source") # Compute the task graph size before and after the persist. - old_graph_size = len(ens._source.dask) + old_graph_size = len(ens.source.dask) ens.persist() - new_graph_size = len(ens._source.dask) + new_graph_size = len(ens.source.dask) assert new_graph_size < old_graph_size @@ -579,92 +751,192 @@ def test_update_column_map(dask_client): "parquet_ensemble_with_divisions", ], ) -def test_sync_tables(data_fixture, request): +@pytest.mark.parametrize("legacy", [True, False]) +def test_sync_tables(data_fixture, request, legacy): """ - Test that _sync_tables works as expected + Test that _sync_tables works as expected, using Ensemble-level APIs + when `legacy` is `True`, and EsnembleFrame APIs when `legacy` is `False`. """ - parquet_ensemble = request.getfixturevalue(data_fixture) - assert len(parquet_ensemble.compute("object")) == 15 - assert len(parquet_ensemble.compute("source")) == 2000 + if legacy: + assert len(parquet_ensemble.compute("object")) == 15 + assert len(parquet_ensemble.compute("source")) == 2000 + else: + assert len(parquet_ensemble.object.compute()) == 15 + assert len(parquet_ensemble.source.compute()) == 2000 parquet_ensemble.prune(50, col_name="nobs_r").prune(50, col_name="nobs_g") - assert parquet_ensemble._object_dirty # Prune should set the object dirty flag + assert parquet_ensemble.object.is_dirty() # Prune should set the object dirty flag - parquet_ensemble.dropna(table="source") - assert parquet_ensemble._source_dirty # Dropna should set the source dirty flag + if legacy: + assert len(parquet_ensemble.compute("object")) == 5 + else: + assert len(parquet_ensemble.object.compute()) == 5 + + if legacy: + parquet_ensemble.dropna(table="source") + else: + parquet_ensemble.source.dropna().update_ensemble() + assert parquet_ensemble.source.is_dirty() # Dropna should set the source dirty flag # Drop a whole object from Source to test that the object is dropped in the object table dropped_obj_id = 88472935274829959 - parquet_ensemble.query(f"{parquet_ensemble._id_col} != {dropped_obj_id}", table="source") + if legacy: + parquet_ensemble.query(f"{parquet_ensemble._id_col} != {dropped_obj_id}", table="source") + else: + filtered_src = parquet_ensemble.source.query(f"{parquet_ensemble._id_col} != 88472935274829959") - # Marks the Object table as dirty without triggering a sync. This is good to test since - # we always sync the object table first. - parquet_ensemble.dropna("object") + # Since we have not yet called update_ensemble, the compute call should not trigger + # a sync and the source table should remain dirty. + assert parquet_ensemble.source.is_dirty() + filtered_src.compute() + assert parquet_ensemble.source.is_dirty() + + # Update the ensemble to use the filtered source. + filtered_src.update_ensemble() # Verify that the object ID we removed from the source table is present in the object table - assert dropped_obj_id in parquet_ensemble._object.index.compute().values + assert dropped_obj_id in parquet_ensemble.object.index.compute().values # Perform an operation which should trigger syncing both tables. parquet_ensemble.compute() # Both tables should have the expected number of rows after a sync - assert len(parquet_ensemble.compute("object")) == 4 - assert len(parquet_ensemble.compute("source")) == 1063 + if legacy: + assert len(parquet_ensemble.compute("object")) == 4 + assert len(parquet_ensemble.compute("source")) == 1063 + else: + assert len(parquet_ensemble.object.compute()) == 4 + assert len(parquet_ensemble.source.compute()) == 1063 # Validate that the filtered object has been removed from both tables. - assert dropped_obj_id not in parquet_ensemble._source.index.compute().values - assert dropped_obj_id not in parquet_ensemble._object.index.compute().values + assert dropped_obj_id not in parquet_ensemble.source.index.compute().values + assert dropped_obj_id not in parquet_ensemble.object.index.compute().values # Dirty flags should be unset after sync - assert not parquet_ensemble._object_dirty - assert not parquet_ensemble._source_dirty + assert not parquet_ensemble.object.is_dirty() + assert not parquet_ensemble.source.is_dirty() # Make sure that divisions are preserved if data_fixture == "parquet_ensemble_with_divisions": - assert parquet_ensemble._source.known_divisions - assert parquet_ensemble._object.known_divisions + assert parquet_ensemble.source.known_divisions + assert parquet_ensemble.object.known_divisions -def test_lazy_sync_tables(parquet_ensemble): +@pytest.mark.parametrize("legacy", [True, False]) +def test_lazy_sync_tables(parquet_ensemble, legacy): """ - Test that _lazy_sync_tables works as expected + Test that _lazy_sync_tables works as expected, using Ensemble-level APIs + when `legacy` is `True`, and EsnembleFrame APIs when `legacy` is `False`. """ - assert len(parquet_ensemble.compute("object")) == 15 - assert len(parquet_ensemble.compute("source")) == 2000 + if legacy: + assert len(parquet_ensemble.compute("object")) == 15 + assert len(parquet_ensemble.compute("source")) == 2000 + else: + assert len(parquet_ensemble.object.compute()) == 15 + assert len(parquet_ensemble.source.compute()) == 2000 # Modify only the object table. parquet_ensemble.prune(50, col_name="nobs_r").prune(50, col_name="nobs_g") - assert parquet_ensemble._object_dirty - assert not parquet_ensemble._source_dirty + assert parquet_ensemble.object.is_dirty() + assert not parquet_ensemble.source.is_dirty() # For a lazy sync on the object table, nothing should change, because # it is already dirty. - parquet_ensemble._lazy_sync_tables(table="object") - assert parquet_ensemble._object_dirty - assert not parquet_ensemble._source_dirty + if legacy: + parquet_ensemble.compute("object") + else: + parquet_ensemble.object.compute() + assert parquet_ensemble.object.is_dirty() + assert not parquet_ensemble.source.is_dirty() # For a lazy sync on the source table, the source table should be updated. - parquet_ensemble._lazy_sync_tables(table="source") - assert not parquet_ensemble._object_dirty - assert not parquet_ensemble._source_dirty + if legacy: + parquet_ensemble.compute("source") + else: + parquet_ensemble.source.compute() + assert not parquet_ensemble.object.is_dirty() + assert not parquet_ensemble.source.is_dirty() # Modify only the source table. - parquet_ensemble.dropna(table="source") - assert not parquet_ensemble._object_dirty - assert parquet_ensemble._source_dirty + # Replace the maximum flux value with a NaN so that we will have a row to drop. + max_flux = max(parquet_ensemble.source[parquet_ensemble._flux_col]) + parquet_ensemble.source[parquet_ensemble._flux_col] = parquet_ensemble.source[ + parquet_ensemble._flux_col + ].apply(lambda x: np.nan if x == max_flux else x, meta=pd.Series(dtype=float)) + + assert not parquet_ensemble.object.is_dirty() + assert not parquet_ensemble.source.is_dirty() + + if legacy: + parquet_ensemble.dropna(table="source") + else: + parquet_ensemble.source.dropna().update_ensemble() + assert not parquet_ensemble.object.is_dirty() + assert parquet_ensemble.source.is_dirty() # For a lazy sync on the source table, nothing should change, because # it is already dirty. - parquet_ensemble._lazy_sync_tables(table="source") - assert not parquet_ensemble._object_dirty - assert parquet_ensemble._source_dirty + if legacy: + parquet_ensemble.compute("source") + else: + parquet_ensemble.source.compute() + assert not parquet_ensemble.object.is_dirty() + assert parquet_ensemble.source.is_dirty() # For a lazy sync on the source, the object table should be updated. - parquet_ensemble._lazy_sync_tables(table="object") - assert not parquet_ensemble._object_dirty - assert not parquet_ensemble._source_dirty + if legacy: + parquet_ensemble.compute("object") + else: + parquet_ensemble.object.compute() + assert not parquet_ensemble.object.is_dirty() + assert not parquet_ensemble.source.is_dirty() + + +def test_compute_triggers_syncing(parquet_ensemble): + """ + Tests that tape.EnsembleFrame.compute() only triggers an Ensemble sync if the + frame is the actively tracked source or object table of the Ensemble. + """ + # Test that an object table can trigger a sync that will clean a dirty + # source table. + parquet_ensemble.source.set_dirty(True) + updated_obj = parquet_ensemble.object.dropna() + + # Because we have not yet called update_ensemble(), a sync is not triggered + # and the source table remains dirty. + updated_obj.compute() + assert parquet_ensemble.source.is_dirty() + + # Update the Ensemble so that computing the object table will trigger + # a sync + updated_obj.update_ensemble() + updated_obj.compute() # Now equivalent to Ensemble.object.compute() + assert not parquet_ensemble.source.is_dirty() + + # Test that an source table can trigger a sync that will clean a dirty + # object table. + parquet_ensemble.object.set_dirty(True) + updated_src = parquet_ensemble.source.dropna() + + # Because we have not yet called update_ensemble(), a sync is not triggered + # and the object table remains dirty. + updated_src.compute() + assert parquet_ensemble.object.is_dirty() + + # Update the Ensemble so that computing the object table will trigger + # a sync + updated_src.update_ensemble() + updated_src.compute() # Now equivalent to Ensemble.source.compute() + assert not parquet_ensemble.object.is_dirty() + + # Generate a new Object frame and set the Ensemble to None to + # validate that we return a valid result even for untracked frames + # which cannot be synced. + new_obj_frame = parquet_ensemble.object.dropna() + new_obj_frame.ensemble = None + assert len(new_obj_frame.compute()) > 0 def test_temporary_cols(parquet_ensemble): @@ -673,7 +945,7 @@ def test_temporary_cols(parquet_ensemble): """ ens = parquet_ensemble - ens._object = ens._object.drop(columns=["nobs_r", "nobs_g", "nobs_total"]) + ens.update_frame(ens.object.drop(columns=["nobs_r", "nobs_g", "nobs_total"])) # Make sure temp lists are available but empty assert not len(ens._source_temp) @@ -683,29 +955,29 @@ def test_temporary_cols(parquet_ensemble): # nobs_total should be a temporary column assert "nobs_total" in ens._object_temp - assert "nobs_total" in ens._object.columns + assert "nobs_total" in ens.object.columns ens.assign(nobs2=lambda x: x["nobs_total"] * 2, table="object", temporary=True) # nobs2 should be a temporary column assert "nobs2" in ens._object_temp - assert "nobs2" in ens._object.columns + assert "nobs2" in ens.object.columns # drop NaNs from source, source should be dirty now ens.dropna(how="any", table="source") - assert ens._source_dirty + assert ens.source.is_dirty() # try a sync ens._sync_tables() # nobs_total should be removed from object assert "nobs_total" not in ens._object_temp - assert "nobs_total" not in ens._object.columns + assert "nobs_total" not in ens.object.columns # nobs2 should be removed from object assert "nobs2" not in ens._object_temp - assert "nobs2" not in ens._object.columns + assert "nobs2" not in ens.object.columns # add a source column that we manually set as dirty, don't have a function # that adds temporary source columns at the moment @@ -714,14 +986,77 @@ def test_temporary_cols(parquet_ensemble): # prune object, object should be dirty ens.prune(threshold=10) - assert ens._object_dirty + assert ens.object.is_dirty() # try a sync ens._sync_tables() # f2 should be removed from source assert "f2" not in ens._source_temp - assert "f2" not in ens._source.columns + assert "f2" not in ens.source.columns + + +def test_temporary_cols(parquet_ensemble): + """ + Test that temporary columns are tracked and dropped as expected. + """ + + ens = parquet_ensemble + ens.object = ens.object.drop(columns=["nobs_r", "nobs_g", "nobs_total"]) + + # Make sure temp lists are available but empty + assert not len(ens._source_temp) + assert not len(ens._object_temp) + + ens.calc_nobs(temporary=True) # Generates "nobs_total" + + # nobs_total should be a temporary column + assert "nobs_total" in ens._object_temp + assert "nobs_total" in ens.object.columns + + ens.assign(nobs2=lambda x: x["nobs_total"] * 2, table="object", temporary=True) + + # nobs2 should be a temporary column + assert "nobs2" in ens._object_temp + assert "nobs2" in ens.object.columns + + # Replace the maximum flux value with a NaN so that we will have a row to drop. + max_flux = max(parquet_ensemble.source[parquet_ensemble._flux_col]) + parquet_ensemble.source[parquet_ensemble._flux_col] = parquet_ensemble.source[ + parquet_ensemble._flux_col + ].apply(lambda x: np.nan if x == max_flux else x, meta=pd.Series(dtype=float)) + + # drop NaNs from source, source should be dirty now + ens.dropna(how="any", table="source") + + assert ens.source.is_dirty() + + # try a sync + ens._sync_tables() + + # nobs_total should be removed from object + assert "nobs_total" not in ens._object_temp + assert "nobs_total" not in ens.object.columns + + # nobs2 should be removed from object + assert "nobs2" not in ens._object_temp + assert "nobs2" not in ens.object.columns + + # add a source column that we manually set as dirty, don't have a function + # that adds temporary source columns at the moment + ens.assign(f2=lambda x: x[ens._flux_col] ** 2, table="source", temporary=True) + + # prune object, object should be dirty + ens.prune(threshold=10) + + assert ens.object.is_dirty() + + # try a sync + ens._sync_tables() + + # f2 should be removed from source + assert "f2" not in ens._source_temp + assert "f2" not in ens.source.columns @pytest.mark.parametrize( @@ -731,7 +1066,10 @@ def test_temporary_cols(parquet_ensemble): "parquet_ensemble_with_divisions", ], ) -def test_dropna(data_fixture, request): +@pytest.mark.parametrize("legacy", [True, False]) +def test_dropna(data_fixture, request, legacy): + """Tests dropna, using Ensemble.dropna when `legacy` is `True`, and + EnsembleFrame.dropna when `legacy` is `False`.""" parquet_ensemble = request.getfixturevalue(data_fixture) # Try passing in an unrecognized 'table' parameter and verify an exception is thrown @@ -739,13 +1077,15 @@ def test_dropna(data_fixture, request): parquet_ensemble.dropna(table="banana") # First test dropping na from the 'source' table - # - source_pdf = parquet_ensemble._source.compute() + source_pdf = parquet_ensemble.source.compute() source_length = len(source_pdf.index) # Try dropping NaNs from source and confirm nothing is dropped (there are no NaNs). - parquet_ensemble.dropna(table="source") - assert len(parquet_ensemble._source) == source_length + if legacy: + parquet_ensemble.dropna(table="source") + else: + parquet_ensemble.source.dropna().update_ensemble() + assert len(parquet_ensemble.source) == source_length # Get a valid ID to use and count its occurrences. valid_source_id = source_pdf.index.values[1] @@ -753,45 +1093,61 @@ def test_dropna(data_fixture, request): # Set the psFlux values for one source to NaN so we can drop it. # We do this on the instantiated source (pdf) and convert it back into a - # Dask DataFrame. + # SourceFrame. source_pdf.loc[valid_source_id, parquet_ensemble._flux_col] = pd.NA - parquet_ensemble._source = dd.from_pandas(source_pdf, npartitions=1) + parquet_ensemble.update_frame( + SourceFrame.from_tapeframe(TapeSourceFrame(source_pdf), label="source", npartitions=1) + ) # Try dropping NaNs from source and confirm that we did. - parquet_ensemble.dropna(table="source") - assert len(parquet_ensemble._source.compute().index) == source_length - occurrences_source + if legacy: + parquet_ensemble.dropna(table="source") + else: + parquet_ensemble.source.dropna().update_ensemble() + assert len(parquet_ensemble.source.compute().index) == source_length - occurrences_source if data_fixture == "parquet_ensemble_with_divisions": # divisions should be preserved - assert parquet_ensemble._source.known_divisions + assert parquet_ensemble.source.known_divisions # Now test dropping na from 'object' table + # Sync the tables + parquet_ensemble._sync_tables() - object_pdf = parquet_ensemble._object.compute() + # Sync (triggered by the compute) the table and check that the number of objects decreased. + object_pdf = parquet_ensemble.object.compute() object_length = len(object_pdf.index) # Try dropping NaNs from object and confirm nothing is dropped (there are no NaNs). - parquet_ensemble.dropna(table="object") - assert len(parquet_ensemble._object.compute().index) == object_length + if legacy: + parquet_ensemble.dropna(table="object") + else: + parquet_ensemble.object.dropna().update_ensemble() + assert len(parquet_ensemble.object.compute().index) == object_length # select an id from the object table valid_object_id = object_pdf.index.values[1] # Set the nobs_g values for one object to NaN so we can drop it. # We do this on the instantiated object (pdf) and convert it back into a - # Dask DataFrame. - object_pdf.loc[valid_object_id, parquet_ensemble._object.columns[0]] = pd.NA - parquet_ensemble._object = dd.from_pandas(object_pdf, npartitions=1) + # ObjectFrame. + object_pdf.loc[valid_object_id, parquet_ensemble.object.columns[0]] = pd.NA + parquet_ensemble.update_frame( + ObjectFrame.from_tapeframe(TapeObjectFrame(object_pdf), label="object", npartitions=1) + ) # Try dropping NaNs from object and confirm that we dropped a row - parquet_ensemble.dropna(table="object") - assert len(parquet_ensemble._object.compute().index) == object_length - 1 + if legacy: + parquet_ensemble.dropna(table="object") + else: + parquet_ensemble.object.dropna().update_ensemble() + assert len(parquet_ensemble.object.compute().index) == object_length - 1 if data_fixture == "parquet_ensemble_with_divisions": # divisions should be preserved - assert parquet_ensemble._object.known_divisions + assert parquet_ensemble.object.known_divisions - new_objects_pdf = parquet_ensemble._object.compute() + new_objects_pdf = parquet_ensemble.object.compute() assert len(new_objects_pdf.index) == len(object_pdf.index) - 1 # Assert the filtered ID is no longer in the objects. @@ -803,29 +1159,37 @@ def test_dropna(data_fixture, request): assert new_objects_pdf.loc[i, c] == object_pdf.loc[i, c] -def test_keep_zeros(parquet_ensemble): - """Test that we can sync the tables and keep objects with zero sources.""" +@pytest.mark.parametrize("legacy", [True, False]) +def test_keep_zeros(parquet_ensemble, legacy): + """Test that we can sync the tables and keep objects with zero sources, using + Ensemble.dropna when `legacy` is `True`, and EnsembleFrame.dropna when `legacy` is `False`.""" parquet_ensemble.keep_empty_objects = True - prev_npartitions = parquet_ensemble._object.npartitions - old_objects_pdf = parquet_ensemble._object.compute() - pdf = parquet_ensemble._source.compute() + prev_npartitions = parquet_ensemble.object.npartitions + old_objects_pdf = parquet_ensemble.object.compute() + pdf = parquet_ensemble.source.compute() # Set the psFlux values for one object to NaN so we can drop it. # We do this on the instantiated object (pdf) and convert it back into a # Dask DataFrame. valid_id = pdf.index.values[1] pdf.loc[valid_id, parquet_ensemble._flux_col] = pd.NA - parquet_ensemble._source = dd.from_pandas(pdf, npartitions=1) + parquet_ensemble.source = dd.from_pandas(pdf, npartitions=1) + parquet_ensemble.update_frame( + SourceFrame.from_tapeframe(TapeSourceFrame(pdf), npartitions=1, label="source") + ) # Sync the table and check that the number of objects decreased. - parquet_ensemble.dropna(table="source") + if legacy: + parquet_ensemble.dropna("source") + else: + parquet_ensemble.source.dropna().update_ensemble() parquet_ensemble._sync_tables() # Check that objects are preserved after sync - new_objects_pdf = parquet_ensemble._object.compute() + new_objects_pdf = parquet_ensemble.object.compute() assert len(new_objects_pdf.index) == len(old_objects_pdf.index) - assert parquet_ensemble._object.npartitions == prev_npartitions + assert parquet_ensemble.object.npartitions == prev_npartitions @pytest.mark.parametrize( @@ -842,29 +1206,29 @@ def test_calc_nobs(data_fixture, request, by_band, multi_partition): ens = request.getfixturevalue(data_fixture) if multi_partition: - ens._source = ens._source.repartition(3) + ens.source = ens.source.repartition(3) # Drop the existing nobs columns - ens._object = ens._object.drop(["nobs_g", "nobs_r", "nobs_total"], axis=1) + ens.object = ens.object.drop(["nobs_g", "nobs_r", "nobs_total"], axis=1) # Calculate nobs ens.calc_nobs(by_band) # Check that things turned out as we expect - lc = ens._object.loc[88472935274829959].compute() + lc = ens.object.loc[88472935274829959].compute() if by_band: - assert np.all([col in ens._object.columns for col in ["nobs_g", "nobs_r"]]) + assert np.all([col in ens.object.columns for col in ["nobs_g", "nobs_r"]]) assert lc["nobs_g"].values[0] == 98 assert lc["nobs_r"].values[0] == 401 - assert "nobs_total" in ens._object.columns + assert "nobs_total" in ens.object.columns assert lc["nobs_total"].values[0] == 499 # Make sure that if divisions were set previously, they are preserved if data_fixture == "parquet_ensemble_with_divisions": - assert ens._object.known_divisions - assert ens._source.known_divisions + assert ens.object.known_divisions + assert ens.source.known_divisions @pytest.mark.parametrize( @@ -887,19 +1251,19 @@ def test_prune(data_fixture, request, generate_nobs): # Generate the nobs cols from within prune if generate_nobs: # Drop the existing nobs columns - parquet_ensemble._object = parquet_ensemble._object.drop(["nobs_g", "nobs_r", "nobs_total"], axis=1) + parquet_ensemble.object = parquet_ensemble.object.drop(["nobs_g", "nobs_r", "nobs_total"], axis=1) parquet_ensemble.prune(threshold) # Use an existing column else: parquet_ensemble.prune(threshold, col_name="nobs_total") - assert not np.any(parquet_ensemble._object["nobs_total"].values < threshold) + assert not np.any(parquet_ensemble.object["nobs_total"].values < threshold) # Make sure that if divisions were set previously, they are preserved if data_fixture == "parquet_ensemble_with_divisions": - assert parquet_ensemble._source.known_divisions - assert parquet_ensemble._object.known_divisions + assert parquet_ensemble.source.known_divisions + assert parquet_ensemble.object.known_divisions def test_query(dask_client): @@ -941,7 +1305,7 @@ def test_filter_from_series(dask_client): ens.from_source_dict(rows, column_mapper=cmap, npartitions=2) # Filter the data set to low flux sources only. - keep_series = ens._source[ens._time_col] >= 250.0 + keep_series = ens.source[ens._time_col] >= 250.0 ens.filter_from_series(keep_series, table="source") # Check that all of the filtered rows are value. @@ -966,25 +1330,28 @@ def test_select(dask_client): } cmap = ColumnMapper(id_col="id", time_col="time", flux_col="flux", err_col="err", band_col="band") ens.from_source_dict(rows, column_mapper=cmap, npartitions=2) - assert len(ens._source.columns) == 5 - assert "time" in ens._source.columns - assert "flux" in ens._source.columns - assert "band" in ens._source.columns - assert "count" in ens._source.columns - assert "something_else" in ens._source.columns + assert len(ens.source.columns) == 5 + assert "time" in ens.source.columns + assert "flux" in ens.source.columns + assert "band" in ens.source.columns + assert "count" in ens.source.columns + assert "something_else" in ens.source.columns # Select on just time and flux ens.select(["time", "flux"], table="source") - assert len(ens._source.columns) == 2 - assert "time" in ens._source.columns - assert "flux" in ens._source.columns - assert "band" not in ens._source.columns - assert "count" not in ens._source.columns - assert "something_else" not in ens._source.columns + assert len(ens.source.columns) == 2 + assert "time" in ens.source.columns + assert "flux" in ens.source.columns + assert "band" not in ens.source.columns + assert "count" not in ens.source.columns + assert "something_else" not in ens.source.columns -def test_assign(dask_client): +@pytest.mark.parametrize("legacy", [True, False]) +def test_assign(dask_client, legacy): + """Tests assign for column-manipulation, using Ensemble.assign when `legacy` is `True`, + and EnsembleFrame.assign when `legacy` is `False`.""" ens = Ensemble(client=dask_client) num_points = 1000 @@ -998,29 +1365,35 @@ def test_assign(dask_client): } cmap = ColumnMapper(id_col="id", time_col="time", flux_col="flux", err_col="err", band_col="band") ens.from_source_dict(rows, column_mapper=cmap, npartitions=1) - assert len(ens._source.columns) == 4 - assert "lower_bnd" not in ens._source.columns + assert len(ens.source.columns) == 4 + assert "lower_bnd" not in ens.source.columns # Insert a new column for the "lower bound" computation. - ens.assign(table="source", lower_bnd=lambda x: x["flux"] - 2.0 * x["err"]) - assert len(ens._source.columns) == 5 - assert "lower_bnd" in ens._source.columns + if legacy: + ens.assign(table="source", lower_bnd=lambda x: x["flux"] - 2.0 * x["err"]) + else: + ens.source.assign(lower_bnd=lambda x: x["flux"] - 2.0 * x["err"]).update_ensemble() + assert len(ens.source.columns) == 5 + assert "lower_bnd" in ens.source.columns # Check the values in the new column. - new_source = ens.compute(table="source") + new_source = ens.source.compute() if not legacy else ens.compute(table="source") assert new_source.shape[0] == 1000 for i in range(1000): expected = new_source.iloc[i]["flux"] - 2.0 * new_source.iloc[i]["err"] assert new_source.iloc[i]["lower_bnd"] == expected # Create a series directly from the table. - res_col = ens._source["band"] + "2" - ens.assign(table="source", band2=res_col) - assert len(ens._source.columns) == 6 - assert "band2" in ens._source.columns + res_col = ens.source["band"] + "2" + if legacy: + ens.assign(table="source", band2=res_col) + else: + ens.source.assign(band2=res_col).update_ensemble() + assert len(ens.source.columns) == 6 + assert "band2" in ens.source.columns # Check the values in the new column. - new_source = ens.compute(table="source") + new_source = ens.source.compute() if not legacy else ens.compute(table="source") for i in range(1000): assert new_source.iloc[i]["band2"] == new_source.iloc[i]["band"] + "2" @@ -1048,7 +1421,7 @@ def test_coalesce(dask_client, drop_inputs): ens.coalesce(["flux1", "flux2", "flux3"], "flux", table="source", drop_inputs=drop_inputs) # Coalesce should return this exact flux array - assert list(ens._source["flux"].values.compute()) == [5.0, 3.0, 4.0, 10.0, 7.0] + assert list(ens.source["flux"].values.compute()) == [5.0, 3.0, 4.0, 10.0, 7.0] if drop_inputs: # The column mapping should be updated @@ -1056,7 +1429,7 @@ def test_coalesce(dask_client, drop_inputs): # The columns to drop should be dropped for col in ["flux1", "flux2", "flux3"]: - assert col not in ens._source.columns + assert col not in ens.source.columns # Test for the drop warning with pytest.warns(UserWarning): @@ -1065,7 +1438,7 @@ def test_coalesce(dask_client, drop_inputs): else: # The input columns should still be present for col in ["flux1", "flux2", "flux3"]: - assert col in ens._source.columns + assert col in ens.source.columns @pytest.mark.parametrize("zero_point", [("zp_mag", "zp_flux"), (25.0, 10**10)]) @@ -1096,19 +1469,19 @@ def test_convert_flux_to_mag(dask_client, zero_point, zp_form, out_col_name): if zp_form == "flux": ens.convert_flux_to_mag(zero_point[1], zp_form, out_col_name) - res_mag = ens._source.compute()[output_column].to_list()[0] + res_mag = ens.source.compute()[output_column].to_list()[0] assert pytest.approx(res_mag, 0.001) == 21.28925 - res_err = ens._source.compute()[output_column + "_err"].to_list()[0] + res_err = ens.source.compute()[output_column + "_err"].to_list()[0] assert pytest.approx(res_err, 0.001) == 0.355979 elif zp_form == "mag" or zp_form == "magnitude": ens.convert_flux_to_mag(zero_point[0], zp_form, out_col_name) - res_mag = ens._source.compute()[output_column].to_list()[0] + res_mag = ens.source.compute()[output_column].to_list()[0] assert pytest.approx(res_mag, 0.001) == 21.28925 - res_err = ens._source.compute()[output_column + "_err"].to_list()[0] + res_err = ens.source.compute()[output_column + "_err"].to_list()[0] assert pytest.approx(res_err, 0.001) == 0.355979 else: @@ -1252,17 +1625,23 @@ def test_batch(data_fixture, request, use_map, on): """ Test that ensemble.batch() returns the correct values of the first result """ - parquet_ensemble = request.getfixturevalue(data_fixture) + frame_cnt = len(parquet_ensemble.frames) result = ( parquet_ensemble.prune(10) .dropna(table="source") - .batch(calc_stetson_J, use_map=use_map, on=on, band_to_calc=None, compute=False) + .batch(calc_stetson_J, use_map=use_map, on=on, band_to_calc=None, compute=False, label="stetson_j") ) + # Validate that the ensemble is now tracking a new result frame. + assert len(parquet_ensemble.frames) == frame_cnt + 1 + tracked_result = parquet_ensemble.select_frame("stetson_j") + assert isinstance(tracked_result, EnsembleSeries) + assert result is tracked_result + # Make sure that divisions information is propagated if known - if parquet_ensemble._source.known_divisions and parquet_ensemble._object.known_divisions: + if parquet_ensemble.source.known_divisions and parquet_ensemble.object.known_divisions: assert result.known_divisions result = result.compute() @@ -1278,6 +1657,43 @@ def test_batch(data_fixture, request, use_map, on): assert pytest.approx(result.values[1]["r"], 0.001) == -0.49639028 +def test_batch_labels(parquet_ensemble): + """ + Test that ensemble.batch() generates unique labels for result frames when none are provided. + """ + # Since no label was provided we generate a label of "result_1" + parquet_ensemble.prune(10).batch(np.mean, parquet_ensemble._flux_col) + assert "result_1" in parquet_ensemble.frames + assert len(parquet_ensemble.select_frame("result_1")) > 0 + + # Now give a user-provided custom label. + parquet_ensemble.batch(np.mean, parquet_ensemble._flux_col, label="flux_mean") + assert "flux_mean" in parquet_ensemble.frames + assert len(parquet_ensemble.select_frame("flux_mean")) > 0 + + # Since this is the second batch call where a label is *not* provided, we generate label "result_2" + parquet_ensemble.batch(np.mean, parquet_ensemble._flux_col) + assert "result_2" in parquet_ensemble.frames + assert len(parquet_ensemble.select_frame("result_2")) > 0 + + # Explicitly provide label "result_3" + parquet_ensemble.batch(np.mean, parquet_ensemble._flux_col, label="result_3") + assert "result_3" in parquet_ensemble.frames + assert len(parquet_ensemble.select_frame("result_3")) > 0 + + # Validate that the next generated label is "result_4" since "result_3" is taken. + parquet_ensemble.batch(np.mean, parquet_ensemble._flux_col) + assert "result_4" in parquet_ensemble.frames + assert len(parquet_ensemble.select_frame("result_4")) > 0 + + frame_cnt = len(parquet_ensemble.frames) + + # Validate that when the label is None, the result frame isn't tracked by the Ensemble.s + result = parquet_ensemble.batch(np.mean, parquet_ensemble._flux_col, label=None) + assert frame_cnt == len(parquet_ensemble.frames) + assert len(result) > 0 + + def test_batch_with_custom_func(parquet_ensemble): """ Test Ensemble.batch with a custom analysis function @@ -1287,6 +1703,67 @@ def test_batch_with_custom_func(parquet_ensemble): assert len(result) > 0 +@pytest.mark.parametrize( + "custom_meta", + [ + ("flux_mean", float), # A tuple representing a series + pd.Series(name="flux_mean_pandas", dtype="float64"), + TapeSeries(name="flux_mean_tape", dtype="float64"), + ], +) +def test_batch_with_custom_series_meta(parquet_ensemble, custom_meta): + """ + Test Ensemble.batch with various styles of output meta for a Series-style result. + """ + num_frames = len(parquet_ensemble.frames) + + parquet_ensemble.prune(10).batch(np.mean, parquet_ensemble._flux_col, meta=custom_meta, label="flux_mean") + + assert len(parquet_ensemble.frames) == num_frames + 1 + assert len(parquet_ensemble.select_frame("flux_mean")) > 0 + assert isinstance(parquet_ensemble.select_frame("flux_mean"), EnsembleSeries) + + +@pytest.mark.parametrize( + "custom_meta", + [ + {"lc_id": int, "band": str, "dt": float, "sf2": float, "1_sigma": float}, + [("lc_id", int), ("band", str), ("dt", float), ("sf2", float), ("1_sigma", float)], + pd.DataFrame( + { + "lc_id": pd.Series([], dtype=int), + "band": pd.Series([], dtype=str), + "dt": pd.Series([], dtype=float), + "sf2": pd.Series([], dtype=float), + "1_sigma": pd.Series([], dtype=float), + } + ), + TapeFrame( + { + "lc_id": pd.Series([], dtype=int), + "band": pd.Series([], dtype=str), + "dt": pd.Series([], dtype=float), + "sf2": pd.Series([], dtype=float), + "1_sigma": pd.Series([], dtype=float), + } + ), + ], +) +def test_batch_with_custom_frame_meta(parquet_ensemble, custom_meta): + """ + Test Ensemble.batch with various sytles of output meta for a DataFrame-style result. + """ + num_frames = len(parquet_ensemble.frames) + + parquet_ensemble.prune(10).batch( + calc_sf2, parquet_ensemble._flux_col, meta=custom_meta, label="sf2_result" + ) + + assert len(parquet_ensemble.frames) == num_frames + 1 + assert len(parquet_ensemble.select_frame("sf2_result")) > 0 + assert isinstance(parquet_ensemble.select_frame("sf2_result"), EnsembleFrame) + + def test_to_timeseries(parquet_ensemble): """ Test that ensemble.to_timeseries() runs and assigns the correct metadata @@ -1346,7 +1823,7 @@ def test_sf2(data_fixture, request, method, combine, sthresh, use_map=False): res_sf2 = parquet_ensemble.sf2(argument_container=arg_container, use_map=use_map) res_batch = parquet_ensemble.batch(calc_sf2, use_map=use_map, argument_container=arg_container) - if parquet_ensemble._source.known_divisions and parquet_ensemble._object.known_divisions: + if parquet_ensemble.source.known_divisions and parquet_ensemble.object.known_divisions: if not combine: assert res_sf2.known_divisions diff --git a/tests/tape_tests/test_ensemble_frame.py b/tests/tape_tests/test_ensemble_frame.py new file mode 100644 index 00000000..1937f457 --- /dev/null +++ b/tests/tape_tests/test_ensemble_frame.py @@ -0,0 +1,360 @@ +""" Test EnsembleFrame (inherited from Dask.DataFrame) creation and manipulations. """ +import numpy as np +import pandas as pd +from tape import ( + ColumnMapper, + EnsembleFrame, + ObjectFrame, + SourceFrame, + TapeObjectFrame, + TapeSourceFrame, + TapeFrame, +) + +import pytest + +TEST_LABEL = "test_frame" +SOURCE_LABEL = "source" +OBJECT_LABEL = "object" + + +# pylint: disable=protected-access +@pytest.mark.parametrize( + "data_fixture", + [ + "ensemble_from_source_dict", + ], +) +def test_from_dict(data_fixture, request): + """ + Test creating an EnsembleFrame from a dictionary and verify that dask lazy evaluation was appropriately inherited. + """ + _, data = request.getfixturevalue(data_fixture) + ens_frame = EnsembleFrame.from_dict(data, npartitions=1) + + assert isinstance(ens_frame, EnsembleFrame) + assert isinstance(ens_frame._meta, TapeFrame) + + # The calculation for finding the max flux from the data. Note that the + # inherited dask compute method must be called to obtain the result. + assert ens_frame.flux.max().compute() == 80.6 + + +@pytest.mark.parametrize( + "data_fixture", + [ + "ensemble_from_source_dict", + ], +) +def test_from_pandas(data_fixture, request): + """ + Test creating an EnsembleFrame from a Pandas dataframe and verify that dask lazy evaluation was appropriately inherited. + """ + ens, data = request.getfixturevalue(data_fixture) + frame = TapeFrame(data) + ens_frame = EnsembleFrame.from_tapeframe(frame, label=TEST_LABEL, ensemble=ens, npartitions=1) + + assert isinstance(ens_frame, EnsembleFrame) + assert isinstance(ens_frame._meta, TapeFrame) + assert ens_frame.label == TEST_LABEL + assert ens_frame.ensemble is ens + + # The calculation for finding the max flux from the data. Note that the + # inherited dask compute method must be called to obtain the result. + assert ens_frame.flux.max().compute() == 80.6 + + +def test_from_parquet(): + """ + Test creating an EnsembleFrame from a parquet file. + """ + frame = EnsembleFrame.from_parquet( + "tests/tape_tests/data/source/test_source.parquet", label=TEST_LABEL, ensemble=None + ) + assert isinstance(frame, EnsembleFrame) + assert isinstance(frame._meta, TapeFrame) + assert frame.label == TEST_LABEL + assert frame.ensemble is None + + # Validate that we loaded a non-empty frame. + assert len(frame) > 0 + + +@pytest.mark.parametrize( + "data_fixture", + [ + "ensemble_from_source_dict", + ], +) +def test_ensemble_frame_propagation(data_fixture, request): + """ + Test ensuring that slices and copies of an EnsembleFrame or still the same class. + """ + ens, data = request.getfixturevalue(data_fixture) + ens_frame = EnsembleFrame.from_dict(data, npartitions=1) + # Set a label and ensemble for the frame and copies/transformations retain them. + ens_frame.label = TEST_LABEL + ens_frame.ensemble = ens + assert not ens_frame.is_dirty() + ens_frame.set_dirty(True) + + # Create a copy of an EnsembleFrame and verify that it's still a proper + # EnsembleFrame with appropriate metadata propagated. + copied_frame = ens_frame.copy() + assert isinstance(copied_frame, EnsembleFrame) + assert isinstance(copied_frame._meta, TapeFrame) + assert copied_frame.label == TEST_LABEL + assert copied_frame.ensemble == ens + assert copied_frame.is_dirty() + + # Verify that the above is also true by calling copy via map_partitions + mapped_frame = ens_frame.copy().map_partitions(lambda x: x.copy()) + assert isinstance(mapped_frame, EnsembleFrame) + assert isinstance(mapped_frame._meta, TapeFrame) + assert mapped_frame.label == TEST_LABEL + assert mapped_frame.ensemble == ens + assert mapped_frame.is_dirty() + + # Test that a filtered EnsembleFrame is still an EnsembleFrame. + filtered_frame = ens_frame[["id", "time"]] + assert isinstance(filtered_frame, EnsembleFrame) + assert isinstance(filtered_frame._meta, TapeFrame) + assert filtered_frame.label == TEST_LABEL + assert filtered_frame.ensemble == ens + assert filtered_frame.is_dirty() + + # Test that the output of an EnsembleFrame query is still an EnsembleFrame + queried_rows = ens_frame.query("flux > 3.0") + assert isinstance(queried_rows, EnsembleFrame) + assert isinstance(queried_rows._meta, TapeFrame) + assert queried_rows.label == TEST_LABEL + assert queried_rows.ensemble == ens + assert queried_rows.is_dirty() + + # Test merging two subsets of the dataframe, dropping some columns, and persisting the result. + merged_frame = ens_frame.copy()[["id", "time", "error"]].merge( + ens_frame.copy()[["id", "time", "flux"]], on=["id"], suffixes=(None, "_drop_me") + ) + cols_to_drop = [col for col in merged_frame.columns if "_drop_me" in col] + merged_frame = merged_frame.drop(cols_to_drop, axis=1).persist() + assert isinstance(merged_frame, EnsembleFrame) + assert merged_frame.label == TEST_LABEL + assert merged_frame.ensemble == ens + assert merged_frame.is_dirty() + + # Test that head returns a subset of the underlying TapeFrame. + h = ens_frame.head(5) + assert isinstance(h, TapeFrame) + assert len(h) == 5 + + # Test that the inherited dask.DataFrame.compute method returns + # the underlying TapeFrame. + assert isinstance(ens_frame.compute(), TapeFrame) + assert len(ens_frame) == len(ens_frame.compute()) + + # Set an index and then group by that index. + ens_frame = ens_frame.set_index("id", drop=True) + assert ens_frame.label == TEST_LABEL + assert ens_frame.ensemble == ens + group_result = ens_frame.groupby(["id"]).count() + assert len(group_result) > 0 + assert isinstance(group_result, EnsembleFrame) + assert isinstance(group_result._meta, TapeFrame) + + +@pytest.mark.parametrize( + "data_fixture", + [ + "ensemble_from_source_dict", + ], +) +@pytest.mark.parametrize("err_col", [None, "error"]) +@pytest.mark.parametrize("zp_form", ["flux", "mag", "magnitude", "lincc"]) +@pytest.mark.parametrize("out_col_name", [None, "mag"]) +def test_convert_flux_to_mag(data_fixture, request, err_col, zp_form, out_col_name): + ens, data = request.getfixturevalue(data_fixture) + + if out_col_name is None: + output_column = "flux_mag" + else: + output_column = out_col_name + + ens_frame = EnsembleFrame.from_dict(data, npartitions=1) + ens_frame.label = TEST_LABEL + ens_frame.ensemble = ens + + if zp_form == "flux": + ens_frame = ens_frame.convert_flux_to_mag("flux", "zp_flux", err_col, zp_form, out_col_name) + + res_mag = ens_frame.compute()[output_column].to_list()[0] + assert pytest.approx(res_mag, 0.001) == 21.28925 + + if err_col is not None: + res_err = ens_frame.compute()[output_column + "_err"].to_list()[0] + assert pytest.approx(res_err, 0.001) == 0.355979 + else: + assert output_column + "_err" not in ens_frame.columns + + elif zp_form == "mag" or zp_form == "magnitude": + ens_frame = ens_frame.convert_flux_to_mag("flux", "zp_mag", err_col, zp_form, out_col_name) + + res_mag = ens_frame.compute()[output_column].to_list()[0] + assert pytest.approx(res_mag, 0.001) == 21.28925 + + if err_col is not None: + res_err = ens_frame.compute()[output_column + "_err"].to_list()[0] + assert pytest.approx(res_err, 0.001) == 0.355979 + else: + assert output_column + "_err" not in ens_frame.columns + + else: + with pytest.raises(ValueError): + ens_frame = ens_frame.convert_flux_to_mag("flux", "zp_mag", err_col, zp_form, "mag") + + # Verify that if we converted to a new frame, it's still an EnsembleFrame. + assert isinstance(ens_frame, EnsembleFrame) + assert ens_frame.label == TEST_LABEL + assert ens_frame.ensemble is ens + + +@pytest.mark.parametrize( + "data_fixture", + [ + "parquet_files_and_ensemble_without_client", + ], +) +def test_object_and_source_frame_propagation(data_fixture, request): + """ + Test that SourceFrame and ObjectFrame metadata and class type is correctly preserved across + typical Pandas operations. + """ + ens, source_file, object_file, _ = request.getfixturevalue(data_fixture) + + assert ens is not None + + # Create a SourceFrame from a parquet file + source_frame = SourceFrame.from_parquet(source_file, ensemble=ens) + + assert isinstance(source_frame, EnsembleFrame) + assert isinstance(source_frame, SourceFrame) + assert isinstance(source_frame._meta, TapeSourceFrame) + + assert source_frame.ensemble is not None + assert source_frame.ensemble == ens + assert source_frame.ensemble is ens + + assert not source_frame.is_dirty() + source_frame.set_dirty(True) + + # Perform a series of operations on the SourceFrame and then verify the result is still a + # proper SourceFrame with appropriate metadata propagated. + source_frame["psFlux"].mean().compute() + result_source_frame = source_frame.copy()[["psFlux", "psFluxErr"]] + result_source_frame = result_source_frame.map_partitions(lambda x: x.copy()) + assert isinstance(result_source_frame, SourceFrame) + assert isinstance(result_source_frame._meta, TapeSourceFrame) + assert len(result_source_frame) > 0 + assert result_source_frame.label == SOURCE_LABEL + assert result_source_frame.ensemble is not None + assert result_source_frame.ensemble is ens + assert result_source_frame.is_dirty() + + # Mark the frame clean to verify that we propagate that state as well + result_source_frame.set_dirty(False) + + # Set an index and then group by that index. + result_source_frame = result_source_frame.set_index("psFlux", drop=True) + assert result_source_frame.label == SOURCE_LABEL + assert result_source_frame.ensemble == ens + assert not result_source_frame.is_dirty() # frame is still clean. + group_result = result_source_frame.groupby(["psFlux"]).count() + assert len(group_result) > 0 + assert isinstance(group_result, SourceFrame) + assert isinstance(group_result._meta, TapeSourceFrame) + + # Create an ObjectFrame from a parquet file + object_frame = ObjectFrame.from_parquet( + object_file, + ensemble=ens, + index="ps1_objid", + ) + + assert isinstance(object_frame, EnsembleFrame) + assert isinstance(object_frame, ObjectFrame) + assert isinstance(object_frame._meta, TapeObjectFrame) + + assert not object_frame.is_dirty() + object_frame.set_dirty(True) + # Verify that the source frame stays clean when object frame is marked dirty. + assert not result_source_frame.is_dirty() + + # Perform a series of operations on the ObjectFrame and then verify the result is still a + # proper ObjectFrame with appropriate metadata propagated. + result_object_frame = object_frame.copy()[["nobs_g", "nobs_total"]] + result_object_frame = result_object_frame.map_partitions(lambda x: x.copy()) + assert isinstance(result_object_frame, ObjectFrame) + assert isinstance(result_object_frame._meta, TapeObjectFrame) + assert result_object_frame.label == OBJECT_LABEL + assert result_object_frame.ensemble is ens + assert result_object_frame.is_dirty() + + # Mark the frame clean to verify that we propagate that state as well + result_object_frame.set_dirty(False) + + # Set an index and then group by that index. + result_object_frame = result_object_frame.set_index("nobs_g", drop=True) + assert result_object_frame.label == OBJECT_LABEL + assert result_object_frame.ensemble == ens + assert not result_object_frame.is_dirty() # frame is still clean + group_result = result_object_frame.groupby(["nobs_g"]).count() + assert len(group_result) > 0 + assert isinstance(group_result, ObjectFrame) + assert isinstance(group_result._meta, TapeObjectFrame) + + # Test merging source and object frames, dropping some columns, and persisting the result. + merged_frame = source_frame.copy().merge( + object_frame.copy(), on=[ens._id_col], suffixes=(None, "_drop_me") + ) + cols_to_drop = [col for col in merged_frame.columns if "_drop_me" in col] + merged_frame = merged_frame.drop(cols_to_drop, axis=1).persist() + assert isinstance(merged_frame, SourceFrame) + assert merged_frame.label == SOURCE_LABEL + assert merged_frame.ensemble == ens + assert merged_frame.is_dirty() + + +def test_object_and_source_joins(parquet_ensemble): + """ + Test that SourceFrame and ObjectFrame metadata and class type are correctly propagated across + joins. + """ + # Get Source and object frames to test joins on. + source_frame, object_frame = parquet_ensemble.source.copy(), parquet_ensemble.object.copy() + + # Verify their metadata was preserved in the copy() + assert source_frame.label == SOURCE_LABEL + assert source_frame.ensemble is parquet_ensemble + assert object_frame.label == OBJECT_LABEL + assert object_frame.ensemble is parquet_ensemble + + # Join a SourceFrame (left) with an ObjectFrame (right) + # Validate that metadata is preserved and the outputted object is a SourceFrame + joined_source = source_frame.join(object_frame, how="left") + assert joined_source.label is SOURCE_LABEL + assert type(joined_source) is SourceFrame + assert joined_source.ensemble is parquet_ensemble + + # Now the same form of join (in terms of left/right) but produce an ObjectFrame. This is + # because frame1.join(frame2) will yield frame1's type regardless of left vs right. + assert type(object_frame.join(source_frame, how="right")) is ObjectFrame + + # Join an ObjectFrame (left) with a SourceFrame (right) + # Validate that metadata is preserved and the outputted object is an ObjectFrame + joined_object = object_frame.join(source_frame, how="left") + assert joined_object.label is OBJECT_LABEL + assert type(joined_object) is ObjectFrame + assert joined_object.ensemble is parquet_ensemble + + # Now the same form of join (in terms of left/right) but produce a SourceFrame. This is + # because frame1.join(frame2) will yield frame1's type regardless of left vs right. + assert type(source_frame.join(object_frame, how="right")) is SourceFrame