Skip to content

Commit

Permalink
fix: use singleton mlir context (#1454)
Browse files Browse the repository at this point in the history
Co-authored-by: José Valim <[email protected]>
  • Loading branch information
polvalente and josevalim authored Feb 25, 2024
1 parent bd346ab commit 7b087da
Show file tree
Hide file tree
Showing 21 changed files with 214 additions and 161 deletions.
9 changes: 7 additions & 2 deletions exla/c_src/exla/exla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,20 @@ static int open_resources(ErlNifEnv* env) {
if (!exla::nif::open_resource<exla::ExlaBuffer*>(env, mod, "ExlaBuffer", free_exla_buffer)) {
return -1;
}
// MLIR
if (!exla::nif::open_resource<exla::MLIRFunction*>(env, mod, "MLIRBlock")) {
return -1;
}
if (!exla::nif::open_resource<mlir::Value>(env, mod, "MLIRValue")) {
return -1;
}
// MLIR
if (!exla::nif::open_resource<exla::MLIRModule*>(env, mod, "ExlaMLIRModule")) {
return -1;
}

if (!exla::nif::open_resource<mlir::MLIRContext*>(env, mod, "MLIRContext")) {
return -1;
}
return 1;
}

Expand Down Expand Up @@ -670,7 +674,8 @@ ERL_NIF_TERM start_log_sink(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[])

static ErlNifFunc exla_funcs[] = {
// MLIR Builder
{"new_mlir_module", 0, new_mlir_module},
{"new_mlir_context", 0, new_mlir_context},
{"new_mlir_module", 1, new_mlir_module},
{"create_mlir_function", 5, create_mlir_function},
{"get_mlir_function_arguments", 1, get_mlir_function_arguments},
{"mlir_add", 3, mlir_add},
Expand Down
14 changes: 4 additions & 10 deletions exla/c_src/exla/mlir/builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1115,16 +1115,10 @@ ERL_NIF_TERM MLIRFunction::ConstantOp(mlir::Type type, ErlNifEnv *env, ERL_NIF_T
return exla::nif::error(env, "invalid type received");
}

MLIRModule::MLIRModule() {
context_ = std::make_unique<mlir::MLIRContext>();

context_->loadDialect<mlir::func::FuncDialect>();
context_->loadDialect<mlir::stablehlo::StablehloDialect>();
context_->loadDialect<mlir::mhlo::MhloDialect>();
context_->loadDialect<mlir::chlo::ChloDialect>();

module_ = mlir::OwningOpRef<mlir::ModuleOp>(mlir::ModuleOp::create(mlir::UnknownLoc::get(context_.get())));
builder_ = std::make_unique<mlir::OpBuilder>(context_.get());
MLIRModule::MLIRModule(mlir::MLIRContext *context) {
context_ = context;
module_ = mlir::OwningOpRef<mlir::ModuleOp>(mlir::ModuleOp::create(mlir::UnknownLoc::get(context_)));
builder_ = std::make_unique<mlir::OpBuilder>(context_);
builder_->setInsertionPointToStart(module_->getBody());
}

Expand Down
10 changes: 6 additions & 4 deletions exla/c_src/exla/mlir/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class MLIRFunction {

class MLIRModule {
public:
MLIRModule();
MLIRModule(mlir::MLIRContext *context);

MLIRFunction *CreateFunction(
std::string name,
Expand All @@ -151,16 +151,18 @@ class MLIRModule {

mlir::ModuleOp module() { return module_.get(); }
mlir::OpBuilder *builder() { return builder_.get(); }
mlir::MLIRContext *context() { return context_.get(); }
mlir::MLIRContext *context() { return context_; }

void LowerPatterns();

private:
std::unique_ptr<mlir::MLIRContext> context_;
mlir::MLIRContext *context_;
mlir::OwningOpRef<mlir::ModuleOp> module_;
std::unique_ptr<mlir::OpBuilder> builder_;
};

mlir::Type TypeIntToMLIRType(mlir::OpBuilder *builder, xla::PrimitiveType type_int);
mlir::Type
TypeIntToMLIRType(mlir::OpBuilder *builder, xla::PrimitiveType type_int);

xla::PrimitiveType MLIRTypeToPrimitiveType(mlir::Type);
} // namespace exla
Expand Down
31 changes: 27 additions & 4 deletions exla/c_src/exla/mlir/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@

#include "../exla_client.h"
#include "../exla_nif_util.h"
#include "mhlo/IR/hlo_ops.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "stablehlo/dialect/ChloOps.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/shape_util.h"

// MLIR Builder Functions
Expand Down Expand Up @@ -45,8 +49,6 @@ ERL_NIF_TERM mlir_compile(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
return exla::nif::error(env, "Unable to get device ID.");
}

(*module)->LowerPatterns();

build_options.set_num_replicas(num_replicas);
build_options.set_num_partitions(num_partitions);
build_options.set_use_spmd_partitioning(use_spmd);
Expand All @@ -63,12 +65,33 @@ ERL_NIF_TERM mlir_compile(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
return exla::nif::ok(env, exla::nif::make<exla::ExlaExecutable*>(env, executable));
}

ERL_NIF_TERM new_mlir_module(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
ERL_NIF_TERM new_mlir_context(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 0) {
return exla::nif::error(env, "Bad argument count.");
}

exla::MLIRModule* module = new exla::MLIRModule();
mlir::MLIRContext* context = new mlir::MLIRContext();
context->getOrLoadDialect<mlir::func::FuncDialect>();
context->getOrLoadDialect<mlir::stablehlo::StablehloDialect>();
context->getOrLoadDialect<mlir::mhlo::MhloDialect>();
context->getOrLoadDialect<mlir::chlo::ChloDialect>();

auto ret = exla::nif::make<mlir::MLIRContext*>(env, context);
return exla::nif::ok(env, ret);
}

ERL_NIF_TERM new_mlir_module(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
if (argc != 1) {
return exla::nif::error(env, "Bad argument count.");
}

mlir::MLIRContext** ctx;

if (!exla::nif::get<mlir::MLIRContext*>(env, argv[0], ctx)) {
return exla::nif::error(env, "Unable to get context.");
}

exla::MLIRModule* module = new exla::MLIRModule(*ctx);

return exla::nif::ok(env, exla::nif::make<exla::MLIRModule*>(env, module));
}
Expand Down
1 change: 1 addition & 0 deletions exla/c_src/exla/mlir/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

DEFINE_NIF(mlir_compile);
DEFINE_NIF(new_mlir_module);
DEFINE_NIF(new_mlir_context);
DEFINE_NIF(create_mlir_function);
DEFINE_NIF(get_mlir_function_arguments);
DEFINE_NIF(mlir_tuple);
Expand Down
5 changes: 5 additions & 0 deletions exla/lib/exla/application.ex
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ defmodule EXLA.Application do

children = [
EXLA.Logger,
{NimblePool,
worker: {EXLA.MLIR.ContextPool, :pool_state},
pool_size: System.schedulers_online(),
name: EXLA.MLIR.ContextPool,
lazy: true},
EXLA.Client,
EXLA.Defn.Lock,
EXLA.Defn.LockedCache,
Expand Down
17 changes: 0 additions & 17 deletions exla/lib/exla/builder.ex
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,10 @@ defmodule EXLA.Builder do

alias EXLA.Computation
alias EXLA.Op
alias EXLA.MLIR.Module, as: M

@enforce_keys [:ref]
defstruct [:ref, :parent, :name]

def new_mlir(module_and_name, arg_shapes, return_shape) do
{module, name, is_public} =
case module_and_name do
{%M{} = module, name} -> {module, name, false}
_name -> {M.new(), "main", true}
end

M.create_function(
module,
name,
arg_shapes,
return_shape,
is_public
)
end

def new(name) when is_binary(name) do
{:ok, ref} = EXLA.NIF.new_builder(name)
%__MODULE__{ref: ref, parent: nil, name: name}
Expand Down
83 changes: 41 additions & 42 deletions exla/lib/exla/defn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ defmodule EXLA.Defn do
%{module: module, name: name} = subbuilder(builder, "while-pred")
out_types = container_to_exla_shape(expr)

pred_fun = EXLA.Builder.new_mlir({module, name}, arg_shapes, out_types)
pred_fun = EXLA.MLIR.Module.add_function(module, name, arg_shapes, out_types)

[flag | _] = EXLA.MLIR.Function.get_arguments(pred_fun)

Expand All @@ -188,7 +188,7 @@ defmodule EXLA.Defn do

%{module: module, name: name} = subbuilder(builder, "while-body")

body_fun = EXLA.Builder.new_mlir({module, name}, arg_shapes, out_types)
body_fun = EXLA.MLIR.Module.add_function(module, name, arg_shapes, out_types)

[_flag, token | args] = EXLA.MLIR.Function.get_arguments(body_fun)

Expand Down Expand Up @@ -583,45 +583,44 @@ defmodule EXLA.Defn do

mode = options[:compiler_mode] || Application.get_env(:exla, :compiler_mode, :mlir)

{mod, builder} =
{mod, compile_fn} =
case mode do
:xla ->
{EXLA.Op, EXLA.Builder.new(inspect(key))}

:mlir ->
comp_arg_shapes =
for {i, shape} <- inputs_and_shapes, i >= used_buffers, do: shape

out_types =
[outputs]
|> Nx.Defn.Composite.flatten_list()
|> Enum.map(fn t ->
t
|> Nx.devectorize()
|> then(&EXLA.Shape.make_shape(&1.type, &1.shape))
end)

{Value, EXLA.Builder.new_mlir(inspect(key), comp_arg_shapes, out_types)}
:xla -> {EXLA.Op, fn _, _, fun -> fun.(EXLA.Builder.new(inspect(key))) end}
:mlir -> {Value, &EXLA.MLIR.Module.new/3}
end

outfeed =
outfeed
|> Outfeed.with_token(mod.create_token(builder))
|> Outfeed.add_infeeds(builder, reverse_infeeds)
comp_arg_shapes =
for {i, shape} <- inputs_and_shapes, i >= used_buffers, do: shape

expr = Nx.Defn.Composite.traverse(expr || fun.(vars), &Nx.devectorize/1)
out_types =
[outputs]
|> Nx.Defn.Composite.flatten_list()
|> Enum.map(fn t ->
t
|> Nx.devectorize()
|> then(&EXLA.Shape.make_shape(&1.type, &1.shape))
end)

{computation, extra, outfeed} =
to_computation.(builder, expr, inputs_and_shapes, outfeed)
compile_fn.(comp_arg_shapes, out_types, fn builder ->
outfeed =
outfeed
|> Outfeed.with_token(mod.create_token(builder))
|> Outfeed.add_infeeds(builder, reverse_infeeds)

{xla_time, executable} =
:timer.tc(fn ->
shapes = for {i, shape} <- inputs_and_shapes, i >= used_buffers, do: shape
expr = Nx.Defn.Composite.traverse(expr || fun.(vars), &Nx.devectorize/1)

EXLA.Computation.compile(computation, client, shapes, options)
end)
{computation, extra, outfeed} =
to_computation.(builder, expr, inputs_and_shapes, outfeed)

{xla_time, executable} =
:timer.tc(fn ->
shapes = for {i, shape} <- inputs_and_shapes, i >= used_buffers, do: shape

EXLA.Computation.compile(computation, client, shapes, options)
end)

{:ok, {xla_time, executable, extra, %{outfeed | infeeds: []}}}
{:ok, {xla_time, executable, extra, %{outfeed | infeeds: []}}}
end)
end)
end)

Expand Down Expand Up @@ -2050,11 +2049,9 @@ defmodule EXLA.Defn do
arg_shapes = Enum.map(args, &EXLA.Shape.make_shape(&1.type, &1.shape))

function =
EXLA.Builder.new_mlir(
{module, name},
arg_shapes,
[EXLA.Shape.make_shape({:pred, 8}, {})]
)
EXLA.MLIR.Module.add_function(module, name, arg_shapes, [
EXLA.Shape.make_shape({:pred, 8}, {})
])

[lhs, rhs | _] = EXLA.MLIR.Function.get_arguments(function)

Expand Down Expand Up @@ -2113,7 +2110,7 @@ defmodule EXLA.Defn do
) do
%{module: module, name: name} = subbuilder(builder, Atom.to_string(op))

function = EXLA.Builder.new_mlir({module, name}, arg_shapes, out)
function = EXLA.MLIR.Module.add_function(module, name, arg_shapes, out)

args = EXLA.MLIR.Function.get_arguments(function)

Expand Down Expand Up @@ -2143,7 +2140,7 @@ defmodule EXLA.Defn do

out_type = container_to_exla_shape(expr)

function = EXLA.Builder.new_mlir({module, Atom.to_string(name)}, arg_shapes, out_type)
function = EXLA.MLIR.Module.add_function(module, Atom.to_string(name), arg_shapes, out_type)
mlir_args = EXLA.MLIR.Function.get_arguments(function)

arg_params = Enum.zip(args, mlir_args)
Expand Down Expand Up @@ -2199,7 +2196,7 @@ defmodule EXLA.Defn do
out_types
end

function = EXLA.Builder.new_mlir({module, name}, arg_shapes, out_types)
function = EXLA.MLIR.Module.add_function(module, name, arg_shapes, out_types)

[arg_token | arg_params] = EXLA.MLIR.Function.get_arguments(function)

Expand Down Expand Up @@ -2268,7 +2265,9 @@ defmodule EXLA.Defn do
out_shapes = container_to_exla_shape(expr)

function =
EXLA.Builder.new_mlir({module, name}, [token_shape | arg_shapes], [token_shape | out_shapes])
EXLA.MLIR.Module.add_function(module, name, [token_shape | arg_shapes], [
token_shape | out_shapes
])

[arg_token | tail] = EXLA.MLIR.Function.get_arguments(function)

Expand Down
5 changes: 3 additions & 2 deletions exla/lib/exla/lib.ex
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,9 @@ defmodule EXLA.Lib do
%{module: module, name: name} = subbuilder(builder, "min-max")

function =
EXLA.Builder.new_mlir(
{module, name},
EXLA.MLIR.Module.add_function(
module,
name,
[
EXLA.Shape.make_shape(type, {}),
EXLA.Shape.make_shape(index_type, {}),
Expand Down
37 changes: 37 additions & 0 deletions exla/lib/exla/mlir/context_pool.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
defmodule EXLA.MLIR.ContextPool do
@moduledoc false
# Internal pool for MLIRContext reference management
@behaviour NimblePool

def checkout(fun) when is_function(fun, 1) do
NimblePool.checkout!(
__MODULE__,
:checkout,
fn _pool, context -> {fun.(context), :ok} end,
:infinity
)
end

@impl NimblePool
def init_worker(pool_state) do
{:ok, context} = EXLA.NIF.new_mlir_context()
{:ok, context, pool_state}
end

@impl NimblePool
def handle_checkout(:checkout, _from, context, pool_state) do
{:ok, context, context, pool_state}
end

@impl NimblePool
def handle_checkin(:ok, _from, context, pool_state) do
# We just keep the references around and let them die out upon worker termination/GC
{:ok, context, pool_state}
end

@impl NimblePool
def terminate_worker(_reason, _context, pool_state) do
# GC will clean it up
{:ok, pool_state}
end
end
Loading

0 comments on commit 7b087da

Please sign in to comment.