From 08e33304128836f87d8ac9c28f236aaa3f510b45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Fri, 27 Sep 2024 09:09:13 +0200 Subject: [PATCH] Fix encoding of bf16 --- exla/lib/exla/mlir/value.ex | 7 ++++++- exla/test/exla/defn/expr_test.exs | 22 +++++++++++++--------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index b2c766ebbf..5dfd72ca23 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -921,13 +921,14 @@ defmodule EXLA.MLIR.Value do end end - defp float_hex(value, {_, size} = type) do + defp float_hex(value, {mod, size} = type) do data = case value do :nan -> type |> Nx.Type.nan_binary() |> native_to_big() :infinity -> type |> Nx.Type.infinity_binary() |> native_to_big() :neg_infinity -> type |> Nx.Type.neg_infinity_binary() |> native_to_big() value when size == 8 -> f8E5M2_to_big(value) + value when mod == :bf and size == 16 -> bf16_to_big(value) value -> <> end @@ -938,6 +939,10 @@ defmodule EXLA.MLIR.Value do binary_part(<>, 0, 1) end + defp bf16_to_big(x) do + binary_part(<>, 0, 2) + end + defp native_to_big(binary) do size = byte_size(binary) * 8 <> = binary diff --git a/exla/test/exla/defn/expr_test.exs b/exla/test/exla/defn/expr_test.exs index ecb95cb0ee..b05388869d 100644 --- a/exla/test/exla/defn/expr_test.exs +++ b/exla/test/exla/defn/expr_test.exs @@ -86,19 +86,23 @@ defmodule EXLA.Defn.ExprTest do end end - describe "float8" do - defn return_float8, do: Nx.tensor(1, type: {:f, 8}) + describe "types" do + defn return_f8, do: Nx.tensor(1, type: {:f, 8}) - test "supports float8 return types" do - assert_equal(return_float8(), Nx.tensor(1, type: {:f, 8})) + test "f8" do + assert_equal(return_f8(), Nx.tensor(1, type: {:f, 8})) + end + + defn return_f16, do: Nx.tensor(1, type: {:f, 16}) + + test "f16" do + assert_equal(return_f16(), Nx.tensor(1, type: {:f, 16})) end - end - describe "float16" do - defn return_float, do: Nx.tensor(1, type: {:f, 16}) + defn return_bf16, do: Nx.tensor(1, type: {:bf, 16}) - test "supports float16 return types" do - assert_equal(return_float(), Nx.tensor(1, type: {:f, 16})) + test "bf16" do + assert_equal(return_bf16(), Nx.tensor(1, type: {:bf, 16})) end end