From fba9efe1b85cfafc1514df7246a300d82b7e351f Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Tue, 7 May 2024 23:08:46 -0300 Subject: [PATCH] fix: get_tuple NIF helper + compilation configs and speedup (#1479) --- exla/Makefile | 2 +- exla/README.md | 2 ++ exla/c_src/exla/exla_nif_util.cc | 6 +++--- exla/mix.exs | 6 +++++- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/exla/Makefile b/exla/Makefile index 8efb4c1fc0..bd1b1da6ae 100644 --- a/exla/Makefile +++ b/exla/Makefile @@ -73,7 +73,7 @@ ifeq ($(NVCC_TEST),nvcc) NVCC := nvcc NVCCFLAGS += -DCUDA_ENABLED else - NVCC := g++ + NVCC := $(CXX) NVCCFLAGS = $(CFLAGS) endif diff --git a/exla/README.md b/exla/README.md index 302a829ec7..3091555796 100644 --- a/exla/README.md +++ b/exla/README.md @@ -59,6 +59,8 @@ mix deps.get mix test ``` +By default, EXLA passes `["-jN"]` as a Make argument, where `N` is `System.schedulers_online() - 2`, capped at `1`. `config :exla, :make_args, ...` can be used to override this default setting. + In order to run tests on a specific device, use the `EXLA_TARGET` environment variable, which is a dev-only variable for this project (it has no effect when using EXLA as a dependency). For example, `EXLA_TARGET=cuda` or `EXLA_TARGET=rocm`. Make sure to also specify `XLA_TARGET` to fetch or compile a proper version of the XLA binary. ### Building with Docker diff --git a/exla/c_src/exla/exla_nif_util.cc b/exla/c_src/exla/exla_nif_util.cc index 563a4cd8eb..d38785f6ed 100644 --- a/exla/c_src/exla/exla_nif_util.cc +++ b/exla/c_src/exla/exla_nif_util.cc @@ -1,10 +1,10 @@ #include "exla_nif_util.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" +#include "stablehlo/dialect/StablehloOps.h" #include "xla/primitive_util.h" #include "xla/shape_util.h" -#include "mlir/IR/Builders.h" -#include "stablehlo/dialect/StablehloOps.h" namespace exla { namespace nif { @@ -190,7 +190,7 @@ int get_tuple(ErlNifEnv* env, ERL_NIF_TERM tuple, std::vector& var) { var.reserve(length); for (int i = 0; i < length; i++) { - int data; + int64 data; if (!get(env, terms[i], &data)) return 0; var.push_back(data); } diff --git a/exla/mix.exs b/exla/mix.exs index 517e051772..f0b4291e43 100644 --- a/exla/mix.exs +++ b/exla/mix.exs @@ -5,6 +5,9 @@ defmodule EXLA.MixProject do @version "0.7.1" def project do + make_args = + Application.get_env(:exla, :make_args) || ["-j#{max(System.schedulers_online() - 2, 1)}"] + [ app: :exla, version: @version, @@ -34,7 +37,8 @@ defmodule EXLA.MixProject do "MIX_BUILD_EMBEDDED" => "#{Mix.Project.config()[:build_embedded]}", "CWD_RELATIVE_TO_PRIV_PATH" => cwd_relative_to_priv } - end + end, + make_args: make_args ] end