diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index b06f139418..d3f67763a5 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -647,12 +647,12 @@ defmodule EXLA.Defn do %{client: %EXLA.Client{platform: :host}, builder: %Function{}} = state, cache ) do - dbg({type_kind}) # We match only on platform: :host for MLIR, as we want to support # eigh-on-cpu as a custom call only in this case {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() - # convert to float and ensure that we're either using f32 or f64 + # convert to float and ensure that we're either using f32 or f64, because Eigen + # only supports f32 and f64 easily. out_type = Nx.Type.merge(Nx.Type.to_floating(eigenvecs_expr.type), {:f, 32}) tensor = @@ -669,8 +669,6 @@ defmodule EXLA.Defn do expr_to_typespec(%{eigenvals_expr | type: out_type}) ) - dbg(eigenvecs) - {[to_type(eigenvecs, eigenvecs_expr.type), to_type(eigenvals, eigenvals_expr.type)], cache} end