Skip to content

Commit

Permalink
refactor: remove tuples from EXLA (#1463)
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente authored Mar 5, 2024
1 parent 087845b commit bb056ce
Showing 13 changed files with 99 additions and 177 deletions.
2 changes: 0 additions & 2 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
@@ -603,8 +603,6 @@ static ErlNifFunc exla_funcs[] = {
{"get_mlir_function_arguments", 1, get_mlir_function_arguments},
{"mlir_add", 3, mlir_add},
{"mlir_subtract", 3, mlir_subtract},
{"mlir_tuple", 2, mlir_tuple},
{"mlir_get_tuple_element", 3, mlir_get_tuple_element},
{"mlir_multiply", 3, mlir_multiply},
{"mlir_min", 3, mlir_min},
{"mlir_max", 3, mlir_max},
44 changes: 0 additions & 44 deletions exla/c_src/exla/mlir/ops.cc
Original file line number Diff line number Diff line change
@@ -153,50 +153,6 @@ ERL_NIF_TERM get_mlir_function_arguments(ErlNifEnv* env, int argc, const ERL_NIF
return exla::nif::ok(env, enif_make_list_from_array(env, terms.data(), terms.size()));
}

ERL_NIF_TERM mlir_tuple(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 2) {
return exla::nif::error(env, "Bad argument count.");
}

exla::MLIRFunction** function;
std::vector<mlir::Value> vals;

if (!exla::nif::get<exla::MLIRFunction*>(env, argv[0], function)) {
return exla::nif::error(env, "Unable to get function.");
}
if (!exla::nif::get_list<mlir::Value>(env, argv[1], vals)) {
return exla::nif::error(env, "Unable to get values.");
}

mlir::Value res = (*function)->TupleOp(vals);

return exla::nif::ok(env, exla::nif::make<mlir::Value>(env, res));
}

ERL_NIF_TERM mlir_get_tuple_element(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 3) {
return exla::nif::error(env, "Bad argument count.");
}

exla::MLIRFunction** function;
mlir::Value* tuple;
exla::int64 index;

if (!exla::nif::get<exla::MLIRFunction*>(env, argv[0], function)) {
return exla::nif::error(env, "Unable to get function.");
}
if (!exla::nif::get<mlir::Value>(env, argv[1], tuple)) {
return exla::nif::error(env, "Unable to get tuple.");
}
if (!exla::nif::get(env, argv[2], &index)) {
return exla::nif::error(env, "Unable to get index.");
}

mlir::Value res = (*function)->GetTupleElementOp(*tuple, index);

return exla::nif::ok(env, exla::nif::make<mlir::Value>(env, res));
}

ERL_NIF_TERM mlir_binary_op(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[], std::function<mlir::Value(exla::MLIRFunction*, mlir::Value*, mlir::Value*)> op) {
if (argc != 3) {
return exla::nif::error(env, "Bad argument count.");
2 changes: 0 additions & 2 deletions exla/c_src/exla/mlir/ops.h
Original file line number Diff line number Diff line change
@@ -9,8 +9,6 @@ DEFINE_NIF(new_mlir_module);
DEFINE_NIF(new_mlir_context);
DEFINE_NIF(create_mlir_function);
DEFINE_NIF(get_mlir_function_arguments);
DEFINE_NIF(mlir_tuple);
DEFINE_NIF(mlir_get_tuple_element);

// Binary Ops
DEFINE_NIF(mlir_add);
Loading

0 comments on commit bb056ce

Please sign in to comment.