diff --git a/src/tdastro/math_nodes/given_sampler.py b/src/tdastro/math_nodes/given_sampler.py index b62dbae..6e8a1e0 100644 --- a/src/tdastro/math_nodes/given_sampler.py +++ b/src/tdastro/math_nodes/given_sampler.py @@ -1,6 +1,11 @@ -"""A sampler used for testing that produces known (given) results.""" +"""Samplers used for testing that produces precomputed results. These +can be used in testing to produce known results or to use data previously +sampled from another method (such as pzflow). +""" import numpy as np +import pandas as pd +from astropy.table import Table from tdastro.base_models import FunctionNode @@ -66,3 +71,85 @@ def compute(self, graph_state, rng_info=None, **kwargs): # Save and return the results. self._save_results(results, graph_state) return results + + +class TableSampler(FunctionNode): + """A FunctionNode that returns values from a table, including + a Pandas DataFrame or AstroPy Table. + + Attributes + ---------- + data : pandas.DataFrame, astropy.table.Table, or dict + The object containing the data to sample. + columns : list of str + The column names for the output columns. + next_ind : int + The next index to sample. + """ + + def __init__(self, data, node_label=None, **kwargs): + self.next_ind = 0 + + if isinstance(data, dict): + self.data = pd.DataFrame(data) + elif isinstance(data, Table): + self.data = data.to_pandas() + elif isinstance(data, pd.DataFrame): + self.data = data.copy() + else: + raise TypeError("Unsupported data type for TableSampler.") + + # Add each of the flow's data columns as an output parameter. + self.columns = [x for x in self.data.columns] + super().__init__(self._non_func, node_label=node_label, outputs=self.columns, **kwargs) + + def _non_func(self): + """This function does nothing. Everything happens in the overloaded compute().""" + pass + + def reset(self): + """Reset the next index to use.""" + self.next_ind = 0 + + def compute(self, graph_state, rng_info=None, **kwargs): + """Return the given values. + + Parameters + ---------- + graph_state : `GraphState` + An object mapping graph parameters to their values. This object is modified + in place as it is sampled. + rng_info : numpy.random._generator.Generator, optional + Unused in this function, but included to provide consistency with other + compute functions. + **kwargs : `dict`, optional + Additional function arguments. + + Returns + ------- + results : any + The result of the computation. This return value is provided so that testing + functions can easily access the results. + """ + # Check that we have enough points left to sample. + end_index = self.next_ind + graph_state.num_samples + if end_index > len(self.data): + raise IndexError() + + # Extract the table for [self.next_ind, end_index) and move + # the index counter. + samples = self.data[self.next_ind : end_index] + self.next_ind = end_index + + # Parse out each column in the flow samples as a result vector. + results = [] + for attr_name in self.columns: + attr_values = samples[attr_name].values + if graph_state.num_samples == 1: + results.append(attr_values[0]) + else: + results.append(np.array(attr_values)) + + # Save and return the results. + self._save_results(results, graph_state) + return results diff --git a/tests/tdastro/math_nodes/test_given_sampler.py b/tests/tdastro/math_nodes/test_given_sampler.py index d67c5c3..45b40b4 100644 --- a/tests/tdastro/math_nodes/test_given_sampler.py +++ b/tests/tdastro/math_nodes/test_given_sampler.py @@ -1,8 +1,10 @@ import numpy as np +import pandas as pd import pytest +from astropy.table import Table from tdastro.base_models import FunctionNode from tdastro.graph_state import GraphState -from tdastro.math_nodes.given_sampler import GivenSampler +from tdastro.math_nodes.given_sampler import GivenSampler, TableSampler def _test_func(value1, value2): @@ -72,3 +74,56 @@ def test_test_given_sampler_compound(): compound_node.get_param(state2, "function_node_result"), [3.0, 3.5, 4.0, 4.5, 5.0, 1.0, 5.5, 6.0], ) + + +@pytest.mark.parametrize("test_data_type", ["dict", "ap_table", "pd_df"]) +def test_table_sampler(test_data_type): + """Test that we can retrieve numbers from a TableSampler from a + dictionary, AstroPy Table, and Panda's DataFrame.""" + raw_data_dict = { + "A": [1, 2, 3, 4, 5, 6, 7, 8], + "B": [1, 1, 1, 1, 1, 1, 1, 1], + "C": [3, 4, 5, 6, 7, 8, 9, 10], + } + + # Convert the data type depending on the parameterized value. + if test_data_type == "dict": + data = raw_data_dict + elif test_data_type == "ap_table": + data = Table(raw_data_dict) + elif test_data_type == "pd_df": + data = pd.DataFrame(raw_data_dict) + else: + data = None + + # Create the table sampler from the data. + table_node = TableSampler(data, node_label="node") + state = table_node.sample_parameters(num_samples=2) + assert len(state) == 3 + assert np.allclose(state["node"]["A"], [1, 2]) + assert np.allclose(state["node"]["B"], [1, 1]) + assert np.allclose(state["node"]["C"], [3, 4]) + + state = table_node.sample_parameters(num_samples=1) + assert len(state) == 3 + assert state["node"]["A"] == 3 + assert state["node"]["B"] == 1 + assert state["node"]["C"] == 5 + + state = table_node.sample_parameters(num_samples=4) + assert len(state) == 3 + assert np.allclose(state["node"]["A"], [4, 5, 6, 7]) + assert np.allclose(state["node"]["B"], [1, 1, 1, 1]) + assert np.allclose(state["node"]["C"], [6, 7, 8, 9]) + + # We go past the end of the data. + with pytest.raises(IndexError): + _ = table_node.sample_parameters(num_samples=4) + + # We can reset and sample from the beginning. + table_node.reset() + state = table_node.sample_parameters(num_samples=2) + assert len(state) == 3 + assert np.allclose(state["node"]["A"], [1, 2]) + assert np.allclose(state["node"]["B"], [1, 1]) + assert np.allclose(state["node"]["C"], [3, 4])