Skip to content

Commit

Permalink
(Experimental) Integrate Metal PjRt plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Jun 4, 2024
1 parent 2f3c6ef commit 102e62b
Show file tree
Hide file tree
Showing 13 changed files with 220 additions and 30 deletions.
16 changes: 15 additions & 1 deletion exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@ ERL_NIF_TERM mlir_compile(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
build_options.set_use_spmd_partitioning(use_spmd);

bool compile_portable_executable = false;
if (device_id >= 0) {

bool is_mps = (*client)->client()->platform_name() == "METAL";

if (device_id >= 0 && !is_mps) {
compile_portable_executable = true;
build_options.set_device_ordinal(device_id);
}
Expand Down Expand Up @@ -728,6 +731,16 @@ ERL_NIF_TERM get_tpu_client(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[])
return exla::nif::ok(env, exla::nif::make<exla::ExlaClient*>(env, client));
}

ERL_NIF_TERM get_mps_client(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 0) {
return exla::nif::error(env, "Bad argument count.");
}

EXLA_ASSIGN_OR_RETURN_NIF(exla::ExlaClient * client, exla::GetMpsClient(), env);

return exla::nif::ok(env, exla::nif::make<exla::ExlaClient*>(env, client));
}

ERL_NIF_TERM get_c_api_client(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 1) {
return exla::nif::error(env, "Bad argument count.");
Expand Down Expand Up @@ -915,6 +928,7 @@ static ErlNifFunc exla_funcs[] = {
{"get_host_client", 0, get_host_client},
{"get_gpu_client", 2, get_gpu_client},
{"get_tpu_client", 0, get_tpu_client},
{"get_mps_client", 0, get_mps_client},
{"get_c_api_client", 1, get_c_api_client},
{"load_pjrt_plugin", 2, load_pjrt_plugin},
{"get_device_count", 1, get_device_count},
Expand Down
24 changes: 24 additions & 0 deletions exla/c_src/exla/exla_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,30 @@ xla::StatusOr<ExlaClient*> GetTpuClient() {
return new ExlaClient(std::move(client));
}

xla::StatusOr<ExlaClient*> GetMpsClient() {
// The plugin may be compiled for a different version of PjRt C API
// than present in our XLA compilation. By default pjrt::LoadPjrtPlugin
// raises if the version does not match. By setting this environment
// variable, we relax this check to allow different versions, as long
// as they satisfy compatibility constraints.
//
// See https://github.com/openxla/xla/blob/4e8e23f16bc925b6f27817de098a8e1e81296bb5/xla/pjrt/pjrt_api.cc
setenv("ENABLE_PJRT_COMPATIBILITY", "1", 1);

EXLA_EFFECT_OR_RETURN(pjrt::LoadPjrtPlugin("METAL", "pjrt_plugin_metal.dylib"));

xla::Status status = pjrt::InitializePjrtPlugin("METAL");

if (!status.ok()) {
return status;
}

EXLA_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtClient> client,
xla::GetCApiClient("METAL"));

return new ExlaClient(std::move(client));
}

xla::StatusOr<ExlaClient*> GetCApiClient(std::string device_type) {
EXLA_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtClient> client,
xla::GetCApiClient(device_type));
Expand Down
2 changes: 2 additions & 0 deletions exla/c_src/exla/exla_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ xla::StatusOr<ExlaClient*> GetGpuClient(double memory_fraction,

xla::StatusOr<ExlaClient*> GetTpuClient();

xla::StatusOr<ExlaClient*> GetMpsClient();

xla::StatusOr<ExlaClient*> GetCApiClient(std::string device_type);
} // namespace exla

Expand Down
3 changes: 3 additions & 0 deletions exla/lib/exla/client.ex
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ defmodule EXLA.Client do
:tpu ->
EXLA.NIF.get_tpu_client()

:mps ->
EXLA.NIF.get_mps_client()

_ ->
raise ArgumentError, "unknown EXLA platform: #{inspect(platform)}"
end
Expand Down
60 changes: 43 additions & 17 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,9 @@ defmodule EXLA.Defn do
result =
Value.gather(
tensor,
indices,
# TODO remove conversion (unsigned indices fail)
# Reported in https://github.com/google/jax/issues/21547
to_type(indices, {:s, 32}),
index_vector_dim,
slice_sizes,
offset_dims,
Expand Down Expand Up @@ -871,6 +873,10 @@ defmodule EXLA.Defn do
) do
precision = state.precision

# Ensure both have the same type
left = to_type(left, ans.type)
right = to_type(right, ans.type)

Value.dot_general(
left,
right,
Expand Down Expand Up @@ -1291,6 +1297,9 @@ defmodule EXLA.Defn do
defp to_operator(:put_slice, [%Value{} = tensor, start_indices, slice], ans, _state) do
tensor = to_type(tensor, ans.type)
slice = to_type(slice, ans.type)
# TODO remove conversion (unsigned indices fail)
# Reported in https://github.com/google/jax/issues/21547
start_indices = Enum.map(start_indices, &to_type(&1, {:s, 32}))
Value.dynamic_update_slice(tensor, slice, start_indices, expr_to_typespec(ans))
end

Expand All @@ -1313,7 +1322,9 @@ defmodule EXLA.Defn do

Value.gather(
tensor,
indices,
# TODO remove conversion (unsigned indices fail)
# Reported in https://github.com/google/jax/issues/21547
to_type(indices, {:s, 32}),
index_vector_dim,
slice_sizes,
offset_dims,
Expand Down Expand Up @@ -1341,7 +1352,7 @@ defmodule EXLA.Defn do
defp to_operator(:sort, [%Value{} = tensor, opts], ans, state) do
dimension = opts[:axis]

op =
operator =
case opts[:direction] do
:asc -> :less
:desc -> :greater
Expand All @@ -1350,7 +1361,7 @@ defmodule EXLA.Defn do
arg_typespec = Typespec.tensor(ans.type, {})
arg_typespecs = [arg_typespec, arg_typespec]

comp = sort_computation(op, ans.type, arg_typespecs, state)
comp = sort_computation(operator, ans.type, arg_typespecs, state)

Value.sort([tensor], comp, dimension, opts[:stable] == true, [expr_to_typespec(ans)]) |> hd()
end
Expand Down Expand Up @@ -1530,30 +1541,45 @@ defmodule EXLA.Defn do

## Computation helpers

defp sort_computation(op, type, arg_typespecs, %{builder: %EXLA.MLIR.Function{} = function}) do
defp sort_computation(
operator,
type,
arg_typespecs,
%{builder: %EXLA.MLIR.Function{} = function}
) do
{region, [lhs, rhs | _]} = Function.push_region(function, arg_typespecs)

typespec = Typespec.tensor({:pred, 8}, {})

op =
cond do
Nx.Type.integer?(type) ->
apply(Value, op, [lhs, rhs, typespec])

op == :less ->
is_nan = Value.is_nan(rhs, typespec)
Value.bitwise_or(is_nan, Value.less(lhs, rhs, typespec), typespec)

op == :greater ->
is_nan = Value.is_nan(lhs, typespec)
Value.bitwise_or(is_nan, Value.greater(lhs, rhs, typespec), typespec)
{lhs, rhs} =
if Nx.Type.float?(type) do
{canonicalize_float_for_sort(lhs), canonicalize_float_for_sort(rhs)}
else
{lhs, rhs}
end

op = apply(Value, operator, [lhs, rhs, typespec, [total_order: true]])

Value.return(function, [op])
Function.pop_region(function)
region
end

defp canonicalize_float_for_sort(%Value{function: func} = op) do
# Standardize the representation of NaNs (-NaN, NaN) and zeros (-0, 0).
# See https://github.com/google/jax/blob/e81c82605f0e1813080cfe1037d043b27b38291d/jax/_src/lax/lax.py#L4248-L4253

op_typespec = Value.get_typespec(op)

zero = Value.constant(func, [0], Typespec.to_shape(op_typespec, {}))
zeros = Value.constant(func, [0], op_typespec)
nans = Value.constant(func, [:nan], op_typespec)

pred_typespec = Typespec.tensor({:pred, 8}, {})
op = Value.select(Value.equal(op, zero, pred_typespec), zeros, op, op_typespec)
Value.select(Value.is_nan(op, pred_typespec), nans, op, op_typespec)
end

defp op_computation(
op,
arg_typespecs,
Expand Down
30 changes: 23 additions & 7 deletions exla/lib/exla/mlir/value.ex
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,18 @@ defmodule EXLA.MLIR.Value do
}

for {op, direction} <- @bin_comparison_ops do
def unquote(op)(%Value{function: func} = lhs, %Value{function: func} = rhs, typespec) do
compare_and_return_bool(func, lhs, rhs, typespec, unquote(direction))
def unquote(op)(
%Value{function: func} = lhs,
%Value{function: func} = rhs,
typespec,
opts \\ []
) do
opts = Keyword.validate!(opts, total_order: false)
compare_and_return_bool(func, lhs, rhs, typespec, unquote(direction), opts[:total_order])
end
end

defp compare_and_return_bool(func, lhs, rhs, typespec, direction) do
defp compare_and_return_bool(func, lhs, rhs, typespec, direction, total_order? \\ false) do
%{type: lhs_type} = get_typespec(lhs)
%{type: rhs_type} = get_typespec(rhs)

Expand All @@ -69,7 +75,11 @@ defmodule EXLA.MLIR.Value do
attr_comparison_type(:float)

Nx.Type.float?(lhs_type) or Nx.Type.float?(rhs_type) ->
attr_comparison_type(:float)
if total_order? do
attr_comparison_type(:totalorder)
else
attr_comparison_type(:float)
end

true ->
attr_comparison_type(:notype)
Expand Down Expand Up @@ -663,9 +673,15 @@ defmodule EXLA.MLIR.Value do
typespecs
) do
result_types = typespecs_to_mlir_types(typespecs)
regions = [on_true, on_false]
pred = convert(pred, Typespec.tensor({:pred, 8}, {}))
op(func, "stablehlo.if", [pred], result_types, regions: regions)

# TODO Jax does not support stablehlo.if, they use stablhelo.case instead.
# It most likely makes sense for use to do the same. That said, note that
# stablehlo.case is implemented for Metal, but does not lower reliably.
# Reported in https://github.com/google/jax/issues/21601

regions = [on_false, on_true]
pred = convert(pred, Typespec.tensor({:s, 32}, {}))
op(func, "stablehlo.case", [pred], result_types, regions: regions)
end

def infeed(%Value{function: func} = token, typespecs) do
Expand Down
2 changes: 2 additions & 0 deletions exla/lib/exla/nif.ex
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ defmodule EXLA.NIF do

def get_tpu_client(), do: :erlang.nif_error(:undef)

def get_mps_client(), do: :erlang.nif_error(:undef)

def get_supported_platforms, do: :erlang.nif_error(:undef)

def get_device_count(_client),
Expand Down
23 changes: 22 additions & 1 deletion exla/mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ defmodule EXLA.MixProject do
cuda: [platform: :cuda],
rocm: [platform: :rocm],
tpu: [platform: :tpu],
mps: [platform: :mps],
host: [platform: :host]
],
preferred_clients: [:cuda, :rocm, :tpu, :host]
preferred_clients: [:cuda, :rocm, :tpu, :mps, :host]
]
]
end
Expand Down Expand Up @@ -128,11 +129,31 @@ defmodule EXLA.MixProject do
:ok -> File.write!(xla_snapshot_path, xla_archive_path)
{:error, term} -> Mix.raise("failed to extract xla archive, reason: #{inspect(term)}")
end

