Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Nx.Random.key as deftransform #1525

Merged
merged 3 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions nx/lib/nx/random.ex
Original file line number Diff line number Diff line change
Expand Up @@ -80,22 +80,33 @@ defmodule Nx.Random do
[0, 12]
>

A single key effectively consists of 64 bits, so all possible values
of a 64-bit integer result in a different key. However, when passing
an integer literal, Nx implicitly assumes its type to be `:s32`,
which would result in an overflow for large integers. Therefore,
when dealing with large seeds, make sure to explicitly use a 64 bit
type:

iex> Nx.Random.key(Nx.u64(999999999999))
iex> Nx.Random.key(999999999999)
#Nx.Tensor<
u32[2]
[232, 3567587327]
>
"""
defn key(seed) do
seed = Nx.as_type(seed, :u64)
deftransform key(seed) do
seed =
case seed do
seed when is_integer(seed) ->
Nx.u64(seed)

%Nx.Tensor{} = seed when seed.type == {:u, 64} ->
seed

%Nx.Tensor{} = seed when seed.type == {:s, 64} ->
Nx.bitcast(seed, {:u, 64})
Copy link
Collaborator

@josevalim josevalim Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need this clause? No-one should be passing s64 now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or do we also want to do bitcasts for f64?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do logical and with another :u64, so that's good.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the s64 case because there are explicit tests for this, and it has the same entropy as u64. Not sure about floats.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I misread the question. I think it's fair to assume seed is always an integer.

No preference on {:s, 64}, if there are tests, then yeah, bitcast sounds good.


other ->
raise ArgumentError,
"expected seed to be an integer, u64 tensor or s64 tensor, got: #{inspect(other)}"
end

key_n(seed)
end

defnp key_n(seed) do
josevalim marked this conversation as resolved.
Show resolved Hide resolved
k1 = Nx.right_shift(seed, 32)
k2 = Nx.bitwise_and(seed, Nx.u64(0xFFFFFFFF))

Expand Down
4 changes: 2 additions & 2 deletions nx/test/nx/random_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ defmodule Nx.RandomTest do
describe "distributions" do
defp distribution_case(name, args: args, expected: expected) do
seed = :erlang.adler32("#{name}threefry2x32")
key = Nx.Random.key(Nx.u64(seed))
key = Nx.Random.key(seed)
actual = apply(Nx.Random, name, [key | args])

assert_all_close(actual, expected)
Expand Down Expand Up @@ -261,7 +261,7 @@ defmodule Nx.RandomTest do
|> assert_all_close(apply(expected_func, expected_args), rtol: 0.1)

seed = :erlang.adler32("uniformthreefry2x32")
key = Nx.Random.key(Nx.tensor(seed, type: :u64))
key = Nx.Random.key(seed)
t = apply(Nx.Random, name, [key | args])

apply(Nx, moment, [t])
Expand Down
4 changes: 2 additions & 2 deletions torchx/test/torchx/random_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ defmodule Torchx.Nx.RandomTest do
describe "distributions" do
defp distribution_case(name, args: args, expected: expected) do
seed = :erlang.adler32("#{name}threefry2x32")
key = Nx.Random.key(Nx.u64(seed))
key = Nx.Random.key(seed)
actual = apply(Nx.Random, name, [key | args])

assert_all_close(actual, expected)
Expand Down Expand Up @@ -209,7 +209,7 @@ defmodule Torchx.Nx.RandomTest do
|> assert_all_close(apply(expected_func, expected_args), rtol: 0.1)

seed = :erlang.adler32("#{name}threefry2x32")
key = Nx.Random.key(Nx.tensor(seed, type: :u64))
key = Nx.Random.key(seed)
t = apply(Nx.Random, name, [key | args])

apply(Nx, moment, [t])
Expand Down
Loading