diff --git a/yggdrasil_decision_forests/learner/decision_tree/splitter_accumulator.h b/yggdrasil_decision_forests/learner/decision_tree/splitter_accumulator.h index 2e5870d2..80aaa5a0 100644 --- a/yggdrasil_decision_forests/learner/decision_tree/splitter_accumulator.h +++ b/yggdrasil_decision_forests/learner/decision_tree/splitter_accumulator.h @@ -56,6 +56,44 @@ namespace yggdrasil_decision_forests { namespace model { namespace decision_tree { +namespace internal { + +// Bucket data containers. +// +// Bucket definitions are templated to facilitate code reuse. Since buckets are +// constructed many times, it's worth saving memory aggressively and only +// construct fields that will actually be used. C++ does not support 0-byte +// objects, hence the unused fields cannot be set to void. Empty structs would +// occupy 1 byte for the unused field. Combining two fields into a struct is +// therefore the most space-efficient alternative. +struct BooleanValueAndWeight { + bool value; + float weight; +}; + +struct BooleanValueOnly { + bool value; +}; + +struct IntegerValueAndWeight { + int value; + float weight; +}; + +struct IntegerValueOnly { + int value; +}; + +struct SumTruesAndWeights { + double sum_trues; + double sum_weights; +}; + +struct SumTruesOnly { + double sum_trues; +}; + +} // namespace internal // =============== // Feature Buckets @@ -756,7 +794,7 @@ struct LabelHessianNumericalScoreAccumulator { // Initialize and empty an accumulator. // void InitEmpty(ScoreAccumulator* acc) const; // -// Initialize an accumulator and set it to conain all the training examples. +// Initialize an accumulator and set it to contain all the training examples. // void InitFull(ScoreAccumulator* acc) const; // // Normalize the score of the bucket. The final split score is: @@ -774,7 +812,7 @@ struct LabelNumericalOneValueBucket { } void SubToScoreAcc(LabelNumericalScoreAccumulator* acc) const { - acc->label.Add(value, -weight); + acc->label.Sub(value, weight); } class Initializer { @@ -836,7 +874,7 @@ struct LabelNumericalOneValueBucket { template void SubDirectToScoreAcc(const ExampleIdx example_idx, LabelNumericalScoreAccumulator* acc) const { - acc->label.Add(label_[example_idx], -weights_[example_idx]); + acc->label.Sub(label_[example_idx], weights_[example_idx]); } template @@ -851,8 +889,8 @@ struct LabelNumericalOneValueBucket { void SubDirectToScoreAccWithDuplicates( const ExampleIdx example_idx, const int num_duplicates, LabelNumericalScoreAccumulator* acc) const { - acc->label.Add(label_[example_idx], - -weights_[example_idx] * num_duplicates); + acc->label.Sub(label_[example_idx], + weights_[example_idx] * num_duplicates); } template @@ -1019,9 +1057,12 @@ inline std::ostream& operator<<( return os; } +template struct LabelCategoricalOneValueBucket { - int value; - float weight; + typedef typename std::conditional_t + ValueAndMaybeWeight; + ValueAndMaybeWeight content; // Not called "kCount" because this is used as a template parameter and // expects the name to be `count` (in other such structs it is not a @@ -1029,11 +1070,19 @@ struct LabelCategoricalOneValueBucket { static constexpr int count = 1; // NOLINT void AddToScoreAcc(LabelCategoricalScoreAccumulator* acc) const { - acc->label.Add(value, weight); + if constexpr (weighted) { + acc->label.Add(content.value, content.weight); + } else { + acc->label.Add(content.value); + } } void SubToScoreAcc(LabelCategoricalScoreAccumulator* acc) const { - acc->label.Add(value, -weight); + if constexpr (weighted) { + acc->label.Sub(content.value, content.weight); + } else { + acc->label.Sub(content.value); + } } class Initializer { @@ -1071,51 +1120,77 @@ struct LabelCategoricalOneValueBucket { public: Filler(const std::vector& label, const std::vector& weights) : label_(label), weights_(weights) { - DCHECK_EQ(weights.size(), label.size()); + if constexpr (weighted) { + DCHECK_EQ(weights.size(), label.size()); + } else { + DCHECK(weights.empty()); + } } - void InitializeAndZero(LabelCategoricalOneValueBucket* acc) const {} + void InitializeAndZero( + LabelCategoricalOneValueBucket* bucket) const {} - void Finalize(LabelCategoricalOneValueBucket* acc) const {} + void Finalize(LabelCategoricalOneValueBucket* bucket) const {} - void ConsumeExample(const UnsignedExampleIdx example_idx, - LabelCategoricalOneValueBucket* acc) const { - acc->value = label_[example_idx]; - acc->weight = weights_[example_idx]; + void ConsumeExample( + const UnsignedExampleIdx example_idx, + LabelCategoricalOneValueBucket* bucket) const { + bucket->content.value = label_[example_idx]; + if constexpr (weighted) { + bucket->content.weight = weights_[example_idx]; + } } template void AddDirectToScoreAcc(const ExampleIdx example_idx, LabelCategoricalScoreAccumulator* acc) const { - acc->label.Add(label_[example_idx], weights_[example_idx]); + if constexpr (weighted) { + acc->label.Add(label_[example_idx], weights_[example_idx]); + } else { + acc->label.Add(label_[example_idx]); + } } template void SubDirectToScoreAcc(const ExampleIdx example_idx, LabelCategoricalScoreAccumulator* acc) const { - acc->label.Add(label_[example_idx], -weights_[example_idx]); + if constexpr (weighted) { + acc->label.Sub(label_[example_idx], weights_[example_idx]); + } else { + acc->label.Sub(label_[example_idx]); + } } template void AddDirectToScoreAccWithDuplicates( const ExampleIdx example_idx, const int num_duplicates, LabelCategoricalScoreAccumulator* acc) const { - acc->label.Add(label_[example_idx], - weights_[example_idx] * num_duplicates); + if constexpr (weighted) { + acc->label.Add(label_[example_idx], + weights_[example_idx] * num_duplicates); + } else { + acc->label.Add(label_[example_idx], num_duplicates); + } } template void SubDirectToScoreAccWithDuplicates( const ExampleIdx example_idx, const int num_duplicates, LabelCategoricalScoreAccumulator* acc) const { - acc->label.Add(label_[example_idx], - -weights_[example_idx] * num_duplicates); + if constexpr (weighted) { + acc->label.Sub(label_[example_idx], + weights_[example_idx] * num_duplicates); + } else { + acc->label.Sub(label_[example_idx], num_duplicates); + } } template void Prefetch(const ExampleIdx example_idx) const { PREFETCH(&label_[example_idx]); - PREFETCH(&weights_[example_idx]); + if constexpr (weighted) { + PREFETCH(&weights_[example_idx]); + } } private: @@ -1127,29 +1202,25 @@ struct LabelCategoricalOneValueBucket { const LabelCategoricalOneValueBucket& data); }; -inline std::ostream& operator<<(std::ostream& os, - const LabelCategoricalOneValueBucket& data) { - os << "value:" << data.value << " weight:" << data.weight +inline std::ostream& operator<<( + std::ostream& os, + const LabelCategoricalOneValueBucket& data) { + os << "value:" << data.content.value << " weight:" << data.content.weight << " count:" << data.count; return os; } +inline std::ostream& operator<<( + std::ostream& os, + const LabelCategoricalOneValueBucket& data) { + os << "value:" << data.content.value << " count:" << data.count; + return os; +} + template struct LabelBinaryCategoricalOneValueBucket { - // Since this object is constructed many times, it's worth saving memory - // aggressively. Since C++ does not support 0-byte objects, the conditional - // cannot be set to void. Setting the conditional to an empty struct would - // still use 1 byte. The construction below has size 1 if unweighted is true - // and size 8 (due to alignment) if unweighted is false. This is the same size - // as in a non-templated version that constructs two different buckets. - struct ValueAndWeight { - bool value; - float weight; - }; - struct ValueOnly { - bool value; - }; - typedef typename std::conditional_t + typedef typename std::conditional_t ValueAndMaybeWeight; ValueAndMaybeWeight content; @@ -1621,6 +1692,7 @@ inline std::ostream& operator<<(std::ostream& os, return os; } +template struct LabelCategoricalBucket { utils::IntegerDistributionDouble value; int64_t count; @@ -1699,7 +1771,11 @@ struct LabelCategoricalBucket { : label_(label), weights_(weights), num_classes_(label_distribution.NumClasses()) { - DCHECK_EQ(weights.size(), label.size()); + if constexpr (weighted) { + DCHECK_EQ(weights.size(), label.size()); + } else { + DCHECK(weights.empty()); + } } void InitializeAndZero(LabelCategoricalBucket* acc) const { @@ -1712,7 +1788,11 @@ struct LabelCategoricalBucket { void ConsumeExample(const UnsignedExampleIdx example_idx, LabelCategoricalBucket* acc) const { - acc->value.Add(label_[example_idx], weights_[example_idx]); + if constexpr (weighted) { + acc->value.Add(label_[example_idx], weights_[example_idx]); + } else { + acc->value.Add(label_[example_idx]); + } acc->count++; } @@ -1727,7 +1807,14 @@ struct LabelCategoricalBucket { }; inline std::ostream& operator<<(std::ostream& os, - const LabelCategoricalBucket& data) { + const LabelCategoricalBucket& data) { + os << "value:{obs:" << data.value.NumObservations() + << "} count:" << data.count; + return os; +} + +inline std::ostream& operator<<(std::ostream& os, + const LabelCategoricalBucket& data) { os << "value:{obs:" << data.value.NumObservations() << "} count:" << data.count; return os; @@ -1735,22 +1822,9 @@ inline std::ostream& operator<<(std::ostream& os, template struct LabelBinaryCategoricalBucket { - // Since this object is constructed many times, it's worth saving memory - // aggressively. Since C++ does not support 0-byte objects, the conditional - // cannot be set to void. Setting the conditional to an empty struct would - // still use 1 byte. The construction below has size 1 if unweighted is true - // and size 8 (due to alignment) if unweighted is false. This is the same size - // as in a non-templated version that constructs two different buckets. - struct SumTruesAndWeights { - double sum_trues; - double sum_weights; - }; - struct SumTruesOnly { - double sum_trues; - }; - typedef - typename std::conditional_t - SumTruesAndMaybeWeights; + typedef typename std::conditional_t + SumTruesAndMaybeWeights; SumTruesAndMaybeWeights content; int64_t count; diff --git a/yggdrasil_decision_forests/learner/decision_tree/splitter_scanner.h b/yggdrasil_decision_forests/learner/decision_tree/splitter_scanner.h index 24ea0340..32ced8d1 100644 --- a/yggdrasil_decision_forests/learner/decision_tree/splitter_scanner.h +++ b/yggdrasil_decision_forests/learner/decision_tree/splitter_scanner.h @@ -146,20 +146,51 @@ using FeatureIsMissingLabelHessianNumerical = ExampleBucketSet< // Label: Categorical. -using FeatureNumericalLabelCategoricalOneValue = ExampleBucketSet< - ExampleBucket>; +using LabelWeightedCategoricalOneValueBucket = + LabelCategoricalOneValueBucket; -using FeatureDiscretizedNumericalLabelCategorical = ExampleBucketSet< - ExampleBucket>; +using LabelWeightedCategoricalBucket = LabelCategoricalBucket; + +using FeatureNumericalLabelCategoricalOneValue = + ExampleBucketSet>; + +using FeatureDiscretizedNumericalLabelCategorical = + ExampleBucketSet>; using FeatureCategoricalLabelCategorical = ExampleBucketSet< - ExampleBucket>; + ExampleBucket>; using FeatureBooleanLabelCategorical = ExampleBucketSet< - ExampleBucket>; + ExampleBucket>; using FeatureIsMissingLabelCategorical = ExampleBucketSet< - ExampleBucket>; + ExampleBucket>; + +// Label: Unweighted Categorical. + +using LabelUnweightedCategoricalOneValueBucket = + LabelCategoricalOneValueBucket; + +using LabelUnweightedCategoricalBucket = LabelCategoricalBucket; + +using FeatureNumericalLabelUnweightedCategoricalOneValue = + ExampleBucketSet>; + +using FeatureDiscretizedNumericalLabelUnweightedCategorical = + ExampleBucketSet>; + +using FeatureCategoricalLabelUnweightedCategorical = ExampleBucketSet< + ExampleBucket>; + +using FeatureBooleanLabelUnweightedCategorical = ExampleBucketSet< + ExampleBucket>; + +using FeatureIsMissingLabelUnweightedCategorical = ExampleBucketSet< + ExampleBucket>; // Label: Binary Categorical. @@ -253,6 +284,13 @@ struct PerThreadCacheV2 { FeatureIsMissingLabelCategorical example_bucket_set_cat_3; FeatureBooleanLabelCategorical example_bucket_set_cat_4; + FeatureNumericalLabelUnweightedCategoricalOneValue example_bucket_set_ucat_1; + FeatureDiscretizedNumericalLabelUnweightedCategorical + example_bucket_set_ucat_5; + FeatureCategoricalLabelUnweightedCategorical example_bucket_set_ucat_2; + FeatureIsMissingLabelUnweightedCategorical example_bucket_set_ucat_3; + FeatureBooleanLabelUnweightedCategorical example_bucket_set_ucat_4; + FeatureNumericalLabelHessianNumericalOneValue example_bucket_set_hnum_1; FeatureDiscretizedNumericalLabelHessianNumerical example_bucket_set_hnum_5; FeatureCategoricalLabelHessianNumerical example_bucket_set_hnum_2; @@ -361,6 +399,25 @@ auto* GetCachedExampleBucketSet(PerThreadCacheV2* cache) { } else if constexpr (is_same_v) { return &cache->example_bucket_set_cat_4; + } else if constexpr ( + is_same_v) { + // Unweighted Categorical. + return &cache->example_bucket_set_ucat_1; + } else if constexpr ( + is_same_v) { + return &cache->example_bucket_set_ucat_5; + } else if constexpr (is_same_v< + ExampleBucketSet, + FeatureCategoricalLabelUnweightedCategorical>) { + return &cache->example_bucket_set_ucat_2; + } else if constexpr (is_same_v) { + return &cache->example_bucket_set_ucat_3; + } else if constexpr (is_same_v) { + return &cache->example_bucket_set_ucat_4; } else if constexpr (is_same_v< ExampleBucketSet, FeatureNumericalLabelBinaryCategoricalOneValue>) { @@ -1283,6 +1340,36 @@ constexpr auto FindBestSplit_LabelClassificationFeatureNACart = LabelCategoricalScoreAccumulator, /*require_label_sorting*/ false>; +// Label: Unweighted Classification. + +constexpr auto FindBestSplit_LabelUnweightedClassificationFeatureNumerical = + FindBestSplit; + +constexpr auto + FindBestSplit_LabelUnweightedClassificationFeatureDiscretizedNumerical = + FindBestSplit; + +constexpr auto + FindBestSplit_LabelUnweightedClassificationFeatureCategoricalCart = + FindBestSplit; + +constexpr auto FindBestSplit_LabelUnweightedClassificationFeatureBooleanCart = + FindBestSplit; + +constexpr auto FindBestSplit_LabelUnweightedClassificationFeatureNACart = + FindBestSplit; + // Label: Binary Classification. constexpr auto FindBestSplit_LabelBinaryClassificationFeatureNumerical = @@ -1312,6 +1399,8 @@ constexpr auto FindBestSplit_LabelBinaryClassificationFeatureNACart = LabelBinaryCategoricalScoreAccumulator, /*require_label_sorting*/ false>; +// Label: Unweighted Binary Classification. + constexpr auto FindBestSplit_LabelUnweightedBinaryClassificationFeatureNumerical = FindBestSplit::Filler label_filler( + labels, weights); + LabelCategoricalOneValueBucket::Initializer + initializer(label_distribution); - if (sorting_strategy == - proto::DecisionTreeTrainingConfig::Internal::PRESORTED || - sorting_strategy == - proto::DecisionTreeTrainingConfig::Internal::FORCE_PRESORTED) { - if (!internal_config.preprocessing) { - LOG(FATAL) << "Preprocessing missing for PRESORTED sorting " - "strategy"; + if (sorting_strategy == + proto::DecisionTreeTrainingConfig::Internal::PRESORTED || + sorting_strategy == + proto::DecisionTreeTrainingConfig::Internal::FORCE_PRESORTED) { + if (!internal_config.preprocessing) { + LOG(FATAL) << "Preprocessing missing for PRESORTED sorting " + "strategy"; + } + if (sorting_strategy == + proto::DecisionTreeTrainingConfig::Internal::FORCE_PRESORTED || + IsPresortingOnNumericalSplitMoreEfficient( + selected_examples.size(), + internal_config.preprocessing->num_examples())) { + const auto& sorted_attributes = + internal_config.preprocessing + ->presorted_numerical_features()[attribute_idx]; + return ScanSplitsPresortedSparse< + FeatureNumericalLabelUnweightedCategoricalOneValue, + LabelCategoricalScoreAccumulator>( + internal_config.preprocessing->num_examples(), selected_examples, + sorted_attributes.items, feature_filler, label_filler, + initializer, min_num_obs, attribute_idx, + internal_config.duplicated_selected_examples, condition, + &cache->cache_v2); + } } + + return FindBestSplit_LabelUnweightedClassificationFeatureNumerical( + selected_examples, feature_filler, label_filler, initializer, + min_num_obs, attribute_idx, condition, &cache->cache_v2); + } else { + LabelCategoricalOneValueBucket::Filler label_filler( + labels, weights); + LabelCategoricalOneValueBucket::Initializer + initializer(label_distribution); + if (sorting_strategy == - proto::DecisionTreeTrainingConfig::Internal::FORCE_PRESORTED || - IsPresortingOnNumericalSplitMoreEfficient( - selected_examples.size(), - internal_config.preprocessing->num_examples())) { - const auto& sorted_attributes = - internal_config.preprocessing - ->presorted_numerical_features()[attribute_idx]; - return ScanSplitsPresortedSparse< - FeatureNumericalLabelCategoricalOneValue, - LabelCategoricalScoreAccumulator>( - internal_config.preprocessing->num_examples(), selected_examples, - sorted_attributes.items, feature_filler, label_filler, initializer, - min_num_obs, attribute_idx, - internal_config.duplicated_selected_examples, condition, - &cache->cache_v2); + proto::DecisionTreeTrainingConfig::Internal::PRESORTED || + sorting_strategy == + proto::DecisionTreeTrainingConfig::Internal::FORCE_PRESORTED) { + if (!internal_config.preprocessing) { + LOG(FATAL) << "Preprocessing missing for PRESORTED sorting " + "strategy"; + } + if (sorting_strategy == + proto::DecisionTreeTrainingConfig::Internal::FORCE_PRESORTED || + IsPresortingOnNumericalSplitMoreEfficient( + selected_examples.size(), + internal_config.preprocessing->num_examples())) { + const auto& sorted_attributes = + internal_config.preprocessing + ->presorted_numerical_features()[attribute_idx]; + return ScanSplitsPresortedSparse< + FeatureNumericalLabelCategoricalOneValue, + LabelCategoricalScoreAccumulator>( + internal_config.preprocessing->num_examples(), selected_examples, + sorted_attributes.items, feature_filler, label_filler, + initializer, min_num_obs, attribute_idx, + internal_config.duplicated_selected_examples, condition, + &cache->cache_v2); + } } - } - return FindBestSplit_LabelClassificationFeatureNumerical( - selected_examples, feature_filler, label_filler, initializer, - min_num_obs, attribute_idx, condition, &cache->cache_v2); + return FindBestSplit_LabelClassificationFeatureNumerical( + selected_examples, feature_filler, label_filler, initializer, + min_num_obs, attribute_idx, condition, &cache->cache_v2); + } } } @@ -2052,13 +2092,25 @@ SplitSearchResult FindSplitLabelClassificationFeatureDiscretizedNumericalCart( } } else { // Multi-class classification. - LabelCategoricalBucket::Filler label_filler(labels, weights, - label_distribution); - LabelCategoricalBucket::Initializer initializer(label_distribution); + if (weights.empty()) { + LabelCategoricalBucket::Filler label_filler( + labels, weights, label_distribution); + LabelCategoricalBucket::Initializer initializer( + label_distribution); + + return FindBestSplit_LabelUnweightedClassificationFeatureDiscretizedNumerical( + selected_examples, feature_filler, label_filler, initializer, + min_num_obs, attribute_idx, condition, &cache->cache_v2); + } else { + LabelCategoricalBucket::Filler label_filler( + labels, weights, label_distribution); + LabelCategoricalBucket::Initializer initializer( + label_distribution); - return FindBestSplit_LabelClassificationFeatureDiscretizedNumerical( - selected_examples, feature_filler, label_filler, initializer, - min_num_obs, attribute_idx, condition, &cache->cache_v2); + return FindBestSplit_LabelClassificationFeatureDiscretizedNumerical( + selected_examples, feature_filler, label_filler, initializer, + min_num_obs, attribute_idx, condition, &cache->cache_v2); + } } } @@ -2391,13 +2443,25 @@ SplitSearchResult FindSplitLabelClassificationFeatureNA( } } else { // Multi-class classification. - LabelCategoricalBucket::Filler label_filler(labels, weights, - label_distribution); - LabelCategoricalBucket::Initializer initializer(label_distribution); + if (weights.empty()) { + LabelCategoricalBucket::Filler label_filler( + labels, weights, label_distribution); + LabelCategoricalBucket::Initializer initializer( + label_distribution); - return FindBestSplit_LabelClassificationFeatureNACart( - selected_examples, feature_filler, label_filler, initializer, - min_num_obs, attribute_idx, condition, &cache->cache_v2); + return FindBestSplit_LabelUnweightedClassificationFeatureNACart( + selected_examples, feature_filler, label_filler, initializer, + min_num_obs, attribute_idx, condition, &cache->cache_v2); + } else { + LabelCategoricalBucket::Filler label_filler( + labels, weights, label_distribution); + LabelCategoricalBucket::Initializer initializer( + label_distribution); + + return FindBestSplit_LabelClassificationFeatureNACart( + selected_examples, feature_filler, label_filler, initializer, + min_num_obs, attribute_idx, condition, &cache->cache_v2); + } } } @@ -2468,7 +2532,6 @@ SplitSearchResult FindSplitLabelClassificationFeatureBoolean( if (num_label_classes == 3) { // Binary classification. if (weights.empty()) { - // Unweighted classes LabelBinaryCategoricalBucket::Filler label_filler( labels, {}, label_distribution); @@ -2491,14 +2554,27 @@ SplitSearchResult FindSplitLabelClassificationFeatureBoolean( } } else { // Multi-class classification. - LabelCategoricalBucket::Filler label_filler(labels, weights, - label_distribution); + if (weights.empty()) { + LabelCategoricalBucket::Filler label_filler( + labels, weights, label_distribution); - LabelCategoricalBucket::Initializer initializer(label_distribution); + LabelCategoricalBucket::Initializer initializer( + label_distribution); - return FindBestSplit_LabelClassificationFeatureBooleanCart( - selected_examples, feature_filler, label_filler, initializer, - min_num_obs, attribute_idx, condition, &cache->cache_v2); + return FindBestSplit_LabelUnweightedClassificationFeatureBooleanCart( + selected_examples, feature_filler, label_filler, initializer, + min_num_obs, attribute_idx, condition, &cache->cache_v2); + } else { + LabelCategoricalBucket::Filler label_filler( + labels, weights, label_distribution); + + LabelCategoricalBucket::Initializer initializer( + label_distribution); + + return FindBestSplit_LabelClassificationFeatureBooleanCart( + selected_examples, feature_filler, label_filler, initializer, + min_num_obs, attribute_idx, condition, &cache->cache_v2); + } } } @@ -3243,12 +3319,22 @@ SplitSearchResult FindSplitLabelClassificationFeatureCategorical( } } else { // Multi-class classification. - return FindSplitLabelClassificationFeatureCategorical< - LabelCategoricalBucket, FeatureCategoricalLabelCategorical, - LabelCategoricalScoreAccumulator>( - selected_examples, weights, attributes, labels, num_attribute_classes, - num_label_classes, na_replacement, min_num_obs, dt_config, - label_distribution, attribute_idx, random, condition, cache); + if (weights.empty()) { + return FindSplitLabelClassificationFeatureCategorical< + LabelCategoricalBucket, + FeatureCategoricalLabelUnweightedCategorical, + LabelCategoricalScoreAccumulator>( + selected_examples, weights, attributes, labels, num_attribute_classes, + num_label_classes, na_replacement, min_num_obs, dt_config, + label_distribution, attribute_idx, random, condition, cache); + } else { + return FindSplitLabelClassificationFeatureCategorical< + LabelCategoricalBucket, + FeatureCategoricalLabelCategorical, LabelCategoricalScoreAccumulator>( + selected_examples, weights, attributes, labels, num_attribute_classes, + num_label_classes, na_replacement, min_num_obs, dt_config, + label_distribution, attribute_idx, random, condition, cache); + } } } diff --git a/yggdrasil_decision_forests/learner/distributed_decision_tree/label_accessor.h b/yggdrasil_decision_forests/learner/distributed_decision_tree/label_accessor.h index 43644f0d..ce5f0b50 100644 --- a/yggdrasil_decision_forests/learner/distributed_decision_tree/label_accessor.h +++ b/yggdrasil_decision_forests/learner/distributed_decision_tree/label_accessor.h @@ -50,7 +50,8 @@ class ClassificationLabelFiller { public: // How to represent a label value. typedef int16_t Label; - typedef decision_tree::LabelCategoricalBucket LabelBucket; + // TODO(b/225812418): Add special handling for unit weights. + typedef decision_tree::LabelCategoricalBucket LabelBucket; typedef decision_tree::LabelCategoricalScoreAccumulator Accumulator; typedef LabelBucket::Initializer AccumulatorInitializer; diff --git a/yggdrasil_decision_forests/learner/random_forest/random_forest.cc b/yggdrasil_decision_forests/learner/random_forest/random_forest.cc index 94df8f30..7f8edc34 100644 --- a/yggdrasil_decision_forests/learner/random_forest/random_forest.cc +++ b/yggdrasil_decision_forests/learner/random_forest/random_forest.cc @@ -419,15 +419,14 @@ RandomForestLearner::TrainWithStatus( // all the examples have the same weight. // // Currently, this feature is supported for: - // - Binary classification without oblique splits (default) and with local + // - Classification without oblique splits (default) and with local // imputation policy (default) to handle missing values. bool use_optimized_unit_weights = false; if (training_config().task() == model::proto::Task::CLASSIFICATION && rf_config.decision_tree().split_axis_case() != decision_tree::proto::DecisionTreeTrainingConfig:: kSparseObliqueSplit) { - // Only use optimized unit weights for binary classification for now. - if (config_link.num_label_classes() == 3) use_optimized_unit_weights = true; + use_optimized_unit_weights = true; } RETURN_IF_ERROR(dataset::GetWeights(train_dataset, config_link, &weights,