Skip to content

Commit

Permalink
Adapt MSE loss for empty weights
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 489140058
  • Loading branch information
rstz authored and copybara-github committed Nov 17, 2022
1 parent cf36cbc commit a694e81
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,17 @@ absl::StatusOr<std::vector<float>> MeanSquaredErrorLoss::InitialPredictions(
dataset
.ColumnWithCastWithStatus<dataset::VerticalDataset::NumericalColumn>(
label_col_idx));
for (UnsignedExampleIdx example_idx = 0; example_idx < dataset.nrow();
example_idx++) {
sum_weights += weights[example_idx];
weighted_sum_values += weights[example_idx] * labels->values()[example_idx];
if (weights.empty()) {
sum_weights = dataset.nrow();
weighted_sum_values =
std::accumulate(labels->values().begin(), labels->values().end(), 0.);
} else {
for (UnsignedExampleIdx example_idx = 0; example_idx < dataset.nrow();
example_idx++) {
sum_weights += weights[example_idx];
weighted_sum_values +=
weights[example_idx] * labels->values()[example_idx];
}
}
// Note: Null and negative weights are detected by the dataspec
// computation.
Expand Down Expand Up @@ -127,57 +134,23 @@ absl::Status MeanSquaredErrorLoss::UpdatePredictions(
decision_tree::CreateSetLeafValueFunctor MeanSquaredErrorLoss::SetLeafFunctor(
const std::vector<float>& predictions,
const std::vector<GradientData>& gradients, const int label_col_idx) const {
return
[this, &predictions, label_col_idx](
const dataset::VerticalDataset& train_dataset,
const std::vector<UnsignedExampleIdx>& selected_examples,
const std::vector<float>& weights,
const model::proto::TrainingConfig& config,
const model::proto::TrainingConfigLinking& config_link,
decision_tree::NodeWithChildren* node) {
return SetLeaf(train_dataset, selected_examples, weights, config,
config_link, predictions, label_col_idx, node);
};
}

absl::Status MeanSquaredErrorLoss::SetLeaf(
const dataset::VerticalDataset& train_dataset,
const std::vector<UnsignedExampleIdx>& selected_examples,
const std::vector<float>& weights,
const model::proto::TrainingConfig& config,
const model::proto::TrainingConfigLinking& config_link,
const std::vector<float>& predictions, const int label_col_idx,
decision_tree::NodeWithChildren* node) const {
RETURN_IF_ERROR(decision_tree::SetRegressionLabelDistribution(
train_dataset, selected_examples, weights, config_link,
node->mutable_node()));

// Set the value of the leaf to be the residual:
// label[i] - prediction
ASSIGN_OR_RETURN(
const auto* labels,
train_dataset
.ColumnWithCastWithStatus<dataset::VerticalDataset::NumericalColumn>(
label_col_idx));
double sum_weighted_values = 0;
double sum_weights = 0;
for (const auto example_idx : selected_examples) {
const float label = labels->values()[example_idx];
const float prediction = predictions[example_idx];
sum_weighted_values += weights[example_idx] * (label - prediction);
sum_weights += weights[example_idx];
}
if (sum_weights <= 0) {
LOG(WARNING) << "Zero or negative weights in node";
sum_weights = 1.0;
}
// Note: The "sum_weights" terms carries an implicit 2x factor that is
// integrated in the shrinkage. We don't integrate this factor here not to
// change the behavior of existing training configurations.
node->mutable_node()->mutable_regressor()->set_top_value(
gbt_config_.shrinkage() * sum_weighted_values /
(sum_weights + gbt_config_.l2_regularization() / 2));
return absl::OkStatus();
return [this, &predictions, label_col_idx](
const dataset::VerticalDataset& train_dataset,
const std::vector<UnsignedExampleIdx>& selected_examples,
const std::vector<float>& weights,
const model::proto::TrainingConfig& config,
const model::proto::TrainingConfigLinking& config_link,
decision_tree::NodeWithChildren* node) {
if (weights.empty()) {
return SetLeaf</*weighted=*/false>(train_dataset, selected_examples,
weights, config, config_link,
predictions, label_col_idx, node);
} else {
return SetLeaf</*weighted=*/true>(train_dataset, selected_examples,
weights, config, config_link,
predictions, label_col_idx, node);
}
};
}

absl::StatusOr<decision_tree::SetLeafValueFromLabelStatsFunctor>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,59 @@ class MeanSquaredErrorLoss : public AbstractLoss {
const std::vector<GradientData>& gradients,
int label_col_idx) const override;

template <bool weighted>
absl::Status SetLeaf(const dataset::VerticalDataset& train_dataset,
const std::vector<UnsignedExampleIdx>& selected_examples,
const std::vector<float>& weights,
const model::proto::TrainingConfig& config,
const model::proto::TrainingConfigLinking& config_link,
const std::vector<float>& predictions, int label_col_idx,
decision_tree::NodeWithChildren* node) const;
const std::vector<float>& predictions,
const int label_col_idx,
decision_tree::NodeWithChildren* node) const {
if constexpr (weighted) {
STATUS_CHECK(weights.size() == train_dataset.nrow());
} else {
STATUS_CHECK(weights.empty());
}
// Initialize the distribution (as the "top_value" is overridden right
// after.
RETURN_IF_ERROR(decision_tree::SetRegressionLabelDistribution(
train_dataset, selected_examples, weights, config_link,
node->mutable_node()));

// Set the value of the leaf to be the residual:
// label[i] - prediction
ASSIGN_OR_RETURN(
const auto* labels,
train_dataset.ColumnWithCastWithStatus<
dataset::VerticalDataset::NumericalColumn>(label_col_idx));
double sum_weighted_values = 0;
double sum_weights = 0;
if constexpr (!weighted) {
sum_weights = selected_examples.size();
}
for (const auto example_idx : selected_examples) {
const float label = labels->values()[example_idx];
const float prediction = predictions[example_idx];
if constexpr (weighted) {
sum_weighted_values += weights[example_idx] * (label - prediction);
sum_weights += weights[example_idx];
} else {
sum_weighted_values += label - prediction;
}
}
if (sum_weights <= 0) {
LOG(WARNING) << "Zero or negative weights in node";
sum_weights = 1.0;
}
// Note: The "sum_weights" terms carries an implicit 2x factor that is
// integrated in the shrinkage. We don't integrate this factor here not to
// change the behavior of existing training configurations.
node->mutable_node()->mutable_regressor()->set_top_value(
gbt_config_.shrinkage() * sum_weighted_values /
(sum_weights + gbt_config_.l2_regularization() / 2));
return absl::OkStatus();
}

absl::StatusOr<decision_tree::SetLeafValueFromLabelStatsFunctor>
SetLeafFunctorFromLabelStatistics() const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,39 @@ absl::StatusOr<dataset::VerticalDataset> CreateToyDataset() {
return dataset;
}

TEST(MeanSquareErrorLossTest, InitialPredictions) {
class MeanSquareErrorLossTest : public testing::TestWithParam<bool> {};

TEST_P(MeanSquareErrorLossTest, InitialPredictions) {
ASSERT_OK_AND_ASSIGN(const dataset::VerticalDataset dataset,
CreateToyDataset());
std::vector<float> weights = {1.f, 1.f, 1.f, 1.f};
const bool weighted = GetParam();
std::vector<float> weights;
if (weighted) {
weights = {2.f, 4.f, 6.f, 8.f};
}

const MeanSquaredErrorLoss loss_imp({}, model::proto::Task::REGRESSION,
dataset.data_spec().columns(0));
ASSERT_OK_AND_ASSIGN(
const std::vector<float> init_pred,
loss_imp.InitialPredictions(dataset, /* label_col_idx= */ 0, weights));
EXPECT_THAT(init_pred, ElementsAre((1.f + 2.f + 3.f + 4.f) / 4.f)); // Mean.
if (weighted) {
EXPECT_THAT(init_pred,
ElementsAre((2.f + 8.f + 18.f + 32.f) / 20.f)); // Mean.
} else {
EXPECT_THAT(init_pred,
ElementsAre((1.f + 2.f + 3.f + 4.f) / 4.f)); // Mean.
}
}

TEST(MeanSquareErrorLossTest, UpdateGradients) {
TEST_P(MeanSquareErrorLossTest, UpdateGradients) {
ASSERT_OK_AND_ASSIGN(const dataset::VerticalDataset dataset,
CreateToyDataset());
std::vector<float> weights = {1.f, 1.f, 1.f, 1.f};
const bool weighted = GetParam();
std::vector<float> weights;
if (weighted) {
weights = {2.f, 4.f, 6.f, 8.f};
}

dataset::VerticalDataset gradient_dataset;
std::vector<GradientData> gradients;
Expand All @@ -97,14 +114,23 @@ TEST(MeanSquareErrorLossTest, UpdateGradients) {
&random));

ASSERT_THAT(gradients, Not(IsEmpty()));
if (weighted) {
EXPECT_THAT(gradients.front().gradient,
ElementsAre(1.f - 3.f, 2.f - 3.f, 3.f - 3.f, 4.f - 3.f));
} else {
EXPECT_THAT(gradients.front().gradient,
ElementsAre(1.f - 2.5f, 2.f - 2.5f, 3.f - 2.5f, 4.f - 2.5f));
}
}

TEST(MeanSquareErrorLossTest, SetLabelDistribution) {
TEST_P(MeanSquareErrorLossTest, SetLabelDistribution) {
ASSERT_OK_AND_ASSIGN(const dataset::VerticalDataset dataset,
CreateToyDataset());
std::vector<float> weights = {1.f, 1.f, 1.f, 1.f};
const bool weighted = GetParam();
std::vector<float> weights;
if (weighted) {
weights = {2.f, 4.f, 6.f, 8.f};
}

proto::GradientBoostedTreesTrainingConfig gbt_config;
gbt_config.set_shrinkage(1.f);
Expand All @@ -128,40 +154,71 @@ TEST(MeanSquareErrorLossTest, SetLabelDistribution) {
config_link.set_label(2); // Gradient column.

decision_tree::NodeWithChildren node;
ASSERT_OK(loss_imp.SetLeaf(gradient_dataset, selected_examples, weights,
config, config_link, predictions,
/* label_col_idx= */ 0, &node));

EXPECT_EQ(node.node().regressor().top_value(), 2.5f); // Mean of the labels.
// Distribution of the gradients:
EXPECT_EQ(node.node().regressor().distribution().sum(), 0);
EXPECT_EQ(node.node().regressor().distribution().sum_squares(), 0);
// Same as the number of examples in the dataset.
EXPECT_EQ(node.node().regressor().distribution().count(), 4.);
if (weighted) {
ASSERT_OK(loss_imp.SetLeaf</*weighted=*/true>(
gradient_dataset, selected_examples, weights, config, config_link,
predictions,
/* label_col_idx= */ 0, &node));
// Top_value is the weighted mean of the labels
EXPECT_EQ(node.node().regressor().top_value(), 3.f);
// Distribution of the gradients:
EXPECT_EQ(node.node().regressor().distribution().sum(), 0);
EXPECT_EQ(node.node().regressor().distribution().sum_squares(), 0);
// Total weight in the dataset.
EXPECT_EQ(node.node().regressor().distribution().count(), 20.);
} else {
ASSERT_OK(loss_imp.SetLeaf</*weighted=*/false>(
gradient_dataset, selected_examples, weights, config, config_link,
predictions,
/* label_col_idx= */ 0, &node));
// Top value is the mean of the labels.
EXPECT_EQ(node.node().regressor().top_value(), 2.5f);
// Distribution of the gradients:
EXPECT_EQ(node.node().regressor().distribution().sum(), 0);
EXPECT_EQ(node.node().regressor().distribution().sum_squares(), 0);
// Same as the number of examples in the dataset.
EXPECT_EQ(node.node().regressor().distribution().count(), 4.);
}
}

TEST(MeanSquareErrorLossTest, ComputeClassificationLoss) {
TEST_P(MeanSquareErrorLossTest, ComputeClassificationLoss) {
ASSERT_OK_AND_ASSIGN(const dataset::VerticalDataset dataset,
CreateToyDataset());
std::vector<float> weights = {1.f, 1.f, 1.f, 1.f};
const bool weighted = GetParam();
std::vector<float> weights;
if (weighted) {
weights = {2.f, 4.f, 6.f, 8.f};
}

std::vector<float> predictions = {0.f, 0.f, 0.f, 0.f};
const MeanSquaredErrorLoss loss_imp({}, model::proto::Task::REGRESSION,
dataset.data_spec().columns(0));
ASSERT_OK_AND_ASSIGN(
LossResults loss_results,
loss_imp.Loss(dataset,
/* label_col_idx= */ 0, predictions, weights, nullptr));

EXPECT_NEAR(loss_results.loss, std::sqrt(30. / 4.), kTestPrecision);
// For classification, the only secondary metric is also RMSE.
EXPECT_THAT(loss_results.secondary_metrics,
ElementsAre(FloatNear(std::sqrt(30. / 4.), kTestPrecision)));
if (weighted) {
EXPECT_NEAR(loss_results.loss, std::sqrt(200. / 20.), kTestPrecision);
// For classification, the only secondary metric is also RMSE.
EXPECT_THAT(loss_results.secondary_metrics,
ElementsAre(FloatNear(std::sqrt(200. / 20.), kTestPrecision)));
} else {
EXPECT_NEAR(loss_results.loss, std::sqrt(30. / 4.), kTestPrecision);
// For classification, the only secondary metric is also RMSE.
EXPECT_THAT(loss_results.secondary_metrics,
ElementsAre(FloatNear(std::sqrt(30. / 4.), kTestPrecision)));
}
}

TEST(MeanSquareErrorLossTest, ComputeRankingLoss) {
TEST_P(MeanSquareErrorLossTest, ComputeRankingLoss) {
ASSERT_OK_AND_ASSIGN(const dataset::VerticalDataset dataset,
CreateToyDataset());
std::vector<float> weights = {1.f, 1.f, 1.f, 1.f};
const bool weighted = GetParam();
std::vector<float> weights;
if (weighted) {
weights = {2.f, 4.f, 6.f, 8.f};
}

std::vector<float> predictions = {0.f, 0.f, 0.f, 0.f};
const MeanSquaredErrorLoss loss_imp({}, model::proto::Task::RANKING,
dataset.data_spec().columns(0));
Expand All @@ -171,13 +228,21 @@ TEST(MeanSquareErrorLossTest, ComputeRankingLoss) {
LossResults loss_results,
loss_imp.Loss(dataset,
/* label_col_idx= */ 0, predictions, weights, &index));

EXPECT_NEAR(loss_results.loss, std::sqrt(30. / 4.), kTestPrecision);
// For ranking, first secondary metric is RMSE, second secondary metric is
// NDCG@5.
EXPECT_THAT(loss_results.secondary_metrics,
ElementsAre(FloatNear(std::sqrt(30. / 4.), kTestPrecision),
FloatNear(0.861909, kTestPrecision)));
if (weighted) {
EXPECT_NEAR(loss_results.loss, std::sqrt(200. / 20.), kTestPrecision);
// For ranking, first secondary metric is RMSE, second secondary metric is
// NDCG@5.
EXPECT_THAT(loss_results.secondary_metrics,
ElementsAre(FloatNear(std::sqrt(200. / 20.), kTestPrecision),
FloatNear(0.86291, kTestPrecision)));
} else {
EXPECT_NEAR(loss_results.loss, std::sqrt(30. / 4.), kTestPrecision);
// For ranking, first secondary metric is RMSE, second secondary metric is
// NDCG@5.
EXPECT_THAT(loss_results.secondary_metrics,
ElementsAre(FloatNear(std::sqrt(30. / 4.), kTestPrecision),
FloatNear(0.861909, kTestPrecision)));
}
}

TEST(MeanSquareErrorLossTest, SecondaryMetricNamesClassification) {
Expand All @@ -198,6 +263,9 @@ TEST(MeanSquareErrorLossTest, SecondaryMetricNamesRanking) {
ElementsAre("rmse", "NDCG@5"));
}

INSTANTIATE_TEST_SUITE_P(MeanSquareErrorLossTestWithWeights,
MeanSquareErrorLossTest, testing::Bool());

} // namespace
} // namespace gradient_boosted_trees
} // namespace model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ double RankingGroupsIndices::NDCG(const std::vector<float>& predictions,
const std::vector<float>& weights,
const int truncation) const {
DCHECK_EQ(predictions.size(), num_items_);
DCHECK_EQ(weights.size(), num_items_);
DCHECK(weights.empty() || weights.size() == num_items_);

metric::NDCGCalculator ndcg_calculator(truncation);
std::vector<metric::RankingLabelAndPrediction> pred_and_label_relevance;
Expand Down

0 comments on commit a694e81

Please sign in to comment.