Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to pull information from Pandas DataFrame or AstroPy Table #183

Merged
merged 1 commit into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 88 additions & 1 deletion src/tdastro/math_nodes/given_sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
"""A sampler used for testing that produces known (given) results."""
"""Samplers used for testing that produces precomputed results. These
can be used in testing to produce known results or to use data previously
sampled from another method (such as pzflow).
"""

import numpy as np
import pandas as pd
from astropy.table import Table

from tdastro.base_models import FunctionNode

Expand Down Expand Up @@ -66,3 +71,85 @@ def compute(self, graph_state, rng_info=None, **kwargs):
# Save and return the results.
self._save_results(results, graph_state)
return results


class TableSampler(FunctionNode):
"""A FunctionNode that returns values from a table, including
a Pandas DataFrame or AstroPy Table.

Attributes
----------
data : pandas.DataFrame, astropy.table.Table, or dict
The object containing the data to sample.
columns : list of str
The column names for the output columns.
next_ind : int
The next index to sample.
"""

def __init__(self, data, node_label=None, **kwargs):
self.next_ind = 0

if isinstance(data, dict):
self.data = pd.DataFrame(data)
elif isinstance(data, Table):
self.data = data.to_pandas()
elif isinstance(data, pd.DataFrame):
self.data = data.copy()
else:
raise TypeError("Unsupported data type for TableSampler.")

# Add each of the flow's data columns as an output parameter.
self.columns = [x for x in self.data.columns]
super().__init__(self._non_func, node_label=node_label, outputs=self.columns, **kwargs)

def _non_func(self):
"""This function does nothing. Everything happens in the overloaded compute()."""
pass

def reset(self):
"""Reset the next index to use."""
self.next_ind = 0

def compute(self, graph_state, rng_info=None, **kwargs):
"""Return the given values.

Parameters
----------
graph_state : `GraphState`
An object mapping graph parameters to their values. This object is modified
in place as it is sampled.
rng_info : numpy.random._generator.Generator, optional
Unused in this function, but included to provide consistency with other
compute functions.
**kwargs : `dict`, optional
Additional function arguments.

Returns
-------
results : any
The result of the computation. This return value is provided so that testing
functions can easily access the results.
"""
# Check that we have enough points left to sample.
end_index = self.next_ind + graph_state.num_samples
if end_index > len(self.data):
raise IndexError()

# Extract the table for [self.next_ind, end_index) and move
# the index counter.
samples = self.data[self.next_ind : end_index]
self.next_ind = end_index

# Parse out each column in the flow samples as a result vector.
results = []
for attr_name in self.columns:
attr_values = samples[attr_name].values
if graph_state.num_samples == 1:
results.append(attr_values[0])
else:
results.append(np.array(attr_values))

# Save and return the results.
self._save_results(results, graph_state)
return results
57 changes: 56 additions & 1 deletion tests/tdastro/math_nodes/test_given_sampler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
import pandas as pd
import pytest
from astropy.table import Table
from tdastro.base_models import FunctionNode
from tdastro.graph_state import GraphState
from tdastro.math_nodes.given_sampler import GivenSampler
from tdastro.math_nodes.given_sampler import GivenSampler, TableSampler


def _test_func(value1, value2):
Expand Down Expand Up @@ -72,3 +74,56 @@ def test_test_given_sampler_compound():
compound_node.get_param(state2, "function_node_result"),
[3.0, 3.5, 4.0, 4.5, 5.0, 1.0, 5.5, 6.0],
)


@pytest.mark.parametrize("test_data_type", ["dict", "ap_table", "pd_df"])
def test_table_sampler(test_data_type):
"""Test that we can retrieve numbers from a TableSampler from a
dictionary, AstroPy Table, and Panda's DataFrame."""
raw_data_dict = {
"A": [1, 2, 3, 4, 5, 6, 7, 8],
"B": [1, 1, 1, 1, 1, 1, 1, 1],
"C": [3, 4, 5, 6, 7, 8, 9, 10],
}

# Convert the data type depending on the parameterized value.
if test_data_type == "dict":
data = raw_data_dict
elif test_data_type == "ap_table":
data = Table(raw_data_dict)
elif test_data_type == "pd_df":
data = pd.DataFrame(raw_data_dict)
else:
data = None

# Create the table sampler from the data.
table_node = TableSampler(data, node_label="node")
state = table_node.sample_parameters(num_samples=2)
assert len(state) == 3
assert np.allclose(state["node"]["A"], [1, 2])
assert np.allclose(state["node"]["B"], [1, 1])
assert np.allclose(state["node"]["C"], [3, 4])

state = table_node.sample_parameters(num_samples=1)
assert len(state) == 3
assert state["node"]["A"] == 3
assert state["node"]["B"] == 1
assert state["node"]["C"] == 5

state = table_node.sample_parameters(num_samples=4)
assert len(state) == 3
assert np.allclose(state["node"]["A"], [4, 5, 6, 7])
assert np.allclose(state["node"]["B"], [1, 1, 1, 1])
assert np.allclose(state["node"]["C"], [6, 7, 8, 9])

# We go past the end of the data.
with pytest.raises(IndexError):
_ = table_node.sample_parameters(num_samples=4)

# We can reset and sample from the beginning.
table_node.reset()
state = table_node.sample_parameters(num_samples=2)
assert len(state) == 3
assert np.allclose(state["node"]["A"], [1, 2])
assert np.allclose(state["node"]["B"], [1, 1])
assert np.allclose(state["node"]["C"], [3, 4])
Loading