Skip to content

Commit

Permalink
(Experimental) Integrate Metal PjRt plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Dec 10, 2024
1 parent b2fdb9a commit 5a50406
Show file tree
Hide file tree
Showing 12 changed files with 170 additions and 8 deletions.
16 changes: 15 additions & 1 deletion exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,10 @@ ERL_NIF_TERM mlir_compile(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
build_options.set_use_spmd_partitioning(use_spmd);

bool compile_portable_executable = false;
if (device_id >= 0) {

bool is_mps = (*client)->client()->platform_name() == "METAL";

if (device_id >= 0 && !is_mps) {
compile_portable_executable = true;
build_options.set_device_ordinal(device_id);
}
Expand Down Expand Up @@ -877,6 +880,16 @@ ERL_NIF_TERM get_tpu_client(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[])
return exla::nif::ok(env, exla::nif::make<exla::ExlaClient*>(env, client));
}

ERL_NIF_TERM get_mps_client(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 0) {
return exla::nif::error(env, "Bad argument count.");
}

EXLA_ASSIGN_OR_RETURN_NIF(exla::ExlaClient * client, exla::GetMpsClient(), env);

return exla::nif::ok(env, exla::nif::make<exla::ExlaClient*>(env, client));
}

ERL_NIF_TERM get_c_api_client(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 1) {
return exla::nif::error(env, "Bad argument count.");
Expand Down Expand Up @@ -1065,6 +1078,7 @@ static ErlNifFunc exla_funcs[] = {
{"get_host_client", 0, get_host_client},
{"get_gpu_client", 2, get_gpu_client},
{"get_tpu_client", 0, get_tpu_client},
{"get_mps_client", 0, get_mps_client},
{"get_c_api_client", 1, get_c_api_client},
{"load_pjrt_plugin", 2, load_pjrt_plugin},
{"get_device_count", 1, get_device_count},
Expand Down
24 changes: 24 additions & 0 deletions exla/c_src/exla/exla_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,30 @@ xla::StatusOr<ExlaClient*> GetTpuClient() {
return new ExlaClient(std::move(client));
}

xla::StatusOr<ExlaClient*> GetMpsClient() {
// The plugin may be compiled for a different version of PjRt C API
// than present in our XLA compilation. By default pjrt::LoadPjrtPlugin
// raises if the version does not match. By setting this environment
// variable, we relax this check to allow different versions, as long
// as they satisfy compatibility constraints.
//
// See https://github.com/openxla/xla/blob/4e8e23f16bc925b6f27817de098a8e1e81296bb5/xla/pjrt/pjrt_api.cc
setenv("ENABLE_PJRT_COMPATIBILITY", "1", 1);

EXLA_ASSIGN_OR_RETURN(const PJRT_Api* pjrt_api, pjrt::LoadPjrtPlugin("METAL", "pjrt_plugin_metal.dylib"));

xla::Status status = pjrt::InitializePjrtPlugin("METAL");

if (!status.ok()) {
return status;
}

EXLA_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtClient> client,
xla::GetCApiClient("METAL"));

return new ExlaClient(std::move(client));
}

xla::StatusOr<ExlaClient*> GetCApiClient(std::string device_type) {
EXLA_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtClient> client,
xla::GetCApiClient(device_type));
Expand Down
2 changes: 2 additions & 0 deletions exla/c_src/exla/exla_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ xla::StatusOr<ExlaClient*> GetGpuClient(double memory_fraction,

xla::StatusOr<ExlaClient*> GetTpuClient();

xla::StatusOr<ExlaClient*> GetMpsClient();

xla::StatusOr<ExlaClient*> GetCApiClient(std::string device_type);
} // namespace exla

Expand Down
3 changes: 3 additions & 0 deletions exla/lib/exla/client.ex
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ defmodule EXLA.Client do
:tpu ->
EXLA.NIF.get_tpu_client()

:mps ->
EXLA.NIF.get_mps_client()

_ ->
raise ArgumentError, "unknown EXLA platform: #{inspect(platform)}"
end
Expand Down
4 changes: 4 additions & 0 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,10 @@ defmodule EXLA.Defn do
) do
precision = state.precision

# Ensure both have the same type
left = to_type(left, ans.type)
right = to_type(right, ans.type)

