Skip to content

Commit

Permalink
Add template argument to class DenseNonDecomposableStatisticView.
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-rapp committed Jan 21, 2025
1 parent 0ed9d0a commit fea1fab
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 124 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ namespace boosting {
* gradients and Hessians to be added to this vector
* @param row The index of the row to be added to this vector
*/
void add(const DenseNonDecomposableStatisticView& view, uint32 row);
void add(const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row);

/**
* Adds all gradients and Hessians in a single row of a `DenseNonDecomposableStatisticView` to this vector.
Expand All @@ -175,7 +175,7 @@ namespace boosting {
* @param row The index of the row to be added to this vector
* @param weight The weight, the gradients and Hessians should be multiplied by
*/
void add(const DenseNonDecomposableStatisticView& view, uint32 row, StatisticType weight);
void add(const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row, StatisticType weight);

/**
* Removes all gradients and Hessians in a single row of a `DenseNonDecomposableStatisticView` from this
Expand All @@ -185,7 +185,7 @@ namespace boosting {
* gradients and Hessians to be removed from this vector
* @param row The index of the row to be removed from this vector
*/
void remove(const DenseNonDecomposableStatisticView& view, uint32 row);
void remove(const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row);

/**
* Removes all gradients and Hessians in a single row of a `DenseNonDecomposableStatisticView` from this
Expand All @@ -196,7 +196,7 @@ namespace boosting {
* @param row The index of the row to be removed from this vector
* @param weight The weight, the gradients and Hessians should be multiplied by
*/
void remove(const DenseNonDecomposableStatisticView& view, uint32 row, StatisticType weight);
void remove(const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row, StatisticType weight);

/**
* Adds certain gradients and Hessians in another vector, whose positions are given as a
Expand All @@ -207,7 +207,7 @@ namespace boosting {
* @param row The index of the row to be added to this vector
* @param indices A reference to a `CompleteIndexVector` that provides access to the indices
*/
void addToSubset(const DenseNonDecomposableStatisticView& view, uint32 row,
void addToSubset(const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row,
const CompleteIndexVector& indices);

/**
Expand All @@ -219,7 +219,7 @@ namespace boosting {
* @param row The index of the row to be added to this vector
* @param indices A reference to a `PartialIndexVector` that provides access to the indices
*/
void addToSubset(const DenseNonDecomposableStatisticView& view, uint32 row,
void addToSubset(const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row,
const PartialIndexVector& indices);

/**
Expand All @@ -233,7 +233,7 @@ namespace boosting {
* @param indices A reference to a `CompleteIndexVector` that provides access to the indices
* @param weight The weight, the gradients and Hessians should be multiplied by
*/
void addToSubset(const DenseNonDecomposableStatisticView& view, uint32 row,
void addToSubset(const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row,
const CompleteIndexVector& indices, StatisticType weight);

/**
Expand All @@ -247,7 +247,7 @@ namespace boosting {
* @param indices A reference to a `PartialIndexVector` that provides access to the indices
* @param weight The weight, the gradients and Hessians should be multiplied by
*/
void addToSubset(const DenseNonDecomposableStatisticView& view, uint32 row,
void addToSubset(const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row,
const PartialIndexVector& indices, StatisticType weight);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@ namespace boosting {
/**
* Implements row-wise read and write access to the gradients and Hessians that have been calculated using a
* non-decomposable loss function and are stored in pre-allocated C-contiguous arrays.
*
* @tparam StatisticType The type of the gradients and Hessians
*/
template<typename StatisticType>
class MLRLBOOSTING_API DenseNonDecomposableStatisticView
: public CompositeMatrix<AllocatedCContiguousView<float64>, AllocatedCContiguousView<float64>> {
: public CompositeMatrix<AllocatedCContiguousView<StatisticType>, AllocatedCContiguousView<StatisticType>> {
public:

/**
Expand All @@ -27,34 +30,34 @@ namespace boosting {
/**
* @param other A reference to an object of type `DenseNonDecomposableStatisticView` that should be copied
*/
DenseNonDecomposableStatisticView(DenseNonDecomposableStatisticView&& other);
DenseNonDecomposableStatisticView(DenseNonDecomposableStatisticView<StatisticType>&& other);

virtual ~DenseNonDecomposableStatisticView() override {}

/**
* An iterator that provides read-only access to the gradients.
*/
typedef AllocatedCContiguousView<float64>::value_const_iterator gradient_const_iterator;
typedef typename AllocatedCContiguousView<StatisticType>::value_const_iterator gradient_const_iterator;

/**
* An iterator that provides access to the gradients and allows to modify them.
*/
typedef AllocatedCContiguousView<float64>::value_iterator gradient_iterator;
typedef typename AllocatedCContiguousView<StatisticType>::value_iterator gradient_iterator;

/**
* An iterator that provides read-only access to the Hessians.
*/
typedef AllocatedCContiguousView<float64>::value_const_iterator hessian_const_iterator;
typedef typename AllocatedCContiguousView<StatisticType>::value_const_iterator hessian_const_iterator;

/**
* An iterator that provides access to the Hessians and allows to modify them.
*/
typedef AllocatedCContiguousView<float64>::value_iterator hessian_iterator;
typedef typename AllocatedCContiguousView<StatisticType>::value_iterator hessian_iterator;

/**
* An iterator that provides read-only access to the Hessians that correspond to the diagonal of the matrix.
*/
typedef DiagonalIterator<const float64> hessian_diagonal_const_iterator;
typedef DiagonalIterator<const StatisticType> hessian_diagonal_const_iterator;

/**
* Returns a `gradient_const_iterator` to the beginning of the gradients at a specific row.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ namespace boosting {
* predicted scores
* @param statisticView A reference to an object of type `DenseNonDecomposableStatisticView` to be updated
*/
virtual void updateNonDecomposableStatistics(uint32 exampleIndex,
const CContiguousView<const uint8>& labelMatrix,
const CContiguousView<float64>& scoreMatrix,
DenseNonDecomposableStatisticView& statisticView) const = 0;
virtual void updateNonDecomposableStatistics(
uint32 exampleIndex, const CContiguousView<const uint8>& labelMatrix,
const CContiguousView<float64>& scoreMatrix,
DenseNonDecomposableStatisticView<float64>& statisticView) const = 0;

/**
* Updates the statistics of the example at a specific index.
Expand All @@ -44,9 +44,9 @@ namespace boosting {
* predicted scores
* @param statisticView A reference to an object of type `DenseNonDecomposableStatisticView` to be updated
*/
virtual void updateNonDecomposableStatistics(uint32 exampleIndex, const BinaryCsrView& labelMatrix,
const CContiguousView<float64>& scoreMatrix,
DenseNonDecomposableStatisticView& statisticView) const = 0;
virtual void updateNonDecomposableStatistics(
uint32 exampleIndex, const BinaryCsrView& labelMatrix, const CContiguousView<float64>& scoreMatrix,
DenseNonDecomposableStatisticView<float64>& statisticView) const = 0;
};

/**
Expand All @@ -69,10 +69,10 @@ namespace boosting {
* @param statisticView A reference to an object of type `DenseNonDecomposableStatisticView` to be
* updated
*/
virtual void updateNonDecomposableStatistics(uint32 exampleIndex,
const CContiguousView<const float32>& regressionMatrix,
const CContiguousView<float64>& scoreMatrix,
DenseNonDecomposableStatisticView& statisticView) const = 0;
virtual void updateNonDecomposableStatistics(
uint32 exampleIndex, const CContiguousView<const float32>& regressionMatrix,
const CContiguousView<float64>& scoreMatrix,
DenseNonDecomposableStatisticView<float64>& statisticView) const = 0;

/**
* Updates the statistics of the example at a specific index.
Expand All @@ -85,10 +85,10 @@ namespace boosting {
* @param statisticView A reference to an object of type `DenseNonDecomposableStatisticView` to be
* updated
*/
virtual void updateNonDecomposableStatistics(uint32 exampleIndex,
const CsrView<const float32>& regressionMatrix,
const CContiguousView<float64>& scoreMatrix,
DenseNonDecomposableStatisticView& statisticView) const = 0;
virtual void updateNonDecomposableStatistics(
uint32 exampleIndex, const CsrView<const float32>& regressionMatrix,
const CContiguousView<float64>& scoreMatrix,
DenseNonDecomposableStatisticView<float64>& statisticView) const = 0;
};

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,49 +100,48 @@ namespace boosting {
}

template<typename StatisticType>
void DenseNonDecomposableStatisticVector<StatisticType>::add(const DenseNonDecomposableStatisticView& view,
uint32 row) {
void DenseNonDecomposableStatisticVector<StatisticType>::add(
const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row) {
util::addToView(this->gradients_begin(), view.gradients_cbegin(row), this->getNumGradients());
util::addToView(this->hessians_begin(), view.hessians_cbegin(row), this->getNumHessians());
}

template<typename StatisticType>
void DenseNonDecomposableStatisticVector<StatisticType>::add(const DenseNonDecomposableStatisticView& view,
uint32 row, StatisticType weight) {
void DenseNonDecomposableStatisticVector<StatisticType>::add(
const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row, StatisticType weight) {
util::addToViewWeighted(this->gradients_begin(), view.gradients_cbegin(row), this->getNumGradients(), weight);
util::addToViewWeighted(this->hessians_begin(), view.hessians_cbegin(row), this->getNumHessians(), weight);
}

template<typename StatisticType>
void DenseNonDecomposableStatisticVector<StatisticType>::remove(const DenseNonDecomposableStatisticView& view,
uint32 row) {
void DenseNonDecomposableStatisticVector<StatisticType>::remove(
const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row) {
util::removeFromView(this->gradients_begin(), view.gradients_cbegin(row), this->getNumGradients());
util::removeFromView(this->hessians_begin(), view.hessians_cbegin(row), this->getNumHessians());
}

template<typename StatisticType>
void DenseNonDecomposableStatisticVector<StatisticType>::remove(const DenseNonDecomposableStatisticView& view,
uint32 row, StatisticType weight) {
void DenseNonDecomposableStatisticVector<StatisticType>::remove(
const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row, StatisticType weight) {
util::removeFromViewWeighted(this->gradients_begin(), view.gradients_cbegin(row), this->getNumGradients(),
weight);
util::removeFromViewWeighted(this->hessians_begin(), view.hessians_cbegin(row), this->getNumHessians(), weight);
}

template<typename StatisticType>
void DenseNonDecomposableStatisticVector<StatisticType>::addToSubset(const DenseNonDecomposableStatisticView& view,
uint32 row,
const CompleteIndexVector& indices) {
void DenseNonDecomposableStatisticVector<StatisticType>::addToSubset(
const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row, const CompleteIndexVector& indices) {
util::addToView(this->gradients_begin(), view.gradients_cbegin(row), this->getNumGradients());
util::addToView(this->hessians_begin(), view.hessians_cbegin(row), this->getNumHessians());
}

template<typename StatisticType>
void DenseNonDecomposableStatisticVector<StatisticType>::addToSubset(const DenseNonDecomposableStatisticView& view,
uint32 row,
const PartialIndexVector& indices) {
void DenseNonDecomposableStatisticVector<StatisticType>::addToSubset(
const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row, const PartialIndexVector& indices) {
PartialIndexVector::const_iterator indexIterator = indices.cbegin();
util::addToView(this->gradients_begin(), view.gradients_cbegin(row), indexIterator, this->getNumGradients());
DenseNonDecomposableStatisticView::hessian_const_iterator hessiansBegin = view.hessians_cbegin(row);
typename DenseNonDecomposableStatisticView<StatisticType>::hessian_const_iterator hessiansBegin =
view.hessians_cbegin(row);

for (uint32 i = 0; i < this->getNumGradients(); i++) {
uint32 index = indexIterator[i];
Expand All @@ -152,21 +151,22 @@ namespace boosting {
}

template<typename StatisticType>
void DenseNonDecomposableStatisticVector<StatisticType>::addToSubset(const DenseNonDecomposableStatisticView& view,
uint32 row, const CompleteIndexVector& indices,
StatisticType weight) {
void DenseNonDecomposableStatisticVector<StatisticType>::addToSubset(
const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row, const CompleteIndexVector& indices,
StatisticType weight) {
util::addToViewWeighted(this->gradients_begin(), view.gradients_cbegin(row), this->getNumGradients(), weight);
util::addToViewWeighted(this->hessians_begin(), view.hessians_cbegin(row), this->getNumHessians(), weight);
}

template<typename StatisticType>
void DenseNonDecomposableStatisticVector<StatisticType>::addToSubset(const DenseNonDecomposableStatisticView& view,
uint32 row, const PartialIndexVector& indices,
StatisticType weight) {
void DenseNonDecomposableStatisticVector<StatisticType>::addToSubset(
const DenseNonDecomposableStatisticView<StatisticType>& view, uint32 row, const PartialIndexVector& indices,
StatisticType weight) {
PartialIndexVector::const_iterator indexIterator = indices.cbegin();
util::addToViewWeighted(this->gradients_begin(), view.gradients_cbegin(row), indexIterator,
this->getNumGradients(), weight);
DenseNonDecomposableStatisticView::hessian_const_iterator hessiansBegin = view.hessians_cbegin(row);
typename DenseNonDecomposableStatisticView<StatisticType>::hessian_const_iterator hessiansBegin =
view.hessians_cbegin(row);

for (uint32 i = 0; i < this->getNumGradients(); i++) {
uint32 index = indexIterator[i];
Expand Down
Loading

0 comments on commit fea1fab

Please sign in to comment.