diff --git a/RELEASENOTES.md b/RELEASENOTES.md
index f609e23e5..7f35468b4 100644
--- a/RELEASENOTES.md
+++ b/RELEASENOTES.md
@@ -11,6 +11,11 @@ Adding allow_tf32
Adding overloads of Module.save() and Module.load() taking a 'Stream' argument.
Adding torch.softmax() and Tensor.softmax() as aliases for torch.special.softmax()
Adding torch.from_file()
+Adding a number of missing pointwise Tensor operations.
+Adding select_scatter, diagonal_scatter, and slice_scatter
+Adding torch.set_printoptions
+Adding torch.cartesian_prod, combinations, and cov.
+Adding torch.cdist, diag_embed, rot90, triu_indices, tril_indices
__Fixed Bugs__:
diff --git a/src/Native/LibTorchSharp/THSLinearAlgebra.cpp b/src/Native/LibTorchSharp/THSLinearAlgebra.cpp
index 026374365..d2d1bfd6c 100644
--- a/src/Native/LibTorchSharp/THSLinearAlgebra.cpp
+++ b/src/Native/LibTorchSharp/THSLinearAlgebra.cpp
@@ -47,6 +47,11 @@ Tensor THSLinalg_det(const Tensor tensor)
CATCH_TENSOR(torch::linalg::det(*tensor));
}
+Tensor THSTensor_logdet(const Tensor tensor)
+{
+ CATCH_TENSOR(torch::logdet(*tensor));
+}
+
Tensor THSLinalg_slogdet(const Tensor tensor, Tensor* logabsdet)
{
std::tuple res;
@@ -63,6 +68,13 @@ Tensor THSLinalg_eig(const Tensor tensor, Tensor* eigenvectors)
return ResultTensor(std::get<0>(res));
}
+Tensor THSTensor_geqrf(const Tensor tensor, Tensor* tau)
+{
+ std::tuple res;
+ CATCH(res = torch::geqrf(*tensor);)
+ *tau = ResultTensor(std::get<1>(res));
+ return ResultTensor(std::get<0>(res));
+}
#if 0
Tensor THSTensor_eig(const Tensor tensor, bool vectors, Tensor* eigenvectors)
@@ -98,6 +110,11 @@ Tensor THSLinalg_eigvalsh(const Tensor tensor, const char UPLO)
CATCH_TENSOR(torch::linalg::eigvalsh(*tensor, _uplo));
}
+Tensor THSLinalg_householder_product(const Tensor tensor, const Tensor tau)
+{
+ CATCH_TENSOR(torch::linalg::householder_product(*tensor, *tau));
+}
+
Tensor THSLinalg_inv(const Tensor tensor)
{
CATCH_TENSOR(torch::linalg::inv(*tensor));
diff --git a/src/Native/LibTorchSharp/THSTensor.cpp b/src/Native/LibTorchSharp/THSTensor.cpp
index 8d0b01d1b..f9f32b3d2 100644
--- a/src/Native/LibTorchSharp/THSTensor.cpp
+++ b/src/Native/LibTorchSharp/THSTensor.cpp
@@ -66,6 +66,12 @@ Tensor THSTensor_any_along_dimension(const Tensor tensor, const int64_t dim, boo
{
CATCH_TENSOR(tensor->any(dim, keepdim));
}
+
+Tensor THSTensor_adjoint(const Tensor tensor)
+{
+ CATCH_TENSOR(tensor->adjoint());
+}
+
Tensor THSTensor_argmax(const Tensor tensor)
{
CATCH_TENSOR(tensor->argmax());
@@ -86,6 +92,11 @@ Tensor THSTensor_argmin_along_dimension(const Tensor tensor, const int64_t dim,
CATCH_TENSOR(tensor->argmin(dim, keepdim));
}
+Tensor THSTensor_argwhere(const Tensor tensor)
+{
+ CATCH_TENSOR(tensor->argwhere());
+}
+
Tensor THSTensor_atleast_1d(const Tensor tensor)
{
CATCH_TENSOR(torch::atleast_1d(*tensor));
@@ -159,6 +170,11 @@ void THSTensor_vector_to_parameters(const Tensor vec, const Tensor* tensors, con
CATCH(torch::nn::utils::vector_to_parameters(*vec, toTensors((torch::Tensor**)tensors, length)););
}
+Tensor THSTensor_cartesian_prod(const Tensor* tensors, const int length)
+{
+ CATCH_TENSOR(torch::cartesian_prod(toTensors((torch::Tensor**)tensors, length)));
+}
+
double THSTensor_clip_grad_norm_(const Tensor* tensors, const int length, const double max_norm, const double norm_type)
{
double res = 0.0;
@@ -258,6 +274,11 @@ Tensor THSTensor_clone(const Tensor tensor)
CATCH_TENSOR(tensor->clone());
}
+Tensor THSTensor_combinations(const Tensor tensor, const int r, const bool with_replacement)
+{
+ CATCH_TENSOR(torch::combinations(*tensor, r, with_replacement));
+}
+
Tensor THSTensor_copy_(const Tensor input, const Tensor other, const bool non_blocking)
{
CATCH_TENSOR(input->copy_(*other, non_blocking));
@@ -285,6 +306,13 @@ int THSTensor_is_contiguous(const Tensor tensor)
return result;
}
+int64_t THSTensor_is_nonzero(const Tensor tensor)
+{
+ bool result = false;
+ CATCH(result = tensor->is_nonzero();)
+ return result;
+}
+
Tensor THSTensor_copysign(const Tensor input, const Tensor other)
{
CATCH_TENSOR(input->copysign(*other));
@@ -295,13 +323,6 @@ Tensor THSTensor_corrcoef(const Tensor tensor)
CATCH_TENSOR(tensor->corrcoef());
}
-Tensor THSTensor_cov(const Tensor input, int64_t correction, const Tensor fweights, const Tensor aweights)
-{
- c10::optional fw = (fweights == nullptr) ? c10::optional() : *fweights;
- c10::optional aw = (aweights == nullptr) ? c10::optional() : *aweights;
- CATCH_TENSOR(input->cov(correction, fw, aw));
-}
-
bool THSTensor_is_cpu(const Tensor tensor)
{
bool result = true;
@@ -402,6 +423,11 @@ int THSTensor_device_type(const Tensor tensor)
return (int)device.type();
}
+Tensor THSTensor_diag_embed(const Tensor tensor, const int64_t offset, const int64_t dim1, const int64_t dim2)
+{
+ CATCH_TENSOR(tensor->diag_embed(offset, dim1, dim2));
+}
+
Tensor THSTensor_diff(const Tensor tensor, const int64_t n, const int64_t dim, const Tensor prepend, const Tensor append)
{
c10::optional prep = prepend != nullptr ? *prepend : c10::optional(c10::nullopt);
@@ -473,6 +499,11 @@ Tensor THSTensor_repeat_interleave_int64(const Tensor tensor, const int64_t repe
CATCH_TENSOR(tensor->repeat_interleave(repeats, _dim, _output_size));
}
+int THSTensor_result_type(const Tensor left, const Tensor right)
+{
+ CATCH_RETURN_RES(int, -1, res = (int)torch::result_type(*left, *right));
+}
+
Tensor THSTensor_movedim(const Tensor tensor, const int64_t* src, const int src_len, const int64_t* dst, const int dst_len)
{
CATCH_TENSOR(tensor->movedim(at::ArrayRef(src, src_len), at::ArrayRef(dst, dst_len)));
@@ -1070,6 +1101,11 @@ Tensor THSTensor_outer(const Tensor left, const Tensor right)
CATCH_TENSOR(left->outer(*right));
}
+Tensor THSTensor_ormqr(const Tensor input, const Tensor tau, const Tensor other, bool left, bool transpose)
+{
+ CATCH_TENSOR(torch::ormqr(*input, *tau, *other, left, transpose));
+}
+
Tensor THSTensor_mH(const Tensor tensor)
{
CATCH_TENSOR(tensor->mH());
@@ -1161,6 +1197,11 @@ Tensor THSTensor_reshape(const Tensor tensor, const int64_t* shape, const int le
CATCH_TENSOR(tensor->reshape(at::ArrayRef(shape, length)));
}
+Tensor THSTensor_rot90(const Tensor tensor, const int64_t k, const int64_t dim1, const int64_t dim2)
+{
+ CATCH_TENSOR(tensor->rot90(k, { dim1, dim2 }));
+}
+
Tensor THSTensor_roll(const Tensor tensor, const int64_t* shifts, const int shLength, const int64_t* dims, const int dimLength)
{
CATCH_TENSOR(
@@ -1194,6 +1235,36 @@ Tensor THSTensor_scatter_(
CATCH_TENSOR(tensor->scatter_(dim, *index, *source));
}
+Tensor THSTensor_select_scatter(
+ const Tensor tensor,
+ const Tensor source,
+ const int64_t dim,
+ const int64_t index)
+{
+ CATCH_TENSOR(torch::select_scatter(*tensor, *source, dim, index));
+}
+
+Tensor THSTensor_diagonal_scatter(
+ const Tensor tensor,
+ const Tensor source,
+ const int64_t offset,
+ const int64_t dim1,
+ const int64_t dim2)
+{
+ CATCH_TENSOR(torch::diagonal_scatter(*tensor, *source, offset, dim1, dim2));
+}
+
+Tensor THSTensor_slice_scatter(
+ const Tensor tensor,
+ const Tensor source,
+ const int64_t dim,
+ const int64_t *start,
+ const int64_t *end,
+ const int64_t step)
+{
+ CATCH_TENSOR(torch::slice_scatter(*tensor, *source, dim, start == nullptr ? c10::optional() : c10::optional(*start), end == nullptr ? c10::optional() : c10::optional(*end), step));
+}
+
Tensor THSTensor_scatter_add(
const Tensor tensor,
const int64_t dim,
@@ -1762,6 +1833,23 @@ Tensor THSTensor_tril(const Tensor tensor, const int64_t diagonal)
CATCH_TENSOR(tensor->tril(diagonal));
}
+Tensor THSTensor_tril_indices(const int64_t row, const int64_t col, const int64_t offset, const int8_t scalar_type, const int device_type, const int device_index)
+{
+ auto options = at::TensorOptions()
+ .dtype(at::ScalarType(scalar_type))
+ .device(c10::Device((c10::DeviceType)device_type, (c10::DeviceIndex)device_index));
+ CATCH_TENSOR(torch::tril_indices(row, col, offset, options));
+}
+
+Tensor THSTensor_triu_indices(const int64_t row, const int64_t col, const int64_t offset, const int8_t scalar_type, const int device_type, const int device_index)
+{
+ auto options = at::TensorOptions()
+ .dtype(at::ScalarType(scalar_type))
+ .device(c10::Device((c10::DeviceType)device_type, (c10::DeviceIndex)device_index));
+ CATCH_TENSOR(torch::triu_indices(row, col, offset, options));
+}
+
+
Tensor THSTensor_transpose(const Tensor tensor, const int64_t dim1, const int64_t dim2)
{
CATCH_TENSOR(tensor->transpose(dim1, dim2));
diff --git a/src/Native/LibTorchSharp/THSTensor.h b/src/Native/LibTorchSharp/THSTensor.h
index 105ab9b4d..07e2ca2e1 100644
--- a/src/Native/LibTorchSharp/THSTensor.h
+++ b/src/Native/LibTorchSharp/THSTensor.h
@@ -55,6 +55,8 @@ EXPORT_API(Tensor) THSTensor_addr(const Tensor input, const Tensor mat1, const T
EXPORT_API(Tensor) THSTensor_addr_(const Tensor input, const Tensor mat1, const Tensor vec2, const float beta, const float alpha);
+EXPORT_API(Tensor) THSTensor_adjoint(const Tensor tensor);
+
EXPORT_API(Tensor) THSTensor_alias(const Tensor tensor);
EXPORT_API(int) THSTensor_allclose(const Tensor left, const Tensor right, double rtol, double atol, bool equal_nan);
@@ -103,6 +105,8 @@ EXPORT_API(Tensor) THSTensor_argmin_along_dimension(const Tensor tensor, const i
EXPORT_API(Tensor) THSTensor_argsort(const Tensor tensor, const int64_t dim, bool descending);
+EXPORT_API(Tensor) THSTensor_argwhere(const Tensor tensor);
+
EXPORT_API(Tensor) THSTensor_asin(const Tensor tensor);
EXPORT_API(Tensor) THSTensor_asin_(const Tensor tensor);
@@ -205,10 +209,14 @@ EXPORT_API(void) THSTensor_broadcast_tensors(const Tensor* tensor, const int len
EXPORT_API(Tensor) THSTensor_bucketize(const Tensor tensor, const Tensor boundaries, const bool out_int32, const bool right);
+EXPORT_API(Tensor) THSTensor_cartesian_prod(const Tensor* tensor, const int length);
+
EXPORT_API(Tensor) THSTensor_cat(const Tensor* tensor, const int length, const int64_t dim);
EXPORT_API(Tensor) THSTensor_channel_shuffle(const Tensor tensor, const int64_t groups);
+EXPORT_API(Tensor) THSTensor_cdist(const Tensor x1, const Tensor x2, const double p, const int64_t compute_mode);
+
EXPORT_API(double) THSTensor_clip_grad_norm_(const Tensor* tensor, const int length, const double max_norm, const double norm_type);
EXPORT_API(void) THSTensor_clip_grad_value_(const Tensor* tensors, const int length, const double value);
@@ -219,6 +227,8 @@ EXPORT_API(void) THSTensor_vector_to_parameters(const Tensor vec, const Tensor*
EXPORT_API(Tensor) THSTensor_clone(const Tensor input);
+EXPORT_API(Tensor) THSTensor_combinations(const Tensor tensor, const int r, const bool with_replacement);
+
EXPORT_API(Tensor) THSTensor_contiguous(const Tensor input);
EXPORT_API(Tensor) THSTensor_ceil(const Tensor tensor);
@@ -255,12 +265,14 @@ EXPORT_API(Tensor) THSTensor_complex(const Tensor real, const Tensor imag);
EXPORT_API(Tensor) THSTensor_conj(const Tensor tensor);
-EXPORT_API(int64_t) THSTensor_is_conj(const Tensor tensor);
+EXPORT_API(int64_t) THSTensor_is_nonzero(const Tensor tensor);
EXPORT_API(Tensor) THSTensor_conj_physical(const Tensor tensor);
EXPORT_API(Tensor) THSTensor_conj_physical_(const Tensor tensor);
+EXPORT_API(int64_t) THSTensor_is_conj(const Tensor tensor);
+
EXPORT_API(Tensor) THSTensor_resolve_conj(const Tensor tensor);
@@ -360,6 +372,8 @@ EXPORT_API(int) THSTensor_device_index(const Tensor tensor);
EXPORT_API(Tensor) THSTensor_diag(const Tensor tensor, const int64_t diagonal);
+EXPORT_API(Tensor) THSTensor_diag_embed(const Tensor tensor, const int64_t offset, const int64_t dim1, const int64_t dim2);
+
EXPORT_API(Tensor) THSTensor_trace(const Tensor tensor);
EXPORT_API(Tensor) THSTensor_diagflat(const Tensor tensor, const int64_t offset);
@@ -478,6 +492,22 @@ EXPORT_API(Tensor) THSTensor_floor(const Tensor tensor);
EXPORT_API(Tensor) THSTensor_floor_(const Tensor tensor);
+EXPORT_API(Tensor) THSTensor_floor_divide(const Tensor left, const Tensor right);
+
+EXPORT_API(Tensor) THSTensor_floor_divide_scalar(const Tensor left, const Scalar right);
+
+EXPORT_API(Tensor) THSTensor_floor_divide_(const Tensor left, const Tensor right);
+
+EXPORT_API(Tensor) THSTensor_floor_divide_scalar_(const Tensor left, const Scalar right);
+
+EXPORT_API(Tensor) THSTensor_true_divide(const Tensor left, const Tensor right);
+
+EXPORT_API(Tensor) THSTensor_true_divide_scalar(const Tensor left, const Scalar right);
+
+EXPORT_API(Tensor) THSTensor_true_divide_(const Tensor left, const Tensor right);
+
+EXPORT_API(Tensor) THSTensor_true_divide_scalar_(const Tensor left, const Scalar right);
+
EXPORT_API(Tensor) THSTensor_frac(const Tensor tensor);
EXPORT_API(Tensor) THSTensor_frac_(const Tensor tensor);
@@ -879,6 +909,10 @@ EXPORT_API(Tensor) THSTensor_neg(const Tensor tensor);
EXPORT_API(Tensor) THSTensor_neg_(const Tensor tensor);
+EXPORT_API(int64_t) THSTensor_is_neg(const Tensor tensor);
+
+EXPORT_API(Tensor) THSTensor_resolve_neg(const Tensor tensor);
+
EXPORT_API(Tensor) THSTensor_new(
void* data,
void (*deleter)(void*),
@@ -958,6 +992,8 @@ EXPORT_API(Tensor) THSTensor_ones_out(const int64_t* sizes, const int length, co
EXPORT_API(Tensor) THSTensor_ones_like(const Tensor input, const int8_t scalar_type, const int device_type, const int device_index, const bool requires_grad);
+EXPORT_API(Tensor) THSTensor_ormqr(const Tensor input, const Tensor tau, const Tensor other, bool left, bool transpose);
+
EXPORT_API(Tensor) THSTensor_outer(const Tensor left, const Tensor right);
EXPORT_API(Tensor) THSTensor_mT(const Tensor tensor);
@@ -1040,6 +1076,8 @@ EXPORT_API(Tensor) THSTensor_reshape(const Tensor tensor, const int64_t* shape,
EXPORT_API(Tensor) THSTensor_roll(const Tensor tensor, const int64_t* shifts, const int shLength, const int64_t* dims, const int dimLength);
+EXPORT_API(Tensor) THSTensor_rot90(const Tensor tensor, const int64_t k, const int64_t dim1, const int64_t dim2);
+
EXPORT_API(Tensor) THSTensor_round(const Tensor tensor, const int64_t decimals);
EXPORT_API(Tensor) THSTensor_round_(const Tensor tensor, const int64_t decimals);
@@ -1053,6 +1091,8 @@ EXPORT_API(Tensor) THSTensor_remainder_scalar_(const Tensor left, const Scalar r
EXPORT_API(void) THSTensor_retain_grad(const Tensor tensor);
+EXPORT_API(int) THSTensor_result_type(const Tensor left, const Tensor right);
+
EXPORT_API(Tensor) THSTensor_rsqrt(const Tensor tensor);
EXPORT_API(Tensor) THSTensor_rsqrt_(const Tensor tensor);
@@ -1073,6 +1113,10 @@ EXPORT_API(Tensor) THSTensor_sign(const Tensor tensor);
EXPORT_API(Tensor) THSTensor_sign_(const Tensor tensor);
+EXPORT_API(Tensor) THSTensor_sgn(const Tensor tensor);
+
+EXPORT_API(Tensor) THSTensor_sgn_(const Tensor tensor);
+
EXPORT_API(Tensor) THSTensor_signbit(const Tensor tensor);
EXPORT_API(Tensor) THSTensor_silu(const Tensor tensor);
@@ -1132,6 +1176,10 @@ EXPORT_API(void) THSTensor_save(const Tensor tensor, const char* location);
EXPORT_API(Tensor) THSTensor_scatter(const Tensor tensor, const int64_t dim, const Tensor index, const Tensor source);
EXPORT_API(Tensor) THSTensor_scatter_(const Tensor tensor, const int64_t dim, const Tensor index, const Tensor source);
+EXPORT_API(Tensor) THSTensor_diagonal_scatter(const Tensor tensor, const Tensor source, const int64_t offset, const int64_t dim1, const int64_t dim2);
+EXPORT_API(Tensor) THSTensor_select_scatter(const Tensor tensor, const Tensor source, const int64_t dim, const int64_t index);
+EXPORT_API(Tensor) THSTensor_slice_scatter(const Tensor tensor, const Tensor source, const int64_t dim, const int64_t* start, const int64_t* end, const int64_t step);
+
EXPORT_API(Tensor) THSTensor_scatter_add(const Tensor tensor, const int64_t dim, const Tensor index, const Tensor source);
EXPORT_API(Tensor) THSTensor_scatter_add_(const Tensor tensor, const int64_t dim, const Tensor index, const Tensor source);
@@ -1226,6 +1274,9 @@ EXPORT_API(Tensor) THSTensor_tril(const Tensor tensor, const int64_t diagonal);
EXPORT_API(Tensor) THSTensor_triu(const Tensor tensor, const int64_t diagonal);
+EXPORT_API(Tensor) THSTensor_tril_indices(const int64_t row, const int64_t col, const int64_t offset, const int8_t scalar_type, const int device_type, const int device_index);
+EXPORT_API(Tensor) THSTensor_triu_indices(const int64_t row, const int64_t col, const int64_t offset, const int8_t scalar_type, const int device_type, const int device_index);
+
EXPORT_API(Tensor) THSTensor_transpose(const Tensor tensor, const int64_t dim1, const int64_t dim2);
EXPORT_API(Tensor) THSTensor_transpose_(const Tensor tensor, const int64_t dim1, const int64_t dim2);
@@ -1236,6 +1287,9 @@ EXPORT_API(Tensor) THSTensor_cumulative_trapezoid_dx(const Tensor y, const doubl
EXPORT_API(Tensor) THSTensor_trapezoid_x(const Tensor y, const Tensor x, int64_t dim);
EXPORT_API(Tensor) THSTensor_trapezoid_dx(const Tensor y, const double dx, int64_t dim);
+EXPORT_API(Tensor) THSTensor_cumulative_trapezoid_x(const Tensor y, const Tensor x, int64_t dim);
+EXPORT_API(Tensor) THSTensor_cumulative_trapezoid_dx(const Tensor y, const double dx, int64_t dim);
+
EXPORT_API(Tensor) THSTensor_to_dense(Tensor tensor);
EXPORT_API(Tensor) THSTensor_to_device(const Tensor tensor, const int device_type, const int device_index, const bool copy);
@@ -1386,6 +1440,7 @@ EXPORT_API(Tensor) THSLinalg_cholesky_ex(const Tensor tensor, bool check_errors,
EXPORT_API(Tensor) THSLinalg_cross(const Tensor input, const Tensor other, const int64_t dim);
EXPORT_API(Tensor) THSLinalg_det(const Tensor tensor);
+EXPORT_API(Tensor) THSTensor_logdet(const Tensor tensor);
EXPORT_API(Tensor) THSLinalg_slogdet(const Tensor tensor, Tensor *logabsdet);
@@ -1397,6 +1452,10 @@ EXPORT_API(Tensor) THSTensor_eig(const Tensor tensor, bool vectors, Tensor* eige
EXPORT_API(Tensor) THSLinalg_eigvals(const Tensor tensor);
EXPORT_API(Tensor) THSLinalg_eigvalsh(const Tensor tensor, const char UPLO);
+EXPORT_API(Tensor) THSTensor_geqrf(const Tensor tensor, Tensor* tau);
+
+EXPORT_API(Tensor) THSLinalg_householder_product(const Tensor tensor, const Tensor tau);
+
EXPORT_API(Tensor) THSLinalg_inv(const Tensor tensor);
EXPORT_API(Tensor) THSLinalg_inv_ex(const Tensor tensor, bool check_errors, Tensor* info);
diff --git a/src/Native/LibTorchSharp/THSTensorMath.cpp b/src/Native/LibTorchSharp/THSTensorMath.cpp
index f74a492dc..1f72b07f3 100644
--- a/src/Native/LibTorchSharp/THSTensorMath.cpp
+++ b/src/Native/LibTorchSharp/THSTensorMath.cpp
@@ -241,6 +241,13 @@ Tensor THSTensor_bmm(const Tensor batch1, const Tensor batch2)
CATCH_TENSOR(batch1->bmm(*batch2));
}
+Tensor THSTensor_cdist(const Tensor x1, const Tensor x2, const double p, const int64_t compute_mode)
+{
+ CATCH_TENSOR(compute_mode == 0
+ ? torch::cdist(*x1, *x2, p)
+ : torch::cdist(*x1, *x2, p, compute_mode));
+}
+
Tensor THSTensor_ceil(const Tensor tensor)
{
CATCH_TENSOR(tensor->ceil());
@@ -258,7 +265,12 @@ Tensor THSTensor_conj(const Tensor tensor)
int64_t THSTensor_is_conj(const Tensor tensor)
{
- CATCH_RETURN_RES(int64_t, 0, res = tensor->is_conj();)
+ CATCH_RETURN_RES(int64_t, -1, res = tensor->is_conj();)
+}
+
+int64_t THSTensor_is_neg(const Tensor tensor)
+{
+ CATCH_RETURN_RES(int64_t, -1, res = tensor->is_neg();)
}
Tensor THSTensor_conj_physical(const Tensor tensor)
@@ -276,6 +288,11 @@ Tensor THSTensor_resolve_conj(const Tensor tensor)
CATCH_TENSOR(tensor->resolve_conj());
}
+Tensor THSTensor_resolve_neg(const Tensor tensor)
+{
+ CATCH_TENSOR(tensor->resolve_neg());
+}
+
Tensor THSTensor_cos(const Tensor tensor)
{
CATCH_TENSOR(tensor->cos());
@@ -296,6 +313,13 @@ Tensor THSTensor_cosh_(const Tensor tensor)
CATCH_TENSOR(tensor->cosh_());
}
+Tensor THSTensor_cov(const Tensor input, int64_t correction, const Tensor fweights, const Tensor aweights)
+{
+ c10::optional fw = (fweights == nullptr) ? c10::optional() : *fweights;
+ c10::optional aw = (aweights == nullptr) ? c10::optional() : *aweights;
+ CATCH_TENSOR(input->cov(correction, fw, aw));
+}
+
Tensor THSTensor_cross(const Tensor tensor, const Tensor other, const int64_t dim)
{
CATCH_TENSOR(tensor->cross(*other, dim));
@@ -430,6 +454,46 @@ Tensor THSTensor_floor_(const Tensor tensor)
CATCH_TENSOR(tensor->floor_());
}
+Tensor THSTensor_floor_divide(const Tensor left, const Tensor right)
+{
+ CATCH_TENSOR(left->floor_divide(*right));
+}
+
+Tensor THSTensor_floor_divide_scalar(const Tensor left, const Scalar right)
+{
+ CATCH_TENSOR(left->floor_divide(*right));
+}
+
+Tensor THSTensor_floor_divide_(const Tensor left, const Tensor right)
+{
+ CATCH_TENSOR(left->floor_divide_(*right));
+}
+
+Tensor THSTensor_floor_divide_scalar_(const Tensor left, const Scalar right)
+{
+ CATCH_TENSOR(left->floor_divide_(*right));
+}
+
+Tensor THSTensor_true_divide(const Tensor left, const Tensor right)
+{
+ CATCH_TENSOR(left->true_divide(*right));
+}
+
+Tensor THSTensor_true_divide_scalar(const Tensor left, const Scalar right)
+{
+ CATCH_TENSOR(left->true_divide(*right));
+}
+
+Tensor THSTensor_true_divide_(const Tensor left, const Tensor right)
+{
+ CATCH_TENSOR(left->true_divide_(*right));
+}
+
+Tensor THSTensor_true_divide_scalar_(const Tensor left, const Scalar right)
+{
+ CATCH_TENSOR(left->true_divide_(*right));
+}
+
Tensor THSTensor_fmax(const Tensor left, const Tensor right)
{
CATCH_TENSOR(left->fmax(*right));
@@ -856,6 +920,16 @@ Tensor THSTensor_sign_(const Tensor tensor)
CATCH_TENSOR(tensor->sign_());
}
+Tensor THSTensor_sgn(const Tensor tensor)
+{
+ CATCH_TENSOR(tensor->sgn());
+}
+
+Tensor THSTensor_sgn_(const Tensor tensor)
+{
+ CATCH_TENSOR(tensor->sgn_());
+}
+
Tensor THSTensor_signbit(const Tensor tensor)
{
CATCH_TENSOR(tensor->signbit());
diff --git a/src/Native/LibTorchSharp/THSTorch.cpp b/src/Native/LibTorchSharp/THSTorch.cpp
index fdeac851e..87da90699 100644
--- a/src/Native/LibTorchSharp/THSTorch.cpp
+++ b/src/Native/LibTorchSharp/THSTorch.cpp
@@ -11,7 +11,6 @@ void THSTorch_manual_seed(const int64_t seed)
Generator THSGenerator_manual_seed(const int64_t seed)
{
- torch::manual_seed(seed);
return THSGenerator_default_generator();
}
@@ -152,6 +151,37 @@ const char * THSTorch_get_and_reset_last_err()
return tmp;
}
+int THSTorch_get_num_threads()
+{
+ CATCH_RETURN_RES(int, -1, res = torch::get_num_threads());
+}
+
+void THSTorch_set_num_threads(const int threads)
+{
+ torch::set_num_threads(threads);
+}
+
+int THSTorch_get_num_interop_threads()
+{
+ CATCH_RETURN_RES(int, -1, res = torch::get_num_interop_threads());
+}
+
+void THSTorch_set_num_interop_threads(const int threads)
+{
+ torch::set_num_interop_threads(threads);
+}
+
+int THSTorch_can_cast(const int type1, const int type2)
+{
+ CATCH_RETURN_RES(int, -1, res = (int)torch::can_cast((c10::ScalarType)type1, (c10::ScalarType)type2));
+}
+
+int THSTorch_promote_types(const int type1, const int type2)
+{
+ CATCH_RETURN_RES(int, -1, res = (int)torch::promote_types((c10::ScalarType)type1, (c10::ScalarType)type2));
+}
+
+
Scalar THSTorch_int8_to_scalar(int8_t value)
{
return new torch::Scalar(value);
diff --git a/src/Native/LibTorchSharp/THSTorch.h b/src/Native/LibTorchSharp/THSTorch.h
index dde158829..9b2a31edc 100644
--- a/src/Native/LibTorchSharp/THSTorch.h
+++ b/src/Native/LibTorchSharp/THSTorch.h
@@ -41,10 +41,19 @@ EXPORT_API(void) THSBackend_cuda_set_enable_flash_sdp(const bool flag);
EXPORT_API(bool) THSBackend_cuda_get_enable_math_sdp();
EXPORT_API(void) THSBackend_cuda_set_enable_math_sdp(const bool flag);
+EXPORT_API(int) THSTorch_get_num_threads();
+EXPORT_API(void) THSTorch_set_num_threads(const int threads);
+
+EXPORT_API(int) THSTorch_get_num_interop_threads();
+EXPORT_API(void) THSTorch_set_num_interop_threads(const int threads);
+
// Returns the latest error. This is thread-local.
EXPORT_API(const char *) THSTorch_get_and_reset_last_err();
+EXPORT_API(int) THSTorch_can_cast(const int type1, const int type2);
+EXPORT_API(int) THSTorch_promote_types(const int type1, const int type2);
+
EXPORT_API(Scalar) THSTorch_int8_to_scalar(int8_t value);
EXPORT_API(Scalar) THSTorch_uint8_to_scalar(uint8_t value);
EXPORT_API(Scalar) THSTorch_int16_to_scalar(short value);
diff --git a/src/TorchSharp/LinearAlgebra.cs b/src/TorchSharp/LinearAlgebra.cs
index db542fcec..21a2347cf 100644
--- a/src/TorchSharp/LinearAlgebra.cs
+++ b/src/TorchSharp/LinearAlgebra.cs
@@ -198,6 +198,19 @@ public static Tensor eigvalsh(Tensor input, char UPLO = 'L')
return new Tensor(res);
}
+ ///
+ /// Computes the first n columns of a product of Householder matrices.
+ ///
+ /// tensor of shape (*, m, n) where * is zero or more batch dimensions.
+ /// tensor of shape (*, k) where * is zero or more batch dimensions.
+ public static Tensor householder_product(Tensor A, Tensor tau)
+ {
+ var res = THSLinalg_householder_product(A.Handle, tau.Handle);
+ if (res == IntPtr.Zero)
+ torch.CheckForErrors();
+ return new Tensor(res);
+ }
+
///
/// Computes the inverse of a square matrix if it exists.
///
diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSAutograd.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSAutograd.cs
index ba883f72f..7da59a090 100644
--- a/src/TorchSharp/PInvoke/LibTorchSharp.THSAutograd.cs
+++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSAutograd.cs
@@ -8,6 +8,7 @@ namespace TorchSharp.PInvoke
internal static partial class LibTorchSharp
{
[DllImport("LibTorchSharp")]
+ [return: MarshalAs(UnmanagedType.U1)]
internal static extern bool THSAutograd_isGradEnabled();
[DllImport("LibTorchSharp")]
diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSCuda.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSCuda.cs
index 361e3b583..6ed8ddfba 100644
--- a/src/TorchSharp/PInvoke/LibTorchSharp.THSCuda.cs
+++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSCuda.cs
@@ -16,28 +16,33 @@ internal static partial class LibTorchSharp
internal static extern void THSCuda_synchronize(long device_index);
[DllImport("LibTorchSharp")]
+ [return: MarshalAs(UnmanagedType.U1)]
internal static extern bool THSBackend_cublas_get_allow_tf32();
[DllImport("LibTorchSharp")]
- internal static extern void THSBackend_cublas_set_allow_tf32(bool flag);
+ internal static extern void THSBackend_cublas_set_allow_tf32([MarshalAs(UnmanagedType.U1)] bool flag);
[DllImport("LibTorchSharp")]
+ [return: MarshalAs(UnmanagedType.U1)]
internal static extern bool THSBackend_cudnn_get_allow_tf32();
[DllImport("LibTorchSharp")]
- internal static extern void THSBackend_cudnn_set_allow_tf32(bool flag);
+ internal static extern void THSBackend_cudnn_set_allow_tf32([MarshalAs(UnmanagedType.U1)] bool flag);
[DllImport("LibTorchSharp")]
+ [return: MarshalAs(UnmanagedType.U1)]
internal static extern bool THSBackend_cuda_get_allow_fp16_reduced_precision_reduction();
[DllImport("LibTorchSharp")]
- internal static extern void THSBackend_cuda_set_allow_fp16_reduced_precision_reduction(bool flag);
+ internal static extern void THSBackend_cuda_set_allow_fp16_reduced_precision_reduction([MarshalAs(UnmanagedType.U1)] bool flag);
[DllImport("LibTorchSharp")]
+ [return: MarshalAs(UnmanagedType.U1)]
internal static extern bool THSBackend_cuda_get_enable_flash_sdp();
[DllImport("LibTorchSharp")]
- internal static extern void THSBackend_cuda_set_enable_flash_sdp(bool flag);
+ internal static extern void THSBackend_cuda_set_enable_flash_sdp([MarshalAs(UnmanagedType.U1)] bool flag);
[DllImport("LibTorchSharp")]
+ [return: MarshalAs(UnmanagedType.U1)]
internal static extern bool THSBackend_cuda_get_enable_math_sdp();
[DllImport("LibTorchSharp")]
- internal static extern void THSBackend_cuda_set_enable_math_sdp(bool flag);
+ internal static extern void THSBackend_cuda_set_enable_math_sdp([MarshalAs(UnmanagedType.U1)] bool flag);
}
}
diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSData.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSData.cs
index 695cb8de9..ee8ffe6c6 100644
--- a/src/TorchSharp/PInvoke/LibTorchSharp.THSData.cs
+++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSData.cs
@@ -11,18 +11,19 @@ internal static partial class LibTorchSharp
internal static extern IntPtr THSData_loaderMNIST(
[MarshalAs(UnmanagedType.LPStr)] string filename,
long batchSize,
- bool isTrain);
+ [MarshalAs(UnmanagedType.U1)] bool isTrain);
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSData_loaderCIFAR10(
[MarshalAs(UnmanagedType.LPStr)] string path,
long batchSize,
- bool isTrain);
+ [MarshalAs(UnmanagedType.U1)] bool isTrain);
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSData_current(IntPtr iterator, IntPtr data, IntPtr target);
[DllImport("LibTorchSharp")]
+ [return: MarshalAs(UnmanagedType.U1)]
internal static extern bool THSData_moveNext(IntPtr iterator);
[DllImport("LibTorchSharp")]
diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs
index 307092b03..d878d6bca 100644
--- a/src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs
+++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSJIT.cs
@@ -41,12 +41,13 @@ internal static partial class LibTorchSharp
internal static extern int THSJIT_Module_num_outputs(torch.nn.Module.HType module);
[DllImport("LibTorchSharp")]
- internal static extern void THSJIT_Module_train(torch.nn.Module.HType module, bool on);
+ internal static extern void THSJIT_Module_train(torch.nn.Module.HType module, [MarshalAs(UnmanagedType.U1)] bool on);
[DllImport("LibTorchSharp")]
internal static extern void THSJIT_Module_eval(torch.nn.Module.HType module);
[DllImport("LibTorchSharp")]
+ [return: MarshalAs(UnmanagedType.U1)]
internal static extern bool THSJIT_Module_is_training(torch.nn.Module.HType module);
[DllImport("LibTorchSharp")]
diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSLinalg.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSLinalg.cs
index 4940507a9..5c7cad92d 100644
--- a/src/TorchSharp/PInvoke/LibTorchSharp.THSLinalg.cs
+++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSLinalg.cs
@@ -38,6 +38,9 @@ internal static partial class LibTorchSharp
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSLinalg_eig(IntPtr tensor, out IntPtr pEigenvectors);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_geqrf(IntPtr tensor, out IntPtr tau);
+
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSLinalg_eigh(IntPtr tensor, byte UPLO, out IntPtr pEigenvectors);
@@ -47,6 +50,9 @@ internal static partial class LibTorchSharp
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSLinalg_eigvalsh(IntPtr tensor, byte UPLO);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSLinalg_householder_product(IntPtr tensor, IntPtr tau);
+
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSLinalg_inv(IntPtr tensor);
@@ -60,11 +66,11 @@ internal static partial class LibTorchSharp
internal static extern IntPtr THSLinalg_lstsq_rcond(IntPtr tensor, IntPtr other, double rcond, out IntPtr pResiduals, out IntPtr pRank, out IntPtr pSingularValues);
[DllImport("LibTorchSharp")]
- internal static extern IntPtr THSLinalg_ldl_factor(IntPtr A, bool hermitian, out IntPtr pivots);
+ internal static extern IntPtr THSLinalg_ldl_factor(IntPtr A, [MarshalAs(UnmanagedType.U1)] bool hermitian, out IntPtr pivots);
[DllImport("LibTorchSharp")]
- internal static extern IntPtr THSLinalg_ldl_factor_ex(IntPtr A, bool hermitian, bool check_errors, out IntPtr pivots, out IntPtr info);
+ internal static extern IntPtr THSLinalg_ldl_factor_ex(IntPtr A, [MarshalAs(UnmanagedType.U1)] bool hermitian, [MarshalAs(UnmanagedType.U1)] bool check_errors, out IntPtr pivots, out IntPtr info);
[DllImport("LibTorchSharp")]
- internal static extern IntPtr THSLinalg_ldl_solve(IntPtr LD, IntPtr pivots, IntPtr B, bool hermitian);
+ internal static extern IntPtr THSLinalg_ldl_solve(IntPtr LD, IntPtr pivots, IntPtr B, [MarshalAs(UnmanagedType.U1)] bool hermitian);
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSLinalg_lu(IntPtr tensor, [MarshalAs(UnmanagedType.U1)] bool pivot, out IntPtr pL, out IntPtr pU);
@@ -84,6 +90,9 @@ internal static partial class LibTorchSharp
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSLinalg_matrix_rank_tensor(IntPtr tensor, IntPtr atol, IntPtr rtol, [MarshalAs(UnmanagedType.U1)] bool hermitian);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSLinalg_dot(IntPtr tensor, int len);
+
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSLinalg_multi_dot(IntPtr tensor, int len);
diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs
index 1bf198380..9217498c5 100644
--- a/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs
+++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs
@@ -252,6 +252,7 @@ internal static extern IntPtr THSNN_custom_module(
internal static extern void THSNN_Module_eval(torch.nn.Module.HType module);
[DllImport("LibTorchSharp")]
+ [return: MarshalAs(UnmanagedType.U1)]
internal static extern bool THSNN_Module_is_training(torch.nn.Module.HType module);
[DllImport("LibTorchSharp")]
@@ -1191,19 +1192,19 @@ internal static extern IntPtr THSNN_custom_module(
internal static extern IntPtr THSNN_AvgPool1d_forward(IntPtr module, IntPtr tensor);
[DllImport("LibTorchSharp")]
- internal static extern IntPtr THSNN_AvgPool1d_ctor(IntPtr pkernelSize, IntPtr pstrides, IntPtr ppadding, bool ceil_mode, bool count_include_pad, long divisor_override, out IntPtr pBoxedModule);
+ internal static extern IntPtr THSNN_AvgPool1d_ctor(IntPtr pkernelSize, IntPtr pstrides, IntPtr ppadding, [MarshalAs(UnmanagedType.U1)] bool ceil_mode, [MarshalAs(UnmanagedType.U1)] bool count_include_pad, long divisor_override, out IntPtr pBoxedModule);
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSNN_AvgPool2d_forward(IntPtr module, IntPtr tensor);
[DllImport("LibTorchSharp")]
- internal static extern IntPtr THSNN_AvgPool2d_ctor(IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength, IntPtr ppadding, int paddingLength, bool ceil_mode, bool count_include_pad, long divisor_override, out IntPtr pBoxedModule);
+ internal static extern IntPtr THSNN_AvgPool2d_ctor(IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength, IntPtr ppadding, int paddingLength, [MarshalAs(UnmanagedType.U1)] bool ceil_mode, [MarshalAs(UnmanagedType.U1)] bool count_include_pad, long divisor_override, out IntPtr pBoxedModule);
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSNN_AvgPool3d_forward(IntPtr module, IntPtr tensor);
[DllImport("LibTorchSharp")]
- internal static extern IntPtr THSNN_AvgPool3d_ctor(IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength, IntPtr ppadding, int paddingLength, bool ceil_mode, bool count_include_pad, long divisor_override, out IntPtr pBoxedModule);
+ internal static extern IntPtr THSNN_AvgPool3d_ctor(IntPtr pkernelSize, int kernelSizeLength, IntPtr pstrides, int stridesLength, IntPtr ppadding, int paddingLength, [MarshalAs(UnmanagedType.U1)] bool ceil_mode, [MarshalAs(UnmanagedType.U1)] bool count_include_pad, long divisor_override, out IntPtr pBoxedModule);
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSNN_FractionalMaxPool2d_forward(torch.nn.Module.HType module, IntPtr tensor);
diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs
index d49fbfa5e..2d8b4ccef 100644
--- a/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs
+++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs
@@ -2,6 +2,7 @@
#nullable enable
using System;
using System.Runtime.InteropServices;
+using TorchSharp.Modules;
namespace TorchSharp.PInvoke
{
@@ -67,7 +68,7 @@ internal static extern IntPtr THSTensor_max_pool1d(IntPtr input,
IntPtr strides, int stridesLength,
IntPtr padding, int paddingLength,
IntPtr dilation, int dilationLength,
- bool ceil_mode);
+ [MarshalAs(UnmanagedType.U1)] bool ceil_mode);
[DllImport("LibTorchSharp")]
internal static extern void THSTensor_max_pool1d_with_indices(IntPtr input, AllocatePinnedArray allocator,
@@ -75,7 +76,7 @@ internal static extern void THSTensor_max_pool1d_with_indices(IntPtr input, Allo
IntPtr strides, int stridesLength,
IntPtr padding, int paddingLength,
IntPtr dilation, int dilationLength,
- bool ceil_mode);
+ [MarshalAs(UnmanagedType.U1)] bool ceil_mode);
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_max_pool2d(IntPtr input,
@@ -83,7 +84,7 @@ internal static extern IntPtr THSTensor_max_pool2d(IntPtr input,
IntPtr strides, int stridesLength,
IntPtr padding, int paddingLength,
IntPtr dilation, int dilationLength,
- bool ceil_mode);
+ [MarshalAs(UnmanagedType.U1)] bool ceil_mode);
[DllImport("LibTorchSharp")]
internal static extern void THSTensor_max_pool2d_with_indices(IntPtr input, AllocatePinnedArray allocator,
@@ -91,7 +92,7 @@ internal static extern void THSTensor_max_pool2d_with_indices(IntPtr input, Allo
IntPtr strides, int stridesLength,
IntPtr padding, int paddingLength,
IntPtr dilation, int dilationLength,
- bool ceil_mode);
+ [MarshalAs(UnmanagedType.U1)] bool ceil_mode);
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_max_pool3d(IntPtr input,
@@ -99,7 +100,7 @@ internal static extern IntPtr THSTensor_max_pool3d(IntPtr input,
IntPtr strides, int stridesLength,
IntPtr padding, int paddingLength,
IntPtr dilation, int dilationLength,
- bool ceil_mode);
+ [MarshalAs(UnmanagedType.U1)] bool ceil_mode);
[DllImport("LibTorchSharp")]
internal static extern void THSTensor_max_pool3d_with_indices(IntPtr input, AllocatePinnedArray allocator,
@@ -107,7 +108,7 @@ internal static extern void THSTensor_max_pool3d_with_indices(IntPtr input, Allo
IntPtr strides, int stridesLength,
IntPtr padding, int paddingLength,
IntPtr dilation, int dilationLength,
- bool ceil_mode);
+ [MarshalAs(UnmanagedType.U1)] bool ceil_mode);
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_maxunpool3d(IntPtr input, IntPtr indices, IntPtr outputSize, int outputSizeLength, IntPtr strides, int stridesLength,
@@ -118,32 +119,32 @@ internal static extern IntPtr THSTensor_avg_pool1d(IntPtr input,
IntPtr kernelSize, int kernelSizeLength,
IntPtr strides, int stridesLength,
IntPtr padding, int paddingLength,
- bool ceil_mode,
- bool count_include_pad);
+ [MarshalAs(UnmanagedType.U1)] bool ceil_mode,
+ [MarshalAs(UnmanagedType.U1)] bool count_include_pad);
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_avg_pool2d(IntPtr input,
IntPtr kernelSize, int kernelSizeLength,
IntPtr strides, int stridesLength,
IntPtr padding, int paddingLength,
- bool ceil_mode,
- bool count_include_pad);
+ [MarshalAs(UnmanagedType.U1)] bool ceil_mode,
+ [MarshalAs(UnmanagedType.U1)] bool count_include_pad);
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_avg_pool3d(IntPtr input,
IntPtr kernelSize, int kernelSizeLength,
IntPtr strides, int stridesLength,
IntPtr padding, int paddingLength,
- bool ceil_mode,
- bool count_include_pad);
+ [MarshalAs(UnmanagedType.U1)] bool ceil_mode,
+ [MarshalAs(UnmanagedType.U1)] bool count_include_pad);
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_avg_pool2d_backward(IntPtr gradOutput, IntPtr originalInput,
IntPtr kernelSize, int kernelSizeLength,
IntPtr strides, int stridesLength,
IntPtr padding, int paddingLength,
- bool ceil_mode,
- bool count_include_pad,
+ [MarshalAs(UnmanagedType.U1)] bool ceil_mode,
+ [MarshalAs(UnmanagedType.U1)] bool count_include_pad,
long divisorOverride);
[DllImport("LibTorchSharp")]
@@ -151,8 +152,8 @@ internal static extern IntPtr THSTensor_avg_pool3d_backward(IntPtr gradOutput, I
IntPtr kernelSize, int kernelSizeLength,
IntPtr strides, int stridesLength,
IntPtr padding, int paddingLength,
- bool ceil_mode,
- bool count_include_pad,
+ [MarshalAs(UnmanagedType.U1)] bool ceil_mode,
+ [MarshalAs(UnmanagedType.U1)] bool count_include_pad,
long divisorOverride);
[DllImport("LibTorchSharp")]
@@ -255,6 +256,7 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
internal static extern int THSTensor_device_type(IntPtr handle);
[DllImport("LibTorchSharp")]
+ [return: MarshalAs(UnmanagedType.U1)]
internal static extern bool THSTensor_is_sparse(IntPtr handle);
[DllImport("LibTorchSharp")]
@@ -264,6 +266,7 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
internal static extern IntPtr THSTensor_save(IntPtr tensor, [MarshalAs(UnmanagedType.LPStr)] string location);
[DllImport("LibTorchSharp")]
+ [return: MarshalAs(UnmanagedType.U1)]
internal static extern bool THSTensor_requires_grad(IntPtr handle);
[DllImport("LibTorchSharp")]
@@ -273,6 +276,10 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
internal static extern void THSTensor_retain_grad(IntPtr handle);
[DllImport("LibTorchSharp")]
+ internal static extern int THSTensor_result_type(IntPtr tensor1, IntPtr tensor2);
+
+ [DllImport("LibTorchSharp")]
+ [return: MarshalAs(UnmanagedType.U1)]
internal static extern bool THSTensor_is_cpu(IntPtr handle);
[DllImport("LibTorchSharp")]
@@ -300,6 +307,7 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
internal static extern long THSTensor_sizes(IntPtr handle, AllocatePinnedArray allocator);
[DllImport("LibTorchSharp")]
+ [return: MarshalAs(UnmanagedType.U1)]
internal static extern bool THSTensor_has_names(IntPtr handle);
[DllImport("LibTorchSharp")]
@@ -341,6 +349,9 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_clone(IntPtr handle);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_combinations(IntPtr handle, int r, [MarshalAs(UnmanagedType.U1)] bool with_replacement);
+
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_copy_(IntPtr handle, IntPtr source, [MarshalAs(UnmanagedType.U1)] bool non_blocking);
@@ -410,6 +421,12 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_select(IntPtr tensor, long dim, long index);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_adjoint(IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_argwhere(IntPtr tensor);
+
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_take(IntPtr tensor, IntPtr index);
@@ -491,9 +508,15 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_tril(IntPtr tensor, long diagonal);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_tril_indices(long row, long col, long offset, sbyte scalar_type, int device_type, int device_index);
+
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_triu(IntPtr tensor, long diagonal);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_triu_indices(long row, long col, long offset, sbyte scalar_type, int device_type, int device_index);
+
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_transpose_(IntPtr tensor, long dim1, long dim2);
@@ -680,6 +703,9 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_isnan(IntPtr tensor);
+ [DllImport("LibTorchSharp")]
+ internal static extern long THSTensor_is_nonzero(IntPtr handle);
+
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_isreal(IntPtr tensor);
@@ -716,6 +742,9 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_bmm(IntPtr batch1, IntPtr batch2);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_cdist(IntPtr x1, IntPtr x2, double p, long compute_mode);
+
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_bucketize(IntPtr input, IntPtr boundaries, [MarshalAs(UnmanagedType.U1)] bool out_int32, [MarshalAs(UnmanagedType.U1)] bool right);
@@ -758,6 +787,9 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_trace(IntPtr tensor);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_diag_embed(IntPtr tensor, long offset, long dim1, long dim2);
+
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_diagflat(IntPtr tensor, long offset);
@@ -795,9 +827,11 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
internal static extern IntPtr THSTensor_eq_scalar_(IntPtr tensor, IntPtr trg);
[DllImport("LibTorchSharp")]
+ [return: MarshalAs(UnmanagedType.U1)]
internal static extern bool THSTensor_equal(IntPtr tensor, IntPtr trg);
[DllImport("LibTorchSharp")]
+ [return: MarshalAs(UnmanagedType.U1)]
internal static extern bool THSTensor_allclose(IntPtr tensor, IntPtr trg, double rtol, double atol, [MarshalAs(UnmanagedType.U1)] bool equal_nan);
[DllImport("LibTorchSharp")]
@@ -992,6 +1026,9 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_outer(IntPtr input, IntPtr vec2);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_ormqr(IntPtr input, IntPtr tau, IntPtr other, [MarshalAs(UnmanagedType.U1)] bool left, [MarshalAs(UnmanagedType.U1)] bool transpose);
+
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_inner(IntPtr input, IntPtr vec2);
@@ -1175,6 +1212,15 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_scatter_add(IntPtr tensor, long dim, IntPtr index, IntPtr source);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_diagonal_scatter(IntPtr tensor, IntPtr source, long offset, long dim1, long dim2);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_select_scatter(IntPtr tensor, IntPtr source, long dim, long index);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_slice_scatter(IntPtr tensor, IntPtr source, long dim, IntPtr start, IntPtr end, long step);
+
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_scatter_add_(IntPtr tensor, long dim, IntPtr index, IntPtr source);
@@ -1214,6 +1260,9 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_roll(IntPtr tensor, IntPtr shifts, int shLength, IntPtr dims, long dimLength);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_rot90(IntPtr tensor, long k, long dim1, long dim2);
+
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_slice(IntPtr tensor, long dim, long start, long length, long step);
@@ -1496,6 +1545,9 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_cat(IntPtr tensor, int len, long dim);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_cartesian_prod(IntPtr tensor, int len);
+
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_stack(IntPtr tensor, int len, long dim);
@@ -1619,6 +1671,12 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_resolve_conj(IntPtr tensor);
+ [DllImport("LibTorchSharp")]
+ internal static extern long THSTensor_is_neg(IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_resolve_neg(IntPtr tensor);
+
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_bitwise_left_shift(IntPtr tensor, IntPtr other);
@@ -1667,6 +1725,18 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_floor_(IntPtr tensor);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_floor_divide(IntPtr left, IntPtr right);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_floor_divide_(IntPtr left, IntPtr right);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_floor_divide_scalar(IntPtr left, IntPtr right);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_floor_divide_scalar_(IntPtr left, IntPtr right);
+
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_frexp(IntPtr tensor, out IntPtr exponent);
@@ -1853,6 +1923,15 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_sign(IntPtr tensor);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_sign_(IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_sgn(IntPtr tensor);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_sgn_(IntPtr tensor);
+
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_signbit(IntPtr tensor);
@@ -1875,7 +1954,16 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
internal static extern IntPtr THSTensor_trapezoid_dx(IntPtr y, double dx, long dim);
[DllImport("LibTorchSharp")]
- internal static extern IntPtr THSTensor_sign_(IntPtr tensor);
+ internal static extern IntPtr THSTensor_true_divide(IntPtr left, IntPtr right);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_true_divide_(IntPtr left, IntPtr right);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_true_divide_scalar(IntPtr left, IntPtr right);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_true_divide_scalar_(IntPtr left, IntPtr right);
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_trunc_(IntPtr tensor);
@@ -1940,6 +2028,12 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_vdot(IntPtr tensor, IntPtr target);
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_dot(IntPtr tensor, IntPtr target);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern IntPtr THSTensor_logdet(IntPtr tensor);
+
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_lu(IntPtr tensor, [MarshalAs(UnmanagedType.U1)] bool pivot, [MarshalAs(UnmanagedType.U1)] bool get_infos, out IntPtr infos, out IntPtr pivots);
diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs
index cdc6a4d90..3c256f5b6 100644
--- a/src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs
+++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSTorch.cs
@@ -17,6 +17,12 @@ internal static partial class LibTorchSharp
[DllImport("LibTorchSharp")]
internal static extern byte THSTorch_scalar_type(IntPtr value);
+ [DllImport("LibTorchSharp")]
+ internal static extern int THSTorch_can_cast(int type1, int type2);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern int THSTorch_promote_types(int type1, int type2);
+
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTorch_uint8_to_scalar(byte value);
@@ -85,5 +91,17 @@ internal static partial class LibTorchSharp
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTorch_lstsq(IntPtr handle, IntPtr b, out IntPtr qr);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern int THSTorch_get_num_threads();
+
+ [DllImport("LibTorchSharp")]
+ internal static extern void THSTorch_set_num_threads(int threads);
+
+ [DllImport("LibTorchSharp")]
+ internal static extern int THSTorch_get_num_interop_threads();
+
+ [DllImport("LibTorchSharp")]
+ internal static extern void THSTorch_set_num_interop_threads(int threads);
}
}
diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSTorchCuda.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSTorchCuda.cs
index b9435d713..b39478e34 100644
--- a/src/TorchSharp/PInvoke/LibTorchSharp.THSTorchCuda.cs
+++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSTorchCuda.cs
@@ -7,9 +7,11 @@ namespace TorchSharp.PInvoke
internal static partial class LibTorchSharp
{
[DllImport("LibTorchSharp")]
+ [return: MarshalAs(UnmanagedType.U1)]
internal static extern bool THSTorchCuda_is_available();
[DllImport("LibTorchSharp")]
+ [return: MarshalAs(UnmanagedType.U1)]
internal static extern bool THSTorchCuda_cudnn_is_available();
[DllImport("LibTorchSharp")]
diff --git a/src/TorchSharp/Tensor/Enums/compute_mode.cs b/src/TorchSharp/Tensor/Enums/compute_mode.cs
index 5e1a9c833..61e8fba63 100644
--- a/src/TorchSharp/Tensor/Enums/compute_mode.cs
+++ b/src/TorchSharp/Tensor/Enums/compute_mode.cs
@@ -1,10 +1,10 @@
-// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
+// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
namespace TorchSharp
{
public enum compute_mode
{
- use_mm_for_euclid_dist_if_necessary,
- use_mm_for_euclid_dist,
- donot_use_mm_for_euclid_dist
+ use_mm_for_euclid_dist_if_necessary = 0,
+ use_mm_for_euclid_dist = 1,
+ donot_use_mm_for_euclid_dist = 2
}
}
\ No newline at end of file
diff --git a/src/TorchSharp/Tensor/Tensor.Factories.cs b/src/TorchSharp/Tensor/Tensor.Factories.cs
index 7fd36f486..42956d2a3 100644
--- a/src/TorchSharp/Tensor/Tensor.Factories.cs
+++ b/src/TorchSharp/Tensor/Tensor.Factories.cs
@@ -2794,7 +2794,7 @@ public static Tensor sparse(Tensor indices, Tensor values, long[] size, ScalarTy
}
///
- /// onstructs a complex tensor with its real part equal to real and its imaginary part equal to imag.
+ /// Constructs a complex tensor with its real part equal to real and its imaginary part equal to imag.
///
public static Tensor complex(Tensor real, Tensor imag)
{
diff --git a/src/TorchSharp/Tensor/Tensor.LinearAlgebra.cs b/src/TorchSharp/Tensor/Tensor.LinearAlgebra.cs
index 09300e262..83378431a 100644
--- a/src/TorchSharp/Tensor/Tensor.LinearAlgebra.cs
+++ b/src/TorchSharp/Tensor/Tensor.LinearAlgebra.cs
@@ -68,6 +68,41 @@ public Tensor det()
return linalg.det(this);
}
+ ///
+ /// Calculates log determinant of a square matrix or batches of square matrices.
+ ///
+ ///
+ public Tensor logdet()
+ {
+ var shape = this.shape;
+ var len = shape.Length;
+ if (shape[len - 1] != shape[len - 2]) throw new ArgumentException("The input tensor is not square");
+
+ var res = THSTensor_logdet(Handle);
+ if (res == IntPtr.Zero) { CheckForErrors(); }
+ return new Tensor(res);
+ }
+
+
+ ///
+ /// This is a low-level function for calling LAPACK’s geqrf directly.
+ /// This function returns a namedtuple (a, tau) as defined in LAPACK documentation for geqrf.
+ ///
+ ///
+ /// Computes a QR decomposition of input. Both Q and R matrices are stored in the same output tensor a.
+ /// The elements of R are stored on and above the diagonal. Elementary reflectors (or Householder vectors)
+ /// implicitly defining matrix Q are stored below the diagonal. The results of this function can be used
+ /// together with torch.linalg.householder_product() to obtain the Q matrix or with torch.ormqr(), which
+ /// uses an implicit representation of the Q matrix, for an efficient matrix-matrix multiplication.
+ ///
+ public (Tensor a, Tensor tau) geqrf()
+ {
+ var res = THSTensor_geqrf(Handle, out var tau);
+ if (res == IntPtr.Zero || tau == IntPtr.Zero)
+ torch.CheckForErrors();
+ return (new Tensor(res), new Tensor(tau));
+ }
+
///
/// Matrix product of two tensors.
///
@@ -138,7 +173,6 @@ public Tensor matrix_power(int n)
///
/// Computes the dot product of two 1D tensors.
///
- ///
///
///
/// The vdot(a, b) function handles complex numbers differently than dot(a, b).
@@ -152,6 +186,18 @@ public Tensor vdot(Tensor target)
return new Tensor(res);
}
+ ///
+ /// Computes the dot product of two 1D tensors.
+ ///
+ ///
+ public Tensor dot(Tensor target)
+ {
+ if (shape.Length != 1 || target.shape.Length != 1 || shape[0] != target.shape[0]) throw new InvalidOperationException("dot arguments must have the same shape.");
+ var res = THSTensor_dot(Handle, target.Handle);
+ if (res == IntPtr.Zero) { CheckForErrors(); }
+ return new Tensor(res);
+ }
+
///
/// Computes the pseudoinverse (Moore-Penrose inverse) of a matrix.
///
@@ -166,6 +212,22 @@ public Tensor pinverse(double rcond = 1e-15, bool hermitian = false)
CheckForErrors();
return new Tensor(res);
}
+
+ ///
+ /// Computes the matrix-matrix multiplication of a product of Householder matrices with a general matrix.
+ ///
+ /// Tensor of shape (*, min(mn, k)) where * is zero or more batch dimensions.
+ /// Tensor of shape (*, m, n) where * is zero or more batch dimensions.
+ /// Controls the order of multiplication.
+ /// Controls whether the matrix Q is conjugate transposed or not.
+ ///
+ public Tensor ormqr(Tensor tau, Tensor other, bool left = true, bool transpose = false)
+ {
+ var res = THSTensor_ormqr(Handle, tau.handle, other.Handle, left, transpose);
+ if (res == IntPtr.Zero)
+ CheckForErrors();
+ return new Tensor(res);
+ }
}
}
}
\ No newline at end of file
diff --git a/src/TorchSharp/Tensor/Tensor.Math.cs b/src/TorchSharp/Tensor/Tensor.Math.cs
index 9d0cb8bda..4ac2979c5 100644
--- a/src/TorchSharp/Tensor/Tensor.Math.cs
+++ b/src/TorchSharp/Tensor/Tensor.Math.cs
@@ -570,7 +570,7 @@ public Tensor conj_physical_()
public bool is_conj()
{
var res = THSTensor_is_conj(Handle);
- CheckForErrors();
+ if (res == -1) CheckForErrors();
return res != 0;
}
@@ -587,6 +587,29 @@ public Tensor resolve_conj()
return new Tensor(res);
}
+ ///
+ /// Returns true if the input's negative bit is set to True.
+ ///
+ public bool is_neg()
+ {
+ var res = THSTensor_is_neg(Handle);
+ if (res == -1) CheckForErrors();
+ return res != 0;
+ }
+
+ ///
+ /// Returns a new tensor with materialized negation if input’s negative bit is set to True, else returns input.
+ /// The output tensor will always have its negative bit set to False.
+ ///
+ ///
+ public Tensor resolve_neg()
+ {
+ var res = THSTensor_resolve_neg(Handle);
+ if (res == IntPtr.Zero)
+ CheckForErrors();
+ return new Tensor(res);
+ }
+
///
/// Returns a tuple (values, indices) where values is the cumulative maximum of elements of input in the dimension dim.
/// Indices is the index location of each maximum value found in the dimension dim.
@@ -825,6 +848,54 @@ public Tensor floor_()
return new Tensor(res);
}
+ ///
+ /// Computes input divided by other, elementwise, and floors the result.
+ ///
+ /// the divisor
+ public Tensor floor_divide(Tensor other)
+ {
+ var res = THSTensor_floor_divide(Handle, other.Handle);
+ if (res == IntPtr.Zero)
+ CheckForErrors();
+ return new Tensor(res);
+ }
+
+ ///
+ /// Computes input divided by other, elementwise, and floors the result.
+ ///
+ /// the divisor
+ public Tensor floor_divide(Scalar other)
+ {
+ var res = THSTensor_floor_divide_scalar(Handle, other.Handle);
+ if (res == IntPtr.Zero)
+ CheckForErrors();
+ return new Tensor(res);
+ }
+
+ ///
+ /// Computes input divided by other, elementwise, and floors the result, computation done in place.
+ ///
+ /// the divisor
+ public Tensor floor_divide_(Tensor other)
+ {
+ var res = THSTensor_floor_divide_(Handle, other.Handle);
+ if (res == IntPtr.Zero)
+ CheckForErrors();
+ return new Tensor(res);
+ }
+
+ ///
+ /// Computes input divided by other, elementwise, and floors the result, computation done in place.
+ ///
+ /// the divisor
+ public Tensor floor_divide_(Scalar other)
+ {
+ var res = THSTensor_floor_divide_scalar_(Handle, other.Handle);
+ if (res == IntPtr.Zero)
+ CheckForErrors();
+ return new Tensor(res);
+ }
+
///
/// Computes the element-wise remainder of division.
///
@@ -1540,6 +1611,36 @@ public Tensor sign_()
return new Tensor(res);
}
+ ///
+ /// This function is an extension of torch.sign() to complex tensors.
+ /// It computes a new tensor whose elements have the same angles as the corresponding
+ /// elements of input and absolute values (i.e. magnitudes) of one for complex tensors
+ /// and is equivalent to torch.sign() for non-complex tensors.
+ ///
+ ///
+ public Tensor sgn()
+ {
+ var res = THSTensor_sgn(Handle);
+ if (res == IntPtr.Zero)
+ CheckForErrors();
+ return new Tensor(res);
+ }
+
+ ///
+ /// This function is an extension of torch.sign() to complex tensors.
+ /// It computes a new tensor whose elements have the same angles as the corresponding
+ /// elements of input and absolute values (i.e. magnitudes) of one for complex tensors
+ /// and is equivalent to torch.sign() for non-complex tensors. In-place version.
+ ///
+ ///
+ public Tensor sgn_()
+ {
+ var res = THSTensor_sgn_(Handle);
+ if (res == IntPtr.Zero)
+ CheckForErrors();
+ return new Tensor(res);
+ }
+
///
/// Tests if each element of input has its sign bit set (is less than zero) or not.
///
@@ -1614,7 +1715,7 @@ public Tensor sub_(Scalar target)
///
public Tensor cumulative_trapezoid(double dx = 1, long dim = -1)
{
- IntPtr res = THSTensor_trapezoid_dx(Handle, dx, dim);
+ IntPtr res = THSTensor_cumulative_trapezoid_dx(Handle, dx, dim);
if (res == IntPtr.Zero) { CheckForErrors(); }
return new Tensor(res);
}
@@ -1628,7 +1729,7 @@ public Tensor cumulative_trapezoid(double dx = 1, long dim = -1)
///
public Tensor cumulative_trapezoid(Tensor x, long dim = -1)
{
- IntPtr res = THSTensor_trapezoid_x(Handle, x.Handle, dim);
+ IntPtr res = THSTensor_cumulative_trapezoid_x(Handle, x.Handle, dim);
if (res == IntPtr.Zero) { CheckForErrors(); }
return new Tensor(res);
}
@@ -1661,6 +1762,54 @@ public Tensor trapezoid(Tensor x, long dim = -1)
return new Tensor(res);
}
+ ///
+ /// Computes input divided by other, elementwise, and floors the result.
+ ///
+ /// the divisor
+ public Tensor true_divide(Tensor other)
+ {
+ var res = THSTensor_true_divide(Handle, other.Handle);
+ if (res == IntPtr.Zero)
+ CheckForErrors();
+ return new Tensor(res);
+ }
+
+ ///
+ /// Computes input divided by other, elementwise, and floors the result.
+ ///
+ /// the divisor
+ public Tensor true_divide(Scalar other)
+ {
+ var res = THSTensor_true_divide_scalar(Handle, other.Handle);
+ if (res == IntPtr.Zero)
+ CheckForErrors();
+ return new Tensor(res);
+ }
+
+ ///
+ /// Computes input divided by other, elementwise, and floors the result, computation done in place.
+ ///
+ /// the divisor
+ public Tensor true_divide_(Tensor other)
+ {
+ var res = THSTensor_true_divide_(Handle, other.Handle);
+ if (res == IntPtr.Zero)
+ CheckForErrors();
+ return new Tensor(res);
+ }
+
+ ///
+ /// Computes input divided by other, elementwise, and floors the result, computation done in place.
+ ///
+ /// the divisor
+ public Tensor true_divide_(Scalar other)
+ {
+ var res = THSTensor_true_divide_scalar_(Handle, other.Handle);
+ if (res == IntPtr.Zero)
+ CheckForErrors();
+ return new Tensor(res);
+ }
+
///
/// Returns a new tensor with the truncated integer values of the elements of input.
///
diff --git a/src/TorchSharp/Tensor/Tensor.cs b/src/TorchSharp/Tensor/Tensor.cs
index 251a61bfc..512dc4c11 100644
--- a/src/TorchSharp/Tensor/Tensor.cs
+++ b/src/TorchSharp/Tensor/Tensor.cs
@@ -238,10 +238,30 @@ internal IntPtr MoveHandle()
public bool is_integral() => torch.is_integral(dtype);
+ ///
+ /// Returns True if the data type of input is a floating point data type.
+ ///
public bool is_floating_point() => torch.is_floating_point(dtype);
+ ///
+ /// Returns True if the data type of input is a complex data type i.e., one of torch.complex64, and torch.complex128.
+ ///
public bool is_complex() => torch.is_complex(dtype);
+ ///
+ /// Returns True if the input is a single element tensor which is not equal to zero after type conversions,
+ /// i.e. not equal to torch.tensor([0.]) or torch.tensor([0]) or torch.tensor([False]).
+ /// Throws an InvalidOperationException if torch.numel() != 1.
+ ///
+ public bool is_nonzero()
+ {
+ if (numel() != 1)
+ throw new InvalidOperationException("is_nonzero() called on non-singleton tensor");
+ var res = LibTorchSharp.THSTensor_is_nonzero(Handle);
+ CheckForErrors();
+ return res != 0;
+ }
+
public bool is_cuda => device.type == DeviceType.CUDA;
public bool is_meta => device.type == DeviceType.META;
@@ -1511,6 +1531,21 @@ public Tensor take(Tensor index)
return new Tensor(res);
}
+ ///
+ /// Returns a tensor containing the indices of all non-zero elements of input.
+ /// Each row in the result contains the indices of a non-zero element in input.
+ /// The result is sorted lexicographically, with the last index changing the fastest (C-style).
+ /// If input has n dimensions, then the resulting indices tensor out is of size (z×n), where
+ /// z is the total number of non-zero elements in the input tensor.
+ ///
+ public Tensor argwhere()
+ {
+ var res = LibTorchSharp.THSTensor_argwhere(Handle);
+ if (res == IntPtr.Zero)
+ CheckForErrors();
+ return new Tensor(res);
+ }
+
///
/// Selects values from input at the 1-dimensional indices from indices along the given dim.
///
@@ -1973,6 +2008,17 @@ public Tensor transpose(long dim0, long dim1)
return new Tensor(res);
}
+ ///
+ /// Returns a view of the tensor conjugated and with the last two dimensions transposed.
+ ///
+ public Tensor adjoint()
+ {
+ var res = LibTorchSharp.THSTensor_adjoint(Handle);
+ if (res == IntPtr.Zero)
+ CheckForErrors();
+ return new Tensor(res);
+ }
+
///
/// Returns the lower triangular part of the matrix (2-D tensor) or batch of matrices input, the other elements of the result tensor out are set to 0.
/// The lower triangular part of the matrix is defined as the elements on and below the diagonal.
@@ -3038,6 +3084,29 @@ public Tensor trace()
return new Tensor(res);
}
+ ///
+ /// Creates a tensor whose diagonals of certain 2D planes (specified by dim1 and dim2) are filled by input.
+ /// To facilitate creating batched diagonal matrices, the 2D planes formed by the last two dimensions of the returned tensor are chosen by default.
+ ///
+ /// The argument offset controls which diagonal to consider:
+ /// If offset is equal to 0, it is the main diagonal.
+ /// If offset is greater than 0, it is above the main diagonal.
+ /// If offset is less than 0, it is below the main diagonal.
+ ///
+ /// The size of the new matrix will be calculated to make the specified diagonal of the size of the last input dimension.Note that for offset other than 0,
+ ///
+ /// the order of dim1 and dim2 matters.Exchanging them is equivalent to changing the sign of offset.
+ ///
+ /// Which diagonal to consider.
+ /// First dimension with respect to which to take diagonal.
+ /// Second dimension with respect to which to take diagonal
+ public Tensor diag_embed(long offset = 0L, long dim1 = -2L, long dim2 = -1L)
+ {
+ var res = LibTorchSharp.THSTensor_diag_embed(Handle, offset, dim1, dim2);
+ if (res == IntPtr.Zero) { CheckForErrors(); }
+ return new Tensor(res);
+ }
+
///
/// If input is a vector (1-D tensor), then returns a 2-D square tensor with the elements of input as the diagonal.
/// If input is a matrix (2-D tensor), then returns a 2-D tensor with diagonal elements equal to a flattened input.
@@ -4103,6 +4172,13 @@ public Tensor outer(Tensor vec2)
return new Tensor(res);
}
+ ///
+ /// Outer product of input and vec2.
+ ///
+ /// 1-D input vector.
+ /// If input is a vector of size n and vec2 is a vector of size m, then out must be a matrix of size n×m.
+ public Tensor ger(Tensor vec2) => outer(vec2);
+
///
/// Computes the dot product for 1D tensors.
/// For higher dimensions, sums the product of elements from input and other along their last dimension.
@@ -5477,6 +5553,47 @@ public Tensor scatter_add_(long dim, Tensor index, Tensor src)
return new Tensor(res);
}
+
+ public Tensor diagonal_scatter(Tensor src, long offset = 0L, long dim1 = 0L, long dim2 = 1L)
+ {
+ var res = LibTorchSharp.THSTensor_diagonal_scatter(Handle, src.Handle, offset, dim1, dim2);
+ if (res == IntPtr.Zero) { CheckForErrors(); }
+ return new Tensor(res);
+ }
+
+ ///
+ /// Embeds the values of the src tensor into input at the given index. This function returns a tensor with fresh storage; it does not create a view.
+ ///
+ /// The tensor to embed into 'this'
+ /// The dimension to insert the slice into
+ /// The index to select with
+ /// This function returns a tensor with fresh storage; it does not create a view.
+ public Tensor select_scatter(Tensor src, long dim, long index)
+ {
+ var res = LibTorchSharp.THSTensor_select_scatter(Handle, src.Handle, dim, index);
+ if (res == IntPtr.Zero) { CheckForErrors(); }
+ return new Tensor(res);
+ }
+
+ ///
+ /// Embeds the values of the src tensor into input at the given dimension.
+ ///
+ /// The tensor to embed into 'this'.
+ /// The dimension to insert the slice into
+ /// The start index of where to insert the slice
+ /// The end index of where to insert the slice
+ /// How many elements to skip
+ public unsafe Tensor slice_scatter(Tensor src, long dim = 0L, long? start = null, long? end = null, long step = 1L)
+ {
+ var _start = start.HasValue ? new long[] { start.Value } : null;
+ var _end = end.HasValue ? new long[] { end.Value } : null;
+ fixed (long* pstart = _start, pend = _end) {
+ var res = LibTorchSharp.THSTensor_slice_scatter(Handle, src.Handle, dim, (IntPtr)pstart, (IntPtr)pend, step);
+ if (res == IntPtr.Zero) { CheckForErrors(); }
+ return new Tensor(res);
+ }
+ }
+
///
/// Gathers values along an axis specified by dim.
///
@@ -5654,6 +5771,25 @@ public Tensor roll((long, long) shifts, (long, long) dims)
///
public Tensor roll(long[] shifts) => _roll(shifts, new long[] { 0 });
+ ///
+ /// Rotate a n-D tensor by 90 degrees in the plane specified by dims axis.
+ /// Rotation direction is from the first towards the second axis if k is greater than 0,
+ /// and from the second towards the first for k less than 0.
+ ///
+ /// The number of times to rotate.
+ /// Axes to rotate
+ public Tensor rot90(long k = 1, (long, long)? dims = null)
+ {
+ if (!dims.HasValue) {
+ dims = (0, 1);
+ }
+
+ var res =
+ LibTorchSharp.THSTensor_rot90(Handle, k, dims.Value.Item1, dims.Value.Item2);
+ if (res == IntPtr.Zero) { CheckForErrors(); }
+ return new Tensor(res);
+ }
+
///
/// Roll the tensor along the given dimension(s).
/// Elements that are shifted beyond the last position are re-introduced at the first position.
@@ -6029,10 +6165,10 @@ public static implicit operator Tensor(Scalar scalar)
///
///
public string ToString(bool disamb,
- string fltFormat = "g5",
- int width = 100,
+ string? fltFormat = null,
+ int? width = null,
CultureInfo? cultureInfo = null,
- string newLine = "") => disamb ? ToString(torch.TensorStringStyle, fltFormat, width, cultureInfo, newLine) : ToMetadataString();
+ string? newLine = null) => disamb ? ToString(torch.TensorStringStyle, fltFormat, width, cultureInfo, newLine) : ToMetadataString();
///
/// Tensor-specific ToString()
@@ -6046,11 +6182,15 @@ public string ToString(bool disamb,
/// The newline string to use, defaults to system default.
///
public string ToString(TensorStringStyle style,
- string fltFormat = "g5",
- int width = 100,
+ string? fltFormat = null,
+ int? width = null,
CultureInfo? cultureInfo = null,
- string newLine = "")
+ string? newLine = null)
{
+ var w = width.HasValue ? width.Value : torch.lineWidth;
+ var nl = newLine is null ? torch.newLine : newLine;
+ var fmt = fltFormat is null ? torch.floatFormat : fltFormat;
+
if (String.IsNullOrEmpty(newLine))
newLine = Environment.NewLine;
@@ -6058,10 +6198,10 @@ public string ToString(TensorStringStyle style,
return ToMetadataString();
return style switch {
- TensorStringStyle.Default => ToString(torch.TensorStringStyle, fltFormat, width, cultureInfo, newLine),
+ TensorStringStyle.Default => ToString(torch.TensorStringStyle, fltFormat, width, cultureInfo, nl),
TensorStringStyle.Metadata => ToMetadataString(),
- TensorStringStyle.Julia => ToJuliaString(fltFormat, width, cultureInfo, newLine),
- TensorStringStyle.Numpy => ToNumpyString(this, ndim, true, fltFormat, cultureInfo, newLine),
+ TensorStringStyle.Julia => ToJuliaString(fmt, w, cultureInfo, nl),
+ TensorStringStyle.Numpy => ToNumpyString(this, ndim, true, fmt, cultureInfo, nl),
_ => throw new InvalidEnumArgumentException($"Unsupported tensor string style: {style}")
};
}
@@ -6716,8 +6856,8 @@ public static bool is_complex(ScalarType type)
}
public static bool is_integral(Tensor t) => is_integral(t.dtype);
- public static bool is_floating_point(Tensor t) => is_floating_point(t.dtype);
- public static bool is_complex(Tensor t) => is_complex(t.dtype);
+ //public static bool is_floating_point(Tensor t) => is_floating_point(t.dtype);
+ //public static bool is_complex(Tensor t) => is_complex(t.dtype);
public static ScalarType @bool = ScalarType.Bool;
diff --git a/src/TorchSharp/Tensor/TensorExtensionMethods.cs b/src/TorchSharp/Tensor/TensorExtensionMethods.cs
index 464072e65..534ac25cd 100644
--- a/src/TorchSharp/Tensor/TensorExtensionMethods.cs
+++ b/src/TorchSharp/Tensor/TensorExtensionMethods.cs
@@ -35,10 +35,51 @@ public static TensorStringStyle TensorStringStyle {
}
}
+ ///
+ /// Set options for printing.
+ ///
+ /// Number of digits of precision for floating point output.
+ /// The number of characters per line for the purpose of inserting line breaks (default = 100).
+ /// The string to use to represent new-lines. Starts out as 'Environment.NewLine'
+ /// Enable scientific notation.
+ public static void set_printoptions(
+ int precision,
+ int linewidth = 100,
+ string newLine = "\n",
+ bool sci_mode = false)
+ {
+ torch.floatFormat = sci_mode ? $"E{precision}" : $"F{precision}";
+ torch.newLine = newLine;
+ torch.lineWidth = linewidth;
+ }
+
+ ///
+ /// Set options for printing.
+ ///
+ ///
+ /// The format string to use for floating point values.
+ /// See: https://learn.microsoft.com/en-us/dotnet/standard/base-types/standard-numeric-format-strings
+ ///
+ /// The number of characters per line for the purpose of inserting line breaks (default = 100).
+ /// The string to use to represent new-lines. Starts out as 'Environment.NewLine'
+ public static void set_printoptions(
+ string floatFormat = "g5",
+ int linewidth = 100,
+ string newLine = "\n")
+ {
+ torch.floatFormat = floatFormat;
+ torch.newLine = newLine;
+ torch.lineWidth = linewidth;
+ }
+
public const TensorStringStyle julia = TensorStringStyle.Julia;
public const TensorStringStyle numpy = TensorStringStyle.Numpy;
private static TensorStringStyle _style = TensorStringStyle.Julia;
+
+ internal static string floatFormat = "g5";
+ internal static string newLine = Environment.NewLine;
+ internal static int lineWidth = 100;
}
///
@@ -60,7 +101,10 @@ public static Modules.Parameter AsParameter(this Tensor tensor)
/// Get a string representation of the tensor.
///
/// The input tensor.
- /// The format string to use for floating point values.
+ ///
+ /// The format string to use for floating point values.
+ /// See: https://learn.microsoft.com/en-us/dotnet/standard/base-types/standard-numeric-format-strings
+ ///
/// The width of each line of the output string.
/// The newline string to use, defaults to system default.
/// The culture info to be used when formatting the numbers.
@@ -74,7 +118,7 @@ public static Modules.Parameter AsParameter(this Tensor tensor)
///
/// Primarily intended for use in interactive notebooks.
///
- public static string str(this Tensor tensor, string fltFormat = "g5", int width = 100, string newLine = "\n", CultureInfo? cultureInfo = null, TensorStringStyle style = TensorStringStyle.Default)
+ public static string str(this Tensor tensor, string? fltFormat = null, int? width = null, string? newLine = "\n", CultureInfo? cultureInfo = null, TensorStringStyle style = TensorStringStyle.Default)
{
return tensor.ToString(style, fltFormat, width, cultureInfo, newLine);
}
@@ -83,7 +127,10 @@ public static string str(this Tensor tensor, string fltFormat = "g5", int width
/// Get a Julia-style string representation of the tensor.
///
/// The input tensor.
- /// The format string to use for floating point values.
+ ///
+ /// The format string to use for floating point values.
+ /// See: https://learn.microsoft.com/en-us/dotnet/standard/base-types/standard-numeric-format-strings
+ ///
/// The width of each line of the output string.
/// The newline string to use, defaults to system default.
/// The culture info to be used when formatting the numbers.
@@ -95,7 +142,7 @@ public static string str(this Tensor tensor, string fltFormat = "g5", int width
///
/// Primarily intended for use in interactive notebooks.
///
- public static string jlstr(this Tensor tensor, string fltFormat = "g5", int width = 100, string newLine = "\n", CultureInfo? cultureInfo = null)
+ public static string jlstr(this Tensor tensor, string? fltFormat = null, int? width = null, string? newLine = "\n", CultureInfo? cultureInfo = null)
{
return tensor.ToString(TensorStringStyle.Julia, fltFormat, width, cultureInfo, newLine);
}
@@ -122,7 +169,10 @@ public static string metastr(this Tensor tensor)
/// Get a numpy-style string representation of the tensor.
///
/// The input tensor.
- /// The format string to use for floating point values.
+ ///
+ /// The format string to use for floating point values.
+ /// See: https://learn.microsoft.com/en-us/dotnet/standard/base-types/standard-numeric-format-strings
+ ///
/// The width of each line of the output string.
/// The newline string to use, defaults to system default.
/// The culture info to be used when formatting the numbers.
@@ -144,7 +194,10 @@ public static string npstr(this Tensor tensor, string fltFormat = "g5", int widt
/// interactive notebook use, primarily.
///
/// The input tensor.
- /// The format string to use for floating point values.
+ ///
+ /// The format string to use for floating point values.
+ /// See: https://learn.microsoft.com/en-us/dotnet/standard/base-types/standard-numeric-format-strings
+ ///
/// The width of each line of the output string.
/// The newline string to use, defaults to system default.
/// The culture info to be used when formatting the numbers.
diff --git a/src/TorchSharp/Tensor/torch.BlasAndLapackOperations.cs b/src/TorchSharp/Tensor/torch.BlasAndLapackOperations.cs
index 6cf7ea450..4d5ba9d67 100644
--- a/src/TorchSharp/Tensor/torch.BlasAndLapackOperations.cs
+++ b/src/TorchSharp/Tensor/torch.BlasAndLapackOperations.cs
@@ -127,8 +127,7 @@ public static Tensor addbmm_(Tensor input, Tensor batch1, Tensor batch2, float b
public static Tensor bmm(Tensor input, Tensor batch2) => input.bmm(batch2);
// https://pytorch.org/docs/stable/generated/torch.chain_matmul
- [Obsolete("not implemented")]
- public static Tensor chain_matmul(params Tensor[] matrices) => throw new NotImplementedException();
+ public static Tensor chain_matmul(params Tensor[] matrices) => torch.linalg.multi_dot(matrices);
// https://pytorch.org/docs/stable/generated/torch.cholesky
@@ -151,19 +150,38 @@ public static Tensor cholesky_solve(Tensor input, Tensor input2, bool upper = fa
=> input.cholesky_solve(input2, upper);
// https://pytorch.org/docs/stable/generated/torch.dot
- [Obsolete("not implemented", true)]
- public static Tensor dot(Tensor input, Tensor other) => throw new NotImplementedException();
+ ///
+ /// Computes the dot product of two 1D tensors.
+ ///
+ public static Tensor dot(Tensor input, Tensor other) => input.dot(other);
// https://pytorch.org/docs/stable/generated/torch.eig
+ [Obsolete("Method removed in Pytorch. Please use the `torch.linalg.eig` function instead.", true)]
public static (Tensor eigenvalues, Tensor eigenvectors) eig(Tensor input, bool eigenvectors = false) => throw new NotImplementedException();
// https://pytorch.org/docs/stable/generated/torch.geqrf
- [Obsolete("not implemented", true)]
- public static Tensor geqrf(Tensor input) => throw new NotImplementedException();
+ ///
+ /// This is a low-level function for calling LAPACK’s geqrf directly.
+ /// This function returns a namedtuple (a, tau) as defined in LAPACK documentation for geqrf.
+ ///
+ /// The input tensor.
+ ///
+ /// Computes a QR decomposition of input. Both Q and R matrices are stored in the same output tensor a.
+ /// The elements of R are stored on and above the diagonal. Elementary reflectors (or Householder vectors)
+ /// implicitly defining matrix Q are stored below the diagonal. The results of this function can be used
+ /// together with torch.linalg.householder_product() to obtain the Q matrix or with torch.ormqr(), which
+ /// uses an implicit representation of the Q matrix, for an efficient matrix-matrix multiplication.
+ ///
+ public static (Tensor a, Tensor tau) geqrf(Tensor input) => input.geqrf();
// https://pytorch.org/docs/stable/generated/torch.ger
- [Obsolete("not implemented", true)]
- public static Tensor ger(Tensor input, Tensor vec2) => throw new NotImplementedException();
+ ///
+ /// Outer product of input and vec2.
+ ///
+ /// The input vector.
+ /// 1-D input vector.
+ /// If input is a vector of size n and vec2 is a vector of size m, then out must be a matrix of size n×m.
+ public static Tensor ger(Tensor input, Tensor vec2) => input.ger(vec2);
// https://pytorch.org/docs/stable/generated/torch.inner
///
@@ -186,12 +204,10 @@ public static Tensor cholesky_solve(Tensor input, Tensor input2, bool upper = fa
public static Tensor det(Tensor input) => input.det();
// https://pytorch.org/docs/stable/generated/torch.logdet
- [Obsolete("not implemented", true)]
- public static Tensor logdet(Tensor input) => throw new NotImplementedException();
+ public static Tensor logdet(Tensor input) => input.logdet();
// https://pytorch.org/docs/stable/generated/torch.slogdet
- [Obsolete("not implemented", true)]
- public static (Tensor res, Tensor logabsdet) slogdet(Tensor A) => throw new NotImplementedException();
+ public static (Tensor res, Tensor logabsdet) slogdet(Tensor A) => torch.linalg.slogdet(A);
// https://pytorch.org/docs/stable/generated/torch.lstsq
///
@@ -285,7 +301,7 @@ public static (Tensor P, Tensor? L, Tensor? U) lu_unpack(Tensor LU_data, Tensor
public static Tensor matrix_power(Tensor input, int n) => input.matrix_power(n);
// https://pytorch.org/docs/stable/generated/torch.matrix_rank
- [Obsolete("not implemented", true)]
+ [Obsolete("This function was deprecated since version 1.9 and is now removed. Please use the 'torch.linalg.matrix_rank' function instead.", true)]
public static Tensor matrix_rank(Tensor input, float? tol = null, bool symmetric = false) => throw new NotImplementedException();
// https://pytorch.org/docs/stable/generated/torch.matrix_exp
@@ -310,12 +326,23 @@ public static (Tensor P, Tensor? L, Tensor? U) lu_unpack(Tensor LU_data, Tensor
public static Tensor mv(Tensor input, Tensor target) => input.mv(target);
// https://pytorch.org/docs/stable/generated/torch.orgqr
- [Obsolete("not implemented", true)]
- public static Tensor orgqr(Tensor input, Tensor tau) => throw new NotImplementedException();
+ ///
+ /// Computes the first n columns of a product of Householder matrices.
+ ///
+ /// tensor of shape (*, m, n) where * is zero or more batch dimensions.
+ /// tensor of shape (*, k) where * is zero or more batch dimensions.
+ public static Tensor orgqr(Tensor input, Tensor tau) => linalg.householder_product(input, tau);
// https://pytorch.org/docs/stable/generated/torch.ormqr
- [Obsolete("not implemented", true)]
- public static Tensor ormqr(Tensor input, Tensor tau, Tensor other, bool left=true, bool transpose=false) => throw new NotImplementedException();
+ ///
+ /// Computes the matrix-matrix multiplication of a product of Householder matrices with a general matrix.
+ ///
+ /// Tensor of shape (*, mn, k) where * is zero or more batch dimensions and mn equals to m or n depending on the left.
+ /// Tensor of shape (*, min(mn, k)) where * is zero or more batch dimensions.
+ /// Tensor of shape (*, m, n) where * is zero or more batch dimensions.
+ /// Controls the order of multiplication.
+ /// Controls whether the matrix Q is conjugate transposed or not.
+ public static Tensor ormqr(Tensor input, Tensor tau, Tensor other, bool left=true, bool transpose=false) => input.ormqr(tau, other, left, transpose);
// https://pytorch.org/docs/stable/generated/torch.outer
///
@@ -337,23 +364,25 @@ public static (Tensor P, Tensor? L, Tensor? U) lu_unpack(Tensor LU_data, Tensor
public static Tensor pinverse(Tensor input, double rcond = 1e-15, bool hermitian = false) => input.pinverse(rcond, hermitian);
// https://pytorch.org/docs/stable/generated/torch.qr
- [Obsolete("not implemented", true)]
- public static Tensor qr(Tensor input, bool some=true) => throw new NotImplementedException();
+ [Obsolete("torch.qr() is deprecated in favor of torch.linalg.qr() and will be removed in a future PyTorch release.", true)]
+ public static Tensor qr(Tensor input, bool some = true) => throw new NotImplementedException();
// https://pytorch.org/docs/stable/generated/torch.svd
- [Obsolete("not implemented", true)]
+ [Obsolete("torch.qr() is deprecated in favor of torch.linalg.svd() and will be removed in a future PyTorch release.", true)]
public static Tensor svd(Tensor input, bool some=true, bool compute_uv=true) => throw new NotImplementedException();
// https://pytorch.org/docs/stable/generated/torch.svd_lowrank
+ // NOTE TO SELF: there's no native method for this. PyTorch implements it in Python.
[Obsolete("not implemented", true)]
public static Tensor svd_lowrank(Tensor A, int q=6, int niter=2,Tensor? M=null) => throw new NotImplementedException();
// https://pytorch.org/docs/stable/generated/torch.pca_lowrank
+ // NOTE TO SELF: there's no native method for this. PyTorch implements it in Python.
[Obsolete("not implemented", true)]
- public static Tensor pca_lowrank(Tensor A, int? q=null, bool center=true, int niter=2) => throw new NotImplementedException();
+ public static Tensor pca_lowrank(Tensor A, int q=6, bool center=true, int niter=2) => throw new NotImplementedException();
// https://pytorch.org/docs/stable/generated/torch.symeig
- [Obsolete("not implemented", true)]
+ [Obsolete("torch.symeig() is deprecated in favor of torch.linalg.eigh() and will be removed in a future PyTorch release", true)]
public static Tensor symeig(Tensor input, bool eigenvectors = false, bool upper = true) => throw new NotImplementedException();
// https://pytorch.org/docs/stable/generated/torch.lobpcg
@@ -385,25 +414,64 @@ public static Tensor softmax(Tensor input, int dim, ScalarType? dtype = null)
=> torch.special.softmax(input, dim, dtype);
// https://pytorch.org/docs/stable/generated/torch.trapz
- [Obsolete("not implemented", true)]
- public static Tensor trapz(Tensor input, Tensor x, long dim = -1) => throw new NotImplementedException();
+ ///
+ /// Computes the trapezoidal rule along dim. By default the spacing between elements is assumed
+ /// to be 1, but dx can be used to specify a different constant spacing, and x can be used to specify arbitrary spacing along dim.
+ ///
+ /// Values to use when computing the trapezoidal rule.
+ /// Defines spacing between values as specified above.
+ /// The dimension along which to compute the trapezoidal rule. The last (inner-most) dimension by default.
+ public static Tensor trapz(Tensor y, Tensor x, long dim = -1) => trapezoid(y, x, dim);
- [Obsolete("not implemented", true)]
- public static Tensor trapz(Tensor input, double dx = 1, long dim = -1) => throw new NotImplementedException();
+ ///
+ /// Computes the trapezoidal rule along dim. By default the spacing between elements is assumed
+ /// to be 1, but dx can be used to specify a different constant spacing, and x can be used to specify arbitrary spacing along dim.
+ ///
+ /// Values to use when computing the trapezoidal rule.
+ /// Constant spacing between values.
+ /// The dimension along which to compute the trapezoidal rule. The last (inner-most) dimension by default.
+ public static Tensor trapz(Tensor y, double dx = 1, long dim = -1) => trapezoid(y, dx, dim);
// https://pytorch.org/docs/stable/generated/torch.trapezoid
- [Obsolete("not implemented", true)]
- public static Tensor trapezoid(Tensor input, Tensor x, long dim = -1) => throw new NotImplementedException();
+ ///
+ /// Computes the trapezoidal rule along dim. By default the spacing between elements is assumed
+ /// to be 1, but dx can be used to specify a different constant spacing, and x can be used to specify arbitrary spacing along dim.
+ ///
+ /// Values to use when computing the trapezoidal rule.
+ /// Defines spacing between values as specified above.
+ /// The dimension along which to compute the trapezoidal rule. The last (inner-most) dimension by default.
+ public static Tensor trapezoid(Tensor y, Tensor x, long dim = -1) => y.trapezoid(x, dim);
- [Obsolete("not implemented", true)]
- public static Tensor trapezoid(Tensor input, double dx = 1, long dim = -1) => throw new NotImplementedException();
+ ///
+ /// Computes the trapezoidal rule along dim. By default the spacing between elements is assumed
+ /// to be 1, but dx can be used to specify a different constant spacing, and x can be used to specify arbitrary spacing along dim.
+ ///
+ /// Values to use when computing the trapezoidal rule.
+ /// Constant spacing between values.
+ /// The dimension along which to compute the trapezoidal rule. The last (inner-most) dimension by default.
+ public static Tensor trapezoid(Tensor y, double dx = 1, long dim = -1) => y.trapezoid(dx, dim);
// https://pytorch.org/docs/stable/generated/torch.cumulative_trapezoid
- [Obsolete("not implemented", true)]
- public static Tensor cumulative_trapezoid(Tensor input, Tensor x, long dim = -1) => throw new NotImplementedException();
+ ///
+ /// Cumulatively computes the trapezoidal rule along dim. By default the spacing between elements is assumed
+ /// to be 1, but dx can be used to specify a different constant spacing, and x can be used to specify arbitrary spacing along dim.
+ ///
+ /// Values to use when computing the trapezoidal rule.
+ /// Defines spacing between values as specified above.
+ /// The dimension along which to compute the trapezoidal rule. The last (inner-most) dimension by default.
+ public static Tensor cumulative_trapezoid(Tensor y, Tensor x, long dim = -1) => y.cumulative_trapezoid(x, dim);
+
+ ///
+ /// Cumulatively computes the trapezoidal rule along dim. By default the spacing between elements is assumed
+ /// to be 1, but dx can be used to specify a different constant spacing, and x can be used to specify arbitrary spacing along dim.
+ ///
+ /// Values to use when computing the trapezoidal rule.
+ /// Constant spacing between values.
+ /// The dimension along which to compute the trapezoidal rule. The last (inner-most) dimension by default.
+ public static Tensor cumulative_trapezoid(Tensor y, double dx = 1, long dim = -1) => y.cumulative_trapezoid(dx, dim);
// https://pytorch.org/docs/stable/generated/torch.triangular_solve
- [Obsolete("not implemented", true)]
+ [Obsolete("torch.triangular_solve() is deprecated in favor of torch.linalg.solve_triangular() and will be removed in a future PyTorch release.", true)]
static Tensor triangular_solve(
Tensor b,
Tensor A,
diff --git a/src/TorchSharp/Tensor/torch.IndexingSlicingJoiningMutatingOps.cs b/src/TorchSharp/Tensor/torch.IndexingSlicingJoiningMutatingOps.cs
index 499a77934..5bde58340 100644
--- a/src/TorchSharp/Tensor/torch.IndexingSlicingJoiningMutatingOps.cs
+++ b/src/TorchSharp/Tensor/torch.IndexingSlicingJoiningMutatingOps.cs
@@ -1,4 +1,4 @@
-// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
+// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
#nullable enable
using System;
using System.Collections.Generic;
@@ -12,12 +12,21 @@ namespace TorchSharp
public static partial class torch
{
// https://pytorch.org/docs/stable/generated/torch.adjoint
- [Obsolete("not implemented", true)]
- public static Tensor adjoint(Tensor input) => throw new NotImplementedException();
+ ///
+ /// Returns a view of the tensor conjugated and with the last two dimensions transposed.
+ ///
+ /// The input tensor
+ public static Tensor adjoint(Tensor input) => input.adjoint();
// https://pytorch.org/docs/stable/generated/torch.argwhere
- [Obsolete("not implemented", true)]
- public static Tensor argwhere(Tensor input) => throw new NotImplementedException();
+ ///
+ /// Returns a tensor containing the indices of all non-zero elements of input.
+ /// Each row in the result contains the indices of a non-zero element in input.
+ /// The result is sorted lexicographically, with the last index changing the fastest (C-style).
+ /// If input has n dimensions, then the resulting indices tensor out is of size (z×n), where
+ /// z is the total number of non-zero elements in the input tensor.
+ ///
+ public static Tensor argwhere(Tensor input) => input.argwhere();
// https://pytorch.org/docs/stable/generated/torch.cat
///
@@ -25,7 +34,6 @@ public static partial class torch
///
/// A sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension.
/// The dimension over which the tensors are concatenated
- ///
/// All tensors must either have the same shape (except in the concatenating dimension) or be empty.
public static Tensor cat(IList tensors, long dim = 0)
{
@@ -44,9 +52,13 @@ public static Tensor cat(IList tensors, long dim = 0)
}
// https://pytorch.org/docs/stable/generated/torch.concat
- [Obsolete("not implemented", true)]
- public static Tensor concat(IList tensors, long dim = 0)
- => throw new NotImplementedException();
+ ///
+ /// Alias of torch.cat()
+ ///
+ /// A sequence of tensors of the same type. Non-empty tensors provided must have the same shape, except in the cat dimension.
+ /// The dimension over which the tensors are concatenated
+ /// All tensors must either have the same shape (except in the concatenating dimension) or be empty.
+ public static Tensor concat(IList tensors, long dim = 0) => torch.cat(tensors, dim);
// https://pytorch.org/docs/stable/generated/torch.conj
///
@@ -83,11 +95,6 @@ public static Tensor[] dsplit(Tensor input, (long, long, long) indices_or_sectio
public static Tensor[] dsplit(Tensor input, (long, long, long, long) indices_or_sections)
=> input.dsplit(indices_or_sections);
- // https://pytorch.org/docs/stable/generated/torch.column_stack
- [Obsolete("not implemented", true)]
- public static Tensor column_stack(params Tensor[] tensors)
- => throw new NotImplementedException();
-
// https://pytorch.org/docs/stable/generated/torch.dstack
///
/// Stack tensors in sequence depthwise (along third axis).
@@ -377,10 +384,6 @@ public static Tensor narrow(Tensor input, long dim, long start, long length)
/// The new tensor shape.
public static Tensor reshape(Tensor input, params long[] shape) => input.reshape(shape);
- // https://pytorch.org/docs/stable/generated/torch.row_stack
- public static Tensor row_stack(params Tensor[] tensors)
- => throw new NotImplementedException();
-
// https://pytorch.org/docs/stable/generated/torch.select
public static Tensor select(Tensor input, long dim, long index)
=> input.select(dim, index);
@@ -404,19 +407,39 @@ public static Tensor scatter_(Tensor input, long dim, Tensor index, Tensor src)
=> input.scatter_(dim, index, src);
// https://pytorch.org/docs/stable/generated/torch.diagonal_scatter
- [Obsolete("not implemented", true)]
- public static Tensor diagonal_scatter(Tensor input, Tensor src, long offset = 0L, long dim1 = 0L, long dim2 = 1L)
- => throw new NotImplementedException();
+ ///
+ /// Embeds the values of the src tensor into input along the diagonal elements of input, with respect to dim1 and dim2.
+ ///
+ /// The input tensor.
+ /// The tensor to embed into 'this'.
+ /// Which diagonal to consider. Default: main diagonal.
+ /// First dimension with respect to which to take diagonal.
+ /// Second dimension with respect to which to take diagonal.
+ public static Tensor diagonal_scatter(Tensor input, Tensor src, long offset = 0L, long dim1 = 0L, long dim2 = 1L) => input.diagonal_scatter(src, offset, dim1, dim2);
// https://pytorch.org/docs/stable/generated/torch.select_scatter
- [Obsolete("not implemented", true)]
- public static Tensor select_scatter(Tensor input, Tensor src, long dim, long index)
- => throw new NotImplementedException();
+ ///
+ /// Embeds the values of the src tensor into input at the given index. This function returns a tensor with fresh storage; it does not create a view.
+ ///
+ /// The input tensor.
+ /// The tensor to embed into 'this'
+ /// The dimension to insert the slice into
+ /// The index to select with
+ /// This function returns a tensor with fresh storage; it does not create a view.
+ public static Tensor select_scatter(Tensor input, Tensor src, long dim, long index) => input.select_scatter(src, dim, index);
// https://pytorch.org/docs/stable/generated/torch.slice_scatter
- [Obsolete("not implemented", true)]
- public static Tensor slice_scatter(Tensor input, Tensor src, long dim=0L, long? start=null, long? end=null, long step=1L)
- => throw new NotImplementedException();
+ ///
+ /// Embeds the values of the src tensor into input at the given dimension.
+ ///
+ /// The input tensor.
+ /// The tensor to embed into 'this'.
+ /// The dimension to insert the slice into
+ /// The start index of where to insert the slice
+ /// The end index of where to insert the slice
+ /// How many elements to skip
+ public static Tensor slice_scatter(Tensor input, Tensor src, long dim = 0L, long? start = null, long? end = null, long step = 1L)
+ => input.slice_scatter(src, dim, start, end, step);
// https://pytorch.org/docs/stable/generated/torch.scatter_add
///
diff --git a/src/TorchSharp/Tensor/torch.OtherOperations.cs b/src/TorchSharp/Tensor/torch.OtherOperations.cs
index 4eeb18b02..0ad763284 100644
--- a/src/TorchSharp/Tensor/torch.OtherOperations.cs
+++ b/src/TorchSharp/Tensor/torch.OtherOperations.cs
@@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
+using TorchSharp.PInvoke;
using static TorchSharp.PInvoke.LibTorchSharp;
namespace TorchSharp
@@ -119,34 +120,106 @@ public static Tensor bucketize(Tensor input, Tensor boundaries, bool outInt32 =
=> input.bucketize(boundaries, outInt32, right);
// https://pytorch.org/docs/stable/generated/torch.cartesian_prod
- [Obsolete("not implemented", true)]
- public static Tensor cartesian_prod(params Tensor[] tensors)
- => throw new NotImplementedException();
+ ///
+ /// Do cartesian product of the given sequence of tensors.
+ ///
+ ///
+ public static Tensor cartesian_prod(IList tensors)
+ {
+ using var parray = new PinnedArray();
+ IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
+
+ var res = THSTensor_cartesian_prod(tensorsRef, parray.Array.Length);
+ if (res == IntPtr.Zero) { torch.CheckForErrors(); }
+ return new Tensor(res);
+ }
+
+ // https://pytorch.org/docs/stable/generated/torch.cartesian_prod
+ ///
+ /// Do cartesian product of the given sequence of tensors.
+ ///
+ ///
+ public static Tensor cartesian_prod(params Tensor[] tensors) => cartesian_prod((IList)tensors);
// https://pytorch.org/docs/stable/generated/torch.cdist
- [Obsolete("not implemented", true)]
- static Tensor cdist(
+ ///
+ /// Computes batched the p-norm distance between each pair of the two collections of row vectors.
+ ///
+ /// Input tensor of shape BxPxM
+ /// Input tensor of shape BxRxM
+ /// p value for the p-norm distance to calculate between each vector (p > 0)
+ ///
+ /// use_mm_for_euclid_dist_if_necessary - will use matrix multiplication approach to calculate euclidean distance (p = 2) if P > 25 or R > 25
+ /// use_mm_for_euclid_dist - will always use matrix multiplication approach to calculate euclidean distance (p = 2)
+ /// donot_use_mm_for_euclid_dist - will never use matrix multiplication approach to calculate euclidean distance (p = 2)
+ ///
+ ///
+ public static Tensor cdist(
Tensor x1,
Tensor x2,
double p = 2.0,
compute_mode compute_mode = compute_mode.use_mm_for_euclid_dist_if_necessary)
- => throw new NotImplementedException();
+ {
+ if (p < 0)
+ throw new ArgumentException($"p must be non-negative");
+
+ var res = THSTensor_cdist(x1.Handle, x2.Handle, p, (long)compute_mode);
+ if (res == IntPtr.Zero)
+ CheckForErrors();
+ return new Tensor(res);
+ }
// https://pytorch.org/docs/stable/generated/torch.clone
public static Tensor clone(Tensor input) => input.clone();
// https://pytorch.org/docs/stable/generated/torch.combinations
- [Obsolete("not implemented", true)]
- public static IEnumerable combinations(Tensor input, long r = 2L, bool with_replacement = false)
- => throw new NotImplementedException();
+ ///
+ /// Compute combinations of length r of the given tensor
+ ///
+ /// 1D vector.
+ /// Number of elements to combine
+ /// Whether to allow duplication in combination
+ ///
+ public static Tensor combinations(Tensor input, int r = 2, bool with_replacement = false)
+ {
+ if (input.ndim != 1)
+ throw new ArgumentException($"Expected a 1D vector, but got one with {input.ndim} dimensions.");
+ if (r < 0)
+ throw new ArgumentException($"r must be non-negative");
+
+ var res = THSTensor_combinations(input.Handle, r, with_replacement);
+ if (res == IntPtr.Zero)
+ CheckForErrors();
+ return new Tensor(res);
+ }
+
+
// https://pytorch.org/docs/stable/generated/torch.corrcoef
public static Tensor corrcoef(Tensor input) => input.corrcoef();
// https://pytorch.org/docs/stable/generated/torch.cov
- [Obsolete("not implemented", true)]
+ ///
+ /// Estimates the covariance matrix of the variables given by the input matrix, where rows are the variables and columns are the observations.
+ ///
+ /// The input tensor
+ ///
+ /// Difference between the sample size and sample degrees of freedom.
+ /// Defaults to Bessel’s correction, correction = 1 which returns the unbiased estimate,
+ /// even if both fweights and aweights are specified.
+ /// Correction = 0 will return the simple average.
+ ///
+ ///
+ /// A Scalar or 1D tensor of observation vector frequencies representing the number of times each observation should be repeated.
+ /// Its numel must equal the number of columns of input.
+ /// Must have integral dtype.
+ /// A Scalar or 1D array of observation vector weights.
+ /// These relative weights are typically large for observations considered “important” and smaller for
+ /// observations considered less “important”.
+ /// Its numel must equal the number of columns of input.
+ /// Must have floating point dtype.
public static Tensor cov(Tensor input, long correction = 1, Tensor? fweights = null, Tensor? aweights = null)
- => throw new NotImplementedException();
+ => input.cov(correction, fweights, aweights);
// https://pytorch.org/docs/stable/generated/torch.cross
///
@@ -189,9 +262,25 @@ public static Tensor cov(Tensor input, long correction = 1, Tensor? fweights = n
public static Tensor diag(Tensor input, long diagonal = 0) => input.diag(diagonal);
// https://pytorch.org/docs/stable/generated/torch.diag_embed
- [Obsolete("not implemented", true)]
+ ///
+ /// Creates a tensor whose diagonals of certain 2D planes (specified by dim1 and dim2) are filled by input.
+ /// To facilitate creating batched diagonal matrices, the 2D planes formed by the last two dimensions of the returned tensor are chosen by default.
+ ///
+ /// The argument offset controls which diagonal to consider:
+ /// If offset is equal to 0, it is the main diagonal.
+ /// If offset is greater than 0, it is above the main diagonal.
+ /// If offset is less than 0, it is below the main diagonal.
+ ///
+ /// The size of the new matrix will be calculated to make the specified diagonal of the size of the last input dimension.Note that for offset other than 0,
+ ///
+ /// the order of dim1 and dim2 matters.Exchanging them is equivalent to changing the sign of offset.
+ ///
+ /// The input tensor.
+ /// Which diagonal to consider.
+ /// First dimension with respect to which to take diagonal.
+ /// Second dimension with respect to which to take diagonal
public static Tensor diag_embed(Tensor input, long offset = 0L, long dim1 = -2L, long dim2 = -1L)
- => throw new NotImplementedException();
+ => input.diag_embed(offset, dim1, dim2);
// https://pytorch.org/docs/stable/generated/torch.diagflat
///
@@ -295,8 +384,15 @@ public static Tensor einsum(string equation, params Tensor[] tensors)
public static Tensor kron(Tensor input, Tensor other) => input.kron(other);
// https://pytorch.org/docs/stable/generated/torch.rot90
- [Obsolete("not implemented", true)]
- public static Tensor rot90(Tensor input, long k, params long[] dims) => throw new NotImplementedException();
+ ///
+ /// Rotate a n-D tensor by 90 degrees in the plane specified by dims axis.
+ /// Rotation direction is from the first towards the second axis if k is greater than 0,
+ /// and from the second towards the first for k less than 0.
+ ///
+ /// The input tensor
+ /// The number of times to rotate.
+ /// Axes to rotate
+ public static Tensor rot90(Tensor input, long k = 1, (long, long)? dims = null) => input.rot90(k, dims);
// https://pytorch.org/docs/stable/generated/torch.gcd
///
@@ -395,8 +491,7 @@ static Tensor histogram(
/// All tensors need to be of the same size.
static IEnumerable meshgrid(IEnumerable tensors, indexing indexing = indexing.ij)
{
- var idx = indexing switch
- {
+ var idx = indexing switch {
indexing.ij => "ij",
indexing.xy => "xy",
_ => throw new ArgumentOutOfRangeException()
@@ -534,6 +629,7 @@ public static Tensor[] meshgrid(IEnumerable tensors, string indexing = "
public static Tensor roll(Tensor input, ReadOnlySpan shifts, ReadOnlySpan dims = default) => input.roll(shifts, dims);
// https://pytorch.org/docs/stable/generated/torch.searchsorted
+ [Obsolete("not implemented", true)]
static Tensor searchsorted(
Tensor sorted_sequence,
Tensor values,
@@ -545,7 +641,7 @@ static Tensor searchsorted(
// https://pytorch.org/docs/stable/generated/torch.tensordot
[Obsolete("not implemented", true)]
- public static Tensor tensordot(Tensor a, Tensor b, long dims=2) => throw new NotImplementedException();
+ public static Tensor tensordot(Tensor a, Tensor b, long dims = 2) => throw new NotImplementedException();
// https://pytorch.org/docs/stable/generated/torch.trace
///
@@ -559,28 +655,49 @@ static Tensor searchsorted(
public static Tensor tril(Tensor input, long diagonal = 0) => input.tril(diagonal);
// https://pytorch.org/docs/stable/generated/torch.tril_indices
- [Obsolete("not implemented", true)]
- static Tensor tril_indices(
+ public static Tensor tril_indices(
long row,
long col,
long offset = 0L,
ScalarType dtype = ScalarType.Int64,
- Device? device = null,
- layout layout = layout.strided)
- => throw new NotImplementedException();
+ Device? device = null)
+ {
+ if (!torch.is_integral(dtype))
+ throw new ArgumentException("dtype must be integral.");
+
+ if (device == null) {
+ device = torch.CPU;
+ }
+
+ var res = LibTorchSharp.THSTensor_tril_indices(row, col, offset, (sbyte)dtype, (int)device.type, device.index);
+ if (res == IntPtr.Zero)
+ CheckForErrors();
+ return new Tensor(res);
+ }
// https://pytorch.org/docs/stable/generated/torch.triu
public static Tensor triu(Tensor input, long diagonal = 0L) => input.triu(diagonal);
// https://pytorch.org/docs/stable/generated/torch.triu_indices
- static Tensor triu_indices(
+ public static Tensor triu_indices(
long row,
long col,
long offset = 0L,
- ScalarType dtype = ScalarType.Float64,
- Device? device = null,
- layout layout = layout.strided)
- => throw new NotImplementedException();
+ ScalarType dtype = ScalarType.Int64,
+ Device? device = null)
+ {
+ if (!torch.is_integral(dtype))
+ throw new ArgumentException("dtype must be integral.");
+
+ if (device == null) {
+ device = torch.CPU;
+ }
+
+ var res = LibTorchSharp.THSTensor_triu_indices(row, col, offset, (sbyte)dtype, (int)device.type, device.index);
+ if (res == IntPtr.Zero)
+ CheckForErrors();
+ return new Tensor(res);
+ }
// https://pytorch.org/docs/stable/generated/torch.vander
public static Tensor vander(Tensor x, long N = -1, bool increasing = false) => x.vander(N, increasing);
@@ -610,7 +727,15 @@ static Tensor triu_indices(
public static Tensor resolve_conj(Tensor input) => input.resolve_conj();
// https://pytorch.org/docs/stable/generated/torch.resolve_neg
- [Obsolete("not implemented", true)]
- public static Tensor resolve_neg(Tensor input) => throw new NotImplementedException();
+ ///
+ /// Returns a new tensor with materialized negation if input’s negative bit is set to True, else returns input.
+ /// The output tensor will always have its negative bit set to False.
+ ///
+ public static Tensor resolve_neg(Tensor input) => input.resolve_neg();
+
+ ///
+ /// Returns true if the input's negative bit is set to True.
+ ///
+ public static Tensor is_neg(Tensor input) => input.is_neg();
}
}
\ No newline at end of file
diff --git a/src/TorchSharp/Tensor/torch.Parallelism.cs b/src/TorchSharp/Tensor/torch.Parallelism.cs
index b9bb9dac8..916ba73ae 100644
--- a/src/TorchSharp/Tensor/torch.Parallelism.cs
+++ b/src/TorchSharp/Tensor/torch.Parallelism.cs
@@ -1,27 +1,57 @@
-// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
+// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
#nullable enable
using System;
using System.Diagnostics.Contracts;
+using static TorchSharp.PInvoke.LibTorchSharp;
+
namespace TorchSharp
{
// https://pytorch.org/docs/stable/torch#parallelism
public static partial class torch
{
// https://pytorch.org/docs/stable/generated/torch.get_num_threads
- [Pure, Obsolete("not implemented", true)]
- public static int get_num_threads() => throw new NotImplementedException();
+ ///
+ /// Returns the number of threads used for parallelizing CPU operations
+ ///
+ public static int get_num_threads()
+ {
+ var res = THSTorch_get_num_threads();
+ if (res == -1) CheckForErrors();
+ return res;
+ }
// https://pytorch.org/docs/stable/generated/torch.set_num_threads
- [Obsolete("not implemented", true)]
- public static void set_num_threads(int num) => throw new NotImplementedException();
+ ///
+ /// Sets the number of threads used for parallelizing CPU operations
+ ///
+ /// The number of threads to use.
+ public static void set_num_threads(int num)
+ {
+ THSTorch_set_num_threads(num);
+ CheckForErrors();
+ }
// https://pytorch.org/docs/stable/generated/torch.get_num_interop_threads
- [Pure, Obsolete("not implemented", true)]
- public static int get_num_interop_threads() => throw new NotImplementedException();
+ ///
+ /// Returns the number of threads used for inter-op parallelism on CPU (e.g. in JIT interpreter)
+ ///
+ public static int get_num_interop_threads()
+ {
+ var res = THSTorch_get_num_interop_threads();
+ if (res == -1) CheckForErrors();
+ return res;
+ }
// https://pytorch.org/docs/stable/generated/torch.set_num_interop_threads
- [Obsolete("not implemented", true)]
- public static void set_num_interop_threads(int num) => throw new NotImplementedException();
+ ///
+ /// Sets the number of threads used for inter-op parallelism on CPU (e.g. in JIT interpreter)
+ ///
+ /// The number of threads to use.
+ public static void set_num_interop_threads(int num)
+ {
+ THSTorch_set_num_interop_threads(num);
+ CheckForErrors();
+ }
}
}
\ No newline at end of file
diff --git a/src/TorchSharp/Tensor/torch.PointwiseOps.cs b/src/TorchSharp/Tensor/torch.PointwiseOps.cs
index fc734a88d..0fccbd8ce 100644
--- a/src/TorchSharp/Tensor/torch.PointwiseOps.cs
+++ b/src/TorchSharp/Tensor/torch.PointwiseOps.cs
@@ -1,8 +1,9 @@
-// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
+// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
#nullable enable
using System;
using System.Collections.Generic;
using System.Diagnostics.Contracts;
+using ICSharpCode.SharpZipLib.BZip2;
namespace TorchSharp
{
@@ -795,7 +796,7 @@ public static Tensor fake_quantize_per_tensor_affine(Tensor input, Tensor scale,
/// Replaces each element with the floor of the input, the largest integer less than or equal to each element.
///
/// The input tensor.
- public static Tensor floor_(Tensor input) => input.exp_();
+ public static Tensor floor_(Tensor input) => input.floor_();
// https://pytorch.org/docs/stable/generated/torch.floor_divide
///
@@ -805,10 +806,18 @@ public static Tensor fake_quantize_per_tensor_affine(Tensor input, Tensor scale,
/// the dividend
/// the divisor
/// the output tensor
- ///
- [Pure, Obsolete("not implemented", true)]
- public static Tensor floor_divide(Tensor input, Tensor other)
- => throw new NotImplementedException();
+ [Pure]
+ public static Tensor floor_divide(Tensor input, Tensor other) => input.floor_divide(other);
+
+ // https://pytorch.org/docs/stable/generated/torch.floor_divide
+ ///
+ /// Computes input divided by other, elementwise, and floors the result.
+ /// Supports broadcasting to a common shape, type promotion, and integer and float inputs.
+ ///
+ /// the dividend
+ /// the divisor
+ /// the output tensor
+ public static Tensor floor_divide_(Tensor input, Tensor other) => input.floor_divide_(other);
// https://pytorch.org/docs/stable/generated/torch.fmod
///
@@ -1441,6 +1450,13 @@ public static Tensor quantized_max_pool2d(Tensor input, long[] kernel_size, long
/// The input tensor.
[Pure]public static Tensor sign(Tensor input) => input.sign();
+ // https://pytorch.org/docs/stable/generated/torch.sign
+ ///
+ /// Returns a new tensor with the signs (-1, 0, 1) of the elements of input.
+ ///
+ /// The input tensor.
+ [Pure] public static Tensor sign_(Tensor input) => input.sign_();
+
// https://pytorch.org/docs/stable/generated/torch.sgn
///
/// This function is an extension of torch.sign() to complex tensors.
@@ -1450,8 +1466,20 @@ public static Tensor quantized_max_pool2d(Tensor input, long[] kernel_size, long
///
/// the input tensor.
/// the output tensor.
- [Pure, Obsolete("not implemented", true)]
- public static Tensor sgn(Tensor input) => throw new NotImplementedException();
+ [Pure]
+ public static Tensor sgn(Tensor input) => input.sgn();
+
+ // https://pytorch.org/docs/stable/generated/torch.sgn
+ ///
+ /// This function is an extension of torch.sign() to complex tensors.
+ /// It computes a new tensor whose elements have the same angles as the corresponding
+ /// elements of input and absolute values (i.e. magnitudes) of one for complex tensors
+ /// and is equivalent to torch.sign() for non-complex tensors.
+ ///
+ /// the input tensor.
+ /// the output tensor.
+ [Pure]
+ public static Tensor sgn_(Tensor input) => input.sgn_();
// https://pytorch.org/docs/stable/generated/torch.signbit
///
@@ -1596,7 +1624,7 @@ public static Tensor quantized_max_pool2d(Tensor input, long[] kernel_size, long
// https://pytorch.org/docs/stable/generated/torch.tan
///
- /// Computes the tangent of the elements of input.
+ /// Computes the tangent of the elements of input. In-place version.
///
///
public static Tensor tan_(Tensor input) => input.tan_();
@@ -1617,10 +1645,17 @@ public static Tensor quantized_max_pool2d(Tensor input, long[] kernel_size, long
public static Tensor tanh_(Tensor input) => input.tanh_();
// https://pytorch.org/docs/stable/generated/torch.true_divide
- // TODO: implement true_divide
- [Pure, Obsolete("not implemented", true)]
- public static Tensor true_divide(Tensor dividend, Tensor divisor)
- => throw new NotImplementedException();
+ ///
+ /// Alias for torch.div() with rounding_mode=None.
+ ///
+ [Pure]
+ public static Tensor true_divide(Tensor dividend, Tensor divisor) => dividend.true_divide(divisor);
+
+ // https://pytorch.org/docs/stable/generated/torch.true_divide
+ ///
+ /// Alias for torch.div_() with rounding_mode=None.
+ ///
+ public static Tensor true_divide_(Tensor dividend, Tensor divisor) => dividend.true_divide_(divisor);
// https://pytorch.org/docs/stable/generated/torch.trunc
///
diff --git a/src/TorchSharp/Tensor/torch.RandomSampling.cs b/src/TorchSharp/Tensor/torch.RandomSampling.cs
index b0e481496..e74fcdd6f 100644
--- a/src/TorchSharp/Tensor/torch.RandomSampling.cs
+++ b/src/TorchSharp/Tensor/torch.RandomSampling.cs
@@ -1,4 +1,4 @@
-// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
+// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
#nullable enable
using System;
using System.Diagnostics.Contracts;
@@ -10,24 +10,36 @@ namespace TorchSharp
public static partial class torch
{
// https://pytorch.org/docs/stable/generated/torch.seed
- [Obsolete("not implemented", true)]
- public static int seed() => throw new NotImplementedException();
+ ///
+ /// Sets the seed for generating random numbers to a non-deterministic random number. Returns a 64 bit number used to seed the RNG.
+ ///
+ public static long seed() => torch.random.seed();
// https://pytorch.org/docs/stable/generated/torch.manual_seed
- [Obsolete("not implemented", true)]
- public static Generator manual_seed(long seed) => throw new NotImplementedException();
+ ///
+ /// Sets the seed for generating random numbers. Returns a torch.Generator object.
+ ///
+ /// The desired seed.
+ public static Generator manual_seed(long seed) => torch.random.manual_seed(seed);
// https://pytorch.org/docs/stable/generated/torch.initial_seed
- [Obsolete("not implemented", true)]
- public static long initial_seed() => throw new NotImplementedException();
+ ///
+ /// Returns the initial seed for generating random numbers.
+ ///
+ public static long initial_seed() => torch.random.initial_seed();
// https://pytorch.org/docs/stable/generated/torch.get_rng_state
- [Obsolete("not implemented", true)]
- public static Tensor get_rng_state() => throw new NotImplementedException();
+ ///
+ /// Returns the random number generator state as a torch.ByteTensor.
+ ///
+ public static Tensor get_rng_state() => torch.random.get_rng_state();
// https://pytorch.org/docs/stable/generated/torch.set_rng_state
- [Obsolete("not implemented", true)]
- public static void set_rng_state(Tensor new_state) => throw new NotImplementedException();
+ ///
+ /// Sets the random number generator state.
+ ///
+ /// The desired state
+ public static void set_rng_state(Tensor new_state) => torch.random.set_rng_state(new_state);
// https://pytorch.org/docs/stable/generated/torch.bernoulli
///
diff --git a/src/TorchSharp/Tensor/torch.ReductionOps.cs b/src/TorchSharp/Tensor/torch.ReductionOps.cs
index 82fc36073..c413776ff 100644
--- a/src/TorchSharp/Tensor/torch.ReductionOps.cs
+++ b/src/TorchSharp/Tensor/torch.ReductionOps.cs
@@ -1,7 +1,11 @@
-// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
+// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
#nullable enable
using System;
+using System.Collections.Generic;
using System.Diagnostics.Contracts;
+using System.Linq;
+
+using static TorchSharp.PInvoke.LibTorchSharp;
namespace TorchSharp
{
diff --git a/src/TorchSharp/Tensor/torch.Tensors.cs b/src/TorchSharp/Tensor/torch.Tensors.cs
index d4d5888b2..55cb1bf19 100644
--- a/src/TorchSharp/Tensor/torch.Tensors.cs
+++ b/src/TorchSharp/Tensor/torch.Tensors.cs
@@ -1,4 +1,4 @@
-// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
+// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
#nullable enable
using System;
using System.Diagnostics.Contracts;
@@ -15,20 +15,27 @@ public static partial class torch
[Pure]public static bool is_storage(object obj) => obj is Storage;
// https://pytorch.org/docs/stable/generated/torch.is_complex
- [Pure, Obsolete("not implemented", true)]
- public static bool is_complex(object input) => throw new NotImplementedException();
-
- // https://pytorch.org/docs/stable/generated/torch.is_conj
- [Pure, Obsolete("not implemented", true)]
- public static bool is_conj(object input) => throw new NotImplementedException();
+ ///
+ /// Returns True if the data type of input is a complex data type i.e., one of torch.complex64, and torch.complex128.
+ ///
+ /// The input tensor
+ public static bool is_complex(Tensor input) => is_complex(input.dtype);
// https://pytorch.org/docs/stable/generated/torch.is_floating_point
- [Pure, Obsolete("not implemented", true)]
- public static bool is_floating_point(object input) => throw new NotImplementedException();
+ ///
+ /// Returns True if the data type of input is a floating point data type.
+ ///
+ /// The input tensor
+ public static bool is_floating_point(Tensor input) => is_floating_point(input.dtype);
// https://pytorch.org/docs/stable/generated/torch.is_nonzero
- [Pure, Obsolete("not implemented", true)]
- public static bool is_nonzero(object input) => throw new NotImplementedException();
+ ///
+ /// Returns True if the input is a single element tensor which is not equal to zero after type conversions,
+ /// i.e. not equal to torch.tensor([0.]) or torch.tensor([0]) or torch.tensor([False]).
+ /// Throws an InvalidOperationException if torch.numel() != 1.
+ ///
+ /// The input tensor
+ public static bool is_nonzero(Tensor input) => input.is_nonzero();
// https://pytorch.org/docs/stable/generated/torch.set_default_dtype
///
@@ -56,15 +63,5 @@ public static partial class torch
/// Get the number of elements in the input tensor.
///
[Pure]public static long numel(Tensor input) => input.numel();
-
- // https://pytorch.org/docs/stable/generated/torch.set_printoptions
- [Obsolete("not implemented", true)]
- public static void set_printoptions(
- int precision = 4,
- int threshold = 1000,
- int edgeitems = 3,
- int linewidth = 80,
- PrintOptionsProfile profile = PrintOptionsProfile.@default,
- bool? sci_mode = null) => throw new NotImplementedException();
}
}
\ No newline at end of file
diff --git a/src/TorchSharp/Tensor/torch.Utilities.cs b/src/TorchSharp/Tensor/torch.Utilities.cs
index 601017cd6..32d8053c0 100644
--- a/src/TorchSharp/Tensor/torch.Utilities.cs
+++ b/src/TorchSharp/Tensor/torch.Utilities.cs
@@ -1,7 +1,8 @@
-// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
+// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
#nullable enable
using System;
using System.Diagnostics.Contracts;
+using static TorchSharp.PInvoke.LibTorchSharp;
namespace TorchSharp
{
@@ -13,16 +14,28 @@ public static partial class torch
public static bool compiled_with_cxx11_abi() => throw new NotImplementedException();
// https://pytorch.org/docs/stable/generated/torch.result_type
- [Pure, Obsolete("not implemented", true)]
- public static ScalarType result_type(Tensor tensor1, Tensor tensor2) => throw new NotImplementedException();
+ public static ScalarType result_type(Tensor tensor1, Tensor tensor2)
+ {
+ var res = THSTensor_result_type(tensor1.Handle, tensor2.Handle);
+ if (res == -1) CheckForErrors();
+ return (ScalarType)res;
+ }
// https://pytorch.org/docs/stable/generated/torch.can_cast
- [Pure, Obsolete("not implemented", true)]
- public static bool can_cast(ScalarType from, ScalarType to) => throw new NotImplementedException();
+ public static bool can_cast(ScalarType from, ScalarType to)
+ {
+ var res = THSTorch_can_cast((int)from, (int)to);
+ if (res == -1) CheckForErrors();
+ return res != 0;
+ }
// https://pytorch.org/docs/stable/generated/torch.promote_types
- [Obsolete("not implemented", true)]
- public static bool promote_types(ScalarType type1, ScalarType type2) => throw new NotImplementedException();
+ public static ScalarType promote_types(ScalarType type1, ScalarType type2)
+ {
+ var res = THSTorch_promote_types((int)type1, (int)type2);
+ if (res == -1) CheckForErrors();
+ return (ScalarType)res;
+ }
// https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms
[Obsolete("not implemented", true)]
diff --git a/src/TorchSharp/Tensor/torch.cs b/src/TorchSharp/Tensor/torch.cs
index f8e027aa9..1a70dba88 100644
--- a/src/TorchSharp/Tensor/torch.cs
+++ b/src/TorchSharp/Tensor/torch.cs
@@ -65,6 +65,14 @@ public static Tensor column_stack(IList tensors)
return new Tensor(res);
}
+ ///
+ /// Creates a new tensor by horizontally stacking the input tensors.
+ ///
+ /// A list of input tensors.
+ ///
+ /// Equivalent to torch.hstack(tensors), except each zero or one dimensional tensor t in tensors is first reshaped into a (t.numel(), 1) column before being stacked horizontally.
+ public static Tensor column_stack(params Tensor[] tensors) => column_stack((IList)tensors);
+
///
/// Stack tensors in sequence vertically (row wise).
///
@@ -80,6 +88,13 @@ public static Tensor row_stack(IList tensors)
return new Tensor(res);
}
+ ///
+ /// Stack tensors in sequence vertically (row wise).
+ ///
+ ///
+ ///
+ public static Tensor row_stack(params Tensor[] tensors) => row_stack((IList)tensors);
+
///
/// Removes a tensor dimension.
///
@@ -165,12 +180,6 @@ public static Tensor _sample_dirichlet(Tensor input, Generator? generator = null
/// The input tensor.
public static bool is_conj(Tensor input) => input.is_conj();
- ///
- /// Replaces each element with the signs (-1, 0, 1) of the elements of input.
- ///
- /// The input tensor.
- public static Tensor sign_(Tensor input) => input.sign_();
-
///
/// Calculates the standard deviation and mean of all elements in the tensor.
///
diff --git a/src/TorchSharp/Torch.cs b/src/TorchSharp/Torch.cs
index 167feca3d..915f5683b 100644
--- a/src/TorchSharp/Torch.cs
+++ b/src/TorchSharp/Torch.cs
@@ -288,6 +288,16 @@ public static Device InitializeDevice(Device? device)
public static partial class random
{
+ ///
+ /// Sets the seed for generating random numbers to a non-deterministic random number. Returns a 64 bit number used to seed the RNG.
+ ///
+ public static long seed() => Generator.Default.seed();
+
+ ///
+ /// Returns the initial seed for generating random numbers.
+ ///
+ public static long initial_seed() => Generator.Default.initial_seed();
+
///
/// Sets the seed for generating random numbers. Returns a torch.Generator object.
///
@@ -301,6 +311,23 @@ public static Generator manual_seed(long seed)
CheckForErrors();
return new Generator(res);
}
+
+ ///
+ /// Returns the random number generator state as a torch.ByteTensor.
+ ///
+ ///
+ public static Tensor get_rng_state()
+ {
+ return Generator.Default.get_state();
+ }
+ ///
+ /// Sets the random number generator state.
+ ///
+ /// The desired state
+ public static void set_rng_state(Tensor new_state)
+ {
+ Generator.Default.set_state(new_state);
+ }
}
public static partial class nn
diff --git a/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj b/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj
index d7a7458f6..8bbc5d293 100644
--- a/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj
+++ b/test/TorchSharpTest.WithCudaBinaries/TorchSharpTest.WithCudaBinaries.csproj
@@ -20,7 +20,9 @@
Always
+
+
diff --git a/test/TorchSharpTest/LinearAlgebra.cs b/test/TorchSharpTest/LinearAlgebra.cs
new file mode 100644
index 000000000..ce15e8c48
--- /dev/null
+++ b/test/TorchSharpTest/LinearAlgebra.cs
@@ -0,0 +1,761 @@
+// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
+using System;
+using System.IO;
+using System.Linq;
+using System.Runtime.InteropServices;
+using System.Collections.Generic;
+using System.Globalization;
+using Xunit;
+using Xunit.Sdk;
+using static TorchSharp.torch;
+
+#nullable enable
+
+namespace TorchSharp
+{
+#if NET472_OR_GREATER
+ [Collection("Sequential")]
+#endif // NET472_OR_GREATER
+ public class LinearAlgebra
+ {
+
+
+ [Fact]
+ [TestOf(nameof(torch.lu))]
+ public void TestLUSolve()
+ {
+ var A = torch.randn(2, 3, 3);
+ var b = torch.randn(2, 3, 1);
+
+ {
+ var (A_LU, pivots, infos) = torch.lu(A);
+
+ Assert.NotNull(A_LU);
+ Assert.NotNull(pivots);
+ Assert.Null(infos);
+
+ Assert.Equal(new long[] { 2, 3, 3 }, A_LU.shape);
+ Assert.Equal(new long[] { 2, 3 }, pivots.shape);
+
+ var x = torch.lu_solve(b, A_LU, pivots);
+ Assert.Equal(new long[] { 2, 3, 1 }, x.shape);
+
+ var y = torch.norm(torch.bmm(A, x) - b);
+ Assert.Empty(y.shape);
+ }
+
+ {
+ var (A_LU, pivots, infos) = torch.lu(A, get_infos: true);
+
+ Assert.NotNull(A_LU);
+ Assert.NotNull(pivots);
+ Assert.NotNull(infos);
+
+ Assert.Equal(new long[] { 2, 3, 3 }, A_LU.shape);
+ Assert.Equal(new long[] { 2, 3 }, pivots.shape);
+ Assert.Equal(new long[] { 2 }, infos.shape);
+
+ var x = torch.lu_solve(b, A_LU, pivots);
+ Assert.Equal(new long[] { 2, 3, 1 }, x.shape);
+
+ var y = torch.norm(torch.bmm(A, x) - b);
+ Assert.Empty(y.shape);
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(torch.lu_unpack))]
+ public void TestLUUnpack()
+ {
+ var A = torch.randn(2, 3, 3);
+
+ {
+ var (A_LU, pivots, infos) = torch.lu(A);
+
+ Assert.NotNull(A_LU);
+ Assert.NotNull(pivots);
+ Assert.Null(infos);
+
+ var (P, A_L, A_U) = torch.lu_unpack(A_LU, pivots);
+
+ Assert.NotNull(P);
+ Assert.NotNull(A_L);
+ Assert.NotNull(A_U);
+
+ Assert.Equal(new long[] { 2, 3, 3 }, P.shape);
+ Assert.Equal(new long[] { 2, 3, 3 }, A_L!.shape);
+ Assert.Equal(new long[] { 2, 3, 3 }, A_U!.shape);
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(Tensor.mul))]
+ public void TestMul()
+ {
+ var x = torch.ones(new long[] { 100, 100 });
+
+ var y = x.mul(0.5f.ToScalar());
+
+ var ydata = y.data();
+ var xdata = x.data();
+
+ for (int i = 0; i < 100; i++) {
+ for (int j = 0; j < 100; j++) {
+ Assert.Equal(ydata[i + j], xdata[i + j] * 0.5f);
+ }
+ }
+ }
+
+ void TestMmGen(Device device)
+ {
+ {
+ var x1 = torch.ones(new long[] { 1, 2 }, device: device);
+ var x2 = torch.ones(new long[] { 2, 1 }, device: device);
+
+ var y = x1.mm(x2).to(DeviceType.CPU);
+
+ var ydata = y.data();
+
+ Assert.Equal(2.0f, ydata[0]);
+ }
+ //System.Runtime.InteropServices.ExternalException : addmm for CUDA tensors only supports floating - point types.Try converting the tensors with.float() at C:\w\b\windows\pytorch\aten\src\THC / generic / THCTensorMathBlas.cu:453
+ if (device.type == DeviceType.CPU) {
+ var x1 = torch.ones(new long[] { 1, 2 }, int64, device: device);
+ var x2 = torch.ones(new long[] { 2, 1 }, int64, device: device);
+
+ var y = x1.mm(x2).to(DeviceType.CPU);
+
+ var ydata = y.data();
+
+ Assert.Equal(2L, ydata[0]);
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(torch.CPU))]
+ public void TestMmCpu()
+ {
+ TestMmGen(torch.CPU);
+ }
+
+ [Fact]
+ [TestOf(nameof(torch.CUDA))]
+ public void TestMmCuda()
+ {
+ if (torch.cuda.is_available()) {
+ TestMmGen(torch.CUDA);
+ }
+ }
+
+ void TestMVGen(Device device)
+ {
+ {
+ var mat1 = torch.ones(new long[] { 4, 3 }, device: device);
+ var vec1 = torch.ones(new long[] { 3 }, device: device);
+
+ var y = mat1.mv(vec1).to(DeviceType.CPU);
+
+ Assert.Equal(4, y.shape[0]);
+ }
+ }
+
+ void TestAddMVGen(Device device)
+ {
+ {
+ var x1 = torch.ones(new long[] { 4 }, device: device);
+ var mat1 = torch.ones(new long[] { 4, 3 }, device: device);
+ var vec1 = torch.ones(new long[] { 3 }, device: device);
+
+ var y = x1.addmv(mat1, vec1).to(DeviceType.CPU);
+
+ Assert.Equal(4, y.shape[0]);
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(torch.CPU))]
+ public void TestMVCpu()
+ {
+ TestMVGen(torch.CPU);
+ }
+
+ [Fact]
+ [TestOf(nameof(torch.CUDA))]
+ public void TestMVCuda()
+ {
+ if (torch.cuda.is_available()) {
+ TestMVGen(torch.CUDA);
+ }
+ }
+
+ [Fact]
+ public void TestAddMVCpu()
+ {
+ TestAddMVGen(torch.CPU);
+ }
+
+ [Fact]
+ [TestOf(nameof(torch.CUDA))]
+ public void TestAddMVCuda()
+ {
+ if (torch.cuda.is_available()) {
+ TestAddMVGen(torch.CUDA);
+ }
+ }
+
+ void TestAddRGen(Device device)
+ {
+ {
+ var x1 = torch.ones(new long[] { 4, 3 }, device: device);
+ var vec1 = torch.ones(new long[] { 4 }, device: device);
+ var vec2 = torch.ones(new long[] { 3 }, device: device);
+
+ var y = x1.addr(vec1, vec2).to(DeviceType.CPU);
+
+ Assert.Equal(new long[] { 4, 3 }, y.shape);
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(torch.CPU))]
+ public void TestAddRCpu()
+ {
+ TestAddRGen(torch.CPU);
+ }
+
+ [Fact]
+ [TestOf(nameof(torch.CUDA))]
+ public void TestAddRCuda()
+ {
+ if (torch.cuda.is_available()) {
+ TestAddRGen(torch.CUDA);
+ }
+ }
+
+
+
+ [Fact]
+ [TestOf(nameof(Tensor.vdot))]
+ public void VdotTest()
+ {
+ var a = new float[] { 1.0f, 2.0f, 3.0f };
+ var b = new float[] { 1.0f, 2.0f, 3.0f };
+ var expected = torch.tensor(a.Zip(b).Select(x => x.First * x.Second).Sum());
+ var res = torch.tensor(a).vdot(torch.tensor(b));
+ Assert.True(res.allclose(expected));
+ }
+
+ [Fact]
+ [TestOf(nameof(Tensor.vander))]
+ public void VanderTest()
+ {
+ var x = torch.tensor(new int[] { 1, 2, 3, 5 });
+ {
+ var res = x.vander();
+ var expected = torch.tensor(new long[] { 1, 1, 1, 1, 8, 4, 2, 1, 27, 9, 3, 1, 125, 25, 5, 1 }, 4, 4);
+ Assert.Equal(expected, res);
+ }
+ {
+ var res = x.vander(3);
+ var expected = torch.tensor(new long[] { 1, 1, 1, 4, 2, 1, 9, 3, 1, 25, 5, 1 }, 4, 3);
+ Assert.Equal(expected, res);
+ }
+ {
+ var res = x.vander(3, true);
+ var expected = torch.tensor(new long[] { 1, 1, 1, 1, 2, 4, 1, 3, 9, 1, 5, 25 }, 4, 3);
+ Assert.Equal(expected, res);
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(torch.linalg.vander))]
+ public void LinalgVanderTest()
+ {
+ var x = torch.tensor(new int[] { 1, 2, 3, 5 });
+ {
+ var res = torch.linalg.vander(x);
+ var expected = torch.tensor(new long[] { 1, 1, 1, 1, 1, 2, 4, 8, 1, 3, 9, 27, 1, 5, 25, 125 }, 4, 4);
+ Assert.Equal(expected, res);
+ }
+ {
+ var res = torch.linalg.vander(x, 3);
+ var expected = torch.tensor(new long[] { 1, 1, 1, 1, 2, 4, 1, 3, 9, 1, 5, 25 }, 4, 3);
+ Assert.Equal(expected, res);
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.cholesky))]
+ public void CholeskyTest()
+ {
+ var a = torch.randn(new long[] { 3, 2, 2 }, float64);
+ a = a.matmul(a.swapdims(-2, -1)); // Worked this in to get it tested. Alias for 'transpose'
+ var l = linalg.cholesky(a);
+
+ Assert.True(a.allclose(l.matmul(l.swapaxes(-2, -1)))); // Worked this in to get it tested. Alias for 'transpose'
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.cholesky_ex))]
+ public void CholeskyExTest()
+ {
+ var a = torch.randn(new long[] { 3, 2, 2 }, float64);
+ a = a.matmul(a.swapdims(-2, -1)); // Worked this in to get it tested. Alias for 'transpose'
+ var (l, info) = linalg.cholesky_ex(a);
+
+ Assert.True(a.allclose(l.matmul(l.swapaxes(-2, -1))));
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.inv))]
+ public void InvTest()
+ {
+ var a = torch.randn(new long[] { 3, 2, 2 }, float64);
+ var l = linalg.inv(a);
+
+ Assert.Equal(a.shape, l.shape);
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.inv_ex))]
+ public void InvExTest()
+ {
+ var a = torch.randn(new long[] { 3, 2, 2 }, float64);
+ var (l, info) = linalg.inv_ex(a);
+
+ Assert.Equal(a.shape, l.shape);
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.cond))]
+ public void CondTestF64()
+ {
+ {
+ var a = torch.randn(new long[] { 3, 3, 3 }, float64);
+ // The following mostly checks that the runtime interop doesn't blow up.
+ _ = linalg.cond(a);
+ _ = linalg.cond(a, "fro");
+ _ = linalg.cond(a, "nuc");
+ _ = linalg.cond(a, 1);
+ _ = linalg.cond(a, -1);
+ _ = linalg.cond(a, 2);
+ _ = linalg.cond(a, -2);
+ _ = linalg.cond(a, Double.PositiveInfinity);
+ _ = linalg.cond(a, Double.NegativeInfinity);
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.cond))]
+ public void CondTestCF64()
+ {
+ {
+ var a = torch.randn(new long[] { 3, 3, 3 }, complex128);
+ // The following mostly checks that the runtime interop doesn't blow up.
+ _ = linalg.cond(a);
+ _ = linalg.cond(a, "fro");
+ _ = linalg.cond(a, "nuc");
+ _ = linalg.cond(a, 1);
+ _ = linalg.cond(a, -1);
+ _ = linalg.cond(a, 2);
+ _ = linalg.cond(a, -2);
+ _ = linalg.cond(a, Double.PositiveInfinity);
+ _ = linalg.cond(a, Double.NegativeInfinity);
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.qr))]
+ public void QRTest()
+ {
+ var a = torch.randn(new long[] { 4, 25, 25 });
+
+ var l = linalg.qr(a);
+
+ Assert.Equal(a.shape, l.Q.shape);
+ Assert.Equal(a.shape, l.R.shape);
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.solve))]
+ public void SolveTest()
+ {
+ var A = torch.randn(3, 3);
+ var b = torch.randn(3);
+ var x = torch.linalg.solve(A, b);
+ Assert.True(A.matmul(x).allclose(b, rtol: 1e-03, atol: 1e-06));
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.svd))]
+ public void SVDTest()
+ {
+ var a = torch.randn(new long[] { 4, 25, 15 });
+
+ var l = linalg.svd(a);
+
+ Assert.Equal(new long[] { 4, 25, 25 }, l.U.shape);
+ Assert.Equal(new long[] { 4, 15 }, l.S.shape);
+ Assert.Equal(new long[] { 4, 15, 15 }, l.Vh.shape);
+
+ l = linalg.svd(a, fullMatrices: false);
+
+ Assert.Equal(a.shape, l.U.shape);
+ Assert.Equal(new long[] { 4, 15 }, l.S.shape);
+ Assert.Equal(new long[] { 4, 15, 15 }, l.Vh.shape);
+ }
+
+
+ [Fact]
+ [TestOf(nameof(linalg.svdvals))]
+ public void SVDValsTest()
+ {
+ var a = torch.tensor(new double[] { -1.3490, -0.1723, 0.7730,
+ -1.6118, -0.3385, -0.6490,
+ 0.0908, 2.0704, 0.5647,
+ -0.6451, 0.1911, 0.7353,
+ 0.5247, 0.5160, 0.5110}, 5, 3);
+
+ var l = linalg.svdvals(a);
+ Assert.True(l.allclose(torch.tensor(new double[] { 2.5138929972840613, 2.1086555338402455, 1.1064930672223237 }), rtol: 1e-04, atol: 1e-07));
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.lstsq))]
+ public void LSTSQTest()
+ {
+ var a = torch.randn(new long[] { 4, 25, 15 });
+ var b = torch.randn(new long[] { 4, 25, 10 });
+
+ var l = linalg.lstsq(a, b);
+
+ Assert.Equal(new long[] { 4, 15, 10 }, l.Solution.shape);
+ Assert.Equal(0, l.Residuals.shape[0]);
+ Assert.Equal(new long[] { 4 }, l.Rank.shape);
+ Assert.Equal(new long[] { 4, 15, 10 }, l.Solution.shape);
+ Assert.Equal(0, l.SingularValues.shape[0]);
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.lu))]
+ public void LUTest()
+ {
+ var A = torch.randn(2, 3, 3);
+ var A_factor = torch.linalg.lu(A);
+ // For right now, pretty much just checking that it's not blowing up.
+ Assert.Multiple(
+ () => Assert.NotNull(A_factor.P),
+ () => Assert.NotNull(A_factor.L),
+ () => Assert.NotNull(A_factor.U)
+ );
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.lu_factor))]
+ public void LUFactorTest()
+ {
+ var A = torch.randn(2, 3, 3);
+ var A_factor = torch.linalg.lu_factor(A);
+ // For right now, pretty much just checking that it's not blowing up.
+ Assert.Multiple(
+ () => Assert.NotNull(A_factor.LU),
+ () => Assert.NotNull(A_factor.Pivots)
+ );
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.ldl_factor))]
+ public void LDLFactorTest()
+ {
+ var A = torch.randn(2, 3, 3);
+ var A_factor = torch.linalg.ldl_factor(A);
+ // For right now, pretty much just checking that it's not blowing up.
+ Assert.Multiple(
+ () => Assert.NotNull(A_factor.LU),
+ () => Assert.NotNull(A_factor.Pivots)
+ );
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.ldl_factor))]
+ public void LDLFactorExTest()
+ {
+ var A = torch.randn(2, 3, 3);
+ var A_factor = torch.linalg.ldl_factor_ex(A);
+ // For right now, pretty much just checking that it's not blowing up.
+ Assert.Multiple(
+ () => Assert.NotNull(A_factor.LU),
+ () => Assert.NotNull(A_factor.Pivots),
+ () => Assert.NotNull(A_factor.Info)
+ );
+ }
+
+ [Fact]
+ [TestOf(nameof(Tensor.matrix_power))]
+ public void MatrixPowerTest()
+ {
+ var a = torch.randn(new long[] { 25, 25 });
+ var b = a.matrix_power(3);
+ Assert.Equal(new long[] { 25, 25 }, b.shape);
+ }
+
+ [Fact]
+ [TestOf(nameof(Tensor.matrix_exp))]
+ public void MatrixExpTest1()
+ {
+ var a = torch.randn(new long[] { 25, 25 });
+ var b = a.matrix_exp();
+ Assert.Equal(new long[] { 25, 25 }, b.shape);
+
+ var c = torch.matrix_exp(a);
+ Assert.Equal(new long[] { 25, 25 }, c.shape);
+ }
+
+ [Fact]
+ [TestOf(nameof(torch.matrix_exp))]
+ public void MatrixExpTest2()
+ {
+ var a = torch.randn(new long[] { 16, 25, 25 });
+ var b = a.matrix_exp();
+ Assert.Equal(new long[] { 16, 25, 25 }, b.shape);
+ var c = torch.matrix_exp(a);
+ Assert.Equal(new long[] { 16, 25, 25 }, c.shape);
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.matrix_rank))]
+ public void MatrixRankTest()
+ {
+ var mr1 = torch.linalg.matrix_rank(torch.randn(4, 3, 2));
+ Assert.Equal(new long[] { 4 }, mr1.shape);
+
+ var mr2 = torch.linalg.matrix_rank(torch.randn(2, 4, 3, 2));
+ Assert.Equal(new long[] { 2, 4 }, mr2.shape);
+
+ // Really just testing that it doesn't blow up in interop for the following lines:
+
+ mr2 = torch.linalg.matrix_rank(torch.randn(2, 4, 3, 2), atol: 1.0);
+ Assert.Equal(new long[] { 2, 4 }, mr2.shape);
+
+ mr2 = torch.linalg.matrix_rank(torch.randn(2, 4, 3, 2), atol: 1.0, rtol: 0.0);
+ Assert.Equal(new long[] { 2, 4 }, mr2.shape);
+
+ mr2 = torch.linalg.matrix_rank(torch.randn(2, 4, 3, 2), atol: torch.tensor(1.0));
+ Assert.Equal(new long[] { 2, 4 }, mr2.shape);
+
+ mr2 = torch.linalg.matrix_rank(torch.randn(2, 4, 3, 2), atol: torch.tensor(1.0), rtol: torch.tensor(0.0));
+ Assert.Equal(new long[] { 2, 4 }, mr2.shape);
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.multi_dot))]
+ public void MultiDotTest()
+ {
+ var a = torch.randn(new long[] { 25, 25 });
+ var b = torch.randn(new long[] { 25, 25 });
+ var c = torch.randn(new long[] { 25, 25 });
+ var d = torch.linalg.multi_dot(new Tensor[] { a, b, c });
+ Assert.Equal(new long[] { 25, 25 }, d.shape);
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.det))]
+ public void DeterminantTest()
+ {
+ {
+ var a = torch.tensor(
+ new float[] { 0.9478f, 0.9158f, -1.1295f,
+ 0.9701f, 0.7346f, -1.8044f,
+ -0.2337f, 0.0557f, 0.6929f }, 3, 3);
+ var l = linalg.det(a);
+ Assert.True(l.allclose(torch.tensor(0.09335048f)));
+ }
+ {
+ var a = torch.tensor(
+ new float[] { 0.9254f, -0.6213f, -0.5787f, 1.6843f, 0.3242f, -0.9665f,
+ 0.4539f, -0.0887f, 1.1336f, -0.4025f, -0.7089f, 0.9032f }, 3, 2, 2);
+ var l = linalg.det(a);
+ Assert.True(l.allclose(torch.tensor(new float[] { 1.19910491f, 0.4099378f, 0.7385352f })));
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.matrix_norm))]
+ public void MatrixNormTest()
+ {
+ {
+ var a = torch.arange(9, float32).view(3, 3);
+
+ var b = linalg.matrix_norm(a);
+ var c = linalg.matrix_norm(a, ord: -1);
+
+ Assert.Equal(14.282857f, b.item());
+ Assert.Equal(9.0f, c.item());
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.vector_norm))]
+ public void VectorNormTest()
+ {
+ {
+ var a = torch.tensor(
+ new float[] { -4.0f, -3.0f, -2.0f, -1.0f, 0, 1.0f, 2.0f, 3.0f, 4.0f });
+
+ var b = linalg.vector_norm(a, ord: 3.5);
+ var c = linalg.vector_norm(a.view(3, 3), ord: 3.5);
+
+ Assert.Equal(5.4344883f, b.item());
+ Assert.Equal(5.4344883f, c.item());
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.pinv))]
+ public void PinvTest()
+ {
+ var mr1 = torch.linalg.pinv(torch.randn(4, 3, 5));
+ Assert.Equal(new long[] { 4, 5, 3 }, mr1.shape);
+
+ // Really just testing that it doesn't blow up in interop for the following lines:
+
+ mr1 = torch.linalg.pinv(torch.randn(4, 3, 5), atol: 1.0);
+ Assert.Equal(new long[] { 4, 5, 3 }, mr1.shape);
+
+ mr1 = torch.linalg.pinv(torch.randn(4, 3, 5), atol: 1.0, rtol: 0.0);
+ Assert.Equal(new long[] { 4, 5, 3 }, mr1.shape);
+
+ mr1 = torch.linalg.pinv(torch.randn(4, 3, 5), atol: torch.tensor(1.0));
+ Assert.Equal(new long[] { 4, 5, 3 }, mr1.shape);
+
+ mr1 = torch.linalg.pinv(torch.randn(4, 3, 5), atol: torch.tensor(1.0), rtol: torch.tensor(0.0));
+ Assert.Equal(new long[] { 4, 5, 3 }, mr1.shape);
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.eig))]
+ public void EigTest32()
+ {
+ {
+ var a = torch.tensor(
+ new float[] { 2.8050f, -0.3850f, -0.3850f, 3.2376f, -1.0307f, -2.7457f, -2.7457f, -1.7517f, 1.7166f }, 3, 3);
+
+ var expected = torch.tensor(
+ new (float, float)[] { (3.44288778f, 0.0f), (2.17609453f, 0.0f), (-2.128083f, 0.0f) });
+
+ {
+ var (values, vectors) = linalg.eig(a);
+ Assert.NotNull(vectors);
+ Assert.True(values.allclose(expected));
+ }
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.eigvals))]
+ public void EighvalsTest32()
+ {
+ {
+ var a = torch.tensor(
+ new float[] { 2.8050f, -0.3850f, -0.3850f, 3.2376f, -1.0307f, -2.7457f, -2.7457f, -1.7517f, 1.7166f }, 3, 3);
+ var expected = torch.tensor(
+ new (float, float)[] { (3.44288778f, 0.0f), (2.17609453f, 0.0f), (-2.128083f, 0.0f) });
+ var l = linalg.eigvals(a);
+ Assert.True(l.allclose(expected));
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.eigvals))]
+ public void EighvalsTest64()
+ {
+ // TODO: (Skip = "Not working on MacOS (note: may now be working, we need to recheck)")
+ if (!RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) {
+ var a = torch.tensor(
+ new double[] { 2.8050f, -0.3850f, -0.3850f, 3.2376f, -1.0307f, -2.7457f, -2.7457f, -1.7517f, 1.7166f }, 3, 3);
+ var expected = torch.tensor(
+ new System.Numerics.Complex[] { new System.Numerics.Complex(3.44288778f, 0.0f), new System.Numerics.Complex(2.17609453f, 0.0f), new System.Numerics.Complex(-2.128083f, 0.0f) });
+ var l = linalg.eigvals(a);
+ Assert.True(l.allclose(expected));
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.eigvalsh))]
+ public void EighvalshTest32()
+ {
+ // TODO: (Skip = "Not working on MacOS (note: may now be working, we need to recheck)")
+ if (!RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) {
+ var a = torch.tensor(
+ new float[] { 2.8050f, -0.3850f, -0.3850f, 3.2376f, -1.0307f, -2.7457f,
+ -2.7457f, -1.7517f, 1.7166f, 2.2207f, 2.2207f, -2.0898f }, 3, 2, 2);
+ var expected = torch.tensor(
+ new float[] { 2.5797f, 3.46290016f, -4.16046524f, 1.37806475f, -3.11126733f, 2.73806715f }, 3, 2);
+ var l = linalg.eigvalsh(a);
+ Assert.True(l.allclose(expected));
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.eigvalsh))]
+ public void EighvalshTest64()
+ {
+ {
+ var a = torch.tensor(
+ new double[] { 2.8050, -0.3850, -0.3850, 3.2376, -1.0307, -2.7457,
+ -2.7457, -1.7517, 1.7166, 2.2207, 2.2207, -2.0898 }, 3, 2, 2);
+ var expected = torch.tensor(
+ new double[] { 2.5797, 3.46290016, -4.16046524, 1.37806475, -3.11126733, 2.73806715 }, 3, 2);
+ var l = linalg.eigvalsh(a);
+ Assert.True(l.allclose(expected));
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(linalg.norm))]
+ public void LinalgNormTest()
+ {
+ {
+ var a = torch.tensor(
+ new float[] { -4.0f, -3.0f, -2.0f, -1.0f, 0.0f, 1.0f, 2.0f, 3.0f, 4.0f });
+ var b = a.reshape(3, 3);
+
+ Assert.True(linalg.norm(a).allclose(torch.tensor(7.7460f)));
+ Assert.True(linalg.norm(b).allclose(torch.tensor(7.7460f)));
+ Assert.True(linalg.norm(b, "fro").allclose(torch.tensor(7.7460f)));
+
+ Assert.True(linalg.norm(a, float.PositiveInfinity).allclose(torch.tensor(4.0f)));
+ Assert.True(linalg.norm(b, float.PositiveInfinity).allclose(torch.tensor(9.0f)));
+ Assert.True(linalg.norm(a, float.NegativeInfinity).allclose(torch.tensor(0.0f)));
+ Assert.True(linalg.norm(b, float.NegativeInfinity).allclose(torch.tensor(2.0f)));
+
+ Assert.True(linalg.norm(a, 1).allclose(torch.tensor(20.0f)));
+ Assert.True(linalg.norm(b, 1).allclose(torch.tensor(7.0f)));
+ Assert.True(linalg.norm(a, -1).allclose(torch.tensor(0.0f)));
+ Assert.True(linalg.norm(b, -1).allclose(torch.tensor(6.0f)));
+
+ Assert.True(linalg.norm(a, 2).allclose(torch.tensor(7.7460f)));
+ Assert.True(linalg.norm(b, 2).allclose(torch.tensor(7.3485f)));
+ Assert.True(linalg.norm(a, 3).allclose(torch.tensor(5.8480f)));
+ Assert.True(linalg.norm(a, -2).allclose(torch.tensor(0.0f)));
+ Assert.True(linalg.norm(a, -3).allclose(torch.tensor(0.0f)));
+ }
+ }
+
+ [Fact]
+ public void TestTrilIndex()
+ {
+ var a = torch.tril_indices(3, 3);
+ var expected = new long[] { 0, 1, 1, 2, 2, 2, 0, 0, 1, 0, 1, 2 };
+ Assert.Equal(expected, a.data().ToArray());
+ }
+
+ [Fact]
+ public void TestTriuIndex()
+ {
+ var a = torch.triu_indices(3, 3);
+ var expected = new long[] { 0, 0, 0, 1, 1, 2, 0, 1, 2, 1, 2, 2 };
+ Assert.Equal(expected, a.data().ToArray());
+ }
+ }
+}
diff --git a/test/TorchSharpTest/PointwiseTensorMath.cs b/test/TorchSharpTest/PointwiseTensorMath.cs
new file mode 100644
index 000000000..4041f3553
--- /dev/null
+++ b/test/TorchSharpTest/PointwiseTensorMath.cs
@@ -0,0 +1,961 @@
+// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
+using System;
+using System.IO;
+using System.Linq;
+using System.Runtime.InteropServices;
+using System.Collections.Generic;
+using System.Globalization;
+using Xunit;
+using Xunit.Sdk;
+using static TorchSharp.torch;
+
+#nullable enable
+
+namespace TorchSharp
+{
+#if NET472_OR_GREATER
+ [Collection("Sequential")]
+#endif // NET472_OR_GREATER
+ public class PointwiseTensorMath
+ {
+ [Fact]
+ [TestOf(nameof(Tensor))]
+ public void TestArithmeticOperatorsFloat16()
+ {
+ // Float16 arange_cuda not available on cuda in LibTorch 1.8.0
+ // Float16 arange_cpu not available on cuda in LibTorch 1.8.0
+ foreach (var device in new Device[] { torch.CPU, torch.CUDA }) {
+ if (device.type != DeviceType.CUDA || torch.cuda.is_available()) {
+ var c1 = torch.ones(new long[] { 10, 10 }, float16, device: device);
+ var c2 = torch.ones(new long[] { 10, 10 }, float16, device: device);
+ var c3 = torch.ones(new long[] { 10, 10 }, float16, device: device);
+ Func getFunc = (tt, i, j) => tt[i, j].ToSingle();
+ // scalar-tensor operators
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a + 0.5f, a => a + 0.5f);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5f + a, a => 0.5f + a);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a - 0.5f, a => a - 0.5f);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a * 0.5f, a => a * 0.5f);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5f * a, a => 0.5f * a);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a / 0.5f, a => a / 0.5f);
+
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a.add(0.5f), a => a + 0.5f);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a.sub(0.5f), a => a - 0.5f);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a.mul(0.5f), a => a * 0.5f);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a.div(0.5f), a => a / 0.5f);
+
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.add_(0.5f), a => a + 0.5f);
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.sub_(0.5f), a => a - 0.5f);
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.mul_(0.5f), a => a * 0.5f);
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.div_(0.5f), a => a / 0.5f);
+
+ // tensor-tensor operators
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a + b, (a, b) => a + b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a - b, (a, b) => a - b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a * b, (a, b) => a * b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a / b, (a, b) => a / b);
+
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.add(b), (a, b) => a + b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.sub(b), (a, b) => a - b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.mul(b), (a, b) => a * b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.div(b), (a, b) => a / b);
+
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.add_(b), (a, b) => a + b);
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.sub_(b), (a, b) => a - b);
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.mul_(b), (a, b) => a * b);
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.div_(b), (a, b) => a / b);
+ }
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(Tensor))]
+ public void TestArithmeticOperatorsBFloat16()
+ {
+ // BFloat16 arange_cuda not available on cuda in LibTorch 1.8.0
+ // BFloat16 arange_cpu not available on cuda in LibTorch 1.8.0
+ foreach (var device in new Device[] { torch.CPU, torch.CUDA }) {
+ if (device.type != DeviceType.CUDA || torch.cuda.is_available()) {
+ var c1 = torch.ones(new long[] { 10, 10 }, bfloat16, device: device);
+ var c2 = torch.ones(new long[] { 10, 10 }, bfloat16, device: device);
+ var c3 = torch.ones(new long[] { 10, 10 }, bfloat16, device: device);
+ Func getFunc = (tt, i, j) => tt[i, j].ToSingle();
+ // scalar-tensor operators
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a + 0.5f, a => a + 0.5f);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5f + a, a => 0.5f + a);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a - 0.5f, a => a - 0.5f);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a * 0.5f, a => a * 0.5f);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5f * a, a => 0.5f * a);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a / 0.5f, a => a / 0.5f);
+
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a.add(0.5f), a => a + 0.5f);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a.sub(0.5f), a => a - 0.5f);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a.mul(0.5f), a => a * 0.5f);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a.div(0.5f), a => a / 0.5f);
+
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.add_(0.5f), a => a + 0.5f);
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.sub_(0.5f), a => a - 0.5f);
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.mul_(0.5f), a => a * 0.5f);
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.div_(0.5f), a => a / 0.5f);
+
+ // tensor-tensor operators
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a + b, (a, b) => a + b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a - b, (a, b) => a - b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a * b, (a, b) => a * b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a / b, (a, b) => a / b);
+
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.add(b), (a, b) => a + b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.sub(b), (a, b) => a - b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.mul(b), (a, b) => a * b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.div(b), (a, b) => a / b);
+
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.add_(b), (a, b) => a + b);
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.sub_(b), (a, b) => a - b);
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.mul_(b), (a, b) => a * b);
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.div_(b), (a, b) => a / b);
+ }
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(Tensor))]
+ public void TestArithmeticOperatorsFloat32()
+ {
+ foreach (var device in new Device[] { torch.CPU, torch.CUDA }) {
+ if (device.type != DeviceType.CUDA || torch.cuda.is_available()) {
+ var c1 = torch.arange(0, 10, float32, device: device).expand(new long[] { 10, 10 });
+ var c2 = torch.arange(10, 0, -1, float32, device: device).expand(new long[] { 10, 10 });
+ var c3 = torch.ones(new long[] { 10, 10 }, float32, device: device);
+ Func getFunc = (tt, i, j) => tt[i, j].ToSingle();
+ // scalar-tensor operators
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a + 0.5f, a => a + 0.5f);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5f + a, a => 0.5f + a);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a - 0.5f, a => a - 0.5f);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a * 0.5f, a => a * 0.5f);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5f * a, a => 0.5f * a);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a / 0.5f, a => a / 0.5f);
+
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a.add(0.5f), a => a + 0.5f);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a.sub(0.5f), a => a - 0.5f);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a.mul(0.5f), a => a * 0.5f);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a.div(0.5f), a => a / 0.5f);
+
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.add_(0.5f), a => a + 0.5f);
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.sub_(0.5f), a => a - 0.5f);
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.mul_(0.5f), a => a * 0.5f);
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.div_(0.5f), a => a / 0.5f);
+
+ // tensor-tensor operators
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a + b, (a, b) => a + b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a - b, (a, b) => a - b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a * b, (a, b) => a * b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a / b, (a, b) => a / b);
+
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.add(b), (a, b) => a + b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.sub(b), (a, b) => a - b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.mul(b), (a, b) => a * b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.div(b), (a, b) => a / b);
+
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.add_(b), (a, b) => a + b);
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.sub_(b), (a, b) => a - b);
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.mul_(b), (a, b) => a * b);
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.div_(b), (a, b) => a / b);
+ }
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(Tensor))]
+ public void TestArithmeticOperatorsFloat64()
+ {
+ foreach (var device in new Device[] { torch.CPU, torch.CUDA }) {
+ if (device.type != DeviceType.CUDA || torch.cuda.is_available()) {
+ var c1 = torch.arange(0, 10, float64, device: device).expand(new long[] { 10, 10 });
+ var c2 = torch.arange(10, 0, -1, float64, device: device).expand(new long[] { 10, 10 });
+ var c3 = torch.ones(new long[] { 10, 10 }, float64, device: device);
+ Func getFunc = (tt, i, j) => tt[i, j].ToDouble();
+ // scalar-tensor operators
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a + 0.5, a => a + 0.5);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5 + a, a => 0.5 + a);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a - 0.5, a => a - 0.5);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a * 0.5, a => a * 0.5);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5 * a, a => 0.5 * a);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a / 0.5, a => a / 0.5);
+
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a.add(0.5), a => a + 0.5);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a.sub(0.5), a => a - 0.5);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a.mul(0.5), a => a * 0.5);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a.div(0.5), a => a / 0.5);
+
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.add_(0.5), a => a + 0.5);
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.sub_(0.5), a => a - 0.5);
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.mul_(0.5), a => a * 0.5);
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.div_(0.5), a => a / 0.5);
+
+ // tensor-tensor operators
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a + b, (a, b) => a + b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a - b, (a, b) => a - b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a * b, (a, b) => a * b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a / b, (a, b) => a / b);
+
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.add(b), (a, b) => a + b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.sub(b), (a, b) => a - b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.mul(b), (a, b) => a * b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.div(b), (a, b) => a / b);
+
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.add_(b), (a, b) => a + b);
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.sub_(b), (a, b) => a - b);
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.mul_(b), (a, b) => a * b);
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.div_(b), (a, b) => a / b);
+ }
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(Tensor))]
+ public void TestArithmeticOperatorsComplexFloat64()
+ {
+ foreach (var device in new Device[] { torch.CPU, torch.CUDA }) {
+ if (device.type != DeviceType.CUDA || torch.cuda.is_available()) {
+ var c1 = torch.arange(0, 10, complex128, device: device).expand(new long[] { 10, 10 });
+ var c2 = torch.arange(10, 0, -1, complex128, device: device).expand(new long[] { 10, 10 });
+ var c3 = torch.ones(new long[] { 10, 10 }, complex128, device: device);
+ Func getFunc = (tt, i, j) => tt[i, j].ToComplexFloat64();
+ // scalar-tensor operators
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a + 0.5, a => a + 0.5);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5 + a, a => 0.5 + a);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a - 0.5, a => a - 0.5);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a * 0.5, a => a * 0.5);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => 0.5 * a, a => 0.5 * a);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a / 0.5, a => a / 0.5);
+
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a.add(0.5), a => a + 0.5);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a.sub(0.5), a => a - 0.5);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a.mul(0.5), a => a * 0.5);
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a.div(0.5), a => a / 0.5);
+
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.add_(0.5), a => a + 0.5);
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.sub_(0.5), a => a - 0.5);
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.mul_(0.5), a => a * 0.5);
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.div_(0.5), a => a / 0.5);
+
+ // tensor-tensor operators
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a + b, (a, b) => a + b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a - b, (a, b) => a - b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a * b, (a, b) => a * b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a / b, (a, b) => a / b);
+
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.add(b), (a, b) => a + b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.sub(b), (a, b) => a - b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.mul(b), (a, b) => a * b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a.div(b), (a, b) => a / b);
+
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.add_(b), (a, b) => a + b);
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.sub_(b), (a, b) => a - b);
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.mul_(b), (a, b) => a * b);
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.div_(b), (a, b) => a / b);
+ }
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(Tensor))]
+ public void TestComparisonOperatorsFloat32()
+ {
+ foreach (var device in new Device[] { torch.CPU, torch.CUDA }) {
+ if (device.type != DeviceType.CUDA || torch.cuda.is_available()) {
+ var c1 = torch.arange(0, 10, float32, device: device).expand(new long[] { 10, 10 });
+ var c2 = torch.arange(10, 0, -1, float32, device: device).expand(new long[] { 10, 10 });
+ var c3 = torch.ones(new long[] { 10, 10 }, float32, device: device);
+ Func getFunc = (tt, i, j) => tt[i, j].ToSingle();
+ Func getFuncBool = (tt, i, j) => tt[i, j].ToBoolean();
+ // scalar-tensor operators
+ TestOneTensor(c1, c2, getFunc, getFuncBool, a => a == 5.0f, a => a == 5.0f);
+ TestOneTensor(c1, c2, getFunc, getFuncBool, a => a != 5.0f, a => a != 5.0f);
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.eq_(5.0f), a => a == 5.0f ? 1.0f : 0.0f);
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.ne_(5.0f), a => a != 5.0f ? 1.0f : 0.0f);
+
+ TestOneTensor(c1, c2, getFunc, getFuncBool, a => a < 5.0f, a => a < 5.0f);
+ TestOneTensor(c1, c2, getFunc, getFuncBool, a => 5.0f < a, a => 5.0f < a);
+ TestOneTensor(c1, c2, getFunc, getFuncBool, a => a <= 5.0f, a => a <= 5.0f);
+ TestOneTensor(c1, c2, getFunc, getFuncBool, a => 5.0f <= a, a => 5.0f <= a);
+ TestOneTensor(c1, c2, getFunc, getFuncBool, a => a > 5.0f, a => a > 5.0f);
+ TestOneTensor(c1, c2, getFunc, getFuncBool, a => 5.0f > a, a => 5.0f > a);
+ TestOneTensor(c1, c2, getFunc, getFuncBool, a => a >= 5.0f, a => a >= 5.0f);
+ TestOneTensor(c1, c2, getFunc, getFuncBool, a => 5.0f >= a, a => 5.0f >= a);
+
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.lt_(5.0f), a => a < 5.0f ? 1.0f : 0.0f);
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.le_(5.0f), a => a <= 5.0f ? 1.0f : 0.0f);
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.gt_(5.0f), a => a > 5.0f ? 1.0f : 0.0f);
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.ge_(5.0f), a => a >= 5.0f ? 1.0f : 0.0f);
+
+ TestOneTensor(c1, c2, getFunc, getFunc, a => a % 5.0f, a => a % 5.0f);
+ TestOneTensorInPlace(c1, c2, getFunc, a => a.remainder_(5.0f), a => a % 5.0f);
+
+ // tensor-tensor operators
+ TestTwoTensor(c1, c2, c3, getFunc, getFuncBool, (a, b) => a == b, (a, b) => a == b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFuncBool, (a, b) => a != b, (a, b) => a != b);
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.eq_(b), (a, b) => a == b ? 1.0f : 0.0f);
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.ne_(b), (a, b) => a != b ? 1.0f : 0.0f);
+
+ TestTwoTensor(c1, c2, c3, getFunc, getFuncBool, (a, b) => a < b, (a, b) => a < b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFuncBool, (a, b) => a <= b, (a, b) => a <= b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFuncBool, (a, b) => a > b, (a, b) => a > b);
+ TestTwoTensor(c1, c2, c3, getFunc, getFuncBool, (a, b) => a >= b, (a, b) => a >= b);
+
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.lt_(b), (a, b) => a < b ? 1.0f : 0.0f);
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.le_(b), (a, b) => a <= b ? 1.0f : 0.0f);
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.gt_(b), (a, b) => a > b ? 1.0f : 0.0f);
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.ge_(b), (a, b) => a >= b ? 1.0f : 0.0f);
+
+ TestTwoTensor(c1, c2, c3, getFunc, getFunc, (a, b) => a % b, (a, b) => a % b);
+ TestTwoTensorInPlace(c1, c2, c3, getFunc, (a, b) => a.remainder_(b), (a, b) => a % b);
+ }
+ }
+ }
+
+ private void TestOneTensor(
+ Tensor c1,
+ Tensor c2,
+ Func getFuncIn,
+ Func getFuncOut,
+ Func tensorFunc,
+ Func scalarFunc)
+ {
+ var x = c1 * c2;
+ var y = tensorFunc(x);
+
+ for (int i = 0; i < 10; i++) {
+ for (int j = 0; j < 10; j++) {
+ var xv = getFuncIn(x, i, j);
+ var yv = getFuncOut(y, i, j);
+ Assert.Equal(yv, scalarFunc(xv));
+ }
+ }
+ }
+
+ private void TestOneTensorInPlace(
+ Tensor c1,
+ Tensor c2,
+ Func getFuncIn,
+ Func tensorFunc,
+ Func scalarFunc)
+ {
+
+ var x = c1 * c2;
+ var xClone = x.clone();
+ var y = tensorFunc(x);
+
+ for (int i = 0; i < 10; i++) {
+ for (int j = 0; j < 10; j++) {
+ var xClonev = getFuncIn(xClone, i, j);
+ var xv = getFuncIn(x, i, j);
+ var yv = getFuncIn(y, i, j);
+ Assert.Equal(yv, scalarFunc(xClonev));
+ Assert.Equal(yv, xv);
+ }
+ }
+ }
+
+ private void TestTwoTensor(
+ Tensor c1,
+ Tensor c2,
+ Tensor c3,
+ Func getFuncIn,
+ Func getFuncOut,
+ Func tensorFunc,
+ Func scalarFunc)
+ {
+
+ var x = c1 * c3;
+ var y = c2 * c3;
+
+ var z = tensorFunc(x, y);
+
+ for (int i = 0; i < 10; i++) {
+ for (int j = 0; j < 10; j++) {
+ var xv = getFuncIn(x, i, j);
+ var yv = getFuncIn(y, i, j);
+ var zv = getFuncOut(z, i, j);
+ Assert.Equal(zv, scalarFunc(xv, yv));
+ }
+ }
+ }
+
+ private void TestTwoTensorInPlace(
+ Tensor c1,
+ Tensor c2,
+ Tensor c3,
+ Func getFuncIn,
+ Func tensorFunc,
+ Func scalarFunc) where Tin : unmanaged
+ {
+
+ var x = c1 * c3;
+ var xClone = x.clone();
+ var y = c2 * c3;
+
+ var z = tensorFunc(x, y);
+
+ if (x.device_type == DeviceType.CPU) {
+ var xData = x.data();
+ var yData = y.data();
+ var zData = z.data();
+
+ Assert.True(xData == zData);
+ }
+
+ for (int i = 0; i < 10; i++) {
+ for (int j = 0; j < 10; j++) {
+ var xClonev = getFuncIn(xClone, i, j);
+ var xv = getFuncIn(x, i, j);
+ var yv = getFuncIn(y, i, j);
+ var zv = getFuncIn(z, i, j);
+ Assert.Equal(zv, scalarFunc(xClonev, yv));
+ Assert.Equal(zv, xv);
+ }
+ }
+ }
+
+ [Fact]
+ [TestOf(nameof(Tensor.eq))]
+ [TestOf(nameof(Tensor.ne))]
+ [TestOf(nameof(Tensor.lt))]
+ [TestOf(nameof(Tensor.gt))]
+ [TestOf(nameof(Tensor.le))]
+ public void TestComparison()
+ {
+ var A = torch.tensor(new float[] { 1.2f, 3.4f, 1.4f, 3.3f }).reshape(2, 2);
+ var B = torch.tensor(new float[] { 1.3f, 3.3f });
+ Assert.Equal(new bool[] { false, false, false, true }, A.eq(B).data().ToArray());
+ Assert.Equal(new bool[] { false, false, false, true }, torch.eq(A, B).data