Skip to content

Commit

Permalink
refactor(mlir): reduce tuple usage for returns in favor of variadic+f…
Browse files Browse the repository at this point in the history
…lat tuples (#1448)
  • Loading branch information
polvalente authored Feb 20, 2024
1 parent 51c7122 commit 321c711
Show file tree
Hide file tree
Showing 14 changed files with 244 additions and 277 deletions.
6 changes: 2 additions & 4 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ ERL_NIF_TERM deserialize_executable(ErlNifEnv* env, int argc, const ERL_NIF_TERM
}

EXLA_ASSIGN_OR_RETURN_NIF(exla::ExlaExecutable * executable,
(*client)->DeserializeExecutable(serialized), env);
(*client)->DeserializeExecutable(serialized), env);

return exla::nif::ok(env, exla::nif::make<exla::ExlaExecutable*>(env, executable));
}
Expand Down Expand Up @@ -690,7 +690,6 @@ static ErlNifFunc exla_funcs[] = {
{"mlir_less_equal", 3, mlir_less_equal},
{"mlir_greater", 3, mlir_greater},
{"mlir_greater_equal", 3, mlir_greater_equal},
{"mlir_build", 2, mlir_build},
{"dump_mlir_module", 1, dump_mlir_module},
{"mlir_get_shape", 1, mlir_get_shape},
{"mlir_convert", 3, mlir_convert},
Expand Down Expand Up @@ -922,7 +921,6 @@ static ErlNifFunc exla_funcs[] = {
{"start_log_sink", 1, start_log_sink},
// Serialization
{"serialize_executable", 1, serialize_executable},
{"deserialize_executable", 2, deserialize_executable}
};
{"deserialize_executable", 2, deserialize_executable}};

ERL_NIF_INIT(Elixir.EXLA.NIF, exla_funcs, &load, NULL, NULL, NULL);
5 changes: 3 additions & 2 deletions exla/c_src/exla/mlir/builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1368,12 +1368,13 @@ mlir::Value MLIRFunction::OutfeedOp(std::vector<mlir::Value> inputs, mlir::Value
return builder->create<mlir::stablehlo::OutfeedOp>(builder->getUnknownLoc(), mlir::ValueRange(inputs), token);
}

