Skip to content

Commit

Permalink
feat: update MLIR notation to latest stablehlo spec" (#1488)
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente authored May 21, 2024
1 parent e26112d commit 741829f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 25 deletions.
2 changes: 2 additions & 0 deletions exla/lib/exla/mlir/module.ex
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ defmodule EXLA.MLIR.Module do
do: -1,
else: Keyword.get(options, :device_id, client.default_device_id)

# module.ref |> EXLA.NIF.mlir_module_to_string() |> elem(1) |> IO.puts()

ref =
EXLA.NIF.mlir_compile(
client.ref,
Expand Down
54 changes: 29 additions & 25 deletions exla/lib/exla/mlir/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -186,23 +186,23 @@ defmodule EXLA.MLIR.Value do

def reverse(%Value{function: func} = operand, dims, typespec) do
result_types = typespecs_to_mlir_types([typespec])
attributes = [dimensions: attr_dense_i64_elements(dims)]
attributes = [dimensions: attr_array_i64_elements(dims)]
op(func, "stablehlo.reverse", [operand], result_types, attributes: attributes) |> one!()
end

def transpose(%Value{function: func} = operand, axes, typespec) do
result_types = typespecs_to_mlir_types([typespec])
attributes = [permutation: attr_dense_i64_elements(axes)]
attributes = [permutation: attr_array_i64_elements(axes)]
op(func, "stablehlo.transpose", [operand], result_types, attributes: attributes) |> one!()
end

def slice(%Value{function: func} = operand, starts, limits, strides, typespec) do
result_types = typespecs_to_mlir_types([typespec])

attributes = [
start_indices: attr_dense_i64_elements(starts),
limit_indices: attr_dense_i64_elements(limits),
strides: attr_dense_i64_elements(strides)
start_indices: attr_array_i64_elements(starts),
limit_indices: attr_array_i64_elements(limits),
strides: attr_array_i64_elements(strides)
]

op(func, "stablehlo.slice", [operand], result_types, attributes: attributes) |> one!()
Expand All @@ -211,7 +211,7 @@ defmodule EXLA.MLIR.Value do
def dynamic_slice(%Value{function: func} = operand, starts, lengths, typespec) do
result_types = typespecs_to_mlir_types([typespec])
operands = [operand] ++ starts
attributes = [slice_sizes: attr_dense_i64_elements(lengths)]
attributes = [slice_sizes: attr_array_i64_elements(lengths)]
op(func, "stablehlo.dynamic_slice", operands, result_types, attributes: attributes) |> one!()
end

Expand Down Expand Up @@ -303,7 +303,7 @@ defmodule EXLA.MLIR.Value do
result_types = typespecs_to_mlir_types([typespec])

attributes = [
broadcast_dimensions: attr_dense_i64_elements(axes)
broadcast_dimensions: attr_array_i64_elements(axes)
]

op(func, "stablehlo.broadcast_in_dim", [operand], result_types, attributes: attributes)
Expand Down Expand Up @@ -347,9 +347,9 @@ defmodule EXLA.MLIR.Value do
{padding_low, padding_high, padding_mid} = unzip_padding_config(padding_config)

attributes = [
edge_padding_low: attr_dense_i64_elements(padding_low),
edge_padding_high: attr_dense_i64_elements(padding_high),
interior_padding: attr_dense_i64_elements(padding_mid)
edge_padding_low: attr_array_i64_elements(padding_low),
edge_padding_high: attr_array_i64_elements(padding_high),
interior_padding: attr_array_i64_elements(padding_mid)
]

op(func, "stablehlo.pad", [operand, pad], result_types, attributes: attributes) |> one!()
Expand All @@ -375,7 +375,7 @@ defmodule EXLA.MLIR.Value do

attributes = [
fft_type: fft_type,
fft_length: attr_dense_i64_elements(List.wrap(fft_length))
fft_length: attr_array_i64_elements(List.wrap(fft_length))
]

op(func, "stablehlo.fft", [value], result_types, attributes: attributes) |> one!()
Expand Down Expand Up @@ -451,8 +451,8 @@ defmodule EXLA.MLIR.Value do
result_types = typespecs_to_mlir_types([typespec])

attributes = [
window_dimensions: attr_dense_i64_elements(window_dimensions),
window_strides: attr_dense_i64_elements(window_strides),
window_dimensions: attr_array_i64_elements(window_dimensions),
window_strides: attr_array_i64_elements(window_strides),
padding: attr_padding(padding)
]

Expand Down Expand Up @@ -501,7 +501,7 @@ defmodule EXLA.MLIR.Value do

attributes = [
dimension_numbers: dimension_numbers,
slice_sizes: attr_dense_i64_elements(slice_sizes),
slice_sizes: attr_array_i64_elements(slice_sizes),
indices_are_sorted: attr_boolean(false)
]

Expand Down Expand Up @@ -546,10 +546,10 @@ defmodule EXLA.MLIR.Value do
attr_precision_config = attr_precision_config(precision_config)

attributes = [
window_strides: attr_dense_i64_elements(strides),
window_strides: attr_array_i64_elements(strides),
padding: attr_padding(padding),
lhs_dilation: attr_dense_i64_elements(input_dilation),
rhs_dilation: attr_dense_i64_elements(kernel_dilation),
lhs_dilation: attr_array_i64_elements(input_dilation),
rhs_dilation: attr_array_i64_elements(kernel_dilation),
dimension_numbers: attr_conv_dimension_numbers(dimension_numbers),
feature_group_count: attr_i64(feature_group_count),
batch_group_count: attr_i64(batch_group_count),
Expand Down Expand Up @@ -625,7 +625,7 @@ defmodule EXLA.MLIR.Value do
) do
operands = inputs ++ init_values
result_types = typespecs_to_mlir_types(typespecs)
attributes = [dimensions: attr_dense_i64_elements(dimensions)]
attributes = [dimensions: attr_array_i64_elements(dimensions)]
regions = [reducer]
op(func, "stablehlo.reduce", operands, result_types, attributes: attributes, regions: regions)
end
Expand All @@ -645,10 +645,10 @@ defmodule EXLA.MLIR.Value do
result_types = typespecs_to_mlir_types(typespecs)

attributes = [
window_dimensions: attr_dense_i64_elements(window_dimensions),
window_strides: attr_dense_i64_elements(window_strides),
base_dilations: attr_dense_i64_elements(input_dilations),
window_dilations: attr_dense_i64_elements(window_dilations),
window_dimensions: attr_array_i64_elements(window_dimensions),
window_strides: attr_array_i64_elements(window_strides),
base_dilations: attr_array_i64_elements(input_dilations),
window_dilations: attr_array_i64_elements(window_dilations),
padding: attr_padding(padding)
]

Expand All @@ -669,7 +669,7 @@ defmodule EXLA.MLIR.Value do
result_types = typespecs_to_mlir_types([typespec])

attributes = [
dimensions: attr_dense_i64_elements(dimensions)
dimensions: attr_array_i64_elements(dimensions)
]

regions = [mapper]
Expand Down Expand Up @@ -904,8 +904,12 @@ defmodule EXLA.MLIR.Value do
<<value::size(size)-big>>
end

defp attr_dense_i64_elements(list) do
attr_dense_elements(list, {:s, 64}, {length(list)})
defp attr_array_i64_elements([]) do
"array<i64>"
end

defp attr_array_i64_elements(list) do
"array<i64: #{Enum.join(list, ", ")}>"
end

defp attr_dense_elements([], type, {0} = shape) do
Expand Down

0 comments on commit 741829f

Please sign in to comment.