From 2c7d1b75239843f0f208408c3e3637176c428fa3 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Sun, 3 Nov 2024 18:12:00 -0500 Subject: [PATCH 1/3] A very basic math processing node --- src/tdastro/math_nodes/basic_math_node.py | 154 ++++++++++++++++++ .../math_nodes/test_basic_math_node.py | 109 +++++++++++++ 2 files changed, 263 insertions(+) create mode 100644 src/tdastro/math_nodes/basic_math_node.py create mode 100644 tests/tdastro/math_nodes/test_basic_math_node.py diff --git a/src/tdastro/math_nodes/basic_math_node.py b/src/tdastro/math_nodes/basic_math_node.py new file mode 100644 index 00000000..417d0b7b --- /dev/null +++ b/src/tdastro/math_nodes/basic_math_node.py @@ -0,0 +1,154 @@ +"""Nodes that perform basic math operations.""" + +import ast +import math + +import jax.numpy as jnp +import numpy as np + +from tdastro.base_models import FunctionNode + + +class BasicMathNode(FunctionNode): + """A node that evaluates basic mathematical functions. + + Attributes + ---------- + expression : `str` + The expression to evaluate. + backend : `str` + The math libary to use. Must be one of: math, numpy, or jax. + tree : `ast.*` + The root node of the parsed syntax tree. + + Parameters + ---------- + expression : `str` + The expression to evaluate. + backend : `str` + The math libary to use. Must be one of: math, numpy, or jax. + node_label : `str`, optional + An identifier (or name) for the current node. + **kwargs : `dict`, optional + Any additional keyword arguments. + """ + + # A list of supported Python operations. Used to prevent eval from + # running arbitrary python expressions. + _supported_ast_nodes = ( + ast.Module, # Top level object when parsed as exec. + ast.Expression, # Top level object when parsed as eval. + ast.Expr, # Math expressions. + ast.Constant, # Constant values. + ast.Name, # A named variable or function. + ast.BinOp, # Binary operations + ast.Add, + ast.Sub, + ast.Mult, + ast.Div, + ast.FloorDiv, + ast.Mod, + ast.Pow, + ast.UnaryOp, # Uninary operations + ast.UAdd, + ast.USub, + ast.Invert, + # Call functions (but we do NOT include ast.Call because + # we need to special case that). + ast.Load, + ast.Store, + ) + + # A very limited set of math operations that are supported + # in all of the backends. + _math_funcs = set( + [ + "abs", + "ceil", + "cos", + "cosh", + "degrees", + "exp", + "floor", + "log", + "log10", + "log2", + "radians", + "sin", + "sinh", + "sqrt", + "tan", + "tanh", + ] + ) + + def __init__(self, expression, backend="numpy", node_label=None, **kwargs): + self.expression = expression + self.backend = backend + + # Check that all the functions are supported. + if backend == "math": + supported_funcs = dir(math) + supported_funcs.append("abs") + elif backend == "jax": + supported_funcs = dir(jnp) + elif backend == "numpy": + supported_funcs = dir(np) + else: + raise ValueError(f"Unsupported math backend {backend}") + + for fn_name in self._math_funcs: + if fn_name not in supported_funcs: + raise ValueError(f"Function {fn_name} is not supported by {backend}.") + + # Check the expression is pure math and translate it into the correct backend. + self._compile() + + # Create a function from the expression. Note the expression has + # already been sanitized and validated via _compile(). + def eval_func(**kwargs): + return eval(self.expression, globals(), kwargs) + + super().__init__(eval_func, node_label=node_label, **kwargs) + + def __call__(self, **kwargs): + """Evaluate thge""" + return eval(self.expression, globals(), kwargs) + + def _compile(self, **kwargs): + """Compile a python expression that consists of only basic math. + + Parameters + ---------- + **kwargs : `dict`, optional + Any additional keyword arguments, including the variable + assignments. + + Returns + ------- + tree : `ast.*` + The root node of the parsed syntax tree. + """ + tree = ast.parse(self.expression) + + # Walk the tree and confirm that it only contains the basic math. + for node in ast.walk(tree): + if isinstance(node, self._supported_ast_nodes): + # Nothing to do, this is a valid operation for the ast. + continue + elif isinstance(node, (ast.FunctionType, ast.Call)): + func_name = node.func.id + if func_name not in self._math_funcs: + raise ValueError("Unsupported function {func_name}.") + + if self.backend == "numpy": + node.func.id = f"np.{func_name}" + elif self.backend == "jax": + node.func.id = f"jnp.{func_name}" + elif self.backend == "math" and "func_name" != "abs": + node.func.id = f"math.{func_name}" + else: + raise ValueError(f"Invalid part of expression {type(node)}") + + # Convert the expression back into a string. + self.expression = ast.unparse(tree) diff --git a/tests/tdastro/math_nodes/test_basic_math_node.py b/tests/tdastro/math_nodes/test_basic_math_node.py new file mode 100644 index 00000000..d01ed368 --- /dev/null +++ b/tests/tdastro/math_nodes/test_basic_math_node.py @@ -0,0 +1,109 @@ +import math + +import jax +import pytest +from tdastro.math_nodes.basic_math_node import BasicMathNode +from tdastro.math_nodes.single_value_node import SingleVariableNode + + +def test_basic_math_node(): + """Test that we can perform computations via a BasicMathNode.""" + node_a = SingleVariableNode("a", 10.0) + node_b = SingleVariableNode("b", -5.0) + node = BasicMathNode("a + b", a=node_a.a, b=node_b.b, node_label="test", backend="math") + state = node.sample_parameters() + assert state["test"]["function_node_result"] == 5.0 + + # Try with a math function. + node_c = SingleVariableNode("c", 1000.0) + node = BasicMathNode("a + b - log10(c)", a=10.0, b=5.0, c=node_c.c, node_label="test", backend="math") + state = node.sample_parameters() + assert state["test"]["function_node_result"] == 12.0 + + # Try with a second math function. + node = BasicMathNode( + "sqrt(a) + b - log10(c)", a=16.0, b=4.0, c=node_c.c, node_label="test", backend="math" + ) + state = node.sample_parameters() + assert state["test"]["function_node_result"] == 5.0 + + # Test that we can reproduce the power function. + node_d = SingleVariableNode("d", 5.0) + node = BasicMathNode("a ** b", a=node_d.d, b=2.5, node_label="test", backend="math") + state = node.sample_parameters() + assert state["test"]["function_node_result"] == pytest.approx(math.pow(5.0, 2.5)) + + +def test_basic_math_node_numpy(): + """Test that we can perform computations via a BasicMathNode.""" + node_a = SingleVariableNode("a", 10.0) + node_b = SingleVariableNode("b", -5.0) + node = BasicMathNode("a + b", a=node_a.a, b=node_b.b, node_label="test", backend="numpy") + state = node.sample_parameters() + assert state["test"]["function_node_result"] == 5.0 + + # Try with a math function. + node_c = SingleVariableNode("c", 1000.0) + node = BasicMathNode("a + b - log10(c)", a=10.0, b=5.0, c=node_c.c, node_label="test", backend="numpy") + state = node.sample_parameters() + assert state["test"]["function_node_result"] == 12.0 + + # Try with a second math function. + node = BasicMathNode( + "sqrt(a) + b - log10(c)", a=16.0, b=4.0, c=node_c.c, node_label="test", backend="numpy" + ) + state = node.sample_parameters() + assert state["test"]["function_node_result"] == 5.0 + + # Test that we can reproduce the power function. + node_d = SingleVariableNode("d", 5.0) + node = BasicMathNode("a ** b", a=node_d.d, b=2.5, node_label="test", backend="math") + state = node.sample_parameters() + assert state["test"]["function_node_result"] == pytest.approx(math.pow(5.0, 2.5)) + + +def test_basic_math_node_jax(): + """Test that we can perform computations via a BasicMathNode.""" + node_a = SingleVariableNode("a", 10.0) + node_b = SingleVariableNode("b", -5.0) + node = BasicMathNode("a + b", a=node_a.a, b=node_b.b, node_label="test", backend="jax") + state = node.sample_parameters() + assert state["test"]["function_node_result"] == 5.0 + + # Try with a math function. + node_c = SingleVariableNode("c", 1000.0) + node = BasicMathNode("a + b - log10(c)", a=10.0, b=5.0, c=node_c.c, node_label="test", backend="jax") + state = node.sample_parameters() + assert state["test"]["function_node_result"] == 12.0 + + # Try with a second math function. + node = BasicMathNode( + "sqrt(a) + b - log10(c)", a=16.0, b=4.0, c=node_c.c, node_label="test", backend="jax" + ) + state = node.sample_parameters() + assert state["test"]["function_node_result"] == 5.0 + + # Test that we can reproduce the power function. + node_d = SingleVariableNode("d", 5.0) + node = BasicMathNode("a ** b", a=node_d.d, b=2.5, node_label="test", backend="math") + state = node.sample_parameters() + assert state["test"]["function_node_result"] == pytest.approx(math.pow(5.0, 2.5)) + + +def test_basic_math_node_autodiff_jax(): + """Test that we can do auto-differentiation with JAX.""" + node_a = SingleVariableNode("a", 16.0, node_label="a_node") + node_b = SingleVariableNode("b", 1000.0, node_label="b_node") + + # Create a basic math function and create tghe pytree. + node = BasicMathNode( + "sqrt(a) + 1.0 - log10(b)", a=node_a.a, b=node_b.b, node_label="diff_test", backend="jax" + ) + state = node.sample_parameters() + pytree = node.build_pytree(state) + + gr_func = jax.value_and_grad(node.resample_and_compute) + values, gradients = gr_func(pytree) + assert values == 2.0 + assert gradients["a_node"]["a"] > 0.0 + assert gradients["b_node"]["b"] < 0.0 From 1490e6a3f51ddbb20a8c28a242ab4903e45f1e89 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 4 Nov 2024 10:30:37 -0500 Subject: [PATCH 2/3] Improve robustness --- src/tdastro/math_nodes/basic_math_node.py | 165 +++++++++++------- .../math_nodes/test_basic_math_node.py | 29 +++ 2 files changed, 127 insertions(+), 67 deletions(-) diff --git a/src/tdastro/math_nodes/basic_math_node.py b/src/tdastro/math_nodes/basic_math_node.py index 417d0b7b..dba52b72 100644 --- a/src/tdastro/math_nodes/basic_math_node.py +++ b/src/tdastro/math_nodes/basic_math_node.py @@ -1,10 +1,17 @@ -"""Nodes that perform basic math operations.""" +"""Nodes that perform basic math operations that can be specified as strings. + +The goal of this library is to save users from needing to create a bunch of +small FunctionNodes to perform basic math. +""" import ast -import math -import jax.numpy as jnp -import numpy as np +# Disable unused import because we need all of these imported +# so they can be used during evaluation of the node. +import math # noqa: F401 + +import jax.numpy as jnp # noqa: F401 +import numpy as np # noqa: F401 from tdastro.base_models import FunctionNode @@ -12,14 +19,24 @@ class BasicMathNode(FunctionNode): """A node that evaluates basic mathematical functions. + The BasicMathNode wraps Python's eval() function to sanitize the input string + and thus prevent the execution of arbitrary code. It also allows the user to write + the expression once and execute using math, numpy, or JAX. The names of the + variables in the expression must match the input variables provided by kwargs. + + Example: + my_node = BasicMathNode( + "redshift + 10.0 * sin(phase)", + redshift=host.redshift, + phase=source.phase, + ) + Attributes ---------- expression : `str` The expression to evaluate. backend : `str` The math libary to use. Must be one of: math, numpy, or jax. - tree : `ast.*` - The root node of the parsed syntax tree. Parameters ---------- @@ -30,17 +47,20 @@ class BasicMathNode(FunctionNode): node_label : `str`, optional An identifier (or name) for the current node. **kwargs : `dict`, optional - Any additional keyword arguments. + Any additional keyword arguments. Every variable in the expression + must be included as a kwarg. """ # A list of supported Python operations. Used to prevent eval from - # running arbitrary python expressions. + # running arbitrary python expressions. The Call and Name types are special + # cased so we can do checks and translations. _supported_ast_nodes = ( ast.Module, # Top level object when parsed as exec. ast.Expression, # Top level object when parsed as eval. ast.Expr, # Math expressions. ast.Constant, # Constant values. - ast.Name, # A named variable or function. + ast.Load, # Load a variable - must come from an approved function or variable. + ast.Store, # Store value - must come from an approved function or variable. ast.BinOp, # Binary operations ast.Add, ast.Sub, @@ -53,70 +73,70 @@ class BasicMathNode(FunctionNode): ast.UAdd, ast.USub, ast.Invert, - # Call functions (but we do NOT include ast.Call because - # we need to special case that). - ast.Load, - ast.Store, ) - # A very limited set of math operations that are supported - # in all of the backends. - _math_funcs = set( - [ - "abs", - "ceil", - "cos", - "cosh", - "degrees", - "exp", - "floor", - "log", - "log10", - "log2", - "radians", - "sin", - "sinh", - "sqrt", - "tan", - "tanh", - ] - ) + # A map from aa very limited set of supported math constant/function names to + # the corresponding names in [math, numpy, jax]. This is needed because + # a very few functions have different names in different libraries. + _math_map = { + "abs": ["abs", "np.abs", "jnp.abs"], # Special handling for math. + "acos": ["math.acos", "np.acos", "jnp.acos"], + "acosh": ["math.acosh", "np.acosh", "jnp.acosh"], + "asin": ["math.asin", "np.asin", "jnp.asin"], + "asinh": ["math.asinh", "np.asinh", "jnp.asinh"], + "atan": ["math.atan", "np.atan", "jnp.atan"], + "atan2": ["math.atan2", "np.atan2", "jnp.atan2"], + "cos": ["math.cos", "np.cos", "jnp.cos"], + "cosh": ["math.cosh", "np.cosh", "jnp.cosh"], + "ceil": ["math.ceil", "np.ceil", "jnp.ceil"], + "degrees": ["math.degrees", "np.degrees", "jnp.degrees"], + "deg2rad": ["math.radians", "np.deg2rad", "jnp.deg2rad"], # Special handling for math + "e": ["math.e", "np.e", "jnp.e"], + "exp": ["math.exp", "np.exp", "jnp.exp"], + "fabs": ["math.fabs", "np.fabs", "jnp.fabs"], + "floor": ["math.floor", "np.floor", "jnp.floor"], + "log": ["math.log", "np.log", "jnp.log"], + "log10": ["math.log10", "np.log10", "jnp.log10"], + "log2": ["math.log2", "np.log2", "jnp.log2"], + "max": ["max", "np.max", "jnp.max"], # Special handling for math + "min": ["min", "np.min", "jnp.min"], # Special handling for math + "pi": ["math.pi", "np.pi", "jnp.pi"], + "pow": ["math.pow", "np.power", "jnp.power"], # Special handling for numpy + "power": ["math.pow", "np.power", "jnp.power"], # Special handling for math + "radians": ["math.radians", "np.radians", "jnp.radians"], + "rad2deg": ["math.degrees", "np.rad2deg", "jnp.rad2deg"], # Special handling for math + "sin": ["math.sin", "np.sin", "jnp.sin"], + "sinh": ["math.sinh", "np.sinh", "jnp.sinh"], + "sqrt": ["math.sqrt", "np.sqrt", "jnp.sqrt"], + "tan": ["math.tan", "np.tan", "jnp.tan"], + "tanh": ["math.tanh", "np.tanh", "jnp.tanh"], + "trunc": ["math.trunc", "np.trunc", "jnp.trunc"], + } def __init__(self, expression, backend="numpy", node_label=None, **kwargs): - self.expression = expression - self.backend = backend - - # Check that all the functions are supported. - if backend == "math": - supported_funcs = dir(math) - supported_funcs.append("abs") - elif backend == "jax": - supported_funcs = dir(jnp) - elif backend == "numpy": - supported_funcs = dir(np) - else: + if backend not in ["jax", "math", "numpy"]: raise ValueError(f"Unsupported math backend {backend}") - - for fn_name in self._math_funcs: - if fn_name not in supported_funcs: - raise ValueError(f"Function {fn_name} is not supported by {backend}.") + self.backend = backend # Check the expression is pure math and translate it into the correct backend. - self._compile() + self.expression = expression + self._prepare(**kwargs) # Create a function from the expression. Note the expression has - # already been sanitized and validated via _compile(). + # already been sanitized and validated via _prepare(). def eval_func(**kwargs): return eval(self.expression, globals(), kwargs) super().__init__(eval_func, node_label=node_label, **kwargs) def __call__(self, **kwargs): - """Evaluate thge""" + """Evaluate the expression.""" return eval(self.expression, globals(), kwargs) - def _compile(self, **kwargs): - """Compile a python expression that consists of only basic math. + def _prepare(self, **kwargs): + """Rewrite a python expression that consists of only basic math to use + the prespecified math library. Santizes the string to prevent + arbitrary code execution. Parameters ---------- @@ -136,17 +156,28 @@ def _compile(self, **kwargs): if isinstance(node, self._supported_ast_nodes): # Nothing to do, this is a valid operation for the ast. continue - elif isinstance(node, (ast.FunctionType, ast.Call)): - func_name = node.func.id - if func_name not in self._math_funcs: - raise ValueError("Unsupported function {func_name}.") - - if self.backend == "numpy": - node.func.id = f"np.{func_name}" - elif self.backend == "jax": - node.func.id = f"jnp.{func_name}" - elif self.backend == "math" and "func_name" != "abs": - node.func.id = f"math.{func_name}" + elif isinstance(node, ast.Call): + # Check that function calls are only using items on the allow list. + if node.func.id not in self._math_map: + raise ValueError(f"Unsupported function {node.func.id}") + elif isinstance(node, ast.Name): + if node.id in kwargs: + # This is a user supplied variable. + continue + elif node.id in self._math_map: + # This is a math function or constant. Overwrite + if self.backend == "math": + node.id = self._math_map[node.id][0] + elif self.backend == "numpy": + node.id = self._math_map[node.id][1] + elif self.backend == "jax": + node.id = self._math_map[node.id][2] + else: + raise ValueError( + f"Unrecognized named variable or function {node.id}. " + "This could be because the function is not supported or " + "you forgot to include the variable as an argument." + ) else: raise ValueError(f"Invalid part of expression {type(node)}") diff --git a/tests/tdastro/math_nodes/test_basic_math_node.py b/tests/tdastro/math_nodes/test_basic_math_node.py index d01ed368..8e5b36ea 100644 --- a/tests/tdastro/math_nodes/test_basic_math_node.py +++ b/tests/tdastro/math_nodes/test_basic_math_node.py @@ -34,6 +34,35 @@ def test_basic_math_node(): assert state["test"]["function_node_result"] == pytest.approx(math.pow(5.0, 2.5)) +def test_basic_math_node_special_cases(): + """Test that we can handle some of the special cases for a BasicMathNode.""" + node_a = SingleVariableNode("a", 180.0) + node = BasicMathNode("sin(deg2rad(x) + pi / 2.0)", x=node_a.a, node_label="test", backend="math") + state = node.sample_parameters() + assert state["test"]["function_node_result"] == pytest.approx(-1.0) + + +def test_basic_math_node_fail(): + """Test that we perform the needed checks for a math node.""" + # Imports not allowed + with pytest.raises(ValueError): + _ = BasicMathNode("import os") + + # Ifs not allowed (won't work with JAX) + with pytest.raises(ValueError): + _ = BasicMathNode("x if 1.0 else 1.0", x=2.0) + + # We only allow functions on the allow list. + with pytest.raises(ValueError): + _ = BasicMathNode("fake_delete_everything_no_confirm('./')") + with pytest.raises(ValueError): + _ = BasicMathNode("median(10, 20)") + + # All variables must be defined. + with pytest.raises(ValueError): + _ = BasicMathNode("x + y", x=1.0) + + def test_basic_math_node_numpy(): """Test that we can perform computations via a BasicMathNode.""" node_a = SingleVariableNode("a", 10.0) From ecaf1f3b0a1401ee1b3bce8cbeee7a5a2fd59cb2 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 4 Nov 2024 11:03:24 -0500 Subject: [PATCH 3/3] Update src/tdastro/math_nodes/basic_math_node.py Co-authored-by: Melissa DeLucchi <113376043+delucchi-cmu@users.noreply.github.com> --- src/tdastro/math_nodes/basic_math_node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tdastro/math_nodes/basic_math_node.py b/src/tdastro/math_nodes/basic_math_node.py index dba52b72..55c40e8e 100644 --- a/src/tdastro/math_nodes/basic_math_node.py +++ b/src/tdastro/math_nodes/basic_math_node.py @@ -75,7 +75,7 @@ class BasicMathNode(FunctionNode): ast.Invert, ) - # A map from aa very limited set of supported math constant/function names to + # A map from a very limited set of supported math constant/function names to # the corresponding names in [math, numpy, jax]. This is needed because # a very few functions have different names in different libraries. _math_map = {