diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index 3503344a207..c91b9fb6213 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -505,7 +505,6 @@ ERL_NIF_TERM binary_to_device_mem(ErlNifEnv* env, int argc, const ERL_NIF_TERM a return exla::nif::error(env, "Bad argument count."); } - ErlNifBinary bin; xla::Shape shape; exla::ExlaClient** client; int device_id; @@ -513,9 +512,6 @@ ERL_NIF_TERM binary_to_device_mem(ErlNifEnv* env, int argc, const ERL_NIF_TERM a if (!exla::nif::get(env, argv[0], client)) { return exla::nif::error(env, "Unable to get client."); } - if (!exla::nif::get_binary(env, argv[1], &bin)) { - return exla::nif::error(env, "Unable to get data."); - } if (!exla::nif::get_typespec_as_xla_shape(env, argv[2], &shape)) { return exla::nif::error(env, "Unable to get shape."); } diff --git a/exla/c_src/exla/exla_client.cc b/exla/c_src/exla/exla_client.cc index 7e30bc76628..7e6af18a1a7 100644 --- a/exla/c_src/exla/exla_client.cc +++ b/exla/c_src/exla/exla_client.cc @@ -66,8 +66,11 @@ xla::StatusOr> PjRtBufferFromBinary(xla::PjRtCl std::function on_done_with_host_buffer = [copy_env]() { enif_free_env(copy_env); }; EXLA_ASSIGN_OR_RETURN(xla::PjRtDevice * device, client->LookupDevice(xla::PjRtGlobalDeviceId(device_id))); + // Passing std::nullopt should work, but it fails for subbyte types, + // so we build the default strides. See https://github.com/openxla/xla/issues/16795 + auto byte_strides = xla::ShapeUtil::ByteStrides(shape); EXLA_ASSIGN_OR_RETURN(auto buffer, client->BufferFromHostBuffer( - binary.data, shape.element_type(), shape.dimensions(), std::nullopt, semantics, on_done_with_host_buffer, device)); + binary.data, shape.element_type(), shape.dimensions(), byte_strides, semantics, on_done_with_host_buffer, device)); return std::move(buffer); } diff --git a/exla/lib/exla/backend.ex b/exla/lib/exla/backend.ex index d6f8fc81b59..dc4b70460c5 100644 --- a/exla/lib/exla/backend.ex +++ b/exla/lib/exla/backend.ex @@ -187,6 +187,8 @@ defmodule EXLA.Backend do @impl true def to_binary(%T{data: %B{buffer: buffer}, type: {_, size}}, limit) do + # Subbyte elements are read as individual bytes + size = max(size, 8) EXLA.DeviceBuffer.read(buffer, limit * div(size, 8)) end diff --git a/exla/lib/exla/device_buffer.ex b/exla/lib/exla/device_buffer.ex index 9be97d4faba..d1753d76304 100644 --- a/exla/lib/exla/device_buffer.ex +++ b/exla/lib/exla/device_buffer.ex @@ -18,7 +18,21 @@ defmodule EXLA.DeviceBuffer do Places the given binary `data` on the given `device` using `client`. """ def place_on_device(data, %EXLA.Typespec{} = typespec, client = %Client{}, device_id) - when is_integer(device_id) and is_binary(data) do + when is_integer(device_id) and is_bitstring(data) do + # At the moment XLA does not support allocating a packed buffer, + # so we unpack subbyte elements into their own bytes + data = + case typespec.type do + {:u, size} when size in [2, 4] -> + for <>, into: <<>>, do: <> + + {:s, size} when size in [2, 4] -> + for <>, into: <<>>, do: <> + + _ -> + data + end + ref = client.ref |> EXLA.NIF.binary_to_device_mem(data, EXLA.Typespec.nif_encode(typespec), device_id) @@ -47,8 +61,21 @@ defmodule EXLA.DeviceBuffer do without destroying it. If `size` is negative, then it reads the whole buffer. """ - def read(%DeviceBuffer{ref: ref}, size \\ -1) do - EXLA.NIF.read_device_mem(ref, size) |> unwrap!() + def read(%DeviceBuffer{ref: ref, typespec: typespec}, size \\ -1) do + data = EXLA.NIF.read_device_mem(ref, size) |> unwrap!() + + # At the moment XLA does not support reading a packed buffer, + # so we pack the elements ourselves + case typespec.type do + {:u, size} when size in [2, 4] -> + for <>, into: <<>>, do: <> + + {:s, size} when size in [2, 4] -> + for <>, into: <<>>, do: <> + + _ -> + data + end end @doc """ diff --git a/exla/lib/exla/typespec.ex b/exla/lib/exla/typespec.ex index 0c56bf07f23..60166ef84d0 100644 --- a/exla/lib/exla/typespec.ex +++ b/exla/lib/exla/typespec.ex @@ -53,10 +53,14 @@ defmodule EXLA.Typespec do type_to_charlist = %{ :token => ~c"token", {:pred, 8} => ~c"pred", + {:s, 2} => ~c"s2", + {:s, 4} => ~c"s4", {:s, 8} => ~c"s8", {:s, 16} => ~c"s16", {:s, 32} => ~c"s32", {:s, 64} => ~c"s64", + {:u, 2} => ~c"u2", + {:u, 4} => ~c"u4", {:u, 8} => ~c"u8", {:u, 16} => ~c"u16", {:u, 32} => ~c"u32", diff --git a/exla/test/exla/backend_test.exs b/exla/test/exla/backend_test.exs index edcd55c52f3..22e3c60850e 100644 --- a/exla/test/exla/backend_test.exs +++ b/exla/test/exla/backend_test.exs @@ -197,4 +197,76 @@ defmodule EXLA.BackendTest do assert inspect(Nx.conjugate(~VEC[1 2-0i 3+0i 0-i 0-2i])) =~ "1.0-0.0i, 2.0+0.0i, 3.0-0.0i, 0.0+1.0i, 0.0+2.0i" end + + describe "quantized types" do + test "s2" do + tensor = Nx.s2(-1) + assert <<-1::2-signed-native>> = Nx.to_binary(tensor) + + tensor = Nx.s2([-2, -1, 1]) + assert tensor.type == {:s, 2} + + assert <<-2::2-signed-native, -1::2-signed-native, 1::2-signed-native>> = + Nx.to_binary(tensor) + + assert [-2, -1, 1] = Nx.to_flat_list(tensor) + assert 0 = Nx.byte_size(tensor) + assert 6 = Nx.bit_size(tensor) + + tensor = Nx.s2([-2, -1, 0, 1, 0, -1, -2]) + assert 1 = Nx.byte_size(tensor) + assert 14 = Nx.bit_size(tensor) + end + + test "s4" do + tensor = Nx.s4(-1) + assert <<-1::4-signed-native>> = Nx.to_binary(tensor) + + tensor = Nx.s4([-8, -1, 7]) + assert tensor.type == {:s, 4} + + assert <<-8::4-signed-native, -1::4-signed-native, 7::4-signed-native>> = + Nx.to_binary(tensor) + + assert [-8, -1, 7] = Nx.to_flat_list(tensor) + assert 1 = Nx.byte_size(tensor) + assert 12 = Nx.bit_size(tensor) + + tensor = Nx.s4([-8, -3, 0, 7, 0, -3, -8]) + assert 3 = Nx.byte_size(tensor) + assert 28 = Nx.bit_size(tensor) + end + + test "u2" do + tensor = Nx.u2(1) + assert <<1::2-native>> = Nx.to_binary(tensor) + + tensor = Nx.u2([1, 2, 3]) + assert tensor.type == {:u, 2} + assert <<1::2-native, 2::2-native, 3::2-native>> = Nx.to_binary(tensor) + assert [1, 2, 3] = Nx.to_flat_list(tensor) + assert 0 = Nx.byte_size(tensor) + assert 6 = Nx.bit_size(tensor) + + tensor = Nx.u2([0, 1, 2, 3, 2, 1, 0]) + assert 1 = Nx.byte_size(tensor) + assert 14 = Nx.bit_size(tensor) + end + + test "u4" do + tensor = Nx.u4(1) + assert <<1::4-native>> = Nx.to_binary(tensor) + + tensor = Nx.u4([0, 7, 15]) + assert tensor.type == {:u, 4} + assert <<0::4-native, 7::4-native, 15::4-native>> = Nx.to_binary(tensor) + assert [0, 7, 15] = Nx.to_flat_list(tensor) + assert 1 = Nx.byte_size(tensor) + assert 12 = Nx.bit_size(tensor) + + tensor = Nx.u4([0, 1, 2, 3, 13, 14, 15]) + assert 3 = Nx.byte_size(tensor) + assert 28 = Nx.bit_size(tensor) + end + end end diff --git a/nx/README.md b/nx/README.md index 45f2676209c..a4e2b653ae6 100644 --- a/nx/README.md +++ b/nx/README.md @@ -4,7 +4,7 @@ Nx is a multi-dimensional tensors library for Elixir with multi-staged compilation to the CPU/GPU. Its high-level features are: - * Typed multi-dimensional tensors, where the tensors can be unsigned integers (`u8`, `u16`, `u32`, `u64`), signed integers (`s8`, `s16`, `s32`, `s64`), floats (`f16`, `f32`, `f64`), brain floats (`bf16`), and complex numbers (`c64`, `c128`); + * Typed multi-dimensional tensors, where the tensors can be unsigned integers (`u2`, `u4`, `u8`, `u16`, `u32`, `u64`), signed integers (`s2`, `s4`, `s8`, `s16`, `s32`, `s64`), floats (`f16`, `f32`, `f64`), brain floats (`bf16`), and complex numbers (`c64`, `c128`); * Named tensors, allowing developers to give names to each dimension, leading to more readable and less error prone codebases; diff --git a/nx/guides/intro-to-nx.livemd b/nx/guides/intro-to-nx.livemd index fa2711c1b0d..e44214735d5 100644 --- a/nx/guides/intro-to-nx.livemd +++ b/nx/guides/intro-to-nx.livemd @@ -24,8 +24,8 @@ libraries that support those tensors. Nx has three primary capabilities: such as machine learning, simulations, curve fitting, and probabilistic models. Here's more about each of those capabilities. Nx [tensors]() can hold -unsigned integers (u8, u16, u32, u64), -signed integers (s8, s16, s32, s64), +unsigned integers (u2, u4, u8, u16, u32, u64), +signed integers (s2, s4s8, s16, s32, s64), floats (f32, f64), brain floats (bf16), and complex (c64, c128). Tensors support backends implemented outside of Elixir, including Google's Accelerated Linear Algebra (XLA) and LibTorch. diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 2667f6d7761..320e682d22f 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -49,8 +49,8 @@ defmodule Nx do The tensor types can be one of: - * unsigned integers (`u8`, `u16`, `u32`, `u64`) - * signed integers (`s8`, `s16`, `s32`, `s64`) + * unsigned integers (`u2`, `u4`, `u8`, `u16`, `u32`, `u64`) + * signed integers (`s2`, `s4`, `s8`, `s16`, `s32`, `s64`) * floats (`f8`, `f16`, `f32`, `f64`) * brain floats (`bf16`) * and complex numbers (`c64`, `c128`) @@ -431,6 +431,7 @@ defmodule Nx do import Nx.Shared import Nx.Defn.Kernel, only: [keyword!: 2] + import Kernel, except: [bit_size: 1] alias Nx.Tensor, as: T @@ -855,7 +856,7 @@ defmodule Nx do {dimensions, acc} = flatten_list(list, type, [], []) {dimensions |> Enum.reverse() |> List.to_tuple(), - acc |> Enum.reverse() |> :erlang.list_to_binary()} + acc |> Enum.reverse() |> :erlang.list_to_bitstring()} end defp flatten_list([], _type, dimensions, acc) do @@ -940,7 +941,9 @@ defmodule Nx do %T{shape: shape, type: type, names: names, data: %Nx.TemplateBackend{}} end - for t <- [:u8, :u16, :u32, :u64, :s8, :s16, :s32, :s64, :bf16, :f8, :f16, :f32, :f64] do + for t <- + [:u2, :u4, :u8, :u16, :u32, :u64, :s2, :s4, :s8, :s16, :s32, :s64] ++ + [:f8, :bf16, :f16, :f32, :f64] do @doc """ Short-hand function for creating tensor of type `#{t}`. @@ -1971,13 +1974,13 @@ defmodule Nx do def from_binary(binary, type, opts \\ []) when is_binary(binary) do opts = keyword!(opts, [:backend]) {_, size} = type = Nx.Type.normalize!(type) - dim = div(bit_size(binary), size) + dim = div(Kernel.bit_size(binary), size) if binary == "" do raise ArgumentError, "cannot build an empty tensor" end - if rem(bit_size(binary), size) != 0 do + if rem(Kernel.bit_size(binary), size) != 0 do raise ArgumentError, "binary does not match the given size" end @@ -1990,17 +1993,26 @@ defmodule Nx do @doc """ Returns the underlying tensor as a binary. - **Warning**: converting a tensor to a binary can - potentially be a very expensive operation, as it - may copy a GPU tensor fully to the machine memory. - It returns the in-memory binary representation of the tensor in a row-major fashion. The binary is in the system endianness, which has to be taken into account if the binary is meant to be serialized to other systems. - Note: This function cannot be used in `defn`. + This function cannot be used in `defn`. + + > ### Potentially expensive operation {: .warning} + > + > Converting a tensor to a binary can potentially be a very + > expensive operation, as it may copy a GPU tensor fully to + > the machine memory. + + > ### Binaries vs bitstrings {: .info} + > + > If a tensor of type u2/u4/s2/s4 is given to this function, + > this function may not return a binary (where the number of bits + > is divisible by 8) but rather a bitstring (where the number of + > bits may not be divisible by 8). ## Options @@ -4286,6 +4298,10 @@ defmodule Nx do Returns the byte size of the data in the tensor computed from its shape and type. + If the tensor has s2/s4/u2/u4 types, the value + will be rounded down. Consider using `bit_size/1` + instead. + ## Examples iex> Nx.byte_size(Nx.tensor([[1, 2, 3], [4, 5, 6]])) @@ -4304,9 +4320,33 @@ defmodule Nx do """ @doc type: :shape - def byte_size(tensor) do + def byte_size(tensor), do: div(bit_size(tensor), 8) + + @doc """ + Returns the bit size of the data in the tensor + computed from its shape and type. + + ## Examples + + iex> Nx.bit_size(Nx.tensor([[1, 2, 3], [4, 5, 6]])) + 192 + iex> Nx.bit_size(Nx.tensor([[1, 2, 3], [4, 5, 6]], type: :u8)) + 48 + iex> Nx.bit_size(Nx.tensor([[1, 2, 3], [3, 2, 1]], type: :u2)) + 12 + iex> Nx.bit_size(1) + 32 + + Vectorized tensors account for all elements + + iex> Nx.bit_size(Nx.tensor([[1, 2], [3, 4]]) |> Nx.vectorize(:x)) + 128 + + """ + @doc type: :shape + def bit_size(tensor) do %{type: {_, bit_size}} = tensor = to_tensor(tensor) - flat_size(tensor) * div(bit_size, 8) + flat_size(tensor) * bit_size end @doc """ @@ -15466,9 +15506,9 @@ defmodule Nx do defp do_numpy_to_tensor(rest, header_size) when is_binary(rest) do <> = rest {byte_order, {_, size} = type, shape, fortran_order?} = parse_header(header) - byte_size_of_array = div(size, 8) * Nx.size(shape) + bit_size_of_array = size * Nx.size(shape) - <> = array + <> = array data |> new_byte_order(size, byte_order) diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index 49fd97beb58..583b964bb69 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -20,6 +20,9 @@ defmodule Nx.BinaryBackend do import Nx.Shared import Bitwise, only: [>>>: 2, &&&: 2] + # Remove functions which work at the byte-level. We need to work at the bit-level. + import Kernel, except: [byte_size: 1, binary_part: 3, binary_slice: 2, binary_slice: 3] + @impl true def init(opts) do if opts != [] do @@ -33,7 +36,7 @@ defmodule Nx.BinaryBackend do @impl true def constant(%{type: type, shape: shape} = out, constant, _backend_options) do - data = :binary.copy(number_to_binary(constant, type), Nx.size(shape)) + data = bitstring_copy(number_to_binary(constant, type), Nx.size(shape)) from_binary(out, data) end @@ -107,33 +110,28 @@ defmodule Nx.BinaryBackend do def from_binary(t, binary, _backend_options), do: from_binary(t, binary) if Application.compile_env(:nx, :verify_binary_size) do - defp from_binary(%{type: {_, bitsize}, shape: shape} = t, binary) when is_binary(binary) do - actual = byte_size(binary) - expected = Tuple.product(shape) * div(bitsize, 8) + defp from_binary(%{type: {_, bitsize}, shape: shape} = t, binary) when is_bitstring(binary) do + actual = bit_size(binary) + expected = Tuple.product(shape) * bitsize unless actual == expected do raise ArgumentError, - "unexpected size for tensor data, expected #{expected} bytes got: #{actual} bytes" + "unexpected size for tensor data, expected #{expected} bits got: #{actual} bits" end %{t | data: %B{state: binary}} end else - defp from_binary(t, binary) when is_binary(binary), do: %{t | data: %B{state: binary}} + defp from_binary(t, binary) when is_bitstring(binary), do: %{t | data: %B{state: binary}} end - defp from_binary(t, other), do: from_binary(t, IO.iodata_to_binary(other)) + defp from_binary(t, other), do: from_binary(t, :erlang.list_to_bitstring(other)) @impl true def to_binary(%{type: {_backend_options, size}} = t, limit) do - limit = limit * div(size, 8) - binary = to_binary(t) - - if byte_size(binary) == limit do - binary - else - binary_part(binary, 0, limit) - end + t + |> to_binary() + |> bitstring_part(0, limit * size) end defp to_binary(%T{data: %{state: data}}), do: data @@ -189,18 +187,21 @@ defmodule Nx.BinaryBackend do end binary = to_binary(tensor) - batch_bytes = Nx.size(out) * div(size, 8) + batch_bits = Nx.size(out) * size Stream.map(range, fn ^num_full_batches -> - before = num_full_batches * batch_bytes - available = byte_size(binary) - before - missing = batch_bytes - available + before = num_full_batches * batch_bits + available = bit_size(binary) - before + missing = batch_bits - available - from_binary(out, [binary_part(binary, before, available), binary_part(binary, 0, missing)]) + from_binary(out, [ + bitstring_part(binary, before, available), + bitstring_part(binary, 0, missing) + ]) i -> - from_binary(out, binary_part(binary, i * batch_bytes, batch_bytes)) + from_binary(out, bitstring_part(binary, i * batch_bits, batch_bits)) end) end @@ -237,7 +238,7 @@ defmodule Nx.BinaryBackend do new_shape |> Tuple.to_list() |> unary_broadcast(0, old_shape, 0, axes, to_binary(t), chunk_size) - |> IO.iodata_to_binary() + |> :erlang.list_to_bitstring() end # Old and new match @@ -376,7 +377,7 @@ defmodule Nx.BinaryBackend do <> end - new_bytes = byte_size(padded) * 8 - interior_padding_size + new_bytes = bit_size(padded) - interior_padding_size <> = padded new_bin end @@ -387,7 +388,7 @@ defmodule Nx.BinaryBackend do edge_low < 0 and edge_high < 0 -> low_byte = abs(edge_low) * size high_byte = abs(edge_high) * size - new_bytes = byte_size(bin) * 8 - high_byte - low_byte + new_bytes = bit_size(bin) - high_byte - low_byte <<_::size(low_byte)-bitstring, new_bin::size(new_bytes)-bitstring, _::bitstring>> = bin @@ -401,7 +402,7 @@ defmodule Nx.BinaryBackend do edge_low >= 0 and edge_high < 0 -> high_byte = abs(edge_high) * size - new_bytes = byte_size(bin) * 8 - high_byte + new_bytes = bit_size(bin) - high_byte <> = bin <> @@ -1159,7 +1160,7 @@ defmodule Nx.BinaryBackend do window = batch_weighted_shape |> weighted_traverse(batch, input_size, offset) - |> IO.iodata_to_binary() + |> :erlang.list_to_bitstring() # The receptive field size of each binary in bytes input_field_size = Nx.size(filter_shape) * input_size @@ -1290,21 +1291,21 @@ defmodule Nx.BinaryBackend do m = elem(a_shape, tuple_size(a_shape) - 1) - a_batch_byte_size = (m * m * a_size) |> div(8) - batches_num = byte_size(a_data) |> div(a_batch_byte_size) + a_batch_bit_size = m * m * a_size + batches_num = bit_size(a_data) |> div(a_batch_bit_size) a_batches = Enum.map( 0..(batches_num - 1), - &binary_part(a_data, &1 * a_batch_byte_size, a_batch_byte_size) + &bitstring_part(a_data, &1 * a_batch_bit_size, a_batch_bit_size) ) - b_batch_byte_size = byte_size(b_data) |> div(batches_num) + b_batch_bit_size = bit_size(b_data) |> div(batches_num) b_batches = Enum.map( 0..(batches_num - 1), - &binary_part(b_data, &1 * b_batch_byte_size, b_batch_byte_size) + &bitstring_part(b_data, &1 * b_batch_bit_size, b_batch_bit_size) ) b_batch_shape = @@ -1503,7 +1504,9 @@ defmodule Nx.BinaryBackend do match_types [type] do for anchor <- anchors, into: <<>> do offset = weighted_offset(weighted_shape, anchor, dilations) - window = IO.iodata_to_binary(weighted_traverse(weighted_shape, data, size, offset)) + + window = + :erlang.list_to_bitstring(weighted_traverse(weighted_shape, data, size, offset)) window_val = for <>, @@ -1619,7 +1622,9 @@ defmodule Nx.BinaryBackend do offset = weighted_offset(input_weighted_shape, anchor) window = - IO.iodata_to_binary(weighted_traverse(input_weighted_shape, input_data, size, offset)) + :erlang.list_to_bitstring( + weighted_traverse(input_weighted_shape, input_data, size, offset) + ) # Get the index where `select_fn` is true {_, index, _} = @@ -1674,14 +1679,14 @@ defmodule Nx.BinaryBackend do num_vals_before = div(offset - acc_offset, output_size) vals_before = List.duplicate(init_binary, num_vals_before) source_val = to_binary(value) - new_binary = IO.iodata_to_binary([vals_before, source_val]) + new_binary = :erlang.list_to_bitstring([vals_before, source_val]) {offset + output_size, <>} end num_vals_left = div(output_size * Nx.size(output_shape) - final_offset, output_size) - vals_left = IO.iodata_to_binary(List.duplicate(init_binary, num_vals_left)) + vals_left = :erlang.list_to_bitstring(List.duplicate(init_binary, num_vals_left)) output_data = <> from_binary(out, output_data) @@ -1804,7 +1809,7 @@ defmodule Nx.BinaryBackend do binary_to_binary(tail, target.type, out.type, & &1) end - from_binary(out, IO.iodata_to_binary([result, tail])) + from_binary(out, :erlang.list_to_bitstring([result, tail])) end) end @@ -1851,9 +1856,9 @@ defmodule Nx.BinaryBackend do start_indices = clamp_indices(start_indices, shape, lengths) if hd(strides) == 1 and top_dimension_slice?(tuple_size(shape), shape, output_shape) do - length = Nx.size(output_shape) * div(size, 8) + length = Nx.size(output_shape) * size offset = div(length, elem(output_shape, 0)) * hd(start_indices) - binary_part(data, offset, length) + bitstring_part(data, offset, length) else # Anchored around the start indices weighted_shape = weighted_shape(shape, size, output_shape) @@ -1866,7 +1871,7 @@ defmodule Nx.BinaryBackend do {d, dim_size + (s - 1) * dim_size} end) - IO.iodata_to_binary(weighted_traverse(weighted_shape, data, size, offset)) + :erlang.list_to_bitstring(weighted_traverse(weighted_shape, data, size, offset)) end end @@ -1950,8 +1955,7 @@ defmodule Nx.BinaryBackend do last_dim_bin_size = indices_depth * indices_size data = to_binary(tensor) - byte_size = div(size, 8) - byte_count = div(Tuple.product(out.shape), div(indices_count, indices_depth)) + count = div(Tuple.product(out.shape), div(indices_count, indices_depth)) new_data = for <>, into: <<>> do @@ -1960,7 +1964,7 @@ defmodule Nx.BinaryBackend do do: binary_to_number(bin, indices.type) offset = index_to_binary_offset(slice_start, shape) - binary_part(data, offset * byte_size, byte_size * byte_count) + bitstring_part(data, offset * size, size * count) end from_binary(out, new_data) @@ -2012,14 +2016,14 @@ defmodule Nx.BinaryBackend do end defp bin_concatenate(binaries_shapes, _size, 0, _output_shape) do - binaries_shapes |> Enum.map(&elem(&1, 0)) |> IO.iodata_to_binary() + binaries_shapes |> Enum.map(&elem(&1, 0)) |> :erlang.list_to_bitstring() end defp bin_concatenate(binaries_shapes, size, axis, output_shape) do rank = tuple_size(output_shape) steps = product_part(output_shape, 0, axis) - # We don't use lists plus IO.iodata_to_binary on purpose. + # We don't use lists plus :erlang.list_to_bitstring on purpose. # Because the number of steps can be really large, we could create large # intermediate lists. So we build the binary directly. bin_concatenate_outer(0, steps, binaries_shapes, "", fn binary, shape, step -> @@ -2202,7 +2206,7 @@ defmodule Nx.BinaryBackend do Enum.sort(data, comparator) end - IO.iodata_to_binary(sorted) + :erlang.list_to_bitstring(sorted) end from_binary(output, new_data) @@ -2427,18 +2431,23 @@ defmodule Nx.BinaryBackend do end defp bin_batch_reduce(bin, batch_size, {_, size}, acc, fun) do - batch_byte_size = (batch_size * size) |> div(8) - batches = byte_size(bin) |> div(batch_byte_size) + batch_bit_size = batch_size * size + batches = bit_size(bin) |> div(batch_bit_size) for i <- 0..(batches - 1), reduce: acc do acc -> - batch = binary_part(bin, i * batch_byte_size, batch_byte_size) + batch = bitstring_part(bin, i * batch_bit_size, batch_bit_size) fun.(batch, acc) end end ## Conversion helpers + defp bitstring_part(bitstring, skip, size) do + <<_::size(skip)-bitstring, part::size(size)-bitstring, _::bitstring>> = bitstring + part + end + defp scalar_to_number(n) when is_number(n) or n in [:nan, :neg_infinity, :infinity], do: n defp scalar_to_number(%Complex{} = n), do: n defp scalar_to_number(t), do: binary_to_number(to_binary(t), t.type) @@ -2541,7 +2550,7 @@ defmodule Nx.BinaryBackend do {reverse_pos, read_size} = aggregate_read(reverse_pos, tuple_size(shape) - 1, Enum.reverse(axes), size) - path = Enum.reverse(reverse_pre, [(&IO.iodata_to_binary/1) | Enum.reverse(reverse_pos)]) + path = Enum.reverse(reverse_pre, [(&:erlang.list_to_bitstring/1) | Enum.reverse(reverse_pos)]) {chunk_size, read_size, path} end @@ -2692,4 +2701,8 @@ defmodule Nx.BinaryBackend do div(size, dilation_factor) * x + weighted_offset(dims, pos, dilation) end + + defp bitstring_copy(bitstring, n) do + for _ <- 1..n, into: <<>>, do: bitstring + end end diff --git a/nx/lib/nx/type.ex b/nx/lib/nx/type.ex index a20af58bb27..44f72cf6363 100644 --- a/nx/lib/nx/type.ex +++ b/nx/lib/nx/type.ex @@ -25,10 +25,14 @@ defmodule Nx.Type do """ @type t :: - {:s, 8} + {:s, 2} + | {:s, 4} + | {:s, 8} | {:s, 16} | {:s, 32} | {:s, 64} + | {:u, 2} + | {:u, 4} | {:u, 8} | {:u, 16} | {:u, 32} @@ -64,6 +68,8 @@ defmodule Nx.Type do """ def min_finite_binary(type) + def min_finite_binary({:s, 2}), do: <<-2::2-signed-native>> + def min_finite_binary({:s, 4}), do: <<-8::4-signed-native>> def min_finite_binary({:s, 8}), do: <<-128::8-signed-native>> def min_finite_binary({:s, 16}), do: <<-32768::16-signed-native>> def min_finite_binary({:s, 32}), do: <<-2_147_483_648::32-signed-native>> @@ -87,10 +93,14 @@ defmodule Nx.Type do """ def max_finite_binary(type) + def max_finite_binary({:s, 2}), do: <<1::2-signed-native>> + def max_finite_binary({:s, 4}), do: <<7::4-signed-native>> def max_finite_binary({:s, 8}), do: <<127::8-signed-native>> def max_finite_binary({:s, 16}), do: <<32767::16-signed-native>> def max_finite_binary({:s, 32}), do: <<2_147_483_647::32-signed-native>> def max_finite_binary({:s, 64}), do: <<9_223_372_036_854_775_807::64-signed-native>> + def max_finite_binary({:u, 2}), do: <<3::2-native>> + def max_finite_binary({:u, 4}), do: <<15::4-native>> def max_finite_binary({:u, 8}), do: <<255::8-native>> def max_finite_binary({:u, 16}), do: <<65535::16-native>> def max_finite_binary({:u, 32}), do: <<4_294_967_295::32-native>> @@ -191,8 +201,8 @@ defmodule Nx.Type do end type_variants = [ - s: [8, 16, 32, 64], - u: [8, 16, 32, 64], + s: [2, 4, 8, 16, 32, 64], + u: [2, 4, 8, 16, 32, 64], f: [8, 16, 32, 64], bf: [16], c: [64, 128] diff --git a/nx/test/nx_test.exs b/nx/test/nx_test.exs index 98aa75c9902..80ea1e64ba4 100644 --- a/nx/test/nx_test.exs +++ b/nx/test/nx_test.exs @@ -2430,7 +2430,7 @@ defmodule NxTest do end test "works with all integer types in indices" do - for kind <- [:u, :s], width <- [8, 16, 32, 64] do + for kind <- [:u, :s], width <- [2, 4, 8, 16, 32, 64] do indices = Nx.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], type: {kind, width}) assert Nx.add(Nx.iota({2, 2}), 1) == @@ -3259,4 +3259,64 @@ defmodule NxTest do ]) end end + + describe "quantized types" do + test "s2" do + tensor = Nx.s2([-2, -1, 1]) + assert tensor.type == {:s, 2} + + assert <<-2::2-signed-native, -1::2-signed-native, 1::2-signed-native>> = + Nx.to_binary(tensor) + + assert [-2, -1, 1] = Nx.to_flat_list(tensor) + assert 0 = Nx.byte_size(tensor) + assert 6 = Nx.bit_size(tensor) + + tensor = Nx.s2([-2, -1, 0, 1, 0, -1, -2]) + assert 1 = Nx.byte_size(tensor) + assert 14 = Nx.bit_size(tensor) + end + + test "s4" do + tensor = Nx.s4([-8, -1, 7]) + assert tensor.type == {:s, 4} + + assert <<-8::4-signed-native, -1::4-signed-native, 7::4-signed-native>> = + Nx.to_binary(tensor) + + assert [-8, -1, 7] = Nx.to_flat_list(tensor) + assert 1 = Nx.byte_size(tensor) + assert 12 = Nx.bit_size(tensor) + + tensor = Nx.s4([-8, -3, 0, 7, 0, -3, -8]) + assert 3 = Nx.byte_size(tensor) + assert 28 = Nx.bit_size(tensor) + end + + test "u2" do + tensor = Nx.u2([1, 2, 3]) + assert tensor.type == {:u, 2} + assert <<1::2-native, 2::2-native, 3::2-native>> = Nx.to_binary(tensor) + assert [1, 2, 3] = Nx.to_flat_list(tensor) + assert 0 = Nx.byte_size(tensor) + assert 6 = Nx.bit_size(tensor) + + tensor = Nx.u2([0, 1, 2, 3, 2, 1, 0]) + assert 1 = Nx.byte_size(tensor) + assert 14 = Nx.bit_size(tensor) + end + + test "u4" do + tensor = Nx.u4([0, 7, 15]) + assert tensor.type == {:u, 4} + assert <<0::4-native, 7::4-native, 15::4-native>> = Nx.to_binary(tensor) + assert [0, 7, 15] = Nx.to_flat_list(tensor) + assert 1 = Nx.byte_size(tensor) + assert 12 = Nx.bit_size(tensor) + + tensor = Nx.u4([0, 1, 2, 3, 13, 14, 15]) + assert 3 = Nx.byte_size(tensor) + assert 28 = Nx.bit_size(tensor) + end + end end diff --git a/torchx/lib/torchx/backend.ex b/torchx/lib/torchx/backend.ex index cfc2b637b64..18f453dc6d7 100644 --- a/torchx/lib/torchx/backend.ex +++ b/torchx/lib/torchx/backend.ex @@ -223,6 +223,14 @@ defmodule Torchx.Backend do for <>, into: <<>>, do: <> end + defp maybe_pad_binary(bin, {:u, size}) when size in [2, 4] do + for <>, into: <<>>, do: <> + end + + defp maybe_pad_binary(bin, {:s, size}) when size in [2, 4] do + for <>, into: <<>>, do: <> + end + defp maybe_pad_binary(bin, _), do: bin ## Shape @@ -1723,10 +1731,14 @@ defmodule Torchx.Backend do def from_torch_type(:complex_double), do: {:c, 128} defp to_torch_type(nx_type, hint \\ "") + defp to_torch_type({:u, 2}, _), do: :byte + defp to_torch_type({:u, 4}, _), do: :byte defp to_torch_type({:u, 8}, _), do: :byte defp to_torch_type({:u, 16}, _), do: :int defp to_torch_type({:u, 32}, _), do: :long defp to_torch_type({:u, 64}, _), do: :long + defp to_torch_type({:s, 2}, _), do: :char + defp to_torch_type({:s, 4}, _), do: :char defp to_torch_type({:s, 8}, _), do: :char defp to_torch_type({:s, 16}, _), do: :short defp to_torch_type({:s, 32}, _), do: :int @@ -1744,6 +1756,12 @@ defmodule Torchx.Backend do current_type = Torchx.scalar_type(device_ref) |> from_torch_type() case {current_type, type} do + {{:s, 8}, {:s, qint}} when qint in [2, 4] -> + :ok + + {{:u, 8}, {:u, qint}} when qint in [2, 4] -> + :ok + {{:s, 32}, {:u, 16}} -> :ok