Value.dot_general(
left,
right,
Expand Down
12 changes: 9 additions & 3 deletions exla/lib/exla/mlir/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -672,9 +672,15 @@ defmodule EXLA.MLIR.Value do
typespecs
) do
result_types = typespecs_to_mlir_types(typespecs)
regions = [on_true, on_false]
pred = convert(pred, Typespec.tensor({:pred, 8}, {}))
op(func, "stablehlo.if", [pred], result_types, regions: regions)

# TODO Jax does not support stablehlo.if, they use stablhelo.case instead.
# It most likely makes sense for use to do the same. That said, note that
# stablehlo.case is implemented for Metal, but does not lower reliably.
# Reported in https://github.com/google/jax/issues/21601

regions = [on_false, on_true]
pred = convert(pred, Typespec.tensor({:s, 32}, {}))
op(func, "stablehlo.case", [pred], result_types, regions: regions)
end

def infeed(%Value{function: func} = token, typespecs) do
Expand Down
2 changes: 2 additions & 0 deletions exla/lib/exla/nif.ex
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ defmodule EXLA.NIF do

def get_tpu_client(), do: :erlang.nif_error(:undef)

def get_mps_client(), do: :erlang.nif_error(:undef)

def get_supported_platforms, do: :erlang.nif_error(:undef)

def get_device_count(_client),
Expand Down
23 changes: 22 additions & 1 deletion exla/mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ defmodule EXLA.MixProject do
cuda: [platform: :cuda],
rocm: [platform: :rocm],
tpu: [platform: :tpu],
mps: [platform: :mps],
host: [platform: :host]
],
preferred_clients: [:cuda, :rocm, :tpu, :host]
preferred_clients: [:cuda, :rocm, :tpu, :mps, :host]
]
]
end
Expand Down Expand Up @@ -129,11 +130,31 @@ defmodule EXLA.MixProject do
:ok -> File.write!(xla_snapshot_path, xla_archive_path)
{:error, term} -> Mix.raise("failed to extract xla archive, reason: #{inspect(term)}")
end

# TODO should be packed into the XLA archive
download_metal_plugin!(xla_extension_path)
end

{:ok, []}
end

defp download_metal_plugin!(xla_extension_path) do
plugin_path = Path.join(xla_extension_path, "lib/pjrt_plugin_metal.dylib")

wheel_url =
"https://files.pythonhosted.org/packages/09/dc/6d8fbfc29d902251cf333414cf7dcfaf4b252a9920c881354584ed36270d/jax_metal-0.1.1-py3-none-macosx_13_0_arm64.whl"

wheel_path = Path.join(xla_extension_path, "jax_metal.whl")

{_, 0} = System.shell("wget --output-document=#{wheel_path} #{wheel_url}")
{_, 0} = System.shell("unzip #{wheel_path} -d #{xla_extension_path}")

wheel_plugin_path =
Path.join(xla_extension_path, "jax_plugins/metal_plugin/pjrt_plugin_metal_14.dylib")

File.cp!(wheel_plugin_path, plugin_path)
end

defp cached_make(args) do
force_rebuild_env_var = System.get_env("EXLA_FORCE_REBUILD", "")

Expand Down
73 changes: 72 additions & 1 deletion exla/test/exla/backend_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,75 @@ defmodule EXLA.BackendTest do
@skip_mac_arm []
end

