Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 374448559
  • Loading branch information
achoum committed May 18, 2021
1 parent e133b45 commit fe034a5
Show file tree
Hide file tree
Showing 13 changed files with 1,870 additions and 7 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## 0.1.2 - 2021-05-18

### Features

- Inference engines: QuickScorer Extended and Pred

## 0.1.1 - 2021-05-17

### Features
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ Download one of the build, and then run `examples/beginner.{sh,bat}`.

Target | Version | Link
------- | ------- | ----
Linux | 0.1.1 | [CLI](https://github.com/google/yggdrasil-decision-forests/releases/download/0.1.1/cli_linux.zip)
Windows | 0.1.1 | [CLI](https://github.com/google/yggdrasil-decision-forests/releases/download/0.1.1/cli_windows.zip)
Linux | 0.1.0 | [CLI](https://github.com/google/yggdrasil-decision-forests/releases/download/0.1.0/cli_linux.zip)
Windows | 0.1.0 | [CLI](https://github.com/google/yggdrasil-decision-forests/releases/download/0.1.0/cli_windows.zip)

## Installation from Source

Expand Down
5 changes: 3 additions & 2 deletions yggdrasil_decision_forests/cli/benchmark_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@
// batch_size : 100 num_runs : 20
// time/example(µs) time/batch(µs) method
// ----------------------------------------
// interface] 9.179 917.9 GradientBoostedTreesGeneric [virtual
// interface] 21.547 2154.8 Generic slow engine
// 0.79025 79.025 GradientBoostedTreesQuickScorerExtended
// 9.179 917.9 GradientBoostedTreesGeneric
// 21.547 2154.8 Generic slow engine
// ----------------------------------------
//
#include "absl/flags/flag.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,28 @@ TEST_F(GradientBoostedTreesOnAdult, RandomCategorical) {
EXPECT_TRUE(gbt_model->IsMissingValueConditionResultFollowGlobalImputation());
}

// Train and test a model on the adult dataset with too much nodes for the
// QuickScorer serving algorithm.
TEST_F(GradientBoostedTreesOnAdult, BaseNoQuickScorer) {
auto* gbt_config = train_config_.MutableExtension(
gradient_boosted_trees::proto::gradient_boosted_trees_config);
gbt_config->set_num_trees(100);
gbt_config->mutable_decision_tree()->set_max_depth(10);
gbt_config->set_shrinkage(0.1f);
gbt_config->set_subsample(0.9f);
TrainAndEvaluateModel();

// Note: Accuracy is similar as RF (see :random_forest_test). However logloss
// is significantly better (which is expected as, unlike RF, GBT is
// calibrated).
EXPECT_NEAR(metric::Accuracy(evaluation_), 0.8549, 0.015);
EXPECT_NEAR(metric::LogLoss(evaluation_), 0.320, 0.04);

auto* gbt_model =
dynamic_cast<const GradientBoostedTreesModel*>(model_.get());
EXPECT_TRUE(gbt_model->IsMissingValueConditionResultFollowGlobalImputation());
}

// Train and test a model on the adult dataset.
TEST_F(GradientBoostedTreesOnAdult, BaseConcurrentDeprecated) {
auto* gbt_config = train_config_.MutableExtension(
Expand Down
36 changes: 36 additions & 0 deletions yggdrasil_decision_forests/serving/decision_forest/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ cc_library(
],
deps = [
":decision_forest",
":quick_scorer_extended",
"//yggdrasil_decision_forests/model:abstract_model",
"//yggdrasil_decision_forests/model/gradient_boosted_trees",
"//yggdrasil_decision_forests/serving:example_set_model_wrapper",
Expand Down Expand Up @@ -49,6 +50,28 @@ cc_library(
],
)

cc_library(
name = "quick_scorer_extended",
srcs = [
"quick_scorer_extended.cc",
],
hdrs = [
"quick_scorer_extended.h",
],
deps = [
":utils",
"@com_google_absl//absl/base:config",
"@com_google_absl//absl/status",
"//yggdrasil_decision_forests/model/decision_tree",
"//yggdrasil_decision_forests/model/gradient_boosted_trees",
"//yggdrasil_decision_forests/model/gradient_boosted_trees:gradient_boosted_trees_cc_proto",
"//yggdrasil_decision_forests/serving:example_set",
"//yggdrasil_decision_forests/utils:bitmap",
"//yggdrasil_decision_forests/utils:compatibility",
"//yggdrasil_decision_forests/utils:usage",
],
)

cc_library(
name = "utils",
srcs = [
Expand Down Expand Up @@ -84,6 +107,7 @@ cc_test(
shard_count = 10,
deps = [
":decision_forest",
":quick_scorer_extended",
":register_engines",
"@com_google_googletest//:gtest_main",
"@com_google_absl//absl/flags:flag",
Expand All @@ -108,3 +132,15 @@ cc_test(
"//yggdrasil_decision_forests/utils:test_utils",
],
)

cc_test(
name = "quick_scorer_extended_test",
srcs = ["quick_scorer_extended_test.cc"],
deps = [
":quick_scorer_extended",
"@com_google_googletest//:gtest_main",
"//yggdrasil_decision_forests/model/decision_tree",
"//yggdrasil_decision_forests/model/gradient_boosted_trees",
"//yggdrasil_decision_forests/utils:test",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ struct FlatNodeModel {
using NodeType = Node;
using FeaturesDefinition = FeaturesDefinitionNumericalOrCategoricalFlat;

using ExampleSet =
ExampleSetNumericalOrCategoricalFlat<FlatNodeModel<Node, Value>,
ExampleFormat::FORMAT_EXAMPLE_MAJOR>;

const FeaturesDefinition& features() const { return internal_features; }

FeaturesDefinition* mutable_features() { return &internal_features; }
Expand Down Expand Up @@ -555,6 +559,14 @@ void Predict(const GradientBoostedTreesRankingNumericalAndCategorical& model,
const std::vector<NumericalOrCategoricalValue>& examples,
int num_examples, std::vector<float>* predictions);

template <typename Model>
void PredictWithExampleSet(const Model& model,
const typename Model::ExampleSet& examples,
int num_examples, std::vector<float>* predictions) {
Predict(model, examples.InternalCategoricalAndNumericalValues(), num_examples,
predictions);
}

// Note: Requires for the number of trees to be a multiple of 8.
void PredictOptimizedV1(
const RandomForestBinaryClassificationNumericalFeatures& model,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.pb.h"
#include "yggdrasil_decision_forests/model/model_library.h"
#include "yggdrasil_decision_forests/model/prediction.pb.h"
#include "yggdrasil_decision_forests/serving/decision_forest/quick_scorer_extended.h"
#include "yggdrasil_decision_forests/utils/csv.h"
#include "yggdrasil_decision_forests/utils/distribution.pb.h"
#include "yggdrasil_decision_forests/utils/filesystem.h"
Expand Down
Loading

0 comments on commit fe034a5

Please sign in to comment.