From 741829f378db492ea43943515568a1ee42b292fb Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 21 May 2024 19:56:25 -0300 Subject: [PATCH] feat: update MLIR notation to latest stablehlo spec" (#1488) --- exla/lib/exla/mlir/module.ex | 2 ++ exla/lib/exla/mlir/value.ex | 54 +++++++++++++++++++----------------- 2 files changed, 31 insertions(+), 25 deletions(-) diff --git a/exla/lib/exla/mlir/module.ex b/exla/lib/exla/mlir/module.ex index 3c91024c32..1005d15ed9 100644 --- a/exla/lib/exla/mlir/module.ex +++ b/exla/lib/exla/mlir/module.ex @@ -98,6 +98,8 @@ defmodule EXLA.MLIR.Module do do: -1, else: Keyword.get(options, :device_id, client.default_device_id) + # module.ref |> EXLA.NIF.mlir_module_to_string() |> elem(1) |> IO.puts() + ref = EXLA.NIF.mlir_compile( client.ref, diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index b469048369..b5f853b2da 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -186,13 +186,13 @@ defmodule EXLA.MLIR.Value do def reverse(%Value{function: func} = operand, dims, typespec) do result_types = typespecs_to_mlir_types([typespec]) - attributes = [dimensions: attr_dense_i64_elements(dims)] + attributes = [dimensions: attr_array_i64_elements(dims)] op(func, "stablehlo.reverse", [operand], result_types, attributes: attributes) |> one!() end def transpose(%Value{function: func} = operand, axes, typespec) do result_types = typespecs_to_mlir_types([typespec]) - attributes = [permutation: attr_dense_i64_elements(axes)] + attributes = [permutation: attr_array_i64_elements(axes)] op(func, "stablehlo.transpose", [operand], result_types, attributes: attributes) |> one!() end @@ -200,9 +200,9 @@ defmodule EXLA.MLIR.Value do result_types = typespecs_to_mlir_types([typespec]) attributes = [ - start_indices: attr_dense_i64_elements(starts), - limit_indices: attr_dense_i64_elements(limits), - strides: attr_dense_i64_elements(strides) + start_indices: attr_array_i64_elements(starts), + limit_indices: attr_array_i64_elements(limits), + strides: attr_array_i64_elements(strides) ] op(func, "stablehlo.slice", [operand], result_types, attributes: attributes) |> one!() @@ -211,7 +211,7 @@ defmodule EXLA.MLIR.Value do def dynamic_slice(%Value{function: func} = operand, starts, lengths, typespec) do result_types = typespecs_to_mlir_types([typespec]) operands = [operand] ++ starts - attributes = [slice_sizes: attr_dense_i64_elements(lengths)] + attributes = [slice_sizes: attr_array_i64_elements(lengths)] op(func, "stablehlo.dynamic_slice", operands, result_types, attributes: attributes) |> one!() end @@ -303,7 +303,7 @@ defmodule EXLA.MLIR.Value do result_types = typespecs_to_mlir_types([typespec]) attributes = [ - broadcast_dimensions: attr_dense_i64_elements(axes) + broadcast_dimensions: attr_array_i64_elements(axes) ] op(func, "stablehlo.broadcast_in_dim", [operand], result_types, attributes: attributes) @@ -347,9 +347,9 @@ defmodule EXLA.MLIR.Value do {padding_low, padding_high, padding_mid} = unzip_padding_config(padding_config) attributes = [ - edge_padding_low: attr_dense_i64_elements(padding_low), - edge_padding_high: attr_dense_i64_elements(padding_high), - interior_padding: attr_dense_i64_elements(padding_mid) + edge_padding_low: attr_array_i64_elements(padding_low), + edge_padding_high: attr_array_i64_elements(padding_high), + interior_padding: attr_array_i64_elements(padding_mid) ] op(func, "stablehlo.pad", [operand, pad], result_types, attributes: attributes) |> one!() @@ -375,7 +375,7 @@ defmodule EXLA.MLIR.Value do attributes = [ fft_type: fft_type, - fft_length: attr_dense_i64_elements(List.wrap(fft_length)) + fft_length: attr_array_i64_elements(List.wrap(fft_length)) ] op(func, "stablehlo.fft", [value], result_types, attributes: attributes) |> one!() @@ -451,8 +451,8 @@ defmodule EXLA.MLIR.Value do result_types = typespecs_to_mlir_types([typespec]) attributes = [ - window_dimensions: attr_dense_i64_elements(window_dimensions), - window_strides: attr_dense_i64_elements(window_strides), + window_dimensions: attr_array_i64_elements(window_dimensions), + window_strides: attr_array_i64_elements(window_strides), padding: attr_padding(padding) ] @@ -501,7 +501,7 @@ defmodule EXLA.MLIR.Value do attributes = [ dimension_numbers: dimension_numbers, - slice_sizes: attr_dense_i64_elements(slice_sizes), + slice_sizes: attr_array_i64_elements(slice_sizes), indices_are_sorted: attr_boolean(false) ] @@ -546,10 +546,10 @@ defmodule EXLA.MLIR.Value do attr_precision_config = attr_precision_config(precision_config) attributes = [ - window_strides: attr_dense_i64_elements(strides), + window_strides: attr_array_i64_elements(strides), padding: attr_padding(padding), - lhs_dilation: attr_dense_i64_elements(input_dilation), - rhs_dilation: attr_dense_i64_elements(kernel_dilation), + lhs_dilation: attr_array_i64_elements(input_dilation), + rhs_dilation: attr_array_i64_elements(kernel_dilation), dimension_numbers: attr_conv_dimension_numbers(dimension_numbers), feature_group_count: attr_i64(feature_group_count), batch_group_count: attr_i64(batch_group_count), @@ -625,7 +625,7 @@ defmodule EXLA.MLIR.Value do ) do operands = inputs ++ init_values result_types = typespecs_to_mlir_types(typespecs) - attributes = [dimensions: attr_dense_i64_elements(dimensions)] + attributes = [dimensions: attr_array_i64_elements(dimensions)] regions = [reducer] op(func, "stablehlo.reduce", operands, result_types, attributes: attributes, regions: regions) end @@ -645,10 +645,10 @@ defmodule EXLA.MLIR.Value do result_types = typespecs_to_mlir_types(typespecs) attributes = [ - window_dimensions: attr_dense_i64_elements(window_dimensions), - window_strides: attr_dense_i64_elements(window_strides), - base_dilations: attr_dense_i64_elements(input_dilations), - window_dilations: attr_dense_i64_elements(window_dilations), + window_dimensions: attr_array_i64_elements(window_dimensions), + window_strides: attr_array_i64_elements(window_strides), + base_dilations: attr_array_i64_elements(input_dilations), + window_dilations: attr_array_i64_elements(window_dilations), padding: attr_padding(padding) ] @@ -669,7 +669,7 @@ defmodule EXLA.MLIR.Value do result_types = typespecs_to_mlir_types([typespec]) attributes = [ - dimensions: attr_dense_i64_elements(dimensions) + dimensions: attr_array_i64_elements(dimensions) ] regions = [mapper] @@ -904,8 +904,12 @@ defmodule EXLA.MLIR.Value do <> end - defp attr_dense_i64_elements(list) do - attr_dense_elements(list, {:s, 64}, {length(list)}) + defp attr_array_i64_elements([]) do + "array" + end + + defp attr_array_i64_elements(list) do + "array" end defp attr_dense_elements([], type, {0} = shape) do