From db7c294754850153499a10f9a6c1c88fe0a5585f Mon Sep 17 00:00:00 2001 From: Ryan Richard Date: Sun, 12 Jan 2025 21:47:10 -0600 Subject: [PATCH] DSL Supports Tensor (Finally) (#193) * scalar multiplication works for buffer * done --- include/tensorwrapper/buffer/eigen.hpp | 3 + include/tensorwrapper/detail_/dsl_base.hpp | 18 ++- include/tensorwrapper/detail_/dsl_base.ipp | 12 ++ include/tensorwrapper/dsl/pairwise_parser.hpp | 37 +++-- include/tensorwrapper/tensor/tensor_class.hpp | 61 +++++++- src/tensorwrapper/buffer/eigen.cpp | 24 ++++ src/tensorwrapper/tensor/tensor_class.cpp | 131 ++++++++++++++++++ .../unit_tests/tensorwrapper/buffer/eigen.cpp | 63 +++++++++ .../tensorwrapper/detail_/dsl_base.cpp | 19 ++- .../cxx/unit_tests/tensorwrapper/dsl/dsl.cpp | 18 ++- .../tensorwrapper/dsl/pairwise_parser.cpp | 35 +++-- .../tensorwrapper/tensor/tensor_class.cpp | 108 +++++++++++++++ .../unit_tests/tensorwrapper/testing/dsl.hpp | 29 ++-- 13 files changed, 510 insertions(+), 48 deletions(-) diff --git a/include/tensorwrapper/buffer/eigen.hpp b/include/tensorwrapper/buffer/eigen.hpp index 320e7f0d..d9aca697 100644 --- a/include/tensorwrapper/buffer/eigen.hpp +++ b/include/tensorwrapper/buffer/eigen.hpp @@ -199,6 +199,9 @@ class Eigen : public Replicated { dsl_reference permute_assignment_(label_type this_labels, const_labeled_reference rhs) override; + dsl_reference scalar_multiplication_(label_type this_labels, double scalar, + const_labeled_reference rhs) override; + /// Implements to_string typename polymorphic_base::string_type to_string_() const override; diff --git a/include/tensorwrapper/detail_/dsl_base.hpp b/include/tensorwrapper/detail_/dsl_base.hpp index 4e3bccff..ea6e8e1a 100644 --- a/include/tensorwrapper/detail_/dsl_base.hpp +++ b/include/tensorwrapper/detail_/dsl_base.hpp @@ -209,16 +209,18 @@ class DSLBase { * @tparam ScalarType The type of @p scalar. Assumed to be a floating- * point type. * - * This method is responsible for scaling @p *this by @p scalar. + * This method is responsible for scaling @p rhs by @p scalar and assigning + * it to *this. * * @note This method is templated on the scalar type to avoid limiting the * API. That said, at present the backend converts @p scalar to - * double precision. + * double precision, but we could use a variant or something similar + * to avoid this */ - template - dsl_reference scalar_multiplication(ScalarType&& scalar) { - return scalar_multiplication_(std::forward(scalar)); - } + template + dsl_reference scalar_multiplication(LabelType&& this_labels, + ScalarType&& scalar, + const_labeled_reference rhs); protected: /// Derived class should overwrite to implement addition_assignment @@ -249,7 +251,9 @@ class DSLBase { } /// Derived class should overwrite to implement scalar_multiplication - dsl_reference scalar_multiplication_(double scalar) { + virtual dsl_reference scalar_multiplication_(label_type this_labels, + double scalar, + const_labeled_reference rhs) { throw std::runtime_error("Scalar multiplication NYI"); } diff --git a/include/tensorwrapper/detail_/dsl_base.ipp b/include/tensorwrapper/detail_/dsl_base.ipp index 6e2db6a1..c84e436f 100644 --- a/include/tensorwrapper/detail_/dsl_base.ipp +++ b/include/tensorwrapper/detail_/dsl_base.ipp @@ -84,6 +84,18 @@ typename DSL_BASE::dsl_reference DSL_BASE::permute_assignment( return permute_assignment_(std::move(lhs_labels), rhs); } +TPARAMS +template +typename DSL_BASE::dsl_reference DSL_BASE::scalar_multiplication( + LabelType&& this_labels, FloatType&& scalar, const_labeled_reference rhs) { + assert_indices_match_rank_(rhs); + + label_type lhs_labels(std::forward(this_labels)); + assert_is_subset_(lhs_labels, rhs.labels()); + + return scalar_multiplication_(std::move(lhs_labels), scalar, rhs); +} + #undef DSL_BASE #undef TPARAMS diff --git a/include/tensorwrapper/dsl/pairwise_parser.hpp b/include/tensorwrapper/dsl/pairwise_parser.hpp index 0209dcab..e88a700d 100644 --- a/include/tensorwrapper/dsl/pairwise_parser.hpp +++ b/include/tensorwrapper/dsl/pairwise_parser.hpp @@ -59,11 +59,7 @@ class PairwiseParser { */ template void dispatch(LHSType&& lhs, const RHSType& rhs) { - if constexpr(std::is_floating_point_v>) { - lhs.object().scalar_multiplication(rhs); - } else { - lhs.object().permute_assignment(lhs.labels(), rhs); - } + lhs.object().permute_assignment(lhs.labels(), rhs); } /** @brief Handles adding two expressions together. @@ -130,14 +126,29 @@ class PairwiseParser { */ template void dispatch(LHSType&& lhs, const utilities::dsl::Multiply& rhs) { - auto pA = lhs.object().clone(); - auto pB = lhs.object().clone(); - auto labels = lhs.labels(); - auto lA = (*pA)(labels); - auto lB = (*pB)(labels); - dispatch(lA, rhs.lhs()); - dispatch(lB, rhs.rhs()); - lhs.object().multiplication_assignment(labels, lA, lB); + constexpr bool t_is_float = std::is_floating_point_v; + constexpr bool u_is_float = std::is_floating_point_v; + static_assert(!(t_is_float && u_is_float), "Both can be float??"); + if constexpr(t_is_float) { + auto pA = lhs.object().clone(); + auto lA = (*pA)(lhs.labels()); + dispatch(lA, rhs.rhs()); + lhs.object().scalar_multiplication(lhs.labels(), rhs.lhs(), lA); + } else if constexpr(u_is_float) { + auto pA = lhs.object().clone(); + auto lA = (*pA)(lhs.labels()); + dispatch(lA, rhs.lhs()); + lhs.object().scalar_multiplication(lhs.labels(), rhs.rhs(), lA); + } else { + auto pA = lhs.object().clone(); + auto pB = lhs.object().clone(); + auto labels = lhs.labels(); + auto lA = (*pA)(labels); + auto lB = (*pB)(labels); + dispatch(lA, rhs.lhs()); + dispatch(lB, rhs.rhs()); + lhs.object().multiplication_assignment(labels, lA, lB); + } } }; diff --git a/include/tensorwrapper/tensor/tensor_class.hpp b/include/tensorwrapper/tensor/tensor_class.hpp index 334b7047..57c2836f 100644 --- a/include/tensorwrapper/tensor/tensor_class.hpp +++ b/include/tensorwrapper/tensor/tensor_class.hpp @@ -15,7 +15,8 @@ */ #pragma once -#include +#include +#include #include namespace tensorwrapper { @@ -34,7 +35,8 @@ struct IsTuple> : std::true_type {}; * The Tensor class is envisioned as being the most user-facing class of * TensorWrapper and forms the entry point into TensorWrapper's DSL. */ -class Tensor { +class Tensor : public detail_::DSLBase, + public detail_::PolymorphicBase { private: /// Type of a helper class which collects the inputs needed to make a tensor using input_type = detail_::TensorInput; @@ -53,6 +55,8 @@ class Tensor { using enable_if_no_tensors_t = std::enable_if_t>; + using polymorphic_base = detail_::PolymorphicBase; + public: /// Type of the object implementing *this using pimpl_type = detail_::TensorPIMPL; @@ -81,6 +85,9 @@ class Tensor { /// Type of a pointer to a read-only buffer using const_buffer_pointer = input_type::const_buffer_pointer; + /// Type used to convey rank + using rank_type = typename logical_layout_type::size_type; + /// Type of an initializer list if *this is a scalar using scalar_il_type = double; @@ -299,6 +306,19 @@ class Tensor { */ const_buffer_reference buffer() const; + /** @brief Returns the logical rank of the tensor. + * + * Most users interacting with a tensor will be thinking of it in terms of + * its logical rank. This function is a convenience function for calling + * `rank()` on the logical layout. + * + * @return The rank of the tensor, logically. + * + * @throw std::runtime_error if *this does not have a logical layout. + * Strong throw guarantee. + */ + rank_type rank() const; + // ------------------------------------------------------------------------- // -- Utility methods // ------------------------------------------------------------------------- @@ -344,6 +364,43 @@ class Tensor { */ bool operator!=(const Tensor& rhs) const noexcept; +protected: + /// Implements clone by calling copy ctor + polymorphic_base::base_pointer clone_() const override { + return std::make_unique(*this); + } + + /// Implements are_equal by calling are_equal_impl_ + bool are_equal_(const_base_reference rhs) const noexcept override { + return polymorphic_base::are_equal_impl_(rhs); + } + + /// Implements addition_assignment by calling addition_assignment on state + dsl_reference addition_assignment_(label_type this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs) override; + + /// Calls subtraction_assignment on each member + dsl_reference subtraction_assignment_(label_type this_labels, + const_labeled_reference lhs, + const_labeled_reference rhs) override; + + /// Calls multiplication_assignment on each member + dsl_reference multiplication_assignment_( + label_type this_labels, const_labeled_reference lhs, + const_labeled_reference rhs) override; + + /// Calls scalar_multiplication on each member + dsl_reference scalar_multiplication_(label_type this_labels, double scalar, + const_labeled_reference rhs) override; + + /// Calls permute_assignment on each member + dsl_reference permute_assignment_(label_type this_labels, + const_labeled_reference rhs) override; + + /// Implements to_string + typename polymorphic_base::string_type to_string_() const override; + private: /// All ctors ultimately dispatch to this ctor Tensor(pimpl_pointer pimpl) noexcept; diff --git a/src/tensorwrapper/buffer/eigen.cpp b/src/tensorwrapper/buffer/eigen.cpp index 81b726ff..9980ef5b 100644 --- a/src/tensorwrapper/buffer/eigen.cpp +++ b/src/tensorwrapper/buffer/eigen.cpp @@ -150,6 +150,30 @@ typename EIGEN::dsl_reference EIGEN::permute_assignment_( return *this; } +TPARAMS +typename EIGEN::dsl_reference EIGEN::scalar_multiplication_( + label_type this_labels, double scalar, const_labeled_reference rhs) { + BufferBase::permute_assignment_(this_labels, rhs); + + using allocator_type = allocator::Eigen; + const auto& rhs_downcasted = allocator_type::rebind(rhs.object()); + + const auto& rlabels = rhs.labels(); + + FloatType c(scalar); + + if(this_labels != rlabels) { // We need to permute rhs before assignment + auto r_to_l = rhs.labels().permutation(this_labels); + // Eigen wants int objects + std::vector r_to_l2(r_to_l.begin(), r_to_l.end()); + m_tensor_ = rhs_downcasted.value().shuffle(r_to_l2) * c; + } else { + m_tensor_ = rhs_downcasted.value() * c; + } + + return *this; +} + TPARAMS typename detail_::PolymorphicBase::string_type EIGEN::to_string_() const { diff --git a/src/tensorwrapper/tensor/tensor_class.cpp b/src/tensorwrapper/tensor/tensor_class.cpp index 559b37ad..f27c686e 100644 --- a/src/tensorwrapper/tensor/tensor_class.cpp +++ b/src/tensorwrapper/tensor/tensor_class.cpp @@ -75,6 +75,8 @@ const_buffer_reference Tensor::buffer() const { return m_pimpl_->buffer(); } +Tensor::rank_type Tensor::rank() const { return logical_layout().rank(); } + // -- Utility void Tensor::swap(Tensor& other) noexcept { m_pimpl_.swap(other.m_pimpl_); } @@ -89,6 +91,135 @@ bool Tensor::operator!=(const Tensor& rhs) const noexcept { return !(*this == rhs); } +// -- Protected methods + +Tensor::dsl_reference Tensor::addition_assignment_( + label_type this_labels, const_labeled_reference lhs, + const_labeled_reference rhs) { + const auto& lobject = lhs.object(); + const auto& llabels = lhs.labels(); + const auto& robject = rhs.object(); + const auto& rlabels = rhs.labels(); + + auto llayout = lobject.logical_layout(); + auto rlayout = robject.logical_layout(); + auto pthis_layout = llayout.clone_as(); + + pthis_layout->addition_assignment(this_labels, llayout(llabels), + rlayout(rlabels)); + + auto pthis_buffer = lobject.buffer().clone(); + auto lbuffer = lobject.buffer()(llabels); + auto rbuffer = robject.buffer()(rlabels); + pthis_buffer->addition_assignment(this_labels, lbuffer, rbuffer); + + auto new_pimpl = std::make_unique(std::move(pthis_layout), + std::move(pthis_buffer)); + new_pimpl.swap(m_pimpl_); + + return *this; +} + +Tensor::dsl_reference Tensor::subtraction_assignment_( + label_type this_labels, const_labeled_reference lhs, + const_labeled_reference rhs) { + const auto& lobject = lhs.object(); + const auto& llabels = lhs.labels(); + const auto& robject = rhs.object(); + const auto& rlabels = rhs.labels(); + + auto llayout = lobject.logical_layout(); + auto rlayout = robject.logical_layout(); + auto pthis_layout = llayout.clone_as(); + + pthis_layout->subtraction_assignment(this_labels, llayout(llabels), + rlayout(rlabels)); + + auto pthis_buffer = lobject.buffer().clone(); + auto lbuffer = lobject.buffer()(llabels); + auto rbuffer = robject.buffer()(rlabels); + pthis_buffer->subtraction_assignment(this_labels, lbuffer, rbuffer); + + auto new_pimpl = std::make_unique(std::move(pthis_layout), + std::move(pthis_buffer)); + new_pimpl.swap(m_pimpl_); + + return *this; +} + +Tensor::dsl_reference Tensor::multiplication_assignment_( + label_type this_labels, const_labeled_reference lhs, + const_labeled_reference rhs) { + const auto& lobject = lhs.object(); + const auto& llabels = lhs.labels(); + const auto& robject = rhs.object(); + const auto& rlabels = rhs.labels(); + + auto llayout = lobject.logical_layout(); + auto rlayout = robject.logical_layout(); + auto pthis_layout = llayout.clone_as(); + + pthis_layout->multiplication_assignment(this_labels, llayout(llabels), + rlayout(rlabels)); + + auto pthis_buffer = lobject.buffer().clone(); + auto lbuffer = lobject.buffer()(llabels); + auto rbuffer = robject.buffer()(rlabels); + pthis_buffer->multiplication_assignment(this_labels, lbuffer, rbuffer); + + auto new_pimpl = std::make_unique(std::move(pthis_layout), + std::move(pthis_buffer)); + new_pimpl.swap(m_pimpl_); + + return *this; +} + +Tensor::dsl_reference Tensor::scalar_multiplication_( + label_type this_labels, double scalar, const_labeled_reference rhs) { + const auto& robject = rhs.object(); + const auto& rlabels = rhs.labels(); + + auto rlayout = robject.logical_layout(); + auto pthis_layout = rlayout.clone_as(); + + pthis_layout->permute_assignment(this_labels, rlayout(rlabels)); + + auto pthis_buffer = robject.buffer().clone(); + auto rbuffer = robject.buffer()(rlabels); + pthis_buffer->scalar_multiplication(this_labels, scalar, rbuffer); + + auto new_pimpl = std::make_unique(std::move(pthis_layout), + std::move(pthis_buffer)); + new_pimpl.swap(m_pimpl_); + + return *this; +} + +Tensor::dsl_reference Tensor::permute_assignment_(label_type this_labels, + const_labeled_reference rhs) { + const auto& robject = rhs.object(); + const auto& rlabels = rhs.labels(); + + auto rlayout = robject.logical_layout(); + auto pthis_layout = rlayout.clone_as(); + + pthis_layout->permute_assignment(this_labels, rlayout(rlabels)); + + auto pthis_buffer = robject.buffer().clone(); + auto rbuffer = robject.buffer()(rlabels); + pthis_buffer->permute_assignment(this_labels, rbuffer); + + auto new_pimpl = std::make_unique(std::move(pthis_layout), + std::move(pthis_buffer)); + new_pimpl.swap(m_pimpl_); + + return *this; +} + +typename Tensor::polymorphic_base::string_type Tensor::to_string_() const { + return has_pimpl_() ? buffer().to_string() : ""; +} + // -- Private methods Tensor::Tensor(pimpl_pointer pimpl) noexcept : m_pimpl_(std::move(pimpl)) {} diff --git a/tests/cxx/unit_tests/tensorwrapper/buffer/eigen.cpp b/tests/cxx/unit_tests/tensorwrapper/buffer/eigen.cpp index 1ab164da..abfd5c8b 100644 --- a/tests/cxx/unit_tests/tensorwrapper/buffer/eigen.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/buffer/eigen.cpp @@ -461,6 +461,69 @@ TEMPLATE_TEST_CASE("Eigen", "", float, double) { } } + SECTION("scalar_multiplication_") { + SECTION("scalar") { + auto scalar2 = testing::eigen_scalar(); + scalar2.value()() = 42.0; + + auto s = scalar(""); + auto pscalar2 = &(scalar2.scalar_multiplication("", 2.0, s)); + + auto corr = testing::eigen_scalar(); + corr.value()() = 20.0; + REQUIRE(pscalar2 == &scalar2); + REQUIRE(scalar2 == corr); + } + + SECTION("vector") { + auto vector2 = testing::eigen_vector(); + + auto vi = vector("i"); + auto pvector2 = &(vector2.scalar_multiplication("i", 2.0, vi)); + + auto corr = testing::eigen_vector(2); + corr.value()(0) = 20.0; + corr.value()(1) = 40.0; + + REQUIRE(pvector2 == &vector2); + REQUIRE(vector2 == corr); + } + + SECTION("matrix : no permutation") { + auto matrix2 = testing::eigen_matrix(); + + auto mij = matrix("i,j"); + auto p = &(matrix2.scalar_multiplication("i,j", 2.0, mij)); + + auto corr = testing::eigen_matrix(2, 3); + corr.value()(0, 0) = 20.0; + corr.value()(0, 1) = 40.0; + corr.value()(0, 2) = 60.0; + corr.value()(1, 0) = 80.0; + corr.value()(1, 1) = 100.0; + corr.value()(1, 2) = 120.0; + + REQUIRE(p == &matrix2); + REQUIRE(matrix2 == corr); + } + + SECTION("matrix: permutation") { + auto matrix2 = testing::eigen_matrix(); + auto mij = matrix("i,j"); + auto p = &(matrix2.scalar_multiplication("j,i", 2.0, mij)); + + auto corr = testing::eigen_matrix(3, 2); + corr.value()(0, 0) = 20.0; + corr.value()(1, 0) = 40.0; + corr.value()(2, 0) = 60.0; + corr.value()(0, 1) = 80.0; + corr.value()(1, 1) = 100.0; + corr.value()(2, 1) = 120.0; + REQUIRE(p == &matrix2); + compare_eigen(corr.value(), matrix2.value()); + } + } + SECTION("hadamard_") { SECTION("scalar") { scalar_buffer scalar2(eigen_scalar, scalar_layout); diff --git a/tests/cxx/unit_tests/tensorwrapper/detail_/dsl_base.cpp b/tests/cxx/unit_tests/tensorwrapper/detail_/dsl_base.cpp index 3f162775..5257eef8 100644 --- a/tests/cxx/unit_tests/tensorwrapper/detail_/dsl_base.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/detail_/dsl_base.cpp @@ -170,11 +170,24 @@ TEMPLATE_LIST_TEST_CASE("DSLBase", "", test_types) { } SECTION("scalar_multiplication") { + using error_t = std::runtime_error; // N.b., only tensor and buffer will override so here we're checking // that other objects throw - using error_t = std::runtime_error; + if constexpr(std::is_same_v) { + auto s = default_value(""); + auto sij = default_value("i,j"); - // Input's indices must match rank - REQUIRE_THROWS_AS(value.scalar_multiplication(1.0), error_t); + // Input's indices mush match rank + REQUIRE_THROWS_AS(value.scalar_multiplication("i,j", 2.0, sij), + error_t); + + // Output must have <= number of dummy indices + REQUIRE_THROWS_AS(value.scalar_multiplication("i,j", 2.0, s), + error_t); + + } else { + auto s = default_value(""); + REQUIRE_THROWS_AS(value.scalar_multiplication("", 1.0, s), error_t); + } } } \ No newline at end of file diff --git a/tests/cxx/unit_tests/tensorwrapper/dsl/dsl.cpp b/tests/cxx/unit_tests/tensorwrapper/dsl/dsl.cpp index f8f3a4ab..3565b558 100644 --- a/tests/cxx/unit_tests/tensorwrapper/dsl/dsl.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/dsl/dsl.cpp @@ -62,6 +62,17 @@ TEMPLATE_LIST_TEST_CASE("DSL", "", testing::dsl_types) { value1.multiplication_assignment("i,j", value2("i,j"), value2("i,j")); REQUIRE(value1.are_equal(value0)); } + + SECTION("scalar_multiplication") { + if constexpr(std::is_same_v) { + } else { + // N.b., only tensor and buffer will override so here we're checking + // that other objects throw + using error_t = std::runtime_error; + + REQUIRE_THROWS_AS(value0("") = value0("") * 1.0, error_t); + } + } } // Since Eigen buffers are templated on the rank there isn't an easy way to @@ -109,9 +120,8 @@ TEST_CASE("DSLr : buffer::Eigen") { } SECTION("scalar_multiplication") { - // This should actually work. Will fix in a future PR - using error_t = std::runtime_error; - - REQUIRE_THROWS_AS(scalar0("") = scalar0("") * 1.0, error_t); + scalar0("") = scalar1("") * 1.0; + corr.scalar_multiplication("", 1.0, scalar1("")); + REQUIRE(corr.are_equal(scalar0)); } } \ No newline at end of file diff --git a/tests/cxx/unit_tests/tensorwrapper/dsl/pairwise_parser.cpp b/tests/cxx/unit_tests/tensorwrapper/dsl/pairwise_parser.cpp index 7478335d..746b8f33 100644 --- a/tests/cxx/unit_tests/tensorwrapper/dsl/pairwise_parser.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/dsl/pairwise_parser.cpp @@ -95,11 +95,29 @@ TEMPLATE_LIST_TEST_CASE("PairwiseParser", "", testing::dsl_types) { } SECTION("scalar_multiplication") { - // N.b., only tensor and buffer will override so here we're checking - // that other objects throw - using error_t = std::runtime_error; - - REQUIRE_THROWS_AS(p.dispatch(value0(""), value0("") * 1.0), error_t); + if constexpr(std::is_same_v) { + object_type rv(value1); + object_type corr(value1); + + SECTION("scalar") { + p.dispatch(rv(""), value0("") * 2.0); + corr.scalar_multiplication("", 2.0, value0("")); + REQUIRE(corr.are_equal(rv)); + } + SECTION("matrix") { + p.dispatch(rv("i,j"), value2("i,j") * 2.0); + corr.scalar_multiplication("i,j", 2.0, value2("i,j")); + REQUIRE(corr.are_equal(rv)); + } + + } else { + // N.b., only tensor and buffer will override so here we're checking + // that other objects throw + using error_t = std::runtime_error; + + REQUIRE_THROWS_AS(p.dispatch(value0(""), value0("") * 1.0), + error_t); + } } } @@ -150,9 +168,8 @@ TEST_CASE("PairwiseParser : buffer::Eigen") { } SECTION("scalar_multiplication") { - // This should actually work. Will fix in a future PR - using error_t = std::runtime_error; - - REQUIRE_THROWS_AS(p.dispatch(scalar0(""), scalar0("") * 1.0), error_t); + p.dispatch(scalar0(""), scalar1("") * 1.0); + corr.scalar_multiplication("", 1.0, scalar1("")); + REQUIRE(corr.are_equal(scalar0)); } } \ No newline at end of file diff --git a/tests/cxx/unit_tests/tensorwrapper/tensor/tensor_class.cpp b/tests/cxx/unit_tests/tensorwrapper/tensor/tensor_class.cpp index f6161410..6e1f0ea5 100644 --- a/tests/cxx/unit_tests/tensorwrapper/tensor/tensor_class.cpp +++ b/tests/cxx/unit_tests/tensorwrapper/tensor/tensor_class.cpp @@ -121,6 +121,13 @@ TEST_CASE("Tensor") { REQUIRE_THROWS_AS(const_defaulted.buffer(), std::runtime_error); } + SECTION("rank") { + REQUIRE(scalar.rank() == 0); + REQUIRE(vector.rank() == 1); + + REQUIRE_THROWS_AS(defaulted.rank(), std::runtime_error); + } + SECTION("swap") { Tensor scalar_copy(scalar); Tensor vector_copy(vector); @@ -163,4 +170,105 @@ TEST_CASE("Tensor") { REQUIRE_FALSE(scalar != other_scalar); REQUIRE(scalar != vector); } + + SECTION("addition_assignment") { + SECTION("scalar") { + Tensor rv; + Tensor s0(42.0); + auto prv = &(rv.addition_assignment("", s0(""), s0(""))); + REQUIRE(prv == &rv); + Tensor corr(84.0); + REQUIRE(rv == corr); + } + SECTION("vector") { + Tensor rv; + Tensor v0{0, 1, 2, 3, 4}; + auto prv = &(rv.addition_assignment("i", v0("i"), v0("i"))); + REQUIRE(prv == &rv); + REQUIRE(rv == Tensor{0, 2, 4, 6, 8}); + } + } + SECTION("subtraction_assignment") { + SECTION("scalar") { + Tensor rv; + Tensor s0(42.0); + auto prv = &(rv.subtraction_assignment("", s0(""), s0(""))); + REQUIRE(prv == &rv); + REQUIRE(rv == Tensor(0.0)); + } + SECTION("vector") { + Tensor rv; + Tensor v0{0, 1, 2, 3, 4}; + auto prv = &(rv.subtraction_assignment("i", v0("i"), v0("i"))); + REQUIRE(prv == &rv); + REQUIRE(rv == Tensor{0, 0, 0, 0, 0}); + } + } + SECTION("multiplication_assignment") { + SECTION("scalar") { + Tensor rv; + Tensor s0(42.0); + auto prv = &(rv.multiplication_assignment("", s0(""), s0(""))); + REQUIRE(prv == &rv); + Tensor corr(1764.0); + REQUIRE(rv == corr); + } + SECTION("vector") { + Tensor rv; + Tensor v0{0, 1, 2, 3, 4}; + auto prv = &(rv.multiplication_assignment("i", v0("i"), v0("i"))); + REQUIRE(prv == &rv); + REQUIRE(rv == Tensor{0, 1, 4, 9, 16}); + } + } + SECTION("scalar_multiplication") { + SECTION("scalar") { + Tensor rv; + Tensor s0(42.0); + auto prv = &(rv.scalar_multiplication("", 2.0, s0(""))); + REQUIRE(prv == &rv); + Tensor corr(84.0); + REQUIRE(rv == corr); + } + SECTION("vector") { + Tensor rv; + Tensor v0{0, 1, 2, 3, 4}; + auto prv = &(rv.scalar_multiplication("i", 2.0, v0("i"))); + REQUIRE(prv == &rv); + REQUIRE(rv == Tensor{0.0, 2.0, 4.0, 6.0, 8.0}); + } + SECTION("matrix") { + Tensor rv; + Tensor m0{{1, 2}, {3, 4}}; + auto prv = &(rv.scalar_multiplication("j,i", 2.0, m0("i,j"))); + REQUIRE(prv == &rv); + Tensor corr{{2.0, 6.0}, {4.0, 8.0}}; + REQUIRE(rv == corr); + } + } + SECTION("permute_assignment") { + SECTION("scalar") { + Tensor rv; + Tensor s0(42.0); + auto prv = &(rv.permute_assignment("", s0(""))); + REQUIRE(prv == &rv); + Tensor corr(42.0); + REQUIRE(rv == corr); + } + SECTION("vector") { + Tensor rv; + Tensor v0{0, 1, 2, 3, 4}; + auto prv = &(rv.permute_assignment("i", v0("i"))); + REQUIRE(prv == &rv); + REQUIRE(rv == Tensor{0, 1, 2, 3, 4}); + } + SECTION("matrix") { + Tensor rv; + Tensor m0{{1, 2}, {3, 4}}; + auto prv = &(rv.permute_assignment("j,i", m0("i,j"))); + REQUIRE(prv == &rv); + Tensor corr{{1, 3}, {2, 4}}; + REQUIRE(rv == corr); + } + } } diff --git a/tests/cxx/unit_tests/tensorwrapper/testing/dsl.hpp b/tests/cxx/unit_tests/tensorwrapper/testing/dsl.hpp index 4b7ffe51..54e4f33f 100644 --- a/tests/cxx/unit_tests/tensorwrapper/testing/dsl.hpp +++ b/tests/cxx/unit_tests/tensorwrapper/testing/dsl.hpp @@ -25,24 +25,33 @@ namespace tensorwrapper::testing { using dsl_types = std::tuple; + tensorwrapper::layout::Physical, tensorwrapper::Tensor>; inline auto scalar_values() { - return dsl_types{smooth_scalar(), tensorwrapper::symmetry::Group(0), - tensorwrapper::sparsity::Pattern(0), scalar_logical(), - scalar_physical()}; + return dsl_types{smooth_scalar(), + tensorwrapper::symmetry::Group(0), + tensorwrapper::sparsity::Pattern(0), + scalar_logical(), + scalar_physical(), + Tensor(42.0)}; } inline auto vector_values() { - return dsl_types{smooth_vector(), tensorwrapper::symmetry::Group(1), - tensorwrapper::sparsity::Pattern(1), vector_logical(), - vector_physical()}; + return dsl_types{smooth_vector(), + tensorwrapper::symmetry::Group(1), + tensorwrapper::sparsity::Pattern(1), + vector_logical(), + vector_physical(), + Tensor{1.0, 2.0, 3.0}}; } inline auto matrix_values() { - return dsl_types{smooth_matrix(), tensorwrapper::symmetry::Group(2), - tensorwrapper::sparsity::Pattern(2), matrix_logical(), - matrix_physical()}; + return dsl_types{smooth_matrix(), + tensorwrapper::symmetry::Group(2), + tensorwrapper::sparsity::Pattern(2), + matrix_logical(), + matrix_physical(), + Tensor{{1.0, 2.0}, {3.0, 4.0}}}; } } // namespace tensorwrapper::testing \ No newline at end of file