Skip to content

Commit

Permalink
Add support for monotonic constraints.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 569096752
  • Loading branch information
achoum authored and copybara-github committed Sep 28, 2023
1 parent 300cbe6 commit 39bebc1
Show file tree
Hide file tree
Showing 21 changed files with 867 additions and 212 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## HEAD

- Add support for monotonic constraints for gradient boosted trees.

## Fix

- Fix Window compilation with Visual Studio 2019
Expand Down
1 change: 1 addition & 0 deletions yggdrasil_decision_forests/dataset/data_spec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include <algorithm>
#include <cmath>
#include <cstdint>
#include <functional>
#include <iterator>
#include <limits>
Expand Down
1 change: 1 addition & 0 deletions yggdrasil_decision_forests/learner/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ cc_library_ydf(
"//yggdrasil_decision_forests/utils:status_macros",
"//yggdrasil_decision_forests/utils:synchronization_primitives",
"//yggdrasil_decision_forests/utils:uid",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand Down
57 changes: 56 additions & 1 deletion yggdrasil_decision_forests/learner/abstract_learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@
#include <utility>
#include <vector>

#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
#include "absl/strings/substitute.h"
#include "absl/time/time.h"
Expand Down Expand Up @@ -212,11 +214,56 @@ absl::Status AbstractLearner::LinkTrainingConfig(

*config_link->mutable_features() = {feature_idxs.begin(), feature_idxs.end()};

// Index the numerical features.
// Index numerical features
config_link->clear_numerical_features();
absl::flat_hash_set<int> numerical_features;
for (const auto feature_idx : feature_idxs) {
if (data_spec.columns(feature_idx).type() == dataset::proto::NUMERICAL) {
config_link->add_numerical_features(feature_idx);
numerical_features.insert(feature_idx);
}
}

// Allocate per-attributes array
config_link->clear_per_columns();
for (int i = 0; i < data_spec.columns_size(); i++) {
config_link->add_per_columns();
}

// Monotonicity constraints
for (const auto& src : training_config.monotonic_constraints()) {
if (src.feature().empty()) {
return absl::InvalidArgumentError(
"Empty \"feature\" in a monotonicity constraint");
}
std::vector<int32_t> feature_idxs;
dataset::GetMultipleColumnIdxFromName({src.feature()}, data_spec,
&feature_idxs);
if (feature_idxs.empty()) {
return absl::InvalidArgumentError(
absl::StrCat(src.feature(), " does not match any input features"));
}
for (const int src_feature : feature_idxs) {
if (numerical_features.find(src_feature) == numerical_features.end()) {
// Build error message.
std::vector<std::string> str_numerical_features;
str_numerical_features.reserve(numerical_features.size());
for (const auto feature_idx : numerical_features) {
str_numerical_features.push_back(
data_spec.columns(feature_idx).name());
}

return absl::InvalidArgumentError(absl::Substitute(
"Feature \"$0\" caught by regular expression \"$1\" is not a "
"numerical input feature of the "
"model. Make sure this "
"feature is also defined as input feature of the model, and that "
"it is numerical. The numerical input features are: [$2].",
data_spec.columns(src_feature).name(), src.feature(),
absl::StrJoin(str_numerical_features, ", ")));
}
auto* dst = config_link->mutable_per_columns(src_feature);
*dst->mutable_monotonic_constraint() = src;
}
}

Expand Down Expand Up @@ -770,6 +817,14 @@ absl::Status AbstractLearner::CheckCapabilities() const {
training_config().learner()));
}

// Monotonic constraints
if (!capabilities.support_monotonic_constraints() &&
training_config().monotonic_constraints_size() > 0) {
return absl::InvalidArgumentError(absl::Substitute(
"The learner $0 does not support monotonic constraints.",
training_config().learner()));
}

return absl::OkStatus();
}

