Skip to content

Commit

Permalink
[YDF] Fix Squared error loss
Browse files Browse the repository at this point in the history
When using squared error loss, the leaf values were incorrectly clamped (by default to -5, 5).
This clamping should not be the default.

PiperOrigin-RevId: 583376180
  • Loading branch information
rstz authored and copybara-github committed Nov 17, 2023
1 parent 3592598 commit 7f04d57
Show file tree
Hide file tree
Showing 5 changed files with 3,074 additions and 4 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## HEAD

### Fix

Regression with Mean Squared Error loss and Mean Average error loss incorrectly
clamped the gradients, leading to incorrect predictions.

## 1.7.0 - 2023-10-20

### Feature
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1845,6 +1845,62 @@ TEST_F(AutotunedGradientBoostedTreesOnAdult,
EXPECT_EQ(model_->hyperparameter_optimizer_logs()->steps_size(), 25);
}

// TODO - b/311636358 Refactor the GBT tests to be more cohesive and
// comprehensive.
struct RegressionEnd2EndTestParams {
const proto::Loss loss;
const float expected_rmse;
};

using RegressionEnd2EndTest =
testing::TestWithParam<RegressionEnd2EndTestParams>;

TEST_P(RegressionEnd2EndTest, LargeValues) {
const std::string train_typed_path = absl::StrCat(
"csv:", file::JoinPath(DatasetDir(), "two_center_regression_train.csv"));
const std::string test_typed_path = absl::StrCat(
"csv:", file::JoinPath(DatasetDir(), "two_center_regression_test.csv"));
dataset::proto::DataSpecification data_spec;
dataset::proto::DataSpecificationGuide guide;
dataset::CreateDataSpec(train_typed_path, false, guide, &data_spec);
dataset::VerticalDataset train_dataset;
CHECK_OK(LoadVerticalDataset(train_typed_path, data_spec, &train_dataset));
dataset::VerticalDataset test_dataset;
CHECK_OK(LoadVerticalDataset(test_typed_path, data_spec, &test_dataset));
utils::RandomEngine random(1234);

model::proto::DeploymentConfig deployment_config;
model::proto::TrainingConfig train_config;
train_config.set_label("target");
train_config.set_learner(GradientBoostedTreesLearner::kRegisteredName);
train_config.set_task(model::proto::Task::REGRESSION);
auto* gbt_config = train_config.MutableExtension(
gradient_boosted_trees::proto::gradient_boosted_trees_config);
gbt_config->set_loss(GetParam().loss);
std::unique_ptr<model::AbstractLearner> learner;
CHECK_OK(model::GetLearner(train_config, &learner, deployment_config));

ASSERT_OK_AND_ASSIGN(auto model, learner->TrainWithStatus(train_dataset));
metric::proto::EvaluationOptions eval_options;
eval_options.set_task(model::proto::Task::REGRESSION);
eval_options.mutable_regression()->set_enable_regression_plots(false);
ASSERT_OK_AND_ASSIGN(auto eval, model->EvaluateWithStatus(
test_dataset, eval_options, &random));

EXPECT_NEAR(metric::RMSE(eval), GetParam().expected_rmse, 1.0);
}

INSTANTIATE_TEST_SUITE_P(
RegressionEnd2EndTest, RegressionEnd2EndTest,
testing::ValuesIn<RegressionEnd2EndTestParams>({
{proto::SQUARED_ERROR, 114.8},
{proto::MEAN_AVERAGE_ERROR, 4051.6},
{proto::POISSON, 114.8},
}),
[](const testing::TestParamInfo<RegressionEnd2EndTest::ParamType>& info) {
return proto::Loss_Name(info.param.loss);
});

} // namespace
} // namespace gradient_boosted_trees
} // namespace model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,11 @@ absl::Status SetLeafValueWithNewtonRaphsonStep(
const double denominator =
sum_weighted_hessian + gbt_config.l2_regularization();
float value = gbt_config.shrinkage() * numerator / denominator;
value = utils::clamp(value, -gbt_config.clamp_leaf_logit(),
gbt_config.clamp_leaf_logit());
// TODO - b/311636358: Move this information to the AbstractLoss class.
if (gbt_config.loss() != proto::SQUARED_ERROR) {
value = utils::clamp(value, -gbt_config.clamp_leaf_logit(),
gbt_config.clamp_leaf_logit());
}
reg->set_top_value(value);
return absl::OkStatus();
}
Expand Down Expand Up @@ -162,8 +165,10 @@ absl::Status SetLeafValueWithNewtonRaphsonStep(
sum_gradients, gbt_config_.l1_regularization());
const double denominator = sum_hessians + gbt_config_.l2_regularization();
float value = gbt_config_.shrinkage() * numerator / denominator;
value = utils::clamp(value, -gbt_config_.clamp_leaf_logit(),
gbt_config_.clamp_leaf_logit());
if (gbt_config_.loss() != proto::SQUARED_ERROR) {
value = utils::clamp(value, -gbt_config_.clamp_leaf_logit(),
gbt_config_.clamp_leaf_logit());
}
node->mutable_regressor()->set_top_value(value);
return absl::OkStatus();
}
Expand Down
Loading

0 comments on commit 7f04d57

Please sign in to comment.