Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add quantized int types #1528

Merged
merged 6 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -505,17 +505,13 @@ ERL_NIF_TERM binary_to_device_mem(ErlNifEnv* env, int argc, const ERL_NIF_TERM a
return exla::nif::error(env, "Bad argument count.");
}

ErlNifBinary bin;
xla::Shape shape;
exla::ExlaClient** client;
int device_id;

if (!exla::nif::get<exla::ExlaClient*>(env, argv[0], client)) {
return exla::nif::error(env, "Unable to get client.");
}
if (!exla::nif::get_binary(env, argv[1], &bin)) {
return exla::nif::error(env, "Unable to get data.");
}
if (!exla::nif::get_typespec_as_xla_shape(env, argv[2], &shape)) {
return exla::nif::error(env, "Unable to get shape.");
}
Expand Down
5 changes: 4 additions & 1 deletion exla/c_src/exla/exla_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,11 @@ xla::StatusOr<std::unique_ptr<xla::PjRtBuffer>> PjRtBufferFromBinary(xla::PjRtCl
std::function<void()> on_done_with_host_buffer = [copy_env]() { enif_free_env(copy_env); };

EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client->LookupDevice(xla::PjRtGlobalDeviceId(device_id)));
// Passing std::nullopt should work, but it fails for subbyte types,
// so we build the default strides. See https://github.com/openxla/xla/issues/16795
auto byte_strides = xla::ShapeUtil::ByteStrides(shape);
EXLA_ASSIGN_OR_RETURN(auto buffer, client->BufferFromHostBuffer(
binary.data, shape.element_type(), shape.dimensions(), std::nullopt, semantics, on_done_with_host_buffer, device));
binary.data, shape.element_type(), shape.dimensions(), byte_strides, semantics, on_done_with_host_buffer, device));

return std::move(buffer);
}
Expand Down
2 changes: 2 additions & 0 deletions exla/lib/exla/backend.ex
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ defmodule EXLA.Backend do

@impl true
def to_binary(%T{data: %B{buffer: buffer}, type: {_, size}}, limit) do
# Subbyte elements are read as individual bytes
size = max(size, 8)
EXLA.DeviceBuffer.read(buffer, limit * div(size, 8))
end

Expand Down
42 changes: 39 additions & 3 deletions exla/lib/exla/device_buffer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,30 @@ defmodule EXLA.DeviceBuffer do
Places the given binary `data` on the given `device` using `client`.
"""
def place_on_device(data, %EXLA.Typespec{} = typespec, client = %Client{}, device_id)
when is_integer(device_id) and is_binary(data) do
when is_integer(device_id) and is_bitstring(data) do
# # Pad
# data =
# if is_binary(data) do
# data
# else
# remaining = byte_size(data) * 8 - bit_size(data)
# <<data::bitstring, 0::size(remaining)>>
# end

josevalim marked this conversation as resolved.
Show resolved Hide resolved
# At the moment XLA does not support allocating a packed buffer,
# so we unpack subbyte elements into their own bytes
data =
case typespec.type do
{:u, size} when size in [2, 4] ->
for <<x::native-size(size) <- data>>, into: <<>>, do: <<x::native-8>>

{:s, size} when size in [2, 4] ->
for <<x::native-signed-size(size) <- data>>, into: <<>>, do: <<x::native-signed-8>>

_ ->
data
end

ref =
client.ref
|> EXLA.NIF.binary_to_device_mem(data, EXLA.Typespec.nif_encode(typespec), device_id)
Expand Down Expand Up @@ -47,8 +70,21 @@ defmodule EXLA.DeviceBuffer do
without destroying it. If `size` is negative, then it
reads the whole buffer.
"""
def read(%DeviceBuffer{ref: ref}, size \\ -1) do
EXLA.NIF.read_device_mem(ref, size) |> unwrap!()
def read(%DeviceBuffer{ref: ref, typespec: typespec}, size \\ -1) do
data = EXLA.NIF.read_device_mem(ref, size) |> unwrap!()

