diff --git a/Rapfi/search/mcts/node.cpp b/Rapfi/search/mcts/node.cpp index 97048d1..704278d 100644 --- a/Rapfi/search/mcts/node.cpp +++ b/Rapfi/search/mcts/node.cpp @@ -146,7 +146,9 @@ void Node::updateStats() if (!edgeArray) return; - uint32_t nSum = 1; + constexpr float invTemp = 1.0f; + + float wSum = 1 * std::exp((utility - 1.0f) * invTemp); float qSum = utility; float qSqrSum = utility * utility; float dSum = drawRate; @@ -168,14 +170,19 @@ void Node::updateStats() float childQ = childNode->q.load(std::memory_order_relaxed); float childQSqr = childNode->qSqr.load(std::memory_order_relaxed); float childD = childNode->d.load(std::memory_order_relaxed); - nSum += childN; - qSum += childN * (-childQ); // Flip side for child's utility - qSqrSum += childN * childQSqr; - dSum += childN * childD; + + // Compute the weight of this child node using softmax + // We minus childQ by the maximum Q value to avoid overflow. + float childW = childN * std::exp((childQ - 1.0f) * invTemp); + + wSum += childW; + qSum += childW * (-childQ); // Flip side for child's utility + qSqrSum += childW * childQSqr; + dSum += childW * childD; maxBound |= childNode->bound.load(std::memory_order_relaxed); } - float norm = 1.0f / nSum; + float norm = 1.0f / wSum; q.store(qSum * norm, std::memory_order_relaxed); qSqr.store(qSqrSum * norm, std::memory_order_relaxed); d.store(dSum * norm, std::memory_order_relaxed);