Skip to content

Commit

Permalink
feat: window op padding and strides, and small fixes (#46)
Browse files Browse the repository at this point in the history
* feat: support padding and strides

* fix: various small fixes

* chore: unskip more tests
  • Loading branch information
polvalente authored Nov 27, 2024
1 parent dba3985 commit 57afab5
Show file tree
Hide file tree
Showing 5 changed files with 231 additions and 91 deletions.
20 changes: 15 additions & 5 deletions c_src/emlx_nif.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,20 +329,28 @@ NIF(from_blob) {
}

NIF(scalar_tensor) {
SCALAR_PARAM(0, scalar);
SCALAR_PARAM(0, scalar, is_complex);
TYPE_PARAM(1, type);
// DEVICE_PARAM(2, device);

TENSOR(mlx::core::array(scalar, type))
if (is_complex) {
TENSOR(mlx::core::array(complex_scalar, type))
} else {
TENSOR(mlx::core::array(scalar, type))
}
}

NIF(full) {
SCALAR_PARAM(0, scalar);
SCALAR_PARAM(0, scalar, is_complex);
SHAPE_PARAM(1, shape);
TYPE_PARAM(2, type);
DEVICE_PARAM(3, device);

TENSOR(mlx::core::full(shape, scalar, type, device));
if (is_complex) {
TENSOR(mlx::core::full(shape, complex_scalar, type, device));
} else {
TENSOR(mlx::core::full(shape, scalar, type, device));
}
}

NIF(arange) {
Expand Down Expand Up @@ -732,7 +740,7 @@ NIF(bitwise_not) {
DEVICE_PARAM(1, device);

auto dtype = (*a).dtype();
auto mask = mlx::core::full({1}, 0xFFFFFFFFFFFFFFFF, dtype, device);
auto mask = mlx::core::full({}, 0xFFFFFFFFFFFFFFFF, dtype, device);
TENSOR(mlx::core::subtract(mask, *a, device));
}
BINARY_OP(left_shift)
Expand Down Expand Up @@ -961,6 +969,8 @@ static ErlNifFunc nif_funcs[] = {{"strides", 1, strides},
{"sigmoid", 2, sigmoid},
{"asin", 2, asin},
{"asinh", 2, asinh},
{"acos", 2, acos},
{"acosh", 2, acosh},
{"cos", 2, cos},
{"cosh", 2, cosh},
{"atan", 2, atan},
Expand Down
13 changes: 9 additions & 4 deletions c_src/nx_nif_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,16 @@ ErlNifResourceType *TENSOR_TYPE;
ATOM_PARAM(ARGN, VAR##_atom) \
mlx::core::Device VAR = string2device(VAR##_atom)

#define SCALAR_PARAM(ARGN, VAR) \
#define SCALAR_PARAM(ARGN, VAR, IS_COMPLEX_VAR) \
bool IS_COMPLEX_VAR = false; \
double VAR; \
std::vector<double> complex_##VAR; \
if (nx::nif::get_tuple<double>(env, argv[ARGN], complex_##VAR)) { \
return nx::nif::error(env, "Complex numbers are not supported in MLX"); \
std::complex<float> complex_##VAR; \
std::vector<double> complex_reader_##VAR; \
if (nx::nif::get_tuple<double>(env, argv[ARGN], complex_reader_##VAR)) { \
complex_##VAR = \
std::complex<float>(static_cast<float>(complex_reader_##VAR[0]), \
static_cast<float>(complex_reader_##VAR[1])); \
IS_COMPLEX_VAR = true; \
} else if (enif_get_double(env, argv[ARGN], &VAR) == 0) { \
int64_t int64_##VAR; \
if (!enif_get_int64(env, argv[ARGN], (ErlNifSInt64 *)&int64_##VAR)) \
Expand Down
Loading

0 comments on commit 57afab5

Please sign in to comment.