Skip to content

Commit

Permalink
EXLA int types WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Sep 4, 2024
1 parent ca18083 commit 4e4e397
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 6 deletions.
4 changes: 0 additions & 4 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -505,17 +505,13 @@ ERL_NIF_TERM binary_to_device_mem(ErlNifEnv* env, int argc, const ERL_NIF_TERM a
return exla::nif::error(env, "Bad argument count.");
}

ErlNifBinary bin;
xla::Shape shape;
exla::ExlaClient** client;
int device_id;

if (!exla::nif::get<exla::ExlaClient*>(env, argv[0], client)) {
return exla::nif::error(env, "Unable to get client.");
}
if (!exla::nif::get_binary(env, argv[1], &bin)) {
return exla::nif::error(env, "Unable to get data.");
}
if (!exla::nif::get_typespec_as_xla_shape(env, argv[2], &shape)) {
return exla::nif::error(env, "Unable to get shape.");
}
Expand Down
24 changes: 22 additions & 2 deletions exla/lib/exla/device_buffer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,30 @@ defmodule EXLA.DeviceBuffer do
Places the given binary `data` on the given `device` using `client`.
"""
def place_on_device(data, %EXLA.Typespec{} = typespec, client = %Client{}, device_id)
when is_integer(device_id) and is_binary(data) do
when is_integer(device_id) and is_bitstring(data) do
padded =
if is_binary(data) do
data
else
remaining = byte_size(data) * 8 - bit_size(data)
<<data::bitstring, 0::size(remaining)>>
end

# padded =
# case typespec.type do
# {:u, size} when size in [2, 4] ->
# for <<x::native-size(size) <- data>>, into: <<>>, do: <<x::native-8>>

# {:s, size} when size in [2, 4] ->
# for <<x::native-signed-size(size) <- data>>, into: <<>>, do: <<x::native-signed-8>>

# _ ->
# data
# end

ref =
client.ref
|> EXLA.NIF.binary_to_device_mem(data, EXLA.Typespec.nif_encode(typespec), device_id)
|> EXLA.NIF.binary_to_device_mem(padded, EXLA.Typespec.nif_encode(typespec), device_id)
|> unwrap!()

%DeviceBuffer{ref: ref, client_name: client.name, device_id: device_id, typespec: typespec}
Expand Down
60 changes: 60 additions & 0 deletions exla/test/exla/backend_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -197,4 +197,64 @@ defmodule EXLA.BackendTest do
assert inspect(Nx.conjugate(~VEC[1 2-0i 3+0i 0-i 0-2i])) =~
"1.0-0.0i, 2.0+0.0i, 3.0-0.0i, 0.0+1.0i, 0.0+2.0i"
end

describe "quantized types" do
test "s2" do
tensor = Nx.s2([-2, -1, 1])
assert tensor.type == {:s, 2}

assert <<-2::2-signed-native, -1::2-signed-native, 1::2-signed-native>> =
Nx.to_binary(tensor)

assert [-2, -1, 1] = Nx.to_flat_list(tensor)
assert 0 = Nx.byte_size(tensor)
assert 6 = Nx.bit_size(tensor)

tensor = Nx.s2([-2, -1, 0, 1, 0, -1, -2])
assert 1 = Nx.byte_size(tensor)
assert 14 = Nx.bit_size(tensor)
end

test "s4" do
tensor = Nx.s4([-8, -1, 7])
assert tensor.type == {:s, 4}

assert <<-8::4-signed-native, -1::4-signed-native, 7::4-signed-native>> =
Nx.to_binary(tensor)

assert [-8, -1, 7] = Nx.to_flat_list(tensor)
assert 1 = Nx.byte_size(tensor)
assert 12 = Nx.bit_size(tensor)

tensor = Nx.s4([-8, -3, 0, 7, 0, -3, -8])
assert 3 = Nx.byte_size(tensor)
assert 28 = Nx.bit_size(tensor)
end

test "u2" do
tensor = Nx.u2([1, 2, 3])
assert tensor.type == {:u, 2}
assert <<1::2-native, 2::2-native, 3::2-native>> = Nx.to_binary(tensor)
assert [1, 2, 3] = Nx.to_flat_list(tensor)
assert 0 = Nx.byte_size(tensor)
assert 6 = Nx.bit_size(tensor)

tensor = Nx.u2([0, 1, 2, 3, 2, 1, 0])
assert 1 = Nx.byte_size(tensor)
assert 14 = Nx.bit_size(tensor)
end

test "u4" do
tensor = Nx.u4([0, 7, 15])
assert tensor.type == {:u, 4}
assert <<0::4-native, 7::4-native, 15::4-native>> = Nx.to_binary(tensor)
assert [0, 7, 15] = Nx.to_flat_list(tensor)
assert 1 = Nx.byte_size(tensor)
assert 12 = Nx.bit_size(tensor)

tensor = Nx.u4([0, 1, 2, 3, 13, 14, 15])
assert 3 = Nx.byte_size(tensor)
assert 28 = Nx.bit_size(tensor)
end
end
end

0 comments on commit 4e4e397

Please sign in to comment.