Skip to content

Commit

Permalink
refactor(mlir): use if regions (#1453)
Browse files Browse the repository at this point in the history
  • Loading branch information
polvalente authored Feb 23, 2024
1 parent 53ce7be commit bd346ab
Show file tree
Hide file tree
Showing 10 changed files with 278 additions and 303 deletions.
4 changes: 3 additions & 1 deletion exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,9 @@ static ErlNifFunc exla_funcs[] = {
{"mlir_reduce", 5, mlir_reduce},
{"mlir_window_reduce", 9, mlir_window_reduce},
{"mlir_map", 4, mlir_map},
{"mlir_if", 6, mlir_if},
{"mlir_if", 3, mlir_if},
{"mlir_set_if_block", 3, mlir_set_if_block},
{"mlir_pop_region", 1, mlir_pop_region},
{"mlir_while", 4, mlir_while},
{"mlir_return", 2, mlir_return},
// XlaBuilder
Expand Down
289 changes: 133 additions & 156 deletions exla/c_src/exla/mlir/builder.cc

Large diffs are not rendered by default.

11 changes: 9 additions & 2 deletions exla/c_src/exla/mlir/builder.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
#ifndef EXLA_MLIR_BUILDER_H_
#define EXLA_MLIR_BUILDER_H_

#include <stack>

#include "../exla_nif_util.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OwningOpRef.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/reference/Types.h"
#include "xla/shape.h"
#include "xla/types.h"
Expand Down Expand Up @@ -109,14 +112,16 @@ class MLIRFunction {
std::vector<mlir::Value> ReduceOp(MLIRFunction *function, std::vector<mlir::Value> init_values, std::vector<mlir::Value> inputs, std::vector<int64_t> dimensions);
std::vector<mlir::Value> WindowReduceOp(MLIRFunction *function, std::vector<mlir::Value> init_values, std::vector<mlir::Value> inputs, std::vector<int64_t> window_dimensions, std::vector<int64_t> window_strides, std::vector<int64_t> input_dilations, std::vector<int64_t> window_dilations, std::vector<std::pair<int64_t, int64_t>> padding);
mlir::Value MapOp(MLIRFunction *function, std::vector<mlir::Value> inputs, std::vector<int64_t> dimensions);
std::vector<mlir::Value> IfOp(mlir::Value pred, std::vector<xla::Shape> output_shape, std::vector<mlir::Value> implicit_args, MLIRFunction *on_true, MLIRFunction *on_false);
std::vector<mlir::Value> IfOp(mlir::Value pred, std::vector<xla::Shape> output_shape);
void SetIfOpBlock(mlir::Value node, bool true_or_false_branch);
ERL_NIF_TERM ConstantOp(mlir::Type type, ErlNifEnv *env, ERL_NIF_TERM value_ptr, std::optional<std::vector<int64_t>> dims = std::nullopt);
mlir::Value InfeedOp(mlir::Value token, xla::Shape *shape);
mlir::Value OutfeedOp(std::vector<mlir::Value> inputs, mlir::Value token);
std::vector<mlir::Value> CallOp(std::vector<mlir::Value> inputs, MLIRFunction *computation);
std::vector<mlir::Value> WhileOp(MLIRFunction *pred, MLIRFunction *body, std::vector<mlir::Value> initial);
std::vector<mlir::Value> ReturnOp(std::vector<mlir::Value> values);
int get_mlir_type(ErlNifEnv *env, ERL_NIF_TERM term, mlir::Type *type);
void PopRegion();

void Build(mlir::Value root);

Expand All @@ -128,7 +133,10 @@ class MLIRFunction {
std::shared_ptr<MLIRModule> module_;
std::unique_ptr<mlir::func::FuncOp> func_;

std::stack<mlir::Region *> regions;

void dump_mlir_module();
void setInsertionPoint();
};

class MLIRModule {
Expand All @@ -145,7 +153,6 @@ class MLIRModule {
mlir::OpBuilder *builder() { return builder_.get(); }
mlir::MLIRContext *context() { return context_.get(); }
void LowerPatterns();
void RemoveEmptyFunctions();

private:
std::unique_ptr<mlir::MLIRContext> context_;
Expand Down
50 changes: 37 additions & 13 deletions exla/c_src/exla/mlir/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ ERL_NIF_TERM mlir_compile(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
}

(*module)->LowerPatterns();
(*module)->RemoveEmptyFunctions();

build_options.set_num_replicas(num_replicas);
build_options.set_num_partitions(num_partitions);
Expand Down Expand Up @@ -865,15 +864,12 @@ ERL_NIF_TERM mlir_map(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
}

ERL_NIF_TERM mlir_if(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 6) {
if (argc != 3) {
return exla::nif::error(env, "Bad argument count.");
}

exla::MLIRFunction** function;
mlir::Value* pred;
std::vector<mlir::Value> implicit_args;
exla::MLIRFunction** on_true;
exla::MLIRFunction** on_false;
std::vector<xla::Shape> output_shapes;

if (!exla::nif::get<exla::MLIRFunction*>(env, argv[0], function)) {
Expand All @@ -885,19 +881,47 @@ ERL_NIF_TERM mlir_if(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (!exla::nif::get_list<xla::Shape>(env, argv[2], output_shapes)) {
return exla::nif::error(env, "Unable to get output shapes.");
}
if (!exla::nif::get_list<mlir::Value>(env, argv[3], implicit_args)) {
return exla::nif::error(env, "Unable to get implicit_args.");

std::vector<mlir::Value> result = (*function)->IfOp(*pred, output_shapes);
return exla::nif::ok(env, exla::nif::make_list<mlir::Value>(env, result));
}

ERL_NIF_TERM mlir_set_if_block(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 3) {
return exla::nif::error(env, "Bad argument count.");
}

exla::MLIRFunction** function;
mlir::Value* node;
bool true_or_false_branch;

if (!exla::nif::get<exla::MLIRFunction*>(env, argv[0], function)) {
return exla::nif::error(env, "Unable to get function.");
}
if (!exla::nif::get<mlir::Value>(env, argv[1], node)) {
return exla::nif::error(env, "Unable to get node.");
}
if (!exla::nif::get<exla::MLIRFunction*>(env, argv[4], on_true)) {
return exla::nif::error(env, "Unable to get on_true.");
if (!exla::nif::get(env, argv[2], &true_or_false_branch)) {
return exla::nif::error(env, "Unable to get true_or_false_branch.");
}
if (!exla::nif::get<exla::MLIRFunction*>(env, argv[5], on_false)) {
return exla::nif::error(env, "Unable to get on_false.");

(*function)->SetIfOpBlock(*node, true_or_false_branch);
return exla::nif::ok(env);
}

ERL_NIF_TERM mlir_pop_region(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 1) {
return exla::nif::error(env, "Bad argument count.");
}

std::vector<mlir::Value> result = (*function)->IfOp(*pred, output_shapes, implicit_args, *on_true, *on_false);
exla::MLIRFunction** function;

return exla::nif::ok(env, exla::nif::make_list<mlir::Value>(env, result));
if (!exla::nif::get<exla::MLIRFunction*>(env, argv[0], function)) {
return exla::nif::error(env, "Unable to get function.");
}

(*function)->PopRegion();
return exla::nif::ok(env);
}

ERL_NIF_TERM mlir_bitcast_convert(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
Expand Down
2 changes: 2 additions & 0 deletions exla/c_src/exla/mlir/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ DEFINE_NIF(mlir_reduce);
DEFINE_NIF(mlir_window_reduce);
DEFINE_NIF(mlir_map);
DEFINE_NIF(mlir_if);
DEFINE_NIF(mlir_set_if_block);
DEFINE_NIF(mlir_pop_region);
DEFINE_NIF(mlir_infeed);
DEFINE_NIF(mlir_outfeed);
DEFINE_NIF(mlir_call);
Expand Down
Loading

0 comments on commit bd346ab

Please sign in to comment.