From 0738c76887796c52ebeeac05b03e0abe65ec615d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jos=C3=A9=20Valim?= <jose.valim@dashbit.co>
Date: Mon, 13 May 2024 14:46:15 +0200
Subject: [PATCH] Add tests for executable serde

---
 exla/lib/exla/executable.ex        | 32 ++++++++++++++++-----
 exla/test/exla/executable_test.exs | 45 +++++++++++++++++++++---------
 2 files changed, 57 insertions(+), 20 deletions(-)

diff --git a/exla/lib/exla/executable.ex b/exla/lib/exla/executable.ex
index a6a0c8cbdf..1c6e4df49a 100644
--- a/exla/lib/exla/executable.ex
+++ b/exla/lib/exla/executable.ex
@@ -22,6 +22,9 @@ defmodule EXLA.Executable do
     end
   end
 
+  @doc """
+  Serializes the executable to a binary.
+  """
   def serialize(%Executable{
         ref: executable,
         output_typespecs: output_typespecs,
@@ -36,6 +39,7 @@ defmodule EXLA.Executable do
       |> IO.iodata_to_binary()
 
     %{
+      version: 1,
       serialized: serialized_exec,
       output_typespecs: output_typespecs,
       num_replicas: num_replicas,
@@ -45,21 +49,35 @@ defmodule EXLA.Executable do
     |> :erlang.term_to_binary()
   end
 
+  @doc """
+  Deserializes a previous serialized executable.
+  """
   def deserialize(client, binary) do
     case :erlang.binary_to_term(binary) do
-      %{serialized: serialized_exec} = exec_data ->
+      %{version: 1, serialized: serialized} = data ->
+        %{
+          output_typespecs: output_typespecs,
+          num_replicas: num_replicas,
+          num_partitions: num_partitions,
+          device_id: device_id
+        } = data
+
         ref =
-          serialized_exec
+          serialized
           |> then(&EXLA.NIF.deserialize_executable(client.ref, &1))
           |> unwrap!()
 
-        exec_data
-        |> Map.put(:ref, ref)
-        |> Map.put(:client, client)
-        |> then(&struct(__MODULE__, &1))
+        %EXLA.Executable{
+          output_typespecs: output_typespecs,
+          num_replicas: num_replicas,
+          num_partitions: num_partitions,
+          device_id: device_id,
+          ref: ref,
+          client: client
+        }
 
       _other ->
-        raise "invalid serialized executable"
+        raise ArgumentError, "invalid serialized executable"
     end
   end
 
diff --git a/exla/test/exla/executable_test.exs b/exla/test/exla/executable_test.exs
index 49f33f9280..fe0e5c3f4d 100644
--- a/exla/test/exla/executable_test.exs
+++ b/exla/test/exla/executable_test.exs
@@ -11,7 +11,7 @@ defmodule EXLA.ExecutableTest do
   describe "run" do
     test "with no inputs and default options" do
       assert [a = %DeviceBuffer{}] =
-               run_one([], [], Typespec.tensor({:s, 32}, {}), fn b ->
+               run_one([], [], s32_typespec(), fn b ->
                  [Value.constant(b, [1], s32_typespec())]
                end)
 
@@ -19,8 +19,8 @@ defmodule EXLA.ExecutableTest do
     end
 
     test "with 2 inputs and default options" do
-      t1 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {}))
-      t2 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {}))
+      t1 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec())
+      t2 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec())
 
       assert [a = %DeviceBuffer{}] =
                run_one([t1, t2], [], [t1.typespec], fn _b, x, y ->
@@ -34,7 +34,7 @@ defmodule EXLA.ExecutableTest do
       t1 =
         DeviceBuffer.place_on_device(
           <<1::32-native>>,
-          Typespec.tensor({:s, 32}, {}),
+          s32_typespec(),
           client(),
           0
         )
@@ -42,7 +42,7 @@ defmodule EXLA.ExecutableTest do
       t2 =
         DeviceBuffer.place_on_device(
           <<1::32-native>>,
-          Typespec.tensor({:s, 32}, {}),
+          s32_typespec(),
           client(),
           0
         )
@@ -62,8 +62,8 @@ defmodule EXLA.ExecutableTest do
     end
 
     test "with data from a previous run" do
-      t1 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {}))
-      t2 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {}))
+      t1 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec())
+      t2 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec())
 
       exec =
         compile([t1.typespec, t2.typespec], [], [t1.typespec], fn _b, x, y ->
@@ -80,12 +80,12 @@ defmodule EXLA.ExecutableTest do
       t1 =
         DeviceBuffer.place_on_device(
           <<1::32-native>>,
-          Typespec.tensor({:s, 32}, {}),
+          s32_typespec(),
           client(),
           0
         )
 
-      t2 = BinaryBuffer.from_binary(<<2::32-native>>, Typespec.tensor({:s, 32}, {}))
+      t2 = BinaryBuffer.from_binary(<<2::32-native>>, s32_typespec())
 
       assert [a = %DeviceBuffer{}] =
                run_one([t1, t2], [], [t1.typespec], fn _b, x, y ->
@@ -96,8 +96,8 @@ defmodule EXLA.ExecutableTest do
     end
 
     test "with tuple return" do
-      t1 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {}))
-      t2 = BinaryBuffer.from_binary(<<2::32-native>>, Typespec.tensor({:s, 32}, {}))
+      t1 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec())
+      t2 = BinaryBuffer.from_binary(<<2::32-native>>, s32_typespec())
 
       assert [a = %DeviceBuffer{}, b = %DeviceBuffer{}] =
                run_one([t1, t2], [], [t1.typespec, t2.typespec], fn _b, x, y ->
@@ -110,8 +110,8 @@ defmodule EXLA.ExecutableTest do
 
     @tag :multi_device
     test "runs on a specific device" do
-      t1 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {}))
-      t2 = BinaryBuffer.from_binary(<<2::32-native>>, Typespec.tensor({:s, 32}, {}))
+      t1 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec())
+      t2 = BinaryBuffer.from_binary(<<2::32-native>>, s32_typespec())
 
       assert [a = %DeviceBuffer{}, b = %DeviceBuffer{}, c = %DeviceBuffer{}] =
                run_one(
@@ -138,6 +138,25 @@ defmodule EXLA.ExecutableTest do
     end
   end
 
+  describe "serialization" do
+    test "run" do
+      t1 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec())
+      t2 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec())
+
+      exec =
+        compile([s32_typespec(), s32_typespec()], [], [s32_typespec()], fn _, x, y ->
+          [Value.add(x, y, s32_typespec())]
+        end)
+
+      binary = Executable.serialize(exec)
+      assert is_binary(binary)
+      exec = Executable.deserialize(client(), binary)
+
+      assert [[a = %DeviceBuffer{}]] = EXLA.Executable.run(exec, [[t1, t2]], [])
+      assert <<2::32-native>> == DeviceBuffer.read(a)
+    end
+  end
+
   defp s32_typespec(), do: Typespec.tensor({:s, 32}, {})
 end