From 680b76cfa7ce9636252642678a1fb06886b28f7c Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 27 Nov 2024 10:29:00 -0300 Subject: [PATCH 1/2] refactor: remove same-device enforcement --- lib/emlx.ex | 58 ++++++++++++++++++++----------------------------- lib/emlx/nif.ex | 2 +- 2 files changed, 24 insertions(+), 36 deletions(-) diff --git a/lib/emlx.ex b/lib/emlx.ex index 017036e..3788031 100644 --- a/lib/emlx.ex +++ b/lib/emlx.ex @@ -1,3 +1,7 @@ +defmodule EMLX.NIFError do + defexception [:message] +end + defmodule EMLX.Macro do @moduledoc false @@ -97,7 +101,6 @@ defmodule EMLX.Macro do end defmodule EMLX do - alias EMLX.NIF, as: NIF use EMLX.Macro defguard is_tensor(device, ref) when is_reference(ref) and is_atom(device) @@ -122,7 +125,6 @@ defmodule EMLX do end ## Creation / conversion - def eye(size, type, device), do: eye(size, size, type, device) defdevice eye(m, n, type, device) defdevice from_blob(blob, shape, type, device) defdevice scalar_tensor(scalar, type, device) @@ -260,7 +262,7 @@ defmodule EMLX do defp unwrap!(:ok), do: :ok defp unwrap!({:ok, result}), do: result - defp unwrap!({:error, error}), do: raise("EMLX: " <> List.to_string(error)) + defp unwrap!({:error, error}), do: raise(EMLX.NIFError, List.to_string(error)) defp unwrap_tensor!(tagged_result, device) do case unwrap!(tagged_result) do @@ -275,49 +277,35 @@ defmodule EMLX do end end - defp prepare_tensors_list!(tensors_list, dev) do - tensors = - Enum.map(tensors_list, fn - {dev, ref} when is_tensor(dev, ref) -> - ref - - # TODO: double check if this is correct / does not have overhead - # {other_dev, ref} when is_tensor(other_dev, ref) -> - # raise ArgumentError, "cannot perform operation across devices #{dev} and #{other_dev}" - - bad_tensor -> - raise ArgumentError, "expected a EMLX tensor, got: #{inspect(bad_tensor)}" - end) + defp prepare_tensors_list!(tensors_list, device) do + Enum.map_reduce(tensors_list, device, fn + {dev, ref}, device when is_tensor(dev, ref) -> + {ref, merge_device(device, dev)} - {tensors, dev} + bad_tensor, _device -> + raise ArgumentError, "expected a EMLX tensor, got: #{inspect(bad_tensor)}" + end) end defp prepare_tensors!(tensors) do - Enum.map_reduce(tensors, nil, fn - {dev, ref}, nil when is_tensor(dev, ref) -> - {ref, dev} - - {dev, ref}, _dev when is_tensor(dev, ref) -> - {ref, dev} - - [{dev, ref} | _] = tensors, nil when is_tensor(dev, ref) -> - prepare_tensors_list!(tensors, dev) + Enum.map_reduce(tensors, :cpu, fn + {dev, ref}, device when is_tensor(dev, ref) -> + {ref, merge_device(device, dev)} - tensors, dev when is_list(tensors) -> - prepare_tensors_list!(tensors, dev) + [{dev, ref} | _] = tensors, device when is_tensor(dev, ref) -> + prepare_tensors_list!(tensors, device) - bad_tensor, _dev -> + bad_tensor, _device -> raise ArgumentError, "expected a EMLX tensor, got: #{inspect(bad_tensor)}" end) end - def deallocate(tensor_ref) do - NIF.deallocate(tensor_ref) - end + defp merge_device(:gpu, _), do: :gpu + defp merge_device(_, :gpu), do: :gpu + defp merge_device(_, _), do: :cpu - def eval(tensor) do - NIF.eval(tensor) - end + defvalue deallocate(tensor_ref) + defvalue eval(tensor) deftensor slice(tensor, starts, stops, strides) deftensor slice_update(tensor, tensor_updates, starts, stops) diff --git a/lib/emlx/nif.ex b/lib/emlx/nif.ex index ee07470..be1cbc1 100644 --- a/lib/emlx/nif.ex +++ b/lib/emlx/nif.ex @@ -3,7 +3,7 @@ defmodule EMLX.NIF do Elixir bindings for MLX array operations. """ - for {name, arity} <- EMLX.__mlx_functions__() ++ [eval: 1, deallocate: 1] do + for {name, arity} <- EMLX.__mlx_functions__() do args = Macro.generate_arguments(arity, __MODULE__) def unquote(name)(unquote_splicing(args)) do From 48c3c8088c171317df0c82cbcdf15223522c8b6f Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 27 Nov 2024 11:21:00 -0300 Subject: [PATCH 2/2] feat: add type transcoding --- lib/emlx/backend.ex | 116 ++++++++++++++++++++++++++++++++-- test/emlx/nx_doctest_test.exs | 18 ++---- 2 files changed, 115 insertions(+), 19 deletions(-) diff --git a/lib/emlx/backend.ex b/lib/emlx/backend.ex index eda810c..ea94a84 100644 --- a/lib/emlx/backend.ex +++ b/lib/emlx/backend.ex @@ -73,11 +73,13 @@ defmodule EMLX.Backend do @impl true def to_binary(tensor, limit) do EMLX.to_blob(from_nx(tensor), limit) + |> maybe_modify_binary(to_nx_type(to_mlx_type(tensor.type)), tensor.type) end @impl true def from_binary(%T{type: type, shape: shape} = out, binary, backend_options) do binary + |> maybe_modify_binary(type, to_nx_type(to_mlx_type(type))) |> EMLX.from_blob( shape, to_mlx_type(type), @@ -86,6 +88,104 @@ defmodule EMLX.Backend do |> to_nx(out) end + defp maybe_modify_binary(binary, type, type), do: binary + + defp maybe_modify_binary(binary, {:f, 8}, {:f, 16}) do + for <>, into: <<>> do + case read_f8(<>) do + :nan -> + Nx.Type.nan_binary({:f, 16}) + + :infinity -> + Nx.Type.infinity_binary({:f, 16}) + + :neg_infinity -> + Nx.Type.neg_infinity_binary({:f, 16}) + + number -> + <> + end + end + end + + defp maybe_modify_binary(binary, {:f, 16}, {:f, 8}) do + for <>, into: <<>> do + case <> do + <> -> write_finite_f8(float) + <<0xFC00::16-native>> -> write_non_finite(:neg_infinity, 8) + <<0x7C00::16-native>> -> write_non_finite(:infinity, 8) + _ -> write_non_finite(:nan, 8) + end + end + end + + defp maybe_modify_binary(binary, {:f, 64}, {:f, 32}) do + for <>, into: <<>> do + case <> do + <> -> <> + <<0xFFF0000000000000::64-native>> -> write_non_finite(:neg_infinity, 32) + <<0x7FF0000000000000::64-native>> -> write_non_finite(:infinity, 32) + _ -> write_non_finite(:nan, 32) + end + end + end + + defp maybe_modify_binary(binary, {:f, 32}, {:f, 64}) do + for <>, into: <<>> do + case <> do + <> -> <> + <<0xFF800000::32-native>> -> write_non_finite(:neg_infinity, 64) + <<0x7F800000::32-native>> -> write_non_finite(:infinity, 64) + _ -> write_non_finite(:nan, 64) + end + end + end + + defp maybe_modify_binary(binary, {:u, size}, {:u, 8}) when size in [2, 4] do + for <>, into: <<>> do + <> + end + end + + defp maybe_modify_binary(binary, {:u, 8}, {:u, size}) when size in [2, 4] do + for <>, into: <<>> do + <> + end + end + + defp read_f8(<<0xFC::8-native>>), do: :neg_infinity + defp read_f8(<<0x7C::8-native>>), do: :infinity + defp read_f8(<<_sign::1, 31::5, mantissa::2>>) when mantissa != 0, do: :nan + + defp read_f8(<>) do + float = :math.pow(2, exp - 15) * (1 + mantissa / 4) + + case sign do + 0 -> float + _ -> -float + end + end + + if System.endianness() == :little do + def write_finite_f8(x) do + binary_part(<>, 0, 1) + end + else + def write_finite_f8(x) do + binary_part(<>, 1, 1) + end + end + + for size <- [8, 16, 32, 64] do + def write_non_finite(data, unquote(size)) do + case data do + :infinity -> unquote(Nx.Type.infinity_binary({:f, size})) + :neg_infinity -> unquote(Nx.Type.neg_infinity_binary({:f, size})) + :nan -> unquote(Nx.Type.nan_binary({:f, size})) + end + end + end + @impl true def slice( out, @@ -202,18 +302,25 @@ defmodule EMLX.Backend do defp needs_type_conversion?({:u, 8}, :bool), do: true defp needs_type_conversion?(_, _), do: false + defp to_mlx_type({:u, 2}), do: :uint8 + defp to_mlx_type({:u, 4}), do: :uint8 defp to_mlx_type({:u, 8}), do: :uint8 defp to_mlx_type({:u, 16}), do: :uint16 defp to_mlx_type({:u, 32}), do: :uint32 defp to_mlx_type({:u, 64}), do: :uint64 + defp to_mlx_type({:s, 2}), do: :int8 + defp to_mlx_type({:s, 4}), do: :int8 defp to_mlx_type({:s, 8}), do: :int8 defp to_mlx_type({:s, 16}), do: :int16 defp to_mlx_type({:s, 32}), do: :int32 defp to_mlx_type({:s, 64}), do: :int64 + defp to_mlx_type({:f, 8}), do: :float16 defp to_mlx_type({:f, 16}), do: :float16 defp to_mlx_type({:f, 32}), do: :float32 + defp to_mlx_type({:f, 64}), do: :float32 defp to_mlx_type({:bf, 16}), do: :bfloat16 defp to_mlx_type({:c, 64}), do: :complex64 + defp to_mlx_type({:c, 128}), do: :complex64 defp to_mlx_type(:bool), do: :bool defp to_nx_type(:uint8), do: {:u, 8} @@ -252,16 +359,13 @@ defmodule EMLX.Backend do {{:u, 8}, {:u, qint}} when qint in [2, 4] -> :ok - {{:s, 32}, {:u, 16}} -> - :ok - - {{:s, 64}, {:u, 32}} -> + {{:f, 16}, {:f, 8}} -> :ok - {{:s, 64}, {:u, 64}} -> + {{:f, 32}, {:f, 64}} -> :ok - {{:u, 8}, {:u, 32}} -> + {{:c, 64}, {:c, 128}} -> :ok _ when actual_type != expected_type -> diff --git a/test/emlx/nx_doctest_test.exs b/test/emlx/nx_doctest_test.exs index fd055f7..889d960 100644 --- a/test/emlx/nx_doctest_test.exs +++ b/test/emlx/nx_doctest_test.exs @@ -25,7 +25,11 @@ defmodule EMLX.Nx.DoctestTest do cosh: 1, log10: 1, acos: 1, - covariance: 3 + covariance: 3, + # These fail because we're using different representation types + atan2: 2, + as_type: 2, + from_binary: 3 ] @to_be_fixed [ @@ -51,18 +55,6 @@ defmodule EMLX.Nx.DoctestTest do ] @not_supported [ - # Does not support f8 (yet?) - tensor: 2, - # Does not support u2 (yet?) - bit_size: 1, - # f64 not supported - from_binary: 3, - # f64 not supported - iota: 2, - # f64 not supported - atan2: 2, - # f64 not supported - as_type: 2, reduce: 4, window_reduce: 5, population_count: 1,