# TODO should be packed into the XLA archive
download_metal_plugin!(xla_extension_path)
end

{:ok, []}
end

defp download_metal_plugin!(xla_extension_path) do
plugin_path = Path.join(xla_extension_path, "lib/pjrt_plugin_metal.dylib")

wheel_url =
"https://files.pythonhosted.org/packages/d6/4f/f5d128a493b7387fbbe0e6906544214af2a6b86af30302dd6ffb9dc66a74/jax_metal-0.0.7-py3-none-macosx_13_0_arm64.whl"

wheel_path = Path.join(xla_extension_path, "jax_metal.whl")

{_, 0} = System.shell("wget --output-document=#{wheel_path} #{wheel_url}")
{_, 0} = System.shell("unzip #{wheel_path} -d #{xla_extension_path}")

wheel_plugin_path =
Path.join(xla_extension_path, "jax_plugins/metal_plugin/pjrt_plugin_metal_14.dylib")

File.cp!(wheel_plugin_path, plugin_path)
end

defp cached_make(_) do
contents =
for path <- Path.wildcard("c_src/**/*"),
Expand Down
67 changes: 66 additions & 1 deletion exla/test/exla/backend_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,73 @@ defmodule EXLA.BackendTest do
@skip_mac_arm []
end

if EXLA.Client.default_name() == :mps do
@skip_mps [
# Missing support for "stablehlo.reduce_window".
# Reported in https://github.com/google/jax/issues/21387
window_max: 3,
window_min: 3,
window_sum: 3,
window_product: 3,
window_reduce: 5,
window_scatter_min: 5,
window_scatter_max: 5,
window_mean: 3,
# Argmax/armin fail when a custom :type is passed.
# Reported in https://github.com/google/jax/issues/21577
argmin: 2,
argmax: 2,
# Missing support for general "stablehlo.reduce". Some cases work
# becuase they are special-cased.
# Reported in https://github.com/google/jax/issues/21384
reduce: 4,
# Missing support for "stablehlo.popcnt", "stablehlo.count_leading_zeros",
# "stablehlo.cbrt".
# Reported in https://github.com/google/jax/issues/21389
count_leading_zeros: 1,
population_count: 1,
cbrt: 1,
# Matrix multiplication for integers is not supported
dot: 2,
dot: 4,
dot: 6,
covariance: 3,
# (edge case) Put slice with overflowing slice, different behaviour.
# Reported in https://github.com/google/jax/issues/21392
put_slice: 3,
# (edge case) Slice with overflowing index, different behaviour.
# Reported in https://github.com/google/jax/issues/21393
slice: 4,
# (edge case) Top-k wrong behaviour with NaNs.
# Reported in https://github.com/google/jax/issues/21397
top_k: 2,
# Missing support for complex numbers.
# Tracked in https://github.com/google/jax/issues/16416
complex: 2,
conjugate: 1,
conv: 3,
fft: 2,
fft2: 2,
ifft: 2,
ifft2: 2,
imag: 1,
is_infinity: 1,
is_nan: 1,
phase: 1,
real: 1,
sigil_MAT: 2,
# Missing support for float-64.
# Tracked in https://github.com/google/jax/issues/20938
iota: 2,
as_type: 2,
atan2: 2
]
else
@skip_mps []
end

doctest Nx,
except: [:moduledoc] ++ @excluded_doctests ++ @skip_mac_arm
except: [:moduledoc] ++ @excluded_doctests ++ @skip_mac_arm ++ @skip_mps

test "Nx.to_binary/1" do
t = Nx.tensor([1, 2, 3, 4], backend: EXLA.Backend)
Expand Down
Loading

0 comments on commit 102e62b

Please sign in to comment.