From 38bc042174525022b023a48cc11ab31a4157fa54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Tue, 4 Jun 2024 15:43:26 +0200 Subject: [PATCH] Properly set comparison type attribute on MLIR comparison ops (#1502) --- exla/lib/exla/mlir/value.ex | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 06b14fd099..10b9c95f52 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -64,15 +64,20 @@ defmodule EXLA.MLIR.Value do %{type: rhs_type} = get_typespec(rhs) comparison_type = - if Nx.Type.float?(lhs_type) or Nx.Type.float?(rhs_type) do - attr_comparison_type(:totalorder) - else - attr_comparison_type(:notype) + cond do + Nx.Type.complex?(lhs_type) or Nx.Type.complex?(rhs_type) -> + attr_comparison_type(:float) + + Nx.Type.float?(lhs_type) or Nx.Type.float?(rhs_type) -> + attr_comparison_type(:float) + + true -> + attr_comparison_type(:notype) end attributes = [ comparison_direction: attr_comparison_direction(direction), - comparison_type: comparison_type + compare_type: comparison_type ] result_types = typespecs_to_mlir_types([Typespec.to_type(typespec, {:pred, 8})]) @@ -929,7 +934,7 @@ defmodule EXLA.MLIR.Value do defp attr_comparison_direction(value) when value in [:eq, :lt, :le, :gt, :ge, :ne], do: attr_enum("stablehlo", "comparison_direction", value) - defp attr_comparison_type(value) when value in [:totalorder, :notype], + defp attr_comparison_type(value) when value in [:float, :totalorder, :notype], do: attr_enum("stablehlo", "comparison_type", value) defp attr_precision(value) when value in [:default, :high, :highest],