Skip to content

Commit

Permalink
r2g
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanmrichard committed Jul 9, 2024
1 parent 5628e06 commit 6359e9b
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 230 deletions.
81 changes: 2 additions & 79 deletions include/tensorwrapper/layout/tiled.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#pragma once
#include <tensorwrapper/detail_/polymorphic_base.hpp>
#include <tensorwrapper/shape/shape_base.hpp>
#include <tensorwrapper/sparsity/pattern.hpp>
#include <tensorwrapper/symmetry/group.hpp>
Expand All @@ -24,7 +25,7 @@ namespace tensorwrapper::layout {
/** @brief Describes how the tensor is actually laid out.
*
*/
class Tiled {
class Tiled : public detail_::PolymorphicBase<Tiled> {
public:
/// Type all layouts derive from
using layout_base = Tiled;
Expand Down Expand Up @@ -91,14 +92,6 @@ class Tiled {
/// Defaulted polymorphic dtor
virtual ~Tiled() noexcept = default;

/** @brief Make a polymorphic deep copy of *this.
*
* @return A pointer to the deep copy of *this.
*
* @throw std::bad_alloc if allocating the copy fails.
*/
layout_pointer clone() const { return clone_(); }

// -------------------------------------------------------------------------
// -- State methods
// -------------------------------------------------------------------------
Expand Down Expand Up @@ -143,26 +136,6 @@ class Tiled {
// -- Utility methods
// -------------------------------------------------------------------------

/** @brief Is *this polymorphically value equal to @p rhs?
*
* This method is used to compare *this to @p rhs polymorphically. More
* specifically both *this and @p rhs will be downcasted to their most
* derived class. If the most derived class of both *this and @p rhs is
* the same, then *this will be compared to @p rhs via the most derived
* class's value comparison operator. If *this and @p rhs do not have the
* same most derived class this method returns false.
*
* @param[in] rhs The object to compare to.
*
* @return True if *this is polymorphically value equal to @p rhs and
* false otherwise.
*
* @throw None No throw guarantee.
*/
bool are_equal(const layout_base& rhs) const noexcept {
return are_equal_(rhs) && rhs.are_equal_(*this);
}

/** @brief Is *this value equal to @p rhs?
*
* Two Tiled objects are value equal if they both don't have shapes or if
Expand Down Expand Up @@ -213,18 +186,6 @@ class Tiled {
m_symmetry_(other.m_symmetry_),
m_sparsity_(other.m_sparsity_) {}

/** @brief Implements clone
*
* Derived classes should override this method so that it makes a deep
* copy of *this via the derived class's copy ctor.
*
* @return A deep copy of *this.
*
* @throw std::bad_alloc if there is a problem allocating the copy. Strong
* throw guarantee.
*/
virtual layout_pointer clone_() const = 0;

/** @brief Implements tile_size.
*
* For now this is an abstract method. When tiling is actually supported
Expand All @@ -237,44 +198,6 @@ class Tiled {
*/
virtual size_type tile_size_() const noexcept = 0;

/** @brief Helps derived classes implement are_equal.
*
* @tparam DerivedClass Type of the class implementing are_equal. Must be
* provided by the caller.
*
* This method wraps the process of downcasting *this and @p rhs to objects
* of type @p DerivedClass and then comparing them. This method will also
* handle the logic for when the downcasts fail.
*
* @param[in] rhs The object to compare against.
*
* @return True if *this and @p rhs compare value equal when compared as
* objects of type @p DerivedClass and false otherwise.
*
* @throw None No throw guarantee.
*/
template<typename DerivedType>
bool are_equal_impl_(const layout_base& rhs) const noexcept {
auto plhs = dynamic_cast<const DerivedType*>(this);
auto prhs = dynamic_cast<const DerivedType*>(&rhs);
if(plhs == nullptr || prhs == nullptr) return false;
return (*plhs) == (*prhs);
}

/** @brief Implements are_equal
*
* Derived classes should override this method so that it calls
* are_equal_impl_.
*
* @param[in] rhs The object to compare to.
*
* @return True if *this compared via its most derived class is value equal
* to @p rhs as its most derived class and false otherwise.
*
* @throw None No throw guarantee
*/
virtual bool are_equal_(const layout_base& rhs) const noexcept = 0;

private:
/// Ctor all other value ctors dispatch to
Tiled(shape_pointer shape, symmetry_type symmetry, sparsity_type sparsity) :
Expand Down
91 changes: 2 additions & 89 deletions include/tensorwrapper/shape/shape_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#pragma once
#include <cstddef>
#include <memory>

#include <tensorwrapper/detail_/polymorphic_base.hpp>
namespace tensorwrapper::shape {

/** @brief Code factorization for the various types of shapes.
Expand All @@ -34,7 +34,7 @@ namespace tensorwrapper::shape {
* - get_rank_()
* - get_size_()
*/
class ShapeBase {
class ShapeBase : public detail_::PolymorphicBase<ShapeBase> {
public:
/// Type all shapes inherit from
using shape_base = ShapeBase;
Expand All @@ -54,15 +54,6 @@ class ShapeBase {
/// Defaulted polymorphic dtor
virtual ~ShapeBase() noexcept = default;

/** @brief Deep polymorphic copy of *this.
*
* @return A pointer to a deep copy of *this.
*
* @throw std::bad_alloc if there is a problem allocating the copy. Strong
* throw guarantee.
*/
base_pointer clone() const { return clone_(); }

/** @brief The total rank of of the tensor described by *this.
*
* In the simplest terms, the total rank of a tensor is the number of
Expand Down Expand Up @@ -92,44 +83,7 @@ class ShapeBase {
*/
size_type size() const noexcept { return get_size_(); }

/** @brief Polymorphic value comparison.
*
* This method is used to compare two ShapeBase objects polymorphically.
* The instances will be cast to their most derived type. If the most
* derived types are the same then the objects will be value compared as
* derived objects.
*
* @param[in] rhs The object to compare against.
*
* @return True if *this is polymorphically value equal to @p rhs and false
* otherwise.
*
* @throw None No throw guarantee.
*/
bool are_equal(const ShapeBase& rhs) const noexcept {
return are_equal_(rhs) && rhs.are_equal_(*this);
}

protected:
/** @brief Used to implement clone()
*
* Derived classes should override this method to implement clone. In
* general, if the derived class's copy ctor is a deep copy, then one
* simply needs to do:
*
* @code
* // Replace DerivedType with the actual type of the derived class
* return std::make_unique<DerivedType>(*this);
* @endcode
*
* to implement clone_.
*
* @return A deep copy of *this, done polymorphically.
*
* @throw std::bad_alloc if the copy fails. Strong throw guarantee.
*/
virtual base_pointer clone_() const = 0;

/** @brief Used to implement rank().
*
* The derived class is responsible for implementing this method so that
Expand All @@ -154,47 +108,6 @@ class ShapeBase {
* subject to a no-throw guarantee.
*/
virtual size_type get_size_() const noexcept = 0;

/** @brief Called by derived class to implement are_equal_
*
* @tparam DerivedType The type of the derived class for which are_equal_
* is being implemented. Derived class must provide
* this value.
*
* This method is a convience method for implementing are_equal_. Derived
* classes need only call this method from their overload of are_equal_
* to implement are_equal_.
*
* @param[in] rhs The shape to compare to.
*
* @return True if *this and @p rhs are convertible to DerivedType objects
* and if, when viewed as DerivedType objects, *this and @p rhs
* are value equal. False otherwise.
*
* @throw None No throw guarantee.
*/
template<typename DerivedType>
bool are_equal_impl_(const ShapeBase& rhs) const noexcept {
auto pthis = dynamic_cast<const DerivedType*>(this);
auto prhs = dynamic_cast<const DerivedType*>(&rhs);
if(pthis == nullptr || prhs == nullptr) return false;
return (*pthis) == (*prhs);
}

/** @brief Derived class overrides to implement are_equal.
*
* Derived classes should implement this method by calling are_equal_impl_.
* This assumes that the derived class has implemented a non-polymorphic
* value equality check via operator==.
*
* @param[in] rhs The shape to compare to.
*
* @return True if *this is value equal to @p rhs (when compared as objects
* of *this most derived type) and false otherwise.
*
* @throw None No throw guarantee.
*/
virtual bool are_equal_(const ShapeBase& rhs) const noexcept = 0;
};

} // namespace tensorwrapper::shape
64 changes: 2 additions & 62 deletions include/tensorwrapper/symmetry/operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once
#include <memory>
#include <tensorwrapper/detail_/polymorphic_base.hpp>

namespace tensorwrapper::symmetry {

Expand All @@ -25,7 +26,7 @@ namespace tensorwrapper::symmetry {
* API. This class defines that API. The Operation class itself models a
* transformation which when applied to a tensor leaves the tensor unchanged.
*/
class Operation {
class Operation : public detail_::PolymorphicBase<Operation> {
public:
/// Common base class for all symmetry operations
using base_type = Operation;
Expand All @@ -49,18 +50,6 @@ class Operation {
/// Defaulted no-throw dtor
virtual ~Operation() noexcept = default;

/** @brief Polymorphic copy constructor.
*
* Derived classes implement this method by overriding clone_
*
* @return A deep copy of the derived class, returned as a pointer to
* *this.
*
* @throw std::bad_alloc if there is a problem allocating the new state.
* Strong throw guarantee.
*/
base_pointer clone() const { return clone_(); }

// -------------------------------------------------------------------------
// - Properties
// -------------------------------------------------------------------------
Expand All @@ -71,58 +60,9 @@ class Operation {
// -- Utility methods
// -------------------------------------------------------------------------

/** @brief Determines if two Operation objects are polymorphically value
* equal.
*
* Two Operation objects @f$a@f$ and @f$b@f$ are polymorphically value
* equal if the most derived class of @f$a@f$, @f$A@f$ is the same as the
* most derived class of @f$b@f$ and if when compared as objects of typ
* @f$A@f$ @f$a@f$ anb @f$b@f$ are value equal.
*
* @param[in] rhs The object to compare to.
*
* @return True if *this is polymorphically value equal to @p rhs and false
* otherwise.
*
* @throw None No throw guarantee.
*/
bool are_equal(const_base_reference rhs) const noexcept {
return are_equal_(rhs) && rhs.are_equal_(*this);
}

protected:
/** @brief Derived class should call to implement are_equal_
*
* @tparam DerivedType The class we are implementing are_equal for.
*
* Assuming the derived class implements operator== for non-polymorphic
* comparison, then are_equal can be implemented generically given the
* type of the derived class. This method is that generic implementation
* and should be called by the derived class.
*
* @param[in] rhs The object to polymorphically compare to *this.
*
* @return True if @p other compares value equal to *this and false
* otherwise.
*
* @throw None No throw guarantee.
*/
template<typename DerivedType>
bool are_equal_impl_(const_base_reference rhs) const noexcept {
auto pthis = dynamic_cast<const DerivedType*>(this);
auto prhs = dynamic_cast<const DerivedType*>(&rhs);
if(pthis == nullptr || prhs == nullptr) return false;
return (*pthis) == (*prhs);
}

/// Derived class should overwrite to implement clone()
virtual base_pointer clone_() const = 0;

/// Derived class should overwrite to implement is_identity
virtual bool is_identity_() const noexcept = 0;

/// Derived class should overwrite to implement are_equal()
virtual bool are_equal_(const_base_reference rhs) const noexcept = 0;
};

} // namespace tensorwrapper::symmetry

0 comments on commit 6359e9b

Please sign in to comment.