Skip to content

Commit

Permalink
Properly set comparison type attribute on MLIR comparison ops (#1502)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Jun 4, 2024
1 parent bb61c58 commit 38bc042
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions exla/lib/exla/mlir/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,20 @@ defmodule EXLA.MLIR.Value do
%{type: rhs_type} = get_typespec(rhs)

comparison_type =
if Nx.Type.float?(lhs_type) or Nx.Type.float?(rhs_type) do
attr_comparison_type(:totalorder)
else
attr_comparison_type(:notype)
cond do
Nx.Type.complex?(lhs_type) or Nx.Type.complex?(rhs_type) ->
attr_comparison_type(:float)

Nx.Type.float?(lhs_type) or Nx.Type.float?(rhs_type) ->
attr_comparison_type(:float)

true ->
attr_comparison_type(:notype)
end

attributes = [
comparison_direction: attr_comparison_direction(direction),
comparison_type: comparison_type
compare_type: comparison_type
]

result_types = typespecs_to_mlir_types([Typespec.to_type(typespec, {:pred, 8})])
Expand Down Expand Up @@ -929,7 +934,7 @@ defmodule EXLA.MLIR.Value do
defp attr_comparison_direction(value) when value in [:eq, :lt, :le, :gt, :ge, :ne],
do: attr_enum("stablehlo", "comparison_direction", value)

defp attr_comparison_type(value) when value in [:totalorder, :notype],
defp attr_comparison_type(value) when value in [:float, :totalorder, :notype],
do: attr_enum("stablehlo", "comparison_type", value)

defp attr_precision(value) when value in [:default, :high, :highest],
Expand Down

0 comments on commit 38bc042

Please sign in to comment.