diff --git a/lib/safetensors.ex b/lib/safetensors.ex index 4258196..88d3730 100644 --- a/lib/safetensors.ex +++ b/lib/safetensors.ex @@ -41,6 +41,70 @@ defmodule Safetensors do @dtype_to_type for {k, v} <- @type_to_dtype, into: %{}, do: {v, k} + @doc """ + Writes a map of tensors to a file. + + Tensors are written into the file one by one, without the need to + dump all of them into the memory at once. + """ + @spec write!(path :: Path.t(), %{String.t() => Nx.Tensor.t()}) :: :ok + def write!(path, tensors) when is_map(tensors) do + File.open!(path, [:write, :raw], fn file -> + tensors = Enum.sort(tensors) + + {header_entries, _offset} = + Enum.map_reduce(tensors, 0, fn {tensor_name, tensor}, offset -> + tensor_header_entry(tensor_name, tensor, offset) + end) + + :ok = :file.write(file, header_binary(header_entries)) + + for {_tensor_name, tensor} <- tensors do + :ok = :file.write(file, tensor_to_binary(tensor)) + end + end) + + :ok + end + + defp tensor_header_entry(tensor_name, tensor, offset) do + end_offset = offset + tensor_byte_size(tensor) + + header_entry = { + tensor_name, + Jason.OrderedObject.new( + dtype: tensor |> Nx.type() |> type_to_dtype(), + shape: tensor |> Nx.shape() |> Tuple.to_list(), + data_offsets: [offset, end_offset] + ) + } + + {header_entry, end_offset} + end + + defp header_binary(header_entries) do + header_json = + header_entries + |> Jason.OrderedObject.new() + |> Jason.encode!() + + [<>, header_json] + end + + defp tensor_byte_size(tensor) do + {_, elem_size} = Nx.type(tensor) + elem_byte_size = div(elem_size, 8) + Nx.size(tensor) * elem_byte_size + end + + defp tensor_to_binary(tensor) do + {_, elem_size} = Nx.type(tensor) + + tensor + |> Nx.to_binary() + |> new_byte_order(elem_size, :little) + end + @doc """ Serializes the given map of tensors to iodata. @@ -50,46 +114,23 @@ defmodule Safetensors do """ @spec dump(%{String.t() => Nx.Tensor.t()}) :: iodata() def dump(tensors) when is_map(tensors) do + tensors = Enum.sort(tensors) + {header_entries, {buffer, _offset}} = Enum.map_reduce(tensors, {[], 0}, fn {tensor_name, tensor}, {buffer, offset} -> - {_, elem_size} = Nx.type(tensor) - - binary = - tensor - |> Nx.to_binary() - |> new_byte_order(elem_size, :little) - - end_offset = offset + byte_size(binary) - - header_entry = { - tensor_name, - Jason.OrderedObject.new( - dtype: tensor |> Nx.type() |> type_to_dtype(), - shape: tensor |> Nx.shape() |> Tuple.to_list(), - data_offsets: [offset, end_offset] - ) - } - + {header_entry, end_offset} = tensor_header_entry(tensor_name, tensor, offset) + binary = tensor_to_binary(tensor) {header_entry, {[buffer, binary], end_offset}} end) - header_json = - header_entries - |> Jason.OrderedObject.new() - |> Jason.encode!() - - [ - <>, - header_json, - buffer - ] + [header_binary(header_entries), buffer] end @doc """ - Reads a safe tensor from file. + Reads a serialized map of tensors from a file. - Tensors are loaded into Nx one by one, - without the need to load the entire file from disk into memory. + Tensors are loaded into Nx one by one, without the need to load the + entire file from disk into memory. """ @spec read!(path :: Path.t()) :: %{String.t() => Nx.Tensor.t()} def read!(path) do diff --git a/test/safetensors_test.exs b/test/safetensors_test.exs index ea444e4..a20430b 100644 --- a/test/safetensors_test.exs +++ b/test/safetensors_test.exs @@ -3,6 +3,20 @@ defmodule SafetensorsTest do doctest Safetensors + @tag :tmp_dir + test "write", %{tmp_dir: tmp_dir} do + path = Path.join(tmp_dir, "safetensor") + + data = %{test: Nx.tensor([[1, 2], [3, 4]], type: :s32)} + Safetensors.write!(path, data) + + # source: + # https://github.com/huggingface/safetensors/blob/1a65a3fdebcf280ef0ca32934901d3e2ad3b2c65/bindings/python/tests/test_simple.py#L22-L25 + # with the header padding removed and changed numbers + assert File.read!(path) == + ~s(<\x00\x00\x00\x00\x00\x00\x00{"test":{"dtype":"I32","shape":[2,2],"data_offsets":[0,16]}}\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x04\x00\x00\x00) + end + test "dump" do binary = %{test: Nx.tensor([[1, 2], [3, 4]], type: :s32)}