Expand Down
36 changes: 36 additions & 0 deletions yggdrasil_decision_forests/learner/abstract_learner.proto
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,9 @@ message TrainingConfig {
// serving speed or RAM usage of model serving.
optional bool pure_serving_model = 14 [default = false];

// Set of monotonic constraints between the model's input features and output.
repeated MonotonicConstraint monotonic_constraints = 15;

// Learner specific configuration/hyper-parameters.
// The message/extension is dependent on the "learner". For example, see
// "yggdrasil_decision_forests/learner/random_forest.proto" for the parameters
Expand Down Expand Up @@ -198,6 +201,11 @@ message TrainingConfigLinking {

// Index of the column matching "uplift_treatment" in the "TrainingConfig".
optional int32 uplift_treatment = 12 [default = -1];

// Data for specific dataset columns.
// This field is either empty, or contains exactly one value for each column
// in the dataset.
repeated PerColumn per_columns = 13;
}

// Returns a list of hyper-parameter sets that outperforms the default
Expand Down Expand Up @@ -249,4 +257,32 @@ message LearnerCapabilities {
// If true, the algorithm supports training with a maximum model size
// (maximum_model_size_in_memory_in_bytes).
optional bool support_max_model_size_in_memory = 5 [default = false];

// If true, the algorithm supports monotonic constraints over numerical
// features.
optional bool support_monotonic_constraints = 6 [default = false];
}

// Monotonic constraints between model's output and numerical input features.
message MonotonicConstraint {
// Regular expressions over the input features.
optional string feature = 1;

optional Direction direction = 2 [default = INCREASING];

enum Direction {
// Ensure the model output is monotonic increasing (non-strict) with the
// feature.
INCREASING = 0;

// Ensure the model output is monotonic decreasing (non-strict) with the
// feature.
DECREASING = 1;
}
}

message PerColumn {
// If set, the attribute has a monotonic constraint.
// Note: monotonic_constraint.feature might not be set.
optional MonotonicConstraint monotonic_constraint = 1;
}
1 change: 1 addition & 0 deletions yggdrasil_decision_forests/learner/decision_tree/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ cc_test(
"//yggdrasil_decision_forests/dataset:data_spec_cc_proto",
"//yggdrasil_decision_forests/dataset:data_spec_inference",
"//yggdrasil_decision_forests/dataset:example_cc_proto",
"//yggdrasil_decision_forests/dataset:types",
"//yggdrasil_decision_forests/dataset:vertical_dataset",
"//yggdrasil_decision_forests/dataset:vertical_dataset_io",
"//yggdrasil_decision_forests/learner:abstract_learner",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,18 @@ message DecisionTreeTrainingConfig {
// rank, the "false" method is better than the "true" one by a small but
// visible margin.
optional bool hessian_split_score_subtract_parent = 22 [default = false];

// If true, partially checks monotonic constraints of trees after training.
// This option is used by unit testing. That is, check that the value of a
// positive node is greater than the value of a generative note (in case of
// increasing monotonic constraint). If false and if a monotonic constraint
// is not satisfied, the monotonic constraint is manually enforced.
//
// The current checking implementation might detect as non-monotonic trees
// that are in fact monotonic (e.g. false positive). However, with the
// current algorithm used to create monotonic constraints, this checking
// algorithm cannot create false positives.
optional bool check_monotonic_constraints = 23 [default = false];
}

// Deprecated tag numbers.
Expand Down
113 changes: 105 additions & 8 deletions yggdrasil_decision_forests/learner/decision_tree/decision_tree_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "yggdrasil_decision_forests/dataset/data_spec.pb.h"
#include "yggdrasil_decision_forests/dataset/data_spec_inference.h"
#include "yggdrasil_decision_forests/dataset/example.pb.h"
#include "yggdrasil_decision_forests/dataset/types.h"
#include "yggdrasil_decision_forests/dataset/vertical_dataset.h"
#include "yggdrasil_decision_forests/dataset/vertical_dataset_io.h"
#include "yggdrasil_decision_forests/learner/abstract_learner.h"
Expand Down Expand Up @@ -685,7 +686,7 @@ TEST(DecisionTree, FindBestConditionClassification) {
PerThreadCache cache;
EXPECT_TRUE(FindBestCondition(dataset, selected_examples, weights, config,
config_link, dt_config, {}, tree.root().node(),
{}, &condition, &random, &cache)
{}, {}, &condition, &random, &cache)
.value());

// We test that a condition was created on attribute 0 or 1 (non
Expand Down Expand Up @@ -2074,7 +2075,7 @@ TEST(DecisionTree, FindBestConditionConcurrentManager_NoFeatures) {

bool result = FindBestConditionConcurrentManager(
dataset, selected_examples, weights, config, config_link,
dt_config, setup, parent, internal_config, label_stats,
dt_config, setup, parent, internal_config, label_stats, {},
&best_condition, &random, &cache)
.value();

Expand Down Expand Up @@ -2112,7 +2113,7 @@ TEST(DecisionTree, FindBestConditionConcurrentManager_AlwaysInvalid) {
best_condition.set_split_score(0.f);
bool result = FindBestConditionConcurrentManager(
dataset, selected_examples, weights, config, config_link,
dt_config, setup, parent, internal_config, label_stats,
dt_config, setup, parent, internal_config, label_stats, {},
&best_condition, &random, &cache)
.value();

Expand Down Expand Up @@ -2156,7 +2157,7 @@ TEST(DecisionTree, FindBestConditionConcurrentManager_Multiplicative) {
best_condition.set_split_score(1000.f);
bool result = FindBestConditionConcurrentManager(
dataset, selected_examples, weights, config, config_link,
dt_config, setup, parent, internal_config, label_stats,
dt_config, setup, parent, internal_config, label_stats, {},
&best_condition, &random, &cache)
.value();

Expand All @@ -2167,7 +2168,7 @@ TEST(DecisionTree, FindBestConditionConcurrentManager_Multiplicative) {
best_condition.set_split_score(0.f);
result = FindBestConditionConcurrentManager(
dataset, selected_examples, weights, config, config_link,
dt_config, setup, parent, internal_config, label_stats,
dt_config, setup, parent, internal_config, label_stats, {},
&best_condition, &random, &cache)
.value();

Expand Down Expand Up @@ -2209,7 +2210,7 @@ TEST(DecisionTree, FindBestConditionConcurrentManager_Alternate) {
best_condition.set_split_score(0.f);
bool result = FindBestConditionConcurrentManager(
dataset, selected_examples, weights, config, config_link,
dt_config, setup, parent, internal_config, label_stats,
dt_config, setup, parent, internal_config, label_stats, {},
&best_condition, &random, &cache)
.value();

Expand Down Expand Up @@ -2251,7 +2252,7 @@ TEST(DecisionTree, FindBestConditionConcurrentManagerScaled) {
best_condition.set_split_score(1000.f);
bool result = FindBestConditionConcurrentManager(
dataset, selected_examples, weights, config, config_link,
dt_config, setup, parent, internal_config, label_stats,
dt_config, setup, parent, internal_config, label_stats, {},
&best_condition, &random, &cache)
.value();

Expand All @@ -2264,7 +2265,7 @@ TEST(DecisionTree, FindBestConditionConcurrentManagerScaled) {
best_condition.set_split_score(0.f);
result = FindBestConditionConcurrentManager(
dataset, selected_examples, weights, config, config_link,
dt_config, setup, parent, internal_config, label_stats,
dt_config, setup, parent, internal_config, label_stats, {},
&best_condition, &random, &cache)
.value();
EXPECT_TRUE(result);
Expand Down Expand Up @@ -2701,6 +2702,102 @@ TEST(DecisionTree, MinNumExamplePerTreatment) {
}
}

TEST(Monotonic, FindSplitLabelHessianRegressionFeatureNumericalCart) {
std::vector<float> weights;
const std::vector<UnsignedExampleIdx> selected_examples{0, 1, 2, 3};
const std::vector<float> attributes{1, 2, 3, 4};
const std::vector<float> gradients{-10, -10, 10, 10};
const std::vector<float> hessians{1, 1, 1, 1};

proto::DecisionTreeTrainingConfig dt_config;
dt_config.mutable_internal()->set_sorting_strategy(
proto::DecisionTreeTrainingConfig::Internal::IN_NODE);
const double sum_gradient =
std::accumulate(gradients.begin(), gradients.end(), 0.);
const double sum_hessian =
std::accumulate(hessians.begin(), hessians.end(), 0.);
const double sum_weights = selected_examples.size();

proto::NodeCondition best_condition;
SplitterPerThreadCache cache;
EXPECT_EQ(FindSplitLabelHessianRegressionFeatureNumericalCart<false>(
selected_examples, weights, attributes, gradients, hessians, 2,
1, dt_config, sum_gradient, sum_hessian, sum_weights, -1, {},
{}, 0, &best_condition, &cache),
SplitSearchResult::kBetterSplitFound);

EXPECT_EQ(best_condition.condition().higher_condition().threshold(), 2.5f);
EXPECT_EQ(best_condition.num_training_examples_without_weight(), 4);
EXPECT_EQ(best_condition.num_pos_training_examples_without_weight(), 2);
EXPECT_EQ(best_condition.na_value(), false);
EXPECT_EQ(best_condition.num_training_examples_with_weight(), 4);
EXPECT_EQ(best_condition.num_pos_training_examples_with_weight(), 2);
EXPECT_NEAR(best_condition.split_score(), 10 * 10 * 4, TEST_PRECISION);
}

TEST(Monotonic,
FindSplitLabelHessianRegressionFeatureNumericalCartWithRangeConstraint) {
std::vector<float> weights;
const std::vector<UnsignedExampleIdx> selected_examples{0, 1, 2, 3};
const std::vector<float> attributes{1, 2, 3, 4};
const std::vector<float> gradients{-10, -10, 10, 10};
const std::vector<float> hessians{1, 1, 1, 1};
const NodeConstraints constraints = {
.min_max_output = NodeConstraints::MinMax{.min = -5, .max = 5}};

proto::DecisionTreeTrainingConfig dt_config;
dt_config.mutable_internal()->set_sorting_strategy(
proto::DecisionTreeTrainingConfig::Internal::IN_NODE);
const double sum_gradient =
std::accumulate(gradients.begin(), gradients.end(), 0.);
const double sum_hessian =
std::accumulate(hessians.begin(), hessians.end(), 0.);
const double sum_weights = selected_examples.size();

proto::NodeCondition best_condition;
SplitterPerThreadCache cache;
EXPECT_EQ(FindSplitLabelHessianRegressionFeatureNumericalCart<false>(
selected_examples, weights, attributes, gradients, hessians, 2,
1, dt_config, sum_gradient, sum_hessian, sum_weights, -1, {},
constraints, 0, &best_condition, &cache),
SplitSearchResult::kBetterSplitFound);

EXPECT_EQ(best_condition.condition().higher_condition().threshold(), 2.5f);
EXPECT_EQ(best_condition.num_training_examples_without_weight(), 4);
EXPECT_EQ(best_condition.num_pos_training_examples_without_weight(), 2);
EXPECT_EQ(best_condition.na_value(), false);
EXPECT_EQ(best_condition.num_training_examples_with_weight(), 4);
EXPECT_EQ(best_condition.num_pos_training_examples_with_weight(), 2);
EXPECT_NEAR(best_condition.split_score(), 5 * 20 / 2 * 2, TEST_PRECISION);
}

TEST(
Monotonic,
FindSplitLabelHessianRegressionFeatureNumericalCartWithMonotonicConstraint) {
std::vector<float> weights;
const std::vector<UnsignedExampleIdx> selected_examples{0, 1, 2};
const std::vector<float> attributes{1, 2, 3};
const std::vector<float> gradients{-1, 1, -10};
const std::vector<float> hessians{1, 1, 1};

proto::DecisionTreeTrainingConfig dt_config;
dt_config.mutable_internal()->set_sorting_strategy(
proto::DecisionTreeTrainingConfig::Internal::IN_NODE);
const double sum_gradient =
std::accumulate(gradients.begin(), gradients.end(), 0.);
const double sum_hessian =
std::accumulate(hessians.begin(), hessians.end(), 0.);
const double sum_weights = selected_examples.size();

proto::NodeCondition best_condition;
SplitterPerThreadCache cache;
EXPECT_EQ(FindSplitLabelHessianRegressionFeatureNumericalCart<false>(
selected_examples, weights, attributes, gradients, hessians, 2,
1, dt_config, sum_gradient, sum_hessian, sum_weights, -1, {},
{}, 1, &best_condition, &cache),
SplitSearchResult::kInvalidAttribute);
}

} // namespace
} // namespace decision_tree
} // namespace model
Expand Down
Loading

0 comments on commit 39bebc1

Please sign in to comment.