# At the moment XLA does not support reading a packed buffer,
# so we pack the elements ourselves
case typespec.type do
{:u, size} when size in [2, 4] ->
for <<x::native-8 <- data>>, into: <<>>, do: <<x::native-size(size)>>

{:s, size} when size in [2, 4] ->
for <<x::native-signed-8 <- data>>, into: <<>>, do: <<x::native-signed-size(size)>>

_ ->
data
end
end

@doc """
Expand Down
4 changes: 4 additions & 0 deletions exla/lib/exla/typespec.ex
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,14 @@ defmodule EXLA.Typespec do
type_to_charlist = %{
:token => ~c"token",
{:pred, 8} => ~c"pred",
{:s, 2} => ~c"s2",
{:s, 4} => ~c"s4",
{:s, 8} => ~c"s8",
{:s, 16} => ~c"s16",
{:s, 32} => ~c"s32",
{:s, 64} => ~c"s64",
{:u, 2} => ~c"u2",
{:u, 4} => ~c"u4",
{:u, 8} => ~c"u8",
{:u, 16} => ~c"u16",
{:u, 32} => ~c"u32",
Expand Down
72 changes: 72 additions & 0 deletions exla/test/exla/backend_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -197,4 +197,76 @@ defmodule EXLA.BackendTest do
assert inspect(Nx.conjugate(~VEC[1 2-0i 3+0i 0-i 0-2i])) =~
"1.0-0.0i, 2.0+0.0i, 3.0-0.0i, 0.0+1.0i, 0.0+2.0i"
end

describe "quantized types" do
test "s2" do
tensor = Nx.s2(-1)
assert <<-1::2-signed-native>> = Nx.to_binary(tensor)

tensor = Nx.s2([-2, -1, 1])
assert tensor.type == {:s, 2}

assert <<-2::2-signed-native, -1::2-signed-native, 1::2-signed-native>> =
Nx.to_binary(tensor)

assert [-2, -1, 1] = Nx.to_flat_list(tensor)
assert 0 = Nx.byte_size(tensor)
assert 6 = Nx.bit_size(tensor)

tensor = Nx.s2([-2, -1, 0, 1, 0, -1, -2])
assert 1 = Nx.byte_size(tensor)
assert 14 = Nx.bit_size(tensor)
end

test "s4" do
tensor = Nx.s4(-1)
assert <<-1::4-signed-native>> = Nx.to_binary(tensor)

tensor = Nx.s4([-8, -1, 7])
assert tensor.type == {:s, 4}

assert <<-8::4-signed-native, -1::4-signed-native, 7::4-signed-native>> =
Nx.to_binary(tensor)

assert [-8, -1, 7] = Nx.to_flat_list(tensor)
assert 1 = Nx.byte_size(tensor)
assert 12 = Nx.bit_size(tensor)

tensor = Nx.s4([-8, -3, 0, 7, 0, -3, -8])
assert 3 = Nx.byte_size(tensor)
assert 28 = Nx.bit_size(tensor)
end

test "u2" do
tensor = Nx.u2(1)
assert <<1::2-native>> = Nx.to_binary(tensor)

tensor = Nx.u2([1, 2, 3])
assert tensor.type == {:u, 2}
assert <<1::2-native, 2::2-native, 3::2-native>> = Nx.to_binary(tensor)
assert [1, 2, 3] = Nx.to_flat_list(tensor)
assert 0 = Nx.byte_size(tensor)
assert 6 = Nx.bit_size(tensor)

tensor = Nx.u2([0, 1, 2, 3, 2, 1, 0])
assert 1 = Nx.byte_size(tensor)
assert 14 = Nx.bit_size(tensor)
end

test "u4" do
tensor = Nx.u4(1)
assert <<1::4-native>> = Nx.to_binary(tensor)

tensor = Nx.u4([0, 7, 15])
assert tensor.type == {:u, 4}
assert <<0::4-native, 7::4-native, 15::4-native>> = Nx.to_binary(tensor)
assert [0, 7, 15] = Nx.to_flat_list(tensor)
assert 1 = Nx.byte_size(tensor)
assert 12 = Nx.bit_size(tensor)

tensor = Nx.u4([0, 1, 2, 3, 13, 14, 15])
assert 3 = Nx.byte_size(tensor)
assert 28 = Nx.bit_size(tensor)
end
end
end
2 changes: 1 addition & 1 deletion nx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

Nx is a multi-dimensional tensors library for Elixir with multi-staged compilation to the CPU/GPU. Its high-level features are:

* Typed multi-dimensional tensors, where the tensors can be unsigned integers (`u8`, `u16`, `u32`, `u64`), signed integers (`s8`, `s16`, `s32`, `s64`), floats (`f16`, `f32`, `f64`), brain floats (`bf16`), and complex numbers (`c64`, `c128`);
* Typed multi-dimensional tensors, where the tensors can be unsigned integers (`u2`, `u4`, `u8`, `u16`, `u32`, `u64`), signed integers (`s2`, `s4`, `s8`, `s16`, `s32`, `s64`), floats (`f16`, `f32`, `f64`), brain floats (`bf16`), and complex numbers (`c64`, `c128`);

* Named tensors, allowing developers to give names to each dimension, leading to more readable and less error prone codebases;

Expand Down
4 changes: 2 additions & 2 deletions nx/guides/intro-to-nx.livemd
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ libraries that support those tensors. Nx has three primary capabilities:
such as machine learning, simulations, curve fitting, and probabilistic models.

Here's more about each of those capabilities. Nx [tensors]() can hold
unsigned integers (u8, u16, u32, u64),
signed integers (s8, s16, s32, s64),
unsigned integers (u2, u4, u8, u16, u32, u64),
signed integers (s2, s4s8, s16, s32, s64),
floats (f32, f64), brain floats (bf16), and complex (c64, c128).
Tensors support backends implemented outside of Elixir, including Google's
Accelerated Linear Algebra (XLA) and LibTorch.
Expand Down
70 changes: 55 additions & 15 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ defmodule Nx do

The tensor types can be one of:

* unsigned integers (`u8`, `u16`, `u32`, `u64`)
* signed integers (`s8`, `s16`, `s32`, `s64`)
* unsigned integers (`u2`, `u4`, `u8`, `u16`, `u32`, `u64`)
* signed integers (`s2`, `s4`, `s8`, `s16`, `s32`, `s64`)
* floats (`f8`, `f16`, `f32`, `f64`)
* brain floats (`bf16`)
* and complex numbers (`c64`, `c128`)
Expand Down Expand Up @@ -431,6 +431,7 @@ defmodule Nx do

import Nx.Shared
import Nx.Defn.Kernel, only: [keyword!: 2]
import Kernel, except: [bit_size: 1]

alias Nx.Tensor, as: T

Expand Down Expand Up @@ -855,7 +856,7 @@ defmodule Nx do
{dimensions, acc} = flatten_list(list, type, [], [])

{dimensions |> Enum.reverse() |> List.to_tuple(),
acc |> Enum.reverse() |> :erlang.list_to_binary()}
acc |> Enum.reverse() |> :erlang.list_to_bitstring()}
end

defp flatten_list([], _type, dimensions, acc) do
Expand Down Expand Up @@ -940,7 +941,9 @@ defmodule Nx do
%T{shape: shape, type: type, names: names, data: %Nx.TemplateBackend{}}
end

for t <- [:u8, :u16, :u32, :u64, :s8, :s16, :s32, :s64, :bf16, :f8, :f16, :f32, :f64] do
for t <-
[:u2, :u4, :u8, :u16, :u32, :u64, :s2, :s4, :s8, :s16, :s32, :s64] ++
[:f8, :bf16, :f16, :f32, :f64] do
@doc """
Short-hand function for creating tensor of type `#{t}`.

