-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #174 from NWChemEx/shape
Adds Smooth
- Loading branch information
Showing
7 changed files
with
505 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
/* | ||
* Copyright 2024 NWChemEx Community | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
#include <tensorwrapper/shape/shape_base.hpp> | ||
#include <tensorwrapper/shape/smooth.hpp> | ||
|
||
/** @brief Sublibrary focused on describing the geometry of the tensor. | ||
*/ | ||
namespace shape {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
/* | ||
* Copyright 2024 NWChemEx Community | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
#include <cstddef> | ||
|
||
namespace tensorwrapper::shape { | ||
|
||
/** @brief Code factorization for the various types of shapes. | ||
* | ||
* Full design details: | ||
* https://nwchemex.github.io/TensorWrapper/developer/design/shape.html | ||
* | ||
* All shapes posses a concept of: | ||
* - Total rank | ||
* - Total number of elements | ||
* | ||
* To respectively implement these features, classes derived from *this are | ||
* expected to implement: | ||
* - get_rank_() | ||
* - get_size_() | ||
*/ | ||
class ShapeBase { | ||
public: | ||
/// Type used to hold the rank of a tensor | ||
using rank_type = unsigned short; | ||
|
||
/// Type used to specify the number of elements in the shape | ||
using size_type = std::size_t; | ||
|
||
/// No-op for ShapeBase because ShapeBase has no state | ||
ShapeBase() noexcept = default; | ||
|
||
/// Defaulted polymorphic dtor | ||
virtual ~ShapeBase() noexcept = default; | ||
|
||
/** @brief The total rank of of the tensor described by *this. | ||
* | ||
* In the simplest terms, the total rank of a tensor is the number of | ||
* offsets needed to uniquely distinguish among scalar elements. For | ||
* example, a scalar is rank 0 (there is only a single element in the | ||
* tensor, so there is no offset needed). A column/row vector is rank 1 | ||
* because an offset for the row/column is needed. A matrix is rank 2 | ||
* because offsets for both the row and column are needed, etc. | ||
* | ||
* @return An object containing the rank of the tensor | ||
* associated with *this. | ||
* | ||
* @throw None No throw guarantee. | ||
*/ | ||
rank_type rank() const noexcept { return get_rank_(); } | ||
|
||
/** @brief The total number of elements in the tensor described by *this. | ||
* | ||
* Ultimately each tensor is simply a collection of scalar values arranged | ||
* into an array. This method is used to determine how many total scalars | ||
* are in this array. The total includes both implicit (for example zeros | ||
* in sparse data structures) and explicit elements. | ||
* | ||
* @return An object containing the number of elements in *this. | ||
* | ||
* @throw None No throw guarantee. | ||
*/ | ||
size_type size() const noexcept { return get_size_(); } | ||
|
||
protected: | ||
/** @brief Used to implement rank(). | ||
* | ||
* The derived class is responsible for implementing this method so that | ||
* it returns a `rank_type` object defining the rank of the derived class. | ||
* | ||
* @return The rank of the derived class. | ||
* | ||
* @throw None Derived classes are responsible for implementing this method | ||
* subject to a no-throw guarantee. | ||
*/ | ||
virtual rank_type get_rank_() const noexcept = 0; | ||
|
||
/** @brief Used to implement size(). | ||
* | ||
* The derived class is responsible for implementing this method so that | ||
* it returns a `size_type` object defining the total number of elements | ||
* in the derived class. | ||
* | ||
* @return The total number of elements in the derived class. | ||
* | ||
* @throw None Derived classes are responsible for implementing this method | ||
* subject to a no-throw guarantee. | ||
*/ | ||
virtual size_type get_size_() const noexcept = 0; | ||
}; | ||
|
||
} // namespace tensorwrapper::shape |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
/* | ||
* Copyright 2024 NWChemEx Community | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
#include <functional> | ||
#include <numeric> | ||
#include <tensorwrapper/shape/shape_base.hpp> | ||
#include <vector> | ||
|
||
namespace tensorwrapper::shape { | ||
|
||
/** @brief Describes the shape of a "traditional" tensor. | ||
* | ||
* Tensors are traditionally thought of as being (hyper-)rectangular arrays of | ||
* scalars. The geometry of such a shape is described by stating the | ||
* geometric dimension of the (hyper-)rectangle and the number of elements in | ||
* the array. | ||
*/ | ||
class Smooth : public ShapeBase { | ||
public: | ||
// Pull in base class's types | ||
using ShapeBase::rank_type; | ||
using ShapeBase::size_type; | ||
|
||
// ------------------------------------------------------------------------- | ||
// -- Ctors, assignment, and dtor | ||
// ------------------------------------------------------------------------- | ||
|
||
/** @brief Constructs *this with a statically specified number of extents. | ||
* | ||
* This ctor is used to create a Smooth object by explicitly providing | ||
* the extents. The number of extents must be known at compile time. For | ||
* a dynamic number of extents use the range ctor. | ||
* | ||
* @param[in] il The extents of the modes. | ||
* | ||
* @throw std::runtime_error if there is a problem allocating the internal | ||
* state. Strong throw guarantee. | ||
*/ | ||
Smooth(std::initializer_list<size_type> il) : | ||
Smooth(il.begin(), il.end()) {} | ||
|
||
/** @brief Range ctor. | ||
* | ||
* @tparam BeginItrType Expected to be a forward iterator which can be | ||
* dereferenced to an object of size_type. | ||
* @tparam EndItrType Expected to be a type which can be compared to an | ||
* object of type BeginItrType. | ||
* | ||
* This ctor is used to construct a Smooth object with the extent of each | ||
* mode provided by a pair of iterators. | ||
* | ||
* @param[in] begin An iterator pointing to the extent of mode 0. | ||
* @param[in] end An iterator pointing to just past the extent of | ||
* the last mode. | ||
* | ||
* @throw ??? If iterating, dereferencing the begin iterator, or comparing | ||
* the iterators throws. Same throw guarantee as the iterators | ||
* involved in the throw. | ||
* @throw std::bad_alloc if there is a problem allocating the internal | ||
* state. Strong throw guarantee. | ||
*/ | ||
template<typename BeginItrType, typename EndItrType> | ||
Smooth(BeginItrType&& begin, EndItrType&& end) : | ||
Smooth(extents_type(std::forward<BeginItrType>(begin), | ||
std::forward<EndItrType>(end))) {} | ||
|
||
/// Defaulted no-throw dtor. | ||
~Smooth() noexcept = default; | ||
|
||
// ------------------------------------------------------------------------- | ||
// -- Utility methods | ||
// ------------------------------------------------------------------------- | ||
|
||
/** @brief Exchanges the state in *this with that of @p other. | ||
* | ||
* @param[in,out] other The object to take the state from. After this | ||
* method is called @p other will have the same state that | ||
* *this previously had. | ||
* | ||
* @throw None No throw guarantee. | ||
*/ | ||
void swap(Smooth& other) noexcept { m_extents_.swap(other.m_extents_); } | ||
|
||
/** @brief Is *this the same shape as @p rhs? | ||
* | ||
* @note This is a non-polymorphic value comparison, i.e., any state in | ||
* *this or @p rhs that resides in derived classes is NOT considered | ||
* in this comparison. | ||
* | ||
* Two Smooth objects are value equal if they contain the same number of | ||
* modes and if their @f$i@f$-th modes have the same extent for all @f$i@f$ | ||
* in the range [0, rank()). | ||
* | ||
* @param[in] rhs The object to compare against. | ||
* | ||
* @return True if *this is value equal to @p rhs and false otherwise. | ||
* | ||
*/ | ||
bool operator==(const Smooth& rhs) const noexcept { | ||
return m_extents_ == rhs.m_extents_; | ||
} | ||
|
||
/** @brief Is *this different from @p rhs? | ||
* | ||
* @note This is a non-polymorphic value comparison, i.e., any state in | ||
* *this or @p rhs that resides in derived classes is NOT considered | ||
* in this comparison. | ||
* | ||
* This method defines "different" as not value equal. See `operator==` for | ||
* the definition of value equal. | ||
* | ||
* @param[in] rhs The object to compare to. | ||
* | ||
* @return False if *this is value equal to @p rhs and true otherwise. | ||
* | ||
* @throw None No throw guarantee. | ||
*/ | ||
bool operator!=(const Smooth& rhs) const noexcept { | ||
return !(*this == rhs); | ||
} | ||
|
||
protected: | ||
/// Implement rank by counting number of extents held by *this | ||
rank_type get_rank_() const noexcept override { | ||
return rank_type(m_extents_.size()); | ||
} | ||
|
||
/// Implement size by taking the product of the extents held by *this | ||
size_type get_size_() const noexcept override { | ||
return std::accumulate(m_extents_.begin(), m_extents_.end(), | ||
size_type(1), std::multiplies<size_type>()); | ||
} | ||
|
||
private: | ||
/// Type used to hold the extents of *this | ||
using extents_type = std::vector<size_type>; | ||
|
||
/// Constructs *this given an object of extents_type | ||
explicit Smooth(extents_type extents) : m_extents_(std::move(extents)) {} | ||
|
||
/// The length of each mode | ||
extents_type m_extents_; | ||
}; | ||
|
||
} // namespace tensorwrapper::shape |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,4 +15,5 @@ | |
*/ | ||
|
||
#pragma once | ||
#include <tensorwrapper/shape/shape.hpp> | ||
#include <tensorwrapper/tensor/tensor.hpp> |
Oops, something went wrong.