Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MLIR create token op #1335

Merged
merged 1 commit into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,7 @@ static ErlNifFunc exla_funcs[] = {
{"mlir_pad", 6, mlir_pad},
{"mlir_fft", 4, mlir_fft},
{"mlir_convolution", 12, mlir_convolution},
{"mlir_create_token", 1, mlir_create_token},
// XlaBuilder
{"new_builder", 1, new_builder},
{"create_sub_builder", 2, create_sub_builder},
Expand Down
6 changes: 6 additions & 0 deletions exla/c_src/exla/mlir/builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1032,4 +1032,10 @@ mlir::Value MLIRFunction::ConvOp(
nullptr);
}

mlir::Value MLIRFunction::CreateTokenOp() {
auto builder = module_->builder();
builder->setInsertionPointToEnd(&func_->getBody().back());
return builder->create<mlir::mhlo::CreateTokenOp>(builder->getUnknownLoc());
}

} // namespace exla
1 change: 1 addition & 0 deletions exla/c_src/exla/mlir/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class MLIRFunction {
mlir::Value SelectAndScatterOp(mlir::Value target, mlir::Value source, mlir::Value init_value, bool gt_or_lt, std::vector<int64_t> window_dimensions, std::vector<int64_t> window_strides, std::vector<int64_t> padding);
mlir::Value FFTOp(mlir::Value tensor, bool forward_fft, std::vector<int64_t> fft_lenght);
mlir::Value ConvOp(mlir::Value tensor, mlir::Value kernel, std::vector<int64_t> window_strides, std::vector<int64_t> padding, std::vector<int64_t> tensor_dilation, std::vector<int64_t> kernel_dilation, xla::ConvolutionDimensionNumbers dimension_numbers, uint64_t feature_group_count, uint64_t batch_group_count, uint64_t precision_config, std::vector<int64_t> output_dims);
mlir::Value CreateTokenOp();
ERL_NIF_TERM ConstantOp(mlir::Type type, ErlNifEnv *env, ERL_NIF_TERM value_ptr, std::vector<int64_t> dims = {});
int get_mlir_type(ErlNifEnv *env, ERL_NIF_TERM term, mlir::Type *type);

Expand Down
16 changes: 16 additions & 0 deletions exla/c_src/exla/mlir/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1147,4 +1147,20 @@ ERL_NIF_TERM mlir_convolution(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[
output_dims);

return exla::nif::ok(env, exla::nif::make<mlir::Value>(env, res));
}

ERL_NIF_TERM mlir_create_token(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 1) {
return exla::nif::error(env, "Bad argument count.");
}

exla::MLIRFunction** function;

if (!exla::nif::get<exla::MLIRFunction*>(env, argv[0], function)) {
return exla::nif::error(env, "Unable to get function.");
}

mlir::Value token = (*function)->CreateTokenOp();

return exla::nif::ok(env, exla::nif::make<mlir::Value>(env, token));
}
1 change: 1 addition & 0 deletions exla/c_src/exla/mlir/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,4 @@ ERL_NIF_TERM dump_mlir_module(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[
ERL_NIF_TERM mlir_scatter(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]);
ERL_NIF_TERM mlir_select_and_scatter(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]);
ERL_NIF_TERM mlir_fft(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]);
ERL_NIF_TERM mlir_create_token(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]);
2 changes: 2 additions & 0 deletions exla/lib/exla/nif.ex
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ defmodule EXLA.NIF do
),
do: :erlang.nif_error(:undef)

def mlir_create_token(_function), do: :erlang.nif_error(:undef)

def new_builder(_name),
do: :erlang.nif_error(:undef)

Expand Down