-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
421 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import numpy as np | ||
|
||
|
||
def phot_eff_function(snr): | ||
""" | ||
Photometric detection efficiency as a simple step function of snr. | ||
Parameters | ||
---------- | ||
snr: `list` or `numpy.ndarray` | ||
Signal to noise ratio of a list of observations. | ||
Returns | ||
------- | ||
eff: `list` or `numpy.ndarray` | ||
The photometric detection efficiency given snr. | ||
""" | ||
|
||
snr = np.array(snr) | ||
eff = np.where(snr > 5, 1.0, 0.0) | ||
|
||
return eff | ||
|
||
|
||
def spec_eff_function(peak_imag): | ||
""" | ||
Spectroscopic follow-up efficiency as a function of peak i band magnitude. | ||
Parameters | ||
---------- | ||
peak_imag: `list` or `numpy.ndarray` | ||
Peak magnitude in i band. | ||
Returns | ||
------- | ||
eff: `list` or `numpy.ndarray` | ||
The spectroscopic efficiency given peak i band magnitude. | ||
Based on Equation (17) in Kessler et al. 2019 | ||
s0, s1, s2 are fitted using data from Figure 4 in Kessler et al. 2019 | ||
""" | ||
|
||
s0 = 1.0 | ||
s1 = 2.36 | ||
s2 = 51.9 | ||
|
||
peak_imag = np.array(peak_imag) | ||
eff = s0 * np.power((1.0 + np.exp(s1 * peak_imag - s2)), -1) | ||
|
||
return eff |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
"""Nodes that perform basic math operations that can be specified as strings. | ||
The goal of this library is to save users from needing to create a bunch of | ||
small FunctionNodes to perform basic math. | ||
""" | ||
|
||
import ast | ||
|
||
# Disable unused import because we need all of these imported | ||
# so they can be used during evaluation of the node. | ||
import math # noqa: F401 | ||
|
||
import jax.numpy as jnp # noqa: F401 | ||
import numpy as np # noqa: F401 | ||
|
||
from tdastro.base_models import FunctionNode | ||
|
||
|
||
class BasicMathNode(FunctionNode): | ||
"""A node that evaluates basic mathematical functions. | ||
The BasicMathNode wraps Python's eval() function to sanitize the input string | ||
and thus prevent the execution of arbitrary code. It also allows the user to write | ||
the expression once and execute using math, numpy, or JAX. The names of the | ||
variables in the expression must match the input variables provided by kwargs. | ||
Example: | ||
my_node = BasicMathNode( | ||
"redshift + 10.0 * sin(phase)", | ||
redshift=host.redshift, | ||
phase=source.phase, | ||
) | ||
Attributes | ||
---------- | ||
expression : `str` | ||
The expression to evaluate. | ||
backend : `str` | ||
The math libary to use. Must be one of: math, numpy, or jax. | ||
Parameters | ||
---------- | ||
expression : `str` | ||
The expression to evaluate. | ||
backend : `str` | ||
The math libary to use. Must be one of: math, numpy, or jax. | ||
node_label : `str`, optional | ||
An identifier (or name) for the current node. | ||
**kwargs : `dict`, optional | ||
Any additional keyword arguments. Every variable in the expression | ||
must be included as a kwarg. | ||
""" | ||
|
||
# A list of supported Python operations. Used to prevent eval from | ||
# running arbitrary python expressions. The Call and Name types are special | ||
# cased so we can do checks and translations. | ||
_supported_ast_nodes = ( | ||
ast.Module, # Top level object when parsed as exec. | ||
ast.Expression, # Top level object when parsed as eval. | ||
ast.Expr, # Math expressions. | ||
ast.Constant, # Constant values. | ||
ast.Load, # Load a variable - must come from an approved function or variable. | ||
ast.Store, # Store value - must come from an approved function or variable. | ||
ast.BinOp, # Binary operations | ||
ast.Add, | ||
ast.Sub, | ||
ast.Mult, | ||
ast.Div, | ||
ast.FloorDiv, | ||
ast.Mod, | ||
ast.Pow, | ||
ast.UnaryOp, # Uninary operations | ||
ast.UAdd, | ||
ast.USub, | ||
ast.Invert, | ||
) | ||
|
||
# A map from a very limited set of supported math constant/function names to | ||
# the corresponding names in [math, numpy, jax]. This is needed because | ||
# a very few functions have different names in different libraries. | ||
_math_map = { | ||
"abs": ["abs", "np.abs", "jnp.abs"], # Special handling for math. | ||
"acos": ["math.acos", "np.acos", "jnp.acos"], | ||
"acosh": ["math.acosh", "np.acosh", "jnp.acosh"], | ||
"asin": ["math.asin", "np.asin", "jnp.asin"], | ||
"asinh": ["math.asinh", "np.asinh", "jnp.asinh"], | ||
"atan": ["math.atan", "np.atan", "jnp.atan"], | ||
"atan2": ["math.atan2", "np.atan2", "jnp.atan2"], | ||
"cos": ["math.cos", "np.cos", "jnp.cos"], | ||
"cosh": ["math.cosh", "np.cosh", "jnp.cosh"], | ||
"ceil": ["math.ceil", "np.ceil", "jnp.ceil"], | ||
"degrees": ["math.degrees", "np.degrees", "jnp.degrees"], | ||
"deg2rad": ["math.radians", "np.deg2rad", "jnp.deg2rad"], # Special handling for math | ||
"e": ["math.e", "np.e", "jnp.e"], | ||
"exp": ["math.exp", "np.exp", "jnp.exp"], | ||
"fabs": ["math.fabs", "np.fabs", "jnp.fabs"], | ||
"floor": ["math.floor", "np.floor", "jnp.floor"], | ||
"log": ["math.log", "np.log", "jnp.log"], | ||
"log10": ["math.log10", "np.log10", "jnp.log10"], | ||
"log2": ["math.log2", "np.log2", "jnp.log2"], | ||
"max": ["max", "np.max", "jnp.max"], # Special handling for math | ||
"min": ["min", "np.min", "jnp.min"], # Special handling for math | ||
"pi": ["math.pi", "np.pi", "jnp.pi"], | ||
"pow": ["math.pow", "np.power", "jnp.power"], # Special handling for numpy | ||
"power": ["math.pow", "np.power", "jnp.power"], # Special handling for math | ||
"radians": ["math.radians", "np.radians", "jnp.radians"], | ||
"rad2deg": ["math.degrees", "np.rad2deg", "jnp.rad2deg"], # Special handling for math | ||
"sin": ["math.sin", "np.sin", "jnp.sin"], | ||
"sinh": ["math.sinh", "np.sinh", "jnp.sinh"], | ||
"sqrt": ["math.sqrt", "np.sqrt", "jnp.sqrt"], | ||
"tan": ["math.tan", "np.tan", "jnp.tan"], | ||
"tanh": ["math.tanh", "np.tanh", "jnp.tanh"], | ||
"trunc": ["math.trunc", "np.trunc", "jnp.trunc"], | ||
} | ||
|
||
def __init__(self, expression, backend="numpy", node_label=None, **kwargs): | ||
if backend not in ["jax", "math", "numpy"]: | ||
raise ValueError(f"Unsupported math backend {backend}") | ||
self.backend = backend | ||
|
||
# Check the expression is pure math and translate it into the correct backend. | ||
self.expression = expression | ||
self._prepare(**kwargs) | ||
|
||
# Create a function from the expression. Note the expression has | ||
# already been sanitized and validated via _prepare(). | ||
def eval_func(**kwargs): | ||
try: | ||
return eval(self.expression, globals(), kwargs) | ||
except Exception as problem: | ||
# Provide more detailed logging, including the expression and parameters | ||
# used, when we encounter a math error like divide by zero. | ||
new_message = f"Error during math operation '{self.expression}' with args={kwargs}" | ||
raise type(problem)(new_message) from problem | ||
|
||
super().__init__(eval_func, node_label=node_label, **kwargs) | ||
|
||
def __call__(self, **kwargs): | ||
"""Evaluate the expression.""" | ||
return eval(self.expression, globals(), kwargs) | ||
|
||
def _prepare(self, **kwargs): | ||
"""Rewrite a python expression that consists of only basic math to use | ||
the prespecified math library. Santizes the string to prevent | ||
arbitrary code execution. | ||
Parameters | ||
---------- | ||
**kwargs : `dict`, optional | ||
Any additional keyword arguments, including the variable | ||
assignments. | ||
Returns | ||
------- | ||
tree : `ast.*` | ||
The root node of the parsed syntax tree. | ||
""" | ||
tree = ast.parse(self.expression) | ||
|
||
# Walk the tree and confirm that it only contains the basic math. | ||
for node in ast.walk(tree): | ||
if isinstance(node, self._supported_ast_nodes): | ||
# Nothing to do, this is a valid operation for the ast. | ||
continue | ||
elif isinstance(node, ast.Call): | ||
# Check that function calls are only using items on the allow list. | ||
if node.func.id not in self._math_map: | ||
raise ValueError(f"Unsupported function {node.func.id}") | ||
elif isinstance(node, ast.Name): | ||
if node.id in kwargs: | ||
# This is a user supplied variable. | ||
continue | ||
elif node.id in self._math_map: | ||
# This is a math function or constant. Overwrite | ||
if self.backend == "math": | ||
node.id = self._math_map[node.id][0] | ||
elif self.backend == "numpy": | ||
node.id = self._math_map[node.id][1] | ||
elif self.backend == "jax": | ||
node.id = self._math_map[node.id][2] | ||
else: | ||
raise ValueError( | ||
f"Unrecognized named variable or function {node.id}. " | ||
"This could be because the function is not supported or " | ||
"you forgot to include the variable as an argument." | ||
) | ||
else: | ||
raise ValueError(f"Invalid part of expression {type(node)}") | ||
|
||
# Convert the expression back into a string. | ||
self.expression = ast.unparse(tree) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import numpy as np | ||
from tdastro.astro_utils.obs_utils import phot_eff_function, spec_eff_function | ||
|
||
|
||
def test_phot_eff_function(): | ||
""" | ||
test that the phot_eff_function returns correct values. | ||
""" | ||
|
||
snr = [1.0, 3.0, 10.0, 100.0] | ||
eff = phot_eff_function(snr) | ||
expected_eff = [0.0, 0.0, 1.0, 1.0] | ||
|
||
np.testing.assert_allclose(eff, expected_eff) | ||
|
||
|
||
def test_spec_eff_function(): | ||
""" | ||
test the spec efficiency function using data extracted from Figure 4 in Kessler et al. 2019 | ||
""" | ||
|
||
imags = [20.01, 21.09, 21.67, 22.54] | ||
expected_eff = [1.0, 0.9, 0.67, 0.2] | ||
eff = spec_eff_function(imags) | ||
|
||
np.testing.assert_allclose(eff, expected_eff, atol=0.02) |
Oops, something went wrong.