Skip to content

Commit

Permalink
Fix templating issue
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 569470835
  • Loading branch information
rstz authored and copybara-github committed Sep 29, 2023
1 parent 4e61fbd commit b3b0ca2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
3 changes: 2 additions & 1 deletion yggdrasil_decision_forests/learner/decision_tree/oblique.cc
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ absl::StatusOr<SplitSearchResult> EvaluateProjection(
const InternalTrainConfig& internal_config, const int first_attribute_idx,
const NodeConstraints& constraints, int8_t monotonic_direction,
proto::NodeCondition* condition, SplitterPerThreadCache* cache) {
const int min_num_obs =
const UnsignedExampleIdx min_num_obs =
dt_config.in_split_min_examples_check() ? dt_config.min_examples() : 1;

// Projection are never missing.
Expand All @@ -386,6 +386,7 @@ absl::StatusOr<SplitSearchResult> EvaluateProjection(
#endif

// Find a good split in the current_projection.
// TODO: Why is internal_config not passed along below?
SplitSearchResult result;
if constexpr (is_same<LabelStats, ClassificationLabelStats>::value) {
result = FindSplitLabelClassificationFeatureNumericalCart(
Expand Down
24 changes: 24 additions & 0 deletions yggdrasil_decision_forests/learner/decision_tree/training.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2555,6 +2555,30 @@ SplitSearchResult FindSplitLabelHessianRegressionFeatureNumericalCart(
attribute_idx, condition, &cache->cache_v2);
}

template SplitSearchResult
FindSplitLabelHessianRegressionFeatureNumericalCart<true>(
const std::vector<UnsignedExampleIdx>& selected_examples,
const std::vector<float>& weights, const std::vector<float>& attributes,
const std::vector<float>& gradients, const std::vector<float>& hessians,
float na_replacement, UnsignedExampleIdx min_num_obs,
const proto::DecisionTreeTrainingConfig& dt_config, double sum_gradient,
double sum_hessian, double sum_weights, int32_t attribute_idx,
const InternalTrainConfig& internal_config,
const NodeConstraints& constraints, int8_t monotonic_direction,
proto::NodeCondition* condition, SplitterPerThreadCache* cache);

template SplitSearchResult
FindSplitLabelHessianRegressionFeatureNumericalCart<false>(
const std::vector<UnsignedExampleIdx>& selected_examples,
const std::vector<float>& weights, const std::vector<float>& attributes,
const std::vector<float>& gradients, const std::vector<float>& hessians,
float na_replacement, UnsignedExampleIdx min_num_obs,
const proto::DecisionTreeTrainingConfig& dt_config, double sum_gradient,
double sum_hessian, double sum_weights, int32_t attribute_idx,
const InternalTrainConfig& internal_config,
const NodeConstraints& constraints, int8_t monotonic_direction,
proto::NodeCondition* condition, SplitterPerThreadCache* cache);

template <bool weighted>
SplitSearchResult
FindSplitLabelHessianRegressionFeatureDiscretizedNumericalCart(
Expand Down

0 comments on commit b3b0ca2

Please sign in to comment.