Skip to content

Commit

Permalink
Add quantized int types (#1528)
Browse files Browse the repository at this point in the history
Co-authored-by: Jonatan Kłosko <[email protected]>
  • Loading branch information
josevalim and jonatanklosko authored Sep 11, 2024
1 parent ad28ea7 commit 8a9c2b3
Show file tree
Hide file tree
Showing 13 changed files with 324 additions and 79 deletions.
4 changes: 0 additions & 4 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -505,17 +505,13 @@ ERL_NIF_TERM binary_to_device_mem(ErlNifEnv* env, int argc, const ERL_NIF_TERM a
return exla::nif::error(env, "Bad argument count.");
}

ErlNifBinary bin;
xla::Shape shape;
exla::ExlaClient** client;
int device_id;

if (!exla::nif::get<exla::ExlaClient*>(env, argv[0], client)) {
return exla::nif::error(env, "Unable to get client.");
}
if (!exla::nif::get_binary(env, argv[1], &bin)) {
return exla::nif::error(env, "Unable to get data.");
}
if (!exla::nif::get_typespec_as_xla_shape(env, argv[2], &shape)) {
return exla::nif::error(env, "Unable to get shape.");
}
Expand Down
5 changes: 4 additions & 1 deletion exla/c_src/exla/exla_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,11 @@ xla::StatusOr<std::unique_ptr<xla::PjRtBuffer>> PjRtBufferFromBinary(xla::PjRtCl
std::function<void()> on_done_with_host_buffer = [copy_env]() { enif_free_env(copy_env); };

EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client->LookupDevice(xla::PjRtGlobalDeviceId(device_id)));
// Passing std::nullopt should work, but it fails for subbyte types,
// so we build the default strides. See https://github.com/openxla/xla/issues/16795
auto byte_strides = xla::ShapeUtil::ByteStrides(shape);
EXLA_ASSIGN_OR_RETURN(auto buffer, client->BufferFromHostBuffer(
binary.data, shape.element_type(), shape.dimensions(), std::nullopt, semantics, on_done_with_host_buffer, device));
binary.data, shape.element_type(), shape.dimensions(), byte_strides, semantics, on_done_with_host_buffer, device));

return std::move(buffer);
}
Expand Down
2 changes: 2 additions & 0 deletions exla/lib/exla/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ defmodule EXLA.Backend do

@impl true
def to_binary(%T{data: %B{buffer: buffer}, type: {_, size}}, limit) do
# Subbyte elements are read as individual bytes
size = max(size, 8)
EXLA.DeviceBuffer.read(buffer, limit * div(size, 8))
end

Expand Down
33 changes: 30 additions & 3 deletions exla/lib/exla/device_buffer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,21 @@ defmodule EXLA.DeviceBuffer do
Places the given binary `data` on the given `device` using `client`.
"""
def place_on_device(data, %EXLA.Typespec{} = typespec, client = %Client{}, device_id)
when is_integer(device_id) and is_binary(data) do
when is_integer(device_id) and is_bitstring(data) do
# At the moment XLA does not support allocating a packed buffer,
# so we unpack subbyte elements into their own bytes
data =
case typespec.type do
{:u, size} when size in [2, 4] ->
for <<x::native-size(size) <- data>>, into: <<>>, do: <<x::native-8>>

{:s, size} when size in [2, 4] ->
for <<x::native-signed-size(size) <- data>>, into: <<>>, do: <<x::native-signed-8>>

_ ->
data
end

ref =
client.ref
|> EXLA.NIF.binary_to_device_mem(data, EXLA.Typespec.nif_encode(typespec), device_id)
Expand Down Expand Up @@ -47,8 +61,21 @@ defmodule EXLA.DeviceBuffer do
without destroying it. If `size` is negative, then it
reads the whole buffer.
"""
def read(%DeviceBuffer{ref: ref}, size \\ -1) do
EXLA.NIF.read_device_mem(ref, size) |> unwrap!()
def read(%DeviceBuffer{ref: ref, typespec: typespec}, size \\ -1) do
data = EXLA.NIF.read_device_mem(ref, size) |> unwrap!()

