From e45d3eb7dfeb5e5f6a41768a83d8a03d8ec1ec2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois-David=20Collin?= Date: Thu, 27 Aug 2020 12:37:56 +0200 Subject: [PATCH] Bug in oob error rate vs number of trees result --- src/ForestOnline.cpp | 4 ++++ src/ForestOnline.hpp | 2 +- src/ForestOnlineClassification.cpp | 6 ++++-- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/ForestOnline.cpp b/src/ForestOnline.cpp index 69e6942..7729b52 100644 --- a/src/ForestOnline.cpp +++ b/src/ForestOnline.cpp @@ -135,6 +135,9 @@ void ForestOnline::init(std::string dependent_variable_name, MemoryMode memory_m if (!prediction_mode && order_snps) { data->orderSnpLevels(dependent_variable_name, (importance_mode == IMP_GINI_CORRECTED)); } + + tree_order = std::vector(num_trees); + } void ForestOnline::run(bool verbose, bool compute_oob_error) { @@ -602,6 +605,7 @@ void ForestOnline::growTreesInThread(uint thread_idx, std::vector* varia trees[i]->predict(predict_data,false); predictInternal(i); mutex.lock(); + tree_order[progress] = i; ++progress; if (verbose_out) { #ifdef PYTHON_OUTPUT diff --git a/src/ForestOnline.hpp b/src/ForestOnline.hpp index c017c5a..466b53a 100644 --- a/src/ForestOnline.hpp +++ b/src/ForestOnline.hpp @@ -181,7 +181,7 @@ class ForestOnline { PredictionType prediction_type; uint num_random_splits; uint max_depth; - + std::vector tree_order; // MAXSTAT splitrule double alpha; double minprop; diff --git a/src/ForestOnlineClassification.cpp b/src/ForestOnlineClassification.cpp index 71d4bc7..baae0b7 100644 --- a/src/ForestOnlineClassification.cpp +++ b/src/ForestOnlineClassification.cpp @@ -172,7 +172,7 @@ void ForestOnlineClassification::calculateAfterGrow(size_t tree_idx, bool oob) { mutex_post.lock(); ++class_counts[sampleID][res]; mutex_post.unlock(); - if (!class_counts[sample_idx].empty()) + if (!class_counts[sampleID].empty()) to_add += (mostFrequentValue(class_counts[sampleID], random_number_generator) == data->get(sampleID,dependent_varID)) ? 0.0 : 1.0; } predictions[2][0][tree_idx] += to_add/static_cast(numOOB); @@ -237,7 +237,9 @@ void ForestOnlineClassification::computePredictionErrorInternal() for(auto sample_idx = 0; sample_idx < predict_data->getNumRows(); sample_idx++) { predictions[1][0][sample_idx] = mostFrequentValue(class_count[sample_idx], random_number_generator); } - + std::vector sort_oob_trees(num_trees); + for(auto i = 0; i < num_trees; i++) sort_oob_trees[i] = predictions[2][0][tree_order[i]]; + predictions[2][0] = sort_oob_trees; } // #nocov start