Skip to content

Commit

Permalink
Merge branch 'main' into math_eval
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Nov 7, 2024
2 parents f701b89 + 9ea64af commit 178a303
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/tdastro/math_nodes/basic_math_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,13 @@ def __init__(self, expression, backend="numpy", node_label=None, **kwargs):
# Create a function from the expression. Note the expression has
# already been sanitized and validated via _prepare().
def eval_func(**kwargs):
return eval(self.expression, globals(), kwargs)
try:
return eval(self.expression, globals(), kwargs)
except Exception as problem:
# Provide more detailed logging, including the expression and parameters
# used, when we encounter a math error like divide by zero.
new_message = f"Error during math operation '{self.expression}' with args={kwargs}"
raise type(problem)(new_message) from problem

super().__init__(eval_func, node_label=node_label, **kwargs)

Expand Down
16 changes: 16 additions & 0 deletions tests/tdastro/math_nodes/test_basic_math_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,22 @@ def test_basic_math_node_fail():
_ = BasicMathNode("x + y", x=1.0)


@pytest.mark.filterwarnings("ignore::RuntimeWarning")
def test_basic_math_node_error():
"""Test that we augment the error with information about the expression and parameters."""
node = BasicMathNode("y / x", x=0.0, y=1.0)
try:
node.sample_parameters()
except ZeroDivisionError as err:
assert str(err) == "Error during math operation 'y / x' with args={'x': 0.0, 'y': 1.0}"

node = BasicMathNode("sqrt(x)", x=-10.0)
try:
node.sample_parameters()
except ValueError as err:
assert str(err) == "Error during math operation 'np.sqrt(x)' with args={'x': -10.0}"


def test_basic_math_node_numpy():
"""Test that we can perform computations via a BasicMathNode."""
node_a = SingleVariableNode("a", 10.0)
Expand Down

0 comments on commit 178a303

Please sign in to comment.