Skip to content

Commit

Permalink
DSL Supports Tensor (Finally) (#193)
Browse files Browse the repository at this point in the history
* scalar multiplication works for buffer

* done
  • Loading branch information
ryanmrichard authored Jan 13, 2025
1 parent 625f90c commit db7c294
Show file tree
Hide file tree
Showing 13 changed files with 510 additions and 48 deletions.
3 changes: 3 additions & 0 deletions include/tensorwrapper/buffer/eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ class Eigen : public Replicated {
dsl_reference permute_assignment_(label_type this_labels,
const_labeled_reference rhs) override;

dsl_reference scalar_multiplication_(label_type this_labels, double scalar,
const_labeled_reference rhs) override;

/// Implements to_string
typename polymorphic_base::string_type to_string_() const override;

Expand Down
18 changes: 11 additions & 7 deletions include/tensorwrapper/detail_/dsl_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,16 +209,18 @@ class DSLBase {
* @tparam ScalarType The type of @p scalar. Assumed to be a floating-
* point type.
*
* This method is responsible for scaling @p *this by @p scalar.
* This method is responsible for scaling @p rhs by @p scalar and assigning
* it to *this.
*
* @note This method is templated on the scalar type to avoid limiting the
* API. That said, at present the backend converts @p scalar to
* double precision.
* double precision, but we could use a variant or something similar
* to avoid this
*/
template<typename ScalarType>
dsl_reference scalar_multiplication(ScalarType&& scalar) {
return scalar_multiplication_(std::forward<ScalarType>(scalar));
}
template<typename LabelType, typename ScalarType>
dsl_reference scalar_multiplication(LabelType&& this_labels,
ScalarType&& scalar,
const_labeled_reference rhs);

protected:
/// Derived class should overwrite to implement addition_assignment
Expand Down Expand Up @@ -249,7 +251,9 @@ class DSLBase {
}

/// Derived class should overwrite to implement scalar_multiplication
dsl_reference scalar_multiplication_(double scalar) {
virtual dsl_reference scalar_multiplication_(label_type this_labels,
double scalar,
const_labeled_reference rhs) {
throw std::runtime_error("Scalar multiplication NYI");
}

Expand Down
12 changes: 12 additions & 0 deletions include/tensorwrapper/detail_/dsl_base.ipp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,18 @@ typename DSL_BASE::dsl_reference DSL_BASE::permute_assignment(
return permute_assignment_(std::move(lhs_labels), rhs);
}

TPARAMS
template<typename LabelType, typename FloatType>
typename DSL_BASE::dsl_reference DSL_BASE::scalar_multiplication(
LabelType&& this_labels, FloatType&& scalar, const_labeled_reference rhs) {
assert_indices_match_rank_(rhs);

label_type lhs_labels(std::forward<LabelType>(this_labels));
assert_is_subset_(lhs_labels, rhs.labels());

return scalar_multiplication_(std::move(lhs_labels), scalar, rhs);
}

#undef DSL_BASE
#undef TPARAMS

Expand Down
37 changes: 24 additions & 13 deletions include/tensorwrapper/dsl/pairwise_parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,7 @@ class PairwiseParser {
*/
template<typename LHSType, typename RHSType>
void dispatch(LHSType&& lhs, const RHSType& rhs) {
if constexpr(std::is_floating_point_v<std::decay_t<RHSType>>) {
lhs.object().scalar_multiplication(rhs);
} else {
lhs.object().permute_assignment(lhs.labels(), rhs);
}
lhs.object().permute_assignment(lhs.labels(), rhs);
}

/** @brief Handles adding two expressions together.
Expand Down Expand Up @@ -130,14 +126,29 @@ class PairwiseParser {
*/
template<typename LHSType, typename T, typename U>
void dispatch(LHSType&& lhs, const utilities::dsl::Multiply<T, U>& rhs) {
auto pA = lhs.object().clone();
auto pB = lhs.object().clone();
auto labels = lhs.labels();
auto lA = (*pA)(labels);
auto lB = (*pB)(labels);
dispatch(lA, rhs.lhs());
dispatch(lB, rhs.rhs());
lhs.object().multiplication_assignment(labels, lA, lB);
constexpr bool t_is_float = std::is_floating_point_v<T>;
constexpr bool u_is_float = std::is_floating_point_v<U>;
static_assert(!(t_is_float && u_is_float), "Both can be float??");
if constexpr(t_is_float) {
auto pA = lhs.object().clone();
auto lA = (*pA)(lhs.labels());
dispatch(lA, rhs.rhs());
lhs.object().scalar_multiplication(lhs.labels(), rhs.lhs(), lA);
} else if constexpr(u_is_float) {
auto pA = lhs.object().clone();
auto lA = (*pA)(lhs.labels());
dispatch(lA, rhs.lhs());
lhs.object().scalar_multiplication(lhs.labels(), rhs.rhs(), lA);
} else {
auto pA = lhs.object().clone();
auto pB = lhs.object().clone();
auto labels = lhs.labels();
auto lA = (*pA)(labels);
auto lB = (*pB)(labels);
dispatch(lA, rhs.lhs());
dispatch(lB, rhs.rhs());
lhs.object().multiplication_assignment(labels, lA, lB);
}
}
};

Expand Down
61 changes: 59 additions & 2 deletions include/tensorwrapper/tensor/tensor_class.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
*/

#pragma once
#include <tensorwrapper/dsl/labeled.hpp>
#include <tensorwrapper/detail_/dsl_base.hpp>
#include <tensorwrapper/detail_/polymorphic_base.hpp>
#include <tensorwrapper/tensor/detail_/tensor_input.hpp>

namespace tensorwrapper {
Expand All @@ -34,7 +35,8 @@ struct IsTuple<std::tuple<Args...>> : std::true_type {};
* The Tensor class is envisioned as being the most user-facing class of
* TensorWrapper and forms the entry point into TensorWrapper's DSL.
*/
class Tensor {
class Tensor : public detail_::DSLBase<Tensor>,
public detail_::PolymorphicBase<Tensor> {
private:
/// Type of a helper class which collects the inputs needed to make a tensor
using input_type = detail_::TensorInput;
Expand All @@ -53,6 +55,8 @@ class Tensor {
using enable_if_no_tensors_t =
std::enable_if_t<!are_any_tensors_v<Args...>>;

using polymorphic_base = detail_::PolymorphicBase<Tensor>;

public:
/// Type of the object implementing *this
using pimpl_type = detail_::TensorPIMPL;
Expand Down Expand Up @@ -81,6 +85,9 @@ class Tensor {
/// Type of a pointer to a read-only buffer
using const_buffer_pointer = input_type::const_buffer_pointer;

/// Type used to convey rank
using rank_type = typename logical_layout_type::size_type;

/// Type of an initializer list if *this is a scalar
using scalar_il_type = double;

Expand Down Expand Up @@ -299,6 +306,19 @@ class Tensor {
*/
const_buffer_reference buffer() const;

/** @brief Returns the logical rank of the tensor.
*
* Most users interacting with a tensor will be thinking of it in terms of
* its logical rank. This function is a convenience function for calling
* `rank()` on the logical layout.
*
* @return The rank of the tensor, logically.
*
* @throw std::runtime_error if *this does not have a logical layout.
* Strong throw guarantee.
*/
rank_type rank() const;

// -------------------------------------------------------------------------
// -- Utility methods
// -------------------------------------------------------------------------
Expand Down Expand Up @@ -344,6 +364,43 @@ class Tensor {
*/
bool operator!=(const Tensor& rhs) const noexcept;

protected:
/// Implements clone by calling copy ctor
polymorphic_base::base_pointer clone_() const override {
return std::make_unique<Tensor>(*this);
}

/// Implements are_equal by calling are_equal_impl_
bool are_equal_(const_base_reference rhs) const noexcept override {
return polymorphic_base::are_equal_impl_<Tensor>(rhs);
}

/// Implements addition_assignment by calling addition_assignment on state
dsl_reference addition_assignment_(label_type this_labels,
const_labeled_reference lhs,
const_labeled_reference rhs) override;

/// Calls subtraction_assignment on each member
dsl_reference subtraction_assignment_(label_type this_labels,
const_labeled_reference lhs,
const_labeled_reference rhs) override;

/// Calls multiplication_assignment on each member
dsl_reference multiplication_assignment_(
label_type this_labels, const_labeled_reference lhs,
const_labeled_reference rhs) override;

/// Calls scalar_multiplication on each member
dsl_reference scalar_multiplication_(label_type this_labels, double scalar,
const_labeled_reference rhs) override;

/// Calls permute_assignment on each member
dsl_reference permute_assignment_(label_type this_labels,
const_labeled_reference rhs) override;

/// Implements to_string
typename polymorphic_base::string_type to_string_() const override;

private:
/// All ctors ultimately dispatch to this ctor
Tensor(pimpl_pointer pimpl) noexcept;
Expand Down
24 changes: 24 additions & 0 deletions src/tensorwrapper/buffer/eigen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,30 @@ typename EIGEN::dsl_reference EIGEN::permute_assignment_(
return *this;
}

TPARAMS
typename EIGEN::dsl_reference EIGEN::scalar_multiplication_(
label_type this_labels, double scalar, const_labeled_reference rhs) {
BufferBase::permute_assignment_(this_labels, rhs);

using allocator_type = allocator::Eigen<FloatType, Rank>;
const auto& rhs_downcasted = allocator_type::rebind(rhs.object());

const auto& rlabels = rhs.labels();

FloatType c(scalar);

if(this_labels != rlabels) { // We need to permute rhs before assignment
auto r_to_l = rhs.labels().permutation(this_labels);
// Eigen wants int objects
std::vector<int> r_to_l2(r_to_l.begin(), r_to_l.end());
m_tensor_ = rhs_downcasted.value().shuffle(r_to_l2) * c;
} else {
m_tensor_ = rhs_downcasted.value() * c;
}

return *this;
}

TPARAMS
typename detail_::PolymorphicBase<BufferBase>::string_type EIGEN::to_string_()
const {
Expand Down
131 changes: 131 additions & 0 deletions src/tensorwrapper/tensor/tensor_class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ const_buffer_reference Tensor::buffer() const {
return m_pimpl_->buffer();
}

Tensor::rank_type Tensor::rank() const { return logical_layout().rank(); }

// -- Utility

void Tensor::swap(Tensor& other) noexcept { m_pimpl_.swap(other.m_pimpl_); }
Expand All @@ -89,6 +91,135 @@ bool Tensor::operator!=(const Tensor& rhs) const noexcept {
return !(*this == rhs);
}

// -- Protected methods

Tensor::dsl_reference Tensor::addition_assignment_(
label_type this_labels, const_labeled_reference lhs,
const_labeled_reference rhs) {
const auto& lobject = lhs.object();
const auto& llabels = lhs.labels();
const auto& robject = rhs.object();
const auto& rlabels = rhs.labels();

auto llayout = lobject.logical_layout();
auto rlayout = robject.logical_layout();
auto pthis_layout = llayout.clone_as<logical_layout_type>();

pthis_layout->addition_assignment(this_labels, llayout(llabels),
rlayout(rlabels));

auto pthis_buffer = lobject.buffer().clone();
auto lbuffer = lobject.buffer()(llabels);
auto rbuffer = robject.buffer()(rlabels);
pthis_buffer->addition_assignment(this_labels, lbuffer, rbuffer);

auto new_pimpl = std::make_unique<pimpl_type>(std::move(pthis_layout),
std::move(pthis_buffer));
new_pimpl.swap(m_pimpl_);

return *this;
}

Tensor::dsl_reference Tensor::subtraction_assignment_(
label_type this_labels, const_labeled_reference lhs,
const_labeled_reference rhs) {
const auto& lobject = lhs.object();
const auto& llabels = lhs.labels();
const auto& robject = rhs.object();
const auto& rlabels = rhs.labels();

auto llayout = lobject.logical_layout();
auto rlayout = robject.logical_layout();
auto pthis_layout = llayout.clone_as<logical_layout_type>();

pthis_layout->subtraction_assignment(this_labels, llayout(llabels),
rlayout(rlabels));

auto pthis_buffer = lobject.buffer().clone();
auto lbuffer = lobject.buffer()(llabels);
auto rbuffer = robject.buffer()(rlabels);
pthis_buffer->subtraction_assignment(this_labels, lbuffer, rbuffer);

auto new_pimpl = std::make_unique<pimpl_type>(std::move(pthis_layout),
std::move(pthis_buffer));
new_pimpl.swap(m_pimpl_);

return *this;
}

Tensor::dsl_reference Tensor::multiplication_assignment_(
label_type this_labels, const_labeled_reference lhs,
const_labeled_reference rhs) {
const auto& lobject = lhs.object();
const auto& llabels = lhs.labels();
const auto& robject = rhs.object();
const auto& rlabels = rhs.labels();

auto llayout = lobject.logical_layout();
auto rlayout = robject.logical_layout();
auto pthis_layout = llayout.clone_as<logical_layout_type>();

pthis_layout->multiplication_assignment(this_labels, llayout(llabels),
rlayout(rlabels));

auto pthis_buffer = lobject.buffer().clone();
auto lbuffer = lobject.buffer()(llabels);
auto rbuffer = robject.buffer()(rlabels);
pthis_buffer->multiplication_assignment(this_labels, lbuffer, rbuffer);

auto new_pimpl = std::make_unique<pimpl_type>(std::move(pthis_layout),
std::move(pthis_buffer));
new_pimpl.swap(m_pimpl_);

return *this;
}

Tensor::dsl_reference Tensor::scalar_multiplication_(
label_type this_labels, double scalar, const_labeled_reference rhs) {
const auto& robject = rhs.object();
const auto& rlabels = rhs.labels();

auto rlayout = robject.logical_layout();
auto pthis_layout = rlayout.clone_as<logical_layout_type>();

pthis_layout->permute_assignment(this_labels, rlayout(rlabels));

auto pthis_buffer = robject.buffer().clone();
auto rbuffer = robject.buffer()(rlabels);
pthis_buffer->scalar_multiplication(this_labels, scalar, rbuffer);

auto new_pimpl = std::make_unique<pimpl_type>(std::move(pthis_layout),
std::move(pthis_buffer));
new_pimpl.swap(m_pimpl_);

return *this;
}

Tensor::dsl_reference Tensor::permute_assignment_(label_type this_labels,
const_labeled_reference rhs) {
const auto& robject = rhs.object();
const auto& rlabels = rhs.labels();

auto rlayout = robject.logical_layout();
auto pthis_layout = rlayout.clone_as<logical_layout_type>();

pthis_layout->permute_assignment(this_labels, rlayout(rlabels));

auto pthis_buffer = robject.buffer().clone();
auto rbuffer = robject.buffer()(rlabels);
pthis_buffer->permute_assignment(this_labels, rbuffer);

auto new_pimpl = std::make_unique<pimpl_type>(std::move(pthis_layout),
std::move(pthis_buffer));
new_pimpl.swap(m_pimpl_);

return *this;
}

typename Tensor::polymorphic_base::string_type Tensor::to_string_() const {
return has_pimpl_() ? buffer().to_string() : "";
}

// -- Private methods

Tensor::Tensor(pimpl_pointer pimpl) noexcept : m_pimpl_(std::move(pimpl)) {}
Expand Down
Loading

0 comments on commit db7c294

Please sign in to comment.