Expand Down Expand Up @@ -1971,13 +1974,13 @@ defmodule Nx do
def from_binary(binary, type, opts \\ []) when is_binary(binary) do
opts = keyword!(opts, [:backend])
{_, size} = type = Nx.Type.normalize!(type)
dim = div(bit_size(binary), size)
dim = div(Kernel.bit_size(binary), size)

if binary == "" do
raise ArgumentError, "cannot build an empty tensor"
end

if rem(bit_size(binary), size) != 0 do
if rem(Kernel.bit_size(binary), size) != 0 do
raise ArgumentError, "binary does not match the given size"
end

Expand All @@ -1990,17 +1993,26 @@ defmodule Nx do
@doc """
Returns the underlying tensor as a binary.

**Warning**: converting a tensor to a binary can
potentially be a very expensive operation, as it
may copy a GPU tensor fully to the machine memory.

It returns the in-memory binary representation of
the tensor in a row-major fashion. The binary is
in the system endianness, which has to be taken into
account if the binary is meant to be serialized to
other systems.

Note: This function cannot be used in `defn`.
This function cannot be used in `defn`.

> ### Potentially expensive operation {: .warning}
>
> Converting a tensor to a binary can potentially be a very
> expensive operation, as it may copy a GPU tensor fully to
> the machine memory.

> ### Binaries vs bitstrings {: .info}
>
> If a tensor of type u2/u4/s2/s4 is given to this function,
> this function may not return a binary (where the number of bits
> is divisible by 8) but rather a bitstring (where the number of
> bits may not be divisible by 8).

## Options

Expand Down Expand Up @@ -4286,6 +4298,10 @@ defmodule Nx do
Returns the byte size of the data in the tensor
computed from its shape and type.

If the tensor has s2/s4/u2/u4 types, the value
will be rounded down. Consider using `bit_size/1`
instead.

## Examples

iex> Nx.byte_size(Nx.tensor([[1, 2, 3], [4, 5, 6]]))
Expand All @@ -4304,9 +4320,33 @@ defmodule Nx do

"""
@doc type: :shape
def byte_size(tensor) do
def byte_size(tensor), do: div(bit_size(tensor), 8)

