Skip to content

Commit

Permalink
Improve robustness
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Nov 4, 2024
1 parent 2c7d1b7 commit 1490e6a
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 67 deletions.
165 changes: 98 additions & 67 deletions src/tdastro/math_nodes/basic_math_node.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,42 @@
"""Nodes that perform basic math operations."""
"""Nodes that perform basic math operations that can be specified as strings.
The goal of this library is to save users from needing to create a bunch of
small FunctionNodes to perform basic math.
"""

import ast
import math

import jax.numpy as jnp
import numpy as np
# Disable unused import because we need all of these imported
# so they can be used during evaluation of the node.
import math # noqa: F401

import jax.numpy as jnp # noqa: F401
import numpy as np # noqa: F401

from tdastro.base_models import FunctionNode


class BasicMathNode(FunctionNode):
"""A node that evaluates basic mathematical functions.
The BasicMathNode wraps Python's eval() function to sanitize the input string
and thus prevent the execution of arbitrary code. It also allows the user to write
the expression once and execute using math, numpy, or JAX. The names of the
variables in the expression must match the input variables provided by kwargs.
Example:
my_node = BasicMathNode(
"redshift + 10.0 * sin(phase)",
redshift=host.redshift,
phase=source.phase,
)
Attributes
----------
expression : `str`
The expression to evaluate.
backend : `str`
The math libary to use. Must be one of: math, numpy, or jax.
tree : `ast.*`
The root node of the parsed syntax tree.
Parameters
----------
Expand All @@ -30,17 +47,20 @@ class BasicMathNode(FunctionNode):
node_label : `str`, optional
An identifier (or name) for the current node.
**kwargs : `dict`, optional
Any additional keyword arguments.
Any additional keyword arguments. Every variable in the expression
must be included as a kwarg.
"""

# A list of supported Python operations. Used to prevent eval from
# running arbitrary python expressions.
# running arbitrary python expressions. The Call and Name types are special
# cased so we can do checks and translations.
_supported_ast_nodes = (
ast.Module, # Top level object when parsed as exec.
ast.Expression, # Top level object when parsed as eval.
ast.Expr, # Math expressions.
ast.Constant, # Constant values.
ast.Name, # A named variable or function.
ast.Load, # Load a variable - must come from an approved function or variable.
ast.Store, # Store value - must come from an approved function or variable.
ast.BinOp, # Binary operations
ast.Add,
ast.Sub,
Expand All @@ -53,70 +73,70 @@ class BasicMathNode(FunctionNode):
ast.UAdd,
ast.USub,
ast.Invert,
# Call functions (but we do NOT include ast.Call because
# we need to special case that).
ast.Load,
ast.Store,
)

# A very limited set of math operations that are supported
# in all of the backends.
_math_funcs = set(
[
"abs",
"ceil",
"cos",
"cosh",
"degrees",
"exp",
"floor",
"log",
"log10",
"log2",
"radians",
"sin",
"sinh",
"sqrt",
"tan",
"tanh",
]
)
# A map from aa very limited set of supported math constant/function names to
# the corresponding names in [math, numpy, jax]. This is needed because
# a very few functions have different names in different libraries.
_math_map = {
"abs": ["abs", "np.abs", "jnp.abs"], # Special handling for math.
"acos": ["math.acos", "np.acos", "jnp.acos"],
"acosh": ["math.acosh", "np.acosh", "jnp.acosh"],
"asin": ["math.asin", "np.asin", "jnp.asin"],
"asinh": ["math.asinh", "np.asinh", "jnp.asinh"],
"atan": ["math.atan", "np.atan", "jnp.atan"],
"atan2": ["math.atan2", "np.atan2", "jnp.atan2"],
"cos": ["math.cos", "np.cos", "jnp.cos"],
"cosh": ["math.cosh", "np.cosh", "jnp.cosh"],
"ceil": ["math.ceil", "np.ceil", "jnp.ceil"],
"degrees": ["math.degrees", "np.degrees", "jnp.degrees"],
"deg2rad": ["math.radians", "np.deg2rad", "jnp.deg2rad"], # Special handling for math
"e": ["math.e", "np.e", "jnp.e"],
"exp": ["math.exp", "np.exp", "jnp.exp"],
"fabs": ["math.fabs", "np.fabs", "jnp.fabs"],
"floor": ["math.floor", "np.floor", "jnp.floor"],
"log": ["math.log", "np.log", "jnp.log"],
"log10": ["math.log10", "np.log10", "jnp.log10"],
"log2": ["math.log2", "np.log2", "jnp.log2"],
"max": ["max", "np.max", "jnp.max"], # Special handling for math
"min": ["min", "np.min", "jnp.min"], # Special handling for math
"pi": ["math.pi", "np.pi", "jnp.pi"],
"pow": ["math.pow", "np.power", "jnp.power"], # Special handling for numpy
"power": ["math.pow", "np.power", "jnp.power"], # Special handling for math
"radians": ["math.radians", "np.radians", "jnp.radians"],
"rad2deg": ["math.degrees", "np.rad2deg", "jnp.rad2deg"], # Special handling for math
"sin": ["math.sin", "np.sin", "jnp.sin"],
"sinh": ["math.sinh", "np.sinh", "jnp.sinh"],
"sqrt": ["math.sqrt", "np.sqrt", "jnp.sqrt"],
"tan": ["math.tan", "np.tan", "jnp.tan"],
"tanh": ["math.tanh", "np.tanh", "jnp.tanh"],
"trunc": ["math.trunc", "np.trunc", "jnp.trunc"],
}

