diff --git a/lib/safetensors.ex b/lib/safetensors.ex index 7d84356..cdd95bc 100644 --- a/lib/safetensors.ex +++ b/lib/safetensors.ex @@ -85,6 +85,31 @@ defmodule Safetensors do ] end + @doc """ + Reads a safe tensor from file. + + Tensors are loaded into Nx one by one, + without loading the whole file into disk. + """ + @spec read!(path :: Path.t()) :: %{String.t() => Nx.Tensor.t()} + def read!(path) do + File.open!(path, [:read, :raw], fn file -> + {:ok, <>} = :file.read(file, 8) + {:ok, header_json} = :file.read(file, header_size) + + header = decode_header!(header_json) + + for {tensor_name, tensor_info} <- header, into: %{} do + %{"data_offsets" => [offset_start, offset_end]} = tensor_info + + {:ok, binary} = + :file.pread(file, header_size + 8 + offset_start, offset_end - offset_start) + + {tensor_name, build_tensor(binary, tensor_info)} + end + end) + end + @doc """ Loads a serialized map of tensors. @@ -100,34 +125,39 @@ defmodule Safetensors do buffer::binary >> = data - {_metadata, header} = - header_json - |> Jason.decode!() - |> Map.pop(@header_metadata_key) + header = decode_header!(header_json) for {tensor_name, tensor_info} <- header, into: %{} do - %{ - "data_offsets" => [offset_start, offset_end], - "dtype" => dtype, - "shape" => shape - } = tensor_info - - {_, elem_size} = type = dtype_to_type(dtype) + %{"data_offsets" => [offset_start, offset_end]} = tensor_info - binary = + tensor = buffer |> binary_slice(offset_start, offset_end - offset_start) - |> new_byte_order(elem_size, :little) - - tensor = - binary - |> Nx.from_binary(type) - |> Nx.reshape(List.to_tuple(shape)) + |> build_tensor(tensor_info) {tensor_name, tensor} end end + defp decode_header!(header_json) do + {_metadata, header} = + header_json + |> Jason.decode!() + |> Map.pop(@header_metadata_key) + + header + end + + defp build_tensor(binary, tensor_info) do + %{"dtype" => dtype, "shape" => shape} = tensor_info + {_, elem_size} = type = dtype_to_type(dtype) + + binary + |> new_byte_order(elem_size, :little) + |> Nx.from_binary(type) + |> Nx.reshape(List.to_tuple(shape)) + end + defp type_to_dtype(type) do @type_to_dtype[type] || raise "unrecognized type #{inspect(type)}" end diff --git a/test/safetensors_test.exs b/test/safetensors_test.exs index ae4a155..ea444e4 100644 --- a/test/safetensors_test.exs +++ b/test/safetensors_test.exs @@ -16,6 +16,20 @@ defmodule SafetensorsTest do ~s(<\x00\x00\x00\x00\x00\x00\x00{"test":{"dtype":"I32","shape":[2,2],"data_offsets":[0,16]}}\x01\x00\x00\x00\x02\x00\x00\x00\x03\x00\x00\x00\x04\x00\x00\x00) end + @tag :tmp_dir + test "read", %{tmp_dir: tmp_dir} do + path = Path.join(tmp_dir, "safetensor") + + # source: + # https://github.com/huggingface/safetensors/blob/1a65a3fdebcf280ef0ca32934901d3e2ad3b2c65/bindings/python/tests/test_simple.py#L35-L40 + File.write!( + path, + ~s(<\x00\x00\x00\x00\x00\x00\x00{"test":{"dtype":"I32","shape":[2,2],"data_offsets":[0,16]}}\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00) + ) + + assert Safetensors.read!(path) == %{"test" => Nx.tensor([[0, 0], [0, 0]], type: :s32)} + end + test "load" do # source: # https://github.com/huggingface/safetensors/blob/1a65a3fdebcf280ef0ca32934901d3e2ad3b2c65/bindings/python/tests/test_simple.py#L35-L40