Skip to content

Commit

Permalink
[YDF] Add sparse_oblique_max_num_projections hyperparameter
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 701891284
  • Loading branch information
rstz authored and copybara-github committed Dec 2, 2024
1 parent fd34b13 commit ca346cb
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 1 deletion.
2 changes: 2 additions & 0 deletions yggdrasil_decision_forests/learner/decision_tree/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ cc_library_ydf(
"//yggdrasil_decision_forests/utils:random",
"//yggdrasil_decision_forests/utils:status_macros",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/random",
"@com_google_absl//absl/random:distributions",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,14 @@ message DecisionTreeTrainingConfig {
// entire train dataset.
MIN_MAX = 2;
}

// Maximum number of features in a projection. Set to -1 or not provided for
// no maximum.
//
// Use only if a hard maximum on the number of variables is needed,
// otherwise prefer `projection_density_factor` for controlling the number
// of features per projection.
optional int32 max_num_features = 11 [default = -1];
}

message MHLDObliqueSplit {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,22 @@ integer weights i.e. `sparse_oblique_weights=INTEGER`. Minimum value of the weig
integer weights i.e. `sparse_oblique_weights=INTEGER`. Maximum value of the weights)");
}

{
ASSIGN_OR_RETURN(auto param,
get_params(kHParamSplitAxisSparseObliqueMaxNumFeatures));
param->mutable_integer()->set_default_value(
config.sparse_oblique_split().max_num_features());
param->mutable_conditional()->set_control_field(kHParamSplitAxis);
param->mutable_integer()->set_minimum(-1);
param->mutable_documentation()->set_proto_field("max_num_features");
param->mutable_conditional()->mutable_categorical()->add_values(
kHParamSplitAxisSparseOblique);
param->mutable_documentation()->set_description(
R"(For sparse oblique splits i.e. `split_axis=SPARSE_OBLIQUE`.
Controls the maximum number of features in a split.Set to -1 for no maximum.
Use only if a hard maximum on the number of variables is needed, otherwise prefer `projection_density_factor` for controlling the number of features per projection.)");
}

{
ASSIGN_OR_RETURN(
auto param, get_params(kHParamSplitAxisSparseObliqueMaxNumProjections));
Expand Down Expand Up @@ -811,6 +827,23 @@ absl::Status SetHyperParameters(
}
}

{
const auto hparam =
generic_hyper_params->Get(kHParamSplitAxisSparseObliqueMaxNumFeatures);
if (hparam.has_value()) {
const auto hparam_value = hparam.value().value().integer();
if (dt_config->has_sparse_oblique_split()) {
dt_config->mutable_sparse_oblique_split()->set_max_num_features(
hparam_value);
} else {
return absl::InvalidArgumentError(
absl::StrCat(kHParamSplitAxisSparseObliqueMaxNumFeatures,
" only works with sparse oblique trees "
"(split_axis=SPARSE_OBLIQUE)"));
}
}
}

