From 03dec0130812d56a8097514be455eba588ec39b1 Mon Sep 17 00:00:00 2001 From: dhb <1084714805@qq.com> Date: Tue, 15 Oct 2024 15:44:09 +0800 Subject: [PATCH] backprop with softmax weights (invT=0.2) test f15 --- Rapfi/search/mcts/node.cpp | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/Rapfi/search/mcts/node.cpp b/Rapfi/search/mcts/node.cpp index 97048d1..28aa50b 100644 --- a/Rapfi/search/mcts/node.cpp +++ b/Rapfi/search/mcts/node.cpp @@ -146,7 +146,9 @@ void Node::updateStats() if (!edgeArray) return; - uint32_t nSum = 1; + constexpr float invTemp = 0.2f; + + float wSum = 1 * std::exp((utility - 1.0f) * invTemp); float qSum = utility; float qSqrSum = utility * utility; float dSum = drawRate; @@ -168,14 +170,19 @@ void Node::updateStats() float childQ = childNode->q.load(std::memory_order_relaxed); float childQSqr = childNode->qSqr.load(std::memory_order_relaxed); float childD = childNode->d.load(std::memory_order_relaxed); - nSum += childN; - qSum += childN * (-childQ); // Flip side for child's utility - qSqrSum += childN * childQSqr; - dSum += childN * childD; + + // Compute the weight of this child node using softmax + // We minus childQ by the maximum Q value to avoid overflow. + float childW = childN * std::exp((childQ - 1.0f) * invTemp); + + wSum += childW; + qSum += childW * (-childQ); // Flip side for child's utility + qSqrSum += childW * childQSqr; + dSum += childW * childD; maxBound |= childNode->bound.load(std::memory_order_relaxed); } - float norm = 1.0f / nSum; + float norm = 1.0f / wSum; q.store(qSum * norm, std::memory_order_relaxed); qSqr.store(qSqrSum * norm, std::memory_order_relaxed); d.store(dSum * norm, std::memory_order_relaxed);