mlir::Value MLIRFunction::CallOp(std::vector<mlir::Value> inputs, MLIRFunction *computation) {
std::vector<mlir::Value> MLIRFunction::CallOp(std::vector<mlir::Value> inputs, MLIRFunction *computation) {
auto builder = module_->builder();
builder->setInsertionPointToEnd(&func_->getBody().back());
auto call_op = builder->create<mlir::func::CallOp>(builder->getUnknownLoc(), *computation->function(), mlir::ValueRange(inputs));

return call_op.getResult(0);
mlir::Operation::result_range results = call_op.getResults();
return std::vector<mlir::Value>(results.begin(), results.end());
}

std::vector<mlir::Value> MLIRFunction::WhileOp(MLIRFunction *pred, MLIRFunction *body_function, std::vector<mlir::Value> initial) {
Expand Down
2 changes: 1 addition & 1 deletion exla/c_src/exla/mlir/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class MLIRFunction {
ERL_NIF_TERM ConstantOp(mlir::Type type, ErlNifEnv *env, ERL_NIF_TERM value_ptr, std::optional<std::vector<int64_t>> dims = std::nullopt);
mlir::Value InfeedOp(mlir::Value token, xla::Shape *shape);
mlir::Value OutfeedOp(std::vector<mlir::Value> inputs, mlir::Value token);
mlir::Value CallOp(std::vector<mlir::Value> inputs, MLIRFunction *computation);
std::vector<mlir::Value> CallOp(std::vector<mlir::Value> inputs, MLIRFunction *computation);
std::vector<mlir::Value> WhileOp(MLIRFunction *pred, MLIRFunction *body, std::vector<mlir::Value> initial);
std::vector<mlir::Value> ReturnOp(std::vector<mlir::Value> values);
int get_mlir_type(ErlNifEnv *env, ERL_NIF_TERM term, mlir::Type *type);
Expand Down
24 changes: 2 additions & 22 deletions exla/c_src/exla/mlir/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -673,26 +673,6 @@ ERL_NIF_TERM mlir_select(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
return exla::nif::ok(env, exla::nif::make<mlir::Value>(env, res));
}

ERL_NIF_TERM mlir_build(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 2) {
return exla::nif::error(env, "Bad argument count.");
}

exla::MLIRFunction** function;
mlir::Value* root;

if (!exla::nif::get<exla::MLIRFunction*>(env, argv[0], function)) {
return exla::nif::error(env, "Unable to get function.");
}
if (!exla::nif::get<mlir::Value>(env, argv[1], root)) {
return exla::nif::error(env, "Unable to get root.");
}

(*function)->Build(*root);

return exla::nif::ok(env);
}

ERL_NIF_TERM mlir_convert(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 3) {
return exla::nif::error(env, "Bad argument count.");
Expand Down Expand Up @@ -1529,9 +1509,9 @@ ERL_NIF_TERM mlir_call(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
return exla::nif::error(env, "Unable to get computation.");
}

mlir::Value result = (*function)->CallOp(arguments, *computation);
std::vector<mlir::Value> result = (*function)->CallOp(arguments, *computation);

return exla::nif::ok(env, exla::nif::make<mlir::Value>(env, result));
return exla::nif::ok(env, exla::nif::make_list<mlir::Value>(env, result));
}

ERL_NIF_TERM mlir_while(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
Expand Down
1 change: 0 additions & 1 deletion exla/c_src/exla/mlir/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ DEFINE_NIF(mlir_constant_r0);
DEFINE_NIF(mlir_constant_from_binary);
DEFINE_NIF(mlir_dot_general);
DEFINE_NIF(mlir_select);
DEFINE_NIF(mlir_build);
DEFINE_NIF(mlir_convert);
DEFINE_NIF(mlir_top_k);
DEFINE_NIF(mlir_sort);
Expand Down
54 changes: 15 additions & 39 deletions exla/lib/exla/builder.ex
Original file line number Diff line number Diff line change
Expand Up @@ -10,44 +10,32 @@ defmodule EXLA.Builder do
@enforce_keys [:ref]
defstruct [:ref, :parent, :name]

def new(name, inputs, outputs, type, sub? \\ false, variadic_return? \\ false)

def new(name, _inputs, _outputs, :xla, _sub?, _variadic_return?) do
new(name)
end

def new(module_and_name, inputs, outputs, :mlir, sub?, variadic_return?) do
# TO-DO(mlir): this module shouldn't have to know about Nx
{_arg_names, arg_shapes} = Enum.unzip(inputs)

def new_mlir(module_and_name, arg_shapes, return_shape) do
{module, name, is_public} =
case module_and_name do
{%M{} = module, name} -> {module, name, false}
_name -> {M.new(), "main", true}
end

return_shape =
if sub? do
exla_shape(outputs, false)
else
out_types = [outputs] |> Nx.Defn.Composite.flatten_list()

if variadic_return? do
exla_shape(out_types, true)
else
out_types |> List.to_tuple() |> exla_shape(false)
end
end

M.create_function(
module,
name,
exla_shape(arg_shapes, false),
List.wrap(return_shape),
arg_shapes,
return_shape,
is_public
)
end

def new(name) when is_binary(name) do
{:ok, ref} = EXLA.NIF.new_builder(name)
%__MODULE__{ref: ref, parent: nil, name: name}
end

def new(builder = %__MODULE__{ref: ref}, name) when is_binary(name) do
{:ok, ref} = EXLA.NIF.create_sub_builder(ref, name)
%__MODULE__{ref: ref, parent: builder, name: name}
end

def exla_shape(tensors, flatten_tuple) when is_list(tensors) do
result = Enum.map(tensors, &exla_shape(&1, flatten_tuple))

Expand Down Expand Up @@ -93,16 +81,6 @@ defmodule EXLA.Builder do
shape
end

defp new(name) when is_binary(name) do
{:ok, ref} = EXLA.NIF.new_builder(name)
%__MODULE__{ref: ref, parent: nil, name: name}
end

def new(builder = %__MODULE__{ref: ref}, name) when is_binary(name) do
{:ok, ref} = EXLA.NIF.create_sub_builder(ref, name)
%__MODULE__{ref: ref, parent: builder, name: name}
end

def build(root)

def build(%Op{} = root) do
Expand All @@ -111,10 +89,8 @@ defmodule EXLA.Builder do
%Computation{ref: ref, output_shape: shape}
end

def build(%EXLA.MLIR.Value{function: function, ref: root_ref}) do
%EXLA.MLIR.Function{ref: function_ref} = function

:ok = EXLA.NIF.mlir_build(function_ref, root_ref)
def build(%EXLA.MLIR.Value{function: function} = value) do
EXLA.MLIR.Value.variadic_return([value])
function
end
end
4 changes: 1 addition & 3 deletions exla/lib/exla/computation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,11 @@ defmodule EXLA.Computation do
end

def compile(
%EXLA.MLIR.Function{module: module, return_shape: [return_shape]},
%EXLA.MLIR.Function{module: module, return_shape: return_shape},
client,
arg_shapes,
opts
) do
assert_output_shape!(%{output_shape: return_shape})

EXLA.MLIR.Module.compile(
module,
client,
Expand Down
Loading

0 comments on commit 321c711

Please sign in to comment.