Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update to use .index_max() #1

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/mlpack/methods/adaboost/adaboost_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ void AdaBoost<WeakLearnerType, MatType>::Classify(
for (size_t i = 0; i < predictedLabels.n_cols; ++i)
{
probabilities.col(i) /= accu(probabilities.col(i));
probabilities.col(i).max(maxIndex);
maxIndex = probabilities.col(i).index_max();
predictedLabels(i) = maxIndex;
}
}
Expand Down
3 changes: 1 addition & 2 deletions src/mlpack/methods/approx_kfn/drusilla_select_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ void DrusillaSelect<MatType>::Train(
for (size_t i = 0; i < l; ++i)
{
// Pick best index.
arma::uword maxIndex = 0;
norms.max(maxIndex);
arma::uword maxIndex = norms.index_max();

arma::vec line(refCopy.col(maxIndex) / norm(refCopy.col(maxIndex)));

Expand Down
3 changes: 1 addition & 2 deletions src/mlpack/methods/decision_tree/decision_tree_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1159,8 +1159,7 @@ void DecisionTree<FitnessFunction,

// Now normalize into probabilities.
classProbabilities /= UseWeights ? sumWeights : labels.n_elem;
arma::uword maxIndex = 0;
classProbabilities.max(maxIndex);
arma::uword maxIndex = classProbabilities.index_max();
majorityClass = (size_t) maxIndex;
}

Expand Down
5 changes: 3 additions & 2 deletions src/mlpack/methods/hmm/hmm_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,13 +528,14 @@ double HMM<Distribution>::Predict(const arma::mat& dataSeq,
for (size_t j = 0; j < logTransition.n_rows; j++)
{
arma::vec prob = logStateProb.col(t - 1) + logTransition.row(j).t();
logStateProb(j, t) = prob.max(index) + logProbs(t, j);
index = prob.index_max();
logStateProb(j, t) = prob[index] + logProbs(t, j);
stateSeqBack(j, t) = index;
}
}

// Backtrack to find the most probable state sequence.
logStateProb.unsafe_col(dataSeq.n_cols - 1).max(index);
index = logStateProb.unsafe_col(dataSeq.n_cols - 1).index_max();
stateSeq[dataSeq.n_cols - 1] = index;
for (size_t t = 2; t <= dataSeq.n_cols; t++)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,9 @@ void BinaryNumericSplit<FitnessFunction, ObservationType>::Split(
}

// Calculate the majority classes of the children.
arma::uword maxIndex;
counts.unsafe_col(0).max(maxIndex);
arma::uword maxIndex = counts.unsafe_col(0).index_max();
childMajorities[0] = size_t(maxIndex);
counts.unsafe_col(1).max(maxIndex);
maxIndex = counts.unsafe_col(1).index_max();
childMajorities[1] = size_t(maxIndex);

// Create the according SplitInfo object.
Expand All @@ -155,8 +154,7 @@ template<typename FitnessFunction, typename ObservationType>
size_t BinaryNumericSplit<FitnessFunction, ObservationType>::MajorityClass()
const
{
arma::uword maxIndex;
classCounts.max(maxIndex);
arma::uword maxIndex = classCounts.index_max();
return size_t(maxIndex);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ void HoeffdingCategoricalSplit<FitnessFunction>::Split(
childMajorities.set_size(sufficientStatistics.n_cols);
for (size_t i = 0; i < sufficientStatistics.n_cols; ++i)
{
arma::uword maxIndex = 0;
sufficientStatistics.unsafe_col(i).max(maxIndex);
arma::uword maxIndex = sufficientStatistics.unsafe_col(i).index_max();
childMajorities[i] = size_t(maxIndex);
}

Expand All @@ -79,8 +78,7 @@ size_t HoeffdingCategoricalSplit<FitnessFunction>::MajorityClass() const
// Calculate the class that we have seen the most of.
arma::Col<size_t> classCounts = sum(sufficientStatistics, 1);

arma::uword maxIndex = 0;
classCounts.max(maxIndex);
arma::uword maxIndex = classCounts.index_max();

return size_t(maxIndex);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ void HoeffdingNumericSplit<FitnessFunction, ObservationType>::Split(
childMajorities.set_size(sufficientStatistics.n_cols);
for (size_t i = 0; i < sufficientStatistics.n_cols; ++i)
{
arma::uword maxIndex = 0;
sufficientStatistics.unsafe_col(i).max(maxIndex);
arma::uword maxIndex = sufficientStatistics.unsafe_col(i).index_max();
childMajorities[i] = size_t(maxIndex);
}

Expand All @@ -144,8 +143,7 @@ size_t HoeffdingNumericSplit<FitnessFunction, ObservationType>::
for (size_t i = 0; i < samplesSeen; ++i)
classes[labels[i]]++;

arma::uword majorityClass;
classes.max(majorityClass);
arma::uword majorityClass = classes.index_max();
return size_t(majorityClass);
}
else
Expand All @@ -154,8 +152,7 @@ size_t HoeffdingNumericSplit<FitnessFunction, ObservationType>::
// statistics.
arma::Col<size_t> classCounts = sum(sufficientStatistics, 1);

arma::uword maxIndex = 0;
classCounts.max(maxIndex);
arma::uword maxIndex = classCounts.index_max();
return size_t(maxIndex);
}
}
Expand Down
3 changes: 1 addition & 2 deletions src/mlpack/methods/kmeans/max_variance_new_cluster_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ void MaxVarianceNewCluster::EmptyCluster(const MatType& data,
this->iteration = iteration;

// Now find the cluster with maximum variance.
arma::uword maxVarCluster = 0;
variances.max(maxVarCluster);
arma::uword maxVarCluster = variances.index_max();

// If the cluster with maximum variance has variance of 0, then we can't
// continue. All the points are the same.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,7 @@ void NaiveBayesClassifier<ModelMatType>::Classify(
// Now calculate maximum probabilities for each point.
for (size_t i = 0; i < data.n_cols; ++i)
{
arma::uword maxIndex = 0;
logLikelihoods.unsafe_col(i).max(maxIndex);
arma::uword maxIndex = logLikelihoods.unsafe_col(i).index_max();
predictions[i] = maxIndex;
}
}
Expand Down
8 changes: 4 additions & 4 deletions src/mlpack/methods/perceptron/perceptron_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ void Perceptron<
// Multiply for each variable and check whether the current weight vector
// correctly classifies this.
tempLabelMat = weights.t() * data.col(j) + biases;

tempLabelMat.max(maxIndexRow, maxIndexCol);
maxIndexRow = arma::ind2sub(arma::size(tempLabelMat), tempLabelMat.index_max())(0);

// Check whether prediction is correct.
if (maxIndexRow != labels(0, j))
Expand Down Expand Up @@ -289,7 +289,7 @@ size_t Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Classify(
arma::uword maxIndex = 0;

tempLabelVec = weights.t() * point + biases;
tempLabelVec.max(maxIndex);
maxIndex = tempLabelVec.index_max();

return size_t(maxIndex);
}
Expand Down Expand Up @@ -322,7 +322,7 @@ void Perceptron<LearnPolicy, WeightInitializationPolicy, MatType>::Classify(
for (size_t i = 0; i < test.n_cols; ++i)
{
tempLabelMat = weights.t() * test.col(i) + biases;
tempLabelMat.max(maxIndex);
maxIndex = tempLabelMat.index_max();
predictedLabels(i) = maxIndex;
}
}
Expand Down
3 changes: 1 addition & 2 deletions src/mlpack/methods/radical/radical_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ inline typename MatType::elem_type Radical::Apply2D(const MatType& matX,
values(i) = Vasicek(candidateY1, m) + Vasicek(candidateY2, m);
}

arma::uword indOpt = 0;
values.min(indOpt); // we ignore the return value; we don't care about it
arma::uword indOpt = values.index_min();
return (indOpt / (ElemType) angles) * M_PI / 2.0;
}

Expand Down
3 changes: 1 addition & 2 deletions src/mlpack/methods/random_forest/random_forest_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,7 @@ void RandomForest<

// Find maximum element after renormalizing probabilities.
probabilities /= trees.size();
arma::uword maxIndex = 0;
probabilities.max(maxIndex);
arma::uword maxIndex = probabilities.index_max();

// Set prediction.
prediction = (size_t) maxIndex;
Expand Down