diff --git a/src/tdastro/math_nodes/basic_math_node.py b/src/tdastro/math_nodes/basic_math_node.py index 55c40e8..f05b0d9 100644 --- a/src/tdastro/math_nodes/basic_math_node.py +++ b/src/tdastro/math_nodes/basic_math_node.py @@ -125,7 +125,13 @@ def __init__(self, expression, backend="numpy", node_label=None, **kwargs): # Create a function from the expression. Note the expression has # already been sanitized and validated via _prepare(). def eval_func(**kwargs): - return eval(self.expression, globals(), kwargs) + try: + return eval(self.expression, globals(), kwargs) + except Exception as problem: + # Provide more detailed logging, including the expression and parameters + # used, when we encounter a math error like divide by zero. + new_message = f"Error during math operation '{self.expression}' with args={kwargs}" + raise type(problem)(new_message) from problem super().__init__(eval_func, node_label=node_label, **kwargs) diff --git a/tests/tdastro/math_nodes/test_basic_math_node.py b/tests/tdastro/math_nodes/test_basic_math_node.py index 8e5b36e..c753a0a 100644 --- a/tests/tdastro/math_nodes/test_basic_math_node.py +++ b/tests/tdastro/math_nodes/test_basic_math_node.py @@ -63,6 +63,22 @@ def test_basic_math_node_fail(): _ = BasicMathNode("x + y", x=1.0) +@pytest.mark.filterwarnings("ignore::RuntimeWarning") +def test_basic_math_node_error(): + """Test that we augment the error with information about the expression and parameters.""" + node = BasicMathNode("y / x", x=0.0, y=1.0) + try: + node.sample_parameters() + except ZeroDivisionError as err: + assert str(err) == "Error during math operation 'y / x' with args={'x': 0.0, 'y': 1.0}" + + node = BasicMathNode("sqrt(x)", x=-10.0) + try: + node.sample_parameters() + except ValueError as err: + assert str(err) == "Error during math operation 'np.sqrt(x)' with args={'x': -10.0}" + + def test_basic_math_node_numpy(): """Test that we can perform computations via a BasicMathNode.""" node_a = SingleVariableNode("a", 10.0)