Skip to content

Commit

Permalink
Add tests for executable serde (#1484)
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim authored Jun 10, 2024
1 parent a91a97c commit 54e22a4
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 20 deletions.
32 changes: 25 additions & 7 deletions exla/lib/exla/executable.ex
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ defmodule EXLA.Executable do
end
end

@doc """
Serializes the executable to a binary.
"""
def serialize(%Executable{
ref: executable,
output_typespecs: output_typespecs,
Expand All @@ -36,6 +39,7 @@ defmodule EXLA.Executable do
|> IO.iodata_to_binary()

%{
version: 1,
serialized: serialized_exec,
output_typespecs: output_typespecs,
num_replicas: num_replicas,
Expand All @@ -45,21 +49,35 @@ defmodule EXLA.Executable do
|> :erlang.term_to_binary()
end

@doc """
Deserializes a previous serialized executable.
"""
def deserialize(client, binary) do
case :erlang.binary_to_term(binary) do
%{serialized: serialized_exec} = exec_data ->
%{version: 1, serialized: serialized} = data ->
%{
output_typespecs: output_typespecs,
num_replicas: num_replicas,
num_partitions: num_partitions,
device_id: device_id
} = data

ref =
serialized_exec
serialized
|> then(&EXLA.NIF.deserialize_executable(client.ref, &1))
|> unwrap!()

exec_data
|> Map.put(:ref, ref)
|> Map.put(:client, client)
|> then(&struct(__MODULE__, &1))
%EXLA.Executable{
output_typespecs: output_typespecs,
num_replicas: num_replicas,
num_partitions: num_partitions,
device_id: device_id,
ref: ref,
client: client
}

_other ->
raise "invalid serialized executable"
raise ArgumentError, "invalid serialized executable"
end
end

Expand Down
45 changes: 32 additions & 13 deletions exla/test/exla/executable_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ defmodule EXLA.ExecutableTest do
describe "run" do
test "with no inputs and default options" do
assert [a = %DeviceBuffer{}] =
run_one([], [], Typespec.tensor({:s, 32}, {}), fn b ->
run_one([], [], s32_typespec(), fn b ->
[Value.constant(b, [1], s32_typespec())]
end)

assert <<1::32-native>> == DeviceBuffer.read(a)
end

test "with 2 inputs and default options" do
t1 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {}))
t2 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {}))
t1 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec())
t2 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec())

assert [a = %DeviceBuffer{}] =
run_one([t1, t2], [], [t1.typespec], fn _b, x, y ->
Expand All @@ -34,15 +34,15 @@ defmodule EXLA.ExecutableTest do
t1 =
DeviceBuffer.place_on_device(
<<1::32-native>>,
Typespec.tensor({:s, 32}, {}),
s32_typespec(),
client(),
0
)

t2 =
DeviceBuffer.place_on_device(
<<1::32-native>>,
Typespec.tensor({:s, 32}, {}),
s32_typespec(),
client(),
0
)
Expand All @@ -62,8 +62,8 @@ defmodule EXLA.ExecutableTest do
end

test "with data from a previous run" do
t1 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {}))
t2 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {}))
t1 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec())
t2 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec())

exec =
compile([t1.typespec, t2.typespec], [], [t1.typespec], fn _b, x, y ->
Expand All @@ -80,12 +80,12 @@ defmodule EXLA.ExecutableTest do
t1 =
DeviceBuffer.place_on_device(
<<1::32-native>>,
Typespec.tensor({:s, 32}, {}),
s32_typespec(),
client(),
0
)

t2 = BinaryBuffer.from_binary(<<2::32-native>>, Typespec.tensor({:s, 32}, {}))
t2 = BinaryBuffer.from_binary(<<2::32-native>>, s32_typespec())

assert [a = %DeviceBuffer{}] =
run_one([t1, t2], [], [t1.typespec], fn _b, x, y ->
Expand All @@ -96,8 +96,8 @@ defmodule EXLA.ExecutableTest do
end

test "with tuple return" do
t1 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {}))
t2 = BinaryBuffer.from_binary(<<2::32-native>>, Typespec.tensor({:s, 32}, {}))
t1 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec())
t2 = BinaryBuffer.from_binary(<<2::32-native>>, s32_typespec())

assert [a = %DeviceBuffer{}, b = %DeviceBuffer{}] =
run_one([t1, t2], [], [t1.typespec, t2.typespec], fn _b, x, y ->
Expand All @@ -110,8 +110,8 @@ defmodule EXLA.ExecutableTest do

@tag :multi_device
test "runs on a specific device" do
t1 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {}))
t2 = BinaryBuffer.from_binary(<<2::32-native>>, Typespec.tensor({:s, 32}, {}))
t1 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec())
t2 = BinaryBuffer.from_binary(<<2::32-native>>, s32_typespec())

assert [a = %DeviceBuffer{}, b = %DeviceBuffer{}, c = %DeviceBuffer{}] =
run_one(
Expand All @@ -138,6 +138,25 @@ defmodule EXLA.ExecutableTest do
end
end

describe "serialization" do
test "run" do
t1 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec())
t2 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec())

exec =
compile([s32_typespec(), s32_typespec()], [], [s32_typespec()], fn _, x, y ->
[Value.add(x, y, s32_typespec())]
end)

binary = Executable.serialize(exec)
assert is_binary(binary)
exec = Executable.deserialize(client(), binary)

assert [[a = %DeviceBuffer{}]] = EXLA.Executable.run(exec, [[t1, t2]], [])
assert <<2::32-native>> == DeviceBuffer.read(a)
end
end

defp s32_typespec(), do: Typespec.tensor({:s, 32}, {})
end

Expand Down

0 comments on commit 54e22a4

Please sign in to comment.