Skip to content

Commit

Permalink
Add quantized int types
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Sep 4, 2024
1 parent ff7ad85 commit 81a2c5e
Show file tree
Hide file tree
Showing 8 changed files with 197 additions and 70 deletions.
4 changes: 4 additions & 0 deletions exla/lib/exla/typespec.ex
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,14 @@ defmodule EXLA.Typespec do
type_to_charlist = %{
:token => ~c"token",
{:pred, 8} => ~c"pred",
{:s, 2} => ~c"s2",
{:s, 4} => ~c"s4",
{:s, 8} => ~c"s8",
{:s, 16} => ~c"s16",
{:s, 32} => ~c"s32",
{:s, 64} => ~c"s64",
{:u, 2} => ~c"u2",
{:u, 4} => ~c"u4",
{:u, 8} => ~c"u8",
{:u, 16} => ~c"u16",
{:u, 32} => ~c"u32",
Expand Down
2 changes: 1 addition & 1 deletion nx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

Nx is a multi-dimensional tensors library for Elixir with multi-staged compilation to the CPU/GPU. Its high-level features are:

* Typed multi-dimensional tensors, where the tensors can be unsigned integers (`u8`, `u16`, `u32`, `u64`), signed integers (`s8`, `s16`, `s32`, `s64`), floats (`f16`, `f32`, `f64`), brain floats (`bf16`), and complex numbers (`c64`, `c128`);
* Typed multi-dimensional tensors, where the tensors can be unsigned integers (`u2`, `u4`, `u8`, `u16`, `u32`, `u64`), signed integers (`s2`, `s4`, `s8`, `s16`, `s32`, `s64`), floats (`f16`, `f32`, `f64`), brain floats (`bf16`), and complex numbers (`c64`, `c128`);

* Named tensors, allowing developers to give names to each dimension, leading to more readable and less error prone codebases;

Expand Down
4 changes: 2 additions & 2 deletions nx/guides/intro-to-nx.livemd
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ libraries that support those tensors. Nx has three primary capabilities:
such as machine learning, simulations, curve fitting, and probabilistic models.

Here's more about each of those capabilities. Nx [tensors]() can hold
unsigned integers (u8, u16, u32, u64),
signed integers (s8, s16, s32, s64),
unsigned integers (u2, u4, u8, u16, u32, u64),
signed integers (s2, s4s8, s16, s32, s64),
floats (f32, f64), brain floats (bf16), and complex (c64, c128).
Tensors support backends implemented outside of Elixir, including Google's
Accelerated Linear Algebra (XLA) and LibTorch.
Expand Down
70 changes: 55 additions & 15 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ defmodule Nx do
The tensor types can be one of:
* unsigned integers (`u8`, `u16`, `u32`, `u64`)
* signed integers (`s8`, `s16`, `s32`, `s64`)
* unsigned integers (`u2`, `u4`, `u8`, `u16`, `u32`, `u64`)
* signed integers (`s2`, `s4`, `s8`, `s16`, `s32`, `s64`)
* floats (`f16`, `f32`, `f64`)
* brain floats (`bf16`)
* and complex numbers (`c64`, `c128`)
Expand Down Expand Up @@ -431,6 +431,7 @@ defmodule Nx do

import Nx.Shared
import Nx.Defn.Kernel, only: [keyword!: 2]
import Kernel, except: [bit_size: 1]

alias Nx.Tensor, as: T

Expand Down Expand Up @@ -844,7 +845,7 @@ defmodule Nx do
{dimensions, acc} = flatten_list(list, type, [], [])

{dimensions |> Enum.reverse() |> List.to_tuple(),
acc |> Enum.reverse() |> :erlang.list_to_binary()}
acc |> Enum.reverse() |> :erlang.list_to_bitstring()}
end

defp flatten_list([], _type, dimensions, acc) do
Expand Down Expand Up @@ -929,7 +930,9 @@ defmodule Nx do
%T{shape: shape, type: type, names: names, data: %Nx.TemplateBackend{}}
end

