Skip to content

Commit

Permalink
Support for MAE loss for GBT (part2)
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 571341495
  • Loading branch information
achoum authored and copybara-github committed Oct 6, 2023
1 parent 57fbcf3 commit 9c32afb
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 2 deletions.
11 changes: 9 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
# Changelog

## Breaking changes
## Head

### Feature

- Mean average error loss for GBT.

## 1.6.0 - 2023-09-28

### Breaking changes

- The dependency to the distributed gradient boosted trees learner is renamed
from
Expand All @@ -15,7 +23,6 @@

### Feature


- Add support for monotonic constraints for gradient boosted trees.
- Improve speed of dataset reading and writing.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1273,6 +1273,15 @@ TEST_F(GradientBoostedTreesOnAbalone, PoissonLoss) {
YDF_TEST_METRIC(metric::RMSE(evaluation_), 2.1563, 0.0852, 2.1491);
}

TEST_F(GradientBoostedTreesOnAbalone, MAELoss) {
auto* gbt_config = train_config_.MutableExtension(
gradient_boosted_trees::proto::gradient_boosted_trees_config);
gbt_config->set_loss(proto::Loss::MEAN_AVERAGE_ERROR);
TrainAndEvaluateModel();
YDF_TEST_METRIC(metric::MAE(evaluation_), 1.5155, 0.0599, 1.4994);
YDF_TEST_METRIC(metric::RMSE(evaluation_), 2.2608, 0.1464, 2.2136);
}

class GradientBoostedTreesOnIris : public utils::TrainAndTestTester {
void SetUp() override {
train_config_.set_learner(GradientBoostedTreesLearner::kRegisteredName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ void GradientBoostedTreesModel::Predict(
dist->set_sum(1.f);
}
} break;
case proto::Loss::MEAN_AVERAGE_ERROR:
case proto::Loss::SQUARED_ERROR: {
double accumulator = initial_predictions_[0];
CallOnAllLeafs(dataset, row_idx,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ enum Loss {
BINARY_FOCAL_LOSS = 6;
// Poisson log likelihood. Only valid for regression.
POISSON = 7;
// Mean average error (MAE).
MEAN_AVERAGE_ERROR = 8;
}

// Log of the training. This proto is generated during the training of the
Expand Down

0 comments on commit 9c32afb

Please sign in to comment.