Skip to content

Commit

Permalink
Add FP8 support (#1507)
Browse files Browse the repository at this point in the history
Co-authored-by: José Valim <[email protected]>
  • Loading branch information
seanmor5 and josevalim authored Sep 4, 2024
1 parent 13058bb commit e5ec49e
Show file tree
Hide file tree
Showing 13 changed files with 146 additions and 15 deletions.
6 changes: 6 additions & 0 deletions exla/lib/exla/mlir/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,7 @@ defmodule EXLA.MLIR.Value do
defp type_number({:pred, 8}), do: "i1"
defp type_number({:s, width}), do: "i#{width}"
defp type_number({:u, width}), do: "ui#{width}"
defp type_number({:f, 8}), do: "f8E5M2"
defp type_number({:f, width}), do: "f#{width}"
defp type_number({:bf, width}), do: "bf#{width}"
defp type_number({:c, 64}), do: "complex<f32>"
Expand Down Expand Up @@ -926,12 +927,17 @@ defmodule EXLA.MLIR.Value do
:nan -> type |> Nx.Type.nan_binary() |> native_to_big()
:infinity -> type |> Nx.Type.infinity_binary() |> native_to_big()
:neg_infinity -> type |> Nx.Type.neg_infinity_binary() |> native_to_big()
value when size == 8 -> f8E5M2_to_big(value)
value -> <<value::float-size(size)-big>>
end

Base.encode16(data)
end

defp f8E5M2_to_big(x) do
binary_part(<<x::float-big-16>>, 0, 1)
end

defp native_to_big(binary) do
size = byte_size(binary) * 8
<<value::size(size)-native>> = binary
Expand Down
3 changes: 3 additions & 0 deletions exla/lib/exla/typespec.ex
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ defmodule EXLA.Typespec do
{:c, 128} => ~c"c128"
}

defp type_to_charlist({:f, 8}), do: ~c"f8e5m2"
defp charlist_to_type(~c"f8"), do: {:f, 8}

for {type, charlist} <- type_to_charlist do
defp charlist_to_type(unquote(charlist)), do: unquote(type)
defp type_to_charlist(unquote(type)), do: unquote(charlist)
Expand Down
8 changes: 8 additions & 0 deletions exla/test/exla/defn/expr_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ defmodule EXLA.Defn.ExprTest do
end
end

describe "float8" do
defn return_float8, do: Nx.tensor(1, type: {:f, 8})

test "supports float8 return types" do
assert_equal(return_float8(), Nx.tensor(1, type: {:f, 8}))
end
end

describe "float16" do
defn return_float, do: Nx.tensor(1, type: {:f, 16})

Expand Down
13 changes: 11 additions & 2 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ defmodule Nx do
* unsigned integers (`u8`, `u16`, `u32`, `u64`)
* signed integers (`s8`, `s16`, `s32`, `s64`)
* floats (`f16`, `f32`, `f64`)
* floats (`f8`, `f16`, `f32`, `f64`)
* brain floats (`bf16`)
* and complex numbers (`c64`, `c128`)
Expand Down Expand Up @@ -612,6 +612,15 @@ defmodule Nx do
[1.0, 2.0, 3.0]
>
Certain backends and compilers support 8-bit floats. On the binary
backend this behavior is emulated:
iex> Nx.tensor([1, 2, 3], type: :f8)
#Nx.Tensor<
f8[3]
[1.0, 2.0, 3.0]
>
In all cases, the non-finite values negative infinity (-Inf),
infinity (Inf), and "not a number" (NaN) can be represented by
the atoms `:neg_infinity`, `:infinity`, and `:nan` respectively:
Expand Down Expand Up @@ -929,7 +938,7 @@ defmodule Nx do
%T{shape: shape, type: type, names: names, data: %Nx.TemplateBackend{}}
end

