From ca346cb0c919e26ff19e44455dd3dfa0cfe46323 Mon Sep 17 00:00:00 2001 From: Richard Stotz Date: Mon, 2 Dec 2024 02:26:26 -0800 Subject: [PATCH] [YDF] Add sparse_oblique_max_num_projections hyperparameter PiperOrigin-RevId: 701891284 --- .../learner/decision_tree/BUILD | 2 + .../learner/decision_tree/decision_tree.proto | 8 +++ .../decision_tree/generic_parameters.cc | 33 ++++++++++ .../decision_tree/generic_parameters.h | 3 + .../decision_tree/generic_parameters_test.cc | 6 ++ .../learner/decision_tree/oblique.cc | 28 +++++++++ .../learner/decision_tree/training_test.cc | 63 ++++++++++++++++++- .../isolation_forest/isolation_forest.cc | 1 + .../port/python/CHANGELOG.md | 1 + 9 files changed, 144 insertions(+), 1 deletion(-) diff --git a/yggdrasil_decision_forests/learner/decision_tree/BUILD b/yggdrasil_decision_forests/learner/decision_tree/BUILD index cd050939..5823469e 100644 --- a/yggdrasil_decision_forests/learner/decision_tree/BUILD +++ b/yggdrasil_decision_forests/learner/decision_tree/BUILD @@ -46,10 +46,12 @@ cc_library_ydf( "//yggdrasil_decision_forests/utils:random", "//yggdrasil_decision_forests/utils:status_macros", "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:btree", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/random", + "@com_google_absl//absl/random:distributions", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto b/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto index a1eed710..e0041d5d 100644 --- a/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto +++ b/yggdrasil_decision_forests/learner/decision_tree/decision_tree.proto @@ -277,6 +277,14 @@ message DecisionTreeTrainingConfig { // entire train dataset. MIN_MAX = 2; } + + // Maximum number of features in a projection. Set to -1 or not provided for + // no maximum. + // + // Use only if a hard maximum on the number of variables is needed, + // otherwise prefer `projection_density_factor` for controlling the number + // of features per projection. + optional int32 max_num_features = 11 [default = -1]; } message MHLDObliqueSplit { diff --git a/yggdrasil_decision_forests/learner/decision_tree/generic_parameters.cc b/yggdrasil_decision_forests/learner/decision_tree/generic_parameters.cc index 94f370e7..ed2a4974 100644 --- a/yggdrasil_decision_forests/learner/decision_tree/generic_parameters.cc +++ b/yggdrasil_decision_forests/learner/decision_tree/generic_parameters.cc @@ -398,6 +398,22 @@ integer weights i.e. `sparse_oblique_weights=INTEGER`. Minimum value of the weig integer weights i.e. `sparse_oblique_weights=INTEGER`. Maximum value of the weights)"); } + { + ASSIGN_OR_RETURN(auto param, + get_params(kHParamSplitAxisSparseObliqueMaxNumFeatures)); + param->mutable_integer()->set_default_value( + config.sparse_oblique_split().max_num_features()); + param->mutable_conditional()->set_control_field(kHParamSplitAxis); + param->mutable_integer()->set_minimum(-1); + param->mutable_documentation()->set_proto_field("max_num_features"); + param->mutable_conditional()->mutable_categorical()->add_values( + kHParamSplitAxisSparseOblique); + param->mutable_documentation()->set_description( + R"(For sparse oblique splits i.e. `split_axis=SPARSE_OBLIQUE`. +Controls the maximum number of features in a split.Set to -1 for no maximum. +Use only if a hard maximum on the number of variables is needed, otherwise prefer `projection_density_factor` for controlling the number of features per projection.)"); + } + { ASSIGN_OR_RETURN( auto param, get_params(kHParamSplitAxisSparseObliqueMaxNumProjections)); @@ -811,6 +827,23 @@ absl::Status SetHyperParameters( } } + { + const auto hparam = + generic_hyper_params->Get(kHParamSplitAxisSparseObliqueMaxNumFeatures); + if (hparam.has_value()) { + const auto hparam_value = hparam.value().value().integer(); + if (dt_config->has_sparse_oblique_split()) { + dt_config->mutable_sparse_oblique_split()->set_max_num_features( + hparam_value); + } else { + return absl::InvalidArgumentError( + absl::StrCat(kHParamSplitAxisSparseObliqueMaxNumFeatures, + " only works with sparse oblique trees " + "(split_axis=SPARSE_OBLIQUE)")); + } + } + } + { const auto hparam = generic_hyper_params->Get(kHParamSplitAxisSparseObliqueWeights); diff --git a/yggdrasil_decision_forests/learner/decision_tree/generic_parameters.h b/yggdrasil_decision_forests/learner/decision_tree/generic_parameters.h index 674b45ea..f3f4218d 100644 --- a/yggdrasil_decision_forests/learner/decision_tree/generic_parameters.h +++ b/yggdrasil_decision_forests/learner/decision_tree/generic_parameters.h @@ -98,6 +98,9 @@ constexpr char kHParamSplitAxisMhldObliqueMaxNumAttributes[] = constexpr char kHParamSplitAxisMhldObliqueSampleAttributes[] = "mhld_oblique_sample_attributes"; +constexpr char kHParamSplitAxisSparseObliqueMaxNumFeatures[] = + "sparse_oblique_max_num_features"; + constexpr char kHParamSortingStrategy[] = "sorting_strategy"; constexpr char kHParamSortingStrategyInNode[] = "IN_NODE"; constexpr char kHParamSortingStrategyPresort[] = "PRESORT"; diff --git a/yggdrasil_decision_forests/learner/decision_tree/generic_parameters_test.cc b/yggdrasil_decision_forests/learner/decision_tree/generic_parameters_test.cc index 2ccc921c..de49c16e 100644 --- a/yggdrasil_decision_forests/learner/decision_tree/generic_parameters_test.cc +++ b/yggdrasil_decision_forests/learner/decision_tree/generic_parameters_test.cc @@ -70,6 +70,7 @@ TEST(GenericParameters, GiveValidAndInvalidHyperparameters) { kHParamHonest, kHParamHonestRatioLeafExamples, kHParamHonestFixedSeparation, + kHParamSplitAxisSparseObliqueMaxNumFeatures, kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent, kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent, kHParamSplitAxisSparseObliqueWeightsIntegerMinimum, @@ -115,6 +116,7 @@ TEST(GenericParameters, MissingValidHyperparameters) { kHParamHonest, kHParamHonestRatioLeafExamples, kHParamHonestFixedSeparation, + kHParamSplitAxisSparseObliqueMaxNumFeatures, kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent, kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent, kHParamSplitAxisSparseObliqueWeightsIntegerMinimum, @@ -158,6 +160,7 @@ TEST(GenericParameters, MissingInvalidHyperparameters) { kHParamHonest, kHParamHonestRatioLeafExamples, kHParamHonestFixedSeparation, + kHParamSplitAxisSparseObliqueMaxNumFeatures, kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent, kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent, kHParamSplitAxisSparseObliqueWeightsIntegerMinimum, @@ -202,6 +205,7 @@ TEST(GenericParameters, UnknownValidHyperparameter) { kHParamHonest, kHParamHonestRatioLeafExamples, kHParamHonestFixedSeparation, + kHParamSplitAxisSparseObliqueMaxNumFeatures, kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent, kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent, kHParamSplitAxisSparseObliqueWeightsIntegerMinimum, @@ -246,6 +250,7 @@ TEST(GenericParameters, UnknownInvalidHyperparameter) { kHParamHonest, kHParamHonestRatioLeafExamples, kHParamHonestFixedSeparation, + kHParamSplitAxisSparseObliqueMaxNumFeatures, kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent, kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent, kHParamSplitAxisSparseObliqueWeightsIntegerMinimum, @@ -292,6 +297,7 @@ TEST(GenericParameters, ExistingHyperparameter) { kHParamHonest, kHParamHonestRatioLeafExamples, kHParamHonestFixedSeparation, + kHParamSplitAxisSparseObliqueMaxNumFeatures, kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent, kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent, kHParamSplitAxisSparseObliqueWeightsIntegerMinimum, diff --git a/yggdrasil_decision_forests/learner/decision_tree/oblique.cc b/yggdrasil_decision_forests/learner/decision_tree/oblique.cc index 363f5c13..40a571c6 100644 --- a/yggdrasil_decision_forests/learner/decision_tree/oblique.cc +++ b/yggdrasil_decision_forests/learner/decision_tree/oblique.cc @@ -23,9 +23,12 @@ #include #include #include +#include #include +#include "absl/container/btree_set.h" #include "absl/log/log.h" +#include "absl/random/distributions.h" #include "absl/random/random.h" #include "absl/status/status.h" #include "absl/status/statusor.h" @@ -746,6 +749,31 @@ void SampleProjection(const absl::Span& features, } else if (projection->size() == 1) { projection->front().weight = 1.f; } + + int max_num_features = dt_config.sparse_oblique_split().max_num_features(); + int cur_num_projections = projection->size(); + + if (max_num_features > 0 && cur_num_projections > max_num_features) { + internal::Projection resampled_projection; + resampled_projection.reserve(max_num_features); + // For a small number of features, a boolean vector is more efficient. + // Re-evaluate if this becomes a bottleneck. + absl::btree_set sampled_features; + // Floyd's sampling algorithm. + for (size_t j = cur_num_projections - max_num_features; + j < cur_num_projections; j++) { + size_t t = absl::Uniform(*random, 0, j + 1); + if (!sampled_features.insert(t).second) { + // t was already sampled, so insert j instead. + sampled_features.insert(j); + resampled_projection.push_back((*projection)[j]); + } else { + // t was not yet sampled. + resampled_projection.push_back((*projection)[t]); + } + } + *projection = std::move(resampled_projection); + } } absl::Status SetCondition(const Projection& projection, const float threshold, diff --git a/yggdrasil_decision_forests/learner/decision_tree/training_test.cc b/yggdrasil_decision_forests/learner/decision_tree/training_test.cc index 6c4b5453..9f956dae 100644 --- a/yggdrasil_decision_forests/learner/decision_tree/training_test.cc +++ b/yggdrasil_decision_forests/learner/decision_tree/training_test.cc @@ -334,7 +334,7 @@ TEST(DecisionTreeTrainingTest, StatusIs(absl::StatusCode::kInvalidArgument)); } -TEST(SparaseOblique, Classification) { +TEST(SparseOblique, Classification) { const model::proto::TrainingConfig config; model::proto::TrainingConfigLinking config_link; config_link.set_label(0); @@ -395,6 +395,67 @@ TEST(SparaseOblique, Classification) { EXPECT_NEAR(best_condition.split_score(), 0.693, 0.001); } +TEST(SparseOblique, ClassificationMaxNumFeatures) { + const model::proto::TrainingConfig config; + model::proto::TrainingConfigLinking config_link; + config_link.set_label(0); + config_link.add_numerical_features(1); + config_link.add_numerical_features(2); + + proto::DecisionTreeTrainingConfig dt_config; + dt_config.mutable_sparse_oblique_split(); + dt_config.set_min_examples(1); + dt_config.mutable_internal()->set_sorting_strategy( + proto::DecisionTreeTrainingConfig::Internal::IN_NODE); + dt_config.mutable_sparse_oblique_split()->set_max_num_features(1); + + dataset::VerticalDataset dataset; + ASSERT_OK_AND_ASSIGN( + auto label_col, + dataset.AddColumn("l", dataset::proto::ColumnType::CATEGORICAL)); + label_col->mutable_categorical()->set_is_already_integerized(true); + label_col->mutable_categorical()->set_number_of_unique_values(3); + EXPECT_OK( + dataset.AddColumn("f1", dataset::proto::ColumnType::NUMERICAL).status()); + EXPECT_OK( + dataset.AddColumn("f2", dataset::proto::ColumnType::NUMERICAL).status()); + EXPECT_OK(dataset.CreateColumnsFromDataspec()); + + dataset.AppendExample({{"l", "1"}, {"f1", "0.1"}, {"f2", "0.1"}}); + dataset.AppendExample({{"l", "1"}, {"f1", "0.9"}, {"f2", "0.9"}}); + dataset.AppendExample({{"l", "2"}, {"f1", "0.1"}, {"f2", "0.15"}}); + dataset.AppendExample({{"l", "2"}, {"f1", "0.9"}, {"f2", "0.95"}}); + + ASSERT_OK_AND_ASSIGN(auto* label_data, + dataset.MutableColumnWithCastWithStatus< + dataset::VerticalDataset::CategoricalColumn>(0)); + + const std::vector selected_examples = {0, 1, 2, 3}; + const std::vector weights = {1.f, 1.f, 1.f, 1.f}; + + ClassificationLabelStats label_stats(label_data->values()); + label_stats.num_label_classes = 3; + label_stats.label_distribution.SetNumClasses(3); + for (const auto example_idx : selected_examples) { + label_stats.label_distribution.Add(label_data->values()[example_idx], + weights[example_idx]); + } + + proto::Node parent; + InternalTrainConfig internal_config; + proto::NodeCondition best_condition; + SplitterPerThreadCache cache; + utils::RandomEngine random; + const auto result = FindBestConditionOblique( + dataset, selected_examples, weights, config, + config_link, dt_config, parent, internal_config, + label_stats, 50, &best_condition, &random, &cache) + .value(); + EXPECT_TRUE(result); + EXPECT_EQ(best_condition.condition().oblique_condition().attributes_size(), + 1); +} + TEST(MHLDTOblique, Classification) { const model::proto::TrainingConfig config; model::proto::TrainingConfigLinking config_link; diff --git a/yggdrasil_decision_forests/learner/isolation_forest/isolation_forest.cc b/yggdrasil_decision_forests/learner/isolation_forest/isolation_forest.cc index 9278b7f2..4c16dabd 100644 --- a/yggdrasil_decision_forests/learner/isolation_forest/isolation_forest.cc +++ b/yggdrasil_decision_forests/learner/isolation_forest/isolation_forest.cc @@ -761,6 +761,7 @@ IsolationForestLearner::GetGenericHyperParameterSpecification() const { decision_tree::kHParamSplitAxisSparseObliqueProjectionDensityFactor, decision_tree::kHParamSplitAxisSparseObliqueNormalization, decision_tree::kHParamSplitAxisSparseObliqueWeights, + decision_tree::kHParamSplitAxisSparseObliqueMaxNumFeatures, decision_tree::kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent, decision_tree::kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent, decision_tree::kHParamSplitAxisSparseObliqueWeightsIntegerMinimum, diff --git a/yggdrasil_decision_forests/port/python/CHANGELOG.md b/yggdrasil_decision_forests/port/python/CHANGELOG.md index ce73b449..39b0623b 100644 --- a/yggdrasil_decision_forests/port/python/CHANGELOG.md +++ b/yggdrasil_decision_forests/port/python/CHANGELOG.md @@ -21,6 +21,7 @@ learner constructor argument. See the [feature selection tutorial]() for more details. - Add standalone prediction evaluation `ydf.evaluate_predictions()`. +- Add new hyperparameter `sparse_oblique_max_num_projections`. - Add options "POWER_OF_TWO" and "INTEGER" for sparse oblique weights. ### Fix