def __init__(self, expression, backend="numpy", node_label=None, **kwargs):
self.expression = expression
self.backend = backend

# Check that all the functions are supported.
if backend == "math":
supported_funcs = dir(math)
supported_funcs.append("abs")
elif backend == "jax":
supported_funcs = dir(jnp)
elif backend == "numpy":
supported_funcs = dir(np)
else:
if backend not in ["jax", "math", "numpy"]:
raise ValueError(f"Unsupported math backend {backend}")

for fn_name in self._math_funcs:
if fn_name not in supported_funcs:
raise ValueError(f"Function {fn_name} is not supported by {backend}.")
self.backend = backend

# Check the expression is pure math and translate it into the correct backend.
self._compile()
self.expression = expression
self._prepare(**kwargs)

# Create a function from the expression. Note the expression has
# already been sanitized and validated via _compile().
# already been sanitized and validated via _prepare().
def eval_func(**kwargs):
return eval(self.expression, globals(), kwargs)

super().__init__(eval_func, node_label=node_label, **kwargs)

def __call__(self, **kwargs):
"""Evaluate thge"""
"""Evaluate the expression."""
return eval(self.expression, globals(), kwargs)

def _compile(self, **kwargs):
"""Compile a python expression that consists of only basic math.
def _prepare(self, **kwargs):
"""Rewrite a python expression that consists of only basic math to use
the prespecified math library. Santizes the string to prevent
arbitrary code execution.
Parameters
----------
Expand All @@ -136,17 +156,28 @@ def _compile(self, **kwargs):
if isinstance(node, self._supported_ast_nodes):
# Nothing to do, this is a valid operation for the ast.
continue
elif isinstance(node, (ast.FunctionType, ast.Call)):
func_name = node.func.id
if func_name not in self._math_funcs:
raise ValueError("Unsupported function {func_name}.")

if self.backend == "numpy":
node.func.id = f"np.{func_name}"
elif self.backend == "jax":
node.func.id = f"jnp.{func_name}"
elif self.backend == "math" and "func_name" != "abs":
node.func.id = f"math.{func_name}"
elif isinstance(node, ast.Call):
# Check that function calls are only using items on the allow list.
if node.func.id not in self._math_map:
raise ValueError(f"Unsupported function {node.func.id}")
elif isinstance(node, ast.Name):
if node.id in kwargs:
# This is a user supplied variable.
continue
elif node.id in self._math_map:
# This is a math function or constant. Overwrite
if self.backend == "math":
node.id = self._math_map[node.id][0]
elif self.backend == "numpy":
node.id = self._math_map[node.id][1]
elif self.backend == "jax":
node.id = self._math_map[node.id][2]
else:
raise ValueError(
f"Unrecognized named variable or function {node.id}. "
"This could be because the function is not supported or "
"you forgot to include the variable as an argument."
)
else:
raise ValueError(f"Invalid part of expression {type(node)}")

Expand Down
29 changes: 29 additions & 0 deletions tests/tdastro/math_nodes/test_basic_math_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,35 @@ def test_basic_math_node():
assert state["test"]["function_node_result"] == pytest.approx(math.pow(5.0, 2.5))


def test_basic_math_node_special_cases():
"""Test that we can handle some of the special cases for a BasicMathNode."""
node_a = SingleVariableNode("a", 180.0)
node = BasicMathNode("sin(deg2rad(x) + pi / 2.0)", x=node_a.a, node_label="test", backend="math")
state = node.sample_parameters()
assert state["test"]["function_node_result"] == pytest.approx(-1.0)


def test_basic_math_node_fail():
"""Test that we perform the needed checks for a math node."""
# Imports not allowed
with pytest.raises(ValueError):
_ = BasicMathNode("import os")

# Ifs not allowed (won't work with JAX)
with pytest.raises(ValueError):
_ = BasicMathNode("x if 1.0 else 1.0", x=2.0)

# We only allow functions on the allow list.
with pytest.raises(ValueError):
_ = BasicMathNode("fake_delete_everything_no_confirm('./')")
with pytest.raises(ValueError):
_ = BasicMathNode("median(10, 20)")

# All variables must be defined.
with pytest.raises(ValueError):
_ = BasicMathNode("x + y", x=1.0)


def test_basic_math_node_numpy():
"""Test that we can perform computations via a BasicMathNode."""
node_a = SingleVariableNode("a", 10.0)
Expand Down

0 comments on commit 1490e6a

Please sign in to comment.