Skip to content

Commit

Permalink
Add option to read tensors lazily (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Feb 23, 2024
1 parent 2287789 commit 5d530c7
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 35 deletions.
76 changes: 42 additions & 34 deletions lib/safetensors.ex
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ defmodule Safetensors do
"""

alias Safetensors.Shared

@header_metadata_key "__metadata__"

@type_to_dtype %{
Expand Down Expand Up @@ -60,7 +62,7 @@ defmodule Safetensors do
:ok = :file.write(file, header_binary(header_entries))

for {_tensor_name, tensor} <- tensors do
:ok = :file.write(file, tensor_to_binary(tensor))
:ok = :file.write(file, tensor_to_iodata(tensor))
end
end)

Expand Down Expand Up @@ -97,12 +99,12 @@ defmodule Safetensors do
Nx.size(tensor) * elem_byte_size
end

defp tensor_to_binary(tensor) do
defp tensor_to_iodata(tensor) do
{_, elem_size} = Nx.type(tensor)

tensor
|> Nx.to_binary()
|> new_byte_order(elem_size, :little)
|> Shared.new_byte_order(elem_size, :little)
end

@doc """
Expand All @@ -119,7 +121,7 @@ defmodule Safetensors do
{header_entries, {buffer, _offset}} =
Enum.map_reduce(tensors, {[], 0}, fn {tensor_name, tensor}, {buffer, offset} ->
{header_entry, end_offset} = tensor_header_entry(tensor_name, tensor, offset)
binary = tensor_to_binary(tensor)
binary = tensor_to_iodata(tensor)
{header_entry, {[buffer, binary], end_offset}}
end)

Expand All @@ -131,9 +133,19 @@ defmodule Safetensors do
Tensors are loaded into Nx one by one, without the need to load the
entire file from disk into memory.
## Options
* `:lazy` - when `true`, instead of returning tensors, the function
returns lazy containers. Such a container can be converted to a
tensor using `Nx.to_tensor/1` and it is only at that point that
it is read from the file. Defaults to `false`
"""
@spec read!(path :: Path.t()) :: %{String.t() => Nx.Tensor.t()}
def read!(path) do
@spec read!(path :: Path.t(), keyword()) :: %{String.t() => Nx.LazyContainer.t()}
def read!(path, opts \\ []) do
opts = Keyword.validate!(opts, lazy: false)

File.open!(path, [:read, :raw], fn file ->
{:ok, <<header_size::unsigned-64-integer-little>>} = :file.read(file, 8)
{:ok, header_json} = :file.read(file, header_size)
Expand All @@ -143,10 +155,26 @@ defmodule Safetensors do
for {tensor_name, tensor_info} <- header, into: %{} do
%{"data_offsets" => [offset_start, offset_end]} = tensor_info

{:ok, binary} =
:file.pread(file, header_size + 8 + offset_start, offset_end - offset_start)

{tensor_name, build_tensor(binary, tensor_info)}
{shape, type} = shape_and_type(tensor_info)

byte_offset = header_size + 8 + offset_start
byte_size = offset_end - offset_start

value =
if opts[:lazy] do
%Safetensors.FileTensor{
shape: shape,
type: type,
path: path,
byte_offset: byte_offset,
byte_size: byte_size
}
else
{:ok, binary} = :file.pread(file, byte_offset, byte_size)
Shared.build_tensor(binary, shape, type)
end

{tensor_name, value}
end
end)
end
Expand All @@ -170,11 +198,12 @@ defmodule Safetensors do

for {tensor_name, tensor_info} <- header, into: %{} do
%{"data_offsets" => [offset_start, offset_end]} = tensor_info
{shape, type} = shape_and_type(tensor_info)

tensor =
buffer
|> binary_slice(offset_start, offset_end - offset_start)
|> build_tensor(tensor_info)
|> Shared.build_tensor(shape, type)

{tensor_name, tensor}
end
Expand All @@ -189,14 +218,8 @@ defmodule Safetensors do
header
end

defp build_tensor(binary, tensor_info) do
%{"dtype" => dtype, "shape" => shape} = tensor_info
{_, elem_size} = type = dtype_to_type(dtype)

binary
|> new_byte_order(elem_size, :little)
|> Nx.from_binary(type)
|> Nx.reshape(List.to_tuple(shape))
defp shape_and_type(%{"dtype" => dtype, "shape" => shape}) do
{List.to_tuple(shape), dtype_to_type(dtype)}
end

defp type_to_dtype(type) do
Expand All @@ -206,19 +229,4 @@ defmodule Safetensors do
defp dtype_to_type(dtype) do
@dtype_to_type[dtype] || raise "unrecognized dtype #{inspect(dtype)}"
end

defp new_byte_order(binary, size, endianness) do
if System.endianness() == endianness do
binary
else
data =
for <<data::size(size)-binary <- binary>> do
data
|> :binary.decode_unsigned()
|> :binary.encode_unsigned(endianness)
end

IO.iodata_to_binary(data)
end
end
end
20 changes: 20 additions & 0 deletions lib/safetensors/file_tensor.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
defmodule Safetensors.FileTensor do
@moduledoc false

defstruct [:shape, :type, :path, :byte_offset, :byte_size]
end

defimpl Nx.LazyContainer, for: Safetensors.FileTensor do
def traverse(lazy_tensor, acc, fun) do
template = Nx.template(lazy_tensor.shape, lazy_tensor.type)

load = fn ->
File.open!(lazy_tensor.path, [:read, :raw], fn file ->
{:ok, binary} = :file.pread(file, lazy_tensor.byte_offset, lazy_tensor.byte_size)
Safetensors.Shared.build_tensor(binary, lazy_tensor.shape, lazy_tensor.type)
end)
end

fun.(template, load, acc)
end
end
36 changes: 36 additions & 0 deletions lib/safetensors/shared.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
defmodule Safetensors.Shared do
@moduledoc false

@doc """
Builds Nx tensor from the given safetensors binary.
"""
@spec build_tensor(binary(), tuple(), Nx.Type.t()) :: Nx.Tensor.t()
def build_tensor(binary, shape, type) do
{_, elem_size} = type

binary
|> new_byte_order(elem_size, :little)
|> IO.iodata_to_binary()
|> Nx.from_binary(type)
|> Nx.reshape(shape)
end

@doc """
Changes endianness `binary` if `endianness` does not match system.
"""
@spec new_byte_order(binary(), pos_integer(), :little | :big) :: iodata()
def new_byte_order(binary, size, endianness) do
if System.endianness() == endianness do
binary
else
data =
for <<data::size(size)-binary <- binary>> do
data
|> :binary.decode_unsigned()
|> :binary.encode_unsigned(endianness)
end

IO.iodata_to_binary(data)
end
end
end
2 changes: 1 addition & 1 deletion mix.lock
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
"makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"},
"makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"},
"nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"},
"nx": {:hex, :nx, "0.5.3", "6ad5534f9b82429dafa12329952708c2fdd6ab01b306e86333fdea72383147ee", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "d1072fc4423809ed09beb729e73c200ce177ddecac425d9eb6fba643669623ec"},
"nx": {:hex, :nx, "0.6.4", "948d9f42f81e63fc901d243ac0a985c8bb87358be62e27826cfd67f58bc640af", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "bb9c2e2e3545b5eb4739d69046a988daaa212d127dba7d97801c291616aff6d6"},
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
}
17 changes: 17 additions & 0 deletions test/safetensors_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,23 @@ defmodule SafetensorsTest do
assert Safetensors.read!(path) == %{"test" => Nx.tensor([[0, 0], [0, 0]], type: :s32)}
end

@tag :tmp_dir
test "read lazy", %{tmp_dir: tmp_dir} do
path = Path.join(tmp_dir, "safetensor")

# source:
# https://github.com/huggingface/safetensors/blob/1a65a3fdebcf280ef0ca32934901d3e2ad3b2c65/bindings/python/tests/test_simple.py#L35-L40
File.write!(
path,
~s(<\x00\x00\x00\x00\x00\x00\x00{"test":{"dtype":"I32","shape":[2,2],"data_offsets":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00)
)

assert %{"test" => %Safetensors.FileTensor{} = file_tensor} =
Safetensors.read!(path, lazy: true)

assert Nx.to_tensor(file_tensor) == Nx.tensor([[0, 0], [0, 0]], type: :s32)
end

test "load" do
# source:
# https://github.com/huggingface/safetensors/blob/1a65a3fdebcf280ef0ca32934901d3e2ad3b2c65/bindings/python/tests/test_simple.py#L35-L40
Expand Down

0 comments on commit 5d530c7

Please sign in to comment.