Skip to content

Commit

Permalink
chore: reorganize EXLA tests (#1456)
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente authored Feb 26, 2024
1 parent 7b087da commit e173aca
Show file tree
Hide file tree
Showing 13 changed files with 169 additions and 1,236 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- elixir: "1.15.4"
working_directory: "exla"
otp: "25.3"
use_mlir: true
use_mlir: mlir
defaults:
run:
working-directory: ${{ matrix.working_directory }}
Expand Down
14 changes: 8 additions & 6 deletions exla/lib/exla/executable.ex
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,15 @@ defmodule EXLA.Executable do
unwrap!(data)
end

defp decompose_output({data, device_id}, shape, client) do
defp decompose_output({data, device_id}, shapes, client) do
shapes =
case shape do
%Shape{dtype: {:tuple, shapes}} -> shapes
shapes when is_list(shapes) -> shapes
%Shape{} -> [shape]
end
Enum.flat_map(List.wrap(shapes), fn shape ->
case shape do
%Shape{dtype: {:tuple, shapes}} -> shapes
shapes when is_list(shapes) -> shapes
%Shape{} -> [shape]
end
end)

Enum.zip_with(data, shapes, fn
buf, subshape when is_reference(buf) ->
Expand Down
8 changes: 7 additions & 1 deletion exla/lib/exla/mlir/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -604,8 +604,14 @@ defmodule EXLA.MLIR.Value do
def while(
%Function{ref: pred_ref},
%Function{ref: body_ref},
%Value{function: function} = initial
initial
) do
function =
case initial do
[%Value{function: function} | _] -> function
%Value{function: function} -> function
end

refs =
EXLA.NIF.mlir_while(function.ref, pred_ref, body_ref, flatten_tuples(initial)) |> unwrap!()

Expand Down
1 change: 1 addition & 0 deletions exla/test/exla/builder_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ defmodule EXLA.BuilderTest do

alias EXLA.{Builder, Computation, Op}

@moduletag skip: :mlir
test "new/1 succeeds in creating a new builder" do
assert b = %Builder{} = Builder.new("builder")
assert b.name == "builder"
Expand Down
2 changes: 1 addition & 1 deletion exla/test/exla/defn/expr_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -4013,7 +4013,7 @@ defmodule EXLA.Defn.ExprTest do

test "raises on bad precision" do
valid_precision =
if Nx.Defn.default_options()[:compiler_mode] == :mlir do
if compiler_mode() == :mlir do
":default, :high, :highest, or :packed_nibble"
else
":default, :high, or :highest"
Expand Down
201 changes: 141 additions & 60 deletions exla/test/exla/executable_test.exs
Original file line number Diff line number Diff line change
@@ -1,29 +1,19 @@
defmodule EXLA.ExecutableTest do
use ExUnit.Case, async: true

alias EXLA.{BinaryBuffer, DeviceBuffer, Executable, Op, Shape}
alias EXLA.BinaryBuffer
alias EXLA.DeviceBuffer
alias EXLA.Executable
alias EXLA.Op
alias EXLA.Shape
alias EXLA.MLIR.Value
import EXLAHelpers

test "raises on invalid tuples" do
t1 = BinaryBuffer.from_binary(<<1::32-native>>, Shape.make_shape({:s, 32}, {}))
t2 = BinaryBuffer.from_binary(<<2::32-native>>, Shape.make_shape({:s, 32}, {}))

assert_raise ArgumentError, ~r"can only compile computations with a tuple at the root", fn ->
run_one([t1, t2], [], fn b, x, y ->
Op.tuple(b, [Op.tuple(b, [x]), Op.tuple(b, [y])])
end)
end

assert_raise ArgumentError, ~r"can only compile computations with a tuple at the root", fn ->
run_one([t1, t2], [], fn _b, x, y -> Op.add(x, y) end)
end
end

describe "run" do
test "with no inputs and default options" do
assert [a = %DeviceBuffer{}] =
run_one([], fn b ->
Op.tuple(b, [Op.constant_r0(b, 1, {:s, 32})])
run_one([], [], Shape.make_shape({:s, 32}, {}), fn b ->
mod().tuple(b, [mod().constant_r0(b, 1, {:s, 32})])
end)

assert <<1::32-native>> == DeviceBuffer.read(a)
Expand All @@ -34,7 +24,9 @@ defmodule EXLA.ExecutableTest do
t2 = BinaryBuffer.from_binary(<<1::32-native>>, Shape.make_shape({:s, 32}, {}))

assert [a = %DeviceBuffer{}] =
run_one([t1, t2], fn b, x, y -> Op.tuple(b, [Op.add(x, y)]) end)
run_one([t1, t2], [], Shape.make_tuple_shape([t1.shape]), fn b, x, y ->
mod().tuple(b, [add(x, y)])
end)

assert <<2::32-native>> == DeviceBuffer.read(a)
end
Expand All @@ -57,13 +49,13 @@ defmodule EXLA.ExecutableTest do
)

assert [%DeviceBuffer{}] =
run_one([t1, t2], [], fn b, x, y ->
Op.tuple(b, [Op.add(x, y)])
run_one([t1, t2], [], t1.shape, fn b, x, y ->
mod().tuple(b, [add(x, y)])
end)

assert [%DeviceBuffer{}] =
run_one([t1, t2], [], fn b, x, y ->
Op.tuple(b, [Op.add(x, y)])
run_one([t1, t2], [], Shape.make_tuple_shape([t1.shape]), fn b, x, y ->
mod().tuple(b, [add(x, y)])
end)

assert DeviceBuffer.read(t1) == <<1::32-native>>
Expand All @@ -74,7 +66,11 @@ defmodule EXLA.ExecutableTest do
t1 = BinaryBuffer.from_binary(<<1::32-native>>, Shape.make_shape({:s, 32}, {}))
t2 = BinaryBuffer.from_binary(<<1::32-native>>, Shape.make_shape({:s, 32}, {}))

exec = compile([t1.shape, t2.shape], fn b, x, y -> Op.tuple(b, [Op.add(x, y)]) end)
exec =
compile([t1.shape, t2.shape], [], [t1.shape], fn b, x, y ->
mod().tuple(b, [add(x, y)])
end)

assert [[t3 = %DeviceBuffer{}]] = Executable.run(exec, [[t1, t2]])
assert [[a = %DeviceBuffer{}]] = Executable.run(exec, [[t3, t3]])

Expand All @@ -93,7 +89,7 @@ defmodule EXLA.ExecutableTest do
t2 = BinaryBuffer.from_binary(<<2::32-native>>, Shape.make_shape({:s, 32}, {}))

assert [a = %DeviceBuffer{}] =
run_one([t1, t2], fn b, x, y -> Op.tuple(b, [Op.add(x, y)]) end)
run_one([t1, t2], [], {t1.shape}, fn b, x, y -> mod().tuple(b, [add(x, y)]) end)

assert <<3::32-native>> == DeviceBuffer.read(a)
end
Expand All @@ -103,7 +99,9 @@ defmodule EXLA.ExecutableTest do
t2 = BinaryBuffer.from_binary(<<2::32-native>>, Shape.make_shape({:s, 32}, {}))

assert [a = %DeviceBuffer{}, b = %DeviceBuffer{}] =
run_one([t1, t2], fn b, x, y -> Op.tuple(b, [x, y]) end)
run_one([t1, t2], [], Shape.make_tuple_shape([t1.shape, t2.shape]), fn b, x, y ->
mod().tuple(b, [x, y])
end)

assert <<1::32-native>> == DeviceBuffer.read(a)
assert <<2::32-native>> == DeviceBuffer.read(b)
Expand All @@ -115,9 +113,14 @@ defmodule EXLA.ExecutableTest do
t2 = BinaryBuffer.from_binary(<<2::32-native>>, Shape.make_shape({:s, 32}, {}))

assert [a = %DeviceBuffer{}, b = %DeviceBuffer{}, c = %DeviceBuffer{}] =
run_one([t1, t2], [device_id: 1], fn b, x, y ->
Op.tuple(b, [x, y, Op.add(x, y)])
end)
run_one(
[t1, t2],
[device_id: 1],
EXLA.Shape.make_tuple_shape([t1.shape, t2.shape, t1.shape]),
fn b, x, y ->
mod().tuple(b, [x, y, add(x, y)])
end
)

assert <<1::32-native>> == DeviceBuffer.read(a)
assert a.device_id == 1
Expand All @@ -127,36 +130,65 @@ defmodule EXLA.ExecutableTest do
assert c.device_id == 1

assert_raise RuntimeError, ~r"Expected buffer to be placed on device 0", fn ->
run_one([a, b], [device_id: 0], fn b, x, y ->
Op.tuple(b, [Op.add(x, y)])
run_one([a, b], [device_id: 0], t1.shape, fn b, x, y ->
mod().tuple(b, [add(x, y)])
end)
end
end
end

defp add(x, y) do
if mod() == Value do
Value.add(x.function, x, y)
else
Op.add(x, y)
end
end

defp mod do
if Application.get_env(:exla, :compiler_mode) == :mlir do
Value
else
Op
end
end
end

defmodule EXLA.ExecutableFeedTest do
# infeed/outfeed are global resources, so they either
# need to be locked or we cannot run them concurrently.
use ExUnit.Case, async: false

alias EXLA.{BinaryBuffer, DeviceBuffer, Client, Op, Shape}
alias EXLA.BinaryBuffer
alias EXLA.DeviceBuffer
alias EXLA.Client
alias EXLA.Op
alias EXLA.Shape
alias EXLA.MLIR.Value
import EXLAHelpers

defp mod do
if Application.get_env(:exla, :compiler_mode) == :mlir do
Value
else
Op
end
end

describe "infeed/outfeed" do
test "successfully sends to/from device asynchronously" do
t = BinaryBuffer.from_binary(<<1::32-native>>, Shape.make_shape({:s, 32}, {}))

assert res =
Task.async(fn ->
run_one([], [], fn b ->
token = Op.create_token(b)
val_and_token = Op.infeed(token, t.shape)
val = Op.get_tuple_element(val_and_token, 0)
new_token = Op.get_tuple_element(val_and_token, 1)
outfeed_val = Op.add(val, val)
_outfeed_token = Op.outfeed(outfeed_val, new_token)
Op.tuple(b, [Op.add(outfeed_val, val)])
run_one([], [], Shape.make_tuple_shape([Shape.make_token_shape()]), fn b ->
token = mod().create_token(b)
val_and_token = mod().infeed(token, t.shape)
val = mod().get_tuple_element(val_and_token, 0)
new_token = mod().get_tuple_element(val_and_token, 1)
outfeed_val = add(val, val)
_outfeed_token = mod().outfeed(outfeed_val, new_token)
mod().tuple(b, [add(outfeed_val, val)])
end)
end)

Expand All @@ -170,28 +202,69 @@ defmodule EXLA.ExecutableFeedTest do
test "successfully sends to/from device asynchronously in a loop" do
t = BinaryBuffer.from_binary(<<1::32-native>>, Shape.make_shape({:s, 32}, {}))

token_shape = Shape.make_token_shape()

assert res =
Task.async(fn ->
run_one([], [], fn b ->
token_shape = Shape.make_token_shape()
tuple_shape = Shape.make_tuple_shape([t.shape, token_shape])

condition_b = EXLA.Builder.new(b, "condition")
param = EXLA.Op.parameter(condition_b, 0, tuple_shape, "arg")
zero = Op.constant_r0(condition_b, 0, {:s, 32})
val = Op.get_tuple_element(param, 0)
condition = EXLA.Builder.build(Op.not_equal(val, zero))

while_b = EXLA.Builder.new(b, "while")
param = EXLA.Op.parameter(while_b, 0, tuple_shape, "arg")
val = Op.get_tuple_element(param, 0)
token = Op.get_tuple_element(param, 1)
token = Op.outfeed(Op.add(val, val), token)
while = EXLA.Builder.build(Op.infeed(token, t.shape))

token = Op.create_token(b)
while = Op.while(condition, while, Op.infeed(token, t.shape))
Op.tuple(b, [Op.get_tuple_element(while, 0)])
run_one([], [], {token_shape, t.shape}, fn b ->
if mod() == Value do
condition =
EXLA.MLIR.Module.add_function(
b.module,
"condition",
[token_shape, t.shape],
[
token_shape,
Shape.make_shape({:pred, 8}, {})
]
)

[_token, val] = EXLA.MLIR.Function.get_arguments(condition)
zero = Value.constant_r0(condition, 0, {:s, 32})
Value.variadic_return([Value.not_equal(condition, val, zero)])

body =
EXLA.MLIR.Module.add_function(b.module, "body", [token_shape, t.shape], [
token_shape,
t.shape
])

[token, val] = EXLA.MLIR.Function.get_arguments(body)

token = Value.outfeed(Value.add(body, val, val), token)

infeed = Value.infeed(token, t.shape)
input = Value.get_tuple_element(infeed, 0)
token = Value.get_tuple_element(infeed, 1)

Value.variadic_return([token, input])

token = Value.create_token(b)
infeed = Value.infeed(token, t.shape)
input = Value.get_tuple_element(infeed, 0)
token = Value.get_tuple_element(infeed, 1)

[_token, result] = Value.while(condition, body, [token, input])
Value.tuple(b, [result])
else
tuple_shape = Shape.make_tuple_shape([t.shape, Shape.make_token_shape()])
condition_b = EXLA.Builder.new(b, "condition")
param = mod().parameter(condition_b, 0, tuple_shape, "arg")
zero = mod().constant_r0(condition_b, 0, {:s, 32})
val = mod().get_tuple_element(param, 0)
condition = EXLA.Builder.build(mod().not_equal(val, zero))

while_b = EXLA.Builder.new(b, "while")
param = mod().parameter(while_b, 0, tuple_shape, "arg")
val = mod().get_tuple_element(param, 0)
token = mod().get_tuple_element(param, 1)
token = mod().outfeed(add(val, val), token)
while = EXLA.Builder.build(mod().infeed(token, t.shape))

token = mod().create_token(b)
while = mod().while(condition, while, mod().infeed(token, t.shape))
mod().tuple(b, [mod().get_tuple_element(while, 0)])
end
end)
end)

Expand All @@ -216,4 +289,12 @@ defmodule EXLA.ExecutableFeedTest do
{^ref, msg} -> msg
end
end

defp add(x, y) do
if mod() == Value do
Value.add(x.function, x, y)
else
Op.add(x, y)
end
end
end
Loading

0 comments on commit e173aca

Please sign in to comment.