Skip to content

Commit

Permalink
Simplify conditional (#340)
Browse files Browse the repository at this point in the history
* Simplify conditional

* Fix bug in ArityMismatch message

* Comments
  • Loading branch information
pbrubeck authored Jan 21, 2025
1 parent 7d7c676 commit a4e9408
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 6 deletions.
40 changes: 40 additions & 0 deletions test/test_check_arities.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
TestFunction,
TrialFunction,
adjoint,
as_tensor,
cofac,
conditional,
conj,
derivative,
ds,
Expand Down Expand Up @@ -84,3 +86,41 @@ def test_product_arity():
with pytest.raises(ArityMismatch):
L = inner(v, v) * dx
compute_form_data(L, complex_mode=False)


def test_zero_simplify_arity():
"""
Test that adding verious zero-like expressions to a form is simplified,
such that one can compute form data for the integral.
"""
cell = tetrahedron
D = Mesh(FiniteElement("Lagrange", cell, 1, (3,), identity_pullback, H1))
V = FunctionSpace(D, FiniteElement("Lagrange", cell, 2, (), identity_pullback, H1))
v = TestFunction(V)
u = Coefficient(V)

nonzero = 1
with pytest.raises(ArityMismatch):
F = inner(u, v + nonzero) * dx
compute_form_data(F)
z = Coefficient(V)

# Add a Zero-component (rank-0) of a tensor to a rank-1 tensor
zero = as_tensor([0, z])[0]
F = inner(u, v + zero) * dx
fd = compute_form_data(F)
assert fd.num_coefficients == 1

# Add a conditional that should have been simplified to zero (rank-0)
# to a rank-1 tensor
zero = conditional(z < 0, 0, 0)
F = inner(u, v + zero) * dx
fd = compute_form_data(F)
assert fd.num_coefficients == 1

# Check that nested zero conditionals are simplifed to zero (rank-0)
# and can be added to a rank-1 tensor
zero = conditional(z < 0, 0, conditional(z == 0, 0, 0))
F = inner(u, v + zero) * dx
fd = compute_form_data(F)
assert fd.num_coefficients == 1
9 changes: 5 additions & 4 deletions ufl/algorithms/check_arities.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def sum(self, o, a, b):
"""Apply to sum."""
if a != b:
raise ArityMismatch(
f"Adding expressions with non-matching form arguments {_afmt(a)} vs {_afmt(b)}."
f"Adding expressions with non-matching form arguments "
f"{tuple(map(_afmt, a))} vs {tuple(map(_afmt, b))}."
)
return a

Expand Down Expand Up @@ -86,7 +87,7 @@ def product(self, o, a, b):
if len(c) != len(a) + len(b) or len(c) != len({x[0] for x in c}):
raise ArityMismatch(
"Multiplying expressions with overlapping form arguments "
f"{_afmt(a)} vs {_afmt(b)}."
f"{tuple(map(_afmt, a))} vs {tuple(map(_afmt, b))}."
)
# It's fine for argument parts to overlap
return c
Expand Down Expand Up @@ -138,7 +139,7 @@ def variable(self, o, f, a):
def conditional(self, o, c, a, b):
"""Apply to conditional."""
if c:
raise ArityMismatch(f"Condition cannot depend on form arguments ({_afmt(a)}).")
raise ArityMismatch("Condition cannot depend on form arguments.")
if a and isinstance(o.ufl_operands[2], Zero):
# Allow conditional(c, arg, 0)
return a
Expand All @@ -153,7 +154,7 @@ def conditional(self, o, c, a, b):
# conditional(c, test, nonzeroconstant)
raise ArityMismatch(
"Conditional subexpressions with non-matching form arguments "
f"{_afmt(a)} vs {_afmt(b)}."
f"{tuple(map(_afmt, a))} vs {tuple(map(_afmt, b))}."
)

def linear_indexed_type(self, o, a, i):
Expand Down
17 changes: 15 additions & 2 deletions ufl/conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,23 @@ class Conditional(Operator):
In C++ these take the format `(condition ? true_value : false_value)`.
"""

__slots__ = ()
__slots__ = ("_initialised",)

def __new__(cls, condition, true_value, false_value):
"""Create a new Conditional."""
# Simplify
if bool(true_value == false_value):
return true_value
# Construct a new instance to be initialised
self = Operator.__new__(cls)
self._initialised = False
return self

def __init__(self, condition, true_value, false_value):
"""Initialise."""
if self._initialised:
return
# Checks
if not isinstance(condition, Condition):
raise ValueError("Expecting condition as first argument.")
true_value = as_ufl(true_value)
Expand All @@ -290,8 +303,8 @@ def __init__(self, condition, true_value, false_value):
)
):
raise ValueError("Non-scalar == or != is not allowed.")

Operator.__init__(self, (condition, true_value, false_value))
self._initialised = True

def evaluate(self, x, mapping, component, index_values):
"""Evaluate."""
Expand Down

0 comments on commit a4e9408

Please sign in to comment.