diff --git a/src/analyzer/analyzer.cc b/src/analyzer/analyzer.cc index a588e5d..5d6c439 100644 --- a/src/analyzer/analyzer.cc +++ b/src/analyzer/analyzer.cc @@ -138,8 +138,8 @@ void Analyzer::RunOnePosition(const std::vector& moves) { // Fetch MCTS-agnostic per-move stats P and V. std::vector nodes; - for (Node* iter : tree.GetCurrentHead()->Children()) { - nodes.emplace_back(iter); + for (Node* child : tree.GetCurrentHead()->Children()) { + nodes.emplace_back(child); } std::sort(nodes.begin(), nodes.end(), [](const Node* a, const Node* b) { return a->GetNStarted() > b->GetNStarted(); @@ -261,4 +261,4 @@ void Analyzer::OnInfo(const ThinkingInfo& info) const { WriteToLog(res); } -} // namespace lczero \ No newline at end of file +} // namespace lczero diff --git a/src/mcts/node.cc b/src/mcts/node.cc index 49ebd12..49b48af 100644 --- a/src/mcts/node.cc +++ b/src/mcts/node.cc @@ -83,8 +83,8 @@ class Node::Pool { Node::Pool::FreeNode* Node::Pool::UnrollNodeTree(FreeNode* node) { if (!node->node.child_) return node; FreeNode* prev = node; - for (Node* iter = node->node.child_; iter; iter = iter->sibling_) { - FreeNode* next = reinterpret_cast(iter); + for (Node* child = node->node.child_; child; child = child->sibling_) { + FreeNode* next = reinterpret_cast(child); prev->next = next; prev = UnrollNodeTree(next); } @@ -278,8 +278,8 @@ void Node::UpdateMaxDepth(int depth) { bool Node::UpdateFullDepth(uint16_t* depth) { if (full_depth_ > *depth) return false; - for (Node* iter : Children()) { - if (*depth > iter->full_depth_) *depth = iter->full_depth_; + for (Node* child : Children()) { + if (*depth > child->full_depth_) *depth = child->full_depth_; } if (*depth >= full_depth_) { full_depth_ = ++*depth; @@ -309,8 +309,8 @@ V3TrainingData Node::GetV3TrainingData(GameResult game_result, float total_n = static_cast(n_ - 1); // First visit was expansion of it inself. std::memset(result.probabilities, 0, sizeof(result.probabilities)); - for (Node* iter : Children()) { - result.probabilities[iter->move_.as_nn_index()] = iter->n_ / total_n; + for (Node* child : Children()) { + result.probabilities[child->move_.as_nn_index()] = child->n_ / total_n; } // Populate planes. @@ -407,4 +407,4 @@ void NodeTree::DeallocateTree() { current_head_ = nullptr; } -} // namespace lczero \ No newline at end of file +} // namespace lczero diff --git a/src/mcts/node.h b/src/mcts/node.h index cb1967e..96a63af 100644 --- a/src/mcts/node.h +++ b/src/mcts/node.h @@ -63,7 +63,7 @@ class Node { // Returns move, with optional flip (false == player BEFORE the position). Move GetMove(bool flip) const; - // Returns sum of probabilities for visited children. + // Returns sum of policy priors which have had at least one playout. float GetVisitedPolicy() const; uint32_t GetN() const { return n_; } uint32_t GetNInFlight() const { return n_in_flight_; } @@ -80,11 +80,12 @@ class Node { return q_; } } - // Returns U / (Puct * N[parent]) + // Returns p / N, which is equal to U / (cpuct * sqrt(N[parent])) by the MCTS + // equation. So it's really more of a "reduced U" than raw U. float GetU() const { return p_ / (1 + n_ + n_in_flight_); } // Returns value of Value Head returned from the neural net. float GetV() const { return v_; } - // Returns value of Move probabilityreturned from the neural net. + // Returns value of Move probability returned from the neural net // (but can be changed by adding Dirichlet noise). float GetP() const { return p_; } // Returns whether the node is known to be draw/lose/win. @@ -219,4 +220,4 @@ class NodeTree { // Performs garbage collection on node pool. Thread safe. void GarbageCollectNodePool(); -} // namespace lczero \ No newline at end of file +} // namespace lczero diff --git a/src/mcts/search.cc b/src/mcts/search.cc index 47902d7..fdcc14e 100644 --- a/src/mcts/search.cc +++ b/src/mcts/search.cc @@ -119,8 +119,8 @@ void ApplyDirichletNoise(Node* node, float eps, double alpha) { std::vector noise; // TODO(mooskagh) remove this loop when we store number of children. - for (Node* iter : node->Children()) { - (void)iter; // Silence the unused variable warning. + for (Node* child : node->Children()) { + (void)child; // Silence the unused variable warning. float eta = Random::Get().GetGamma(alpha, 1.0); noise.emplace_back(eta); total += eta; @@ -129,8 +129,8 @@ void ApplyDirichletNoise(Node* node, float eps, double alpha) { if (total < std::numeric_limits::min()) return; int noise_idx = 0; - for (Node* iter : node->Children()) { - iter->SetP(iter->GetP() * (1 - eps) + eps * noise[noise_idx++] / total); + for (Node* child : node->Children()) { + child->SetP(child->GetP() * (1 - eps) + eps * noise[noise_idx++] / total); } } @@ -226,8 +226,8 @@ void Search::SendMovesStats() const { const float parent_q = -root_node_->GetQ(0, 0) - kFpuReduction * std::sqrt(root_node_->GetVisitedPolicy()); - for (Node* iter : root_node_->Children()) { - nodes.emplace_back(iter); + for (Node* child : root_node_->Children()) { + nodes.emplace_back(child); } std::sort(nodes.begin(), nodes.end(), [](const Node* a, const Node* b) { return a->GetN() < b->GetN(); }); @@ -487,8 +487,7 @@ void SearchWorker::GatherMinibatch() { // Node was never visited, extending. ExtendNode(node); - // If node turned out to be a terminal one, no need to send to NN for - // evaluation. + // Only send non-terminal nodes to neural network if (!node->IsTerminal()) { nodes_to_process_.back().nn_queried = true; AddNodeToCompute(node); @@ -496,9 +495,12 @@ void SearchWorker::GatherMinibatch() { } } -// Returns node and whether it should be processed. -// (false if it is a collision). +// Returns node and whether there's been a search collision on the node. SearchWorker::NodeToProcess SearchWorker::PickNodeToExtend() { + // Starting from search_->root_node_, generate a playout, choosing a + // node at each level according to the MCTS formula. n_in_flight_ is + // incremented for each node in the playout (via TryStartScoreUpdate()). + Node* node = search_->root_node_; // Initialize position sequence with pre-move position. history_.Trim(search_->played_history_.GetLength()); @@ -513,17 +515,20 @@ SearchWorker::NodeToProcess SearchWorker::PickNodeToExtend() { // True on first iteration, false as we dive deeper. bool is_root_node = true; while (true) { + // First, terminate if we find collisions or leaf nodes. { SharedMutex::Lock lock(search_->nodes_mutex_); - // Check whether we are in the leave. + // n_in_flight_ is incremented. If the method returns false, then there is + // a search collision, and this node is alredy being expanded. if (!node->TryStartScoreUpdate()) return {node, true}; - // Found leave, and we are the the first to visit it. + // Unexamined leaf node. We've hit the end of this playout. if (!node->HasChildren()) return {node, false}; + // If we fall through, then n_in_flight_ has been incremented but this + // playout remains incomplete; we must go deeper. } - // Now we are not in leave, we need to go deeper. SharedMutex::SharedLock lock(search_->nodes_mutex_); - float factor = + float puct_mult = search_->kCpuct * std::sqrt(std::max(node->GetChildrenVisits(), 1u)); float best = -100.0f; int possible_moves = 0; @@ -532,29 +537,29 @@ SearchWorker::NodeToProcess SearchWorker::PickNodeToExtend() { ? -node->GetQ(0, search_->kExtraVirtualLoss) : -node->GetQ(0, search_->kExtraVirtualLoss) - search_->kFpuReduction * std::sqrt(node->GetVisitedPolicy()); - for (Node* iter : node->Children()) { + for (Node* child : node->Children()) { if (is_root_node) { - // If there's no chance to catch up the currently best node with - // remaining playouts, not consider it. - // best_move_node_ can change since best_node_n computation. + // If there's no chance to catch up to the current best node with + // remaining playouts, don't consider it. + // best_move_node_ could have changed since best_node_n was retrieved. // To ensure we have at least one node to expand, always include // current best node. - if (iter != search_->best_move_node_ && + if (child != search_->best_move_node_ && search_->remaining_playouts_ < - best_node_n - static_cast(iter->GetN())) { + best_node_n - static_cast(child->GetN())) { continue; } ++possible_moves; } - float Q = iter->GetQ(parent_q, search_->kExtraVirtualLoss); - if (search_->kVirtualLossBug && iter->GetN() == 0) { - Q = (Q * iter->GetParent()->GetN() - search_->kVirtualLossBug) / - (iter->GetParent()->GetN() + std::fabs(search_->kVirtualLossBug)); + float Q = child->GetQ(parent_q, search_->kExtraVirtualLoss); + if (search_->kVirtualLossBug && child->GetN() == 0) { + Q = (Q * child->GetParent()->GetN() - search_->kVirtualLossBug) / + (child->GetParent()->GetN() + std::fabs(search_->kVirtualLossBug)); } - const float score = factor * iter->GetU() + Q; + const float score = puct_mult * child->GetU() + Q; if (score > best) { best = score; - node = iter; + node = child; } } history_.Append(node->GetMove()); @@ -574,20 +579,20 @@ void SearchWorker::ExtendNode(Node* node) { const auto& board = history_.Last().GetBoard(); auto legal_moves = board.GenerateLegalMoves(); - // Check whether it's a draw/lose by rules. + // Check whether it's a draw/lose by position. Importantly, we must check + // these before doing the by-rule checks below. if (legal_moves.empty()) { - // Checkmate or stalemate. + // Could be a checkmate or a stalemate if (board.IsUnderCheck()) { - // Checkmate. node->MakeTerminal(GameResult::WHITE_WON); } else { - // Stalemate. node->MakeTerminal(GameResult::DRAW); } return; } - // If it's root node and we're asked to think, pretend there's no draw. + // We can shortcircuit these draws-by-rule only if they aren't root; + // if they are root, then thinking about them is the point. if (node != search_->root_node_) { if (!board.HasMatingMaterial()) { node->MakeTerminal(GameResult::DRAW); @@ -624,8 +629,8 @@ bool SearchWorker::AddNodeToCompute(Node* node, bool add_if_cached) { if (node->HasChildren()) { // Legal moves are known, using them. - for (Node* iter : node->Children()) { - moves.emplace_back(iter->GetMove().as_nn_index()); + for (Node* child : node->Children()) { + moves.emplace_back(child->GetMove().as_nn_index()); } } else { // Cache pseudolegal moves. A bit of a waste, but faster. @@ -676,28 +681,28 @@ int SearchWorker::PrefetchIntoCache(Node* node, int budget) { return 1; } - // If it's a node in progress of expansion or is terminal, not prefetching. + // If it's a node in process of expansion or is terminal, don't prefetch it. if (!node->HasChildren()) return 0; // Populate all subnodes and their scores. typedef std::pair ScoredNode; std::vector scores; - float factor = + float puct_mult = search_->kCpuct * std::sqrt(std::max(node->GetChildrenVisits(), 1u)); // FPU reduction is not taken into account. const float parent_q = -node->GetQ(0, search_->kExtraVirtualLoss); - for (Node* iter : node->Children()) { - if (iter->GetP() == 0.0f) continue; + for (Node* child : node->Children()) { + if (child->GetP() == 0.0f) continue; // Flipping sign of a score to be able to easily sort. - scores.emplace_back(-factor * iter->GetU() - - iter->GetQ(parent_q, search_->kExtraVirtualLoss), - iter); + scores.emplace_back(-puct_mult * child->GetU() - + child->GetQ(parent_q, search_->kExtraVirtualLoss), + child); } size_t first_unsorted_index = 0; int total_budget_spent = 0; - int budget_to_spend = budget; // Initializing for the case there's only - // on child. + int budget_to_spend = budget; // Initialize for the case where there's only + // one child. for (size_t i = 0; i < scores.size(); ++i) { if (budget <= 0) break; @@ -715,13 +720,14 @@ int SearchWorker::PrefetchIntoCache(Node* node, int budget) { Node* n = scores[i].second; // Last node gets the same budget as prev-to-last node. if (i != scores.size() - 1) { - // Sign of the score was flipped for sorting, flipping back. + // Sign of the score was flipped for sorting, so flip it back. const float next_score = -scores[i + 1].first; const float q = n->GetQ(-parent_q, search_->kExtraVirtualLoss); if (next_score > q) { budget_to_spend = std::min( budget, - int(n->GetP() * factor / (next_score - q) - n->GetNStarted()) + 1); + int(n->GetP() * puct_mult / (next_score - q) - n->GetNStarted()) + + 1); } else { budget_to_spend = budget; } @@ -738,7 +744,8 @@ int SearchWorker::PrefetchIntoCache(Node* node, int budget) { // 4. Run NN computation. // ~~~~~~~~~~~~~~~~~~~~~~ void SearchWorker::RunNNComputation() { - // Evaluate nodes through NN. + // This function is so small as to be silly, but its parent function is + // conceptually cleaner for it. if (computation_->GetBatchSize() != 0) computation_->ComputeBlocking(); } @@ -798,7 +805,7 @@ void SearchWorker::DoBackupUpdate() { float v = node->GetV(); // Maximum depth the node is explored. uint16_t depth = 0; - // If the node is terminal, mark it as fully explored to an infinite + // If the node is terminal, mark it as fully explored to an "infinite" // depth. uint16_t cur_full_depth = node->IsTerminal() ? 999 : 0; bool full_depth_updated = true; @@ -830,8 +837,7 @@ void SearchWorker::DoBackupUpdate() { // 7. UpdateCounters() //~~~~~~~~~~~~~~~~~~~~ void SearchWorker::UpdateCounters() { - search_ - ->UpdateRemainingMoves(); // Update remaining moves using smart pruning. + search_->UpdateRemainingMoves(); // Updates smart pruning counters. search_->MaybeOutputInfo(); search_->MaybeTriggerStop();