From 83419bbb4a442c0b1d5dfd2d182cf554c8137f81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Sun, 12 May 2024 18:21:17 +0200 Subject: [PATCH] Revert accidental change to optional take --- nx/lib/nx.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 4d3c3d43fd..add9fcd605 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -14132,7 +14132,7 @@ defmodule Nx do indices = devectorize(indices, keep_names: false) out = %{tensor | shape: inner_shape, names: inner_names} - Nx.Shared.optional(:take, [tensor, indices, axis], out, fn tensor, indices, axis -> + Nx.Shared.optional(:take, [tensor, indices, [axis: axis]], out, fn tensor, indices, _opts -> gather_indices = new_axis(indices, rank(indices)) {indices_axes, tensor_axes} = Enum.split(axes(inner_shape), rank(indices)) {leading, trailing} = Enum.split(tensor_axes, axis)