diff --git a/exla/c_src/exla/exla.cc b/exla/c_src/exla/exla.cc index 42d0cbb9a4..031ed8faf5 100644 --- a/exla/c_src/exla/exla.cc +++ b/exla/c_src/exla/exla.cc @@ -144,7 +144,10 @@ ERL_NIF_TERM mlir_compile(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { build_options.set_use_spmd_partitioning(use_spmd); bool compile_portable_executable = false; - if (device_id >= 0) { + + bool is_mps = (*client)->client()->platform_name() == "METAL"; + + if (device_id >= 0 && !is_mps) { compile_portable_executable = true; build_options.set_device_ordinal(device_id); } @@ -877,6 +880,16 @@ ERL_NIF_TERM get_tpu_client(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) return exla::nif::ok(env, exla::nif::make(env, client)); } +ERL_NIF_TERM get_mps_client(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { + if (argc != 0) { + return exla::nif::error(env, "Bad argument count."); + } + + EXLA_ASSIGN_OR_RETURN_NIF(exla::ExlaClient * client, exla::GetMpsClient(), env); + + return exla::nif::ok(env, exla::nif::make(env, client)); +} + ERL_NIF_TERM get_c_api_client(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) { if (argc != 1) { return exla::nif::error(env, "Bad argument count."); @@ -1065,6 +1078,7 @@ static ErlNifFunc exla_funcs[] = { {"get_host_client", 0, get_host_client}, {"get_gpu_client", 2, get_gpu_client}, {"get_tpu_client", 0, get_tpu_client}, + {"get_mps_client", 0, get_mps_client}, {"get_c_api_client", 1, get_c_api_client}, {"load_pjrt_plugin", 2, load_pjrt_plugin}, {"get_device_count", 1, get_device_count}, diff --git a/exla/c_src/exla/exla_client.cc b/exla/c_src/exla/exla_client.cc index b6bba1806b..884c2a32d6 100644 --- a/exla/c_src/exla/exla_client.cc +++ b/exla/c_src/exla/exla_client.cc @@ -495,6 +495,30 @@ xla::StatusOr GetTpuClient() { return new ExlaClient(std::move(client)); } +xla::StatusOr GetMpsClient() { + // The plugin may be compiled for a different version of PjRt C API + // than present in our XLA compilation. By default pjrt::LoadPjrtPlugin + // raises if the version does not match. By setting this environment + // variable, we relax this check to allow different versions, as long + // as they satisfy compatibility constraints. + // + // See https://github.com/openxla/xla/blob/4e8e23f16bc925b6f27817de098a8e1e81296bb5/xla/pjrt/pjrt_api.cc + setenv("ENABLE_PJRT_COMPATIBILITY", "1", 1); + + EXLA_ASSIGN_OR_RETURN(const PJRT_Api* pjrt_api, pjrt::LoadPjrtPlugin("METAL", "pjrt_plugin_metal.dylib")); + + xla::Status status = pjrt::InitializePjrtPlugin("METAL"); + + if (!status.ok()) { + return status; + } + + EXLA_ASSIGN_OR_RETURN(std::unique_ptr client, + xla::GetCApiClient("METAL")); + + return new ExlaClient(std::move(client)); +} + xla::StatusOr GetCApiClient(std::string device_type) { EXLA_ASSIGN_OR_RETURN(std::unique_ptr client, xla::GetCApiClient(device_type)); diff --git a/exla/c_src/exla/exla_client.h b/exla/c_src/exla/exla_client.h index 229853c4c9..26156af369 100644 --- a/exla/c_src/exla/exla_client.h +++ b/exla/c_src/exla/exla_client.h @@ -110,6 +110,8 @@ xla::StatusOr GetGpuClient(double memory_fraction, xla::StatusOr GetTpuClient(); +xla::StatusOr GetMpsClient(); + xla::StatusOr GetCApiClient(std::string device_type); } // namespace exla diff --git a/exla/lib/exla/client.ex b/exla/lib/exla/client.ex index 688be5a70f..def253af24 100644 --- a/exla/lib/exla/client.ex +++ b/exla/lib/exla/client.ex @@ -159,6 +159,9 @@ defmodule EXLA.Client do :tpu -> EXLA.NIF.get_tpu_client() + :mps -> + EXLA.NIF.get_mps_client() + _ -> raise ArgumentError, "unknown EXLA platform: #{inspect(platform)}" end diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index fbf392862f..4350bc7de3 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -712,6 +712,10 @@ defmodule EXLA.Defn do ) do precision = state.precision + # Ensure both have the same type + left = to_type(left, ans.type) + right = to_type(right, ans.type) + Value.dot_general( left, right, diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 2b25c6f8f6..8f2ba7894f 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -672,9 +672,15 @@ defmodule EXLA.MLIR.Value do typespecs ) do result_types = typespecs_to_mlir_types(typespecs) - regions = [on_true, on_false] - pred = convert(pred, Typespec.tensor({:pred, 8}, {})) - op(func, "stablehlo.if", [pred], result_types, regions: regions) + + # TODO Jax does not support stablehlo.if, they use stablhelo.case instead. + # It most likely makes sense for use to do the same. That said, note that + # stablehlo.case is implemented for Metal, but does not lower reliably. + # Reported in https://github.com/google/jax/issues/21601 + + regions = [on_false, on_true] + pred = convert(pred, Typespec.tensor({:s, 32}, {})) + op(func, "stablehlo.case", [pred], result_types, regions: regions) end def infeed(%Value{function: func} = token, typespecs) do diff --git a/exla/lib/exla/nif.ex b/exla/lib/exla/nif.ex index 023a0bcbd2..f86d513b17 100644 --- a/exla/lib/exla/nif.ex +++ b/exla/lib/exla/nif.ex @@ -67,6 +67,8 @@ defmodule EXLA.NIF do def get_tpu_client(), do: :erlang.nif_error(:undef) + def get_mps_client(), do: :erlang.nif_error(:undef) + def get_supported_platforms, do: :erlang.nif_error(:undef) def get_device_count(_client), diff --git a/exla/mix.exs b/exla/mix.exs index 7ce6727d2f..1f2e0c908f 100644 --- a/exla/mix.exs +++ b/exla/mix.exs @@ -52,9 +52,10 @@ defmodule EXLA.MixProject do cuda: [platform: :cuda], rocm: [platform: :rocm], tpu: [platform: :tpu], + mps: [platform: :mps], host: [platform: :host] ], - preferred_clients: [:cuda, :rocm, :tpu, :host] + preferred_clients: [:cuda, :rocm, :tpu, :mps, :host] ] ] end @@ -129,11 +130,31 @@ defmodule EXLA.MixProject do :ok -> File.write!(xla_snapshot_path, xla_archive_path) {:error, term} -> Mix.raise("failed to extract xla archive, reason: #{inspect(term)}") end + + # TODO should be packed into the XLA archive + download_metal_plugin!(xla_extension_path) end {:ok, []} end + defp download_metal_plugin!(xla_extension_path) do + plugin_path = Path.join(xla_extension_path, "lib/pjrt_plugin_metal.dylib") + + wheel_url = + "https://files.pythonhosted.org/packages/09/dc/6d8fbfc29d902251cf333414cf7dcfaf4b252a9920c881354584ed36270d/jax_metal-0.1.1-py3-none-macosx_13_0_arm64.whl" + + wheel_path = Path.join(xla_extension_path, "jax_metal.whl") + + {_, 0} = System.shell("wget --output-document=#{wheel_path} #{wheel_url}") + {_, 0} = System.shell("unzip #{wheel_path} -d #{xla_extension_path}") + + wheel_plugin_path = + Path.join(xla_extension_path, "jax_plugins/metal_plugin/pjrt_plugin_metal_14.dylib") + + File.cp!(wheel_plugin_path, plugin_path) + end + defp cached_make(args) do force_rebuild_env_var = System.get_env("EXLA_FORCE_REBUILD", "") diff --git a/exla/test/exla/backend_test.exs b/exla/test/exla/backend_test.exs index 22e3c60850..84cae0ee35 100644 --- a/exla/test/exla/backend_test.exs +++ b/exla/test/exla/backend_test.exs @@ -27,8 +27,75 @@ defmodule EXLA.BackendTest do @skip_mac_arm [] end + if EXLA.Client.default_name() == :mps do + @skip_mps [ + # Missing support for "stablehlo.reduce_window". + # Reported in https://github.com/google/jax/issues/21387 + window_max: 3, + window_min: 3, + window_sum: 3, + window_product: 3, + window_reduce: 5, + window_scatter_min: 5, + window_scatter_max: 5, + window_mean: 3, + # (edge case) Argmax/argmin return wrong value in case of NaN. + # Reported in https://github.com/google/jax/issues/21821 + argmin: 2, + argmax: 2, + # Missing support for general "stablehlo.reduce". Some cases work + # becuase they are special-cased. + # Reported in https://github.com/google/jax/issues/21384 + reduce: 4, + # Missing support for "stablehlo.popcnt", "stablehlo.count_leading_zeros", + # "stablehlo.cbrt". + # Reported in https://github.com/google/jax/issues/21389 + count_leading_zeros: 1, + population_count: 1, + cbrt: 1, + # Matrix multiplication for integers is not supported + dot: 2, + dot: 4, + dot: 6, + covariance: 3, + # (edge case) Put slice with overflowing slice, different behaviour. + # Reported in https://github.com/google/jax/issues/21392 + put_slice: 3, + # (edge case) Slice with overflowing index, different behaviour. + # Reported in https://github.com/google/jax/issues/21393 + slice: 4, + # (edge case) Top-k wrong behaviour with NaNs. + # Reported in https://github.com/google/jax/issues/21397 + top_k: 2, + # Missing support for complex numbers. + # Tracked in https://github.com/google/jax/issues/16416 + complex: 2, + conjugate: 1, + conv: 3, + fft: 2, + fft2: 2, + ifft: 2, + ifft2: 2, + imag: 1, + is_infinity: 1, + is_nan: 1, + phase: 1, + real: 1, + sigil_MAT: 2, + # Missing support for float-64. + # Tracked in https://github.com/google/jax/issues/20938 + iota: 2, + as_type: 2, + atan2: 2, + # Missing support for u2/s2 + bit_size: 1 + ] + else + @skip_mps [] + end + doctest Nx, - except: [:moduledoc] ++ @excluded_doctests ++ @skip_mac_arm + except: [:moduledoc] ++ @excluded_doctests ++ @skip_mac_arm ++ @skip_mps test "Nx.to_binary/1" do t = Nx.tensor([1, 2, 3, 4], backend: EXLA.Backend) @@ -199,6 +266,8 @@ defmodule EXLA.BackendTest do end describe "quantized types" do + # TODO mising support for s2 + @tag :skip test "s2" do tensor = Nx.s2(-1) assert <<-1::2-signed-native>> = Nx.to_binary(tensor) @@ -237,6 +306,8 @@ defmodule EXLA.BackendTest do assert 28 = Nx.bit_size(tensor) end + # TODO mising support for u2 + @tag :skip test "u2" do tensor = Nx.u2(1) assert <<1::2-native>> = Nx.to_binary(tensor) diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index 3162efd10c..22d3175e78 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -7941,6 +7941,16 @@ defmodule Nx do end end + # TODO remove this, or make it an optinal callback + # (Metal does not support stablehlo.logistic yet) + def sigmoid(x) do + x + |> Nx.negate() + |> Nx.exp() + |> Nx.add(1) + |> then(&Nx.divide(1, &1)) + end + ## Unary ops @disallow_complex_type_unary_ops [:erf, :erfc, :erf_inv] diff --git a/nx/lib/nx/lin_alg/cholesky.ex b/nx/lib/nx/lin_alg/cholesky.ex index 8f79429cd1..5f3011575c 100644 --- a/nx/lib/nx/lin_alg/cholesky.ex +++ b/nx/lib/nx/lin_alg/cholesky.ex @@ -31,7 +31,10 @@ defmodule Nx.LinAlg.Cholesky do {l, _} = while {l = Nx.multiply(0.0, a), {a}}, i <- 0..(n - 1) do - {l, _} = + # TODO bring back the assignment (dynamic slice inside while causes a segfault) + # Reported in https://github.com/google/jax/issues/21552 and https://github.com/jax-ml/jax/issues/23931 + # {l, _} = + _result = while {l, {a, i, j = 0}}, j <= i do value = if i == j do diff --git a/nx/lib/nx/lin_alg/qr.ex b/nx/lib/nx/lin_alg/qr.ex index 9b428ea3cc..300a43e5d4 100644 --- a/nx/lib/nx/lin_alg/qr.ex +++ b/nx/lib/nx/lin_alg/qr.ex @@ -56,8 +56,10 @@ defmodule Nx.LinAlg.QR do base_h = Nx.eye({m, m}, type: type, vectorized_axes: a.vectorized_axes) column_iota = Nx.iota({Nx.axis_size(a, 0)}, vectorized_axes: a.vectorized_axes) + # TODO remove :unroll (dynamic slice inside while causes a segfault) + # Reported in https://github.com/google/jax/issues/21552 and https://github.com/jax-ml/jax/issues/23931 {{q, r}, _} = - while {{q = base_h, r = Nx.as_type(a, type)}, {column_iota}}, i <- 0..max_i//1 do + while {{q = base_h, r = Nx.as_type(a, type)}, {column_iota}}, i <- 0..max_i//1, unroll: true do x = r[[.., i]] x = Nx.select(column_iota < i, 0, x) h = householder_reflector(x, i, eps)