for t <- [:u8, :u16, :u32, :u64, :s8, :s16, :s32, :s64, :bf16, :f16, :f32, :f64] do
for t <- [:u8, :u16, :u32, :u64, :s8, :s16, :s32, :s64, :bf16, :f8, :f16, :f32, :f64] do
@doc """
Short-hand function for creating tensor of type `#{t}`.
Expand Down
5 changes: 3 additions & 2 deletions nx/lib/nx/binary_backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2460,8 +2460,9 @@ defmodule Nx.BinaryBackend do
"expected a number or a scalar tensor of type #{inspect(type)}, got: #{inspect(t)}"
end

defp number_to_binary(number, type),
do: match_types([type], do: <<write!(number, 0)>>)
defp number_to_binary(number, type) do
match_types([type], do: <<write!(number, 0)>>)
end

defp binary_to_number(bin, type) do
match_types [type] do
Expand Down
48 changes: 48 additions & 0 deletions nx/lib/nx/constants.ex
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ defmodule Nx.Constants do
## Examples
iex> Nx.Constants.nan({:f, 8})
#Nx.Tensor<
f8
NaN
>
iex> Nx.Constants.nan({:bf, 16})
#Nx.Tensor<
bf16
Expand Down Expand Up @@ -66,6 +72,12 @@ defmodule Nx.Constants do
## Examples
iex> Nx.Constants.infinity({:f, 8})
#Nx.Tensor<
f8
Inf
>
iex> Nx.Constants.infinity({:bf, 16})
#Nx.Tensor<
bf16
Expand Down Expand Up @@ -110,6 +122,12 @@ defmodule Nx.Constants do
## Examples
iex> Nx.Constants.neg_infinity({:f, 8})
#Nx.Tensor<
f8
-Inf
>
iex> Nx.Constants.neg_infinity({:bf, 16})
#Nx.Tensor<
bf16
Expand Down Expand Up @@ -334,6 +352,12 @@ defmodule Nx.Constants do
1.1754943508222875e-38
>
iex> Nx.Constants.smallest_positive_normal(:f8)
#Nx.Tensor<
f8
6.103515625e-5
>
iex> Nx.Constants.smallest_positive_normal({:s, 32})
** (ArgumentError) only floating types are supported, got: {:s, 32}
"""
Expand Down Expand Up @@ -377,6 +401,12 @@ defmodule Nx.Constants do
0.0078125
>
iex> Nx.Constants.epsilon(:f8)
#Nx.Tensor<
f8
0.25
>
iex> Nx.Constants.epsilon({:s, 32})
** (ArgumentError) only floating types are supported, got: {:s, 32}
"""
Expand Down Expand Up @@ -423,6 +453,12 @@ defmodule Nx.Constants do
3.140625
>
iex> Nx.Constants.pi({:f, 8})
#Nx.Tensor<
f8
3.0
>
iex> Nx.Constants.pi({:s, 32})
** (ArgumentError) only floating types are supported, got: {:s, 32}
"""
Expand Down Expand Up @@ -469,6 +505,12 @@ defmodule Nx.Constants do
2.703125
>
iex> Nx.Constants.e({:f, 8})
#Nx.Tensor<
f8
2.5
>
iex> Nx.Constants.e({:s, 32})
** (ArgumentError) only floating types are supported, got: {:s, 32}
"""
Expand Down Expand Up @@ -515,6 +557,12 @@ defmodule Nx.Constants do
0.57421875
>
iex> Nx.Constants.euler_gamma({:f, 8})
#Nx.Tensor<
f8
0.5
>
iex> Nx.Constants.euler_gamma({:s, 32})
** (ArgumentError) only floating types are supported, got: {:s, 32}
"""
Expand Down
1 change: 1 addition & 0 deletions nx/lib/nx/random.ex
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ defmodule Nx.Random do
deftransformp mantissa_shift(nbits, type) do
mantissa =
case type do
{:f, 8} -> 2
{:bf, 16} -> 7
{:f, 16} -> 10
{:f, 32} -> 23
Expand Down
50 changes: 45 additions & 5 deletions nx/lib/nx/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,14 @@ defmodule Nx.Shared do
quote do: Nx.Shared.read_bf16(unquote(var))
end

