From 102f6668cccc4cca4b167510823009ac6e9ced96 Mon Sep 17 00:00:00 2001 From: Mi Dai Date: Wed, 2 Oct 2024 14:57:08 -0400 Subject: [PATCH 01/10] [WIP] add efficiency functions --- src/tdastro/astro_utils/obs_utils.py | 68 +++++++++++++++++++++ tests/tdastro/astro_utils/test_obs_utils.py | 0 2 files changed, 68 insertions(+) create mode 100644 src/tdastro/astro_utils/obs_utils.py create mode 100644 tests/tdastro/astro_utils/test_obs_utils.py diff --git a/src/tdastro/astro_utils/obs_utils.py b/src/tdastro/astro_utils/obs_utils.py new file mode 100644 index 00000000..ad8a1a78 --- /dev/null +++ b/src/tdastro/astro_utils/obs_utils.py @@ -0,0 +1,68 @@ +import numpy as np + + +def phot_eff_function(snr): + + """ + Photometric detection efficiency as a simple step function of snr. + + Parameters + ---------- + snr: `list` or `numpy.ndarray` + Signal to noise ratio of a list of observations. + + Returns + ------- + eff: `list` or `numpy.ndarray` + The photometric detection efficiency given snr. + """ + + snr = np.array(snr) + eff = np.where(snr > 5, 1., 0.) + + return eff + + +def spec_eff_function(peak_imag): + + """ + Spectroscopic follow-up efficiency as a function of peak i band magnitude. + + Parameters + ---------- + peak_imag: `list` or `numpy.ndarray` + Peak magnitude in i band. + + Returns + ------- + eff: `list` or `numpy.ndarray` + The spectroscopic efficiency given peak i band magnitude. + Based on Equation (17) in Kessler et al. 2019 + """ + + s0 = + s1 = + s2 = + + peak_imag = np.array(peak_imag) + eff = s0 * np.power((1.+ np.exp(s1 * peak_imag - s2)), -1) + + return eff + + + +class Detection_Efficiency(FunctionNode): + + def __init__(): + + + + + +class Spectroscopic_Efficiency(FunctionNode): + + def __init__(): + + + + diff --git a/tests/tdastro/astro_utils/test_obs_utils.py b/tests/tdastro/astro_utils/test_obs_utils.py new file mode 100644 index 00000000..e69de29b From 70da9c19eea9fe4eef3760624dbc8e4bac5d1f29 Mon Sep 17 00:00:00 2001 From: Mi Dai Date: Fri, 25 Oct 2024 14:11:08 -0400 Subject: [PATCH 02/10] add phot and spec efficiency functions --- src/tdastro/astro_utils/obs_utils.py | 31 ++++++---------------------- 1 file changed, 6 insertions(+), 25 deletions(-) diff --git a/src/tdastro/astro_utils/obs_utils.py b/src/tdastro/astro_utils/obs_utils.py index ad8a1a78..07d32dad 100644 --- a/src/tdastro/astro_utils/obs_utils.py +++ b/src/tdastro/astro_utils/obs_utils.py @@ -2,7 +2,6 @@ def phot_eff_function(snr): - """ Photometric detection efficiency as a simple step function of snr. @@ -18,13 +17,12 @@ def phot_eff_function(snr): """ snr = np.array(snr) - eff = np.where(snr > 5, 1., 0.) + eff = np.where(snr > 5, 1.0, 0.0) return eff def spec_eff_function(peak_imag): - """ Spectroscopic follow-up efficiency as a function of peak i band magnitude. @@ -38,31 +36,14 @@ def spec_eff_function(peak_imag): eff: `list` or `numpy.ndarray` The spectroscopic efficiency given peak i band magnitude. Based on Equation (17) in Kessler et al. 2019 + s0, s1, s2 are fitted using data from Figure 4 in Kessler et al. 2019 """ - s0 = - s1 = - s2 = + s0 = 1.0 + s1 = 2.36 + s2 = 51.9 peak_imag = np.array(peak_imag) - eff = s0 * np.power((1.+ np.exp(s1 * peak_imag - s2)), -1) + eff = s0 * np.power((1.0 + np.exp(s1 * peak_imag - s2)), -1) return eff - - - -class Detection_Efficiency(FunctionNode): - - def __init__(): - - - - - -class Spectroscopic_Efficiency(FunctionNode): - - def __init__(): - - - - From 01906d070ac96b514dc3b479fd931531f93575c9 Mon Sep 17 00:00:00 2001 From: Mi Dai Date: Mon, 28 Oct 2024 10:52:28 -0400 Subject: [PATCH 03/10] add tests for efficiency functions --- tests/tdastro/astro_utils/test_obs_utils.py | 26 +++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/tdastro/astro_utils/test_obs_utils.py b/tests/tdastro/astro_utils/test_obs_utils.py index e69de29b..a57bfea8 100644 --- a/tests/tdastro/astro_utils/test_obs_utils.py +++ b/tests/tdastro/astro_utils/test_obs_utils.py @@ -0,0 +1,26 @@ +import numpy as np +from tdastro.astro_utils.obs_utils import phot_eff_function, spec_eff_function + + +def test_phot_eff_function(): + """ + test that the phot_eff_function returns correct values. + """ + + snr = [1.0, 3.0, 10.0, 100.0] + eff = phot_eff_function(snr) + expected_eff = [0.0, 0.0, 1.0, 1.0] + + np.testing.assert_allclose(eff, expected_eff) + + +def test_spec_eff_function(): + """ + test the spec efficiency function using data extracted from Figure 4 in Kessler et al. 2019 + """ + + imags = [20.01, 21.09, 21.67, 22.54] + expected_eff = [1.0, 0.9, 0.67, 0.2] + eff = spec_eff_function(imags) + + np.testing.assert_allclose(eff, expected_eff, atol=0.02) From 5ccb50e8dfe0e5a00a521b08ea3243334802c809 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 1 Nov 2024 18:41:32 +0000 Subject: [PATCH 04/10] Bump pre-commit-ci/lite-action from 1.0.3 to 1.1.0 Bumps [pre-commit-ci/lite-action](https://github.com/pre-commit-ci/lite-action) from 1.0.3 to 1.1.0. - [Release notes](https://github.com/pre-commit-ci/lite-action/releases) - [Commits](https://github.com/pre-commit-ci/lite-action/compare/v1.0.3...v1.1.0) --- updated-dependencies: - dependency-name: pre-commit-ci/lite-action dependency-type: direct:production update-type: version-update:semver-minor ... Signed-off-by: dependabot[bot] --- .github/workflows/pre-commit-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pre-commit-ci.yml b/.github/workflows/pre-commit-ci.yml index 86560a8f..54cebb17 100644 --- a/.github/workflows/pre-commit-ci.yml +++ b/.github/workflows/pre-commit-ci.yml @@ -31,5 +31,5 @@ jobs: extra_args: --all-files --verbose env: SKIP: "check-lincc-frameworks-template-version,no-commit-to-branch,check-added-large-files,validate-pyproject,sphinx-build,pytest-check" - - uses: pre-commit-ci/lite-action@v1.0.3 + - uses: pre-commit-ci/lite-action@v1.1.0 if: failure() && github.event_name == 'pull_request' && github.event.pull_request.draft == false \ No newline at end of file From 2c7d1b75239843f0f208408c3e3637176c428fa3 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Sun, 3 Nov 2024 18:12:00 -0500 Subject: [PATCH 05/10] A very basic math processing node --- src/tdastro/math_nodes/basic_math_node.py | 154 ++++++++++++++++++ .../math_nodes/test_basic_math_node.py | 109 +++++++++++++ 2 files changed, 263 insertions(+) create mode 100644 src/tdastro/math_nodes/basic_math_node.py create mode 100644 tests/tdastro/math_nodes/test_basic_math_node.py diff --git a/src/tdastro/math_nodes/basic_math_node.py b/src/tdastro/math_nodes/basic_math_node.py new file mode 100644 index 00000000..417d0b7b --- /dev/null +++ b/src/tdastro/math_nodes/basic_math_node.py @@ -0,0 +1,154 @@ +"""Nodes that perform basic math operations.""" + +import ast +import math + +import jax.numpy as jnp +import numpy as np + +from tdastro.base_models import FunctionNode + + +class BasicMathNode(FunctionNode): + """A node that evaluates basic mathematical functions. + + Attributes + ---------- + expression : `str` + The expression to evaluate. + backend : `str` + The math libary to use. Must be one of: math, numpy, or jax. + tree : `ast.*` + The root node of the parsed syntax tree. + + Parameters + ---------- + expression : `str` + The expression to evaluate. + backend : `str` + The math libary to use. Must be one of: math, numpy, or jax. + node_label : `str`, optional + An identifier (or name) for the current node. + **kwargs : `dict`, optional + Any additional keyword arguments. + """ + + # A list of supported Python operations. Used to prevent eval from + # running arbitrary python expressions. + _supported_ast_nodes = ( + ast.Module, # Top level object when parsed as exec. + ast.Expression, # Top level object when parsed as eval. + ast.Expr, # Math expressions. + ast.Constant, # Constant values. + ast.Name, # A named variable or function. + ast.BinOp, # Binary operations + ast.Add, + ast.Sub, + ast.Mult, + ast.Div, + ast.FloorDiv, + ast.Mod, + ast.Pow, + ast.UnaryOp, # Uninary operations + ast.UAdd, + ast.USub, + ast.Invert, + # Call functions (but we do NOT include ast.Call because + # we need to special case that). + ast.Load, + ast.Store, + ) + + # A very limited set of math operations that are supported + # in all of the backends. + _math_funcs = set( + [ + "abs", + "ceil", + "cos", + "cosh", + "degrees", + "exp", + "floor", + "log", + "log10", + "log2", + "radians", + "sin", + "sinh", + "sqrt", + "tan", + "tanh", + ] + ) + + def __init__(self, expression, backend="numpy", node_label=None, **kwargs): + self.expression = expression + self.backend = backend + + # Check that all the functions are supported. + if backend == "math": + supported_funcs = dir(math) + supported_funcs.append("abs") + elif backend == "jax": + supported_funcs = dir(jnp) + elif backend == "numpy": + supported_funcs = dir(np) + else: + raise ValueError(f"Unsupported math backend {backend}") + + for fn_name in self._math_funcs: + if fn_name not in supported_funcs: + raise ValueError(f"Function {fn_name} is not supported by {backend}.") + + # Check the expression is pure math and translate it into the correct backend. + self._compile() + + # Create a function from the expression. Note the expression has + # already been sanitized and validated via _compile(). + def eval_func(**kwargs): + return eval(self.expression, globals(), kwargs) + + super().__init__(eval_func, node_label=node_label, **kwargs) + + def __call__(self, **kwargs): + """Evaluate thge""" + return eval(self.expression, globals(), kwargs) + + def _compile(self, **kwargs): + """Compile a python expression that consists of only basic math. + + Parameters + ---------- + **kwargs : `dict`, optional + Any additional keyword arguments, including the variable + assignments. + + Returns + ------- + tree : `ast.*` + The root node of the parsed syntax tree. + """ + tree = ast.parse(self.expression) + + # Walk the tree and confirm that it only contains the basic math. + for node in ast.walk(tree): + if isinstance(node, self._supported_ast_nodes): + # Nothing to do, this is a valid operation for the ast. + continue + elif isinstance(node, (ast.FunctionType, ast.Call)): + func_name = node.func.id + if func_name not in self._math_funcs: + raise ValueError("Unsupported function {func_name}.") + + if self.backend == "numpy": + node.func.id = f"np.{func_name}" + elif self.backend == "jax": + node.func.id = f"jnp.{func_name}" + elif self.backend == "math" and "func_name" != "abs": + node.func.id = f"math.{func_name}" + else: + raise ValueError(f"Invalid part of expression {type(node)}") + + # Convert the expression back into a string. + self.expression = ast.unparse(tree) diff --git a/tests/tdastro/math_nodes/test_basic_math_node.py b/tests/tdastro/math_nodes/test_basic_math_node.py new file mode 100644 index 00000000..d01ed368 --- /dev/null +++ b/tests/tdastro/math_nodes/test_basic_math_node.py @@ -0,0 +1,109 @@ +import math + +import jax +import pytest +from tdastro.math_nodes.basic_math_node import BasicMathNode +from tdastro.math_nodes.single_value_node import SingleVariableNode + + +def test_basic_math_node(): + """Test that we can perform computations via a BasicMathNode.""" + node_a = SingleVariableNode("a", 10.0) + node_b = SingleVariableNode("b", -5.0) + node = BasicMathNode("a + b", a=node_a.a, b=node_b.b, node_label="test", backend="math") + state = node.sample_parameters() + assert state["test"]["function_node_result"] == 5.0 + + # Try with a math function. + node_c = SingleVariableNode("c", 1000.0) + node = BasicMathNode("a + b - log10(c)", a=10.0, b=5.0, c=node_c.c, node_label="test", backend="math") + state = node.sample_parameters() + assert state["test"]["function_node_result"] == 12.0 + + # Try with a second math function. + node = BasicMathNode( + "sqrt(a) + b - log10(c)", a=16.0, b=4.0, c=node_c.c, node_label="test", backend="math" + ) + state = node.sample_parameters() + assert state["test"]["function_node_result"] == 5.0 + + # Test that we can reproduce the power function. + node_d = SingleVariableNode("d", 5.0) + node = BasicMathNode("a ** b", a=node_d.d, b=2.5, node_label="test", backend="math") + state = node.sample_parameters() + assert state["test"]["function_node_result"] == pytest.approx(math.pow(5.0, 2.5)) + + +def test_basic_math_node_numpy(): + """Test that we can perform computations via a BasicMathNode.""" + node_a = SingleVariableNode("a", 10.0) + node_b = SingleVariableNode("b", -5.0) + node = BasicMathNode("a + b", a=node_a.a, b=node_b.b, node_label="test", backend="numpy") + state = node.sample_parameters() + assert state["test"]["function_node_result"] == 5.0 + + # Try with a math function. + node_c = SingleVariableNode("c", 1000.0) + node = BasicMathNode("a + b - log10(c)", a=10.0, b=5.0, c=node_c.c, node_label="test", backend="numpy") + state = node.sample_parameters() + assert state["test"]["function_node_result"] == 12.0 + + # Try with a second math function. + node = BasicMathNode( + "sqrt(a) + b - log10(c)", a=16.0, b=4.0, c=node_c.c, node_label="test", backend="numpy" + ) + state = node.sample_parameters() + assert state["test"]["function_node_result"] == 5.0 + + # Test that we can reproduce the power function. + node_d = SingleVariableNode("d", 5.0) + node = BasicMathNode("a ** b", a=node_d.d, b=2.5, node_label="test", backend="math") + state = node.sample_parameters() + assert state["test"]["function_node_result"] == pytest.approx(math.pow(5.0, 2.5)) + + +def test_basic_math_node_jax(): + """Test that we can perform computations via a BasicMathNode.""" + node_a = SingleVariableNode("a", 10.0) + node_b = SingleVariableNode("b", -5.0) + node = BasicMathNode("a + b", a=node_a.a, b=node_b.b, node_label="test", backend="jax") + state = node.sample_parameters() + assert state["test"]["function_node_result"] == 5.0 + + # Try with a math function. + node_c = SingleVariableNode("c", 1000.0) + node = BasicMathNode("a + b - log10(c)", a=10.0, b=5.0, c=node_c.c, node_label="test", backend="jax") + state = node.sample_parameters() + assert state["test"]["function_node_result"] == 12.0 + + # Try with a second math function. + node = BasicMathNode( + "sqrt(a) + b - log10(c)", a=16.0, b=4.0, c=node_c.c, node_label="test", backend="jax" + ) + state = node.sample_parameters() + assert state["test"]["function_node_result"] == 5.0 + + # Test that we can reproduce the power function. + node_d = SingleVariableNode("d", 5.0) + node = BasicMathNode("a ** b", a=node_d.d, b=2.5, node_label="test", backend="math") + state = node.sample_parameters() + assert state["test"]["function_node_result"] == pytest.approx(math.pow(5.0, 2.5)) + + +def test_basic_math_node_autodiff_jax(): + """Test that we can do auto-differentiation with JAX.""" + node_a = SingleVariableNode("a", 16.0, node_label="a_node") + node_b = SingleVariableNode("b", 1000.0, node_label="b_node") + + # Create a basic math function and create tghe pytree. + node = BasicMathNode( + "sqrt(a) + 1.0 - log10(b)", a=node_a.a, b=node_b.b, node_label="diff_test", backend="jax" + ) + state = node.sample_parameters() + pytree = node.build_pytree(state) + + gr_func = jax.value_and_grad(node.resample_and_compute) + values, gradients = gr_func(pytree) + assert values == 2.0 + assert gradients["a_node"]["a"] > 0.0 + assert gradients["b_node"]["b"] < 0.0 From 1490e6a3f51ddbb20a8c28a242ab4903e45f1e89 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 4 Nov 2024 10:30:37 -0500 Subject: [PATCH 06/10] Improve robustness --- src/tdastro/math_nodes/basic_math_node.py | 165 +++++++++++------- .../math_nodes/test_basic_math_node.py | 29 +++ 2 files changed, 127 insertions(+), 67 deletions(-) diff --git a/src/tdastro/math_nodes/basic_math_node.py b/src/tdastro/math_nodes/basic_math_node.py index 417d0b7b..dba52b72 100644 --- a/src/tdastro/math_nodes/basic_math_node.py +++ b/src/tdastro/math_nodes/basic_math_node.py @@ -1,10 +1,17 @@ -"""Nodes that perform basic math operations.""" +"""Nodes that perform basic math operations that can be specified as strings. + +The goal of this library is to save users from needing to create a bunch of +small FunctionNodes to perform basic math. +""" import ast -import math -import jax.numpy as jnp -import numpy as np +# Disable unused import because we need all of these imported +# so they can be used during evaluation of the node. +import math # noqa: F401 + +import jax.numpy as jnp # noqa: F401 +import numpy as np # noqa: F401 from tdastro.base_models import FunctionNode @@ -12,14 +19,24 @@ class BasicMathNode(FunctionNode): """A node that evaluates basic mathematical functions. + The BasicMathNode wraps Python's eval() function to sanitize the input string + and thus prevent the execution of arbitrary code. It also allows the user to write + the expression once and execute using math, numpy, or JAX. The names of the + variables in the expression must match the input variables provided by kwargs. + + Example: + my_node = BasicMathNode( + "redshift + 10.0 * sin(phase)", + redshift=host.redshift, + phase=source.phase, + ) + Attributes ---------- expression : `str` The expression to evaluate. backend : `str` The math libary to use. Must be one of: math, numpy, or jax. - tree : `ast.*` - The root node of the parsed syntax tree. Parameters ---------- @@ -30,17 +47,20 @@ class BasicMathNode(FunctionNode): node_label : `str`, optional An identifier (or name) for the current node. **kwargs : `dict`, optional - Any additional keyword arguments. + Any additional keyword arguments. Every variable in the expression + must be included as a kwarg. """ # A list of supported Python operations. Used to prevent eval from - # running arbitrary python expressions. + # running arbitrary python expressions. The Call and Name types are special + # cased so we can do checks and translations. _supported_ast_nodes = ( ast.Module, # Top level object when parsed as exec. ast.Expression, # Top level object when parsed as eval. ast.Expr, # Math expressions. ast.Constant, # Constant values. - ast.Name, # A named variable or function. + ast.Load, # Load a variable - must come from an approved function or variable. + ast.Store, # Store value - must come from an approved function or variable. ast.BinOp, # Binary operations ast.Add, ast.Sub, @@ -53,70 +73,70 @@ class BasicMathNode(FunctionNode): ast.UAdd, ast.USub, ast.Invert, - # Call functions (but we do NOT include ast.Call because - # we need to special case that). - ast.Load, - ast.Store, ) - # A very limited set of math operations that are supported - # in all of the backends. - _math_funcs = set( - [ - "abs", - "ceil", - "cos", - "cosh", - "degrees", - "exp", - "floor", - "log", - "log10", - "log2", - "radians", - "sin", - "sinh", - "sqrt", - "tan", - "tanh", - ] - ) + # A map from aa very limited set of supported math constant/function names to + # the corresponding names in [math, numpy, jax]. This is needed because + # a very few functions have different names in different libraries. + _math_map = { + "abs": ["abs", "np.abs", "jnp.abs"], # Special handling for math. + "acos": ["math.acos", "np.acos", "jnp.acos"], + "acosh": ["math.acosh", "np.acosh", "jnp.acosh"], + "asin": ["math.asin", "np.asin", "jnp.asin"], + "asinh": ["math.asinh", "np.asinh", "jnp.asinh"], + "atan": ["math.atan", "np.atan", "jnp.atan"], + "atan2": ["math.atan2", "np.atan2", "jnp.atan2"], + "cos": ["math.cos", "np.cos", "jnp.cos"], + "cosh": ["math.cosh", "np.cosh", "jnp.cosh"], + "ceil": ["math.ceil", "np.ceil", "jnp.ceil"], + "degrees": ["math.degrees", "np.degrees", "jnp.degrees"], + "deg2rad": ["math.radians", "np.deg2rad", "jnp.deg2rad"], # Special handling for math + "e": ["math.e", "np.e", "jnp.e"], + "exp": ["math.exp", "np.exp", "jnp.exp"], + "fabs": ["math.fabs", "np.fabs", "jnp.fabs"], + "floor": ["math.floor", "np.floor", "jnp.floor"], + "log": ["math.log", "np.log", "jnp.log"], + "log10": ["math.log10", "np.log10", "jnp.log10"], + "log2": ["math.log2", "np.log2", "jnp.log2"], + "max": ["max", "np.max", "jnp.max"], # Special handling for math + "min": ["min", "np.min", "jnp.min"], # Special handling for math + "pi": ["math.pi", "np.pi", "jnp.pi"], + "pow": ["math.pow", "np.power", "jnp.power"], # Special handling for numpy + "power": ["math.pow", "np.power", "jnp.power"], # Special handling for math + "radians": ["math.radians", "np.radians", "jnp.radians"], + "rad2deg": ["math.degrees", "np.rad2deg", "jnp.rad2deg"], # Special handling for math + "sin": ["math.sin", "np.sin", "jnp.sin"], + "sinh": ["math.sinh", "np.sinh", "jnp.sinh"], + "sqrt": ["math.sqrt", "np.sqrt", "jnp.sqrt"], + "tan": ["math.tan", "np.tan", "jnp.tan"], + "tanh": ["math.tanh", "np.tanh", "jnp.tanh"], + "trunc": ["math.trunc", "np.trunc", "jnp.trunc"], + } def __init__(self, expression, backend="numpy", node_label=None, **kwargs): - self.expression = expression - self.backend = backend - - # Check that all the functions are supported. - if backend == "math": - supported_funcs = dir(math) - supported_funcs.append("abs") - elif backend == "jax": - supported_funcs = dir(jnp) - elif backend == "numpy": - supported_funcs = dir(np) - else: + if backend not in ["jax", "math", "numpy"]: raise ValueError(f"Unsupported math backend {backend}") - - for fn_name in self._math_funcs: - if fn_name not in supported_funcs: - raise ValueError(f"Function {fn_name} is not supported by {backend}.") + self.backend = backend # Check the expression is pure math and translate it into the correct backend. - self._compile() + self.expression = expression + self._prepare(**kwargs) # Create a function from the expression. Note the expression has - # already been sanitized and validated via _compile(). + # already been sanitized and validated via _prepare(). def eval_func(**kwargs): return eval(self.expression, globals(), kwargs) super().__init__(eval_func, node_label=node_label, **kwargs) def __call__(self, **kwargs): - """Evaluate thge""" + """Evaluate the expression.""" return eval(self.expression, globals(), kwargs) - def _compile(self, **kwargs): - """Compile a python expression that consists of only basic math. + def _prepare(self, **kwargs): + """Rewrite a python expression that consists of only basic math to use + the prespecified math library. Santizes the string to prevent + arbitrary code execution. Parameters ---------- @@ -136,17 +156,28 @@ def _compile(self, **kwargs): if isinstance(node, self._supported_ast_nodes): # Nothing to do, this is a valid operation for the ast. continue - elif isinstance(node, (ast.FunctionType, ast.Call)): - func_name = node.func.id - if func_name not in self._math_funcs: - raise ValueError("Unsupported function {func_name}.") - - if self.backend == "numpy": - node.func.id = f"np.{func_name}" - elif self.backend == "jax": - node.func.id = f"jnp.{func_name}" - elif self.backend == "math" and "func_name" != "abs": - node.func.id = f"math.{func_name}" + elif isinstance(node, ast.Call): + # Check that function calls are only using items on the allow list. + if node.func.id not in self._math_map: + raise ValueError(f"Unsupported function {node.func.id}") + elif isinstance(node, ast.Name): + if node.id in kwargs: + # This is a user supplied variable. + continue + elif node.id in self._math_map: + # This is a math function or constant. Overwrite + if self.backend == "math": + node.id = self._math_map[node.id][0] + elif self.backend == "numpy": + node.id = self._math_map[node.id][1] + elif self.backend == "jax": + node.id = self._math_map[node.id][2] + else: + raise ValueError( + f"Unrecognized named variable or function {node.id}. " + "This could be because the function is not supported or " + "you forgot to include the variable as an argument." + ) else: raise ValueError(f"Invalid part of expression {type(node)}") diff --git a/tests/tdastro/math_nodes/test_basic_math_node.py b/tests/tdastro/math_nodes/test_basic_math_node.py index d01ed368..8e5b36ea 100644 --- a/tests/tdastro/math_nodes/test_basic_math_node.py +++ b/tests/tdastro/math_nodes/test_basic_math_node.py @@ -34,6 +34,35 @@ def test_basic_math_node(): assert state["test"]["function_node_result"] == pytest.approx(math.pow(5.0, 2.5)) +def test_basic_math_node_special_cases(): + """Test that we can handle some of the special cases for a BasicMathNode.""" + node_a = SingleVariableNode("a", 180.0) + node = BasicMathNode("sin(deg2rad(x) + pi / 2.0)", x=node_a.a, node_label="test", backend="math") + state = node.sample_parameters() + assert state["test"]["function_node_result"] == pytest.approx(-1.0) + + +def test_basic_math_node_fail(): + """Test that we perform the needed checks for a math node.""" + # Imports not allowed + with pytest.raises(ValueError): + _ = BasicMathNode("import os") + + # Ifs not allowed (won't work with JAX) + with pytest.raises(ValueError): + _ = BasicMathNode("x if 1.0 else 1.0", x=2.0) + + # We only allow functions on the allow list. + with pytest.raises(ValueError): + _ = BasicMathNode("fake_delete_everything_no_confirm('./')") + with pytest.raises(ValueError): + _ = BasicMathNode("median(10, 20)") + + # All variables must be defined. + with pytest.raises(ValueError): + _ = BasicMathNode("x + y", x=1.0) + + def test_basic_math_node_numpy(): """Test that we can perform computations via a BasicMathNode.""" node_a = SingleVariableNode("a", 10.0) From ecaf1f3b0a1401ee1b3bce8cbeee7a5a2fd59cb2 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 4 Nov 2024 11:03:24 -0500 Subject: [PATCH 07/10] Update src/tdastro/math_nodes/basic_math_node.py Co-authored-by: Melissa DeLucchi <113376043+delucchi-cmu@users.noreply.github.com> --- src/tdastro/math_nodes/basic_math_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tdastro/math_nodes/basic_math_node.py b/src/tdastro/math_nodes/basic_math_node.py index dba52b72..55c40e8e 100644 --- a/src/tdastro/math_nodes/basic_math_node.py +++ b/src/tdastro/math_nodes/basic_math_node.py @@ -75,7 +75,7 @@ class BasicMathNode(FunctionNode): ast.Invert, ) - # A map from aa very limited set of supported math constant/function names to + # A map from a very limited set of supported math constant/function names to # the corresponding names in [math, numpy, jax]. This is needed because # a very few functions have different names in different libraries. _math_map = { From 4288ae36ef6fd1ed62df56ab73ae90c7c51d174d Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Thu, 7 Nov 2024 16:19:09 -0500 Subject: [PATCH 08/10] Add some basic logging for errors in math node --- src/tdastro/math_nodes/basic_math_node.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/tdastro/math_nodes/basic_math_node.py b/src/tdastro/math_nodes/basic_math_node.py index 55c40e8e..5fe90599 100644 --- a/src/tdastro/math_nodes/basic_math_node.py +++ b/src/tdastro/math_nodes/basic_math_node.py @@ -5,6 +5,7 @@ """ import ast +import logging # Disable unused import because we need all of these imported # so they can be used during evaluation of the node. @@ -15,6 +16,8 @@ from tdastro.base_models import FunctionNode +logger = logging.getLogger(__name__) + class BasicMathNode(FunctionNode): """A node that evaluates basic mathematical functions. @@ -125,7 +128,16 @@ def __init__(self, expression, backend="numpy", node_label=None, **kwargs): # Create a function from the expression. Note the expression has # already been sanitized and validated via _prepare(). def eval_func(**kwargs): - return eval(self.expression, globals(), kwargs) + try: + return eval(self.expression, globals(), kwargs) + except Exception as problem: + # Provide more detailed logging, including the expression and parameters + # used, when we encounter a math error like divide by zero. + logger.error( + f"{type(problem)} encountered during operation: {self.expression}\n" + f"with arguments={kwargs}" + ) + raise problem super().__init__(eval_func, node_label=node_label, **kwargs) From 0ef5212885be356e8d6cbb93fa4e228763d8f5da Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Thu, 7 Nov 2024 17:06:43 -0500 Subject: [PATCH 09/10] Add the debugging information to the exception --- src/tdastro/math_nodes/basic_math_node.py | 7 ++----- tests/tdastro/math_nodes/test_basic_math_node.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/tdastro/math_nodes/basic_math_node.py b/src/tdastro/math_nodes/basic_math_node.py index 5fe90599..899b2d0b 100644 --- a/src/tdastro/math_nodes/basic_math_node.py +++ b/src/tdastro/math_nodes/basic_math_node.py @@ -133,11 +133,8 @@ def eval_func(**kwargs): except Exception as problem: # Provide more detailed logging, including the expression and parameters # used, when we encounter a math error like divide by zero. - logger.error( - f"{type(problem)} encountered during operation: {self.expression}\n" - f"with arguments={kwargs}" - ) - raise problem + new_message = f"Error during math operation '{self.expression}' with args={kwargs}" + raise type(problem)(new_message) from problem super().__init__(eval_func, node_label=node_label, **kwargs) diff --git a/tests/tdastro/math_nodes/test_basic_math_node.py b/tests/tdastro/math_nodes/test_basic_math_node.py index 8e5b36ea..c753a0aa 100644 --- a/tests/tdastro/math_nodes/test_basic_math_node.py +++ b/tests/tdastro/math_nodes/test_basic_math_node.py @@ -63,6 +63,22 @@ def test_basic_math_node_fail(): _ = BasicMathNode("x + y", x=1.0) +@pytest.mark.filterwarnings("ignore::RuntimeWarning") +def test_basic_math_node_error(): + """Test that we augment the error with information about the expression and parameters.""" + node = BasicMathNode("y / x", x=0.0, y=1.0) + try: + node.sample_parameters() + except ZeroDivisionError as err: + assert str(err) == "Error during math operation 'y / x' with args={'x': 0.0, 'y': 1.0}" + + node = BasicMathNode("sqrt(x)", x=-10.0) + try: + node.sample_parameters() + except ValueError as err: + assert str(err) == "Error during math operation 'np.sqrt(x)' with args={'x': -10.0}" + + def test_basic_math_node_numpy(): """Test that we can perform computations via a BasicMathNode.""" node_a = SingleVariableNode("a", 10.0) From 901a97fbe2f7c00066c0fcf19be71f7a5e20d522 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Thu, 7 Nov 2024 17:08:11 -0500 Subject: [PATCH 10/10] Remove unneeded logger --- src/tdastro/math_nodes/basic_math_node.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/tdastro/math_nodes/basic_math_node.py b/src/tdastro/math_nodes/basic_math_node.py index 899b2d0b..f05b0d9b 100644 --- a/src/tdastro/math_nodes/basic_math_node.py +++ b/src/tdastro/math_nodes/basic_math_node.py @@ -5,7 +5,6 @@ """ import ast -import logging # Disable unused import because we need all of these imported # so they can be used during evaluation of the node. @@ -16,8 +15,6 @@ from tdastro.base_models import FunctionNode -logger = logging.getLogger(__name__) - class BasicMathNode(FunctionNode): """A node that evaluates basic mathematical functions.