@doc """
Returns the bit size of the data in the tensor
computed from its shape and type.

## Examples

iex> Nx.bit_size(Nx.tensor([[1, 2, 3], [4, 5, 6]]))
192
iex> Nx.bit_size(Nx.tensor([[1, 2, 3], [4, 5, 6]], type: :u8))
48
iex> Nx.bit_size(Nx.tensor([[1, 2, 3], [3, 2, 1]], type: :u2))
12
iex> Nx.bit_size(1)
32

Vectorized tensors account for all elements

iex> Nx.bit_size(Nx.tensor([[1, 2], [3, 4]]) |> Nx.vectorize(:x))
128

"""
@doc type: :shape
def bit_size(tensor) do
%{type: {_, bit_size}} = tensor = to_tensor(tensor)
flat_size(tensor) * div(bit_size, 8)
flat_size(tensor) * bit_size
end

@doc """
Expand Down Expand Up @@ -15466,9 +15506,9 @@ defmodule Nx do
defp do_numpy_to_tensor(rest, header_size) when is_binary(rest) do
<<header::size(header_size)-binary, array::binary>> = rest
{byte_order, {_, size} = type, shape, fortran_order?} = parse_header(header)
byte_size_of_array = div(size, 8) * Nx.size(shape)
bit_size_of_array = size * Nx.size(shape)

<<data::size(byte_size_of_array)-binary>> = array
<<data::size(bit_size_of_array)-bitstring>> = array

data
|> new_byte_order(size, byte_order)
Expand Down
Loading
Loading