# At the moment XLA does not support reading a packed buffer,
# so we pack the elements ourselves
case typespec.type do
{:u, size} when size in [2, 4] ->
for <<x::native-8 <- data>>, into: <<>>, do: <<x::native-size(size)>>

{:s, size} when size in [2, 4] ->
for <<x::native-signed-8 <- data>>, into: <<>>, do: <<x::native-signed-size(size)>>

_ ->
data
end
end

@doc """
Expand Down
4 changes: 4 additions & 0 deletions exla/lib/exla/typespec.ex
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,14 @@ defmodule EXLA.Typespec do
type_to_charlist = %{
:token => ~c"token",
{:pred, 8} => ~c"pred",
{:s, 2} => ~c"s2",
{:s, 4} => ~c"s4",
{:s, 8} => ~c"s8",
{:s, 16} => ~c"s16",
{:s, 32} => ~c"s32",
{:s, 64} => ~c"s64",
{:u, 2} => ~c"u2",
{:u, 4} => ~c"u4",
{:u, 8} => ~c"u8",
{:u, 16} => ~c"u16",
{:u, 32} => ~c"u32",
Expand Down
72 changes: 72 additions & 0 deletions exla/test/exla/backend_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -197,4 +197,76 @@ defmodule EXLA.BackendTest do
assert inspect(Nx.conjugate(~VEC[1 2-0i 3+0i 0-i 0-2i])) =~
"1.0-0.0i, 2.0+0.0i, 3.0-0.0i, 0.0+1.0i, 0.0+2.0i"
end

describe "quantized types" do
test "s2" do
tensor = Nx.s2(-1)
assert <<-1::2-signed-native>> = Nx.to_binary(tensor)

tensor = Nx.s2([-2, -1, 1])
assert tensor.type == {:s, 2}

assert <<-2::2-signed-native, -1::2-signed-native, 1::2-signed-native>> =
Nx.to_binary(tensor)

assert [-2, -1, 1] = Nx.to_flat_list(tensor)
assert 0 = Nx.byte_size(tensor)
assert 6 = Nx.bit_size(tensor)

tensor = Nx.s2([-2, -1, 0, 1, 0, -1, -2])
assert 1 = Nx.byte_size(tensor)
assert 14 = Nx.bit_size(tensor)
end

test "s4" do
tensor = Nx.s4(-1)
assert <<-1::4-signed-native>> = Nx.to_binary(tensor)

tensor = Nx.s4([-8, -1, 7])
assert tensor.type == {:s, 4}

assert <<-8::4-signed-native, -1::4-signed-native, 7::4-signed-native>> =
Nx.to_binary(tensor)

assert [-8, -1, 7] = Nx.to_flat_list(tensor)
assert 1 = Nx.byte_size(tensor)
assert 12 = Nx.bit_size(tensor)

tensor = Nx.s4([-8, -3, 0, 7, 0, -3, -8])
assert 3 = Nx.byte_size(tensor)
assert 28 = Nx.bit_size(tensor)
end

test "u2" do
tensor = Nx.u2(1)
assert <<1::2-native>> = Nx.to_binary(tensor)

tensor = Nx.u2([1, 2, 3])
assert tensor.type == {:u, 2}
assert <<1::2-native, 2::2-native, 3::2-native>> = Nx.to_binary(tensor)
assert [1, 2, 3] = Nx.to_flat_list(tensor)
assert 0 = Nx.byte_size(tensor)
assert 6 = Nx.bit_size(tensor)

tensor = Nx.u2([0, 1, 2, 3, 2, 1, 0])
assert 1 = Nx.byte_size(tensor)
assert 14 = Nx.bit_size(tensor)
end

test "u4" do
tensor = Nx.u4(1)
assert <<1::4-native>> = Nx.to_binary(tensor)

tensor = Nx.u4([0, 7, 15])
assert tensor.type == {:u, 4}
assert <<0::4-native, 7::4-native, 15::4-native>> = Nx.to_binary(tensor)
assert [0, 7, 15] = Nx.to_flat_list(tensor)
assert 1 = Nx.byte_size(tensor)
assert 12 = Nx.bit_size(tensor)

tensor = Nx.u4([0, 1, 2, 3, 13, 14, 15])
assert 3 = Nx.byte_size(tensor)
assert 28 = Nx.bit_size(tensor)
end
end
end
2 changes: 1 addition & 1 deletion nx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

Nx is a multi-dimensional tensors library for Elixir with multi-staged compilation to the CPU/GPU. Its high-level features are:

* Typed multi-dimensional tensors, where the tensors can be unsigned integers (`u8`, `u16`, `u32`, `u64`), signed integers (`s8`, `s16`, `s32`, `s64`), floats (`f16`, `f32`, `f64`), brain floats (`bf16`), and complex numbers (`c64`, `c128`);
* Typed multi-dimensional tensors, where the tensors can be unsigned integers (`u2`, `u4`, `u8`, `u16`, `u32`, `u64`), signed integers (`s2`, `s4`, `s8`, `s16`, `s32`, `s64`), floats (`f16`, `f32`, `f64`), brain floats (`bf16`), and complex numbers (`c64`, `c128`);

* Named tensors, allowing developers to give names to each dimension, leading to more readable and less error prone codebases;

Expand Down
4 changes: 2 additions & 2 deletions nx/guides/intro-to-nx.livemd
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ libraries that support those tensors. Nx has three primary capabilities:
such as machine learning, simulations, curve fitting, and probabilistic models.

Here's more about each of those capabilities. Nx [tensors]() can hold
unsigned integers (u8, u16, u32, u64),
signed integers (s8, s16, s32, s64),
unsigned integers (u2, u4, u8, u16, u32, u64),
signed integers (s2, s4s8, s16, s32, s64),
floats (f32, f64), brain floats (bf16), and complex (c64, c128).
Tensors support backends implemented outside of Elixir, including Google's
Accelerated Linear Algebra (XLA) and LibTorch.
Expand Down
70 changes: 55 additions & 15 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ defmodule Nx do
The tensor types can be one of:
* unsigned integers (`u8`, `u16`, `u32`, `u64`)
* signed integers (`s8`, `s16`, `s32`, `s64`)
* unsigned integers (`u2`, `u4`, `u8`, `u16`, `u32`, `u64`)
* signed integers (`s2`, `s4`, `s8`, `s16`, `s32`, `s64`)
* floats (`f8`, `f16`, `f32`, `f64`)
* brain floats (`bf16`)
* and complex numbers (`c64`, `c128`)
Expand Down Expand Up @@ -431,6 +431,7 @@ defmodule Nx do

import Nx.Shared
import Nx.Defn.Kernel, only: [keyword!: 2]
import Kernel, except: [bit_size: 1]

alias Nx.Tensor, as: T

Expand Down Expand Up @@ -855,7 +856,7 @@ defmodule Nx do
{dimensions, acc} = flatten_list(list, type, [], [])

{dimensions |> Enum.reverse() |> List.to_tuple(),
acc |> Enum.reverse() |> :erlang.list_to_binary()}
acc |> Enum.reverse() |> :erlang.list_to_bitstring()}
end

defp flatten_list([], _type, dimensions, acc) do
Expand Down Expand Up @@ -940,7 +941,9 @@ defmodule Nx do
%T{shape: shape, type: type, names: names, data: %Nx.TemplateBackend{}}
end

for t <- [:u8, :u16, :u32, :u64, :s8, :s16, :s32, :s64, :bf16, :f8, :f16, :f32, :f64] do
for t <-
[:u2, :u4, :u8, :u16, :u32, :u64, :s2, :s4, :s8, :s16, :s32, :s64] ++
[:f8, :bf16, :f16, :f32, :f64] do
@doc """
Short-hand function for creating tensor of type `#{t}`.
Expand Down Expand Up @@ -1971,13 +1974,13 @@ defmodule Nx do
def from_binary(binary, type, opts \\ []) when is_binary(binary) do
opts = keyword!(opts, [:backend])
{_, size} = type = Nx.Type.normalize!(type)
dim = div(bit_size(binary), size)
dim = div(Kernel.bit_size(binary), size)

