Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Experimental) Integrate Metal PjRt plugin #1504

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,10 @@ ERL_NIF_TERM mlir_compile(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
build_options.set_use_spmd_partitioning(use_spmd);

bool compile_portable_executable = false;
if (device_id >= 0) {

bool is_mps = (*client)->client()->platform_name() == "METAL";

if (device_id >= 0 && !is_mps) {
compile_portable_executable = true;
build_options.set_device_ordinal(device_id);
}
Expand Down Expand Up @@ -877,6 +880,16 @@ ERL_NIF_TERM get_tpu_client(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[])
return exla::nif::ok(env, exla::nif::make<exla::ExlaClient*>(env, client));
}

ERL_NIF_TERM get_mps_client(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 0) {
return exla::nif::error(env, "Bad argument count.");
}

EXLA_ASSIGN_OR_RETURN_NIF(exla::ExlaClient * client, exla::GetMpsClient(), env);

return exla::nif::ok(env, exla::nif::make<exla::ExlaClient*>(env, client));
}

ERL_NIF_TERM get_c_api_client(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 1) {
return exla::nif::error(env, "Bad argument count.");
Expand Down Expand Up @@ -1065,6 +1078,7 @@ static ErlNifFunc exla_funcs[] = {
{"get_host_client", 0, get_host_client},
{"get_gpu_client", 2, get_gpu_client},
{"get_tpu_client", 0, get_tpu_client},
{"get_mps_client", 0, get_mps_client},
{"get_c_api_client", 1, get_c_api_client},
{"load_pjrt_plugin", 2, load_pjrt_plugin},
{"get_device_count", 1, get_device_count},
Expand Down
24 changes: 24 additions & 0 deletions exla/c_src/exla/exla_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,30 @@ xla::StatusOr<ExlaClient*> GetTpuClient() {
return new ExlaClient(std::move(client));
}

xla::StatusOr<ExlaClient*> GetMpsClient() {
// The plugin may be compiled for a different version of PjRt C API
// than present in our XLA compilation. By default pjrt::LoadPjrtPlugin
// raises if the version does not match. By setting this environment
// variable, we relax this check to allow different versions, as long
// as they satisfy compatibility constraints.
//
// See https://github.com/openxla/xla/blob/4e8e23f16bc925b6f27817de098a8e1e81296bb5/xla/pjrt/pjrt_api.cc
setenv("ENABLE_PJRT_COMPATIBILITY", "1", 1);

EXLA_ASSIGN_OR_RETURN(const PJRT_Api* pjrt_api, pjrt::LoadPjrtPlugin("METAL", "pjrt_plugin_metal.dylib"));

xla::Status status = pjrt::InitializePjrtPlugin("METAL");

if (!status.ok()) {
return status;
}

EXLA_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtClient> client,
xla::GetCApiClient("METAL"));

return new ExlaClient(std::move(client));
}

xla::StatusOr<ExlaClient*> GetCApiClient(std::string device_type) {
EXLA_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtClient> client,
xla::GetCApiClient(device_type));
Expand Down
2 changes: 2 additions & 0 deletions exla/c_src/exla/exla_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ xla::StatusOr<ExlaClient*> GetGpuClient(double memory_fraction,

xla::StatusOr<ExlaClient*> GetTpuClient();

xla::StatusOr<ExlaClient*> GetMpsClient();

xla::StatusOr<ExlaClient*> GetCApiClient(std::string device_type);
} // namespace exla

Expand Down
3 changes: 3 additions & 0 deletions exla/lib/exla/client.ex
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ defmodule EXLA.Client do
:tpu ->
EXLA.NIF.get_tpu_client()

:mps ->
EXLA.NIF.get_mps_client()

_ ->
raise ArgumentError, "unknown EXLA platform: #{inspect(platform)}"
end
Expand Down
4 changes: 4 additions & 0 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,10 @@ defmodule EXLA.Defn do
) do
precision = state.precision

# Ensure both have the same type
left = to_type(left, ans.type)
right = to_type(right, ans.type)

Value.dot_general(
left,
right,
Expand Down
12 changes: 9 additions & 3 deletions exla/lib/exla/mlir/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -672,9 +672,15 @@ defmodule EXLA.MLIR.Value do
typespecs
) do
result_types = typespecs_to_mlir_types(typespecs)
regions = [on_true, on_false]
pred = convert(pred, Typespec.tensor({:pred, 8}, {}))
op(func, "stablehlo.if", [pred], result_types, regions: regions)

