diff --git a/exla/c_src/exla/custom_calls.cc b/exla/c_src/exla/custom_calls.cc index eb2ce90d27..36abcec04a 100644 --- a/exla/c_src/exla/custom_calls.cc +++ b/exla/c_src/exla/custom_calls.cc @@ -1,10 +1,36 @@ #include "custom_calls.h" -#include "exla_nif_util.h" - -#include "xla/service/custom_call_target_registry.h" #include "Eigen/Dense" +#include "Eigen/Eigenvalues" #include "Eigen/QR" +#include "exla_nif_util.h" +#include "xla/service/custom_call_target_registry.h" + +template +void single_matrix_eigh_cpu_custom_call(DataType *eigenvalues_out, DataType *eigenvectors_out, DataType *in, int64_t m, int64_t n) { + typedef Eigen::Matrix RowMajorMatrix; + + // Map the input matrix + Eigen::Map input(in, m, n); + + // Compute the Eigenvalue decomposition + Eigen::SelfAdjointEigenSolver eigensolver(input); + + if (eigensolver.info() != Eigen::Success) { + std::cerr << "Eigenvalue decomposition failed!" << std::endl; + return; + } + + // Get the eigenvalues and eigenvectors + Eigen::Matrix eigenvalues = eigensolver.eigenvalues(); + RowMajorMatrix eigenvectors = eigensolver.eigenvectors(); + + // Copy the eigenvalues to the output + std::memcpy(eigenvalues_out, eigenvalues.data(), m * sizeof(DataType)); + + // Copy the eigenvectors to the output + std::memcpy(eigenvectors_out, eigenvectors.data(), m * n * sizeof(DataType)); +} template void single_matrix_qr_cpu_custom_call(DataType *q_out, DataType *r_out, DataType *in, int64_t m, int64_t k, int64_t n, bool complete) { @@ -89,6 +115,50 @@ void qr_cpu_custom_call(void *out[], const void *in[]) { } } +template +void eigh_cpu_custom_call(void *out[], const void *in[]) { + DataType *operand = (DataType *)in[0]; + + int64_t *dim_sizes = (int64_t *)in[1]; + int64_t num_operand_dims = dim_sizes[0]; + int64_t num_eigenvalues_dims = dim_sizes[1]; + int64_t num_eigenvectors_dims = dim_sizes[2]; + + int64_t *operand_dims_ptr = (int64_t *)in[2]; + std::vector operand_dims(operand_dims_ptr, operand_dims_ptr + num_operand_dims); + + int64_t *eigenvalues_dims_ptr = (int64_t *)in[3]; + std::vector eigenvalues_dims(eigenvalues_dims_ptr, eigenvalues_dims_ptr + num_eigenvalues_dims); + + int64_t *eigenvectors_dims_ptr = (int64_t *)in[4]; + std::vector eigenvectors_dims(eigenvectors_dims_ptr, eigenvectors_dims_ptr + num_eigenvectors_dims); + + int64_t m = eigenvectors_dims[eigenvectors_dims.size() - 2]; + int64_t n = eigenvectors_dims[eigenvectors_dims.size() - 1]; + + auto leading_dimensions = std::vector(operand_dims.begin(), operand_dims.end() - 2); + + int64_t batch_items = 1; + for (int64_t i = 0; i < leading_dimensions.size(); i++) { + batch_items *= leading_dimensions[i]; + } + + DataType *eigenvalues = (DataType *)out[0]; + DataType *eigenvectors = (DataType *)out[1]; + + int64_t eigenvalues_stride = eigenvalues_dims[eigenvalues_dims.size() - 1] * sizeof(DataType); + int64_t eigenvectors_stride = eigenvectors_dims[eigenvectors_dims.size() - 1] * eigenvectors_dims[eigenvectors_dims.size() - 2] * sizeof(DataType); + int64_t inner_stride = m * n * sizeof(DataType); + + for (int64_t i = 0; i < batch_items; i++) { + single_matrix_eigh_cpu_custom_call( + eigenvalues + i * eigenvalues_stride, + eigenvectors + i * eigenvectors_stride, + operand + i * inner_stride / sizeof(DataType), + m, n); + } +} + void qr_cpu_custom_call_bf16(void *out[], const void *in[]) { qr_cpu_custom_call(out, in); } @@ -105,7 +175,19 @@ void qr_cpu_custom_call_f64(void *out[], const void *in[]) { qr_cpu_custom_call(out, in); } +void eigh_cpu_custom_call_f32(void *out[], const void *in[]) { + eigh_cpu_custom_call(out, in); +} + +void eigh_cpu_custom_call_f64(void *out[], const void *in[]) { + eigh_cpu_custom_call(out, in); +} + XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f32", qr_cpu_custom_call_f32); XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f64", qr_cpu_custom_call_f64); XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_f16", qr_cpu_custom_call_f16); XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("qr_cpu_custom_call_bf16", qr_cpu_custom_call_bf16); + + +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f32", eigh_cpu_custom_call_f32); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("eigh_cpu_custom_call_f64", eigh_cpu_custom_call_f64); \ No newline at end of file diff --git a/exla/c_src/exla/custom_calls.h b/exla/c_src/exla/custom_calls.h index f00834d411..36cd7d9ac5 100644 --- a/exla/c_src/exla/custom_calls.h +++ b/exla/c_src/exla/custom_calls.h @@ -6,4 +6,7 @@ void qr_cpu_custom_call_f16(void *out[], const void *in[]); void qr_cpu_custom_call_f32(void *out[], const void *in[]); void qr_cpu_custom_call_f64(void *out[], const void *in[]); +void eigh_cpu_custom_call_f32(void *out[], const void *in[]); +void eigh_cpu_custom_call_f64(void *out[], const void *in[]); + #endif diff --git a/exla/lib/exla/backend.ex b/exla/lib/exla/backend.ex index 21c2f8ccd1..0193e14109 100644 --- a/exla/lib/exla/backend.ex +++ b/exla/lib/exla/backend.ex @@ -234,24 +234,16 @@ defmodule EXLA.Backend do @impl true def concatenate(out, tensors, axis) do - out = Nx.to_template(out) - - expr_fun = fn tensors -> - Nx.Defn.Expr.concatenate(out, Tuple.to_list(tensors), axis) - end - - jit([], expr_fun, tensors, [List.to_tuple(tensors)]) + copied = Enum.map(tensors, &Nx.backend_copy(&1, Nx.BinaryBackend)) + result = Nx.BinaryBackend.concatenate(out, copied, axis) + Nx.backend_transfer(result, {EXLA.Backend, jit_opts([], tensors)}) end @impl true def stack(out, tensors, axis) do - out = Nx.to_template(out) - - expr_fun = fn tensors -> - Nx.Defn.Expr.stack(out, Tuple.to_list(tensors), axis) - end - - jit([], expr_fun, tensors, [List.to_tuple(tensors)]) + copied = Enum.map(tensors, &Nx.backend_copy(&1, Nx.BinaryBackend)) + result = Nx.BinaryBackend.stack(out, copied, axis) + Nx.backend_transfer(result, {EXLA.Backend, jit_opts([], tensors)}) end @impl true @@ -390,6 +382,10 @@ defmodule EXLA.Backend do defp jit(opts, fun, args), do: jit(opts, fun, args, args) defp jit(opts, fun, tensors, args) do + EXLA.jit_apply(fun, args, [on_conflict: :force] ++ jit_opts(tensors, opts)) + end + + defp jit_opts(opts, tensors) do {priority_client, priority_did, backup_client, backup_did} = for %T{data: %B{buffer: %EXLA.DeviceBuffer{client_name: client_name, device_id: device_id}}} <- tensors, @@ -418,6 +414,6 @@ defmodule EXLA.Backend do opts[:device_id] || priority_did || backup_did || EXLA.Client.fetch!(client).default_device_id - EXLA.jit_apply(fun, args, on_conflict: :force, client: client, device_id: device_id) + [client: client, device_id: device_id] end end diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 86cceb4b5d..9c9bbdbdbb 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -237,7 +237,7 @@ defmodule EXLA.Defn do output = wrap_tuple_result(acc, acc_typespec) outfeed = outfeed |> Outfeed.with_token(out_token) |> Outfeed.close(builder) - Value.return(builder, output) + Value.func_return(builder, output) {{input_typespecs, input_indexes}, outfeed} end @@ -307,7 +307,7 @@ defmodule EXLA.Defn do {res, cache} = recur_flatten(expr, state, new_cache(outfeed)) outfeed = cache |> get_outfeed() |> Outfeed.close(function) - Value.return(function, res) + Value.func_return(function, res) {:ok, outfeed} end @@ -433,6 +433,15 @@ defmodule EXLA.Defn do comp_arg_typespecs = for {i, typespec} <- inputs_and_typespecs, i >= used_buffers, do: typespec + outputs = + if stream? do + # The computation returns the final accumulator value + {_chunk_result, acc} = outputs + acc + else + outputs + end + out_typespecs = [outputs] |> Nx.Defn.Composite.flatten_list() @@ -624,6 +633,45 @@ defmodule EXLA.Defn do {[q, r], cache} end + defp cached_recur_operator( + :optional, + %T{ + data: %Expr{ + args: [ + %{data: %{op: :eigh, args: [tensor, _opts]}}, + {eigenvecs_expr, eigenvals_expr}, + _callback + ] + } + }, + %{client: %EXLA.Client{platform: :host}, builder: %Function{}} = state, + cache + ) do + # We match only on platform: :host for MLIR, as we want to support + # eigh-on-cpu as a custom call only in this case + {tensor, cache} = recur_operator(tensor, state, cache) |> unwrap_single_tensor!() + + # convert to float and ensure that we're either using f32 or f64, because Eigen + # only supports f32 and f64 easily. + out_type = Nx.Type.merge(Nx.Type.to_floating(eigenvecs_expr.type), {:f, 32}) + + tensor = + if op_type(tensor) != out_type do + to_type(tensor, out_type) + else + tensor + end + + {eigenvecs, eigenvals} = + Value.eigh( + tensor, + expr_to_typespec(%{eigenvecs_expr | type: out_type}), + expr_to_typespec(%{eigenvals_expr | type: out_type}) + ) + + {[to_type(eigenvecs, eigenvecs_expr.type), to_type(eigenvals, eigenvals_expr.type)], cache} + end + defp cached_recur_operator( :optional, %T{ @@ -1669,9 +1717,9 @@ defmodule EXLA.Defn do {res, comp_cache} = recur_composite(expr, state, reset_token(cache, inner_token)) if outer_token do - Value.return(function, [get_token(comp_cache) | List.flatten(res)]) + Value.func_return(function, [get_token(comp_cache) | List.flatten(res)]) else - Value.return(function, List.flatten(res)) + Value.func_return(function, List.flatten(res)) end {function, merge_outfeed(cache, comp_cache)} diff --git a/exla/lib/exla/executable.ex b/exla/lib/exla/executable.ex index a6a0c8cbdf..1c6e4df49a 100644 --- a/exla/lib/exla/executable.ex +++ b/exla/lib/exla/executable.ex @@ -22,6 +22,9 @@ defmodule EXLA.Executable do end end + @doc """ + Serializes the executable to a binary. + """ def serialize(%Executable{ ref: executable, output_typespecs: output_typespecs, @@ -36,6 +39,7 @@ defmodule EXLA.Executable do |> IO.iodata_to_binary() %{ + version: 1, serialized: serialized_exec, output_typespecs: output_typespecs, num_replicas: num_replicas, @@ -45,21 +49,35 @@ defmodule EXLA.Executable do |> :erlang.term_to_binary() end + @doc """ + Deserializes a previous serialized executable. + """ def deserialize(client, binary) do case :erlang.binary_to_term(binary) do - %{serialized: serialized_exec} = exec_data -> + %{version: 1, serialized: serialized} = data -> + %{ + output_typespecs: output_typespecs, + num_replicas: num_replicas, + num_partitions: num_partitions, + device_id: device_id + } = data + ref = - serialized_exec + serialized |> then(&EXLA.NIF.deserialize_executable(client.ref, &1)) |> unwrap!() - exec_data - |> Map.put(:ref, ref) - |> Map.put(:client, client) - |> then(&struct(__MODULE__, &1)) + %EXLA.Executable{ + output_typespecs: output_typespecs, + num_replicas: num_replicas, + num_partitions: num_partitions, + device_id: device_id, + ref: ref, + client: client + } _other -> - raise "invalid serialized executable" + raise ArgumentError, "invalid serialized executable" end end diff --git a/exla/lib/exla/lib.ex b/exla/lib/exla/lib.ex index 2b54e3a5da..686d98d7df 100644 --- a/exla/lib/exla/lib.ex +++ b/exla/lib/exla/lib.ex @@ -34,7 +34,7 @@ defmodule EXLA.Lib do def argmax(builder, op, type, opts \\ []) def argmax(%Function{} = builder, %Value{} = op, type, opts) do - argmin_or_max(builder, op, false, type, opts) + argmin_or_max(builder, op, :max, type, opts) end @doc """ @@ -49,37 +49,43 @@ defmodule EXLA.Lib do def argmin(builder, op, type, opts \\ []) def argmin(%Function{} = builder, %Value{} = op, type, opts) do - argmin_or_max(builder, op, true, type, opts) + argmin_or_max(builder, op, :min, type, opts) end - defp argmin_or_max(builder, %Value{} = op, is_min?, type, opts) do + defp argmin_or_max(builder, %Value{} = op, variant, type, opts) do tie_break = opts[:tie_break] || :low keep_axis = opts[:keep_axis] || false + axis = opts[:axis] op_typespec = Value.get_typespec(op) + {op, op_typespec} = + if axis == nil and Nx.rank(op_typespec.shape) != 1 do + # When no axis is given, we flatten the tensor and reduce over + # the first axis + typespec = Typespec.to_shape(op_typespec, {Nx.size(op_typespec.shape)}) + {Value.reshape(op, typespec), typespec} + else + {op, op_typespec} + end + + axis = axis || 0 + init_value = - if is_min?, - do: max_number(builder, op_typespec.type), - else: min_number(builder, op_typespec.type) + case variant do + :min -> max_number(builder, op_typespec.type) + :max -> min_number(builder, op_typespec.type) + end - axis = opts[:axis] index_init_value = Value.constant(builder, [0], Typespec.tensor(type, {})) iota = iota(builder, axis, Typespec.to_type(op_typespec, type)) - reduction = create_min_max_computation(builder, op_typespec.type, type, is_min?, tie_break) + reduction = create_min_max_computation(builder, op_typespec.type, type, variant, tie_break) - dims = - if axis do - [axis] - else - Nx.axes(op_typespec.shape) - end - - shape = remove_axes(op_typespec.shape, dims) + shape = Tuple.delete_at(op_typespec.shape, axis) typespecs = [Typespec.tensor(op_typespec.type, shape), Typespec.tensor(type, shape)] [_, result] = - Value.reduce(reduction, [init_value, index_init_value], [op, iota], dims, typespecs) + Value.reduce(reduction, [init_value, index_init_value], [op, iota], [axis], typespecs) if keep_axis do Value.reshape(result, Typespec.tensor(type, put_elem(op_typespec.shape, axis, 1))) @@ -88,13 +94,7 @@ defmodule EXLA.Lib do end end - defp remove_axes(shape, axes) do - axes - |> Enum.reverse() - |> Enum.reduce(shape, &Tuple.delete_at(&2, &1)) - end - - defp create_min_max_computation(%Function{} = function, type, index_type, is_min?, tie_break) do + defp create_min_max_computation(%Function{} = function, type, index_type, variant, tie_break) do arg_typespecs = [ Typespec.tensor(type, {}), Typespec.tensor(index_type, {}), @@ -109,27 +109,42 @@ defmodule EXLA.Lib do value_typespec = Typespec.tensor(type, {}) idx_typespec = Typespec.tensor(index_type, {}) - cmp = - if is_min?, - do: Value.less_equal(lhs_value, rhs_value, pred_typespec), - else: Value.greater_equal(lhs_value, rhs_value, pred_typespec) + comparator = + case variant do + :min -> &Value.less/3 + :max -> &Value.greater/3 + end + + # Pick lhs if strictly before or if it is NaN + pick_lhs_value = + Value.bitwise_or( + comparator.(lhs_value, rhs_value, pred_typespec), + Value.is_nan(lhs_value, pred_typespec), + pred_typespec + ) - max = Value.select(cmp, lhs_value, rhs_value, value_typespec) - arg_max = Value.select(cmp, lhs_index, rhs_index, idx_typespec) + max = Value.select(pick_lhs_value, lhs_value, rhs_value, value_typespec) - arg_max = + idx_comparator = case tie_break do - :low -> - eq? = Value.equal(lhs_value, rhs_value, pred_typespec) - id = Value.min(lhs_index, rhs_index, idx_typespec) - Value.select(eq?, id, arg_max, idx_typespec) - - :high -> - eq? = Value.equal(lhs_value, rhs_value, pred_typespec) - id = Value.max(lhs_index, rhs_index, idx_typespec) - Value.select(eq?, id, arg_max, idx_typespec) + :low -> &Value.less/3 + :high -> &Value.greater/3 end + # If lhs and rhs are equal (and not NaN), then pick index based on tie_break + pick_lhs_idx = + Value.bitwise_or( + pick_lhs_value, + Value.bitwise_and( + Value.equal(lhs_value, rhs_value, pred_typespec), + idx_comparator.(lhs_index, rhs_index, pred_typespec), + pred_typespec + ), + pred_typespec + ) + + arg_max = Value.select(pick_lhs_idx, lhs_index, rhs_index, idx_typespec) + Value.return(function, [max, arg_max]) Function.pop_region(function) region diff --git a/exla/lib/exla/mlir/value.ex b/exla/lib/exla/mlir/value.ex index 62ec3ff100..b53795055d 100644 --- a/exla/lib/exla/mlir/value.ex +++ b/exla/lib/exla/mlir/value.ex @@ -64,15 +64,20 @@ defmodule EXLA.MLIR.Value do %{type: rhs_type} = get_typespec(rhs) comparison_type = - if Nx.Type.float?(lhs_type) or Nx.Type.float?(rhs_type) do - attr_comparison_type(:totalorder) - else - attr_comparison_type(:notype) + cond do + Nx.Type.complex?(lhs_type) or Nx.Type.complex?(rhs_type) -> + attr_comparison_type(:float) + + Nx.Type.float?(lhs_type) or Nx.Type.float?(rhs_type) -> + attr_comparison_type(:float) + + true -> + attr_comparison_type(:notype) end attributes = [ comparison_direction: attr_comparison_direction(direction), - comparison_type: comparison_type + compare_type: comparison_type ] result_types = typespecs_to_mlir_types([Typespec.to_type(typespec, {:pred, 8})]) @@ -125,57 +130,48 @@ defmodule EXLA.MLIR.Value do end end - def is_infinity(%Value{function: func} = operand, typespec) do + def is_infinity(%Value{function: func} = operand, out_typespec) do %{type: type} = get_typespec(operand) - typespec = Typespec.to_type(typespec, {:pred, 8}) + typespec = Typespec.to_type(out_typespec, {:pred, 8}) - cond do - Nx.Type.complex?(type) -> - float_typespec = Typespec.to_type(typespec, complex_part_type(type)) - real = real(operand, float_typespec) - imag = imag(operand, float_typespec) - is_inf_real = is_infinity(real, typespec) - is_inf_imag = is_infinity(imag, typespec) - bitwise_or(is_inf_real, is_inf_imag, typespec) - - Nx.Type.integer?(type) -> - # Integers are never infinity. We use inequality to make sure - # the operand is still a part of the computation - not_equal(operand, operand, typespec) + result = + cond do + Nx.Type.complex?(type) -> + float_typespec = Typespec.to_type(typespec, complex_part_type(type)) + real = real(operand, float_typespec) + imag = imag(operand, float_typespec) + is_inf_real = is_infinity(real, typespec) + is_inf_imag = is_infinity(imag, typespec) + bitwise_or(is_inf_real, is_inf_imag, typespec) + + Nx.Type.integer?(type) -> + # Integers are never infinity. We use inequality to make sure + # the operand is still a part of the computation + not_equal(operand, operand, typespec) + + true -> + result_types = typespecs_to_mlir_types([typespec]) + op(func, "chlo.is_inf", [operand], result_types) |> one!() + end - true -> - result_types = typespecs_to_mlir_types([typespec]) - op(func, "chlo.is_inf", [operand], result_types) |> one!() + if out_typespec.type == typespec.type do + result + else + convert(result, out_typespec) end end - def is_nan(%Value{function: func} = operand, typespec) do - %{type: type} = get_typespec(operand) - - typespec = Typespec.to_type(typespec, {:pred, 8}) + def is_nan(%Value{} = operand, out_typespec) do + typespec = Typespec.to_type(out_typespec, {:pred, 8}) - cond do - Nx.Type.complex?(type) -> - float_typespec = Typespec.to_type(typespec, complex_part_type(type)) - real = real(operand, float_typespec) - imag = imag(operand, float_typespec) - is_nan_real = is_nan(real, typespec) - is_nan_imag = is_nan(imag, typespec) - bitwise_or(is_nan_real, is_nan_imag, typespec) - - Nx.Type.integer?(type) -> - # Integers are never nan. We use inequality to make sure - # the operand is still a part of the computation - not_equal(operand, operand, typespec) + # Only NaN is not equal to itself + result = not_equal(operand, operand, typespec) - true -> - result_types = typespecs_to_mlir_types([typespec]) - is_inf = op(func, "chlo.is_inf", [operand], result_types) |> one!() - is_finite = op(func, "stablehlo.is_finite", [operand], result_types) |> one!() - is_not_inf = bitwise_not(is_inf, typespec) - is_not_finite = bitwise_not(is_finite, typespec) - bitwise_and(is_not_inf, is_not_finite, typespec) + if out_typespec.type == typespec.type do + result + else + convert(result, out_typespec) end end @@ -706,10 +702,66 @@ defmodule EXLA.MLIR.Value do op(func, "stablehlo.while", initial, result_types, regions: regions) end + def func_return(func, values) when is_list(values) do + op(func, "func.return", values, []) + end + def return(func, values) when is_list(values) do op(func, "stablehlo.return", values, []) end + def eigh(%Value{function: func} = value, eigenvecs_typespec, eigenvals_typespec) do + %{type: op_type, shape: op_shape} = get_typespec(value) + %{type: eigenvecs_type, shape: eigenvecs_shape} = eigenvecs_typespec + %{type: eigenvals_type, shape: eigenvals_shape} = eigenvals_typespec + + dim_sizes = [tuple_size(op_shape), tuple_size(eigenvecs_shape), tuple_size(eigenvals_shape)] + operand_dims = Tuple.to_list(op_shape) + eigenvecs_dims = Tuple.to_list(eigenvecs_shape) + eigenvals_dims = Tuple.to_list(eigenvals_shape) + + dim_sizes = constant(func, dim_sizes, Typespec.tensor({:s, 64}, {length(dim_sizes)})) + operand_dims = constant(func, operand_dims, Typespec.tensor({:s, 64}, {length(operand_dims)})) + + eigenvecs_dims = + constant(func, eigenvecs_dims, Typespec.tensor({:s, 64}, {length(eigenvecs_dims)})) + + eigenvals_dims = + constant(func, eigenvals_dims, Typespec.tensor({:s, 64}, {length(eigenvals_dims)})) + + operands = [value, dim_sizes, operand_dims, eigenvecs_dims, eigenvals_dims] + + eigenvecs_result_type = type_tensor(eigenvecs_type, eigenvecs_shape) + eigenvals_result_type = type_tensor(eigenvals_type, eigenvals_shape) + result_types = [type_tuple([eigenvecs_result_type, eigenvals_result_type])] + + call_target_name = + case op_type do + {:f, 32} -> + "eigh_cpu_custom_call_f32" + + {:f, 64} -> + "eigh_cpu_custom_call_f64" + + type -> + # Due to matching on EXLA.Defn, we are sure that the device here is always :host + raise "Eigh decomposition not supported on :host device for type #{inspect(type)}" + end + + attributes = [ + call_target_name: attr_string(call_target_name), + backend_config: attr_string("Host") + ] + + result = + op(func, "stablehlo.custom_call", operands, result_types, attributes: attributes) |> one!() + + eigenvecs = get_tuple_element(result, 0, eigenvecs_typespec) + eigenvals = get_tuple_element(result, 1, eigenvals_typespec) + + {eigenvecs, eigenvals} + end + def qr(%Value{function: func} = value, q_typespec, r_typespec) do %{type: op_type, shape: op_shape} = get_typespec(value) %{type: q_type, shape: q_shape} = q_typespec @@ -934,7 +986,7 @@ defmodule EXLA.MLIR.Value do defp attr_comparison_direction(value) when value in [:eq, :lt, :le, :gt, :ge, :ne], do: attr_enum("stablehlo", "comparison_direction", value) - defp attr_comparison_type(value) when value in [:totalorder, :notype], + defp attr_comparison_type(value) when value in [:float, :totalorder, :notype], do: attr_enum("stablehlo", "comparison_type", value) defp attr_precision(value) when value in [:default, :high, :highest], diff --git a/exla/mix.exs b/exla/mix.exs index 9c6e32e713..cf6105f2e4 100644 --- a/exla/mix.exs +++ b/exla/mix.exs @@ -66,7 +66,7 @@ defmodule EXLA.MixProject do # {:nx, "~> 0.7.1"}, {:nx, path: "../nx"}, {:telemetry, "~> 0.4.0 or ~> 1.0"}, - {:xla, "~> 0.7.0", runtime: false}, + {:xla, "~> 0.7.1", runtime: false}, {:elixir_make, "~> 0.6", runtime: false}, {:benchee, "~> 1.0", only: :dev}, {:ex_doc, "~> 0.29", only: :docs}, diff --git a/exla/mix.lock b/exla/mix.lock index 47649c6002..020904255b 100644 --- a/exla/mix.lock +++ b/exla/mix.lock @@ -3,7 +3,7 @@ "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, "deep_merge": {:hex, :deep_merge, "1.0.0", "b4aa1a0d1acac393bdf38b2291af38cb1d4a52806cf7a4906f718e1feb5ee961", [:mix], [], "hexpm", "ce708e5f094b9cd4e8f2be4f00d2f4250c4095be93f8cd6d018c753894885430"}, "earmark_parser": {:hex, :earmark_parser, "1.4.39", "424642f8335b05bb9eb611aa1564c148a8ee35c9c8a8bba6e129d51a3e3c6769", [:mix], [], "hexpm", "06553a88d1f1846da9ef066b87b57c6f605552cfbe40d20bd8d59cc6bde41944"}, - "elixir_make": {:hex, :elixir_make, "0.8.3", "d38d7ee1578d722d89b4d452a3e36bcfdc644c618f0d063b874661876e708683", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "5c99a18571a756d4af7a4d89ca75c28ac899e6103af6f223982f09ce44942cc9"}, + "elixir_make": {:hex, :elixir_make, "0.8.4", "4960a03ce79081dee8fe119d80ad372c4e7badb84c493cc75983f9d3bc8bde0f", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "6e7f1d619b5f61dfabd0a20aa268e575572b542ac31723293a4c1a567d5ef040"}, "ex_doc": {:hex, :ex_doc, "0.31.1", "8a2355ac42b1cc7b2379da9e40243f2670143721dd50748bf6c3b1184dae2089", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.1", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "3178c3a407c557d8343479e1ff117a96fd31bafe52a039079593fb0524ef61b0"}, "makeup": {:hex, :makeup, "1.1.1", "fa0bc768698053b2b3869fa8a62616501ff9d11a562f3ce39580d60860c3a55e", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "5dc62fbdd0de44de194898b6710692490be74baa02d9d108bc29f007783b0b48"}, "makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"}, @@ -13,5 +13,5 @@ "nx": {:hex, :nx, "0.7.1", "5f6376e3d18408116e8a84b8f4ac851fb07dfe61764a5410ebf0b5dcb69c1b7e", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "e3ddd6a3f2a9bac79c67b3933368c25bb5ec814a883fc68aba8fd8a236751777"}, "statistex": {:hex, :statistex, "1.0.0", "f3dc93f3c0c6c92e5f291704cf62b99b553253d7969e9a5fa713e5481cd858a5", [:mix], [], "hexpm", "ff9d8bee7035028ab4742ff52fc80a2aa35cece833cf5319009b52f1b5a86c27"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, - "xla": {:hex, :xla, "0.7.0", "413880fb8f665d93636908092a409e549545e190b38b91107832e78379190d93", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "8eb5c5510e6737fd9e4860bfb0d8cafb13ab94b1b4123edd347562a71e19ec27"}, + "xla": {:hex, :xla, "0.7.1", "cf188be6af8c794bf0eed1214cb7d0207b93eb4b3a3c1928fc812357b154cdfc", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "0028968ea5431ff9ac10e5eaaed595e139c65f4ff448e26783ba33783640855c"}, } diff --git a/exla/test/exla/executable_test.exs b/exla/test/exla/executable_test.exs index 28e276edfc..fe0e5c3f4d 100644 --- a/exla/test/exla/executable_test.exs +++ b/exla/test/exla/executable_test.exs @@ -11,7 +11,7 @@ defmodule EXLA.ExecutableTest do describe "run" do test "with no inputs and default options" do assert [a = %DeviceBuffer{}] = - run_one([], [], Typespec.tensor({:s, 32}, {}), fn b -> + run_one([], [], s32_typespec(), fn b -> [Value.constant(b, [1], s32_typespec())] end) @@ -19,8 +19,8 @@ defmodule EXLA.ExecutableTest do end test "with 2 inputs and default options" do - t1 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {})) - t2 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {})) + t1 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec()) + t2 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec()) assert [a = %DeviceBuffer{}] = run_one([t1, t2], [], [t1.typespec], fn _b, x, y -> @@ -34,7 +34,7 @@ defmodule EXLA.ExecutableTest do t1 = DeviceBuffer.place_on_device( <<1::32-native>>, - Typespec.tensor({:s, 32}, {}), + s32_typespec(), client(), 0 ) @@ -42,7 +42,7 @@ defmodule EXLA.ExecutableTest do t2 = DeviceBuffer.place_on_device( <<1::32-native>>, - Typespec.tensor({:s, 32}, {}), + s32_typespec(), client(), 0 ) @@ -62,8 +62,8 @@ defmodule EXLA.ExecutableTest do end test "with data from a previous run" do - t1 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {})) - t2 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {})) + t1 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec()) + t2 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec()) exec = compile([t1.typespec, t2.typespec], [], [t1.typespec], fn _b, x, y -> @@ -80,12 +80,12 @@ defmodule EXLA.ExecutableTest do t1 = DeviceBuffer.place_on_device( <<1::32-native>>, - Typespec.tensor({:s, 32}, {}), + s32_typespec(), client(), 0 ) - t2 = BinaryBuffer.from_binary(<<2::32-native>>, Typespec.tensor({:s, 32}, {})) + t2 = BinaryBuffer.from_binary(<<2::32-native>>, s32_typespec()) assert [a = %DeviceBuffer{}] = run_one([t1, t2], [], [t1.typespec], fn _b, x, y -> @@ -96,8 +96,8 @@ defmodule EXLA.ExecutableTest do end test "with tuple return" do - t1 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {})) - t2 = BinaryBuffer.from_binary(<<2::32-native>>, Typespec.tensor({:s, 32}, {})) + t1 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec()) + t2 = BinaryBuffer.from_binary(<<2::32-native>>, s32_typespec()) assert [a = %DeviceBuffer{}, b = %DeviceBuffer{}] = run_one([t1, t2], [], [t1.typespec, t2.typespec], fn _b, x, y -> @@ -110,8 +110,8 @@ defmodule EXLA.ExecutableTest do @tag :multi_device test "runs on a specific device" do - t1 = BinaryBuffer.from_binary(<<1::32-native>>, Typespec.tensor({:s, 32}, {})) - t2 = BinaryBuffer.from_binary(<<2::32-native>>, Typespec.tensor({:s, 32}, {})) + t1 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec()) + t2 = BinaryBuffer.from_binary(<<2::32-native>>, s32_typespec()) assert [a = %DeviceBuffer{}, b = %DeviceBuffer{}, c = %DeviceBuffer{}] = run_one( @@ -138,6 +138,25 @@ defmodule EXLA.ExecutableTest do end end + describe "serialization" do + test "run" do + t1 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec()) + t2 = BinaryBuffer.from_binary(<<1::32-native>>, s32_typespec()) + + exec = + compile([s32_typespec(), s32_typespec()], [], [s32_typespec()], fn _, x, y -> + [Value.add(x, y, s32_typespec())] + end) + + binary = Executable.serialize(exec) + assert is_binary(binary) + exec = Executable.deserialize(client(), binary) + + assert [[a = %DeviceBuffer{}]] = EXLA.Executable.run(exec, [[t1, t2]], []) + assert <<2::32-native>> == DeviceBuffer.read(a) + end + end + defp s32_typespec(), do: Typespec.tensor({:s, 32}, {}) end @@ -160,7 +179,7 @@ defmodule EXLA.ExecutableFeedTest do assert res = Task.async(fn -> - run_one([], [], [Typespec.token()], fn b -> + run_one([], [], [t.typespec], fn b -> token = Value.create_token(b) {new_token, [val]} = Value.infeed(token, [t.typespec]) @@ -185,7 +204,7 @@ defmodule EXLA.ExecutableFeedTest do assert res = Task.async(fn -> - run_one([], [], [token_shape, t.typespec], fn b -> + run_one([], [], [t.typespec], fn b -> token = Value.create_token(b) arg_shapes = [token_shape, t.typespec] diff --git a/exla/test/support/exla_helpers.ex b/exla/test/support/exla_helpers.ex index db7689d144..971d097ab3 100644 --- a/exla/test/support/exla_helpers.ex +++ b/exla/test/support/exla_helpers.ex @@ -15,7 +15,7 @@ defmodule EXLAHelpers do fun |> apply([builder | params]) - |> then(&EXLA.MLIR.Value.return(builder, List.wrap(&1))) + |> then(&EXLA.MLIR.Value.func_return(builder, List.wrap(&1))) EXLA.MLIR.Module.compile( builder.module, diff --git a/nx/guides/exercises/exercises-1-20.livemd b/nx/guides/exercises/exercises-1-20.livemd index 812d3a5197..f3073f0efd 100644 --- a/nx/guides/exercises/exercises-1-20.livemd +++ b/nx/guides/exercises/exercises-1-20.livemd @@ -182,7 +182,7 @@ tensor = Nx.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
```elixir - Nx.iota({39}) + Nx.iota({40}) |> Nx.add(10) ``` diff --git a/nx/lib/nx.ex b/nx/lib/nx.ex index a8fcee7488..aca74c01c7 100644 --- a/nx/lib/nx.ex +++ b/nx/lib/nx.ex @@ -10008,6 +10008,15 @@ defmodule Nx do 1 > + If the tensor includes any NaNs, returns the index of any of them + (NaNs are not equal, hence tie-break does not apply): + + iex> Nx.argmax(Nx.tensor([2.0, :nan, 4.0])) + #Nx.Tensor< + s64 + 1 + > + ### Aggregating over an axis iex> t = Nx.tensor([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]]) @@ -10147,6 +10156,15 @@ defmodule Nx do 0 > + If the tensor includes any NaNs, returns the index of any of them + (NaNs are not equal, hence tie-break does not apply): + + iex> Nx.argmin(Nx.tensor([2.0, :nan, 4.0])) + #Nx.Tensor< + s64 + 1 + > + ### Aggregating over an axis iex> t = Nx.tensor([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]]) diff --git a/nx/lib/nx/binary_backend.ex b/nx/lib/nx/binary_backend.ex index 32279b8115..cbf1539e9b 100644 --- a/nx/lib/nx/binary_backend.ex +++ b/nx/lib/nx/binary_backend.ex @@ -1461,7 +1461,7 @@ defmodule Nx.BinaryBackend do bin, {i, cur_extreme_x, cur_extreme_i} -> x = binary_to_number(bin, type) - if cur_extreme_x == :first or comparator.(x, cur_extreme_x) do + if cur_extreme_x == :first or x == :nan or comparator.(x, cur_extreme_x) do {i, {i + 1, x, i}} else {cur_extreme_i, {i + 1, cur_extreme_x, cur_extreme_i}} diff --git a/nx/lib/nx/binary_backend/matrix.ex b/nx/lib/nx/binary_backend/matrix.ex index 9ee102e6fe..e3930440c0 100644 --- a/nx/lib/nx/binary_backend/matrix.ex +++ b/nx/lib/nx/binary_backend/matrix.ex @@ -177,7 +177,7 @@ defmodule Nx.BinaryBackend.Matrix do # QR iteration for eigenvalues and eigenvectors {eigenvals_diag, eigenvecs} = - Enum.reduce_while(1..max_iter, {h, q_h}, fn _, {a_old, q_old} -> + Enum.reduce_while(1..max_iter//1, {h, q_h}, fn _, {a_old, q_old} -> # QR decomposition {q_now, r_now} = qr_decomposition(a_old, n, n, eps) diff --git a/nx/lib/nx/defn/expr.ex b/nx/lib/nx/defn/expr.ex index 44027d1564..bb07ea30d7 100644 --- a/nx/lib/nx/defn/expr.ex +++ b/nx/lib/nx/defn/expr.ex @@ -612,7 +612,7 @@ defmodule Nx.Defn.Expr do {nil, range} unroll when is_integer(unroll) and unroll > 0 -> - {internal, external} = split_range(range, size - rem(size, unroll)) + {internal, external} = Range.split(range, size - rem(size, unroll)) {{internal, 0..(unroll - 1)//1}, external} unroll -> @@ -730,33 +730,6 @@ defmodule Nx.Defn.Expr do end) end - # TODO: Use Range.split/2 when we require Elixir v1.15+ - defp split_range(first..last//step = range, split) when is_integer(split) do - if split >= 0 do - split_range(first, last, step, split) - else - split_range(first, last, step, Range.size(range) + split) - end - end - - defp split_range(first, last, step, split) when first < last or (first == last and step > 0) do - if step > 0 do - mid = max(min(first + step * (split - 1), last), first - step) - {first..mid//step, (mid + step)..last//step} - else - {first..(first - step)//step, (last + step)..last//step} - end - end - - defp split_range(last, first, step, split) do - if step < 0 do - mid = min(max(last + step * (split - 1), first), last - step) - {last..mid//step, (mid + step)..first//step} - else - {last..(last - step)//step, (first + step)..first//step} - end - end - defp compatible_while!(file, line, initial, body) do if not Nx.compatible?(initial, body) do raise CompileError, diff --git a/nx/lib/nx/lin_alg/eigh.ex b/nx/lib/nx/lin_alg/eigh.ex index 388e1dc17d..05673ad2ba 100644 --- a/nx/lib/nx/lin_alg/eigh.ex +++ b/nx/lib/nx/lin_alg/eigh.ex @@ -26,7 +26,17 @@ defmodule Nx.LinAlg.Eigh do } end - defn eigh_matrix(a, opts \\ []) do + defnp eigh_matrix(a, opts \\ []) do + case Nx.shape(a) do + {1, 1} -> + {a, Nx.fill(a, 1)} + + {_, _} -> + eigh_2d(a, opts) + end + end + + defnp eigh_2d(a, opts \\ []) do # The input Hermitian matrix A reduced to Hessenberg matrix H by Householder transform. # Then, by using QR iteration it converges to AQ = QΛ, # where Λ is the diagonal matrix of eigenvalues and the columns of Q are the eigenvectors. @@ -56,7 +66,7 @@ defmodule Nx.LinAlg.Eigh do {eigenvals, eigenvecs} end - defn hessenberg_decomposition(matrix, eps) do + defnp hessenberg_decomposition(matrix, eps) do # The input Hermitian matrix A reduced to Hessenberg matrix H by Householder transform. # Then, by using QR iteration it converges to AQ = QΛ, # where Λ is the diagonal matrix of eigenvalues and the columns of Q are the eigenvectors. @@ -70,7 +80,7 @@ defmodule Nx.LinAlg.Eigh do {{hess, q}, _} = while {{hess = Nx.as_type(matrix, out_type), q = eye}, {eps, column_iota}}, - i <- 0..(n - 2) do + i <- 0..(n - 2)//1 do x = hess[[.., i]] x = Nx.select(column_iota <= i, 0, x) h = QR.householder_reflector(x, i, eps) diff --git a/nx/lib/nx/serving.ex b/nx/lib/nx/serving.ex index 9ef7c44d39..273e708ed4 100644 --- a/nx/lib/nx/serving.ex +++ b/nx/lib/nx/serving.ex @@ -216,12 +216,14 @@ defmodule Nx.Serving do This can be done by passing either `Nx.backend_transfer/1` or `Nx.backend_copy/1` as third argument: - Nx.Serving.batched_run(MyDistributedServing, input, &Nx.backend_copy/1) + Nx.Serving.batched_run(MyDistributedServing, input, &Nx.backend_copy(&1, Nx.BinaryBackend)) Use `backend_transfer/1` if you know the input will no longer be used. - Similarly, the serving has a `distributed_postprocessing` callback which can do - equivalent before sending the reply to the caller. + Similarly, the serving has a `distributed_postprocessing` callback which is + called on the remote machine before sending the reply to the caller. It can + be used to transfer resources to the binary backend before sending them over + the network. The servings are dispatched using Erlang Distribution. You can use `Node.connect/1` to manually connect nodes. In a production setup, this is @@ -768,8 +770,7 @@ defmodule Nx.Serving do end end) - # TODO: Use Process.monitor/2 on Elixir v1.15+ - {pid, :erlang.monitor(:process, pid, alias: :demonitor)} + {pid, Process.monitor(pid, alias: :demonitor)} end defp run_hook(ref, size, result, hook) do @@ -1038,8 +1039,7 @@ defmodule Nx.Serving do {preprocessed, info} = handle_preprocessing(preprocessing, input) - # TODO: Use Process.monitor/2 on Elixir v1.15+ - ref = :erlang.monitor(:process, pid, alias: :demonitor) + ref = Process.monitor(pid, alias: :demonitor) size_or_unknown = case preprocessed do @@ -1396,8 +1396,7 @@ defmodule Nx.Serving do @impl true def handle_info({__MODULE__, :proxy_monitor, pid, ref}, state) do - # TODO: Use Process.monitor/2 on Elixir v1.15+ - :erlang.monitor(:process, pid, tag: {:proxy, ref}) + Process.monitor(pid, tag: {:proxy, ref}) {:noreply, state} end @@ -1613,7 +1612,7 @@ defmodule Nx.Serving do send(ref, {ref, {:batch, {start, size, output, metadata}}}) for pid <- pids do - send(pid, {ref, size - start}) + send(pid, {ref, size}) end end diff --git a/nx/lib/nx/shape.ex b/nx/lib/nx/shape.ex index 537fe5b947..61d7eeb941 100644 --- a/nx/lib/nx/shape.ex +++ b/nx/lib/nx/shape.ex @@ -1856,7 +1856,7 @@ defmodule Nx.Shape do end # batch axes must be increasing starting from 0 - valid_batch_axes = Enum.to_list(0..(length(b1) - 1)) + valid_batch_axes = Enum.to_list(0..(length(b1) - 1)//1) # ensure normalized batch axis of left is valid value if left_batched? and b1 != valid_batch_axes do diff --git a/nx/mix.lock b/nx/mix.lock index 61aa6535b2..95b94cbdae 100644 --- a/nx/mix.lock +++ b/nx/mix.lock @@ -1,10 +1,10 @@ %{ "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, "earmark_parser": {:hex, :earmark_parser, "1.4.39", "424642f8335b05bb9eb611aa1564c148a8ee35c9c8a8bba6e129d51a3e3c6769", [:mix], [], "hexpm", "06553a88d1f1846da9ef066b87b57c6f605552cfbe40d20bd8d59cc6bde41944"}, - "ex_doc": {:hex, :ex_doc, "0.31.1", "8a2355ac42b1cc7b2379da9e40243f2670143721dd50748bf6c3b1184dae2089", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.1", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "3178c3a407c557d8343479e1ff117a96fd31bafe52a039079593fb0524ef61b0"}, - "makeup": {:hex, :makeup, "1.1.1", "fa0bc768698053b2b3869fa8a62616501ff9d11a562f3ce39580d60860c3a55e", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "5dc62fbdd0de44de194898b6710692490be74baa02d9d108bc29f007783b0b48"}, - "makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"}, - "makeup_erlang": {:hex, :makeup_erlang, "0.1.4", "29563475afa9b8a2add1b7a9c8fb68d06ca7737648f28398e04461f008b69521", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f4ed47ecda66de70dd817698a703f8816daa91272e7e45812469498614ae8b29"}, + "ex_doc": {:hex, :ex_doc, "0.34.1", "9751a0419bc15bc7580c73fde506b17b07f6402a1e5243be9e0f05a68c723368", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "d441f1a86a235f59088978eff870de2e815e290e44a8bd976fe5d64470a4c9d2"}, + "makeup": {:hex, :makeup, "1.1.2", "9ba8837913bdf757787e71c1581c21f9d2455f4dd04cfca785c70bbfff1a76a3", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "cce1566b81fbcbd21eca8ffe808f33b221f9eee2cbc7a1706fc3da9ff18e6cac"}, + "makeup_elixir": {:hex, :makeup_elixir, "0.16.2", "627e84b8e8bf22e60a2579dad15067c755531fea049ae26ef1020cad58fe9578", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "41193978704763f6bbe6cc2758b84909e62984c7752b3784bd3c218bb341706b"}, + "makeup_erlang": {:hex, :makeup_erlang, "1.0.0", "6f0eff9c9c489f26b69b61440bf1b238d95badae49adac77973cbacae87e3c2e", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "ea7a9307de9d1548d2a72d299058d1fd2339e3d398560a0e46c27dab4891e4d2"}, "nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, } diff --git a/nx/test/nx/serving_test.exs b/nx/test/nx/serving_test.exs index 20e1ec46b0..35d6768ffb 100644 --- a/nx/test/nx/serving_test.exs +++ b/nx/test/nx/serving_test.exs @@ -1086,7 +1086,25 @@ defmodule Nx.ServingTest do Task.await(t2, :infinity) end - test "with input streaming", config do + test "with limited concurrent pushing", config do + serving = Nx.Serving.new(Simple, self()) + simple_supervised!(config, batch_size: 4, serving: serving) + + assert Task.async_stream( + 1..4, + fn _ -> + data = Stream.map([Nx.Batch.stack([Nx.tensor([1, 2, 3])])], & &1) + + Nx.Serving.batched_run(config.test, data) + end, + # A bug only shows with limited concurrency + max_concurrency: 2 + ) + |> Enum.map(fn {:ok, results} -> results end) + |> Enum.to_list() == List.duplicate(Nx.tensor([[2, 4, 6]]), 4) + end + + test "with output streaming", config do serving = Nx.Serving.new(Simple, self()) |> Nx.Serving.streaming() simple_supervised!(config, batch_size: 2, serving: serving) stream = Stream.map([[1, 2], [3]], &Nx.Batch.concatenate([Nx.tensor(&1)])) @@ -1109,7 +1127,7 @@ defmodule Nx.ServingTest do refute_received {:DOWN, _, _, _, _} end - test "with input streaming and hooks", config do + test "with output streaming and hooks", config do serving = Nx.Serving.new(Simple, self()) |> Nx.Serving.streaming(hooks: [:foo, :bar]) simple_supervised!(config, batch_size: 2, serving: serving) stream = Stream.map([[1, 2], [3]], &Nx.Batch.concatenate([Nx.tensor(&1)])) diff --git a/nx/test/nx_test.exs b/nx/test/nx_test.exs index 2da1ca42ed..28ee017e0d 100644 --- a/nx/test/nx_test.exs +++ b/nx/test/nx_test.exs @@ -1443,7 +1443,7 @@ defmodule NxTest do [:nan, 0, 1] ]) - assert Nx.argmin(t, axis: 1) == Nx.tensor([0, 0, 0, 0, 2, 2, 1, 1, 0, 0, 0, 0]) + assert Nx.argmin(t, axis: 1) == Nx.tensor([0, 1, 0, 0, 2, 1, 1, 1, 0, 1, 0, 0]) end test "raises for invalid :tie_break option" do @@ -1475,7 +1475,7 @@ defmodule NxTest do [:nan, 0, 1] ]) - assert Nx.argmax(t, axis: 1) == Nx.tensor([1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0]) + assert Nx.argmax(t, axis: 1) == Nx.tensor([1, 1, 2, 2, 0, 1, 0, 0, 0, 1, 0, 0]) end end