From 8a851f18388a330e9eecd0853485fd6b850267e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Wed, 12 Jun 2024 17:44:20 +0700 Subject: [PATCH] Bump jax-metal to 0.1.0 --- exla/lib/exla/defn.ex | 11 ++--------- exla/mix.exs | 2 +- exla/test/exla/backend_test.exs | 4 ++-- 3 files changed, 5 insertions(+), 12 deletions(-) diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 57c7fa6f5a..ca871a54c6 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -663,9 +663,7 @@ defmodule EXLA.Defn do result = Value.gather( tensor, - # TODO remove conversion (unsigned indices fail) - # Reported in https://github.com/google/jax/issues/21547 - to_type(indices, {:s, 32}), + indices, index_vector_dim, slice_sizes, offset_dims, @@ -1297,9 +1295,6 @@ defmodule EXLA.Defn do defp to_operator(:put_slice, [%Value{} = tensor, start_indices, slice], ans, _state) do tensor = to_type(tensor, ans.type) slice = to_type(slice, ans.type) - # TODO remove conversion (unsigned indices fail) - # Reported in https://github.com/google/jax/issues/21547 - start_indices = Enum.map(start_indices, &to_type(&1, {:s, 32})) Value.dynamic_update_slice(tensor, slice, start_indices, expr_to_typespec(ans)) end @@ -1322,9 +1317,7 @@ defmodule EXLA.Defn do Value.gather( tensor, - # TODO remove conversion (unsigned indices fail) - # Reported in https://github.com/google/jax/issues/21547 - to_type(indices, {:s, 32}), + indices, index_vector_dim, slice_sizes, offset_dims, diff --git a/exla/mix.exs b/exla/mix.exs index daa67d91c5..b36425919e 100644 --- a/exla/mix.exs +++ b/exla/mix.exs @@ -141,7 +141,7 @@ defmodule EXLA.MixProject do plugin_path = Path.join(xla_extension_path, "lib/pjrt_plugin_metal.dylib") wheel_url = - "https://files.pythonhosted.org/packages/d6/4f/f5d128a493b7387fbbe0e6906544214af2a6b86af30302dd6ffb9dc66a74/jax_metal-0.0.7-py3-none-macosx_13_0_arm64.whl" + "https://files.pythonhosted.org/packages/80/af/ed482a421a868726e7ca3f51ac19b0c9a8e37f33f54413312c37e9056acc/jax_metal-0.1.0-py3-none-macosx_11_0_arm64.whl" wheel_path = Path.join(xla_extension_path, "jax_metal.whl") diff --git a/exla/test/exla/backend_test.exs b/exla/test/exla/backend_test.exs index ae83e82824..0182bbc4af 100644 --- a/exla/test/exla/backend_test.exs +++ b/exla/test/exla/backend_test.exs @@ -39,8 +39,8 @@ defmodule EXLA.BackendTest do window_scatter_min: 5, window_scatter_max: 5, window_mean: 3, - # Argmax/armin fail when a custom :type is passed. - # Reported in https://github.com/google/jax/issues/21577 + # (edge case) Argmax/argmin return wrong value in case of NaN. + # Reported in https://github.com/google/jax/issues/21821 argmin: 2, argmax: 2, # Missing support for general "stablehlo.reduce". Some cases work