diff --git a/lib/scholar/options.ex b/lib/scholar/options.ex index b4505eb3..e0173ad1 100644 --- a/lib/scholar/options.ex +++ b/lib/scholar/options.ex @@ -102,7 +102,7 @@ defmodule Scholar.Options do end def beta(beta) do - if (is_number(beta) and beta >= 0) or (Nx.is_tensor(beta) and Nx.size(beta) == 1) do + if (is_number(beta) and beta >= 0) or (Nx.is_tensor(beta) and Nx.rank(beta) == 0) do {:ok, beta} else {:error, "expect 'beta' to be in the range [0, inf]"}