Skip to content

Commit

Permalink
fix: Fix defining function in case some of them evaluate to constant (#…
Browse files Browse the repository at this point in the history
…144)

* Add test triggering #143

* Fix typo in test

* Fix root cause of the issue - invalid usage of identity_for_numbers

* Add a warning to the docstring to avoid incorrect usage of
identity_for_numbers

* Add comment about the function g
  • Loading branch information
dexter2206 authored Nov 26, 2024
1 parent 3f2fff3 commit dcc5e05
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 21 deletions.
50 changes: 29 additions & 21 deletions src/bartiq/symbolics/sympy_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ def _inner(backend: SympyBackend, expr: TExpr[S], *args: P.args, **kwargs: P.kwa


def identity_for_numbers(func: ExprTransformer[P, T | Number]) -> TExprTransformer[P, T | Number]:
"""Return a new method that preserves originally passed one on expressions and acts as identity on numbers.
Note:
This function can ONLY be used on methods of SympyBackend class.
If you want to use it on a function, add dummy `_backend` parameter as a first arg - but do know
that this is discouraged. Incorrect usage of this decorator on an ordinary function resulted
in an obscure bug: https://github.com/PsiQ/bartiq/issues/143
"""

def _inner(backend: SympyBackend, expr: TExpr[S], *args: P.args, **kwargs: P.kwargs) -> T | Number:
return expr if isinstance(expr, Number) else func(backend, expr, *args, **kwargs)

Expand All @@ -86,26 +95,6 @@ def parse_to_sympy(expression: str, debug: bool = False) -> Expr:
return parse(expression, interpreter=SympyInterpreter(debug=debug))


@identity_for_numbers
def _define_function(expr: Expr, func_name: str, function: Callable) -> TExpr[Expr]:
"""Define an undefined function."""
# Catch attempt to define special function names
if func_name in BUILT_IN_FUNCTIONS:
raise BartiqCompilationError(
f"Attempted to redefine the special function {func_name}; cannot define special functions."
)

# Trying to evaluate a function which cannot be evaluated symbolically raises TypeError.
# This, however, is expected for certain functions (e.g. with conditions)
try:
return expr.replace(
lambda pattern: isinstance(pattern, SYMPY_USER_FUNCTION_TYPES) and str(type(pattern)) == func_name,
lambda match: function(*match.args),
)
except TypeError:
return expr


class SympyBackend:

def __init__(self, parse_function: Callable[[str], Expr] = parse_to_sympy):
Expand Down Expand Up @@ -179,9 +168,28 @@ def substitute(
if functions_map is None:
functions_map = {}
for func_name, func in functions_map.items():
expr = _define_function(expr, func_name, func)
expr = self._define_function(expr, func_name, func)
return value if (value := self.value_of(expr)) is not None else expr

@identity_for_numbers
def _define_function(self, expr: Expr, func_name: str, function: Callable) -> TExpr[Expr]:
"""Define an undefined function."""
# Catch attempt to define special function names
if func_name in BUILT_IN_FUNCTIONS:
raise BartiqCompilationError(
f"Attempted to redefine the special function {func_name}; cannot define special functions."
)

# Trying to evaluate a function which cannot be evaluated symbolically raises TypeError.
# This, however, is expected for certain functions (e.g. with conditions)
try:
return expr.replace(
lambda pattern: isinstance(pattern, SYMPY_USER_FUNCTION_TYPES) and str(type(pattern)) == func_name,
lambda match: function(*match.args),
)
except TypeError:
return expr

def is_constant_int(self, expr: TExpr[Expr]):
"""Return True if a given expression represents a constant int and False otherwise."""
try:
Expand Down
20 changes: 20 additions & 0 deletions tests/symbolics/test_sympy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,23 @@ def test_functions_obtained_from_backend_can_be_called_to_obtain_new_expressions
result = func(arg)

assert backend.as_native(result) == expected_native_result


@pytest.mark.parametrize(
"expression_str, variables, functions, expected_native_result",
[
("8", {}, {"f": lambda x: x + 10}, 8),
# Note: existence of function "g" is crucial in the examples below, even if it is not used explicitly
# These examples triggered issue #143: https://github.com/PsiQ/bartiq/issues/143
("f(x)", {"x": 5}, {"f": lambda x: int(x) + 5, "g": lambda x: x}, 10),
("f(x)+10", {}, {"f": lambda x: 2, "g": lambda x: int(x) ** 2}, 12),
],
)
def test_function_definition_succeeds_even_if_expression_becomes_constant(
expression_str, variables, functions, expected_native_result, backend
):
expr = backend.as_expression(expression_str)

new_expr = backend.substitute(expr, variables, functions)

assert backend.as_native(new_expr) == expected_native_result

0 comments on commit dcc5e05

Please sign in to comment.