diff --git a/nx/test/nx/defn/tree_test.exs b/nx/test/nx/defn/tree_test.exs index 6586119039..5ccd4fb300 100644 --- a/nx/test/nx/defn/tree_test.exs +++ b/nx/test/nx/defn/tree_test.exs @@ -40,7 +40,9 @@ defmodule Nx.Defn.TreeTest do test "ignores constants" do a = Expr.parameter(:root, {:u, 64}, {}, 0) - assert [{_, :parameter}, {_, :add}] = plus_constant(a) |> Tree.scope_ids() |> Enum.sort() + + assert [{_, :add}, {_, :parameter}] = + plus_constant(a) |> Tree.scope_ids() |> Enum.sort_by(&elem(&1, 1)) end defn inside_cond(bool, a, b) do @@ -54,8 +56,8 @@ defmodule Nx.Defn.TreeTest do test "ignores expressions inside cond" do {bool, cond} = Nx.Defn.jit(&{&1, inside_cond(&1, &2, &3)}).(0, 1, 2) - assert cond |> Tree.scope_ids() |> Enum.sort() == - [{bool.data.id, :parameter}, {cond.data.id, :cond}] + assert cond |> Tree.scope_ids() |> Enum.sort_by(&elem(&1, 1)) == + [{cond.data.id, :cond}, {bool.data.id, :parameter}] end defn inside_both_cond(bool, a, b) do @@ -84,14 +86,14 @@ defmodule Nx.Defn.TreeTest do b = Expr.parameter(:root, {:u, 64}, {}, 2) assert [ - {_, :parameter}, - {_, :parameter}, - {_, :parameter}, {_, :add}, {_, :cond}, {_, :cond}, - {_, :multiply} - ] = inside_both_cond(bool, a, b) |> Tree.scope_ids() |> Enum.sort() + {_, :multiply}, + {_, :parameter}, + {_, :parameter}, + {_, :parameter} + ] = inside_both_cond(bool, a, b) |> Tree.scope_ids() |> Enum.sort_by(&elem(&1, 1)) end end