if EXLA.Client.default_name() == :mps do
@skip_mps [
# Missing support for "stablehlo.reduce_window".
# Reported in https://github.com/google/jax/issues/21387
window_max: 3,
window_min: 3,
window_sum: 3,
window_product: 3,
window_reduce: 5,
window_scatter_min: 5,
window_scatter_max: 5,
window_mean: 3,
# (edge case) Argmax/argmin return wrong value in case of NaN.
# Reported in https://github.com/google/jax/issues/21821
argmin: 2,
argmax: 2,
# Missing support for general "stablehlo.reduce". Some cases work
# becuase they are special-cased.
# Reported in https://github.com/google/jax/issues/21384
reduce: 4,
# Missing support for "stablehlo.popcnt", "stablehlo.count_leading_zeros",
# "stablehlo.cbrt".
# Reported in https://github.com/google/jax/issues/21389
count_leading_zeros: 1,
population_count: 1,
cbrt: 1,
# Matrix multiplication for integers is not supported
dot: 2,
dot: 4,
dot: 6,
covariance: 3,
# (edge case) Put slice with overflowing slice, different behaviour.
# Reported in https://github.com/google/jax/issues/21392
put_slice: 3,
# (edge case) Slice with overflowing index, different behaviour.
# Reported in https://github.com/google/jax/issues/21393
slice: 4,
# (edge case) Top-k wrong behaviour with NaNs.
# Reported in https://github.com/google/jax/issues/21397
top_k: 2,
# Missing support for complex numbers.
# Tracked in https://github.com/google/jax/issues/16416
complex: 2,
conjugate: 1,
conv: 3,
fft: 2,
fft2: 2,
ifft: 2,
ifft2: 2,
imag: 1,
is_infinity: 1,
is_nan: 1,
phase: 1,
real: 1,
sigil_MAT: 2,
# Missing support for float-64.
# Tracked in https://github.com/google/jax/issues/20938
iota: 2,
as_type: 2,
atan2: 2,
# Missing support for u2/s2
bit_size: 1
]
else
@skip_mps []
end

doctest Nx,
except: [:moduledoc] ++ @excluded_doctests ++ @skip_mac_arm
except: [:moduledoc] ++ @excluded_doctests ++ @skip_mac_arm ++ @skip_mps

test "Nx.to_binary/1" do
t = Nx.tensor([1, 2, 3, 4], backend: EXLA.Backend)
Expand Down Expand Up @@ -199,6 +266,8 @@ defmodule EXLA.BackendTest do
end

describe "quantized types" do
# TODO mising support for s2
@tag :skip
test "s2" do
tensor = Nx.s2(-1)
assert <<-1::2-signed-native>> = Nx.to_binary(tensor)
Expand Down Expand Up @@ -237,6 +306,8 @@ defmodule EXLA.BackendTest do
assert 28 = Nx.bit_size(tensor)
end

# TODO mising support for u2
@tag :skip
test "u2" do
tensor = Nx.u2(1)
assert <<1::2-native>> = Nx.to_binary(tensor)
Expand Down
10 changes: 10 additions & 0 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -7941,6 +7941,16 @@ defmodule Nx do
end
end

# TODO remove this, or make it an optinal callback
# (Metal does not support stablehlo.logistic yet)
def sigmoid(x) do
x
|> Nx.negate()
|> Nx.exp()
|> Nx.add(1)
|> then(&Nx.divide(1, &1))
end

## Unary ops
@disallow_complex_type_unary_ops [:erf, :erfc, :erf_inv]

Expand Down
5 changes: 4 additions & 1 deletion nx/lib/nx/lin_alg/cholesky.ex
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ defmodule Nx.LinAlg.Cholesky do

{l, _} =
while {l = Nx.multiply(0.0, a), {a}}, i <- 0..(n - 1) do
{l, _} =
# TODO bring back the assignment (dynamic slice inside while causes a segfault)
# Reported in https://github.com/google/jax/issues/21552 and https://github.com/jax-ml/jax/issues/23931
# {l, _} =
_result =
while {l, {a, i, j = 0}}, j <= i do
value =
if i == j do
Expand Down
4 changes: 3 additions & 1 deletion nx/lib/nx/lin_alg/qr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ defmodule Nx.LinAlg.QR do
base_h = Nx.eye({m, m}, type: type, vectorized_axes: a.vectorized_axes)
column_iota = Nx.iota({Nx.axis_size(a, 0)}, vectorized_axes: a.vectorized_axes)

# TODO remove :unroll (dynamic slice inside while causes a segfault)
# Reported in https://github.com/google/jax/issues/21552 and https://github.com/jax-ml/jax/issues/23931
{{q, r}, _} =
while {{q = base_h, r = Nx.as_type(a, type)}, {column_iota}}, i <- 0..max_i//1 do
while {{q = base_h, r = Nx.as_type(a, type)}, {column_iota}}, i <- 0..max_i//1, unroll: true do
x = r[[.., i]]
x = Nx.select(column_iota < i, 0, x)
h = householder_reflector(x, i, eps)
Expand Down

0 comments on commit 5a50406

Please sign in to comment.