diff --git a/lib/emlx/backend.ex b/lib/emlx/backend.ex index eda810c..ea94a84 100644 --- a/lib/emlx/backend.ex +++ b/lib/emlx/backend.ex @@ -73,11 +73,13 @@ defmodule EMLX.Backend do @impl true def to_binary(tensor, limit) do EMLX.to_blob(from_nx(tensor), limit) + |> maybe_modify_binary(to_nx_type(to_mlx_type(tensor.type)), tensor.type) end @impl true def from_binary(%T{type: type, shape: shape} = out, binary, backend_options) do binary + |> maybe_modify_binary(type, to_nx_type(to_mlx_type(type))) |> EMLX.from_blob( shape, to_mlx_type(type), @@ -86,6 +88,104 @@ defmodule EMLX.Backend do |> to_nx(out) end + defp maybe_modify_binary(binary, type, type), do: binary + + defp maybe_modify_binary(binary, {:f, 8}, {:f, 16}) do + for <>, into: <<>> do + case read_f8(<>) do + :nan -> + Nx.Type.nan_binary({:f, 16}) + + :infinity -> + Nx.Type.infinity_binary({:f, 16}) + + :neg_infinity -> + Nx.Type.neg_infinity_binary({:f, 16}) + + number -> + <> + end + end + end + + defp maybe_modify_binary(binary, {:f, 16}, {:f, 8}) do + for <>, into: <<>> do + case <> do + <> -> write_finite_f8(float) + <<0xFC00::16-native>> -> write_non_finite(:neg_infinity, 8) + <<0x7C00::16-native>> -> write_non_finite(:infinity, 8) + _ -> write_non_finite(:nan, 8) + end + end + end + + defp maybe_modify_binary(binary, {:f, 64}, {:f, 32}) do + for <>, into: <<>> do + case <> do + <> -> <> + <<0xFFF0000000000000::64-native>> -> write_non_finite(:neg_infinity, 32) + <<0x7FF0000000000000::64-native>> -> write_non_finite(:infinity, 32) + _ -> write_non_finite(:nan, 32) + end + end + end + + defp maybe_modify_binary(binary, {:f, 32}, {:f, 64}) do + for <>, into: <<>> do + case <> do + <> -> <> + <<0xFF800000::32-native>> -> write_non_finite(:neg_infinity, 64) + <<0x7F800000::32-native>> -> write_non_finite(:infinity, 64) + _ -> write_non_finite(:nan, 64) + end + end + end + + defp maybe_modify_binary(binary, {:u, size}, {:u, 8}) when size in [2, 4] do + for <>, into: <<>> do + <> + end + end + + defp maybe_modify_binary(binary, {:u, 8}, {:u, size}) when size in [2, 4] do + for <>, into: <<>> do + <> + end + end + + defp read_f8(<<0xFC::8-native>>), do: :neg_infinity + defp read_f8(<<0x7C::8-native>>), do: :infinity + defp read_f8(<<_sign::1, 31::5, mantissa::2>>) when mantissa != 0, do: :nan + + defp read_f8(<>) do + float = :math.pow(2, exp - 15) * (1 + mantissa / 4) + + case sign do + 0 -> float + _ -> -float + end + end + + if System.endianness() == :little do + def write_finite_f8(x) do + binary_part(<>, 0, 1) + end + else + def write_finite_f8(x) do + binary_part(<>, 1, 1) + end + end + + for size <- [8, 16, 32, 64] do + def write_non_finite(data, unquote(size)) do + case data do + :infinity -> unquote(Nx.Type.infinity_binary({:f, size})) + :neg_infinity -> unquote(Nx.Type.neg_infinity_binary({:f, size})) + :nan -> unquote(Nx.Type.nan_binary({:f, size})) + end + end + end + @impl true def slice( out, @@ -202,18 +302,25 @@ defmodule EMLX.Backend do defp needs_type_conversion?({:u, 8}, :bool), do: true defp needs_type_conversion?(_, _), do: false + defp to_mlx_type({:u, 2}), do: :uint8 + defp to_mlx_type({:u, 4}), do: :uint8 defp to_mlx_type({:u, 8}), do: :uint8 defp to_mlx_type({:u, 16}), do: :uint16 defp to_mlx_type({:u, 32}), do: :uint32 defp to_mlx_type({:u, 64}), do: :uint64 + defp to_mlx_type({:s, 2}), do: :int8 + defp to_mlx_type({:s, 4}), do: :int8 defp to_mlx_type({:s, 8}), do: :int8 defp to_mlx_type({:s, 16}), do: :int16 defp to_mlx_type({:s, 32}), do: :int32 defp to_mlx_type({:s, 64}), do: :int64 + defp to_mlx_type({:f, 8}), do: :float16 defp to_mlx_type({:f, 16}), do: :float16 defp to_mlx_type({:f, 32}), do: :float32 + defp to_mlx_type({:f, 64}), do: :float32 defp to_mlx_type({:bf, 16}), do: :bfloat16 defp to_mlx_type({:c, 64}), do: :complex64 + defp to_mlx_type({:c, 128}), do: :complex64 defp to_mlx_type(:bool), do: :bool defp to_nx_type(:uint8), do: {:u, 8} @@ -252,16 +359,13 @@ defmodule EMLX.Backend do {{:u, 8}, {:u, qint}} when qint in [2, 4] -> :ok - {{:s, 32}, {:u, 16}} -> - :ok - - {{:s, 64}, {:u, 32}} -> + {{:f, 16}, {:f, 8}} -> :ok - {{:s, 64}, {:u, 64}} -> + {{:f, 32}, {:f, 64}} -> :ok - {{:u, 8}, {:u, 32}} -> + {{:c, 64}, {:c, 128}} -> :ok _ when actual_type != expected_type -> diff --git a/test/emlx/nx_doctest_test.exs b/test/emlx/nx_doctest_test.exs index fd055f7..889d960 100644 --- a/test/emlx/nx_doctest_test.exs +++ b/test/emlx/nx_doctest_test.exs @@ -25,7 +25,11 @@ defmodule EMLX.Nx.DoctestTest do cosh: 1, log10: 1, acos: 1, - covariance: 3 + covariance: 3, + # These fail because we're using different representation types + atan2: 2, + as_type: 2, + from_binary: 3 ] @to_be_fixed [ @@ -51,18 +55,6 @@ defmodule EMLX.Nx.DoctestTest do ] @not_supported [ - # Does not support f8 (yet?) - tensor: 2, - # Does not support u2 (yet?) - bit_size: 1, - # f64 not supported - from_binary: 3, - # f64 not supported - iota: 2, - # f64 not supported - atan2: 2, - # f64 not supported - as_type: 2, reduce: 4, window_reduce: 5, population_count: 1,