diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index 3503344a20..c91b9fb621 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -505,7 +505,6 @@ ERL_NIF_TERM binary_to_device_mem(ErlNifEnv* env, int argc, const ERL_NIF_TERM a return exla::nif::error(env, "Bad argument count."); } - ErlNifBinary bin; xla::Shape shape; exla::ExlaClient** client; int device_id; @@ -513,9 +512,6 @@ ERL_NIF_TERM binary_to_device_mem(ErlNifEnv* env, int argc, const ERL_NIF_TERM a if (!exla::nif::get(env, argv[0], client)) { return exla::nif::error(env, "Unable to get client."); } - if (!exla::nif::get_binary(env, argv[1], &bin)) { - return exla::nif::error(env, "Unable to get data."); - } if (!exla::nif::get_typespec_as_xla_shape(env, argv[2], &shape)) { return exla::nif::error(env, "Unable to get shape."); } diff --git a/exla/lib/exla/device_buffer.ex b/exla/lib/exla/device_buffer.ex index 9be97d4fab..3a070fc94a 100644 --- a/exla/lib/exla/device_buffer.ex +++ b/exla/lib/exla/device_buffer.ex @@ -18,10 +18,30 @@ defmodule EXLA.DeviceBuffer do Places the given binary `data` on the given `device` using `client`. """ def place_on_device(data, %EXLA.Typespec{} = typespec, client = %Client{}, device_id) - when is_integer(device_id) and is_binary(data) do + when is_integer(device_id) and is_bitstring(data) do + padded = + if is_binary(data) do + data + else + remaining = byte_size(data) * 8 - bit_size(data) + <> + end + + # padded = + # case typespec.type do + # {:u, size} when size in [2, 4] -> + # for <>, into: <<>>, do: <> + + # {:s, size} when size in [2, 4] -> + # for <>, into: <<>>, do: <> + + # _ -> + # data + # end + ref = client.ref - |> EXLA.NIF.binary_to_device_mem(data, EXLA.Typespec.nif_encode(typespec), device_id) + |> EXLA.NIF.binary_to_device_mem(padded, EXLA.Typespec.nif_encode(typespec), device_id) |> unwrap!() %DeviceBuffer{ref: ref, client_name: client.name, device_id: device_id, typespec: typespec} diff --git a/exla/test/exla/backend_test.exs b/exla/test/exla/backend_test.exs index edcd55c52f..8da166312f 100644 --- a/exla/test/exla/backend_test.exs +++ b/exla/test/exla/backend_test.exs @@ -197,4 +197,64 @@ defmodule EXLA.BackendTest do assert inspect(Nx.conjugate(~VEC[1 2-0i 3+0i 0-i 0-2i])) =~ "1.0-0.0i, 2.0+0.0i, 3.0-0.0i, 0.0+1.0i, 0.0+2.0i" end + + describe "quantized types" do + test "s2" do + tensor = Nx.s2([-2, -1, 1]) + assert tensor.type == {:s, 2} + + assert <<-2::2-signed-native, -1::2-signed-native, 1::2-signed-native>> = + Nx.to_binary(tensor) + + assert [-2, -1, 1] = Nx.to_flat_list(tensor) + assert 0 = Nx.byte_size(tensor) + assert 6 = Nx.bit_size(tensor) + + tensor = Nx.s2([-2, -1, 0, 1, 0, -1, -2]) + assert 1 = Nx.byte_size(tensor) + assert 14 = Nx.bit_size(tensor) + end + + test "s4" do + tensor = Nx.s4([-8, -1, 7]) + assert tensor.type == {:s, 4} + + assert <<-8::4-signed-native, -1::4-signed-native, 7::4-signed-native>> = + Nx.to_binary(tensor) + + assert [-8, -1, 7] = Nx.to_flat_list(tensor) + assert 1 = Nx.byte_size(tensor) + assert 12 = Nx.bit_size(tensor) + + tensor = Nx.s4([-8, -3, 0, 7, 0, -3, -8]) + assert 3 = Nx.byte_size(tensor) + assert 28 = Nx.bit_size(tensor) + end + + test "u2" do + tensor = Nx.u2([1, 2, 3]) + assert tensor.type == {:u, 2} + assert <<1::2-native, 2::2-native, 3::2-native>> = Nx.to_binary(tensor) + assert [1, 2, 3] = Nx.to_flat_list(tensor) + assert 0 = Nx.byte_size(tensor) + assert 6 = Nx.bit_size(tensor) + + tensor = Nx.u2([0, 1, 2, 3, 2, 1, 0]) + assert 1 = Nx.byte_size(tensor) + assert 14 = Nx.bit_size(tensor) + end + + test "u4" do + tensor = Nx.u4([0, 7, 15]) + assert tensor.type == {:u, 4} + assert <<0::4-native, 7::4-native, 15::4-native>> = Nx.to_binary(tensor) + assert [0, 7, 15] = Nx.to_flat_list(tensor) + assert 1 = Nx.byte_size(tensor) + assert 12 = Nx.bit_size(tensor) + + tensor = Nx.u4([0, 1, 2, 3, 13, 14, 15]) + assert 3 = Nx.byte_size(tensor) + assert 28 = Nx.bit_size(tensor) + end + end end