From 27e805829335019ab988a71506f9282753ce68dd Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 3 Sep 2024 17:26:48 -0300 Subject: [PATCH 1/3] refactor: Nx.Random.key as deftransform --- nx/lib/nx/random.ex | 14 +++++--------- nx/test/nx/random_test.exs | 4 ++-- torchx/test/torchx/random_test.exs | 4 ++-- 3 files changed, 9 insertions(+), 13 deletions(-) diff --git a/nx/lib/nx/random.ex b/nx/lib/nx/random.ex index c876aed3b32..e8c90417cdf 100644 --- a/nx/lib/nx/random.ex +++ b/nx/lib/nx/random.ex @@ -80,22 +80,18 @@ defmodule Nx.Random do [0, 12] > - A single key effectively consists of 64 bits, so all possible values - of a 64-bit integer result in a different key. However, when passing - an integer literal, Nx implicitly assumes its type to be `:s32`, - which would result in an overflow for large integers. Therefore, - when dealing with large seeds, make sure to explicitly use a 64 bit - type: - - iex> Nx.Random.key(Nx.u64(999999999999)) + iex> Nx.Random.key(999999999999) #Nx.Tensor< u32[2] [232, 3567587327] > """ - defn key(seed) do + deftransform key(seed) do seed = Nx.as_type(seed, :u64) + key_n(seed) + end + defnp key_n(seed) do k1 = Nx.right_shift(seed, 32) k2 = Nx.bitwise_and(seed, Nx.u64(0xFFFFFFFF)) diff --git a/nx/test/nx/random_test.exs b/nx/test/nx/random_test.exs index 46fad985b70..9197067afd3 100644 --- a/nx/test/nx/random_test.exs +++ b/nx/test/nx/random_test.exs @@ -67,7 +67,7 @@ defmodule Nx.RandomTest do describe "distributions" do defp distribution_case(name, args: args, expected: expected) do seed = :erlang.adler32("#{name}threefry2x32") - key = Nx.Random.key(Nx.u64(seed)) + key = Nx.Random.key(seed) actual = apply(Nx.Random, name, [key | args]) assert_all_close(actual, expected) @@ -261,7 +261,7 @@ defmodule Nx.RandomTest do |> assert_all_close(apply(expected_func, expected_args), rtol: 0.1) seed = :erlang.adler32("uniformthreefry2x32") - key = Nx.Random.key(Nx.tensor(seed, type: :u64)) + key = Nx.Random.key(seed) t = apply(Nx.Random, name, [key | args]) apply(Nx, moment, [t]) diff --git a/torchx/test/torchx/random_test.exs b/torchx/test/torchx/random_test.exs index d6721bc8f68..e19865d35e0 100644 --- a/torchx/test/torchx/random_test.exs +++ b/torchx/test/torchx/random_test.exs @@ -55,7 +55,7 @@ defmodule Torchx.Nx.RandomTest do describe "distributions" do defp distribution_case(name, args: args, expected: expected) do seed = :erlang.adler32("#{name}threefry2x32") - key = Nx.Random.key(Nx.u64(seed)) + key = Nx.Random.key(seed) actual = apply(Nx.Random, name, [key | args]) assert_all_close(actual, expected) @@ -209,7 +209,7 @@ defmodule Torchx.Nx.RandomTest do |> assert_all_close(apply(expected_func, expected_args), rtol: 0.1) seed = :erlang.adler32("#{name}threefry2x32") - key = Nx.Random.key(Nx.tensor(seed, type: :u64)) + key = Nx.Random.key(seed) t = apply(Nx.Random, name, [key | args]) apply(Nx, moment, [t]) From 7d247b4d275bd917c71969c5c73a7d6a37ae86a6 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 4 Sep 2024 11:24:03 -0300 Subject: [PATCH 2/3] Update nx/lib/nx/random.ex MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jonatan KÅ‚osko --- nx/lib/nx/random.ex | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/nx/lib/nx/random.ex b/nx/lib/nx/random.ex index e8c90417cdf..49b73b28e85 100644 --- a/nx/lib/nx/random.ex +++ b/nx/lib/nx/random.ex @@ -87,7 +87,18 @@ defmodule Nx.Random do > """ deftransform key(seed) do - seed = Nx.as_type(seed, :u64) + seed = + case seed do + seed when is_integer(seed) -> + Nx.u64(seed) + + %Nx.Tensor{} = seed when seed.type == {:u, 64} -> + seed + + other -> + raise ArgumentError, "expected seed to be an integer or :u64 tensor, got: #{inspect(other)}" + end + key_n(seed) end From f2bb49859c2fd7026286e871e1f90292e306bfbb Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 4 Sep 2024 11:47:46 -0300 Subject: [PATCH 3/3] fix: accept s64 tensors --- nx/lib/nx/random.ex | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/nx/lib/nx/random.ex b/nx/lib/nx/random.ex index 49b73b28e85..77fa286c6f5 100644 --- a/nx/lib/nx/random.ex +++ b/nx/lib/nx/random.ex @@ -94,9 +94,13 @@ defmodule Nx.Random do %Nx.Tensor{} = seed when seed.type == {:u, 64} -> seed - + + %Nx.Tensor{} = seed when seed.type == {:s, 64} -> + Nx.bitcast(seed, {:u, 64}) + other -> - raise ArgumentError, "expected seed to be an integer or :u64 tensor, got: #{inspect(other)}" + raise ArgumentError, + "expected seed to be an integer, u64 tensor or s64 tensor, got: #{inspect(other)}" end key_n(seed)