if binary == "" do
raise ArgumentError, "cannot build an empty tensor"
end

if rem(bit_size(binary), size) != 0 do
if rem(Kernel.bit_size(binary), size) != 0 do
raise ArgumentError, "binary does not match the given size"
end

Expand All @@ -1990,17 +1993,26 @@ defmodule Nx do
@doc """
Returns the underlying tensor as a binary.
**Warning**: converting a tensor to a binary can
potentially be a very expensive operation, as it
may copy a GPU tensor fully to the machine memory.
It returns the in-memory binary representation of
the tensor in a row-major fashion. The binary is
in the system endianness, which has to be taken into
account if the binary is meant to be serialized to
other systems.
Note: This function cannot be used in `defn`.
This function cannot be used in `defn`.
> ### Potentially expensive operation {: .warning}
>
> Converting a tensor to a binary can potentially be a very
> expensive operation, as it may copy a GPU tensor fully to
> the machine memory.
> ### Binaries vs bitstrings {: .info}
>
> If a tensor of type u2/u4/s2/s4 is given to this function,
> this function may not return a binary (where the number of bits
> is divisible by 8) but rather a bitstring (where the number of
> bits may not be divisible by 8).
## Options
Expand Down Expand Up @@ -4286,6 +4298,10 @@ defmodule Nx do
Returns the byte size of the data in the tensor
computed from its shape and type.
If the tensor has s2/s4/u2/u4 types, the value
will be rounded down. Consider using `bit_size/1`
instead.
## Examples
iex> Nx.byte_size(Nx.tensor([[1, 2, 3], [4, 5, 6]]))
Expand All @@ -4304,9 +4320,33 @@ defmodule Nx do
"""
@doc type: :shape
def byte_size(tensor) do
def byte_size(tensor), do: div(bit_size(tensor), 8)

