diff --git a/CHANGELOG.md b/CHANGELOG.md index c9f1a4f2..ddcb0638 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## 0.1.2 - 2021-05-18 + +### Features + +- Inference engines: QuickScorer Extended and Pred + ## 0.1.1 - 2021-05-17 ### Features diff --git a/README.md b/README.md index d8b1a14c..c07ad0e2 100644 --- a/README.md +++ b/README.md @@ -95,8 +95,8 @@ Download one of the build, and then run `examples/beginner.{sh,bat}`. Target | Version | Link ------- | ------- | ---- -Linux | 0.1.1 | [CLI](https://github.com/google/yggdrasil-decision-forests/releases/download/0.1.1/cli_linux.zip) -Windows | 0.1.1 | [CLI](https://github.com/google/yggdrasil-decision-forests/releases/download/0.1.1/cli_windows.zip) +Linux | 0.1.0 | [CLI](https://github.com/google/yggdrasil-decision-forests/releases/download/0.1.0/cli_linux.zip) +Windows | 0.1.0 | [CLI](https://github.com/google/yggdrasil-decision-forests/releases/download/0.1.0/cli_windows.zip) ## Installation from Source diff --git a/yggdrasil_decision_forests/cli/benchmark_inference.cc b/yggdrasil_decision_forests/cli/benchmark_inference.cc index 4a78cb09..148cc584 100644 --- a/yggdrasil_decision_forests/cli/benchmark_inference.cc +++ b/yggdrasil_decision_forests/cli/benchmark_inference.cc @@ -45,8 +45,9 @@ // batch_size : 100 num_runs : 20 // time/example(µs) time/batch(µs) method // ---------------------------------------- -// interface] 9.179 917.9 GradientBoostedTreesGeneric [virtual -// interface] 21.547 2154.8 Generic slow engine +// 0.79025 79.025 GradientBoostedTreesQuickScorerExtended +// 9.179 917.9 GradientBoostedTreesGeneric +// 21.547 2154.8 Generic slow engine // ---------------------------------------- // #include "absl/flags/flag.h" diff --git a/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees_test.cc b/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees_test.cc index 3af507ef..72726bde 100644 --- a/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees_test.cc +++ b/yggdrasil_decision_forests/learner/gradient_boosted_trees/gradient_boosted_trees_test.cc @@ -721,6 +721,28 @@ TEST_F(GradientBoostedTreesOnAdult, RandomCategorical) { EXPECT_TRUE(gbt_model->IsMissingValueConditionResultFollowGlobalImputation()); } +// Train and test a model on the adult dataset with too much nodes for the +// QuickScorer serving algorithm. +TEST_F(GradientBoostedTreesOnAdult, BaseNoQuickScorer) { + auto* gbt_config = train_config_.MutableExtension( + gradient_boosted_trees::proto::gradient_boosted_trees_config); + gbt_config->set_num_trees(100); + gbt_config->mutable_decision_tree()->set_max_depth(10); + gbt_config->set_shrinkage(0.1f); + gbt_config->set_subsample(0.9f); + TrainAndEvaluateModel(); + + // Note: Accuracy is similar as RF (see :random_forest_test). However logloss + // is significantly better (which is expected as, unlike RF, GBT is + // calibrated). + EXPECT_NEAR(metric::Accuracy(evaluation_), 0.8549, 0.015); + EXPECT_NEAR(metric::LogLoss(evaluation_), 0.320, 0.04); + + auto* gbt_model = + dynamic_cast(model_.get()); + EXPECT_TRUE(gbt_model->IsMissingValueConditionResultFollowGlobalImputation()); +} + // Train and test a model on the adult dataset. TEST_F(GradientBoostedTreesOnAdult, BaseConcurrentDeprecated) { auto* gbt_config = train_config_.MutableExtension( diff --git a/yggdrasil_decision_forests/serving/decision_forest/BUILD b/yggdrasil_decision_forests/serving/decision_forest/BUILD index 11d97053..5943a804 100644 --- a/yggdrasil_decision_forests/serving/decision_forest/BUILD +++ b/yggdrasil_decision_forests/serving/decision_forest/BUILD @@ -16,6 +16,7 @@ cc_library( ], deps = [ ":decision_forest", + ":quick_scorer_extended", "//yggdrasil_decision_forests/model:abstract_model", "//yggdrasil_decision_forests/model/gradient_boosted_trees", "//yggdrasil_decision_forests/serving:example_set_model_wrapper", @@ -49,6 +50,28 @@ cc_library( ], ) +cc_library( + name = "quick_scorer_extended", + srcs = [ + "quick_scorer_extended.cc", + ], + hdrs = [ + "quick_scorer_extended.h", + ], + deps = [ + ":utils", + "@com_google_absl//absl/base:config", + "@com_google_absl//absl/status", + "//yggdrasil_decision_forests/model/decision_tree", + "//yggdrasil_decision_forests/model/gradient_boosted_trees", + "//yggdrasil_decision_forests/model/gradient_boosted_trees:gradient_boosted_trees_cc_proto", + "//yggdrasil_decision_forests/serving:example_set", + "//yggdrasil_decision_forests/utils:bitmap", + "//yggdrasil_decision_forests/utils:compatibility", + "//yggdrasil_decision_forests/utils:usage", + ], +) + cc_library( name = "utils", srcs = [ @@ -84,6 +107,7 @@ cc_test( shard_count = 10, deps = [ ":decision_forest", + ":quick_scorer_extended", ":register_engines", "@com_google_googletest//:gtest_main", "@com_google_absl//absl/flags:flag", @@ -108,3 +132,15 @@ cc_test( "//yggdrasil_decision_forests/utils:test_utils", ], ) + +cc_test( + name = "quick_scorer_extended_test", + srcs = ["quick_scorer_extended_test.cc"], + deps = [ + ":quick_scorer_extended", + "@com_google_googletest//:gtest_main", + "//yggdrasil_decision_forests/model/decision_tree", + "//yggdrasil_decision_forests/model/gradient_boosted_trees", + "//yggdrasil_decision_forests/utils:test", + ], +) diff --git a/yggdrasil_decision_forests/serving/decision_forest/decision_forest.h b/yggdrasil_decision_forests/serving/decision_forest/decision_forest.h index b937b06a..41aaec78 100644 --- a/yggdrasil_decision_forests/serving/decision_forest/decision_forest.h +++ b/yggdrasil_decision_forests/serving/decision_forest/decision_forest.h @@ -214,6 +214,10 @@ struct FlatNodeModel { using NodeType = Node; using FeaturesDefinition = FeaturesDefinitionNumericalOrCategoricalFlat; + using ExampleSet = + ExampleSetNumericalOrCategoricalFlat, + ExampleFormat::FORMAT_EXAMPLE_MAJOR>; + const FeaturesDefinition& features() const { return internal_features; } FeaturesDefinition* mutable_features() { return &internal_features; } @@ -555,6 +559,14 @@ void Predict(const GradientBoostedTreesRankingNumericalAndCategorical& model, const std::vector& examples, int num_examples, std::vector* predictions); +template +void PredictWithExampleSet(const Model& model, + const typename Model::ExampleSet& examples, + int num_examples, std::vector* predictions) { + Predict(model, examples.InternalCategoricalAndNumericalValues(), num_examples, + predictions); +} + // Note: Requires for the number of trees to be a multiple of 8. void PredictOptimizedV1( const RandomForestBinaryClassificationNumericalFeatures& model, diff --git a/yggdrasil_decision_forests/serving/decision_forest/decision_forest_test.cc b/yggdrasil_decision_forests/serving/decision_forest/decision_forest_test.cc index 020c7150..dc7d72c1 100644 --- a/yggdrasil_decision_forests/serving/decision_forest/decision_forest_test.cc +++ b/yggdrasil_decision_forests/serving/decision_forest/decision_forest_test.cc @@ -38,6 +38,7 @@ #include "yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.pb.h" #include "yggdrasil_decision_forests/model/model_library.h" #include "yggdrasil_decision_forests/model/prediction.pb.h" +#include "yggdrasil_decision_forests/serving/decision_forest/quick_scorer_extended.h" #include "yggdrasil_decision_forests/utils/csv.h" #include "yggdrasil_decision_forests/utils/distribution.pb.h" #include "yggdrasil_decision_forests/utils/filesystem.h" diff --git a/yggdrasil_decision_forests/serving/decision_forest/quick_scorer_extended.cc b/yggdrasil_decision_forests/serving/decision_forest/quick_scorer_extended.cc new file mode 100644 index 00000000..7dc7d7b6 --- /dev/null +++ b/yggdrasil_decision_forests/serving/decision_forest/quick_scorer_extended.cc @@ -0,0 +1,1057 @@ +/* + * Copyright 2021 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "yggdrasil_decision_forests/serving/decision_forest/quick_scorer_extended.h" + +#include + +#include "absl/status/status.h" + +#ifdef __AVX2__ +#include +#endif + +#include "absl/base/config.h" +#include "yggdrasil_decision_forests/model/decision_tree/decision_tree.h" +#include "yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.h" +#include "yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.pb.h" +#include "yggdrasil_decision_forests/utils/bitmap.h" +#include "yggdrasil_decision_forests/utils/compatibility.h" +#include "yggdrasil_decision_forests/utils/usage.h" + +namespace yggdrasil_decision_forests { +namespace serving { +namespace decision_forest { + +using dataset::proto::ColumnType; +using LeafMask = internal::QuickScorerExtendedModel::LeafMask; +using model::decision_tree::NodeWithChildren; +using model::decision_tree::proto::Condition; +using model::gradient_boosted_trees::proto::Loss; +using ::yggdrasil_decision_forests::utils::bitmap::ToStringBit; + +namespace { + +// Maximum stack size used by the model during inference +constexpr size_t kMaxStackUsageInBytes = 16 * 1024; + +namespace portable { +#ifdef __AVX2__ +void* aligned_alloc(std::size_t alignment, std::size_t size) { +#if defined(_WIN32) + // Visual Studio + return _aligned_malloc(/*size=*/size, /*alignment=*/alignment); +#else + return ::aligned_alloc(/*alignment=*/alignment, /*size=*/size); +#endif +} + +void aligned_free(void* mem) { +#if defined(_WIN32) + _aligned_free(mem); +#else + free(mem); +#endif +} +#endif +} // namespace portable + +// Returns the number of trailing 0-bits in x, starting at the least significant +// bit position. If x is 0, the result is undefined. +int FindLSBSetNonZero64(uint64_t n) { + return utils::CountTrailingZeroesNonzero64(n); +} + +// Activation function for binary classification GBDT trained with Binomial +// LogLikelihood loss. +float ActivationBinomialLogLikelihood(const float value) { + return utils::clamp(1.f / (1.f + std::exp(-value)), 0.f, 1.f); +} + +// Identity activation function. +float ActivationIdentity(const float value) { return value; } + +// Initialize the accumulator used to construct the quick scorer model +// representation. +// +// Note: This accumulator is discarded at the end of the model generation. +template +absl::Status InitializeAccumulator( + const AbstractModel& src, const internal::QuickScorerExtendedModel& dst, + internal::QuickScorerExtendedModel::BuildingAccumulator* accumulator) { + for (const auto& feature : dst.features().fixed_length_features()) { + const auto& feature_spec = src.data_spec().columns(feature.spec_idx); + + switch (feature.type) { + case ColumnType::CATEGORICAL: { + // Note: Initially, the bitmap is initially filled with 1s i.e. no leaf + // is filtered. + auto& feature_acc = + accumulator->categorical_contains_conditions[feature.spec_idx]; + feature_acc.internal_feature_idx = feature.internal_idx; + feature_acc.items.assign( + src.NumTrees() * + feature_spec.categorical().number_of_unique_values(), + ~internal::QuickScorerExtendedModel::kZeroLeafMask); + } break; + + case ColumnType::NUMERICAL: + case ColumnType::DISCRETIZED_NUMERICAL: + case ColumnType::BOOLEAN: { + // Note: Initially, the bitmap is initially filled with 1s i.e. no leaf + // is filtered. + auto& feature_acc = accumulator->is_higher_conditions[feature.spec_idx]; + feature_acc.internal_feature_idx = feature.internal_idx; + } break; + + default: + return absl::InternalError("Unexpected feature type"); + } + } + + for (const auto& feature : dst.features().categorical_set_features()) { + const auto& feature_spec = src.data_spec().columns(feature.spec_idx); + if (feature.type == ColumnType::CATEGORICAL_SET) { + auto& feature_acc = + accumulator->categoricalset_contains_conditions[feature.spec_idx]; + feature_acc.internal_feature_idx = feature.internal_idx; + feature_acc.masks.resize( + feature_spec.categorical().number_of_unique_values() + 1); + } else { + return absl::InternalError("Unexpected feature type"); + } + } + + return absl::OkStatus(); +} + +// Finalize the model. To be run once all the trees have been integrated to the +// quick scorer representation with the "FillQuickScorer" method. +absl::Status FinalizeModel( + const internal::QuickScorerExtendedModel::BuildingAccumulator& accumulator, + internal::QuickScorerExtendedModel* dst) { + // Copy the conditions from the accumulator index to the optimized model. + + // For "is_higher" conditions. + for (const auto& it_is_higher_condition : accumulator.is_higher_conditions) { + dst->is_higher_conditions.push_back(it_is_higher_condition.second); + auto& condition = dst->is_higher_conditions.back(); + // Sort in increasing threshold value. + std::sort( + condition.items.begin(), condition.items.end(), + [](const auto& a, const auto& b) { return a.threshold < b.threshold; }); + } + // Sort the condition by increasing feature index (for better locality when + // querying the examples). + std::sort(dst->is_higher_conditions.begin(), dst->is_higher_conditions.end(), + [](const auto& a, const auto& b) { + return a.internal_feature_idx < b.internal_feature_idx; + }); + + // For dense "contains" conditions. + for (const auto& it_contains_condition : + accumulator.categorical_contains_conditions) { + dst->categorical_contains_conditions.push_back( + it_contains_condition.second); + } + + // For sparse "contains" conditions. + for (const auto& it_contains_condition : + accumulator.categoricalset_contains_conditions) { + internal::QuickScorerExtendedModel::SparseContainsConditions condition; + condition.internal_feature_idx = + it_contains_condition.second.internal_feature_idx; + const auto& src_masks = it_contains_condition.second.masks; + condition.value_to_mask_range.reserve(src_masks.size()); + for (const auto& mask : src_masks) { + condition.value_to_mask_range.emplace_back(); + condition.value_to_mask_range.back().first = condition.mask_buffer.size(); + for (const auto& tree_mask : mask) { + if (tree_mask.second == + ~internal::QuickScorerExtendedModel::kZeroLeafMask) { + continue; + } + condition.mask_buffer.push_back(tree_mask); + } + condition.value_to_mask_range.back().second = + condition.mask_buffer.size(); + } + dst->categoricalset_contains_conditions.push_back(std::move(condition)); + } + + return absl::OkStatus(); +} + +// Adds the content of a node (and its children i.e. recursive visit) to the +// quick scorer tree structure. +template +absl::Status FillQuickScorerNode( + const AbstractModel& src, + const internal::QuickScorerExtendedModel::TreeIdx tree_idx, + const NodeWithChildren& src_node, internal::QuickScorerExtendedModel* dst, + int* leaf_idx, int* non_leaf_idx, + internal::QuickScorerExtendedModel::BuildingAccumulator* accumulator) { + if (src_node.IsLeaf()) { + // Store the lead value. + if (*leaf_idx >= internal::QuickScorerExtendedModel::kMaxLeafs) { + return absl::InternalError("Leaf idx too large"); + } + if (*leaf_idx >= dst->max_num_leafs_per_tree) { + return absl::InternalError("Leaf idx too large"); + } + const auto leaf_value_idx = + *leaf_idx + tree_idx * dst->max_num_leafs_per_tree; + if (leaf_value_idx >= dst->leaf_values.size()) { + return absl::InternalError("Leaf value idx too large"); + } + dst->leaf_values[leaf_value_idx] = src_node.node().regressor().top_value(); + (*leaf_idx)++; + } else { + // Index of the first leaf in the negative branch. + const auto begin_neg_leaf_idx = *leaf_idx; + + // Parse the negative branch. + RETURN_IF_ERROR(FillQuickScorerNode(src, tree_idx, *src_node.neg_child(), + dst, leaf_idx, non_leaf_idx, + accumulator)); + + // Index of the feature used by the node. + const int spec_feature_idx = src_node.node().condition().attribute(); + + // Compute the bitmap mask i.e. the bitmap that hide the leafs of the + // negative branch. + // + // Example: + // If begin_neg_leaf_idx=2 and end_neg_leaf_idx = 5, the mask will be: + // "1100011111" + 54 * "1" (lower bit on the left). + const auto end_neg_leaf_idx = *leaf_idx; + const auto start_leaf_mask = + (internal::QuickScorerExtendedModel::kOneLeafMask + << begin_neg_leaf_idx) - + 1; + const auto after_neg_mask = + (internal::QuickScorerExtendedModel::kOneLeafMask << end_neg_leaf_idx) - + 1; + internal::QuickScorerExtendedModel::LeafMask mask = + ~(after_neg_mask ^ start_leaf_mask); + + const auto& condition = src_node.node().condition().condition(); + // Branch to take is case of missing value. Can be ignored in the case of + // numerical and categorical features as the use "feature_missing_values" + // produce an equivalent (but more efficient) behavior. + const bool na_value = src_node.node().condition().na_value(); + const auto& attribute_spec = + src.data_spec().columns(src_node.node().condition().attribute()); + + auto set_numerical_higher = [&]() { + const auto threshold = condition.higher_condition().threshold(); + accumulator->is_higher_conditions[spec_feature_idx].items.push_back( + {/*.threshold =*/threshold, /*.tree_idx =*/tree_idx, + /*.leaf_mask =*/mask}); + }; + + auto set_boolean_is_true = [&]() { + accumulator->is_higher_conditions[spec_feature_idx].items.push_back( + {/*.threshold =*/0.5f, /*.tree_idx =*/tree_idx, + /*.leaf_mask =*/mask}); + }; + + auto set_discretized_numerical_higher = [&]() { + const auto discretized_threshold = + condition.discretized_higher_condition().threshold(); + const float threshold = attribute_spec.discretized_numerical().boundaries( + discretized_threshold - 1); + accumulator->is_higher_conditions[spec_feature_idx].items.push_back( + {/*.threshold = */ threshold, /*.tree_idx =*/tree_idx, + /*.leaf_mask =*/mask}); + }; + + auto set_categorical_contains = [&]() { + const auto elements = condition.contains_condition().elements(); + for (const auto feature_value : elements) { + accumulator->categorical_contains_conditions[spec_feature_idx] + .items[tree_idx + feature_value * dst->num_trees] &= mask; + } + }; + + auto set_categorical_bitmap_contains = [&]() { + const auto bitmap = + condition.contains_bitmap_condition().elements_bitmap(); + const int num_unique_values = + attribute_spec.categorical().number_of_unique_values(); + for (int feature_value = 0; feature_value < num_unique_values; + ++feature_value) { + if (utils::bitmap::GetValueBit(bitmap, feature_value)) { + accumulator->categorical_contains_conditions[spec_feature_idx] + .items[tree_idx + feature_value * dst->num_trees] &= mask; + } + } + }; + + auto set_categoricalset_contains = [&]() { + const auto elements = condition.contains_condition().elements(); + if (na_value) { + internal::AndMaskMap( + tree_idx, mask, + &accumulator->categoricalset_contains_conditions[spec_feature_idx] + .masks[0]); + } + for (const auto feature_value : elements) { + internal::AndMaskMap( + tree_idx, mask, + &accumulator->categoricalset_contains_conditions[spec_feature_idx] + .masks[feature_value + 1]); + } + }; + + auto set_categoricalset_bitmap_contains = [&]() { + if (na_value) { + internal::AndMaskMap( + tree_idx, mask, + &accumulator->categoricalset_contains_conditions[spec_feature_idx] + .masks[0]); + } + const auto bitmap = + condition.contains_bitmap_condition().elements_bitmap(); + const int num_unique_values = + attribute_spec.categorical().number_of_unique_values(); + for (int feature_value = 0; feature_value < num_unique_values; + ++feature_value) { + if (utils::bitmap::GetValueBit(bitmap, feature_value)) { + internal::AndMaskMap( + tree_idx, mask, + &accumulator->categoricalset_contains_conditions[spec_feature_idx] + .masks[feature_value + 1]); + } + } + }; + + // Process the node's condition. + switch (condition.type_case()) { + case Condition::TypeCase::kHigherCondition: + DCHECK_EQ(attribute_spec.type(), ColumnType::NUMERICAL); + set_numerical_higher(); + break; + + case Condition::TypeCase::kDiscretizedHigherCondition: + DCHECK_EQ(attribute_spec.type(), ColumnType::DISCRETIZED_NUMERICAL); + set_discretized_numerical_higher(); + break; + + case Condition::TypeCase::kTrueValueCondition: + DCHECK_EQ(attribute_spec.type(), ColumnType::BOOLEAN); + set_boolean_is_true(); + break; + + case Condition::TypeCase::kContainsCondition: + if (attribute_spec.type() == ColumnType::CATEGORICAL) { + set_categorical_contains(); + } else if (attribute_spec.type() == ColumnType::CATEGORICAL_SET) { + set_categoricalset_contains(); + } else { + return absl::InternalError("Unexpected type"); + } + break; + + case Condition::TypeCase::kContainsBitmapCondition: + if (attribute_spec.type() == ColumnType::CATEGORICAL) { + set_categorical_bitmap_contains(); + } else if (attribute_spec.type() == ColumnType::CATEGORICAL_SET) { + set_categoricalset_bitmap_contains(); + } else { + return absl::InternalError("Unexpected type"); + } + break; + + default: + return absl::InvalidArgumentError("Unsupported condition type."); + } + + ++(*non_leaf_idx); + + RETURN_IF_ERROR(FillQuickScorerNode(src, tree_idx, *src_node.pos_child(), + dst, leaf_idx, non_leaf_idx, + accumulator)); + } + return absl::OkStatus(); +} + +// Adds the content of the tree structures to the quick scorer structure. +template +absl::Status FillQuickScorer( + const AbstractModel& src, internal::QuickScorerExtendedModel* dst, + internal::QuickScorerExtendedModel::BuildingAccumulator* accumulator) { + RETURN_IF_ERROR(InitializeAccumulator(src, *dst, accumulator)); + + dst->initial_prediction = src.initial_predictions()[0]; + dst->num_trees = src.NumTrees(); + if (dst->num_trees > internal::QuickScorerExtendedModel::kMaxTrees) { + return absl::InvalidArgumentError( + absl::Substitute("The model contains trees with more than $0 trees", + internal::QuickScorerExtendedModel::kMaxTrees)); + } + + // Get the maximum number of leafs per trees. + dst->max_num_leafs_per_tree = 0; + int num_leafs = 0; + for (const auto& src_tree : src.decision_trees()) { + const auto num_leafs_in_tree = src_tree->NumLeafs(); + num_leafs += num_leafs_in_tree; + if (num_leafs_in_tree > dst->max_num_leafs_per_tree) { + dst->max_num_leafs_per_tree = num_leafs_in_tree; + } + } + + if (dst->max_num_leafs_per_tree > + internal::QuickScorerExtendedModel::kMaxLeafs) { + return absl::InvalidArgumentError( + absl::Substitute("The model contains trees with more than $0 leafs", + internal::QuickScorerExtendedModel::kMaxLeafs)); + } + + dst->leaf_values.assign(dst->max_num_leafs_per_tree * dst->num_trees, 0.f); + + for (internal::QuickScorerExtendedModel::TreeIdx tree_idx = 0; + tree_idx < src.decision_trees().size(); ++tree_idx) { + const auto& src_tree = src.decision_trees()[tree_idx]; + int leaf_idx = 0; + int non_leaf_idx = 0; + RETURN_IF_ERROR(FillQuickScorerNode(src, tree_idx, src_tree->root(), dst, + &leaf_idx, &non_leaf_idx, accumulator)); + } + + RETURN_IF_ERROR(FinalizeModel(*accumulator, dst)); + return absl::OkStatus(); +} + +// Tree inference without SIMD i.e. one example at a time. +// This method is used for the examples outside of the SIMD batch. +// +// "active_leaf_buffer" is a pre-allocated buffer of at least "num-trees" +// elements. +template +void PredictQuickScorerSequential( + const Model& model, + const std::vector& fixed_length_features, + const std::vector& categorical_set_begins_and_ends, + const std::vector& categorical_item_buffer, + const int begin_example_idx, const int end_example_idx, + const int major_feature_offset, std::vector* predictions, + internal::QuickScorerExtendedModel::LeafMask* active_leaf_buffer) { + const size_t active_leaf_buffer_size = model.num_trees * sizeof(LeafMask); + + const auto index = [&major_feature_offset](const int feature_idx, + const int example_idx) -> int { + return feature_idx * major_feature_offset + example_idx; + }; + + for (int example_idx = begin_example_idx; example_idx < end_example_idx; + ++example_idx) { + // Reset active node buffer. + std::memset(active_leaf_buffer, 0xFF, active_leaf_buffer_size); + + // Is higher conditions. + for (const auto& is_higher_condition : model.is_higher_conditions) { + const auto feature_value = + fixed_length_features[index(is_higher_condition.internal_feature_idx, + example_idx)] + .numerical_value; + + for (const auto& item : is_higher_condition.items) { + if (item.threshold > feature_value) { + break; + } + active_leaf_buffer[item.tree_idx] &= item.leaf_mask; + } + } + + // Dense contains conditions. + for (const auto& contains_condition : + model.categorical_contains_conditions) { + const auto feature_value = + fixed_length_features[index(contains_condition.internal_feature_idx, + example_idx)] + .categorical_value; + DCHECK_LE(model.num_trees * (feature_value + 1), + contains_condition.items.size()); + const auto* leaf_mask_stream = + &contains_condition.items[model.num_trees * feature_value]; + for (int tree_idx = 0; tree_idx < model.num_trees; ++tree_idx) { + active_leaf_buffer[tree_idx] &= *(leaf_mask_stream++); + } + } + + // Sparse contains conditions. + for (const auto& contains_condition : + model.categoricalset_contains_conditions) { + const auto& range_values = categorical_set_begins_and_ends + [contains_condition.internal_feature_idx * major_feature_offset + + example_idx]; + for (int value_idx = range_values.begin; value_idx < range_values.end; + value_idx++) { + const auto value = categorical_item_buffer[value_idx] + 1; + const auto& range_masks = contains_condition.value_to_mask_range[value]; + for (int mask_idx = range_masks.first; mask_idx < range_masks.second; + mask_idx++) { + const auto& mask = contains_condition.mask_buffer[mask_idx]; + active_leaf_buffer[mask.first] &= mask.second; + } + } + } + + // Get the active leaf. + auto* leaf_reader = &model.leaf_values[0]; + float output = model.initial_prediction; + for (int tree_idx = 0; tree_idx < model.num_trees; ++tree_idx) { + const auto shift_mask = active_leaf_buffer[tree_idx]; + const auto node_idx = FindLSBSetNonZero64(shift_mask); + output += leaf_reader[node_idx]; + leaf_reader += model.max_num_leafs_per_tree; + } + + (*predictions)[example_idx] = Activation(output); + } +} + +} // namespace + +// Apply the quick scorer algorithm. +// +// The examples are represented in the arguments "fixed_length_features", +// "categorical_item_buffer" and "categorical_set_begins_and_ands". These fields +// are made to be contained in the "ExampleSet" class. Refer to this class for +// their definition. +// +// "major_feature_offset" is the number of elements in between features blocks +// i.e. the j-th features of the i-th examples is "index = i + j * +// major_feature_offset". +// +template +void PredictQuickScorerMajorFeatureOffset( + const Model& model, + const std::vector& fixed_length_features, + const std::vector& categorical_set_begins_and_ends, + const std::vector& categorical_item_buffer, const int num_examples, + const int major_feature_offset, std::vector* predictions) { + utils::usage::OnInference(num_examples); + predictions->resize(num_examples); + + // "kNumParallelExamples" examples are treated in parallel using SIMD + // instructions. If the number of examples is not a multiple of + // "kNumParallelExamples", the remaining examples are treated with + // "PredictQuickScorerSequential". + constexpr int kNumParallelExamples = 4; + + const size_t active_leaf_buffer_size = + model.num_trees * kNumParallelExamples * sizeof(LeafMask); + const size_t alignment = 32 * 8; + + // Make sure the allocated chunk of memory is a multiple of "alignment". + size_t rounded_up_active_leaf_buffer_size = active_leaf_buffer_size; + if ((rounded_up_active_leaf_buffer_size % alignment) != 0) { + rounded_up_active_leaf_buffer_size += + alignment - rounded_up_active_leaf_buffer_size % alignment; + } + + // Note: Alloca was measured to be faster and more consistent (in terms of + // speed) than malloc or pre-allocated caches. + // + // The buffer must be aligned on a 32-byte boundary to work with _mm256 + // class of SIMD instructions (intrinsics). + LeafMask* active_leaf_buffer; + const bool active_leaf_buffer_uses_stack = + active_leaf_buffer_size <= kMaxStackUsageInBytes; + + if (active_leaf_buffer_uses_stack) { +#ifdef __AVX2__ + +#if defined(_WIN32) + void* non_aligned = alloca(rounded_up_active_leaf_buffer_size + alignment); + std::size_t space = rounded_up_active_leaf_buffer_size + alignment; + void* aligned = std::align(alignment, 1, non_aligned, space); +#else + void* aligned = __builtin_alloca_with_align( + rounded_up_active_leaf_buffer_size, alignment); +#endif + active_leaf_buffer = reinterpret_cast(aligned); + +#else + active_leaf_buffer = + reinterpret_cast(alloca(rounded_up_active_leaf_buffer_size)); +#endif + } else { +#ifdef __AVX2__ + active_leaf_buffer = reinterpret_cast( + portable::aligned_alloc(alignment, rounded_up_active_leaf_buffer_size)); +#else + active_leaf_buffer = reinterpret_cast( + std::malloc(rounded_up_active_leaf_buffer_size)); +#endif + } + + int example_idx = 0; + +#ifdef __AVX2__ + if (model.cpu_supports_avx2) { + auto* sample_reader = &fixed_length_features[0]; + float* prediction_reader = &(*predictions)[0]; + + // First run on sub-batches of kNumParallelExamples at a time. The + // remaining will be done sequentially below. + int num_remaining_iters = num_examples / kNumParallelExamples; + while (num_remaining_iters--) { + // Reset active node buffer. + std::memset(active_leaf_buffer, 0xFF, active_leaf_buffer_size); + + // Is higher conditions. + for (const auto& is_higher_condition : model.is_higher_conditions) { + const float* begin_example = + &sample_reader[0].numerical_value + + is_higher_condition.internal_feature_idx * major_feature_offset; + + const auto feature_values = _mm_loadu_ps(begin_example); + for (const auto& item : is_higher_condition.items) { + const auto threshold = _mm_set1_ps(item.threshold); + + const auto comparison = + _mm_castps_si128(_mm_cmpge_ps(feature_values, threshold)); + // Note: "comparison" is either 0x00000000 or 0xFFFFFFFF depending on + // the node condition value. + if (!_mm_test_all_zeros(comparison, comparison)) { + // The mask attached to the condition i.e. the mask to apply on the + // active node bitmap iif. the condition is true. + const auto mask = _mm256_set1_epi64x(item.leaf_mask); + auto* active_si256 = reinterpret_cast<__m256i*>( + &active_leaf_buffer[item.tree_idx * kNumParallelExamples]); + const auto active = _mm256_load_si256(active_si256); + + // Expand the comparison to 8 bytes. + const auto pd_comparison = _mm256_cvtepi32_epi64(comparison); + const auto mask_update = _mm256_andnot_si256(mask, pd_comparison); + const auto new_active = _mm256_andnot_si256(mask_update, active); + // new_active = (mask v not comparison) ^ active + // is equivalent to: + // new_active = not (not mask ^ comparison) ^ active + + _mm256_store_si256(active_si256, new_active); + } else { + break; + } + } + } + + // Dense contains conditions. + for (int sub_example_idx = 0; sub_example_idx < kNumParallelExamples; + ++sub_example_idx) { + for (const auto& contains_condition : + model.categorical_contains_conditions) { + const auto feature_value = + sample_reader[contains_condition.internal_feature_idx * + major_feature_offset + + sub_example_idx] + .categorical_value; + const auto* leaf_mask_stream = + &contains_condition.items[model.num_trees * feature_value]; + for (int tree_idx = 0; tree_idx < model.num_trees; ++tree_idx) { + active_leaf_buffer[tree_idx * kNumParallelExamples + + sub_example_idx] &= *(leaf_mask_stream++); + } + } + } + + // Sparse contains conditions. + for (int sub_example_idx = 0; sub_example_idx < kNumParallelExamples; + ++sub_example_idx) { + for (const auto& contains_condition : + model.categoricalset_contains_conditions) { + const auto& range_values = categorical_set_begins_and_ends + [contains_condition.internal_feature_idx * major_feature_offset + + sub_example_idx + example_idx]; + for (int value_idx = range_values.begin; value_idx < range_values.end; + value_idx++) { + const auto value = categorical_item_buffer[value_idx] + 1; + const auto& range_masks = + contains_condition.value_to_mask_range[value]; + for (int mask_idx = range_masks.first; + mask_idx < range_masks.second; mask_idx++) { + const auto& mask = contains_condition.mask_buffer[mask_idx]; + active_leaf_buffer[mask.first * kNumParallelExamples + + sub_example_idx] &= mask.second; + } + } + } + } + +#pragma loop unroll(full) + for (int sub_example_idx = 0; sub_example_idx < kNumParallelExamples; + ++sub_example_idx) { + prediction_reader[sub_example_idx] = model.initial_prediction; + } + + auto* leaf_reader = &model.leaf_values[0]; + for (int tree_idx = 0; tree_idx < model.num_trees; ++tree_idx) { +#pragma loop unroll(full) + for (int sub_example_idx = 0; sub_example_idx < kNumParallelExamples; + ++sub_example_idx) { + const auto shift_mask = + active_leaf_buffer[tree_idx * kNumParallelExamples + + sub_example_idx]; + const auto node_idx = FindLSBSetNonZero64(shift_mask); + prediction_reader[sub_example_idx] += leaf_reader[node_idx]; + } + leaf_reader += model.max_num_leafs_per_tree; + } + +// Note: The compiler should be able to remove the following loop when +// Activation == Identity. Tested with gcc9 and clang9. +#pragma loop unroll(full) + for (int sub_example_idx = 0; sub_example_idx < kNumParallelExamples; + ++sub_example_idx) { + prediction_reader[sub_example_idx] = + Activation(prediction_reader[sub_example_idx]); + } + + sample_reader += kNumParallelExamples; + prediction_reader += kNumParallelExamples; + example_idx += kNumParallelExamples; + } + } +#endif + + PredictQuickScorerSequential( + model, fixed_length_features, categorical_set_begins_and_ends, + categorical_item_buffer, example_idx, num_examples, major_feature_offset, + predictions, active_leaf_buffer); + + if (!active_leaf_buffer_uses_stack) { +#ifdef __AVX2__ + portable::aligned_free(active_leaf_buffer); +#else + free(active_leaf_buffer); +#endif + } +} + +template +void PredictQuickScorer( + const Model& model, + const std::vector& examples, int num_examples, + std::vector* predictions) { + PredictQuickScorerMajorFeatureOffset(model, examples, {}, {}, num_examples, + num_examples, predictions); +} + +// Version of Predict compatible with the ExampleSet signature. +template +void Predict(const Model& model, const typename Model::ExampleSet& examples, + int num_examples, std::vector* predictions) { + PredictQuickScorerMajorFeatureOffset( + model, examples.InternalCategoricalAndNumericalValues(), + examples.InternalCategoricalSetBeginAndEnds(), + examples.InternalCategoricalItemBuffer(), num_examples, + examples.NumberOfExamples(), predictions); +} + +template void +PredictQuickScorer( + const GradientBoostedTreesRegressionQuickScorerExtended& model, + const std::vector& examples, + const int num_examples, std::vector* predictions); + +template void Predict( + const GradientBoostedTreesRegressionQuickScorerExtended& model, + const GradientBoostedTreesRegressionQuickScorerExtended::ExampleSet& + examples, + const int num_examples, std::vector* predictions); + +template void Predict( + const GradientBoostedTreesRankingQuickScorerExtended& model, + const GradientBoostedTreesRankingQuickScorerExtended::ExampleSet& examples, + const int num_examples, std::vector* predictions); + +template <> +void Predict( + const GradientBoostedTreesBinaryClassificationQuickScorerExtended& model, + const GradientBoostedTreesBinaryClassificationQuickScorerExtended:: + ExampleSet& examples, + const int num_examples, std::vector* predictions) { + PredictQuickScorerMajorFeatureOffset< + GradientBoostedTreesBinaryClassificationQuickScorerExtended, + ActivationBinomialLogLikelihood>( + model, examples.InternalCategoricalAndNumericalValues(), + examples.InternalCategoricalSetBeginAndEnds(), + examples.InternalCategoricalItemBuffer(), num_examples, + examples.NumberOfExamples(), predictions); +} + +template +absl::Status BaseGenericToSpecializedModel(const AbstractModel& src, + CompiledModel* dst) { +#ifdef __AVX2__ +#if ABSL_HAVE_BUILTIN(__builtin_cpu_supports) + dst->cpu_supports_avx2 = __builtin_cpu_supports("avx2"); +#else + // We cannot detect if the CPU supports AVX2 instructions. If it does not, + // a fatal error will be raised. + dst->cpu_supports_avx2 = true; +#endif +#elif ABSL_HAVE_BUILTIN(__builtin_cpu_supports) + if (__builtin_cpu_supports("avx2")) { + LOG(INFO) << "The binary was compiled without AVX2 support, but your CPU" + "supports it. Enable it for faster model inference."; + } +#endif + + if (src.task() != CompiledModel::kTask) { + return absl::InvalidArgumentError("Wrong model class."); + } + + typename CompiledModel::BuildingAccumulator accumulator; + + // List the model input features. + std::vector all_input_features; + RETURN_IF_ERROR(GetInputFeatures(src, &all_input_features, nullptr)); + + RETURN_IF_ERROR( + dst->mutable_features()->Initialize(all_input_features, src.data_spec())); + + // Compile the model. + RETURN_IF_ERROR(FillQuickScorer(src, dst, &accumulator)); + + return absl::OkStatus(); +} + +template <> +absl::Status GenericToSpecializedModel( + const model::gradient_boosted_trees::GradientBoostedTreesModel& src, + GradientBoostedTreesRegressionQuickScorerExtended* dst) { + if (src.loss() != Loss::SQUARED_ERROR) { + return absl::InvalidArgumentError( + "The GBDT is not trained for regression with squared error loss."); + } + return BaseGenericToSpecializedModel(src, dst); +} + +template <> +absl::Status GenericToSpecializedModel( + const model::gradient_boosted_trees::GradientBoostedTreesModel& src, + GradientBoostedTreesRankingQuickScorerExtended* dst) { + if (src.loss() != Loss::LAMBDA_MART_NDCG5 && + src.loss() != Loss::XE_NDCG_MART) { + return absl::InvalidArgumentError( + "The GBDT is not trained for ranking with ranking loss."); + } + return BaseGenericToSpecializedModel(src, dst); +} + +template <> +absl::Status GenericToSpecializedModel( + const model::gradient_boosted_trees::GradientBoostedTreesModel& src, + GradientBoostedTreesBinaryClassificationQuickScorerExtended* dst) { + if (src.loss() != Loss::BINOMIAL_LOG_LIKELIHOOD || + src.initial_predictions().size() != 1) { + return absl::InvalidArgumentError( + "The GBDT is not trained for binary classification with binomial log " + "likelihood loss."); + } + return BaseGenericToSpecializedModel(src, dst); +} + +template +absl::Status CreateEmptyModel(const std::vector& input_features, + const DataSpecification& dataspec, + CompiledModel* dst) { + return dst->mutable_features()->Initialize(input_features, dataspec); +} + +template absl::Status +CreateEmptyModel( + const std::vector& input_features, const DataSpecification& dataspec, + GradientBoostedTreesRegressionQuickScorerExtended* dst); + +template +std::string DescribeQuickScorer(const Model& model, const bool detailed) { + std::string structure; + + // Global data. + absl::SubstituteAndAppend(&structure, + "Maximum number of leafs per trees: $0\n", + model.max_num_leafs_per_tree); + absl::SubstituteAndAppend(&structure, "Number of trees: $0\n", + model.num_trees); + absl::SubstituteAndAppend(&structure, "Initial prediction: $0\n", + model.initial_prediction); + + // List of input features. + absl::StrAppend(&structure, "Features (and missing replacement value):\n"); + for (const auto& feature : model.features().fixed_length_features()) { + absl::SubstituteAndAppend(&structure, "\t$0 [$1]", feature.name, + dataset::proto::ColumnType_Name(feature.type)); + switch (feature.type) { + case ColumnType::NUMERICAL: + case ColumnType::DISCRETIZED_NUMERICAL: + absl::SubstituteAndAppend( + &structure, "($0)\n", + model.features() + .fixed_length_na_replacement_values()[feature.internal_idx] + .numerical_value); + break; + case ColumnType::CATEGORICAL: + absl::SubstituteAndAppend( + &structure, "($0)\n", + model.features() + .fixed_length_na_replacement_values()[feature.internal_idx] + .categorical_value); + break; + default: + absl::StrAppend(&structure, "\n"); + break; + } + } + for (const auto& feature : model.features().categorical_set_features()) { + absl::SubstituteAndAppend(&structure, "\t$0 [CATEGORICAL_SET] (none)\n", + feature.name); + } + absl::StrAppend(&structure, "\n"); + + // Leafs. + absl::SubstituteAndAppend(&structure, "Output leaf values ($0):\n", + model.leaf_values.size()); + if (detailed) { + int leaf_idx = 0; + for (const auto& leaf_value : model.leaf_values) { + absl::SubstituteAndAppend(&structure, " $0", leaf_value); + ++leaf_idx; + } + absl::StrAppend(&structure, "\n\n"); + } + + // Condition "contains" for categorical features. + absl::SubstituteAndAppend(&structure, + "Conditions [categorical contains] ($0):\n", + model.categorical_contains_conditions.size()); + for (const auto& item : model.categorical_contains_conditions) { + absl::SubstituteAndAppend( + &structure, "\tfeature: $0 ($1) (num=$2)\n", item.internal_feature_idx, + model.features() + .fixed_length_features()[item.internal_feature_idx] + .name, + item.items.size()); + if (detailed) { + for (int item_idx = 0; item_idx < item.items.size(); ++item_idx) { + const auto bitmap_representation = ToStringBit( + std::string( + reinterpret_cast(&item.items[item_idx]), + sizeof(LeafMask)), + internal::QuickScorerExtendedModel::kMaxLeafs); + absl::SubstituteAndAppend( + &structure, "\t\ttree:$0 value:$1 mask : $2\n", + item_idx % model.num_trees, item_idx / model.num_trees, + bitmap_representation); + } + } + } + absl::StrAppend(&structure, "\n"); + + // Condition "contains" for categoricalset features. + absl::SubstituteAndAppend(&structure, + "Conditions [categorical set contains] ($0):\n", + model.categoricalset_contains_conditions.size()); + for (const auto& item : model.categoricalset_contains_conditions) { + absl::SubstituteAndAppend( + &structure, + "\tfeature: $0 ($1) (ranges=$2 masks=$3 mask/range=$4/$5)\n", + item.internal_feature_idx, + model.features() + .categorical_set_features()[item.internal_feature_idx] + .name, + item.value_to_mask_range.size(), item.mask_buffer.size(), + static_cast(item.mask_buffer.size()) / + item.value_to_mask_range.size(), + model.num_trees); + if (detailed) { + for (int value = 0; value < item.value_to_mask_range.size(); value++) { + absl::SubstituteAndAppend(&structure, "\tValue: $0:\n", value); + const auto& range = item.value_to_mask_range[value]; + for (int mask_idx = range.first; mask_idx < range.second; mask_idx++) { + const auto& mask = item.mask_buffer[mask_idx]; + const auto bitmap_representation = ToStringBit( + std::string(reinterpret_cast(&mask.second), + sizeof(LeafMask)), + internal::QuickScorerExtendedModel::kMaxLeafs); + absl::SubstituteAndAppend(&structure, "\t\ttree:$0 mask : $1\n", + mask.first, bitmap_representation); + } + } + } + } + absl::StrAppend(&structure, "\n"); + + // Conditions "is higher". + absl::SubstituteAndAppend(&structure, "Conditions [is_higher] ($0):\n", + model.is_higher_conditions.size()); + for (const auto& item : model.is_higher_conditions) { + int num_duplicates = 0; + for (int sub_item_idx = 0; sub_item_idx < item.items.size() - 1; + ++sub_item_idx) { + if (item.items[sub_item_idx].threshold == + item.items[sub_item_idx + 1].threshold) { + ++num_duplicates; + } + } + float duplicate_ratio = -1.f; + if (!item.items.empty()) { + duplicate_ratio = static_cast(num_duplicates) / item.items.size(); + } + + absl::SubstituteAndAppend( + &structure, "\tfeature: $0 ($1) (num=$2; duplicate=$3)\n", + item.internal_feature_idx, + model.features() + .fixed_length_features()[item.internal_feature_idx] + .name, + item.items.size(), duplicate_ratio); + if (detailed) { + for (const auto& sub_item : item.items) { + const auto bitmap_representation = ToStringBit( + std::string( + reinterpret_cast(&sub_item.leaf_mask), + sizeof(LeafMask)), + internal::QuickScorerExtendedModel::kMaxLeafs); + absl::SubstituteAndAppend(&structure, + "\t\tmask:$0 = $1 thre:$2 tree:$3\n", + sub_item.leaf_mask, bitmap_representation, + sub_item.threshold, sub_item.tree_idx); + } + } + } + absl::StrAppend(&structure, "\n"); + + return structure; +} + +template std::string +DescribeQuickScorer( + const GradientBoostedTreesRegressionQuickScorerExtended& model, + bool detailed); + +template std::string DescribeQuickScorer< + GradientBoostedTreesBinaryClassificationQuickScorerExtended>( + const GradientBoostedTreesBinaryClassificationQuickScorerExtended& model, + bool detailed); + +} // namespace decision_forest +} // namespace serving +} // namespace yggdrasil_decision_forests diff --git a/yggdrasil_decision_forests/serving/decision_forest/quick_scorer_extended.h b/yggdrasil_decision_forests/serving/decision_forest/quick_scorer_extended.h new file mode 100644 index 00000000..1fdfc5f8 --- /dev/null +++ b/yggdrasil_decision_forests/serving/decision_forest/quick_scorer_extended.h @@ -0,0 +1,329 @@ +/* + * Copyright 2021 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// QuickScorer is inference algorithm for decision trees. +// This implementation extends the QuickScorer algorithms to categorical and +// categorical-set features. +// +// The central idea is to run the model inference per feature and per condition, +// instead of running it per tree and per node. While more complex in +// appearance, the algorithm generates less CPU branch misses and should overall +// run faster on modern CPUs. +// +// At its code, the algorithm manages a bitmap over the model leaves (called +// "active leaf bitmap" in the code). Each condition is attached to a mask +// (called "mask" in the code) of the same size which is applied (conjunction) +// over the leaves if the condition is true. After all the condition have been +// applied, the "first" active leaf (i.e. the leaf corresponding to the first +// non-zero bitmap value) is returned. +// +// The SIMD instructions are used to process multiple examples at the sametime. +// +// Important: This library works faster if AVX2 is enabled at computation: +// Add "--copt=-mavx2" to the build call. +// Add "requirements = {constraints = cpu_features.require(['avx2'])}" to your +// borgcfg. +// Add a "tricorder > builder > copt: '-mavx2'", in your METADATA for your +// forge tests. +// +// Note: Adding 'copts = ["-mavx2"],' to your binary configuration wont work as +// it will only be used to compile your main target. +// +// The current implementation support: +// - Regressive GBDTs. +// +// With the following constraints: +// - Maximum of 65k trees. +// - Maximum of 128 nodes per trees (e.g. max depth = 6). +// - Maximum of 65k unique input features. +// - Support categorical and numerical features. +// +// Unlike the other simpleML optimized inference engine, categorical are not +// restricted to have a maximum of 32 unique values. +// +// The algorithm are described in the following papers: +// +// Original paper: +// http://ecmlpkdd2017.ijs.si/papers/paperID718.pdf +// Extension with SMID instructions: +// http://pages.di.unipi.it/rossano/wp-content/uploads/sites/7/2016/07/SIGIR16a.pdf +// Categorical-Set Extension: +// https://arxiv.org/abs/2009.09991 +// +// The code can be benchmarked using the following command: +// +// bazel run -c opt --copt=-mavx2 :benchmark_nogpu -- \ +// --alsologtostderr \ +// --mtc=GRADIENT_BOOSTED_TREES_REGRESSION_NUMERICAL_AND_CATEGORICAL_32_ATTRIBUTES +// +// On 9.10.2019, this command returned the following results: +// +// ... +// Average serving time per example: +// gbdt_regression_num_and_cat32_attributes_single_thread_cpu_quick_scorer : +// 647.5ns gbdt_regression_num_and_cat32_attributes_single_thread_cpu_opt_v1 +// : 1.3857us gbdt_regression_num_and_cat32_attributes_single_thread_cpu +// : 1.4537us +// ... +// +// In the code "feature_idx" refers to features indexed according to the +// "dataspec" (i.e. the training dataset). "internal_feature_idx" refers to a +// internal feature indices from the point of view of the model (i.e. where the +// features used by the model are densely packaged). +// +// See the "Intrinsics Guide" for a definition of the used SIMD instructions +// (https://software.intel.com/sites/landingpage/IntrinsicsGuide/). + +#ifndef YGGDRASIL_DECISION_FORESTS_SERVING_DECISION_FOREST_QUICK_SCORER_EXTENDED_H_ +#define YGGDRASIL_DECISION_FORESTS_SERVING_DECISION_FOREST_QUICK_SCORER_EXTENDED_H_ + +#include + +#include + +#include "absl/status/status.h" +#include "yggdrasil_decision_forests/serving/decision_forest/utils.h" +#include "yggdrasil_decision_forests/serving/example_set.h" + +namespace yggdrasil_decision_forests { +namespace serving { +namespace decision_forest { + +namespace internal { + +// Base model representation compatible with the QuickScorer algorithm. +struct QuickScorerExtendedModel { + using ExampleSet = + ExampleSetNumericalOrCategoricalFlat; + // Backward compatibility. + using ValueType = NumericalOrCategoricalValue; + + // Definition of the input features of the model, and how they are represented + // in an input example set. + const ExampleSet::FeaturesDefinition& features() const { + return intern_features; + } + + ExampleSet::FeaturesDefinition* mutable_features() { + return &intern_features; + } + + ExampleSet::FeaturesDefinition intern_features; + + // Note: The following four fields will be integrated as template parameters + // in a future cl. + // Index of a tree in the model. Limits the number of trees of the model. + using TreeIdx = uint32_t; + // Bitmap over the leafs in a tree. Limits the number of leafs in a tree. + using LeafMask = uint64_t; + // The value of a leaf. + using LeafOutput = float; + + // Helper values. + static constexpr LeafMask kZeroLeafMask = static_cast(0); + static constexpr LeafMask kOneLeafMask = static_cast(1); + + // Maximum number of trees and number of leafs per trees. + static constexpr size_t kMaxTrees = std::numeric_limits::max(); + static constexpr size_t kMaxLeafs = sizeof(LeafMask) * 8; + + // Maximum number of leafs in each tree. + int max_num_leafs_per_tree; + + // Value (i.e. prediction) of each leaf. + // "leaf_values[i + j * max_num_leafs_per_tree]" is the value of the "i-th" + // leaf in the "j-th" tree. + std::vector leaf_values; + + // Number of trees in the model. + int num_trees; + + // Initial prediction / bias of the model. + float initial_prediction = 0.f; + +#ifdef __AVX2__ + // This flag is set during the compilation of the model and indicates if the + // CPU supports AVX2 instructions + bool cpu_supports_avx2 = true; +#endif + + // Data for "IsHigher" conditions i.e. condition of the form "feature >= t". + struct IsHigherConditionItem { + float threshold; + TreeIdx tree_idx; + LeafMask leaf_mask; + }; + + struct IsHigherConditions { + // Index of the feature in "model.features". + // See the definition of "internal_feature_idx" in the head comment. + int internal_feature_idx; + + // Thresholds ordered in ascending order. + std::vector items; + }; + + // Data for "Contains" conditions i.e. condition of the form "feature \in + // set". + struct ContainsConditions { + // Internal index of the feature. + int internal_feature_idx; + + // "Contains" type condition for each feature value. + // items[tree_idx + feature_value * num_trees] is the mask to apply on tree + // "tree_idx" when the feature value is "feature_value". + std::vector items; + }; + + // Similar to "ContainsConditions", but only index the trees impacted by each + // feature value. + struct SparseContainsConditions { + // Internal index of the feature. + int internal_feature_idx; + + // The "i-th" feature value maps to the masks "mask_buffer[j]" for "j" in + // "[value_to_mask_range[i).first, value_to_mask_range[i].second[". + std::vector> value_to_mask_range; + std::vector> mask_buffer; + }; + + std::vector is_higher_conditions; + std::vector categorical_contains_conditions; + std::vector categoricalset_contains_conditions; + + // Structure used during the compilation of the model and discarded at the + // end. + struct BuildingAccumulator { + struct SparseContainsConditions { + // Internal index of the feature. + int internal_feature_idx; + + // "masks[i][j]" is the mask for the "i-th" feature value on the "j-th" + // tree; + std::vector> masks; + }; + + // Similar to the fields of the same name above, but indexed by the dataspec + // feature index. + // + // Note: Absl hash map does not check at compile time the availability of + // AVX2. + std::unordered_map is_higher_conditions; + std::unordered_map categorical_contains_conditions; + std::unordered_map + categoricalset_contains_conditions; + }; +}; + +// ANDs a "mask" on a value contained in a map (specified by a key) i.e. +// "map[key] &= mask". If the map does not contains the key, set it to the +// "mask" value i.e. "map[key] = mask". +template +void AndMaskMap(const typename Map::key_type& key, + QuickScorerExtendedModel::LeafMask mask, Map* map) { + const auto insertion = map->insert({key, mask}); + if (!insertion.second) { + insertion.first->second &= mask; + } +} + +} // namespace internal + +// Specialization of quick scorer for GBDT regression model. +struct GradientBoostedTreesRegressionQuickScorerExtended + : internal::QuickScorerExtendedModel { + static constexpr model::proto::Task kTask = model::proto::Task::REGRESSION; +}; + +// Specialization of quick scorer for GBDT binary classification model. +struct GradientBoostedTreesBinaryClassificationQuickScorerExtended + : internal::QuickScorerExtendedModel { + static constexpr model::proto::Task kTask = + model::proto::Task::CLASSIFICATION; +}; + +// Specialization of quick scorer for GBDT ranking model. +struct GradientBoostedTreesRankingQuickScorerExtended + : internal::QuickScorerExtendedModel { + static constexpr model::proto::Task kTask = model::proto::Task::RANKING; +}; + +// Computes the model's prediction on a batch of examples. +// +// This method is thread safe. +// +// This method uses a significant amount of stack size. See +// "GenericToSpecializedModel" for more details. +// +// Args: +// - model: A quick scorer model (e.g. +// GradientBoostedTreesRegressionQuickScorer) initialized with +// "GenericToSpecializedModel". +// - examples: A batch of examples. The examples are stored FEATURE-WISE. +// - num_examples: Number of examples in the batch. +// - predictions: Output predictions. Does not need to be pre-allocated. +// + +template +void PredictQuickScorer( + const Model& model, + const std::vector& examples, int num_examples, + std::vector* predictions); + +// Version of PredictQuickScorer compatible with the ExampleSet signature. +template +void Predict(const Model& model, const typename Model::ExampleSet& examples, + int num_examples, std::vector* predictions); + +// Converts a generic GradientBoostedTreesModel with regression loss into a +// quick scorer compatible model. +// +// This method checks that the model inference (i.e. PredictQuickScorer) won't +// take more than 16kb of stack size. The stack size usage is +// defined by the number of trees and number of leafs of the model. For +// reference, the ML AOI model is taking 1.5kb of stack size. If your model is +// too large, contact us (gbm@) for the heap version of this method (~10% +// slower) or use the inference code in "decision_forest.h" (>2x slower, no +// model limit). +template +absl::Status GenericToSpecializedModel(const AbstractModel& src, + CompiledModel* dst); + +// Creates an empty model that returns a constant value (e.g. 0 for regression) +// but which consumes (and ignores) the input features specified at +// construction. +// +// This function can be used to create fake models to unit test the generation +// of ExampleSets. +template +absl::Status CreateEmptyModel(const std::vector& input_features, + const DataSpecification& dataspec, + CompiledModel* dst); + +// Generates a human readable text describing the internal of the quick scorer +// model. +// +// This description is intended for debugging or optimization purpose. For a ML +// development intended description of the model, use the "describe" method on +// the non-compiled model. +template +std::string DescribeQuickScorer(const Model& model, bool detailed = true); + +} // namespace decision_forest +} // namespace serving +} // namespace yggdrasil_decision_forests + +#endif // YGGDRASIL_DECISION_FORESTS_SERVING_DECISION_FOREST_QUICK_SCORER_EXTENDED_H_ diff --git a/yggdrasil_decision_forests/serving/decision_forest/quick_scorer_extended_test.cc b/yggdrasil_decision_forests/serving/decision_forest/quick_scorer_extended_test.cc new file mode 100644 index 00000000..9fe61509 --- /dev/null +++ b/yggdrasil_decision_forests/serving/decision_forest/quick_scorer_extended_test.cc @@ -0,0 +1,396 @@ +/* + * Copyright 2021 Google LLC. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "yggdrasil_decision_forests/model/decision_tree/decision_tree.h" +#include "yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.h" +#include "yggdrasil_decision_forests/utils/test.h" + +#include "yggdrasil_decision_forests/serving/decision_forest/quick_scorer_extended.h" + +namespace yggdrasil_decision_forests { +namespace serving { +namespace decision_forest { +namespace { + +using model::decision_tree::DecisionTree; +using model::decision_tree::NodeWithChildren; +using model::decision_tree::proto::Condition; +using model::gradient_boosted_trees::GradientBoostedTreesModel; +using model::gradient_boosted_trees::proto::Loss; +using testing::ElementsAre; + +void BuildToyModelAndToyDataset(const model::proto::Task task, + const bool use_cateset_feature, + GradientBoostedTreesModel* model, + dataset::VerticalDataset* dataset, + const int duplicate_factor = 1) { + dataset::proto::DataSpecification dataspec = PARSE_TEST_PROTO(R"pb( + columns { type: NUMERICAL name: "a" } + columns { type: NUMERICAL name: "b" } + columns { + type: CATEGORICAL + name: "c" + categorical { is_already_integerized: true number_of_unique_values: 4 } + } + columns { + type: DISCRETIZED_NUMERICAL + name: "e" + numerical { mean: -1 } + discretized_numerical { + boundaries: 0.0 + boundaries: 0.1 + boundaries: 0.2 + boundaries: 0.3 + } + } + )pb"); + + if (use_cateset_feature) { + auto* d = dataspec.add_columns(); + d->set_name("d"); + d->set_type(dataset::proto::ColumnType::CATEGORICAL_SET); + d->mutable_categorical()->set_is_already_integerized(false); + d->mutable_categorical()->set_number_of_unique_values(4); + auto& items = *d->mutable_categorical()->mutable_items(); + items["v0"].set_index(0); + items["v1"].set_index(1); + items["v2"].set_index(2); + items["v3"].set_index(3); + } + + dataset->set_data_spec(dataspec); + CHECK_OK(dataset->CreateColumnsFromDataspec()); + + if (use_cateset_feature) { + dataset->AppendExample( + {{"a", "0.5"}, {"b", "0.5"}, {"c", "1"}, {"e", "1.5"}, {"d", "2 3"}}); + } else { + dataset->AppendExample( + {{"a", "0.5"}, {"b", "0.5"}, {"c", "1"}, {"e", "1.5"}}); + } + + struct NodeHelper { + Condition* condition; + NodeWithChildren* pos; + NodeWithChildren* neg; + NodeWithChildren* node; + }; + + const auto split_node = [](NodeWithChildren* node, + const int attribute) -> NodeHelper { + node->CreateChildren(); + node->mutable_node()->mutable_condition()->set_attribute(attribute); + return {/*.condition =*/ + node->mutable_node()->mutable_condition()->mutable_condition(), + /*.pos =*/node->mutable_pos_child(), + /*.neg =*/node->mutable_neg_child(), + /*.node =*/node}; + }; + + model->set_task(task); + model->set_label_col_idx(0); + model->set_data_spec(dataspec); + model->set_loss(Loss::SQUARED_ERROR); + model->mutable_initial_predictions()->push_back(duplicate_factor); + + for (int duplication_idx = 0; duplication_idx < duplicate_factor; + duplication_idx++) { + { + auto tree = absl::make_unique(); + tree->CreateRoot(); + auto n1 = split_node(tree->mutable_root(), 1); + n1.condition->mutable_higher_condition()->set_threshold(2.0f); + + auto n2 = split_node(n1.pos, 1); + n2.condition->mutable_higher_condition()->set_threshold(3.0f); + n2.pos->mutable_node()->mutable_regressor()->set_top_value(1.f); + n2.neg->mutable_node()->mutable_regressor()->set_top_value(2.f); + + auto n3 = split_node(n1.neg, 1); + n3.condition->mutable_higher_condition()->set_threshold(1.0f); + n3.pos->mutable_node()->mutable_regressor()->set_top_value(3.f); + n3.neg->mutable_node()->mutable_regressor()->set_top_value(4.f); + + model->mutable_decision_trees()->push_back(std::move(tree)); + } + + { + auto tree = absl::make_unique(); + tree->CreateRoot(); + auto n1 = split_node(tree->mutable_root(), 2); + n1.condition->mutable_contains_condition()->add_elements(1); + n1.condition->mutable_contains_condition()->add_elements(2); + + auto n2 = split_node(n1.pos, 1); + n2.condition->mutable_higher_condition()->set_threshold(2.5f); + n2.pos->mutable_node()->mutable_regressor()->set_top_value(10.f); + n2.neg->mutable_node()->mutable_regressor()->set_top_value(20.f); + + auto n3 = split_node(n1.neg, 1); + n3.condition->mutable_higher_condition()->set_threshold(1.5f); + n3.pos->mutable_node()->mutable_regressor()->set_top_value(30.f); + n3.neg->mutable_node()->mutable_regressor()->set_top_value(40.f); + + model->mutable_decision_trees()->push_back(std::move(tree)); + } + + { + auto tree = absl::make_unique(); + tree->CreateRoot(); + auto n1 = split_node(tree->mutable_root(), 1); + n1.condition->mutable_higher_condition()->set_threshold(10.0f); + + auto n2 = split_node(n1.pos, 2); + n2.condition->mutable_contains_bitmap_condition()->set_elements_bitmap( + "\x02"); // [1] + n2.pos->mutable_node()->mutable_regressor()->set_top_value(100.f); + n2.neg->mutable_node()->mutable_regressor()->set_top_value(200.f); + + auto n3 = split_node(n1.neg, 2); + n3.condition->mutable_contains_bitmap_condition()->set_elements_bitmap( + "\x0A"); // [1,3] + n3.pos->mutable_node()->mutable_regressor()->set_top_value(300.f); + n3.neg->mutable_node()->mutable_regressor()->set_top_value(400.f); + + model->mutable_decision_trees()->push_back(std::move(tree)); + } + + { + auto tree = absl::make_unique(); + tree->CreateRoot(); + auto n1 = split_node(tree->mutable_root(), 3); + n1.condition->mutable_discretized_higher_condition()->set_threshold( + 2); // value>=0.1f + n1.neg->mutable_node()->mutable_regressor()->set_top_value(10000.f); + n1.pos->mutable_node()->mutable_regressor()->set_top_value(20000.f); + model->mutable_decision_trees()->push_back(std::move(tree)); + } + + if (use_cateset_feature) { + auto tree = absl::make_unique(); + tree->CreateRoot(); + auto n1 = split_node(tree->mutable_root(), 4); + n1.condition->mutable_contains_condition()->add_elements(2); + n1.condition->mutable_contains_condition()->add_elements(3); + n1.node->mutable_node()->mutable_condition()->set_na_value(true); + + n1.neg->mutable_node()->mutable_regressor()->set_top_value(1000.f); + + auto n2 = split_node(n1.pos, 4); + n2.condition->mutable_contains_bitmap_condition()->set_elements_bitmap( + "\x08"); // [3] + n2.neg->mutable_node()->mutable_regressor()->set_top_value(2000.f); + n2.pos->mutable_node()->mutable_regressor()->set_top_value(3000.f); + n2.node->mutable_node()->mutable_condition()->set_na_value(false); + model->mutable_decision_trees()->push_back(std::move(tree)); + } + } +} + +TEST(QuickScorer, Compilation) { + GradientBoostedTreesModel model; + dataset::VerticalDataset dataset; + BuildToyModelAndToyDataset(model::proto::Task::REGRESSION, + /*use_cateset_feature=*/false, &model, &dataset); + GradientBoostedTreesRegressionQuickScorerExtended quick_scorer_model; + CHECK_OK(GenericToSpecializedModel(model, &quick_scorer_model)); + + const auto model_description = DescribeQuickScorer(quick_scorer_model); + LOG(INFO) << "Model:\n" << model_description; + + EXPECT_EQ(quick_scorer_model.features().input_features().size(), 3); + + // Examples in FORMAT_FEATURE_MAJOR, see decision_forest.h. + using V = NumericalOrCategoricalValue; + std::vector examples = { + // Feature 1 + V::Numerical(0.5f), + V::Numerical(1.0f), + V::Numerical(1.5f), + V::Numerical(2.5f), + V::Numerical(3.5f), + // Feature 2 + V::Categorical(0), + V::Categorical(1), + V::Categorical(2), + V::Categorical(0), + V::Categorical(1), + // Feature 3 + V::Numerical(0.00f), + V::Numerical(0.05f), + V::Numerical(0.10f), + V::Numerical(0.20f), + V::Numerical(0.30f), + }; + std::vector predictions; + PredictQuickScorer(quick_scorer_model, examples, 5, &predictions); + + EXPECT_THAT(predictions, + ElementsAre(1 + 4 + 40 + 400 + 10000, 1 + 3 + 20 + 300 + 10000, + 1 + 3 + 20 + 400 + 20000, 1 + 2 + 30 + 400 + 20000, + 1 + 1 + 10 + 300 + 20000)); +} + +TEST(QuickScorer, ExampleSet) { + GradientBoostedTreesModel model; + dataset::VerticalDataset dataset; + BuildToyModelAndToyDataset(model::proto::Task::REGRESSION, + /*use_cateset_feature=*/true, &model, &dataset); + GradientBoostedTreesRegressionQuickScorerExtended quick_scorer_model; + CHECK_OK(GenericToSpecializedModel(model, &quick_scorer_model)); + + const auto model_description = DescribeQuickScorer(quick_scorer_model); + LOG(INFO) << "Model:\n" << model_description; + + GradientBoostedTreesRegressionQuickScorerExtended::ExampleSet examples( + 5, quick_scorer_model); + examples.FillMissing(quick_scorer_model); + + EXPECT_EQ(quick_scorer_model.features().input_features().size(), 4); + + const auto feature_1 = + GradientBoostedTreesRegressionQuickScorerExtended::ExampleSet:: + GetNumericalFeatureId("b", quick_scorer_model) + .value(); + const auto feature_2 = + GradientBoostedTreesRegressionQuickScorerExtended::ExampleSet:: + GetCategoricalFeatureId("c", quick_scorer_model) + .value(); + const auto feature_3 = + GradientBoostedTreesRegressionQuickScorerExtended::ExampleSet:: + GetCategoricalSetFeatureId("d", quick_scorer_model) + .value(); + const auto feature_4 = + GradientBoostedTreesRegressionQuickScorerExtended::ExampleSet:: + GetNumericalFeatureId("e", quick_scorer_model) + .value(); + + examples.SetNumerical(0, feature_1, 0.5f, quick_scorer_model); + examples.SetNumerical(1, feature_1, 1.0f, quick_scorer_model); + examples.SetNumerical(2, feature_1, 1.5f, quick_scorer_model); + examples.SetNumerical(3, feature_1, 2.5f, quick_scorer_model); + examples.SetNumerical(4, feature_1, 3.5f, quick_scorer_model); + + examples.SetCategorical(0, feature_2, 0, quick_scorer_model); + examples.SetCategorical(1, feature_2, 1, quick_scorer_model); + examples.SetCategorical(2, feature_2, 2, quick_scorer_model); + examples.SetCategorical(3, feature_2, 0, quick_scorer_model); + examples.SetCategorical(4, feature_2, 1, quick_scorer_model); + + examples.SetCategoricalSet(0, feature_3, {"v1"}, quick_scorer_model); + examples.SetCategoricalSet(1, feature_3, {"v2"}, quick_scorer_model); + examples.SetCategoricalSet(2, feature_3, {"v3"}, quick_scorer_model); + examples.SetCategoricalSet(3, feature_3, std::vector{"v2", "v3"}, + quick_scorer_model); + examples.SetMissingCategoricalSet(4, feature_3, quick_scorer_model); + + examples.SetNumerical(0, feature_4, 0.00f, quick_scorer_model); + examples.SetNumerical(1, feature_4, 0.05f, quick_scorer_model); + examples.SetNumerical(2, feature_4, 0.10f, quick_scorer_model); + examples.SetNumerical(3, feature_4, 0.20f, quick_scorer_model); + examples.SetNumerical(4, feature_4, 0.30f, quick_scorer_model); + + std::vector predictions; + Predict(quick_scorer_model, examples, 5, &predictions); + + EXPECT_THAT(predictions, ElementsAre(1 + 4 + 40 + 400 + 1000 + 10000, + 1 + 3 + 20 + 300 + 2000 + 10000, + 1 + 3 + 20 + 400 + 3000 + 20000, + 1 + 2 + 30 + 400 + 3000 + 20000, + 1 + 1 + 10 + 300 + 2000 + 20000)); +} + +TEST(QuickScorer, ExceedStackBuffer) { + const int duplicate_factor = 200; + + GradientBoostedTreesModel model; + + dataset::VerticalDataset dataset; + BuildToyModelAndToyDataset(model::proto::Task::REGRESSION, + /*use_cateset_feature=*/true, &model, &dataset, + duplicate_factor); + GradientBoostedTreesRegressionQuickScorerExtended quick_scorer_model; + CHECK_OK(GenericToSpecializedModel(model, &quick_scorer_model)); + + const auto model_description = DescribeQuickScorer(quick_scorer_model); + LOG(INFO) << "Model:\n" << model_description; + + GradientBoostedTreesRegressionQuickScorerExtended::ExampleSet examples( + 5, quick_scorer_model); + examples.FillMissing(quick_scorer_model); + + EXPECT_EQ(quick_scorer_model.features().input_features().size(), 4); + + const auto feature_1 = + GradientBoostedTreesRegressionQuickScorerExtended::ExampleSet:: + GetNumericalFeatureId("b", quick_scorer_model) + .value(); + const auto feature_2 = + GradientBoostedTreesRegressionQuickScorerExtended::ExampleSet:: + GetCategoricalFeatureId("c", quick_scorer_model) + .value(); + const auto feature_3 = + GradientBoostedTreesRegressionQuickScorerExtended::ExampleSet:: + GetCategoricalSetFeatureId("d", quick_scorer_model) + .value(); + + const auto feature_4 = + GradientBoostedTreesRegressionQuickScorerExtended::ExampleSet:: + GetNumericalFeatureId("e", quick_scorer_model) + .value(); + + examples.SetNumerical(0, feature_1, 0.5f, quick_scorer_model); + examples.SetNumerical(1, feature_1, 1.0f, quick_scorer_model); + examples.SetNumerical(2, feature_1, 1.5f, quick_scorer_model); + examples.SetNumerical(3, feature_1, 2.5f, quick_scorer_model); + examples.SetNumerical(4, feature_1, 3.5f, quick_scorer_model); + + examples.SetCategorical(0, feature_2, 0, quick_scorer_model); + examples.SetCategorical(1, feature_2, 1, quick_scorer_model); + examples.SetCategorical(2, feature_2, 2, quick_scorer_model); + examples.SetCategorical(3, feature_2, 0, quick_scorer_model); + examples.SetCategorical(4, feature_2, 1, quick_scorer_model); + + examples.SetCategoricalSet(0, feature_3, {"v1"}, quick_scorer_model); + examples.SetCategoricalSet(1, feature_3, {"v2"}, quick_scorer_model); + examples.SetCategoricalSet(2, feature_3, {"v3"}, quick_scorer_model); + examples.SetCategoricalSet(3, feature_3, std::vector{"v2", "v3"}, + quick_scorer_model); + examples.SetMissingCategoricalSet(4, feature_3, quick_scorer_model); + + examples.SetNumerical(0, feature_4, 0.00f, quick_scorer_model); + examples.SetNumerical(1, feature_4, 0.05f, quick_scorer_model); + examples.SetNumerical(2, feature_4, 0.10f, quick_scorer_model); + examples.SetNumerical(3, feature_4, 0.20f, quick_scorer_model); + examples.SetNumerical(4, feature_4, 0.30f, quick_scorer_model); + + std::vector predictions; + Predict(quick_scorer_model, examples, 5, &predictions); + + EXPECT_THAT( + predictions, + ElementsAre((1 + 4 + 40 + 400 + 1000 + 10000) * duplicate_factor, + (1 + 3 + 20 + 300 + 2000 + 10000) * duplicate_factor, + (1 + 3 + 20 + 400 + 3000 + 20000) * duplicate_factor, + (1 + 2 + 30 + 400 + 3000 + 20000) * duplicate_factor, + (1 + 1 + 10 + 300 + 2000 + 20000) * duplicate_factor)); +} + +} // namespace +} // namespace decision_forest +} // namespace serving +} // namespace yggdrasil_decision_forests diff --git a/yggdrasil_decision_forests/serving/decision_forest/register_engines.cc b/yggdrasil_decision_forests/serving/decision_forest/register_engines.cc index 6b888c3e..1eece911 100644 --- a/yggdrasil_decision_forests/serving/decision_forest/register_engines.cc +++ b/yggdrasil_decision_forests/serving/decision_forest/register_engines.cc @@ -24,7 +24,6 @@ #include "yggdrasil_decision_forests/serving/example_set_model_wrapper.h" #include "yggdrasil_decision_forests/utils/compatibility.h" - namespace yggdrasil_decision_forests { namespace model { diff --git a/yggdrasil_decision_forests/serving/decision_forest/register_engines.h b/yggdrasil_decision_forests/serving/decision_forest/register_engines.h index 279d012a..a820db6c 100644 --- a/yggdrasil_decision_forests/serving/decision_forest/register_engines.h +++ b/yggdrasil_decision_forests/serving/decision_forest/register_engines.h @@ -26,10 +26,14 @@ namespace serving { namespace gradient_boosted_trees { constexpr char kGeneric[] = "GradientBoostedTreesGeneric"; +constexpr char kQuickScorerExtended[] = + "GradientBoostedTreesQuickScorerExtended"; +constexpr char kOptPred[] = "GradientBoostedTreesOptPred"; } // namespace gradient_boosted_trees namespace random_forest { constexpr char kGeneric[] = "RandomForestGeneric"; +constexpr char kOptPred[] = "RandomForestOptPred"; } // namespace random_forest } // namespace serving diff --git a/yggdrasil_decision_forests/serving/example_set.h b/yggdrasil_decision_forests/serving/example_set.h index 1efc8d14..64243fe7 100644 --- a/yggdrasil_decision_forests/serving/example_set.h +++ b/yggdrasil_decision_forests/serving/example_set.h @@ -25,14 +25,14 @@ // // // Initialize. // std::unique_ptr abstract_model = ...; -// GradientBoostedTreesBinaryClassificationNumericalAndCategorical model; +// GradientBoostedTreesBinaryClassificationQuickScorerExtended model; // GenericToSpecializedModel(abstract_model, &model); // auto feature_1 = model.GetNumericalFeatureId("feature_1"); // auto feature_2 = model.GetCategoricalFeatureId("feature_2"); // auto feature_3 = model.GetNumericalFeatureId("feature_3"); // // // Allocate 5 examples. -// GradientBoostedTreesBinaryClassificationNumericalAndCategorical::ExampleSet +// GradientBoostedTreesBinaryClassificationQuickScorerExtended::ExampleSet // examples(5); examples.FillMissing(model); // // // Set one examples and run the model.