From 54e22a4c4ccbeaa6c945782cdc6b6d65f35d1207 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Mon, 10 Jun 2024 17:01:22 +0200 Subject: [PATCH] Add tests for executable serde (#1484) --- exla/lib/exla/executable.ex | 32 ++++++++++++++++----- exla/test/exla/executable_test.exs | 45 +++++++++++++++++++++--------- 2 files changed, 57 insertions(+), 20 deletions(-) diff --git a/exla/lib/exla/executable.ex b/exla/lib/exla/executable.ex index a6a0c8cbdf..1c6e4df49a 100644 --- a/exla/lib/exla/executable.ex +++ b/exla/lib/exla/executable.ex @@ -22,6 +22,9 @@ defmodule EXLA.Executable do end end + @doc """ + Serializes the executable to a binary. + """ def serialize(%Executable{ ref: executable, output_typespecs: output_typespecs, @@ -36,6 +39,7 @@ defmodule EXLA.Executable do |> IO.iodata_to_binary() %{ + version: 1, serialized: serialized_exec, output_typespecs: output_typespecs, num_replicas: num_replicas, @@ -45,21 +49,35 @@ defmodule EXLA.Executable do |> :erlang.term_to_binary() end + @doc """ + Deserializes a previous serialized executable. + """ def deserialize(client, binary) do case :erlang.binary_to_term(binary) do - %{serialized: serialized_exec} = exec_data -> + %{version: 1, serialized: serialized} = data -> + %{ + output_typespecs: output_typespecs, + num_replicas: num_replicas, + num_partitions: num_partitions, + device_id: device_id + } = data + ref = - serialized_exec + serialized |> then(&EXLA.NIF.deserialize_executable(client.ref, &1)) |> unwrap!() - exec_data - |> Map.put(:ref, ref) - |> Map.put(:client, client) - |> then(&struct(__MODULE__, &1)) + %EXLA.Executable{ + output_typespecs: output_typespecs, + num_replicas: num_replicas, + num_partitions: num_partitions, + device_id: device_id, + ref: ref, + client: client + } _other -> - raise "invalid serialized executable" + raise ArgumentError, "invalid serialized executable" end end diff --git a/exla/test/exla/executable_test.exs b/exla/test/exla/executable_test.exs index 49f33f9280..fe0e5c3f4d 100644 --- a/exla/test/exla/executable_test.exs +++ b/exla/test/exla/executable_test.exs @@ -11,7 +11,7 @@ defmodule EXLA.ExecutableTest do describe "run" do test "with no inputs and default options" do assert [a = %DeviceBuffer{}] = - run_one([], [], Typespec.tensor({:s, 32}, {}), fn b -> + run_one([], [], s32_typespec(), fn b -> [Value.constant(b, [1], s32_typespec())] end) @@ -19,8 +19,8 @@ defmodule EXLA.ExecutableTest do end test "with 2 inputs and default options" do - t1 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {})) - t2 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {})) + t1 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec()) + t2 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec()) assert [a = %DeviceBuffer{}] = run_one([t1, t2], [], [t1.typespec], fn _b, x, y -> @@ -34,7 +34,7 @@ defmodule EXLA.ExecutableTest do t1 = DeviceBuffer.place_on_device( <<1::32-native>>, - Typespec.tensor({:s, 32}, {}), + s32_typespec(), client(), 0 ) @@ -42,7 +42,7 @@ defmodule EXLA.ExecutableTest do t2 = DeviceBuffer.place_on_device( <<1::32-native>>, - Typespec.tensor({:s, 32}, {}), + s32_typespec(), client(), 0 ) @@ -62,8 +62,8 @@ defmodule EXLA.ExecutableTest do end test "with data from a previous run" do - t1 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {})) - t2 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {})) + t1 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec()) + t2 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec()) exec = compile([t1.typespec, t2.typespec], [], [t1.typespec], fn _b, x, y -> @@ -80,12 +80,12 @@ defmodule EXLA.ExecutableTest do t1 = DeviceBuffer.place_on_device( <<1::32-native>>, - Typespec.tensor({:s, 32}, {}), + s32_typespec(), client(), 0 ) - t2 = BinaryBuffer.from_binary(<<2::32-native>>, Typespec.tensor({:s, 32}, {})) + t2 = BinaryBuffer.from_binary(<<2::32-native>>, s32_typespec()) assert [a = %DeviceBuffer{}] = run_one([t1, t2], [], [t1.typespec], fn _b, x, y -> @@ -96,8 +96,8 @@ defmodule EXLA.ExecutableTest do end test "with tuple return" do - t1 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {})) - t2 = BinaryBuffer.from_binary(<<2::32-native>>, Typespec.tensor({:s, 32}, {})) + t1 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec()) + t2 = BinaryBuffer.from_binary(<<2::32-native>>, s32_typespec()) assert [a = %DeviceBuffer{}, b = %DeviceBuffer{}] = run_one([t1, t2], [], [t1.typespec, t2.typespec], fn _b, x, y -> @@ -110,8 +110,8 @@ defmodule EXLA.ExecutableTest do @tag :multi_device test "runs on a specific device" do - t1 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {})) - t2 = BinaryBuffer.from_binary(<<2::32-native>>, Typespec.tensor({:s, 32}, {})) + t1 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec()) + t2 = BinaryBuffer.from_binary(<<2::32-native>>, s32_typespec()) assert [a = %DeviceBuffer{}, b = %DeviceBuffer{}, c = %DeviceBuffer{}] = run_one( @@ -138,6 +138,25 @@ defmodule EXLA.ExecutableTest do end end + describe "serialization" do + test "run" do + t1 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec()) + t2 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec()) + + exec = + compile([s32_typespec(), s32_typespec()], [], [s32_typespec()], fn _, x, y -> + [Value.add(x, y, s32_typespec())] + end) + + binary = Executable.serialize(exec) + assert is_binary(binary) + exec = Executable.deserialize(client(), binary) + + assert [[a = %DeviceBuffer{}]] = EXLA.Executable.run(exec, [[t1, t2]], []) + assert <<2::32-native>> == DeviceBuffer.read(a) + end + end + defp s32_typespec(), do: Typespec.tensor({:s, 32}, {}) end