# TODO Jax does not support stablehlo.if, they use stablhelo.case instead.
# It most likely makes sense for use to do the same. That said, note that
# stablehlo.case is implemented for Metal, but does not lower reliably.
# Reported in https://github.com/google/jax/issues/21601

regions = [on_false, on_true]
pred = convert(pred, Typespec.tensor({:s, 32}, {}))
op(func, "stablehlo.case", [pred], result_types, regions: regions)
end

def infeed(%Value{function: func} = token, typespecs) do
Expand Down
2 changes: 2 additions & 0 deletions exla/lib/exla/nif.ex
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ defmodule EXLA.NIF do

def get_tpu_client(), do: :erlang.nif_error(:undef)

def get_mps_client(), do: :erlang.nif_error(:undef)

def get_supported_platforms, do: :erlang.nif_error(:undef)

def get_device_count(_client),
Expand Down
23 changes: 22 additions & 1 deletion exla/mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ defmodule EXLA.MixProject do
cuda: [platform: :cuda],
rocm: [platform: :rocm],
tpu: [platform: :tpu],
mps: [platform: :mps],
host: [platform: :host]
],
preferred_clients: [:cuda, :rocm, :tpu, :host]
preferred_clients: [:cuda, :rocm, :tpu, :mps, :host]
]
]
end
Expand Down Expand Up @@ -129,11 +130,31 @@ defmodule EXLA.MixProject do
:ok -> File.write!(xla_snapshot_path, xla_archive_path)
{:error, term} -> Mix.raise("failed to extract xla archive, reason: #{inspect(term)}")
end

# TODO should be packed into the XLA archive
download_metal_plugin!(xla_extension_path)
end

{:ok, []}
end

defp download_metal_plugin!(xla_extension_path) do
plugin_path = Path.join(xla_extension_path, "lib/pjrt_plugin_metal.dylib")

wheel_url =
"https://files.pythonhosted.org/packages/09/dc/6d8fbfc29d902251cf333414cf7dcfaf4b252a9920c881354584ed36270d/jax_metal-0.1.1-py3-none-macosx_13_0_arm64.whl"

wheel_path = Path.join(xla_extension_path, "jax_metal.whl")

{_, 0} = System.shell("wget --output-document=#{wheel_path} #{wheel_url}")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wget isn't available in macOS by default, and the default PATH in livebook doesn't include Homebrew's /bin (that might be my own setup), so it would be more portable if this was curl.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is just a proof of concept, this would be moved and downloaded as part of elixir-nx/xla :)

{_, 0} = System.shell("unzip #{wheel_path} -d #{xla_extension_path}")

wheel_plugin_path =
Path.join(xla_extension_path, "jax_plugins/metal_plugin/pjrt_plugin_metal_14.dylib")

File.cp!(wheel_plugin_path, plugin_path)
end

defp cached_make(args) do
force_rebuild_env_var = System.get_env("EXLA_FORCE_REBUILD", "")

Expand Down
73 changes: 72 additions & 1 deletion exla/test/exla/backend_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,75 @@ defmodule EXLA.BackendTest do
@skip_mac_arm []
end