for t <- [:u8, :u16, :u32, :u64, :s8, :s16, :s32, :s64, :bf16, :f16, :f32, :f64] do
for t <-
[:u2, :u4, :u8, :u16, :u32, :u64, :s2, :s4, :s8, :s16, :s32, :s64] ++
[:bf16, :f16, :f32, :f64] do
@doc """
Short-hand function for creating tensor of type `#{t}`.
Expand Down Expand Up @@ -1960,13 +1963,13 @@ defmodule Nx do
def from_binary(binary, type, opts \\ []) when is_binary(binary) do
opts = keyword!(opts, [:backend])
{_, size} = type = Nx.Type.normalize!(type)
dim = div(bit_size(binary), size)
dim = div(Kernel.bit_size(binary), size)

if binary == "" do
raise ArgumentError, "cannot build an empty tensor"
end

if rem(bit_size(binary), size) != 0 do
if rem(Kernel.bit_size(binary), size) != 0 do
raise ArgumentError, "binary does not match the given size"
end

Expand All @@ -1979,17 +1982,26 @@ defmodule Nx do
@doc """
Returns the underlying tensor as a binary.
**Warning**: converting a tensor to a binary can
potentially be a very expensive operation, as it
may copy a GPU tensor fully to the machine memory.
It returns the in-memory binary representation of
the tensor in a row-major fashion. The binary is
in the system endianness, which has to be taken into
account if the binary is meant to be serialized to
other systems.
Note: This function cannot be used in `defn`.
This function cannot be used in `defn`.
> ### Potentially expensive operation {: .warning}
>
> Converting a tensor to a binary can potentially be a very
> expensive operation, as it may copy a GPU tensor fully to
> the machine memory.
> ### Binaries vs bitstrings {: .info}
>
> If a tensor of type u2/u4/s2/s4 is given to this function,
> this function may not return a binary (where the number of bits
> is divisible by 8) but rather a bitstring (where the number of
> bits may not be divisible by 8).
## Options
Expand Down Expand Up @@ -4275,6 +4287,10 @@ defmodule Nx do
Returns the byte size of the data in the tensor
computed from its shape and type.
If the tensor has s2/s4/u2/u4 types, the value
will be rounded down. Consider using `bit_size/1`
instead.
## Examples
iex> Nx.byte_size(Nx.tensor([[1, 2, 3], [4, 5, 6]]))
Expand All @@ -4293,9 +4309,33 @@ defmodule Nx do
"""
@doc type: :shape
def byte_size(tensor) do
def byte_size(tensor), do: div(bit_size(tensor), 8)

@doc """
Returns the bit size of the data in the tensor
computed from its shape and type.
## Examples
iex> Nx.bit_size(Nx.tensor([[1, 2, 3], [4, 5, 6]]))
192
iex> Nx.bit_size(Nx.tensor([[1, 2, 3], [4, 5, 6]], type: :u8))
48
iex> Nx.bit_size(Nx.tensor([[1, 2, 3], [3, 2, 1]], type: :u2))
12
iex> Nx.bit_size(1)
32
Vectorized tensors account for all elements
iex> Nx.bit_size(Nx.tensor([[1, 2], [3, 4]]) |> Nx.vectorize(:x))
128
"""
@doc type: :shape
def bit_size(tensor) do
%{type: {_, bit_size}} = tensor = to_tensor(tensor)
flat_size(tensor) * div(bit_size, 8)
flat_size(tensor) * bit_size
end

@doc """
Expand Down Expand Up @@ -15455,9 +15495,9 @@ defmodule Nx do
defp do_numpy_to_tensor(rest, header_size) when is_binary(rest) do
<<header::size(header_size)-binary, array::binary>> = rest
{byte_order, {_, size} = type, shape, fortran_order?} = parse_header(header)
byte_size_of_array = div(size, 8) * Nx.size(shape)
bit_size_of_array = size * Nx.size(shape)

<<data::size(byte_size_of_array)-binary>> = array
<<data::size(bit_size_of_array)-bitstring>> = array

data
|> new_byte_order(size, byte_order)
Expand Down
Loading

0 comments on commit 81a2c5e

Please sign in to comment.