@doc """
Returns the bit size of the data in the tensor
computed from its shape and type.
## Examples
iex> Nx.bit_size(Nx.tensor([[1, 2, 3], [4, 5, 6]]))
192
iex> Nx.bit_size(Nx.tensor([[1, 2, 3], [4, 5, 6]], type: :u8))
48
iex> Nx.bit_size(Nx.tensor([[1, 2, 3], [3, 2, 1]], type: :u2))
12
iex> Nx.bit_size(1)
32
Vectorized tensors account for all elements
iex> Nx.bit_size(Nx.tensor([[1, 2], [3, 4]]) |> Nx.vectorize(:x))
128
"""
@doc type: :shape
def bit_size(tensor) do
%{type: {_, bit_size}} = tensor = to_tensor(tensor)
flat_size(tensor) * div(bit_size, 8)
flat_size(tensor) * bit_size
end

@doc """
Expand Down Expand Up @@ -15466,9 +15506,9 @@ defmodule Nx do
defp do_numpy_to_tensor(rest, header_size) when is_binary(rest) do
<<header::size(header_size)-binary, array::binary>> = rest
{byte_order, {_, size} = type, shape, fortran_order?} = parse_header(header)
byte_size_of_array = div(size, 8) * Nx.size(shape)
bit_size_of_array = size * Nx.size(shape)

<<data::size(byte_size_of_array)-binary>> = array
<<data::size(bit_size_of_array)-bitstring>> = array

data
|> new_byte_order(size, byte_order)
Expand Down
Loading

0 comments on commit 8a9c2b3

Please sign in to comment.