if EXLA.Client.default_name() == :mps do
@skip_mps [
# Missing support for "stablehlo.reduce_window".
# Reported in https://github.com/google/jax/issues/21387
window_max: 3,
window_min: 3,
window_sum: 3,
window_product: 3,
window_reduce: 5,
window_scatter_min: 5,
window_scatter_max: 5,
window_mean: 3,
# (edge case) Argmax/argmin return wrong value in case of NaN.
# Reported in https://github.com/google/jax/issues/21821
argmin: 2,
argmax: 2,
# Missing support for general "stablehlo.reduce". Some cases work
# becuase they are special-cased.
# Reported in https://github.com/google/jax/issues/21384
reduce: 4,
# Missing support for "stablehlo.popcnt", "stablehlo.count_leading_zeros",
# "stablehlo.cbrt".
# Reported in https://github.com/google/jax/issues/21389
count_leading_zeros: 1,
population_count: 1,
cbrt: 1,
# Matrix multiplication for integers is not supported
dot: 2,
dot: 4,
dot: 6,
covariance: 3,
# (edge case) Put slice with overflowing slice, different behaviour.
# Reported in https://github.com/google/jax/issues/21392
put_slice: 3,
# (edge case) Slice with overflowing index, different behaviour.
# Reported in https://github.com/google/jax/issues/21393
slice: 4,
# (edge case) Top-k wrong behaviour with NaNs.
# Reported in https://github.com/google/jax/issues/21397
top_k: 2,
# Missing support for complex numbers.
# Tracked in https://github.com/google/jax/issues/16416
complex: 2,
conjugate: 1,
conv: 3,
fft: 2,
fft2: 2,
ifft: 2,
ifft2: 2,
imag: 1,
is_infinity: 1,
is_nan: 1,
phase: 1,
real: 1,
sigil_MAT: 2,
# Missing support for float-64.
# Tracked in https://github.com/google/jax/issues/20938
iota: 2,
as_type: 2,
atan2: 2,
# Missing support for u2/s2
bit_size: 1
]
else
@skip_mps []
end

doctest Nx,
except: [:moduledoc] ++ @excluded_doctests ++ @skip_mac_arm
except: [:moduledoc] ++ @excluded_doctests ++ @skip_mac_arm ++ @skip_mps

test "Nx.to_binary/1" do
t = Nx.tensor([1, 2, 3, 4], backend: EXLA.Backend)
Expand Down Expand Up @@ -199,6 +266,8 @@ defmodule EXLA.BackendTest do
end

describe "quantized types" do
# TODO mising support for s2
@tag :skip
test "s2" do
tensor = Nx.s2(-1)
assert <<-1::2-signed-native>> = Nx.to_binary(tensor)
Expand Down Expand Up @@ -237,6 +306,8 @@ defmodule EXLA.BackendTest do
assert 28 = Nx.bit_size(tensor)
end

# TODO mising support for u2
@tag :skip
test "u2" do
tensor = Nx.u2(1)
assert <<1::2-native>> = Nx.to_binary(tensor)
Expand Down
10 changes: 10 additions & 0 deletions nx/lib/nx.ex
Original file line number Diff line number Diff line change
Expand Up @@ -7941,6 +7941,16 @@ defmodule Nx do
end
end

# TODO remove this, or make it an optinal callback
# (Metal does not support stablehlo.logistic yet)
def sigmoid(x) do
x
|> Nx.negate()
|> Nx.exp()
|> Nx.add(1)
|> then(&Nx.divide(1, &1))
end

## Unary ops
@disallow_complex_type_unary_ops [:erf, :erfc, :erf_inv]

Expand Down
5 changes: 4 additions & 1 deletion nx/lib/nx/lin_alg/cholesky.ex
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ defmodule Nx.LinAlg.Cholesky do

{l, _} =
while {l = Nx.multiply(0.0, a), {a}}, i <- 0..(n - 1) do
{l, _} =
# TODO bring back the assignment (dynamic slice inside while causes a segfault)
# Reported in https://github.com/google/jax/issues/21552 and https://github.com/jax-ml/jax/issues/23931
# {l, _} =
_result =
while {l, {a, i, j = 0}}, j <= i do
value =
if i == j do
Expand Down
4 changes: 3 additions & 1 deletion nx/lib/nx/lin_alg/qr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ defmodule Nx.LinAlg.QR do
base_h = Nx.eye({m, m}, type: type, vectorized_axes: a.vectorized_axes)
column_iota = Nx.iota({Nx.axis_size(a, 0)}, vectorized_axes: a.vectorized_axes)

# TODO remove :unroll (dynamic slice inside while causes a segfault)
# Reported in https://github.com/google/jax/issues/21552 and https://github.com/jax-ml/jax/issues/23931
{{q, r}, _} =
while {{q = base_h, r = Nx.as_type(a, type)}, {column_iota}}, i <- 0..max_i//1 do
while {{q = base_h, r = Nx.as_type(a, type)}, {column_iota}}, i <- 0..max_i//1, unroll: true do
x = r[[.., i]]
x = Nx.select(column_iota < i, 0, x)
h = householder_reflector(x, i, eps)
Expand Down
Loading