Skip to content

Commit

Permalink
Fix crash in YDF distributed training.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705142598
  • Loading branch information
achoum authored and copybara-github committed Dec 11, 2024
1 parent e6501b7 commit f59812e
Show file tree
Hide file tree
Showing 9 changed files with 133 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ cc_library_ydf(
"//yggdrasil_decision_forests/utils:distribution",
"//yggdrasil_decision_forests/utils:status_macros",
"//yggdrasil_decision_forests/utils:synchronization_primitives",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down Expand Up @@ -88,6 +90,7 @@ cc_test(
"//yggdrasil_decision_forests/utils/distribute/implementations/multi_thread",
"//yggdrasil_decision_forests/utils/distribute/implementations/multi_thread:multi_thread_cc_proto",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ absl::Status SortNumericalColumns(

// Receive and rename the results.
for (int result_idx = 0; result_idx < pending_requests; result_idx++) {
LOG_EVERY_N_SEC(INFO, 10) << "\tsorting numerical columns "
LOG_EVERY_N_SEC(INFO, 10) << "\tSorting numerical columns "
<< (result_idx + 1) << "/" << pending_requests;

ASSIGN_OR_RETURN(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@
#include <utility>
#include <vector>

#include "absl/base/optimization.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
Expand Down Expand Up @@ -740,15 +742,16 @@ absl::Status SetLeafValue(
}

NodeRemapping TreeBuilder::CreateClosingNodeRemapping() const {
return NodeRemapping{open_nodes_.size(), {kClosedNode, kClosedNode}};
return NodeRemapping{{open_nodes_.size(), {kClosedNode, kClosedNode}}, 0};
}

absl::StatusOr<NodeRemapping> TreeBuilder::ApplySplitToTree(
const SplitPerOpenNode& splits) {
if (open_nodes_.size() != splits.size()) {
return absl::InternalError("Wrong number of internal nodes");
}
NodeRemapping remapping(open_nodes_.size());
NodeRemapping remapping;
remapping.mapping.resize(open_nodes_.size());
std::vector<decision_tree::NodeWithChildren*> new_open_nodes;
for (int split_idx = 0; split_idx < splits.size(); split_idx++) {
const auto& split = splits[split_idx];
Expand All @@ -765,7 +768,7 @@ absl::StatusOr<NodeRemapping> TreeBuilder::ApplySplitToTree(
// condition.
node.mutable_node()->set_num_pos_training_examples_without_weight(
split.condition.num_training_examples_without_weight());
remapping[split_idx] = {
remapping.mapping[split_idx] = {
static_cast<NodeIndex>(new_open_nodes.size()),
static_cast<NodeIndex>(new_open_nodes.size() + 1)};
new_open_nodes.push_back(node.mutable_neg_child());
Expand All @@ -781,7 +784,7 @@ absl::StatusOr<NodeRemapping> TreeBuilder::ApplySplitToTree(
node.FinalizeAsNonLeaf(true, true);
} else {
// Turning the node into a leaf.
remapping[split_idx] = {kClosedNode, kClosedNode};
remapping.mapping[split_idx] = {kClosedNode, kClosedNode};
node.FinalizeAsLeaf(true);
}
}
Expand All @@ -790,6 +793,7 @@ absl::StatusOr<NodeRemapping> TreeBuilder::ApplySplitToTree(
return absl::InvalidArgumentError("Maximum node limit exceeded");
}
open_nodes_ = new_open_nodes;
remapping.num_dst_nodes = new_open_nodes.size();
return remapping;
}

Expand Down Expand Up @@ -1277,8 +1281,9 @@ absl::Status UpdateExampleNodeMap(
const SplitPerOpenNode& splits,
const SplitEvaluationPerOpenNode& split_evaluation,
const NodeRemapping& node_remapping, ExampleToNodeMap* example_to_node,
utils::concurrency::ThreadPool* thread_pool) {
DCHECK_EQ(split_evaluation.size(), node_remapping.size());
utils::concurrency::ThreadPool* thread_pool,
NumExamplesPerNode* num_examples_per_node) {
DCHECK_EQ(split_evaluation.size(), node_remapping.mapping.size());
std::vector<utils::bitmap::BitReader> readers(split_evaluation.size());
for (int node_idx = 0; node_idx < split_evaluation.size(); node_idx++) {
const auto num_elements = static_cast<uint64_t>(
Expand All @@ -1287,6 +1292,8 @@ absl::Status UpdateExampleNodeMap(
readers[node_idx].Open(split_evaluation[node_idx].data(), num_elements);
}

num_examples_per_node->assign(node_remapping.num_dst_nodes, 0);

// TODO: In parallel.
for (ExampleIndex example_idx = 0; example_idx < example_to_node->size();
example_idx++) {
Expand All @@ -1298,13 +1305,14 @@ absl::Status UpdateExampleNodeMap(
DCHECK_GE(node_idx, 0);
DCHECK_LT(node_idx, split_evaluation.size());

if (node_remapping[node_idx].indices[0] == kClosedNode) {
if (node_remapping.mapping[node_idx].indices[0] == kClosedNode) {
// The example is in a node that is closed during this iteration.
node_idx = kClosedNode;
continue;
}
const bool evaluation = readers[node_idx].Read();
node_idx = node_remapping[node_idx].indices[evaluation];
node_idx = node_remapping.mapping[node_idx].indices[evaluation];
(*num_examples_per_node)[node_idx]++;
}

for (auto& reader : readers) {
Expand All @@ -1314,11 +1322,12 @@ absl::Status UpdateExampleNodeMap(
return absl::OkStatus();
}

absl::Status UpdateLabelStatistics(const SplitPerOpenNode& splits,
const NodeRemapping& node_remapping,
LabelStatsPerNode* label_stats) {
absl::Status UpdateLabelStatistics(
const SplitPerOpenNode& splits, const NodeRemapping& node_remapping,
const NumExamplesPerNode& num_examples_per_node,
LabelStatsPerNode* label_stats, const bool allow_statistics_correction) {
NodeIndex dst_num_nodes = 0;
for (const auto& mapping : node_remapping) {
for (const auto& mapping : node_remapping.mapping) {
for (const auto evaluation : {0, 1}) {
const auto dst_node_idx = mapping.indices[evaluation];
if (dst_node_idx != kClosedNode) {
Expand All @@ -1328,18 +1337,38 @@ absl::Status UpdateLabelStatistics(const SplitPerOpenNode& splits,
}
}
label_stats->assign(dst_num_nodes, {});
STATUS_CHECK_EQ(dst_num_nodes, node_remapping.num_dst_nodes);

for (int src_node_idx = 0; src_node_idx < splits.size(); src_node_idx++) {
for (const auto evaluation : {0, 1}) {
const auto dst_node_idx =
node_remapping[src_node_idx].indices[evaluation];
node_remapping.mapping[src_node_idx].indices[evaluation];
if (dst_node_idx == kClosedNode) {
continue;
}
DCHECK_GE(dst_node_idx, 0);
DCHECK_LT(dst_node_idx, dst_num_nodes);
(*label_stats)[dst_node_idx] =
splits[src_node_idx].label_statistics[evaluation];
auto& dst_stats = (*label_stats)[dst_node_idx];
dst_stats = splits[src_node_idx].label_statistics[evaluation];

if (ABSL_PREDICT_FALSE(dst_stats.num_examples() !=
num_examples_per_node[dst_node_idx])) {
// Note: Due to floating point approximations, it is rare (in most
// training this does not happens) but possible for label statistic
// counts to be wrong. In such case, those counts need to be corrected.
// The same observations and logic is used in non-distributed training.
std::string message = absl::Substitute(
"The number of examples returned by the evaluator and "
"splittor don't match. $0 != $1.",
dst_stats.num_examples(), num_examples_per_node[dst_node_idx]);
if (allow_statistics_correction) {
LOG_FIRST_N(WARNING, 10)
<< "[Internal]" << message << " This is not an issue.";
} else {
return absl::InternalError(message);
}
dst_stats.set_num_examples(num_examples_per_node[dst_node_idx]);
}
}
}
return absl::OkStatus();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,13 @@ struct SplitNodeIndex {
// Negative (0) and positive (1) node indices.
NodeIndex indices[2];
};
typedef std::vector<SplitNodeIndex> NodeRemapping;
struct NodeRemapping {
// mapping[i].indices[j] is the new node index for examples initially in node
// i and that evaluate to j.
std::vector<SplitNodeIndex> mapping;
// Number of destination nodes (excluding the kClosedNode).
int num_dst_nodes;
};

// Bitmap of the evaluation of a split i.e. evaluation of the boolean condition
// defined by the split.
Expand Down Expand Up @@ -338,16 +344,28 @@ absl::Status EvaluateSplitsPerBooleanFeature(
SplitEvaluationPerOpenNode* split_evaluation,
const dataset_cache::DatasetCacheReader* dataset);

// Update the node index of each example according to the split.
// Number of examples for each of the open nodes. Used by "UpdateExampleNodeMap"
// and "UpdateLabelStatistics".
typedef std::vector<ExampleIndex> NumExamplesPerNode;

// Updates the node index of each example according to the split. Also populate
// "num_examples_per_node" with the number of example in each nodes as computed
// during the split evaluation.
absl::Status UpdateExampleNodeMap(
const SplitPerOpenNode& splits,
const SplitEvaluationPerOpenNode& split_evaluation,
const NodeRemapping& node_remapping, ExampleToNodeMap* example_to_node,
utils::concurrency::ThreadPool* thread_pool);

absl::Status UpdateLabelStatistics(const SplitPerOpenNode& splits,
const NodeRemapping& node_remapping,
LabelStatsPerNode* label_stats);
utils::concurrency::ThreadPool* thread_pool,
NumExamplesPerNode* num_examples_per_node);

// Updates the label statistics for each node. If the number of examples in the
// condition definition and in "num_examples_per_node" don't match, return an
// error (allow_statistics_correction=false) or print a warning
// (allow_statistics_correction=true).
absl::Status UpdateLabelStatistics(
const SplitPerOpenNode& splits, const NodeRemapping& node_remapping,
const NumExamplesPerNode& num_examples_per_node,
LabelStatsPerNode* label_stats, bool allow_statistics_correction);

// Gets the number of valid splits.
int NumValidSplits(const SplitPerOpenNode& splits);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "gmock/gmock.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "yggdrasil_decision_forests/dataset/data_spec_inference.h"
Expand All @@ -38,6 +39,8 @@ namespace model {
namespace distributed_decision_tree {
namespace {

using test::StatusIs;

// Generic training loop. Displays all the intermediate results.
template <typename LabelAccessor, typename Tester>
void GenericTrainingLoop(LabelAccessor* label_accessor, Tester* tester,
Expand Down Expand Up @@ -115,9 +118,9 @@ void GenericTrainingLoop(LabelAccessor* label_accessor, Tester* tester,
// Add the found splits to the tree structure.
const auto node_remapping = tree_builder->ApplySplitToTree(splits).value();
LOG(INFO) << "Remapping:";
for (int i = 0; i < node_remapping.size(); i++) {
LOG(INFO) << "\t" << i << " -> " << node_remapping[i].indices[0] << " + "
<< node_remapping[i].indices[1];
for (int i = 0; i < node_remapping.mapping.size(); i++) {
LOG(INFO) << "\t" << i << " -> " << node_remapping.mapping[i].indices[0]
<< " + " << node_remapping.mapping[i].indices[1];
}

std::string description;
Expand All @@ -139,8 +142,10 @@ void GenericTrainingLoop(LabelAccessor* label_accessor, Tester* tester,
}

// Update the example->node map.
NumExamplesPerNode num_examples_per_node;
CHECK_OK(UpdateExampleNodeMap(splits, split_evaluation, node_remapping,
&example_to_node, &thread_pool));
&example_to_node, &thread_pool,
&num_examples_per_node));

LOG(INFO) << "Example to node map (first 10 values):";
ExampleIndex example_idx = 0;
Expand All @@ -154,8 +159,9 @@ void GenericTrainingLoop(LabelAccessor* label_accessor, Tester* tester,

// Update the label statistics.
const auto previous_label_stats_per_node_size = label_stats_per_node.size();
CHECK_OK(
UpdateLabelStatistics(splits, node_remapping, &label_stats_per_node));
CHECK_OK(UpdateLabelStatistics(splits, node_remapping,
num_examples_per_node, &label_stats_per_node,
/*allow_statistics_correction=*/false));
LOG(INFO) << "Update the number of open nodes "
<< previous_label_stats_per_node_size << " -> "
<< label_stats_per_node.size();
Expand Down Expand Up @@ -460,9 +466,10 @@ TEST_F(AdultClassificationDataset, ManualCheck) {
test::EqualsProto(expected_pos_statistics));

const auto node_remapping = tree_builder->ApplySplitToTree(splits).value();
EXPECT_EQ(node_remapping.size(), 1);
EXPECT_EQ(node_remapping.front().indices[0], 0);
EXPECT_EQ(node_remapping.front().indices[1], 1);
EXPECT_EQ(node_remapping.mapping.size(), 1);
EXPECT_EQ(node_remapping.mapping.front().indices[0], 0);
EXPECT_EQ(node_remapping.mapping.front().indices[1], 1);
EXPECT_EQ(node_remapping.num_dst_nodes, 2);
EXPECT_EQ(tree_builder->tree().NumNodes(), 3);
EXPECT_THAT(tree_builder->tree().root().node().condition(),
test::EqualsProto(expected_condition));
Expand All @@ -475,16 +482,34 @@ TEST_F(AdultClassificationDataset, ManualCheck) {
EXPECT_EQ(split_evaluation.front().size(), (22792 + 7) / 8);
EXPECT_EQ(utils::bitmap::ToStringBit(split_evaluation[0], 10), "1011101111");

NumExamplesPerNode num_examples_per_node;
CHECK_OK(UpdateExampleNodeMap(splits, split_evaluation, node_remapping,
&example_to_node, &thread_pool));
&example_to_node, &thread_pool,
&num_examples_per_node));
EXPECT_EQ(example_to_node.size(), dataset_->num_examples());
EXPECT_EQ(example_to_node[0], 1);
EXPECT_EQ(example_to_node[1], 0);
EXPECT_EQ(example_to_node[2], 1);
EXPECT_EQ(example_to_node[3], 1);

EXPECT_THAT(num_examples_per_node, ::testing::ElementsAre(5569, 17223));

// Make the statistics wrong as to force the fix.
splits[0].label_statistics[0].set_num_examples(10);

{
LabelStatsPerNode new_label_stats;
EXPECT_THAT(UpdateLabelStatistics(splits, node_remapping,
num_examples_per_node, &new_label_stats,
/*allow_statistics_correction=*/false),
StatusIs(absl::StatusCode::kInternal,
"The number of examples returned by"));
}

LabelStatsPerNode new_label_stats;
CHECK_OK(UpdateLabelStatistics(splits, node_remapping, &new_label_stats));
CHECK_OK(UpdateLabelStatistics(splits, node_remapping, num_examples_per_node,
&new_label_stats,
/*allow_statistics_correction=*/true));

const decision_tree::proto::LabelStatistics expected_new_label_stats_0 =
PARSE_TEST_PROTO(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1592,7 +1592,7 @@ absl::Status EmitStartTraining(
<< load_balancer->NumWorkers() << " [duration: " << absl::Now() - begin
<< "]";
}
LOG(INFO) << "Worker ready to train in " << absl::Now() - begin;
LOG(INFO) << "All workers ready to train in " << absl::Now() - begin;

monitoring->EndStage(internal::Monitoring::kStartTraining);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
#include "yggdrasil_decision_forests/dataset/vertical_dataset.h"
#include "yggdrasil_decision_forests/dataset/vertical_dataset_io.h"
#include "yggdrasil_decision_forests/dataset/weight.h"
#include "yggdrasil_decision_forests/learner/distributed_decision_tree/label_accessor.h"
#include "yggdrasil_decision_forests/learner/distributed_decision_tree/training.h"
#include "yggdrasil_decision_forests/learner/distributed_gradient_boosted_trees/common.h"
#include "yggdrasil_decision_forests/learner/distributed_gradient_boosted_trees/distributed_gradient_boosted_trees.pb.h"
#include "yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees.h"
Expand All @@ -43,7 +45,6 @@
#include "yggdrasil_decision_forests/learner/gradient_boosted_trees/loss/loss_utils.h"
#include "yggdrasil_decision_forests/model/decision_tree/decision_tree.pb.h"
#include "yggdrasil_decision_forests/serving/example_set.h"
#include "yggdrasil_decision_forests/utils/bitmap.h"
#include "yggdrasil_decision_forests/utils/compatibility.h"
#include "yggdrasil_decision_forests/utils/protobuf.h"
#include "yggdrasil_decision_forests/utils/status_macros.h"
Expand Down Expand Up @@ -944,13 +945,16 @@ absl::Status DistributedGradientBoostedTreesWorker::ShareSplits(
weak_model_idx, weak_models_.size(), &predictions_,
thread_pool_.get()));

distributed_decision_tree::NumExamplesPerNode num_examples_per_node;
RETURN_IF_ERROR(UpdateExampleNodeMap(
weak_model.last_splits, weak_model.last_split_evaluation,
node_remapping, &weak_model.example_to_node, thread_pool_.get()));
node_remapping, &weak_model.example_to_node, thread_pool_.get(),
&num_examples_per_node));

RETURN_IF_ERROR(UpdateLabelStatistics(weak_model.last_splits,
node_remapping,
&weak_model.label_stats_per_node));
RETURN_IF_ERROR(UpdateLabelStatistics(
weak_model.last_splits, node_remapping, num_examples_per_node,
&weak_model.label_stats_per_node,
/*allow_statistics_correction=*/true));
}
}
return absl::OkStatus();
Expand Down Expand Up @@ -1367,7 +1371,7 @@ absl::Status UpdateClosingNodesPredictions(
DCHECK_GE(node_idx, 0);
DCHECK_LT(node_idx, label_stats_per_node.size());

if (node_remapping[node_idx].indices[0] !=
if (node_remapping.mapping[node_idx].indices[0] !=
distributed_decision_tree::kClosedNode) {
// This example remains in an open node.
continue;
Expand Down
Loading

0 comments on commit f59812e

Please sign in to comment.