defp read_bin_modifier(var, :f, 8) do
quote do: Nx.Shared.read_f8(unquote(var))
end

defp read_bin_modifier(var, :f, size) do
quote do
case unquote(var) do
_ when unquote(size) == 8 -> Nx.Shared.read_f8(unquote(var))
<<var::float-native-size(unquote(size))>> -> var
var -> Nx.Shared.read_non_finite(var, unquote(size))
end
Expand All @@ -122,14 +127,14 @@ defmodule Nx.Shared do
quote do
case unquote(var) do
x when is_number(x) -> binary_part(<<x::float-native-32>>, 2, 2)
x -> Nx.Shared.write_bf16(x)
x -> Nx.Shared.write_non_finite_bf16(x)
end :: binary
end
else
quote do
case unquote(var) do
x when is_number(x) -> binary_part(<<x::float-native-32>>, 0, 2)
x -> Nx.Shared.write_bf16(x)
x -> Nx.Shared.write_non_finite_bf16(x)
end :: binary
end
end
Expand All @@ -155,7 +160,8 @@ defmodule Nx.Shared do
defp write_bin_modifier(var, :f, size) do
quote do
case unquote(var) do
x when is_number(x) -> <<x::float-native-size(unquote(size))>>
x when is_number(x) and unquote(size) != 8 -> <<x::float-native-size(unquote(size))>>
x when is_number(x) -> Nx.Shared.write_finite_f8(unquote(var))
x -> Nx.Shared.write_non_finite(x, unquote(size))
end :: binary
end
Expand Down Expand Up @@ -192,6 +198,22 @@ defmodule Nx.Shared do
end
end

@doc """
F8 read callback.
"""
def read_f8(<<0xFC::8-native>>), do: :neg_infinity
def read_f8(<<0x7C::8-native>>), do: :infinity
def read_f8(<<_sign::1, 31::5, mantissa::2>>) when mantissa != 0, do: :nan

def read_f8(<<sign::1, exp::5, mantissa::2>>) do
float = :math.pow(2, exp - 15) * (1 + mantissa / 4)

case sign do
0 -> float
_ -> -float
end
end

@doc """
C64 and C128 callback.
"""
Expand All @@ -217,14 +239,24 @@ defmodule Nx.Shared do
@doc """
BF16 write callback.
"""
def write_bf16(data) do
def write_non_finite_bf16(data) do
case data do
:infinity -> unquote(Nx.Type.infinity_binary({:bf, 16}))
:neg_infinity -> unquote(Nx.Type.neg_infinity_binary({:bf, 16}))
:nan -> unquote(Nx.Type.nan_binary({:bf, 16}))
end
end

if System.endianness() == :little do
def write_finite_f8(x) do
binary_part(<<x::float-native-16>>, 1, 1)
end
else
def write_finite_f8(x) do
binary_part(<<x::float-native-16>>, 0, 1)
end
end

@doc """
Complex write callback.
"""
Expand All @@ -247,6 +279,14 @@ defmodule Nx.Shared do
@doc """
Non-finite read callback.
"""
def read_non_finite(data, 8) do
case data do
<<0xFC::8-native>> -> :neg_infinity
<<0x7C::8-native>> -> :infinity
_ -> :nan
end
end

def read_non_finite(data, 16) do
case data do
<<0xFC00::16-native>> -> :neg_infinity
Expand Down Expand Up @@ -274,7 +314,7 @@ defmodule Nx.Shared do
@doc """
Non-finite write callback.
"""
for size <- [16, 32, 64] do
for size <- [8, 16, 32, 64] do
def write_non_finite(data, unquote(size)) do
case data do
:infinity -> unquote(Nx.Type.infinity_binary({:f, size}))
Expand Down
Loading

0 comments on commit e5ec49e

Please sign in to comment.