diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 06b14fd099..10b9c95f52 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -64,15 +64,20 @@ defmodule EXLA.MLIR.Value do %{type: rhs_type} = get_typespec(rhs) comparison_type = - if Nx.Type.float?(lhs_type) or Nx.Type.float?(rhs_type) do - attr_comparison_type(:totalorder) - else - attr_comparison_type(:notype) + cond do + Nx.Type.complex?(lhs_type) or Nx.Type.complex?(rhs_type) -> + attr_comparison_type(:float) + + Nx.Type.float?(lhs_type) or Nx.Type.float?(rhs_type) -> + attr_comparison_type(:float) + + true -> + attr_comparison_type(:notype) end attributes = [ comparison_direction: attr_comparison_direction(direction), - comparison_type: comparison_type + compare_type: comparison_type ] result_types = typespecs_to_mlir_types([Typespec.to_type(typespec, {:pred, 8})]) @@ -929,7 +934,7 @@ defmodule EXLA.MLIR.Value do defp attr_comparison_direction(value) when value in [:eq, :lt, :le, :gt, :ge, :ne], do: attr_enum("stablehlo", "comparison_direction", value) - defp attr_comparison_type(value) when value in [:totalorder, :notype], + defp attr_comparison_type(value) when value in [:float, :totalorder, :notype], do: attr_enum("stablehlo", "comparison_type", value) defp attr_precision(value) when value in [:default, :high, :highest],