Skip to content

Commit

Permalink
Bump jax-metal to 0.1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Jun 12, 2024
1 parent 102e62b commit 8a851f1
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 12 deletions.
11 changes: 2 additions & 9 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -663,9 +663,7 @@ defmodule EXLA.Defn do
result =
Value.gather(
tensor,
# TODO remove conversion (unsigned indices fail)
# Reported in https://github.com/google/jax/issues/21547
to_type(indices, {:s, 32}),
indices,
index_vector_dim,
slice_sizes,
offset_dims,
Expand Down Expand Up @@ -1297,9 +1295,6 @@ defmodule EXLA.Defn do
defp to_operator(:put_slice, [%Value{} = tensor, start_indices, slice], ans, _state) do
tensor = to_type(tensor, ans.type)
slice = to_type(slice, ans.type)
# TODO remove conversion (unsigned indices fail)
# Reported in https://github.com/google/jax/issues/21547
start_indices = Enum.map(start_indices, &to_type(&1, {:s, 32}))
Value.dynamic_update_slice(tensor, slice, start_indices, expr_to_typespec(ans))
end

Expand All @@ -1322,9 +1317,7 @@ defmodule EXLA.Defn do

Value.gather(
tensor,
# TODO remove conversion (unsigned indices fail)
# Reported in https://github.com/google/jax/issues/21547
to_type(indices, {:s, 32}),
indices,
index_vector_dim,
slice_sizes,
offset_dims,
Expand Down
2 changes: 1 addition & 1 deletion exla/mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ defmodule EXLA.MixProject do
plugin_path = Path.join(xla_extension_path, "lib/pjrt_plugin_metal.dylib")

wheel_url =
"https://files.pythonhosted.org/packages/d6/4f/f5d128a493b7387fbbe0e6906544214af2a6b86af30302dd6ffb9dc66a74/jax_metal-0.0.7-py3-none-macosx_13_0_arm64.whl"
"https://files.pythonhosted.org/packages/80/af/ed482a421a868726e7ca3f51ac19b0c9a8e37f33f54413312c37e9056acc/jax_metal-0.1.0-py3-none-macosx_11_0_arm64.whl"

wheel_path = Path.join(xla_extension_path, "jax_metal.whl")

Expand Down
4 changes: 2 additions & 2 deletions exla/test/exla/backend_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ defmodule EXLA.BackendTest do
window_scatter_min: 5,
window_scatter_max: 5,
window_mean: 3,
# Argmax/armin fail when a custom :type is passed.
# Reported in https://github.com/google/jax/issues/21577
# (edge case) Argmax/argmin return wrong value in case of NaN.
# Reported in https://github.com/google/jax/issues/21821
argmin: 2,
argmax: 2,
# Missing support for general "stablehlo.reduce". Some cases work
Expand Down

0 comments on commit 8a851f1

Please sign in to comment.