Skip to content

Commit

Permalink
DSL Support for Layout and Buffer classes (#192)
Browse files Browse the repository at this point in the history
* backup [skip ci]

* backup [skip ci]

* works, let's see what GCC hates...

* Committing clang-format changes

* remove ambiguity in string_type

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
ryanmrichard and github-actions[bot] authored Jan 11, 2025
1 parent f36f194 commit 625f90c
Show file tree
Hide file tree
Showing 44 changed files with 1,850 additions and 281 deletions.
40 changes: 33 additions & 7 deletions include/tensorwrapper/buffer/buffer_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#pragma once
#include <tensorwrapper/detail_/dsl_base.hpp>
#include <tensorwrapper/detail_/polymorphic_base.hpp>
#include <tensorwrapper/dsl/labeled.hpp>
#include <tensorwrapper/layout/layout_base.hpp>
Expand All @@ -25,30 +26,33 @@ namespace tensorwrapper::buffer {
*
* All classes which wrap existing tensor libraries derive from this class.
*/
class BufferBase : public detail_::PolymorphicBase<BufferBase> {
class BufferBase : public detail_::PolymorphicBase<BufferBase>,
public detail_::DSLBase<BufferBase> {
private:
/// Type of *this
using my_type = BufferBase;

protected:
/// Type *this inherits from
using my_base_type = detail_::PolymorphicBase<my_type>;
using polymorphic_base = detail_::PolymorphicBase<my_type>;

public:
/// Type all buffers inherit from
using buffer_base_type = typename my_base_type::base_type;
using buffer_base_type = typename polymorphic_base::base_type;

/// Type of a mutable reference to a buffer_base_type object
using buffer_base_reference = typename my_base_type::base_reference;
using buffer_base_reference = typename polymorphic_base::base_reference;

/// Type of a read-only reference to a buffer_base_type object
using const_buffer_base_reference =
typename my_base_type::const_base_reference;
typename polymorphic_base::const_base_reference;

/// Type of a pointer to an object of type buffer_base_type
using buffer_base_pointer = typename my_base_type::base_pointer;
using buffer_base_pointer = typename polymorphic_base::base_pointer;

/// Type of a pointer to a read-only object of type buffer_base_type
using const_buffer_base_pointer = typename my_base_type::const_base_pointer;
using const_buffer_base_pointer =
typename polymorphic_base::const_base_pointer;

/// Type of the class describing the physical layout of the buffer
using layout_type = layout::LayoutBase;
Expand All @@ -59,6 +63,9 @@ class BufferBase : public detail_::PolymorphicBase<BufferBase> {
/// Type of a pointer to the layout
using layout_pointer = typename layout_type::layout_pointer;

/// Type used to represent the tensor's rank
using rank_type = typename layout_type::size_type;

// -------------------------------------------------------------------------
// -- Accessors
// -------------------------------------------------------------------------
Expand Down Expand Up @@ -90,6 +97,10 @@ class BufferBase : public detail_::PolymorphicBase<BufferBase> {
return *m_layout_;
}

rank_type rank() const noexcept {
return has_layout() ? layout().rank() : 0;
}

// -------------------------------------------------------------------------
// -- Utility methods
// -------------------------------------------------------------------------
Expand Down Expand Up @@ -191,6 +202,21 @@ class BufferBase : public detail_::PolymorphicBase<BufferBase> {
return *this;
}

dsl_reference addition_assignment_(label_type this_labels,
const_labeled_reference lhs,
const_labeled_reference rhs) override;

dsl_reference subtraction_assignment_(label_type this_labels,
const_labeled_reference lhs,
const_labeled_reference rhs) override;

dsl_reference multiplication_assignment_(
label_type this_labels, const_labeled_reference lhs,
const_labeled_reference rhs) override;

dsl_reference permute_assignment_(label_type this_labels,
const_labeled_reference rhs) override;

private:
/// Throws std::runtime_error when there is no layout
void assert_layout_() const {
Expand Down
28 changes: 27 additions & 1 deletion include/tensorwrapper/buffer/eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,36 @@ class Eigen : public Replicated {
return my_base_type::are_equal_impl_<my_type>(rhs);
}

/// Implements addition_assignment by calling addition_assignment on state
dsl_reference addition_assignment_(label_type this_labels,
const_labeled_reference lhs,
const_labeled_reference rhs) override;

/// Calls subtraction_assignment on each member
dsl_reference subtraction_assignment_(label_type this_labels,
const_labeled_reference lhs,
const_labeled_reference rhs) override;

/// Calls multiplication_assignment on each member
dsl_reference multiplication_assignment_(
label_type this_labels, const_labeled_reference lhs,
const_labeled_reference rhs) override;

/// Calls permute_assignment on each member
dsl_reference permute_assignment_(label_type this_labels,
const_labeled_reference rhs) override;

/// Implements to_string
typename my_base_type::string_type to_string_() const override;
typename polymorphic_base::string_type to_string_() const override;

private:
dsl_reference hadamard_(label_type this_labels, const_labeled_reference lhs,
const_labeled_reference rhs);

dsl_reference contraction_(label_type this_labels,
const_labeled_reference lhs,
const_labeled_reference rhs);

/// The actual Eigen tensor
data_type m_tensor_;
};
Expand Down
94 changes: 94 additions & 0 deletions include/tensorwrapper/dsl/dummy_indices.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,51 @@ class DummyIndices
return true;
}

/** @brief Is a thruple of DummyIndices consistent with a pure element-wise
* product?
*
* In generalized Einstein notation a pure element-wise (also commonly
* termed Hadamard) product is denoted by *this, @p lhs, and @p rhs
* having the same ordered set of dummy indices, up to permutation.
* Additionally, the dummy indices associated with any given tensor may
* not include a repeated index.
*
* @param[in] lhs The dummy indices associated with the tensor to the
* left of the times operator.
* @param[in] rhs The dummy indices associated with the tensor to the
* right of the times operator.
*
* @return True If the dummy indices given by *this, @p lhs, and @p rhs
* are consistent with a purely element-wise product of the tensors
* that @p lhs and @p rhs label.
*
* @throw None No throw guarantee.
*/
bool is_hadamard_product(const DummyIndices& lhs,
const DummyIndices& rhs) const noexcept;

/** @brief Does a thruple of DummyIndices indicate a product is a pure
* contraction?
*
* In generalized Einstein notation a pure contraction is an operation
* where indices common to @p lhs and @p rhs are summed over and do NOT
* appear in the result, i.e., *this. Additionally, we stipulate that
* there must be at least one index summed over (if no index is summed over
* the operation is a pure direct-product).
*
* @param[in] lhs The dummy indices associated with the tensor to the
* left of the times operator.
* @param[in] rhs The dummy indices associated with the tensor to the
* right of the times operator.
*
* @return True if the indices associated with *this, @p lhs, and @p rhs
* are consistent with a contraction and false otherwise.
*
* @throw None No throw guarantee.
*/
bool is_contraction(const DummyIndices& lhs,
const DummyIndices& rhs) const noexcept;

/** @brief Computes the permutation needed to convert *this into @p other.
*
* Each DummyIndices object is viewed as an ordered set of objects. If
Expand Down Expand Up @@ -366,6 +411,31 @@ class DummyIndices
return rv;
}

/** @brief Returns the set difference of *this and @p other.
*
* The set difference of *this with @p other is the set of indices which
* appear in *this, but not in @p other. This method will return the set
* (indices which appear more than once in *this will only appear once
* in the result) which results from the set difference of *this with
* @p other.
*
* @param[in] other The set to remove from *this.
*
* @return The set difference of *this and @p rhs.
*
* @throw std::bad_alloc if there is a problem allocating the return.
* Strong throw guarantee.
*/
DummyIndices difference(const DummyIndices& other) const {
DummyIndices rv;
for(const auto& x : *this) {
if(other.count(x)) continue;
if(rv.count(x)) continue;
rv.m_dummy_indices_.push_back(x);
}
return rv;
}

protected:
/// Main ctor for setting the value, throws if any index is empty
explicit DummyIndices(split_string_type split_dummy_indices) :
Expand Down Expand Up @@ -401,4 +471,28 @@ class DummyIndices
split_string_type m_dummy_indices_;
};

template<typename StringType>
bool DummyIndices<StringType>::is_hadamard_product(
const DummyIndices& lhs, const DummyIndices& rhs) const noexcept {
if(has_repeated_indices()) return false;
if(lhs.has_repeated_indices()) return false;
if(rhs.has_repeated_indices()) return false;
if(!is_permutation(lhs)) return false;
if(!is_permutation(rhs)) return false;
return true;
}

template<typename StringType>
bool DummyIndices<StringType>::is_contraction(
const DummyIndices& lhs, const DummyIndices& rhs) const noexcept {
if(has_repeated_indices()) return false;
if(lhs.has_repeated_indices()) return false;
if(rhs.has_repeated_indices()) return false;
auto lhs_cap_rhs = lhs.intersection(rhs);
if(lhs_cap_rhs.empty()) return false; // No common indices
if(!intersection(lhs_cap_rhs).empty())
return false; // Common index not summed
return true;
}

} // namespace tensorwrapper::dsl
Loading

0 comments on commit 625f90c

Please sign in to comment.