Skip to content

Commit

Permalink
Add template argument to class SparseDecomposableStatisticView.
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-rapp committed Jan 21, 2025
1 parent fea1fab commit 432e498
Show file tree
Hide file tree
Showing 14 changed files with 330 additions and 314 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,12 @@ namespace boosting {
* decomposable loss function in a C-contiguous array. For each element in the vector, a single gradient and
* Hessian, as well as the sums of the weights of the aggregated gradients and Hessians, is stored.
*
* @tparam WeightType The type of the weights
* @param StatisticType The type of the gradients and Hessians
* @tparam WeightType The type of the weights
*/
template<typename WeightType>
template<typename StatisticType, typename WeightType>
class SparseDecomposableStatisticVector final
: public VectorDecorator<AllocatedVector<SparseStatistic<float64, WeightType>>> {
: public VectorDecorator<AllocatedVector<SparseStatistic<StatisticType, WeightType>>> {
private:

/**
Expand All @@ -123,7 +124,7 @@ namespace boosting {
class ConstIterator final {
private:

typename View<SparseStatistic<float64, WeightType>>::const_iterator iterator_;
typename View<SparseStatistic<StatisticType, WeightType>>::const_iterator iterator_;

const WeightType sumOfWeights_;

Expand All @@ -134,7 +135,7 @@ namespace boosting {
* `SparseDecomposableStatisticVector`
* @param sumOfWeights The sum of the weights of all statistics that have been added to the vector
*/
ConstIterator(typename View<SparseStatistic<float64, WeightType>>::const_iterator iterator,
ConstIterator(typename View<SparseStatistic<StatisticType, WeightType>>::const_iterator iterator,
WeightType sumOfWeights);

/**
Expand All @@ -145,12 +146,12 @@ namespace boosting {
/**
* The type of the elements, the iterator provides access to.
*/
typedef const Statistic<float64> value_type;
typedef const Statistic<StatisticType> value_type;

/**
* The type of a pointer to an element, the iterator provides access to.
*/
typedef const Statistic<float64>* pointer;
typedef const Statistic<StatisticType>* pointer;

/**
* The type of a reference to an element, the iterator provides access to.
Expand Down Expand Up @@ -244,7 +245,8 @@ namespace boosting {
/**
* @param other A reference to an object of type `SparseDecomposableStatisticVector` to be copied
*/
SparseDecomposableStatisticVector(const SparseDecomposableStatisticVector& other);
SparseDecomposableStatisticVector(
const SparseDecomposableStatisticVector<StatisticType, WeightType>& other);

/**
* An iterator that provides read-only access to the elements in the vector.
Expand All @@ -271,7 +273,7 @@ namespace boosting {
* @param vector A reference to an object of type `SparseDecomposableStatisticVector` that stores the
* gradients and Hessians to be added to this vector
*/
void add(const SparseDecomposableStatisticVector<WeightType>& vector);
void add(const SparseDecomposableStatisticVector<StatisticType, WeightType>& vector);

/**
* Adds all gradients and Hessians in a single row of a `SparseSetView` to this vector.
Expand All @@ -280,7 +282,7 @@ namespace boosting {
* be added to this vector
* @param row The index of the row to be added to this vector
*/
void add(const SparseSetView<Statistic<float64>>& view, uint32 row);
void add(const SparseSetView<Statistic<StatisticType>>& view, uint32 row);

/**
* Adds all gradients and Hessians in a single row of a `SparseSetView` to this vector. The gradients and
Expand All @@ -291,7 +293,7 @@ namespace boosting {
* @param row The index of the row to be added to this vector
* @param weight The weight, the gradients and Hessians should be multiplied by
*/
void add(const SparseSetView<Statistic<float64>>& view, uint32 row, WeightType weight);
void add(const SparseSetView<Statistic<StatisticType>>& view, uint32 row, WeightType weight);

/**
* Removes all gradients and Hessians in a single row of a `SparseSetView` from this vector.
Expand All @@ -300,7 +302,7 @@ namespace boosting {
* be removed from this vector
* @param row The index of the row to be removed from this vector
*/
void remove(const SparseSetView<Statistic<float64>>& view, uint32 row);
void remove(const SparseSetView<Statistic<StatisticType>>& view, uint32 row);

/**
* Removes all gradients and Hessians in a single row of a `SparseSetView` from this vector. The gradients
Expand All @@ -311,7 +313,7 @@ namespace boosting {
* @param row The index of the row to be removed from this vector
* @param weight The weight, the gradients and Hessians should be multiplied by
*/
void remove(const SparseSetView<Statistic<float64>>& view, uint32 row, WeightType weight);
void remove(const SparseSetView<Statistic<StatisticType>>& view, uint32 row, WeightType weight);

/**
* Adds certain gradients and Hessians in a single row of a `SparseSetView`, whose positions are given as a
Expand All @@ -322,7 +324,7 @@ namespace boosting {
* @param row The index of the row to be added to this vector
* @param indices A reference to a `CompleteIndexVector' that provides access to the indices
*/
void addToSubset(const SparseSetView<Statistic<float64>>& view, uint32 row,
void addToSubset(const SparseSetView<Statistic<StatisticType>>& view, uint32 row,
const CompleteIndexVector& indices);

/**
Expand All @@ -334,7 +336,7 @@ namespace boosting {
* @param row The index of the row to be added to this vector
* @param indices A reference to a `PartialIndexVector' that provides access to the indices
*/
void addToSubset(const SparseSetView<Statistic<float64>>& view, uint32 row,
void addToSubset(const SparseSetView<Statistic<StatisticType>>& view, uint32 row,
const PartialIndexVector& indices);

/**
Expand All @@ -348,7 +350,7 @@ namespace boosting {
* @param indices A reference to a `CompleteIndexVector' that provides access to the indices
* @param weight The weight, the gradients and Hessians should be multiplied by
*/
void addToSubset(const SparseSetView<Statistic<float64>>& view, uint32 row,
void addToSubset(const SparseSetView<Statistic<StatisticType>>& view, uint32 row,
const CompleteIndexVector& indices, WeightType weight);

/**
Expand All @@ -362,7 +364,7 @@ namespace boosting {
* @param indices A reference to a `PartialIndexVector' that provides access to the indices
* @param weight The weight, the gradients and Hessians should be multiplied by
*/
void addToSubset(const SparseSetView<Statistic<float64>>& view, uint32 row,
void addToSubset(const SparseSetView<Statistic<StatisticType>>& view, uint32 row,
const PartialIndexVector& indices, WeightType weight);

/**
Expand All @@ -377,9 +379,9 @@ namespace boosting {
* @param second A reference to an object of type `SparseDecomposableStatisticVector` that stores the
* gradients and Hessians in the second vector
*/
void difference(const SparseDecomposableStatisticVector<WeightType>& first,
void difference(const SparseDecomposableStatisticVector<StatisticType, WeightType>& first,
const CompleteIndexVector& firstIndices,
const SparseDecomposableStatisticVector<WeightType>& second);
const SparseDecomposableStatisticVector<StatisticType, WeightType>& second);

/**
* Sets the gradients and Hessians in this vector to the difference `first - second` between the gradients
Expand All @@ -393,9 +395,9 @@ namespace boosting {
* @param second A reference to an object of type `SparseDecomposableStatisticVector` that stores the
* gradients and Hessians in the second vector
*/
void difference(const SparseDecomposableStatisticVector<WeightType>& first,
void difference(const SparseDecomposableStatisticVector<StatisticType, WeightType>& first,
const PartialIndexVector& firstIndices,
const SparseDecomposableStatisticVector<WeightType>& second);
const SparseDecomposableStatisticVector<StatisticType, WeightType>& second);

/**
* Sets all gradients and Hessians stored in this vector to zero.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,20 @@ namespace boosting {
const DenseDecomposableStatisticVector<float64>& statisticVector,
const PartialIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<uint32>>> create(
const SparseDecomposableStatisticVector<uint32>& statisticVector,
std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float64, uint32>>> create(
const SparseDecomposableStatisticVector<float64, uint32>& statisticVector,
const CompleteIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<uint32>>> create(
const SparseDecomposableStatisticVector<uint32>& statisticVector,
std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float64, uint32>>> create(
const SparseDecomposableStatisticVector<float64, uint32>& statisticVector,
const PartialIndexVector& indexVector) const override;

virtual std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float32>>> create(
const SparseDecomposableStatisticVector<float32>& statisticVector,
virtual std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float64, float32>>> create(
const SparseDecomposableStatisticVector<float64, float32>& statisticVector,
const CompleteIndexVector& indexVector) const override;

virtual std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float32>>> create(
const SparseDecomposableStatisticVector<float32>& statisticVector,
virtual std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float64, float32>>> create(
const SparseDecomposableStatisticVector<float64, float32>& statisticVector,
const PartialIndexVector& indexVector) const override;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,20 @@ namespace boosting {
const DenseDecomposableStatisticVector<float64>& statisticVector,
const PartialIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<uint32>>> create(
const SparseDecomposableStatisticVector<uint32>& statisticVector,
std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float64, uint32>>> create(
const SparseDecomposableStatisticVector<float64, uint32>& statisticVector,
const CompleteIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<uint32>>> create(
const SparseDecomposableStatisticVector<uint32>& statisticVector,
std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float64, uint32>>> create(
const SparseDecomposableStatisticVector<float64, uint32>& statisticVector,
const PartialIndexVector& indexVector) const override;

virtual std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float32>>> create(
const SparseDecomposableStatisticVector<float32>& statisticVector,
virtual std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float64, float32>>> create(
const SparseDecomposableStatisticVector<float64, float32>& statisticVector,
const CompleteIndexVector& indexVector) const override;

virtual std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float32>>> create(
const SparseDecomposableStatisticVector<float32>& statisticVector,
virtual std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float64, float32>>> create(
const SparseDecomposableStatisticVector<float64, float32>& statisticVector,
const PartialIndexVector& indexVector) const override;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,20 @@ namespace boosting {
const DenseDecomposableStatisticVector<float64>& statisticVector,
const PartialIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<uint32>>> create(
const SparseDecomposableStatisticVector<uint32>& statisticVector,
std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float64, uint32>>> create(
const SparseDecomposableStatisticVector<float64, uint32>& statisticVector,
const CompleteIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<uint32>>> create(
const SparseDecomposableStatisticVector<uint32>& statisticVector,
std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float64, uint32>>> create(
const SparseDecomposableStatisticVector<float64, uint32>& statisticVector,
const PartialIndexVector& indexVector) const override;

virtual std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float32>>> create(
const SparseDecomposableStatisticVector<float32>& statisticVector,
virtual std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float64, float32>>> create(
const SparseDecomposableStatisticVector<float64, float32>& statisticVector,
const CompleteIndexVector& indexVector) const override;

virtual std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float32>>> create(
const SparseDecomposableStatisticVector<float32>& statisticVector,
virtual std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float64, float32>>> create(
const SparseDecomposableStatisticVector<float64, float32>& statisticVector,
const PartialIndexVector& indexVector) const override;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,20 @@ namespace boosting {
const DenseDecomposableStatisticVector<float64>& statisticVector,
const PartialIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<uint32>>> create(
const SparseDecomposableStatisticVector<uint32>& statisticVector,
std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float64, uint32>>> create(
const SparseDecomposableStatisticVector<float64, uint32>& statisticVector,
const CompleteIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<uint32>>> create(
const SparseDecomposableStatisticVector<uint32>& statisticVector,
std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float64, uint32>>> create(
const SparseDecomposableStatisticVector<float64, uint32>& statisticVector,
const PartialIndexVector& indexVector) const override;

virtual std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float32>>> create(
const SparseDecomposableStatisticVector<float32>& statisticVector,
virtual std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float64, float32>>> create(
const SparseDecomposableStatisticVector<float64, float32>& statisticVector,
const CompleteIndexVector& indexVector) const override;

virtual std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float32>>> create(
const SparseDecomposableStatisticVector<float32>& statisticVector,
virtual std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float64, float32>>> create(
const SparseDecomposableStatisticVector<float64, float32>& statisticVector,
const PartialIndexVector& indexVector) const override;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,20 @@ namespace boosting {
const DenseDecomposableStatisticVector<float64>& statisticVector,
const PartialIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<uint32>>> create(
const SparseDecomposableStatisticVector<uint32>& statisticVector,
std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float64, uint32>>> create(
const SparseDecomposableStatisticVector<float64, uint32>& statisticVector,
const CompleteIndexVector& indexVector) const override;

std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<uint32>>> create(
const SparseDecomposableStatisticVector<uint32>& statisticVector,
std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float64, uint32>>> create(
const SparseDecomposableStatisticVector<float64, uint32>& statisticVector,
const PartialIndexVector& indexVector) const override;

virtual std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float32>>> create(
const SparseDecomposableStatisticVector<float32>& statisticVector,
virtual std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float64, float32>>> create(
const SparseDecomposableStatisticVector<float64, float32>& statisticVector,
const CompleteIndexVector& indexVector) const override;

virtual std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float32>>> create(
const SparseDecomposableStatisticVector<float32>& statisticVector,
virtual std::unique_ptr<IRuleEvaluation<SparseDecomposableStatisticVector<float64, float32>>> create(
const SparseDecomposableStatisticVector<float64, float32>& statisticVector,
const PartialIndexVector& indexVector) const override;
};

Expand Down
Loading

0 comments on commit 432e498

Please sign in to comment.