{
const auto hparam =
generic_hyper_params->Get(kHParamSplitAxisSparseObliqueWeights);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ constexpr char kHParamSplitAxisMhldObliqueMaxNumAttributes[] =
constexpr char kHParamSplitAxisMhldObliqueSampleAttributes[] =
"mhld_oblique_sample_attributes";

constexpr char kHParamSplitAxisSparseObliqueMaxNumFeatures[] =
"sparse_oblique_max_num_features";

constexpr char kHParamSortingStrategy[] = "sorting_strategy";
constexpr char kHParamSortingStrategyInNode[] = "IN_NODE";
constexpr char kHParamSortingStrategyPresort[] = "PRESORT";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ TEST(GenericParameters, GiveValidAndInvalidHyperparameters) {
kHParamHonest,
kHParamHonestRatioLeafExamples,
kHParamHonestFixedSeparation,
kHParamSplitAxisSparseObliqueMaxNumFeatures,
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent,
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent,
kHParamSplitAxisSparseObliqueWeightsIntegerMinimum,
Expand Down Expand Up @@ -115,6 +116,7 @@ TEST(GenericParameters, MissingValidHyperparameters) {
kHParamHonest,
kHParamHonestRatioLeafExamples,
kHParamHonestFixedSeparation,
kHParamSplitAxisSparseObliqueMaxNumFeatures,
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent,
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent,
kHParamSplitAxisSparseObliqueWeightsIntegerMinimum,
Expand Down Expand Up @@ -158,6 +160,7 @@ TEST(GenericParameters, MissingInvalidHyperparameters) {
kHParamHonest,
kHParamHonestRatioLeafExamples,
kHParamHonestFixedSeparation,
kHParamSplitAxisSparseObliqueMaxNumFeatures,
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent,
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent,
kHParamSplitAxisSparseObliqueWeightsIntegerMinimum,
Expand Down Expand Up @@ -202,6 +205,7 @@ TEST(GenericParameters, UnknownValidHyperparameter) {
kHParamHonest,
kHParamHonestRatioLeafExamples,
kHParamHonestFixedSeparation,
kHParamSplitAxisSparseObliqueMaxNumFeatures,
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent,
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent,
kHParamSplitAxisSparseObliqueWeightsIntegerMinimum,
Expand Down Expand Up @@ -246,6 +250,7 @@ TEST(GenericParameters, UnknownInvalidHyperparameter) {
kHParamHonest,
kHParamHonestRatioLeafExamples,
kHParamHonestFixedSeparation,
kHParamSplitAxisSparseObliqueMaxNumFeatures,
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent,
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent,
kHParamSplitAxisSparseObliqueWeightsIntegerMinimum,
Expand Down Expand Up @@ -292,6 +297,7 @@ TEST(GenericParameters, ExistingHyperparameter) {
kHParamHonest,
kHParamHonestRatioLeafExamples,
kHParamHonestFixedSeparation,
kHParamSplitAxisSparseObliqueMaxNumFeatures,
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent,
kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent,
kHParamSplitAxisSparseObliqueWeightsIntegerMinimum,
Expand Down
28 changes: 28 additions & 0 deletions yggdrasil_decision_forests/learner/decision_tree/oblique.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@
#include <optional>
#include <random>
#include <type_traits>
#include <utility>
#include <vector>

#include "absl/container/btree_set.h"
#include "absl/log/log.h"
#include "absl/random/distributions.h"
#include "absl/random/random.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
Expand Down Expand Up @@ -746,6 +749,31 @@ void SampleProjection(const absl::Span<const int>& features,
} else if (projection->size() == 1) {
projection->front().weight = 1.f;
}

int max_num_features = dt_config.sparse_oblique_split().max_num_features();
int cur_num_projections = projection->size();

if (max_num_features > 0 && cur_num_projections > max_num_features) {
internal::Projection resampled_projection;
resampled_projection.reserve(max_num_features);
// For a small number of features, a boolean vector is more efficient.
// Re-evaluate if this becomes a bottleneck.
absl::btree_set<size_t> sampled_features;
// Floyd's sampling algorithm.
for (size_t j = cur_num_projections - max_num_features;
j < cur_num_projections; j++) {
size_t t = absl::Uniform<size_t>(*random, 0, j + 1);
if (!sampled_features.insert(t).second) {
// t was already sampled, so insert j instead.
sampled_features.insert(j);
resampled_projection.push_back((*projection)[j]);
} else {
// t was not yet sampled.
resampled_projection.push_back((*projection)[t]);
}
}
*projection = std::move(resampled_projection);
}
}

absl::Status SetCondition(const Projection& projection, const float threshold,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ TEST(DecisionTreeTrainingTest,
StatusIs(absl::StatusCode::kInvalidArgument));
}

TEST(SparaseOblique, Classification) {
TEST(SparseOblique, Classification) {
const model::proto::TrainingConfig config;
model::proto::TrainingConfigLinking config_link;
config_link.set_label(0);
Expand Down Expand Up @@ -395,6 +395,67 @@ TEST(SparaseOblique, Classification) {
EXPECT_NEAR(best_condition.split_score(), 0.693, 0.001);
}

TEST(SparseOblique, ClassificationMaxNumFeatures) {
const model::proto::TrainingConfig config;
model::proto::TrainingConfigLinking config_link;
config_link.set_label(0);
config_link.add_numerical_features(1);
config_link.add_numerical_features(2);

proto::DecisionTreeTrainingConfig dt_config;
dt_config.mutable_sparse_oblique_split();
dt_config.set_min_examples(1);
dt_config.mutable_internal()->set_sorting_strategy(
proto::DecisionTreeTrainingConfig::Internal::IN_NODE);
dt_config.mutable_sparse_oblique_split()->set_max_num_features(1);

dataset::VerticalDataset dataset;
ASSERT_OK_AND_ASSIGN(
auto label_col,
dataset.AddColumn("l", dataset::proto::ColumnType::CATEGORICAL));
label_col->mutable_categorical()->set_is_already_integerized(true);
label_col->mutable_categorical()->set_number_of_unique_values(3);
EXPECT_OK(
dataset.AddColumn("f1", dataset::proto::ColumnType::NUMERICAL).status());
EXPECT_OK(
dataset.AddColumn("f2", dataset::proto::ColumnType::NUMERICAL).status());
EXPECT_OK(dataset.CreateColumnsFromDataspec());

dataset.AppendExample({{"l", "1"}, {"f1", "0.1"}, {"f2", "0.1"}});
dataset.AppendExample({{"l", "1"}, {"f1", "0.9"}, {"f2", "0.9"}});
dataset.AppendExample({{"l", "2"}, {"f1", "0.1"}, {"f2", "0.15"}});
dataset.AppendExample({{"l", "2"}, {"f1", "0.9"}, {"f2", "0.95"}});

ASSERT_OK_AND_ASSIGN(auto* label_data,
dataset.MutableColumnWithCastWithStatus<
dataset::VerticalDataset::CategoricalColumn>(0));

const std::vector<UnsignedExampleIdx> selected_examples = {0, 1, 2, 3};
const std::vector<float> weights = {1.f, 1.f, 1.f, 1.f};

ClassificationLabelStats label_stats(label_data->values());
label_stats.num_label_classes = 3;
label_stats.label_distribution.SetNumClasses(3);
for (const auto example_idx : selected_examples) {
label_stats.label_distribution.Add(label_data->values()[example_idx],
weights[example_idx]);
}

proto::Node parent;
InternalTrainConfig internal_config;
proto::NodeCondition best_condition;
SplitterPerThreadCache cache;
utils::RandomEngine random;
const auto result = FindBestConditionOblique(
dataset, selected_examples, weights, config,
config_link, dt_config, parent, internal_config,
label_stats, 50, &best_condition, &random, &cache)
.value();
EXPECT_TRUE(result);
EXPECT_EQ(best_condition.condition().oblique_condition().attributes_size(),
1);
}

TEST(MHLDTOblique, Classification) {
const model::proto::TrainingConfig config;
model::proto::TrainingConfigLinking config_link;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,7 @@ IsolationForestLearner::GetGenericHyperParameterSpecification() const {
decision_tree::kHParamSplitAxisSparseObliqueProjectionDensityFactor,
decision_tree::kHParamSplitAxisSparseObliqueNormalization,
decision_tree::kHParamSplitAxisSparseObliqueWeights,
decision_tree::kHParamSplitAxisSparseObliqueMaxNumFeatures,
decision_tree::kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMinExponent,
decision_tree::kHParamSplitAxisSparseObliqueWeightsPowerOfTwoMaxExponent,
decision_tree::kHParamSplitAxisSparseObliqueWeightsIntegerMinimum,
Expand Down
1 change: 1 addition & 0 deletions yggdrasil_decision_forests/port/python/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
learner constructor argument. See the [feature selection tutorial]() for
more details.
- Add standalone prediction evaluation `ydf.evaluate_predictions()`.
- Add new hyperparameter `sparse_oblique_max_num_projections`.
- Add options "POWER_OF_TWO" and "INTEGER" for sparse oblique weights.

### Fix
Expand Down

0 comments on commit ca346cb

Please sign in to comment.