From 483f7ab72726083b59c50b754754c0df77acc9b5 Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 8 Nov 2023 14:56:20 -0300 Subject: [PATCH 1/3] wip: torchx put slice --- torchx/c_src/torchx.cpp | 487 +++++++++++++---------------------- torchx/lib/torchx.ex | 2 +- torchx/lib/torchx/backend.ex | 88 +------ 3 files changed, 177 insertions(+), 400 deletions(-) diff --git a/torchx/c_src/torchx.cpp b/torchx/c_src/torchx.cpp index 6ad2c18d34..958f5919a5 100644 --- a/torchx/c_src/torchx.cpp +++ b/torchx/c_src/torchx.cpp @@ -1,28 +1,25 @@ #include #if defined(USING_TORCH_V1) - #include +#include #else - #include +#include #endif -#include #include +#include #include "nx_nif_utils.hpp" std::map dtypes = {{"byte", torch::kByte}, {"char", torch::kChar}, {"short", torch::kShort}, {"int", torch::kInt}, {"long", torch::kLong}, {"half", torch::kHalf}, {"brain", torch::kBFloat16}, {"float", torch::kFloat}, {"double", torch::kDouble}, {"bool", torch::kBool}, {"complex", at::ScalarType::ComplexFloat}, {"complex_double", at::ScalarType::ComplexDouble}}; std::map dtype_sizes = {{"byte", 1}, {"char", 1}, {"short", 2}, {"int", 4}, {"long", 8}, {"half", 2}, {"brain", 2}, {"float", 4}, {"double", 8}, {"complex", 8}, {"complex_double", 16}}; -inline torch::ScalarType string2type(const std::string &atom) -{ +inline torch::ScalarType string2type(const std::string &atom) { return dtypes[atom]; } -inline const std::string *type2string(const torch::ScalarType type) -{ - for (std::map::iterator i = dtypes.begin(); i != dtypes.end(); ++i) - { +inline const std::string *type2string(const torch::ScalarType type) { + for (std::map::iterator i = dtypes.begin(); i != dtypes.end(); ++i) { if (i->second == type) return &i->first; } @@ -30,14 +27,11 @@ inline const std::string *type2string(const torch::ScalarType type) } // the class instance to manage the refcount of Tensor -class TensorP -{ -public: - TensorP(ErlNifEnv *env, const ERL_NIF_TERM arg) : ptr(nullptr) - { +class TensorP { + public: + TensorP(ErlNifEnv *env, const ERL_NIF_TERM arg) : ptr(nullptr) { // setup - if (!enif_get_resource(env, arg, TENSOR_TYPE, (void **)&ptr)) - { + if (!enif_get_resource(env, arg, TENSOR_TYPE, (void **)&ptr)) { err = nx::nif::error(env, "Unable to get tensor param in NIF"); return; } @@ -45,62 +39,50 @@ class TensorP refcount = (std::atomic *)(ptr + 1); deleted = (std::atomic_flag *)(refcount + 1); - if (refcount->load() == 0) - { + if (refcount->load() == 0) { // already deallocated ptr = nullptr; err = nx::nif::error(env, "Tensor has been deallocated"); return; } - if (is_valid()) - { + if (is_valid()) { // increase reference count ++(*refcount); } } - ~TensorP() - { - if (is_valid()) - { + ~TensorP() { + if (is_valid()) { // decrease reference count - if (refcount->fetch_sub(1) == 0) - { + if (refcount->fetch_sub(1) == 0) { ptr->~Tensor(); } } } - bool deallocate() - { - if (is_valid() && atomic_flag_test_and_set(deleted) == false) - { + bool deallocate() { + if (is_valid() && atomic_flag_test_and_set(deleted) == false) { --(*refcount); return true; - } - else - { + } else { return false; } } - torch::Tensor *data() const - { + torch::Tensor *data() const { return ptr; } - bool is_valid() const - { + bool is_valid() const { return ptr != nullptr; } - ERL_NIF_TERM error() - { + ERL_NIF_TERM error() { return err; } -private: + private: torch::Tensor *ptr; std::atomic *refcount; std::atomic_flag *deleted; @@ -109,26 +91,21 @@ class TensorP #define NIF(NAME) ERL_NIF_TERM NAME(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) -#define SCALAR_PARAM(ARGN, VAR) \ - torch::Scalar VAR; \ - VAR.~Scalar(); \ - double double_##VAR; \ - std::vector complex_##VAR; \ - if (nx::nif::get_tuple(env, argv[ARGN], complex_##VAR)) \ - { \ - new (&VAR) torch::Scalar(c10::complex( \ - complex_##VAR[0], \ - complex_##VAR[1])); \ - } \ - else if (enif_get_double(env, argv[ARGN], &double_##VAR) == 0) \ - { \ - int64_t int64_##VAR; \ - enif_get_int64(env, argv[ARGN], (ErlNifSInt64 *)&int64_##VAR); \ - new (&VAR) torch::Scalar(int64_##VAR); \ - } \ - else \ - { \ - new (&VAR) torch::Scalar(double_##VAR); \ +#define SCALAR_PARAM(ARGN, VAR) \ + torch::Scalar VAR; \ + VAR.~Scalar(); \ + double double_##VAR; \ + std::vector complex_##VAR; \ + if (nx::nif::get_tuple(env, argv[ARGN], complex_##VAR)) { \ + new (&VAR) torch::Scalar(c10::complex( \ + complex_##VAR[0], \ + complex_##VAR[1])); \ + } else if (enif_get_double(env, argv[ARGN], &double_##VAR) == 0) { \ + int64_t int64_##VAR; \ + enif_get_int64(env, argv[ARGN], (ErlNifSInt64 *)&int64_##VAR); \ + new (&VAR) torch::Scalar(int64_##VAR); \ + } else { \ + new (&VAR) torch::Scalar(double_##VAR); \ } #define SHAPE_PARAM(ARGN, VAR) TUPLE_PARAM(ARGN, std::vector, VAR) @@ -146,26 +123,21 @@ class TensorP #define TENSOR_PARAM(ARGN, VAR) \ TensorP VAR##_tp(env, argv[ARGN]); \ torch::Tensor *VAR; \ - if (!VAR##_tp.is_valid()) \ - { \ + if (!VAR##_tp.is_valid()) { \ return VAR##_tp.error(); \ - } \ - else \ - { \ + } else { \ VAR = VAR##_tp.data(); \ } #define CATCH() \ - catch (c10::Error & error) \ - { \ + catch (c10::Error & error) { \ std::ostringstream msg; \ msg << error.msg() << " in NIF." << __func__ << "/" << argc; \ return nx::nif::error(env, msg.str().c_str()); \ } #define SCALAR(S) \ - try \ - { \ + try { \ if (c10::isFloatingType(S.type())) \ return nx::nif::ok(env, nx::nif::make(env, S.toDouble())); \ else \ @@ -174,15 +146,13 @@ class TensorP CATCH() #define TENSOR(T) \ - try \ - { \ + try { \ return nx::nif::ok(env, create_tensor_resource(env, T)); \ } \ CATCH() #define TENSOR_LIST(TL) \ - try \ - { \ + try { \ const std::vector &tl = TL; \ std::vector res_list; \ for (torch::Tensor t : tl) \ @@ -192,8 +162,7 @@ class TensorP CATCH() #define TENSOR_TUPLE(TT) \ - try \ - { \ + try { \ const std::tuple &tt = TT; \ std::vector res_list; \ for (torch::Tensor t : {std::get<0>(tt), std::get<1>(tt)}) \ @@ -203,8 +172,7 @@ class TensorP CATCH() #define TENSOR_TUPLE_3(TT) \ - try \ - { \ + try { \ const std::tuple &tt = TT; \ std::vector res_list; \ for (torch::Tensor t : {std::get<0>(tt), std::get<1>(tt), std::get<2>(tt)}) \ @@ -214,8 +182,7 @@ class TensorP CATCH() ERL_NIF_TERM -create_tensor_resource(ErlNifEnv *env, torch::Tensor tensor) -{ +create_tensor_resource(ErlNifEnv *env, torch::Tensor tensor) { ERL_NIF_TERM ret; torch::Tensor *tensorPtr; std::atomic *refcount; @@ -234,20 +201,17 @@ create_tensor_resource(ErlNifEnv *env, torch::Tensor tensor) return ret; } -NIF(delete_tensor) -{ +NIF(delete_tensor) { TensorP tensor(env, argv[0]); return tensor.deallocate() ? nx::nif::ok(env) : enif_make_badarg(env); } -uint64_t elem_count(std::vector shape) -{ +uint64_t elem_count(std::vector shape) { return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>{}); } -NIF(from_blob) -{ +NIF(from_blob) { BINARY_PARAM(0, blob); SHAPE_PARAM(1, shape); TYPE_PARAM(2, type); @@ -258,18 +222,14 @@ NIF(from_blob) auto tensor = torch::from_blob(blob.data, shape, torch::device(torch::kCPU).dtype(type)); - if (DEVICE(device).device().type() == torch::kCPU) - { + if (DEVICE(device).device().type() == torch::kCPU) { TENSOR(tensor.clone()); - } - else - { + } else { TENSOR(tensor.to(DEVICE(device))); } } -NIF(to_blob) -{ +NIF(to_blob) { ERL_NIF_TERM result; TENSOR_PARAM(0, t); size_t byte_size = t->nbytes(); @@ -277,8 +237,7 @@ NIF(to_blob) bool has_received_limit = (argc == 2); - if (has_received_limit) - { + if (has_received_limit) { PARAM(1, int64_t, param_limit); limit = param_limit; byte_size = limit * t->itemsize(); @@ -293,20 +252,15 @@ NIF(to_blob) torch::Tensor reshaped = (has_received_limit && byte_size < t->nbytes()) ? t->flatten().slice(0, 0, limit) : t->flatten(); void *data_ptr = reshaped.data_ptr(); - if (device.has_value() && device.value().type() == torch::kCPU && data_ptr == t->data_ptr()) - { + if (device.has_value() && device.value().type() == torch::kCPU && data_ptr == t->data_ptr()) { // case where we own the data_ptr and the data is in the CPU already return nx::nif::ok(env, enif_make_resource_binary(env, t, data_ptr, byte_size)); - } - else if (device.has_value() && device.value().type() == torch::kCPU) - { + } else if (device.has_value() && device.value().type() == torch::kCPU) { // case where we don't own the data_ptr but the data is in the CPU already void *result_data = (void *)enif_make_new_binary(env, byte_size, &result); memcpy(result_data, data_ptr, byte_size); return nx::nif::ok(env, result); - } - else - { + } else { // case where the data isn't in the CPU, therefore we don't own the data_ptr void *result_data = (void *)enif_make_new_binary(env, byte_size, &result); memcpy(result_data, reshaped.to(torch::kCPU).data_ptr(), byte_size); @@ -314,15 +268,13 @@ NIF(to_blob) } } -NIF(item) -{ +NIF(item) { TENSOR_PARAM(0, t); SCALAR(t->item()); } -NIF(scalar_type) -{ +NIF(scalar_type) { TENSOR_PARAM(0, t); const std::string *type_name = type2string(t->scalar_type()); @@ -333,8 +285,7 @@ NIF(scalar_type) return nx::nif::error(env, "Could not determine tensor type."); } -NIF(shape) -{ +NIF(shape) { TENSOR_PARAM(0, t); std::vector sizes; @@ -344,8 +295,7 @@ NIF(shape) return nx::nif::ok(env, enif_make_tuple_from_array(env, sizes.data(), sizes.size())); } -NIF(mps_is_available) -{ +NIF(mps_is_available) { #ifdef MAC_ARM64 bool has_mps = at::hasMPS(); #else @@ -354,78 +304,66 @@ NIF(mps_is_available) return nx::nif::make(env, has_mps); } -NIF(cuda_is_available) -{ +NIF(cuda_is_available) { return nx::nif::make(env, (bool)torch::cuda::is_available()); } -NIF(cuda_device_count) -{ +NIF(cuda_device_count) { return nx::nif::make(env, (int)torch::cuda::device_count()); } -NIF(nbytes) -{ +NIF(nbytes) { TENSOR_PARAM(0, t); return nx::nif::ok(env, enif_make_int64(env, t->nbytes())); } -NIF(split) -{ +NIF(split) { TENSOR_PARAM(0, t); PARAM(1, int64_t, batch_size); TENSOR_LIST(torch::split(*t, batch_size)); } -NIF(reshape) -{ +NIF(reshape) { TENSOR_PARAM(0, t); SHAPE_PARAM(1, shape); TENSOR(torch::reshape(*t, shape)); } -NIF(to_type) -{ +NIF(to_type) { TENSOR_PARAM(0, t); TYPE_PARAM(1, type); TENSOR(t->toType(type)); } -NIF(to_device) -{ +NIF(to_device) { TENSOR_PARAM(0, t); DEVICE_PARAM(1, device); TENSOR(t->to(DEVICE(device))); } -NIF(squeeze) -{ +NIF(squeeze) { TENSOR_PARAM(0, t); - if (argc == 2) - { + if (argc == 2) { PARAM(1, int64_t, dim); TENSOR(torch::squeeze(*t, dim)); - } - else + } else TENSOR(torch::squeeze(*t)); } -NIF(broadcast_to) -{ +NIF(broadcast_to) { TENSOR_PARAM(0, t); SHAPE_PARAM(1, shape); TENSOR(torch::broadcast_to(*t, shape).clone()); } -NIF(transpose) -{ +NIF(transpose) { TENSOR_PARAM(0, t); PARAM(1, int64_t, dim0); PARAM(2, int64_t, dim1); @@ -433,8 +371,7 @@ NIF(transpose) TENSOR(torch::transpose(*t, dim0, dim1)); } -NIF(narrow) -{ +NIF(narrow) { TENSOR_PARAM(0, t); PARAM(1, int64_t, dim); PARAM(2, int64_t, start); @@ -443,8 +380,7 @@ NIF(narrow) TENSOR(torch::narrow(*t, dim, start, length).clone()); } -NIF(as_strided) -{ +NIF(as_strided) { TENSOR_PARAM(0, t); SHAPE_PARAM(1, size); LIST_PARAM(2, std::vector, strides); @@ -453,8 +389,7 @@ NIF(as_strided) TENSOR(torch::as_strided(*t, size, strides, offset).clone()); } -NIF(concatenate) -{ +NIF(concatenate) { LIST_PARAM(0, std::vector, tensors); PARAM(1, int64_t, axis); @@ -462,8 +397,7 @@ NIF(concatenate) TENSOR(torch::cat(tensors, axis)); } -NIF(gather) -{ +NIF(gather) { TENSOR_PARAM(0, input); TENSOR_PARAM(1, indices); PARAM(2, int64_t, axis); @@ -478,7 +412,7 @@ NIF(index_put) { PARAM(3, bool, accumulate); c10::List> convertedList; - for (const torch::Tensor& tensor : indices) { + for (const torch::Tensor &tensor : indices) { convertedList.push_back(tensor); } @@ -490,15 +424,14 @@ NIF(index) { LIST_PARAM(1, std::vector, indices); c10::List> convertedList; - for (const torch::Tensor& tensor : indices) { + for (const torch::Tensor &tensor : indices) { convertedList.push_back(tensor); } TENSOR(torch::index(*input, convertedList)); } -NIF(argsort) -{ +NIF(argsort) { TENSOR_PARAM(0, input); PARAM(1, bool, stable); PARAM(2, int64_t, axis); @@ -507,24 +440,21 @@ NIF(argsort) TENSOR(torch::argsort(*input, stable, axis, is_descending)); } -NIF(top_k) -{ +NIF(top_k) { TENSOR_PARAM(0, input); PARAM(1, int64_t, k); TENSOR_TUPLE(at::topk(*input, k)); } -NIF(flip) -{ +NIF(flip) { TENSOR_PARAM(0, input); LIST_PARAM(1, std::vector, dims); TENSOR(torch::flip(*input, dims)); } -NIF(unfold) -{ +NIF(unfold) { TENSOR_PARAM(0, input); PARAM(1, int64_t, dim); PARAM(2, int64_t, size); @@ -533,17 +463,29 @@ NIF(unfold) TENSOR(at::native::unfold(*input, dim, size, step)); } -NIF(put) -{ +NIF(put) { TENSOR_PARAM(0, input); - TENSOR_PARAM(1, index); + LIST_PARAM(1, std::vector, indices); TENSOR_PARAM(2, source); - TENSOR(at::put(*input, *index, *source)); + torch::Tensor destination = input->clone(); + + auto source_shape = source->sizes(); + + size_t dim = 0; + for (dim = 0; dim < indices.size() - 1; dim++) { + auto start = indices[dim]; + // arguments are dimension, start index and NON-INCLUSIVE end index + destination = destination.slice(dim, start, start + source_shape[dim]); + } + + auto start = indices[dim]; + destination.slice(dim, start, start + source_shape[dim]) = *source; + + TENSOR(destination); } -NIF(permute) -{ +NIF(permute) { TENSOR_PARAM(0, t); LIST_PARAM(1, std::vector, dims); @@ -552,8 +494,7 @@ NIF(permute) /* Creation */ -NIF(scalar_tensor) -{ +NIF(scalar_tensor) { SCALAR_PARAM(0, scalar); TYPE_PARAM(1, type); DEVICE_PARAM(2, device); @@ -561,8 +502,7 @@ NIF(scalar_tensor) TENSOR(torch::scalar_tensor(scalar, OPTS(type, device))); } -NIF(randint) -{ +NIF(randint) { PARAM(0, int64_t, min); PARAM(1, int64_t, max); SHAPE_PARAM(2, shape); @@ -572,8 +512,7 @@ NIF(randint) TENSOR(torch::randint(min, max, shape, OPTS(type, device))); } -NIF(rand) -{ +NIF(rand) { PARAM(0, double, min); PARAM(1, double, max); SHAPE_PARAM(2, shape); @@ -583,8 +522,7 @@ NIF(rand) TENSOR(min + torch::rand(shape, OPTS(type, device)) * (max - min)); } -NIF(normal) -{ +NIF(normal) { PARAM(0, double, mean); PARAM(1, double, std); SHAPE_PARAM(2, shape); @@ -594,27 +532,22 @@ NIF(normal) TENSOR(torch::normal(mean, std, shape, c10::nullopt, OPTS(type, device))); } -NIF(arange) -{ +NIF(arange) { PARAM(0, int64_t, start); PARAM(1, int64_t, end); PARAM(2, int64_t, step); TYPE_PARAM(3, type); DEVICE_PARAM(4, device); - if (argc == 6) - { + if (argc == 6) { SHAPE_PARAM(5, shape); TENSOR(torch::reshape(torch::arange((double)start, (double)end, (double)step, OPTS(type, device)), shape)); - } - else - { + } else { TENSOR(torch::arange((double)start, (double)end, (double)step, OPTS(type, device))); } } -NIF(ones) -{ +NIF(ones) { SHAPE_PARAM(0, shape); TYPE_PARAM(1, type); DEVICE_PARAM(2, device); @@ -622,8 +555,7 @@ NIF(ones) TENSOR(torch::ones(shape, OPTS(type, device))); } -NIF(eye) -{ +NIF(eye) { PARAM(0, int64_t, m); PARAM(1, int64_t, n); TYPE_PARAM(2, type); @@ -632,8 +564,7 @@ NIF(eye) TENSOR(torch::eye(m, n, OPTS(type, device))); } -NIF(full) -{ +NIF(full) { SHAPE_PARAM(0, shape); SCALAR_PARAM(1, scalar); TYPE_PARAM(2, type); @@ -647,8 +578,7 @@ NIF(full) #define BINARY_OP(OP) BINARY_OP2(OP, OP) #define BINARY_OP2(OP, NATIVE_OP) \ - NIF(OP) \ - { \ + NIF(OP) { \ TENSOR_PARAM(0, a); \ TENSOR_PARAM(1, b); \ \ @@ -656,8 +586,7 @@ NIF(full) } #define BINARY_OPB(OP) \ - NIF(OP) \ - { \ + NIF(OP) { \ TENSOR_PARAM(0, a); \ TENSOR_PARAM(1, b); \ \ @@ -667,8 +596,7 @@ NIF(full) #define UNARY_OP(OP) UNARY_OP2(OP, OP) #define UNARY_OP2(OP, NATIVE) \ - NIF(OP) \ - { \ + NIF(OP) { \ TENSOR_PARAM(0, a); \ TENSOR(torch::NATIVE(*a)); \ } @@ -701,22 +629,19 @@ BINARY_OP(atan2) BINARY_OP(min) BINARY_OP(max) -NIF(fmod) -{ +NIF(fmod) { TENSOR_PARAM(0, a); TENSOR_PARAM(1, b); TENSOR(at::fmod(*a, *b)); } -NIF(quotient) -{ +NIF(quotient) { TENSOR_PARAM(0, a); TENSOR_PARAM(1, b); TENSOR(torch::divide(*a, *b, "trunc")); } -NIF(tensordot) -{ +NIF(tensordot) { TENSOR_PARAM(0, t1); TENSOR_PARAM(1, t2); LIST_PARAM(2, std::vector, axes1); @@ -728,37 +653,31 @@ NIF(tensordot) torch::Tensor result; - if (is_batched) - { + if (is_batched) { // if any of the tensors is batched, we need to apply some transformations // on the inputs and on the result to wrap the batched APIs that torch exposes std::vector batch_dims1, batch_dims2; int64_t vmap_level = 0; - for (auto dim : batch_axes1) - { + for (auto dim : batch_axes1) { batch_dims1.push_back(at::BatchDim(vmap_level++, dim)); } torch::Tensor batched_1 = at::makeBatched(*t1, at::BatchDims(batch_dims1.begin(), batch_dims1.end())); vmap_level = 0; - for (auto dim : batch_axes2) - { + for (auto dim : batch_axes2) { batch_dims2.push_back(at::BatchDim(vmap_level++, dim)); } torch::Tensor batched_2 = at::makeBatched(*t2, at::BatchDims(batch_dims2.begin(), batch_dims2.end())); torch::Tensor batched_result = torch::tensordot(batched_1, batched_2, axes1, axes2); auto impl = at::maybeGetBatchedImpl(batched_result); - if (!impl) - { + if (!impl) { return nx::nif::error(env, "unable to get tensordot result"); } result = torch::clone(impl->value()); - } - else - { + } else { result = torch::tensordot(*t1, *t2, axes1, axes2); } @@ -799,29 +718,25 @@ UNARY_OP(erf) UNARY_OP(erfc) UNARY_OP2(erf_inv, erfinv) -NIF(view_as_real) -{ +NIF(view_as_real) { TENSOR_PARAM(0, tensor); TENSOR(torch::view_as_real(*tensor)); } -NIF(conjugate) -{ +NIF(conjugate) { TENSOR_PARAM(0, tensor); at::Tensor conjugated = tensor->conj(); TENSOR(conjugated.clone(conjugated.suggest_memory_format())); } -NIF(triangular_solve) -{ +NIF(triangular_solve) { TENSOR_PARAM(0, a); TENSOR_PARAM(1, b); PARAM(2, bool, transpose); PARAM(3, bool, upper); auto ts_a = *a; - if (transpose) - { + if (transpose) { auto num_dims = a->dim(); ts_a = torch::transpose(*a, num_dims - 2, num_dims - 1); upper = !upper; @@ -832,15 +747,13 @@ NIF(triangular_solve) TENSOR(result); } -NIF(determinant) -{ +NIF(determinant) { TENSOR_PARAM(0, t); TENSOR(t->det()); } -NIF(sort) -{ +NIF(sort) { TENSOR_PARAM(0, t); PARAM(1, bool, stable); PARAM(2, int64_t, axis); @@ -850,8 +763,7 @@ NIF(sort) TENSOR(std::get<0>(result)); } -NIF(clip) -{ +NIF(clip) { TENSOR_PARAM(0, t); TENSOR_PARAM(1, min); TENSOR_PARAM(2, max); @@ -859,8 +771,7 @@ NIF(clip) TENSOR(torch::clip(*t, *min, *max)); } -NIF(where) -{ +NIF(where) { TENSOR_PARAM(0, pred); TENSOR_PARAM(1, on_true); TENSOR_PARAM(2, on_false); @@ -870,8 +781,7 @@ NIF(where) /* Aggregates */ -NIF(sum) -{ +NIF(sum) { TENSOR_PARAM(0, t); LIST_PARAM(1, std::vector, dims); PARAM(2, bool, keep_dim); @@ -879,12 +789,10 @@ NIF(sum) TENSOR(torch::sum(*t, dims, keep_dim)); } -NIF(product) -{ +NIF(product) { TENSOR_PARAM(0, t); - if (argc == 1) - { + if (argc == 1) { TENSOR(torch::prod(*t)); } @@ -894,48 +802,36 @@ NIF(product) TENSOR(torch::prod(*t, dim, keep_dim)); } -NIF(argmax) -{ +NIF(argmax) { TENSOR_PARAM(0, t); PARAM(1, int64_t, dim); PARAM(2, bool, keep_dim); - if (dim == -1) - { + if (dim == -1) { TENSOR(torch::argmax(*t)); - } - else - { + } else { TENSOR(torch::argmax(*t, dim, keep_dim)); } } -NIF(argmin) -{ +NIF(argmin) { TENSOR_PARAM(0, t); PARAM(1, int64_t, dim); PARAM(2, bool, keep_dim); - if (dim == -1) - { + if (dim == -1) { TENSOR(torch::argmin(*t)); - } - else - { + } else { TENSOR(torch::argmin(*t, dim, keep_dim)); } } -NIF(cbrt) -{ +NIF(cbrt) { TENSOR_PARAM(0, tensor); - if (tensor->scalar_type() == torch::kDouble) - { + if (tensor->scalar_type() == torch::kDouble) { TENSOR(torch::pow(*tensor, 1.0 / 3)); - } - else - { + } else { TENSOR(torch::pow(*tensor, 1.0f / 3)); } } @@ -967,30 +863,24 @@ NIF(ifft2) { TENSOR(torch::fft::ifft2(*tensor, lengths, axes)); } -NIF(is_nan) -{ +NIF(is_nan) { TENSOR_PARAM(0, tensor); TENSOR(torch::isnan(*tensor)); } -NIF(is_infinity) -{ +NIF(is_infinity) { TENSOR_PARAM(0, tensor); TENSOR(torch::isinf(*tensor)); } -NIF(all) -{ +NIF(all) { TENSOR_PARAM(0, t); - if (argc == 1) - { + if (argc == 1) { TENSOR(torch::all(*t)); - } - else - { + } else { PARAM(1, int64_t, axis); PARAM(2, bool, keep_dim); @@ -998,16 +888,12 @@ NIF(all) } } -NIF(any) -{ +NIF(any) { TENSOR_PARAM(0, t); - if (argc == 1) - { + if (argc == 1) { TENSOR(torch::any(*t)); - } - else - { + } else { PARAM(1, int64_t, axis); PARAM(2, bool, keep_dim); @@ -1015,8 +901,7 @@ NIF(any) } } -NIF(all_close) -{ +NIF(all_close) { TENSOR_PARAM(0, a); TENSOR_PARAM(1, b); PARAM(2, double, rtol); @@ -1029,24 +914,21 @@ NIF(all_close) TENSOR(torch::scalar_tensor(all_close, init_opts)); } -NIF(cumulative_sum) -{ +NIF(cumulative_sum) { TENSOR_PARAM(0, t); PARAM(1, int64_t, axis); TENSOR(torch::cumsum(*t, axis)); } -NIF(cumulative_product) -{ +NIF(cumulative_product) { TENSOR_PARAM(0, t); PARAM(1, int64_t, axis); TENSOR(torch::cumprod(*t, axis)); } -NIF(cumulative_min) -{ +NIF(cumulative_min) { TENSOR_PARAM(0, t); PARAM(1, int64_t, axis); @@ -1054,8 +936,7 @@ NIF(cumulative_min) TENSOR(std::get<0>(tt)); } -NIF(cumulative_max) -{ +NIF(cumulative_max) { TENSOR_PARAM(0, t); PARAM(1, int64_t, axis); @@ -1063,26 +944,22 @@ NIF(cumulative_max) TENSOR(std::get<0>(tt)); } -NIF(cholesky) -{ +NIF(cholesky) { TENSOR_PARAM(0, t); bool upper = false; - if (argc == 2) - { + if (argc == 2) { GET(1, upper); } - if (upper) - { + if (upper) { TENSOR(torch::linalg::cholesky(*t).mH()); } TENSOR(torch::linalg::cholesky(*t)); } -NIF(pad) -{ +NIF(pad) { TENSOR_PARAM(0, tensor); TENSOR_PARAM(1, constant); LIST_PARAM(2, std::vector, config); @@ -1092,34 +969,29 @@ NIF(pad) /* Transformations */ -NIF(qr) -{ +NIF(qr) { TENSOR_PARAM(0, t); bool reduced = true; - if (argc == 2) - { + if (argc == 2) { GET(1, reduced); } TENSOR_TUPLE(torch::linalg_qr(*t, reduced ? "reduced" : "complete")); } -NIF(svd) -{ +NIF(svd) { TENSOR_PARAM(0, t); bool full_matrices = true; - if (argc == 2) - { + if (argc == 2) { GET(1, full_matrices); } TENSOR_TUPLE_3(torch::linalg_svd(*t, full_matrices)); } -NIF(lu) -{ +NIF(lu) { TENSOR_PARAM(0, t); std::tuple lu_result = torch::linalg::lu_factor(*t); @@ -1128,8 +1000,7 @@ NIF(lu) TENSOR_TUPLE_3(plu); } -NIF(amax) -{ +NIF(amax) { TENSOR_PARAM(0, tensor); LIST_PARAM(1, std::vector, axes); PARAM(2, bool, keep_axes); @@ -1137,8 +1008,7 @@ NIF(amax) TENSOR(at::amax(*tensor, axes, keep_axes)); } -NIF(amin) -{ +NIF(amin) { TENSOR_PARAM(0, tensor); LIST_PARAM(1, std::vector, axes); PARAM(2, bool, keep_axes); @@ -1146,23 +1016,20 @@ NIF(amin) TENSOR(at::amin(*tensor, axes, keep_axes)); } -NIF(eigh) -{ +NIF(eigh) { TENSOR_PARAM(0, tensor); TENSOR_TUPLE(torch::linalg_eigh(*tensor)); } -NIF(solve) -{ +NIF(solve) { TENSOR_PARAM(0, tensorA); TENSOR_PARAM(1, tensorB); TENSOR(torch::linalg_solve(*tensorA, *tensorB)); } -NIF(conv) -{ +NIF(conv) { TENSOR_PARAM(0, tensor); TENSOR_PARAM(1, kernel); @@ -1184,8 +1051,7 @@ NIF(conv) stride, padding, dilation, transposed, output_padding, groups)); } -NIF(max_pool_3d) -{ +NIF(max_pool_3d) { TENSOR_PARAM(0, tensor); LIST_PARAM(1, std::vector, kernel_size); LIST_PARAM(2, std::vector, strides); @@ -1195,14 +1061,12 @@ NIF(max_pool_3d) TENSOR(at::max_pool3d(*tensor, kernel_size, strides, padding, dilation)); } -void free_tensor(ErlNifEnv *env, void *obj) -{ +void free_tensor(ErlNifEnv *env, void *obj) { torch::Tensor *tensor = reinterpret_cast(obj); std::atomic *refcount = reinterpret_cast *>(tensor + 1); std::atomic_flag *deleted = reinterpret_cast(refcount + 1); - if (atomic_flag_test_and_set(deleted) == false) - { + if (atomic_flag_test_and_set(deleted) == false) { tensor->~Tensor(); } @@ -1211,8 +1075,7 @@ void free_tensor(ErlNifEnv *env, void *obj) } static int -open_resource_type(ErlNifEnv *env) -{ +open_resource_type(ErlNifEnv *env) { const char *name = "Tensor"; ErlNifResourceFlags flags = (ErlNifResourceFlags)(ERL_NIF_RT_CREATE | ERL_NIF_RT_TAKEOVER); @@ -1222,8 +1085,7 @@ open_resource_type(ErlNifEnv *env) return 0; } -int upgrade(ErlNifEnv *env, void **priv_data, void **old_priv_data, ERL_NIF_TERM load_info) -{ +int upgrade(ErlNifEnv *env, void **priv_data, void **old_priv_data, ERL_NIF_TERM load_info) { // Silence "unused var" warnings. (void)(env); (void)(priv_data); @@ -1233,8 +1095,7 @@ int upgrade(ErlNifEnv *env, void **priv_data, void **old_priv_data, ERL_NIF_TERM return 0; } -int load(ErlNifEnv *env, void **priv_data, ERL_NIF_TERM load_info) -{ +int load(ErlNifEnv *env, void **priv_data, ERL_NIF_TERM load_info) { if (open_resource_type(env) == -1) return -1; diff --git a/torchx/lib/torchx.ex b/torchx/lib/torchx.ex index 17c721d89d..59e53bcc7c 100644 --- a/torchx/lib/torchx.ex +++ b/torchx/lib/torchx.ex @@ -248,7 +248,7 @@ defmodule Torchx do deftensor argsort(tensor, axis, is_descending, stable) deftensor flip(tensor, axis) deftensor unfold(tensor, dimension, size, step) - deftensor put(tensor_input, tensor_index, tensor_source) + deftensor put(tensor_input, index, tensor_source) deftensor where(tensorA, tensorB, tensorC) ## Aggregation diff --git a/torchx/lib/torchx/backend.ex b/torchx/lib/torchx/backend.ex index 0d06086f30..8f877f0440 100644 --- a/torchx/lib/torchx/backend.ex +++ b/torchx/lib/torchx/backend.ex @@ -339,7 +339,7 @@ defmodule Torchx.Backend do @impl true def put_slice(out, input, start_indices_unbounded, slice) do - {device, _} = input_tx = from_nx(input) + input_tx = from_nx(input) slice_shape_list = Tuple.to_list(slice.shape) @@ -351,30 +351,11 @@ defmodule Torchx.Backend do min(max(idx, 0), dim_size - len) end) - range_or_ranges = - [start_indices, slice_shape_list] - |> Enum.zip_with(fn [s, l] -> s..(s + l - 1)//1 end) - |> Enum.reverse() - |> Enum.reduce(fn range, acc -> for x <- range, y <- acc, do: List.flatten([x, y]) end) - - # if below is needed for when the reduce receives a single-element list - linear_indices_tx = - if is_list(range_or_ranges) do - range_or_ranges - |> Nx.tensor(backend: {__MODULE__, device: device}) - |> then(&as_torchx_linear_indices(input.shape, &1)) - else - range_or_ranges - |> Enum.to_list() - |> Nx.tensor(backend: {__MODULE__, device: device}) - |> Torchx.from_nx() - end - slice_tx = slice |> from_nx() |> Torchx.to_type(to_torch_type(out.type)) input_tx |> Torchx.to_type(to_torch_type(out.type)) - |> Torchx.put(linear_indices_tx, slice_tx) + |> Torchx.put(start_indices, slice_tx) |> to_nx(out) end @@ -534,71 +515,6 @@ defmodule Torchx.Backend do |> to_nx(out) end - defp as_torchx_linear_indices(shape, idx) do - # Nx provides indices as a tensor of shape {*, input_dims} - # However, torch expects indices to be a tensor of indices along a given axis. - # As such, we need to convert the indices tensor to linear indices. - # See the `linear_indices_offsets` function for an explanation on the offsets calculation. - - # Index limit validation - - ndims = tuple_size(shape) - - flattened_idx = Nx.reshape(idx, {div(Nx.size(idx), ndims), ndims}) - shape_tensor = shape |> Tuple.to_list() |> Nx.tensor() - - upper_clamped_idx = - flattened_idx - |> Nx.greater_equal(shape_tensor) - |> Nx.select(Nx.subtract(shape_tensor, 1), flattened_idx) - - lower_clamp_selector = Nx.less(upper_clamped_idx, 0) - - fully_clamped_idx = - lower_clamp_selector |> Nx.select(0, upper_clamped_idx) |> Nx.reshape(idx.shape) - - # Actual conversion algorithm - - linear_indices_offsets = - shape - |> linear_indices_offsets() - |> from_nx() - - lin_idx_num_elements = - idx.shape |> Tuple.delete_at(tuple_size(idx.shape) - 1) |> Tuple.product() - - fully_clamped_idx - |> from_nx() - |> Torchx.tensordot(linear_indices_offsets, [tuple_size(idx.shape) - 1], [0]) - |> Torchx.reshape({lin_idx_num_elements}) - end - - defp linear_indices_offsets(shape) do - # The offsets tensor calculated below follows a formula in which we - # multiply the index along each axis by the number of elements contained in all following axes - # For example, for a {3, 5, 7, 2} tensor, the offsets tensor is [70, 14, 2, 1] - - # This offsets tensor is then applied to the indices tensor through matrix multiplication: - # indices = [[0, 2, 1, 0], [0, 0, 0, 1], [1, 4, 3, 2]] - # offsets = [70, 14, 2, 1] - # linear_indices = [14 * 2 + 2 * 1, 1 * 1, 70 * 1 + 14 * 4 + 2 * 3 + 1 * 2] = [30, 1, 134] - - # By linear indices, we refer to the indices of a row-major representation of a tensor - # it's easy to see the expected values using Nx.iota(tensor), which will output a tensor - # which counts in exactly the same way, when provided no arguments. In effect, Nx.iota outputs - # the corresponding linear indices for a given tensor shape. - - {offsets_list, _} = - shape - |> Tuple.to_list() - |> Enum.reverse() - |> Enum.reduce({[], 1}, fn x, {acc, multiplier} -> - {[multiplier | acc], multiplier * x} - end) - - Nx.tensor(offsets_list, backend: __MODULE__) - end - @impl true def take_along_axis(out, tensor, idx, axis) do idx_tx = idx |> from_nx() |> Torchx.to_type(:long) From 53f6d2af917fce22e447bc19df397bf90f2108af Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 8 Nov 2023 15:00:30 -0300 Subject: [PATCH 2/3] fix: return outer ref --- torchx/c_src/torchx.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchx/c_src/torchx.cpp b/torchx/c_src/torchx.cpp index 958f5919a5..1fe4ec389f 100644 --- a/torchx/c_src/torchx.cpp +++ b/torchx/c_src/torchx.cpp @@ -468,7 +468,8 @@ NIF(put) { LIST_PARAM(1, std::vector, indices); TENSOR_PARAM(2, source); - torch::Tensor destination = input->clone(); + torch::Tensor output = input->clone(); + torch::Tensor destination = output; auto source_shape = source->sizes(); @@ -482,7 +483,7 @@ NIF(put) { auto start = indices[dim]; destination.slice(dim, start, start + source_shape[dim]) = *source; - TENSOR(destination); + TENSOR(output); } NIF(permute) { From 27fb855b2127c107e6042674a94a9c74c3a7706e Mon Sep 17 00:00:00 2001 From: Paulo Valente <16843419+polvalente@users.noreply.github.com> Date: Wed, 8 Nov 2023 15:01:37 -0300 Subject: [PATCH 3/3] chore: revert formatting --- torchx/c_src/torchx.cpp | 466 ++++++++++++++++++++++++++-------------- 1 file changed, 309 insertions(+), 157 deletions(-) diff --git a/torchx/c_src/torchx.cpp b/torchx/c_src/torchx.cpp index 1fe4ec389f..29b34bd1f0 100644 --- a/torchx/c_src/torchx.cpp +++ b/torchx/c_src/torchx.cpp @@ -1,25 +1,28 @@ #include #if defined(USING_TORCH_V1) -#include + #include #else -#include + #include #endif -#include #include +#include #include "nx_nif_utils.hpp" std::map dtypes = {{"byte", torch::kByte}, {"char", torch::kChar}, {"short", torch::kShort}, {"int", torch::kInt}, {"long", torch::kLong}, {"half", torch::kHalf}, {"brain", torch::kBFloat16}, {"float", torch::kFloat}, {"double", torch::kDouble}, {"bool", torch::kBool}, {"complex", at::ScalarType::ComplexFloat}, {"complex_double", at::ScalarType::ComplexDouble}}; std::map dtype_sizes = {{"byte", 1}, {"char", 1}, {"short", 2}, {"int", 4}, {"long", 8}, {"half", 2}, {"brain", 2}, {"float", 4}, {"double", 8}, {"complex", 8}, {"complex_double", 16}}; -inline torch::ScalarType string2type(const std::string &atom) { +inline torch::ScalarType string2type(const std::string &atom) +{ return dtypes[atom]; } -inline const std::string *type2string(const torch::ScalarType type) { - for (std::map::iterator i = dtypes.begin(); i != dtypes.end(); ++i) { +inline const std::string *type2string(const torch::ScalarType type) +{ + for (std::map::iterator i = dtypes.begin(); i != dtypes.end(); ++i) + { if (i->second == type) return &i->first; } @@ -27,11 +30,14 @@ inline const std::string *type2string(const torch::ScalarType type) { } // the class instance to manage the refcount of Tensor -class TensorP { - public: - TensorP(ErlNifEnv *env, const ERL_NIF_TERM arg) : ptr(nullptr) { +class TensorP +{ +public: + TensorP(ErlNifEnv *env, const ERL_NIF_TERM arg) : ptr(nullptr) + { // setup - if (!enif_get_resource(env, arg, TENSOR_TYPE, (void **)&ptr)) { + if (!enif_get_resource(env, arg, TENSOR_TYPE, (void **)&ptr)) + { err = nx::nif::error(env, "Unable to get tensor param in NIF"); return; } @@ -39,50 +45,62 @@ class TensorP { refcount = (std::atomic *)(ptr + 1); deleted = (std::atomic_flag *)(refcount + 1); - if (refcount->load() == 0) { + if (refcount->load() == 0) + { // already deallocated ptr = nullptr; err = nx::nif::error(env, "Tensor has been deallocated"); return; } - if (is_valid()) { + if (is_valid()) + { // increase reference count ++(*refcount); } } - ~TensorP() { - if (is_valid()) { + ~TensorP() + { + if (is_valid()) + { // decrease reference count - if (refcount->fetch_sub(1) == 0) { + if (refcount->fetch_sub(1) == 0) + { ptr->~Tensor(); } } } - bool deallocate() { - if (is_valid() && atomic_flag_test_and_set(deleted) == false) { + bool deallocate() + { + if (is_valid() && atomic_flag_test_and_set(deleted) == false) + { --(*refcount); return true; - } else { + } + else + { return false; } } - torch::Tensor *data() const { + torch::Tensor *data() const + { return ptr; } - bool is_valid() const { + bool is_valid() const + { return ptr != nullptr; } - ERL_NIF_TERM error() { + ERL_NIF_TERM error() + { return err; } - private: +private: torch::Tensor *ptr; std::atomic *refcount; std::atomic_flag *deleted; @@ -91,21 +109,26 @@ class TensorP { #define NIF(NAME) ERL_NIF_TERM NAME(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[]) -#define SCALAR_PARAM(ARGN, VAR) \ - torch::Scalar VAR; \ - VAR.~Scalar(); \ - double double_##VAR; \ - std::vector complex_##VAR; \ - if (nx::nif::get_tuple(env, argv[ARGN], complex_##VAR)) { \ - new (&VAR) torch::Scalar(c10::complex( \ - complex_##VAR[0], \ - complex_##VAR[1])); \ - } else if (enif_get_double(env, argv[ARGN], &double_##VAR) == 0) { \ - int64_t int64_##VAR; \ - enif_get_int64(env, argv[ARGN], (ErlNifSInt64 *)&int64_##VAR); \ - new (&VAR) torch::Scalar(int64_##VAR); \ - } else { \ - new (&VAR) torch::Scalar(double_##VAR); \ +#define SCALAR_PARAM(ARGN, VAR) \ + torch::Scalar VAR; \ + VAR.~Scalar(); \ + double double_##VAR; \ + std::vector complex_##VAR; \ + if (nx::nif::get_tuple(env, argv[ARGN], complex_##VAR)) \ + { \ + new (&VAR) torch::Scalar(c10::complex( \ + complex_##VAR[0], \ + complex_##VAR[1])); \ + } \ + else if (enif_get_double(env, argv[ARGN], &double_##VAR) == 0) \ + { \ + int64_t int64_##VAR; \ + enif_get_int64(env, argv[ARGN], (ErlNifSInt64 *)&int64_##VAR); \ + new (&VAR) torch::Scalar(int64_##VAR); \ + } \ + else \ + { \ + new (&VAR) torch::Scalar(double_##VAR); \ } #define SHAPE_PARAM(ARGN, VAR) TUPLE_PARAM(ARGN, std::vector, VAR) @@ -123,21 +146,26 @@ class TensorP { #define TENSOR_PARAM(ARGN, VAR) \ TensorP VAR##_tp(env, argv[ARGN]); \ torch::Tensor *VAR; \ - if (!VAR##_tp.is_valid()) { \ + if (!VAR##_tp.is_valid()) \ + { \ return VAR##_tp.error(); \ - } else { \ + } \ + else \ + { \ VAR = VAR##_tp.data(); \ } #define CATCH() \ - catch (c10::Error & error) { \ + catch (c10::Error & error) \ + { \ std::ostringstream msg; \ msg << error.msg() << " in NIF." << __func__ << "/" << argc; \ return nx::nif::error(env, msg.str().c_str()); \ } #define SCALAR(S) \ - try { \ + try \ + { \ if (c10::isFloatingType(S.type())) \ return nx::nif::ok(env, nx::nif::make(env, S.toDouble())); \ else \ @@ -146,13 +174,15 @@ class TensorP { CATCH() #define TENSOR(T) \ - try { \ + try \ + { \ return nx::nif::ok(env, create_tensor_resource(env, T)); \ } \ CATCH() #define TENSOR_LIST(TL) \ - try { \ + try \ + { \ const std::vector &tl = TL; \ std::vector res_list; \ for (torch::Tensor t : tl) \ @@ -162,7 +192,8 @@ class TensorP { CATCH() #define TENSOR_TUPLE(TT) \ - try { \ + try \ + { \ const std::tuple &tt = TT; \ std::vector res_list; \ for (torch::Tensor t : {std::get<0>(tt), std::get<1>(tt)}) \ @@ -172,7 +203,8 @@ class TensorP { CATCH() #define TENSOR_TUPLE_3(TT) \ - try { \ + try \ + { \ const std::tuple &tt = TT; \ std::vector res_list; \ for (torch::Tensor t : {std::get<0>(tt), std::get<1>(tt), std::get<2>(tt)}) \ @@ -182,7 +214,8 @@ class TensorP { CATCH() ERL_NIF_TERM -create_tensor_resource(ErlNifEnv *env, torch::Tensor tensor) { +create_tensor_resource(ErlNifEnv *env, torch::Tensor tensor) +{ ERL_NIF_TERM ret; torch::Tensor *tensorPtr; std::atomic *refcount; @@ -201,17 +234,20 @@ create_tensor_resource(ErlNifEnv *env, torch::Tensor tensor) { return ret; } -NIF(delete_tensor) { +NIF(delete_tensor) +{ TensorP tensor(env, argv[0]); return tensor.deallocate() ? nx::nif::ok(env) : enif_make_badarg(env); } -uint64_t elem_count(std::vector shape) { +uint64_t elem_count(std::vector shape) +{ return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>{}); } -NIF(from_blob) { +NIF(from_blob) +{ BINARY_PARAM(0, blob); SHAPE_PARAM(1, shape); TYPE_PARAM(2, type); @@ -222,14 +258,18 @@ NIF(from_blob) { auto tensor = torch::from_blob(blob.data, shape, torch::device(torch::kCPU).dtype(type)); - if (DEVICE(device).device().type() == torch::kCPU) { + if (DEVICE(device).device().type() == torch::kCPU) + { TENSOR(tensor.clone()); - } else { + } + else + { TENSOR(tensor.to(DEVICE(device))); } } -NIF(to_blob) { +NIF(to_blob) +{ ERL_NIF_TERM result; TENSOR_PARAM(0, t); size_t byte_size = t->nbytes(); @@ -237,7 +277,8 @@ NIF(to_blob) { bool has_received_limit = (argc == 2); - if (has_received_limit) { + if (has_received_limit) + { PARAM(1, int64_t, param_limit); limit = param_limit; byte_size = limit * t->itemsize(); @@ -252,15 +293,20 @@ NIF(to_blob) { torch::Tensor reshaped = (has_received_limit && byte_size < t->nbytes()) ? t->flatten().slice(0, 0, limit) : t->flatten(); void *data_ptr = reshaped.data_ptr(); - if (device.has_value() && device.value().type() == torch::kCPU && data_ptr == t->data_ptr()) { + if (device.has_value() && device.value().type() == torch::kCPU && data_ptr == t->data_ptr()) + { // case where we own the data_ptr and the data is in the CPU already return nx::nif::ok(env, enif_make_resource_binary(env, t, data_ptr, byte_size)); - } else if (device.has_value() && device.value().type() == torch::kCPU) { + } + else if (device.has_value() && device.value().type() == torch::kCPU) + { // case where we don't own the data_ptr but the data is in the CPU already void *result_data = (void *)enif_make_new_binary(env, byte_size, &result); memcpy(result_data, data_ptr, byte_size); return nx::nif::ok(env, result); - } else { + } + else + { // case where the data isn't in the CPU, therefore we don't own the data_ptr void *result_data = (void *)enif_make_new_binary(env, byte_size, &result); memcpy(result_data, reshaped.to(torch::kCPU).data_ptr(), byte_size); @@ -268,13 +314,15 @@ NIF(to_blob) { } } -NIF(item) { +NIF(item) +{ TENSOR_PARAM(0, t); SCALAR(t->item()); } -NIF(scalar_type) { +NIF(scalar_type) +{ TENSOR_PARAM(0, t); const std::string *type_name = type2string(t->scalar_type()); @@ -285,7 +333,8 @@ NIF(scalar_type) { return nx::nif::error(env, "Could not determine tensor type."); } -NIF(shape) { +NIF(shape) +{ TENSOR_PARAM(0, t); std::vector sizes; @@ -295,7 +344,8 @@ NIF(shape) { return nx::nif::ok(env, enif_make_tuple_from_array(env, sizes.data(), sizes.size())); } -NIF(mps_is_available) { +NIF(mps_is_available) +{ #ifdef MAC_ARM64 bool has_mps = at::hasMPS(); #else @@ -304,66 +354,78 @@ NIF(mps_is_available) { return nx::nif::make(env, has_mps); } -NIF(cuda_is_available) { +NIF(cuda_is_available) +{ return nx::nif::make(env, (bool)torch::cuda::is_available()); } -NIF(cuda_device_count) { +NIF(cuda_device_count) +{ return nx::nif::make(env, (int)torch::cuda::device_count()); } -NIF(nbytes) { +NIF(nbytes) +{ TENSOR_PARAM(0, t); return nx::nif::ok(env, enif_make_int64(env, t->nbytes())); } -NIF(split) { +NIF(split) +{ TENSOR_PARAM(0, t); PARAM(1, int64_t, batch_size); TENSOR_LIST(torch::split(*t, batch_size)); } -NIF(reshape) { +NIF(reshape) +{ TENSOR_PARAM(0, t); SHAPE_PARAM(1, shape); TENSOR(torch::reshape(*t, shape)); } -NIF(to_type) { +NIF(to_type) +{ TENSOR_PARAM(0, t); TYPE_PARAM(1, type); TENSOR(t->toType(type)); } -NIF(to_device) { +NIF(to_device) +{ TENSOR_PARAM(0, t); DEVICE_PARAM(1, device); TENSOR(t->to(DEVICE(device))); } -NIF(squeeze) { +NIF(squeeze) +{ TENSOR_PARAM(0, t); - if (argc == 2) { + if (argc == 2) + { PARAM(1, int64_t, dim); TENSOR(torch::squeeze(*t, dim)); - } else + } + else TENSOR(torch::squeeze(*t)); } -NIF(broadcast_to) { +NIF(broadcast_to) +{ TENSOR_PARAM(0, t); SHAPE_PARAM(1, shape); TENSOR(torch::broadcast_to(*t, shape).clone()); } -NIF(transpose) { +NIF(transpose) +{ TENSOR_PARAM(0, t); PARAM(1, int64_t, dim0); PARAM(2, int64_t, dim1); @@ -371,7 +433,8 @@ NIF(transpose) { TENSOR(torch::transpose(*t, dim0, dim1)); } -NIF(narrow) { +NIF(narrow) +{ TENSOR_PARAM(0, t); PARAM(1, int64_t, dim); PARAM(2, int64_t, start); @@ -380,7 +443,8 @@ NIF(narrow) { TENSOR(torch::narrow(*t, dim, start, length).clone()); } -NIF(as_strided) { +NIF(as_strided) +{ TENSOR_PARAM(0, t); SHAPE_PARAM(1, size); LIST_PARAM(2, std::vector, strides); @@ -389,7 +453,8 @@ NIF(as_strided) { TENSOR(torch::as_strided(*t, size, strides, offset).clone()); } -NIF(concatenate) { +NIF(concatenate) +{ LIST_PARAM(0, std::vector, tensors); PARAM(1, int64_t, axis); @@ -397,7 +462,8 @@ NIF(concatenate) { TENSOR(torch::cat(tensors, axis)); } -NIF(gather) { +NIF(gather) +{ TENSOR_PARAM(0, input); TENSOR_PARAM(1, indices); PARAM(2, int64_t, axis); @@ -412,7 +478,7 @@ NIF(index_put) { PARAM(3, bool, accumulate); c10::List> convertedList; - for (const torch::Tensor &tensor : indices) { + for (const torch::Tensor& tensor : indices) { convertedList.push_back(tensor); } @@ -424,14 +490,15 @@ NIF(index) { LIST_PARAM(1, std::vector, indices); c10::List> convertedList; - for (const torch::Tensor &tensor : indices) { + for (const torch::Tensor& tensor : indices) { convertedList.push_back(tensor); } TENSOR(torch::index(*input, convertedList)); } -NIF(argsort) { +NIF(argsort) +{ TENSOR_PARAM(0, input); PARAM(1, bool, stable); PARAM(2, int64_t, axis); @@ -440,21 +507,24 @@ NIF(argsort) { TENSOR(torch::argsort(*input, stable, axis, is_descending)); } -NIF(top_k) { +NIF(top_k) +{ TENSOR_PARAM(0, input); PARAM(1, int64_t, k); TENSOR_TUPLE(at::topk(*input, k)); } -NIF(flip) { +NIF(flip) +{ TENSOR_PARAM(0, input); LIST_PARAM(1, std::vector, dims); TENSOR(torch::flip(*input, dims)); } -NIF(unfold) { +NIF(unfold) +{ TENSOR_PARAM(0, input); PARAM(1, int64_t, dim); PARAM(2, int64_t, size); @@ -486,7 +556,8 @@ NIF(put) { TENSOR(output); } -NIF(permute) { +NIF(permute) +{ TENSOR_PARAM(0, t); LIST_PARAM(1, std::vector, dims); @@ -495,7 +566,8 @@ NIF(permute) { /* Creation */ -NIF(scalar_tensor) { +NIF(scalar_tensor) +{ SCALAR_PARAM(0, scalar); TYPE_PARAM(1, type); DEVICE_PARAM(2, device); @@ -503,7 +575,8 @@ NIF(scalar_tensor) { TENSOR(torch::scalar_tensor(scalar, OPTS(type, device))); } -NIF(randint) { +NIF(randint) +{ PARAM(0, int64_t, min); PARAM(1, int64_t, max); SHAPE_PARAM(2, shape); @@ -513,7 +586,8 @@ NIF(randint) { TENSOR(torch::randint(min, max, shape, OPTS(type, device))); } -NIF(rand) { +NIF(rand) +{ PARAM(0, double, min); PARAM(1, double, max); SHAPE_PARAM(2, shape); @@ -523,7 +597,8 @@ NIF(rand) { TENSOR(min + torch::rand(shape, OPTS(type, device)) * (max - min)); } -NIF(normal) { +NIF(normal) +{ PARAM(0, double, mean); PARAM(1, double, std); SHAPE_PARAM(2, shape); @@ -533,22 +608,27 @@ NIF(normal) { TENSOR(torch::normal(mean, std, shape, c10::nullopt, OPTS(type, device))); } -NIF(arange) { +NIF(arange) +{ PARAM(0, int64_t, start); PARAM(1, int64_t, end); PARAM(2, int64_t, step); TYPE_PARAM(3, type); DEVICE_PARAM(4, device); - if (argc == 6) { + if (argc == 6) + { SHAPE_PARAM(5, shape); TENSOR(torch::reshape(torch::arange((double)start, (double)end, (double)step, OPTS(type, device)), shape)); - } else { + } + else + { TENSOR(torch::arange((double)start, (double)end, (double)step, OPTS(type, device))); } } -NIF(ones) { +NIF(ones) +{ SHAPE_PARAM(0, shape); TYPE_PARAM(1, type); DEVICE_PARAM(2, device); @@ -556,7 +636,8 @@ NIF(ones) { TENSOR(torch::ones(shape, OPTS(type, device))); } -NIF(eye) { +NIF(eye) +{ PARAM(0, int64_t, m); PARAM(1, int64_t, n); TYPE_PARAM(2, type); @@ -565,7 +646,8 @@ NIF(eye) { TENSOR(torch::eye(m, n, OPTS(type, device))); } -NIF(full) { +NIF(full) +{ SHAPE_PARAM(0, shape); SCALAR_PARAM(1, scalar); TYPE_PARAM(2, type); @@ -579,7 +661,8 @@ NIF(full) { #define BINARY_OP(OP) BINARY_OP2(OP, OP) #define BINARY_OP2(OP, NATIVE_OP) \ - NIF(OP) { \ + NIF(OP) \ + { \ TENSOR_PARAM(0, a); \ TENSOR_PARAM(1, b); \ \ @@ -587,7 +670,8 @@ NIF(full) { } #define BINARY_OPB(OP) \ - NIF(OP) { \ + NIF(OP) \ + { \ TENSOR_PARAM(0, a); \ TENSOR_PARAM(1, b); \ \ @@ -597,7 +681,8 @@ NIF(full) { #define UNARY_OP(OP) UNARY_OP2(OP, OP) #define UNARY_OP2(OP, NATIVE) \ - NIF(OP) { \ + NIF(OP) \ + { \ TENSOR_PARAM(0, a); \ TENSOR(torch::NATIVE(*a)); \ } @@ -630,19 +715,22 @@ BINARY_OP(atan2) BINARY_OP(min) BINARY_OP(max) -NIF(fmod) { +NIF(fmod) +{ TENSOR_PARAM(0, a); TENSOR_PARAM(1, b); TENSOR(at::fmod(*a, *b)); } -NIF(quotient) { +NIF(quotient) +{ TENSOR_PARAM(0, a); TENSOR_PARAM(1, b); TENSOR(torch::divide(*a, *b, "trunc")); } -NIF(tensordot) { +NIF(tensordot) +{ TENSOR_PARAM(0, t1); TENSOR_PARAM(1, t2); LIST_PARAM(2, std::vector, axes1); @@ -654,31 +742,37 @@ NIF(tensordot) { torch::Tensor result; - if (is_batched) { + if (is_batched) + { // if any of the tensors is batched, we need to apply some transformations // on the inputs and on the result to wrap the batched APIs that torch exposes std::vector batch_dims1, batch_dims2; int64_t vmap_level = 0; - for (auto dim : batch_axes1) { + for (auto dim : batch_axes1) + { batch_dims1.push_back(at::BatchDim(vmap_level++, dim)); } torch::Tensor batched_1 = at::makeBatched(*t1, at::BatchDims(batch_dims1.begin(), batch_dims1.end())); vmap_level = 0; - for (auto dim : batch_axes2) { + for (auto dim : batch_axes2) + { batch_dims2.push_back(at::BatchDim(vmap_level++, dim)); } torch::Tensor batched_2 = at::makeBatched(*t2, at::BatchDims(batch_dims2.begin(), batch_dims2.end())); torch::Tensor batched_result = torch::tensordot(batched_1, batched_2, axes1, axes2); auto impl = at::maybeGetBatchedImpl(batched_result); - if (!impl) { + if (!impl) + { return nx::nif::error(env, "unable to get tensordot result"); } result = torch::clone(impl->value()); - } else { + } + else + { result = torch::tensordot(*t1, *t2, axes1, axes2); } @@ -719,25 +813,29 @@ UNARY_OP(erf) UNARY_OP(erfc) UNARY_OP2(erf_inv, erfinv) -NIF(view_as_real) { +NIF(view_as_real) +{ TENSOR_PARAM(0, tensor); TENSOR(torch::view_as_real(*tensor)); } -NIF(conjugate) { +NIF(conjugate) +{ TENSOR_PARAM(0, tensor); at::Tensor conjugated = tensor->conj(); TENSOR(conjugated.clone(conjugated.suggest_memory_format())); } -NIF(triangular_solve) { +NIF(triangular_solve) +{ TENSOR_PARAM(0, a); TENSOR_PARAM(1, b); PARAM(2, bool, transpose); PARAM(3, bool, upper); auto ts_a = *a; - if (transpose) { + if (transpose) + { auto num_dims = a->dim(); ts_a = torch::transpose(*a, num_dims - 2, num_dims - 1); upper = !upper; @@ -748,13 +846,15 @@ NIF(triangular_solve) { TENSOR(result); } -NIF(determinant) { +NIF(determinant) +{ TENSOR_PARAM(0, t); TENSOR(t->det()); } -NIF(sort) { +NIF(sort) +{ TENSOR_PARAM(0, t); PARAM(1, bool, stable); PARAM(2, int64_t, axis); @@ -764,7 +864,8 @@ NIF(sort) { TENSOR(std::get<0>(result)); } -NIF(clip) { +NIF(clip) +{ TENSOR_PARAM(0, t); TENSOR_PARAM(1, min); TENSOR_PARAM(2, max); @@ -772,7 +873,8 @@ NIF(clip) { TENSOR(torch::clip(*t, *min, *max)); } -NIF(where) { +NIF(where) +{ TENSOR_PARAM(0, pred); TENSOR_PARAM(1, on_true); TENSOR_PARAM(2, on_false); @@ -782,7 +884,8 @@ NIF(where) { /* Aggregates */ -NIF(sum) { +NIF(sum) +{ TENSOR_PARAM(0, t); LIST_PARAM(1, std::vector, dims); PARAM(2, bool, keep_dim); @@ -790,10 +893,12 @@ NIF(sum) { TENSOR(torch::sum(*t, dims, keep_dim)); } -NIF(product) { +NIF(product) +{ TENSOR_PARAM(0, t); - if (argc == 1) { + if (argc == 1) + { TENSOR(torch::prod(*t)); } @@ -803,36 +908,48 @@ NIF(product) { TENSOR(torch::prod(*t, dim, keep_dim)); } -NIF(argmax) { +NIF(argmax) +{ TENSOR_PARAM(0, t); PARAM(1, int64_t, dim); PARAM(2, bool, keep_dim); - if (dim == -1) { + if (dim == -1) + { TENSOR(torch::argmax(*t)); - } else { + } + else + { TENSOR(torch::argmax(*t, dim, keep_dim)); } } -NIF(argmin) { +NIF(argmin) +{ TENSOR_PARAM(0, t); PARAM(1, int64_t, dim); PARAM(2, bool, keep_dim); - if (dim == -1) { + if (dim == -1) + { TENSOR(torch::argmin(*t)); - } else { + } + else + { TENSOR(torch::argmin(*t, dim, keep_dim)); } } -NIF(cbrt) { +NIF(cbrt) +{ TENSOR_PARAM(0, tensor); - if (tensor->scalar_type() == torch::kDouble) { + if (tensor->scalar_type() == torch::kDouble) + { TENSOR(torch::pow(*tensor, 1.0 / 3)); - } else { + } + else + { TENSOR(torch::pow(*tensor, 1.0f / 3)); } } @@ -864,24 +981,30 @@ NIF(ifft2) { TENSOR(torch::fft::ifft2(*tensor, lengths, axes)); } -NIF(is_nan) { +NIF(is_nan) +{ TENSOR_PARAM(0, tensor); TENSOR(torch::isnan(*tensor)); } -NIF(is_infinity) { +NIF(is_infinity) +{ TENSOR_PARAM(0, tensor); TENSOR(torch::isinf(*tensor)); } -NIF(all) { +NIF(all) +{ TENSOR_PARAM(0, t); - if (argc == 1) { + if (argc == 1) + { TENSOR(torch::all(*t)); - } else { + } + else + { PARAM(1, int64_t, axis); PARAM(2, bool, keep_dim); @@ -889,12 +1012,16 @@ NIF(all) { } } -NIF(any) { +NIF(any) +{ TENSOR_PARAM(0, t); - if (argc == 1) { + if (argc == 1) + { TENSOR(torch::any(*t)); - } else { + } + else + { PARAM(1, int64_t, axis); PARAM(2, bool, keep_dim); @@ -902,7 +1029,8 @@ NIF(any) { } } -NIF(all_close) { +NIF(all_close) +{ TENSOR_PARAM(0, a); TENSOR_PARAM(1, b); PARAM(2, double, rtol); @@ -915,21 +1043,24 @@ NIF(all_close) { TENSOR(torch::scalar_tensor(all_close, init_opts)); } -NIF(cumulative_sum) { +NIF(cumulative_sum) +{ TENSOR_PARAM(0, t); PARAM(1, int64_t, axis); TENSOR(torch::cumsum(*t, axis)); } -NIF(cumulative_product) { +NIF(cumulative_product) +{ TENSOR_PARAM(0, t); PARAM(1, int64_t, axis); TENSOR(torch::cumprod(*t, axis)); } -NIF(cumulative_min) { +NIF(cumulative_min) +{ TENSOR_PARAM(0, t); PARAM(1, int64_t, axis); @@ -937,7 +1068,8 @@ NIF(cumulative_min) { TENSOR(std::get<0>(tt)); } -NIF(cumulative_max) { +NIF(cumulative_max) +{ TENSOR_PARAM(0, t); PARAM(1, int64_t, axis); @@ -945,22 +1077,26 @@ NIF(cumulative_max) { TENSOR(std::get<0>(tt)); } -NIF(cholesky) { +NIF(cholesky) +{ TENSOR_PARAM(0, t); bool upper = false; - if (argc == 2) { + if (argc == 2) + { GET(1, upper); } - if (upper) { + if (upper) + { TENSOR(torch::linalg::cholesky(*t).mH()); } TENSOR(torch::linalg::cholesky(*t)); } -NIF(pad) { +NIF(pad) +{ TENSOR_PARAM(0, tensor); TENSOR_PARAM(1, constant); LIST_PARAM(2, std::vector, config); @@ -970,29 +1106,34 @@ NIF(pad) { /* Transformations */ -NIF(qr) { +NIF(qr) +{ TENSOR_PARAM(0, t); bool reduced = true; - if (argc == 2) { + if (argc == 2) + { GET(1, reduced); } TENSOR_TUPLE(torch::linalg_qr(*t, reduced ? "reduced" : "complete")); } -NIF(svd) { +NIF(svd) +{ TENSOR_PARAM(0, t); bool full_matrices = true; - if (argc == 2) { + if (argc == 2) + { GET(1, full_matrices); } TENSOR_TUPLE_3(torch::linalg_svd(*t, full_matrices)); } -NIF(lu) { +NIF(lu) +{ TENSOR_PARAM(0, t); std::tuple lu_result = torch::linalg::lu_factor(*t); @@ -1001,7 +1142,8 @@ NIF(lu) { TENSOR_TUPLE_3(plu); } -NIF(amax) { +NIF(amax) +{ TENSOR_PARAM(0, tensor); LIST_PARAM(1, std::vector, axes); PARAM(2, bool, keep_axes); @@ -1009,7 +1151,8 @@ NIF(amax) { TENSOR(at::amax(*tensor, axes, keep_axes)); } -NIF(amin) { +NIF(amin) +{ TENSOR_PARAM(0, tensor); LIST_PARAM(1, std::vector, axes); PARAM(2, bool, keep_axes); @@ -1017,20 +1160,23 @@ NIF(amin) { TENSOR(at::amin(*tensor, axes, keep_axes)); } -NIF(eigh) { +NIF(eigh) +{ TENSOR_PARAM(0, tensor); TENSOR_TUPLE(torch::linalg_eigh(*tensor)); } -NIF(solve) { +NIF(solve) +{ TENSOR_PARAM(0, tensorA); TENSOR_PARAM(1, tensorB); TENSOR(torch::linalg_solve(*tensorA, *tensorB)); } -NIF(conv) { +NIF(conv) +{ TENSOR_PARAM(0, tensor); TENSOR_PARAM(1, kernel); @@ -1052,7 +1198,8 @@ NIF(conv) { stride, padding, dilation, transposed, output_padding, groups)); } -NIF(max_pool_3d) { +NIF(max_pool_3d) +{ TENSOR_PARAM(0, tensor); LIST_PARAM(1, std::vector, kernel_size); LIST_PARAM(2, std::vector, strides); @@ -1062,12 +1209,14 @@ NIF(max_pool_3d) { TENSOR(at::max_pool3d(*tensor, kernel_size, strides, padding, dilation)); } -void free_tensor(ErlNifEnv *env, void *obj) { +void free_tensor(ErlNifEnv *env, void *obj) +{ torch::Tensor *tensor = reinterpret_cast(obj); std::atomic *refcount = reinterpret_cast *>(tensor + 1); std::atomic_flag *deleted = reinterpret_cast(refcount + 1); - if (atomic_flag_test_and_set(deleted) == false) { + if (atomic_flag_test_and_set(deleted) == false) + { tensor->~Tensor(); } @@ -1076,7 +1225,8 @@ void free_tensor(ErlNifEnv *env, void *obj) { } static int -open_resource_type(ErlNifEnv *env) { +open_resource_type(ErlNifEnv *env) +{ const char *name = "Tensor"; ErlNifResourceFlags flags = (ErlNifResourceFlags)(ERL_NIF_RT_CREATE | ERL_NIF_RT_TAKEOVER); @@ -1086,7 +1236,8 @@ open_resource_type(ErlNifEnv *env) { return 0; } -int upgrade(ErlNifEnv *env, void **priv_data, void **old_priv_data, ERL_NIF_TERM load_info) { +int upgrade(ErlNifEnv *env, void **priv_data, void **old_priv_data, ERL_NIF_TERM load_info) +{ // Silence "unused var" warnings. (void)(env); (void)(priv_data); @@ -1096,7 +1247,8 @@ int upgrade(ErlNifEnv *env, void **priv_data, void **old_priv_data, ERL_NIF_TERM return 0; } -int load(ErlNifEnv *env, void **priv_data, ERL_NIF_TERM load_info) { +int load(ErlNifEnv *env, void **priv_data, ERL_NIF_TERM load_info) +{ if (open_resource_type(env) == -1) return -1;