From e5ec49efca78ad539403c4ad1e3a89e6c3a26b53 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Wed, 4 Sep 2024 08:28:40 -0700 Subject: [PATCH] Add FP8 support (#1507) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Valim --- exla/lib/exla/mlir/value.ex | 6 ++++ exla/lib/exla/typespec.ex | 3 ++ exla/test/exla/defn/expr_test.exs | 8 +++++ nx/lib/nx.ex | 13 ++++++-- nx/lib/nx/binary_backend.ex | 5 ++-- nx/lib/nx/constants.ex | 48 +++++++++++++++++++++++++++++ nx/lib/nx/random.ex | 1 + nx/lib/nx/shared.ex | 50 +++++++++++++++++++++++++++---- nx/lib/nx/type.ex | 16 ++++++++-- nx/test/nx_test.exs | 4 +-- torchx/c_src/torchx.cpp | 4 +-- torchx/lib/torchx.ex | 1 + torchx/lib/torchx/backend.ex | 2 ++ 13 files changed, 146 insertions(+), 15 deletions(-) diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 1a145c41311..b2c766ebbf0 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -880,6 +880,7 @@ defmodule EXLA.MLIR.Value do defp type_number({:pred, 8}), do: "i1" defp type_number({:s, width}), do: "i#{width}" defp type_number({:u, width}), do: "ui#{width}" + defp type_number({:f, 8}), do: "f8E5M2" defp type_number({:f, width}), do: "f#{width}" defp type_number({:bf, width}), do: "bf#{width}" defp type_number({:c, 64}), do: "complex" @@ -926,12 +927,17 @@ defmodule EXLA.MLIR.Value do :nan -> type |> Nx.Type.nan_binary() |> native_to_big() :infinity -> type |> Nx.Type.infinity_binary() |> native_to_big() :neg_infinity -> type |> Nx.Type.neg_infinity_binary() |> native_to_big() + value when size == 8 -> f8E5M2_to_big(value) value -> <> end Base.encode16(data) end + defp f8E5M2_to_big(x) do + binary_part(<>, 0, 1) + end + defp native_to_big(binary) do size = byte_size(binary) * 8 <> = binary diff --git a/exla/lib/exla/typespec.ex b/exla/lib/exla/typespec.ex index 471a25aace8..0c56bf07f23 100644 --- a/exla/lib/exla/typespec.ex +++ b/exla/lib/exla/typespec.ex @@ -69,6 +69,9 @@ defmodule EXLA.Typespec do {:c, 128} => ~c"c128" } + defp type_to_charlist({:f, 8}), do: ~c"f8e5m2" + defp charlist_to_type(~c"f8"), do: {:f, 8} + for {type, charlist} <- type_to_charlist do defp charlist_to_type(unquote(charlist)), do: unquote(type) defp type_to_charlist(unquote(type)), do: unquote(charlist) diff --git a/exla/test/exla/defn/expr_test.exs b/exla/test/exla/defn/expr_test.exs index cedf1be49df..ecb95cb0eee 100644 --- a/exla/test/exla/defn/expr_test.exs +++ b/exla/test/exla/defn/expr_test.exs @@ -86,6 +86,14 @@ defmodule EXLA.Defn.ExprTest do end end + describe "float8" do + defn return_float8, do: Nx.tensor(1, type: {:f, 8}) + + test "supports float8 return types" do + assert_equal(return_float8(), Nx.tensor(1, type: {:f, 8})) + end + end + describe "float16" do defn return_float, do: Nx.tensor(1, type: {:f, 16}) diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index fcde90da435..d384e8e898b 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -51,7 +51,7 @@ defmodule Nx do * unsigned integers (`u8`, `u16`, `u32`, `u64`) * signed integers (`s8`, `s16`, `s32`, `s64`) - * floats (`f16`, `f32`, `f64`) + * floats (`f8`, `f16`, `f32`, `f64`) * brain floats (`bf16`) * and complex numbers (`c64`, `c128`) @@ -612,6 +612,15 @@ defmodule Nx do [1.0, 2.0, 3.0] > + Certain backends and compilers support 8-bit floats. On the binary + backend this behavior is emulated: + + iex> Nx.tensor([1, 2, 3], type: :f8) + #Nx.Tensor< + f8[3] + [1.0, 2.0, 3.0] + > + In all cases, the non-finite values negative infinity (-Inf), infinity (Inf), and "not a number" (NaN) can be represented by the atoms `:neg_infinity`, `:infinity`, and `:nan` respectively: @@ -929,7 +938,7 @@ defmodule Nx do %T{shape: shape, type: type, names: names, data: %Nx.TemplateBackend{}} end - for t <- [:u8, :u16, :u32, :u64, :s8, :s16, :s32, :s64, :bf16, :f16, :f32, :f64] do + for t <- [:u8, :u16, :u32, :u64, :s8, :s16, :s32, :s64, :bf16, :f8, :f16, :f32, :f64] do @doc """ Short-hand function for creating tensor of type `#{t}`. diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index cbf1539e9b6..49fd97beb58 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -2460,8 +2460,9 @@ defmodule Nx.BinaryBackend do "expected a number or a scalar tensor of type #{inspect(type)}, got: #{inspect(t)}" end - defp number_to_binary(number, type), - do: match_types([type], do: <>) + defp number_to_binary(number, type) do + match_types([type], do: <>) + end defp binary_to_number(bin, type) do match_types [type] do diff --git a/nx/lib/nx/constants.ex b/nx/lib/nx/constants.ex index b05b756d94c..0f41b64aca9 100644 --- a/nx/lib/nx/constants.ex +++ b/nx/lib/nx/constants.ex @@ -22,6 +22,12 @@ defmodule Nx.Constants do ## Examples + iex> Nx.Constants.nan({:f, 8}) + #Nx.Tensor< + f8 + NaN + > + iex> Nx.Constants.nan({:bf, 16}) #Nx.Tensor< bf16 @@ -66,6 +72,12 @@ defmodule Nx.Constants do ## Examples + iex> Nx.Constants.infinity({:f, 8}) + #Nx.Tensor< + f8 + Inf + > + iex> Nx.Constants.infinity({:bf, 16}) #Nx.Tensor< bf16 @@ -110,6 +122,12 @@ defmodule Nx.Constants do ## Examples + iex> Nx.Constants.neg_infinity({:f, 8}) + #Nx.Tensor< + f8 + -Inf + > + iex> Nx.Constants.neg_infinity({:bf, 16}) #Nx.Tensor< bf16 @@ -334,6 +352,12 @@ defmodule Nx.Constants do 1.1754943508222875e-38 > + iex> Nx.Constants.smallest_positive_normal(:f8) + #Nx.Tensor< + f8 + 6.103515625e-5 + > + iex> Nx.Constants.smallest_positive_normal({:s, 32}) ** (ArgumentError) only floating types are supported, got: {:s, 32} """ @@ -377,6 +401,12 @@ defmodule Nx.Constants do 0.0078125 > + iex> Nx.Constants.epsilon(:f8) + #Nx.Tensor< + f8 + 0.25 + > + iex> Nx.Constants.epsilon({:s, 32}) ** (ArgumentError) only floating types are supported, got: {:s, 32} """ @@ -423,6 +453,12 @@ defmodule Nx.Constants do 3.140625 > + iex> Nx.Constants.pi({:f, 8}) + #Nx.Tensor< + f8 + 3.0 + > + iex> Nx.Constants.pi({:s, 32}) ** (ArgumentError) only floating types are supported, got: {:s, 32} """ @@ -469,6 +505,12 @@ defmodule Nx.Constants do 2.703125 > + iex> Nx.Constants.e({:f, 8}) + #Nx.Tensor< + f8 + 2.5 + > + iex> Nx.Constants.e({:s, 32}) ** (ArgumentError) only floating types are supported, got: {:s, 32} """ @@ -515,6 +557,12 @@ defmodule Nx.Constants do 0.57421875 > + iex> Nx.Constants.euler_gamma({:f, 8}) + #Nx.Tensor< + f8 + 0.5 + > + iex> Nx.Constants.euler_gamma({:s, 32}) ** (ArgumentError) only floating types are supported, got: {:s, 32} """ diff --git a/nx/lib/nx/random.ex b/nx/lib/nx/random.ex index c876aed3b32..5b63a4718b6 100644 --- a/nx/lib/nx/random.ex +++ b/nx/lib/nx/random.ex @@ -299,6 +299,7 @@ defmodule Nx.Random do deftransformp mantissa_shift(nbits, type) do mantissa = case type do + {:f, 8} -> 2 {:bf, 16} -> 7 {:f, 16} -> 10 {:f, 32} -> 23 diff --git a/nx/lib/nx/shared.ex b/nx/lib/nx/shared.ex index 88ec8748e92..392e7416d36 100644 --- a/nx/lib/nx/shared.ex +++ b/nx/lib/nx/shared.ex @@ -105,9 +105,14 @@ defmodule Nx.Shared do quote do: Nx.Shared.read_bf16(unquote(var)) end + defp read_bin_modifier(var, :f, 8) do + quote do: Nx.Shared.read_f8(unquote(var)) + end + defp read_bin_modifier(var, :f, size) do quote do case unquote(var) do + _ when unquote(size) == 8 -> Nx.Shared.read_f8(unquote(var)) <> -> var var -> Nx.Shared.read_non_finite(var, unquote(size)) end @@ -122,14 +127,14 @@ defmodule Nx.Shared do quote do case unquote(var) do x when is_number(x) -> binary_part(<>, 2, 2) - x -> Nx.Shared.write_bf16(x) + x -> Nx.Shared.write_non_finite_bf16(x) end :: binary end else quote do case unquote(var) do x when is_number(x) -> binary_part(<>, 0, 2) - x -> Nx.Shared.write_bf16(x) + x -> Nx.Shared.write_non_finite_bf16(x) end :: binary end end @@ -155,7 +160,8 @@ defmodule Nx.Shared do defp write_bin_modifier(var, :f, size) do quote do case unquote(var) do - x when is_number(x) -> <> + x when is_number(x) and unquote(size) != 8 -> <> + x when is_number(x) -> Nx.Shared.write_finite_f8(unquote(var)) x -> Nx.Shared.write_non_finite(x, unquote(size)) end :: binary end @@ -192,6 +198,22 @@ defmodule Nx.Shared do end end + @doc """ + F8 read callback. + """ + def read_f8(<<0xFC::8-native>>), do: :neg_infinity + def read_f8(<<0x7C::8-native>>), do: :infinity + def read_f8(<<_sign::1, 31::5, mantissa::2>>) when mantissa != 0, do: :nan + + def read_f8(<>) do + float = :math.pow(2, exp - 15) * (1 + mantissa / 4) + + case sign do + 0 -> float + _ -> -float + end + end + @doc """ C64 and C128 callback. """ @@ -217,7 +239,7 @@ defmodule Nx.Shared do @doc """ BF16 write callback. """ - def write_bf16(data) do + def write_non_finite_bf16(data) do case data do :infinity -> unquote(Nx.Type.infinity_binary({:bf, 16})) :neg_infinity -> unquote(Nx.Type.neg_infinity_binary({:bf, 16})) @@ -225,6 +247,16 @@ defmodule Nx.Shared do end end + if System.endianness() == :little do + def write_finite_f8(x) do + binary_part(<>, 1, 1) + end + else + def write_finite_f8(x) do + binary_part(<>, 0, 1) + end + end + @doc """ Complex write callback. """ @@ -247,6 +279,14 @@ defmodule Nx.Shared do @doc """ Non-finite read callback. """ + def read_non_finite(data, 8) do + case data do + <<0xFC::8-native>> -> :neg_infinity + <<0x7C::8-native>> -> :infinity + _ -> :nan + end + end + def read_non_finite(data, 16) do case data do <<0xFC00::16-native>> -> :neg_infinity @@ -274,7 +314,7 @@ defmodule Nx.Shared do @doc """ Non-finite write callback. """ - for size <- [16, 32, 64] do + for size <- [8, 16, 32, 64] do def write_non_finite(data, unquote(size)) do case data do :infinity -> unquote(Nx.Type.infinity_binary({:f, size})) diff --git a/nx/lib/nx/type.ex b/nx/lib/nx/type.ex index 4b65d6e1ad3..a20af58bb27 100644 --- a/nx/lib/nx/type.ex +++ b/nx/lib/nx/type.ex @@ -7,7 +7,7 @@ defmodule Nx.Type do * `:s` - signed integer (8, 16, 32, 64) * `:u` - unsigned integer (8, 16, 32, 64) - * `:f` - float (16, 32, 64) + * `:f` - float (8, 16, 32, 64) * `:bf` - a brain floating point (16) * `:c` - a complex number, represented as a pair of floats (64, 128) @@ -33,6 +33,7 @@ defmodule Nx.Type do | {:u, 16} | {:u, 32} | {:u, 64} + | {:f, 8} | {:f, 16} | {:f, 32} | {:f, 64} @@ -50,6 +51,7 @@ defmodule Nx.Type do | :u16 | :u32 | :u64 + | :f8 | :f16 | :f32 | :f64 @@ -68,6 +70,7 @@ defmodule Nx.Type do def min_finite_binary({:s, 64}), do: <<-9_223_372_036_854_775_808::64-signed-native>> def min_finite_binary({:u, size}), do: <<0::size(size)-native>> def min_finite_binary({:bf, 16}), do: <<0xFF7F::16-native>> + def min_finite_binary({:f, 8}), do: <<0xFB::8-native>> def min_finite_binary({:f, 16}), do: <<0xFBFF::16-native>> def min_finite_binary({:f, 32}), do: <<0xFF7FFFFF::32-native>> def min_finite_binary({:f, 64}), do: <<0xFFEFFFFFFFFFFFFF::64-native>> @@ -93,6 +96,7 @@ defmodule Nx.Type do def max_finite_binary({:u, 32}), do: <<4_294_967_295::32-native>> def max_finite_binary({:u, 64}), do: <<18_446_744_073_709_551_615::64-native>> def max_finite_binary({:bf, 16}), do: <<0x7F7F::16-native>> + def max_finite_binary({:f, 8}), do: <<0x7B::8-native>> def max_finite_binary({:f, 16}), do: <<0x7BFF::16-native>> def max_finite_binary({:f, 32}), do: <<0x7F7FFFFF::32-native>> def max_finite_binary({:f, 64}), do: <<0x7FEFFFFFFFFFFFFF::64-native>> @@ -109,6 +113,7 @@ defmodule Nx.Type do """ def nan_binary(type) def nan_binary({:bf, 16}), do: <<0x7FC0::16-native>> + def nan_binary({:f, 8}), do: <<0x7E::8-native>> def nan_binary({:f, 16}), do: <<0x7E00::16-native>> def nan_binary({:f, 32}), do: <<0x7FC00000::32-native>> def nan_binary({:f, 64}), do: <<0x7FF8000000000000::64-native>> @@ -118,6 +123,7 @@ defmodule Nx.Type do """ def infinity_binary(type) def infinity_binary({:bf, 16}), do: <<0x7F80::16-native>> + def infinity_binary({:f, 8}), do: <<0x7C::8-native>> def infinity_binary({:f, 16}), do: <<0x7C00::16-native>> def infinity_binary({:f, 32}), do: <<0x7F800000::32-native>> def infinity_binary({:f, 64}), do: <<0x7FF0000000000000::64-native>> @@ -127,6 +133,7 @@ defmodule Nx.Type do """ def neg_infinity_binary(type) def neg_infinity_binary({:bf, 16}), do: <<0xFF80::16-native>> + def neg_infinity_binary({:f, 8}), do: <<0xFC::8-native>> def neg_infinity_binary({:f, 16}), do: <<0xFC00::16-native>> def neg_infinity_binary({:f, 32}), do: <<0xFF800000::32-native>> def neg_infinity_binary({:f, 64}), do: <<0xFFF0000000000000::64-native>> @@ -186,7 +193,7 @@ defmodule Nx.Type do type_variants = [ s: [8, 16, 32, 64], u: [8, 16, 32, 64], - f: [16, 32, 64], + f: [8, 16, 32, 64], bf: [16], c: [64, 128] ] @@ -599,6 +606,7 @@ defmodule Nx.Type do """ def smallest_positive_normal_binary(type) def smallest_positive_normal_binary({:bf, 16}), do: <<0x0080::16-native>> + def smallest_positive_normal_binary({:f, 8}), do: <<0x04::8-native>> def smallest_positive_normal_binary({:f, 16}), do: <<0x0400::16-native>> def smallest_positive_normal_binary({:f, 32}), do: <<0x0080_0000::32-native>> def smallest_positive_normal_binary({:f, 64}), do: <<0x0010_0000_0000_0000::64-native>> @@ -611,6 +619,7 @@ defmodule Nx.Type do """ def epsilon_binary(type) def epsilon_binary({:bf, 16}), do: <<0, 60>> + def epsilon_binary({:f, 8}), do: <<52>> def epsilon_binary({:f, 16}), do: <<0, 20>> def epsilon_binary({:f, 32}), do: <<0, 0, 0, 52>> def epsilon_binary({:f, 64}), do: <<0, 0, 0, 0, 0, 0, 176, 60>> @@ -636,6 +645,7 @@ defmodule Nx.Type do """ def pi_binary(type) def pi_binary({:bf, 16}), do: <<73, 64>> + def pi_binary({:f, 8}), do: <<66>> def pi_binary({:f, 16}), do: <<72, 66>> def pi_binary({:f, 32}), do: <<219, 15, 73, 64>> def pi_binary({:f, 64}), do: <<24, 45, 68, 84, 251, 33, 9, 64>> @@ -648,6 +658,7 @@ defmodule Nx.Type do """ def e_binary(type) def e_binary({:bf, 16}), do: <<45, 64>> + def e_binary({:f, 8}), do: <<65>> def e_binary({:f, 16}), do: <<112, 65>> def e_binary({:f, 32}), do: <<84, 248, 45, 64>> def e_binary({:f, 64}), do: <<105, 87, 20, 139, 10, 191, 5, 64>> @@ -660,6 +671,7 @@ defmodule Nx.Type do """ def euler_gamma_binary(type) def euler_gamma_binary({:bf, 16}), do: <<19, 63>> + def euler_gamma_binary({:f, 8}), do: <<56>> def euler_gamma_binary({:f, 16}), do: <<158, 56>> def euler_gamma_binary({:f, 32}), do: <<104, 196, 19, 63>> def euler_gamma_binary({:f, 64}), do: <<25, 182, 111, 252, 140, 120, 226, 63>> diff --git a/nx/test/nx_test.exs b/nx/test/nx_test.exs index 8981ce650fd..98aa75c9902 100644 --- a/nx/test/nx_test.exs +++ b/nx/test/nx_test.exs @@ -2572,9 +2572,9 @@ defmodule NxTest do test "raises on invalid type" do assert_raise( ArgumentError, - "invalid numerical type: {:f, 8} (see Nx.Type docs for all supported types)", + "invalid numerical type: {:f, 4} (see Nx.Type docs for all supported types)", fn -> - eval("~MAT[1 2 3 4]f8") + eval("~MAT[1 2 3 4]f4") end ) end diff --git a/torchx/c_src/torchx.cpp b/torchx/c_src/torchx.cpp index 09cbe160b83..5dfa403f2df 100644 --- a/torchx/c_src/torchx.cpp +++ b/torchx/c_src/torchx.cpp @@ -11,8 +11,8 @@ #include "nx_nif_utils.hpp" -std::map dtypes = {{"byte", torch::kByte}, {"char", torch::kChar}, {"short", torch::kShort}, {"int", torch::kInt}, {"long", torch::kLong}, {"half", torch::kHalf}, {"brain", torch::kBFloat16}, {"float", torch::kFloat}, {"double", torch::kDouble}, {"bool", torch::kBool}, {"complex", at::ScalarType::ComplexFloat}, {"complex_double", at::ScalarType::ComplexDouble}}; -std::map dtype_sizes = {{"byte", 1}, {"char", 1}, {"short", 2}, {"int", 4}, {"long", 8}, {"half", 2}, {"brain", 2}, {"float", 4}, {"double", 8}, {"complex", 8}, {"complex_double", 16}}; +std::map dtypes = {{"byte", torch::kByte}, {"char", torch::kChar}, {"short", torch::kShort}, {"int", torch::kInt}, {"long", torch::kLong}, {"float8_e5m2", torch::kFloat8_e5m2}, {"half", torch::kHalf}, {"brain", torch::kBFloat16}, {"float", torch::kFloat}, {"double", torch::kDouble}, {"bool", torch::kBool}, {"complex", at::ScalarType::ComplexFloat}, {"complex_double", at::ScalarType::ComplexDouble}}; +std::map dtype_sizes = {{"byte", 1}, {"char", 1}, {"short", 2}, {"int", 4}, {"long", 8}, {"float8_e5m2", 1}, {"half", 2}, {"brain", 2}, {"float", 4}, {"double", 8}, {"complex", 8}, {"complex_double", 16}}; inline torch::ScalarType string2type(const std::string &atom) { return dtypes[atom]; diff --git a/torchx/lib/torchx.ex b/torchx/lib/torchx.ex index 051e0b4a3f4..7e9fc61ff17 100644 --- a/torchx/lib/torchx.ex +++ b/torchx/lib/torchx.ex @@ -144,6 +144,7 @@ defmodule Torchx do `{:s, 32}` | `:int` | Signed 32-bit integer `{:s, 64}` | `:long` | Signed 64-bit integer `{:bf, 16}` | `:brain` | 16-bit brain floating-point number + `{:f, 8}` | `:float8_e5m2` | 8-bit floating-point number `{:f, 16}` | `:half` | 16-bit floating-point number `{:f, 32}` | `:float` | 32-bit floating-point number `{:f, 64}` | `:double` | 64-bit floating-point number diff --git a/torchx/lib/torchx/backend.ex b/torchx/lib/torchx/backend.ex index de85a1d7001..cfc2b637b64 100644 --- a/torchx/lib/torchx/backend.ex +++ b/torchx/lib/torchx/backend.ex @@ -1715,6 +1715,7 @@ defmodule Torchx.Backend do def from_torch_type(:int), do: {:s, 32} def from_torch_type(:long), do: {:s, 64} def from_torch_type(:brain), do: {:bf, 16} + def from_torch_type(:float8_e5m2), do: {:f, 8} def from_torch_type(:half), do: {:f, 16} def from_torch_type(:float), do: {:f, 32} def from_torch_type(:double), do: {:f, 64} @@ -1731,6 +1732,7 @@ defmodule Torchx.Backend do defp to_torch_type({:s, 32}, _), do: :int defp to_torch_type({:s, 64}, _), do: :long defp to_torch_type({:bf, 16}, _), do: :brain + defp to_torch_type({:f, 8}, _), do: :float8_e5m2 defp to_torch_type({:f, 16}, _), do: :half defp to_torch_type({:f, 32}, _), do: :float defp to_torch_type({:f, 64}, _), do: :double