Skip to content

Commit

Permalink
Add a model compiler for Numerical Ranking models.
Browse files Browse the repository at this point in the history
With the model compiler, YDF can generate pure C++ headers containing a model for inference with minimal binary impact.

PiperOrigin-RevId: 516481180
  • Loading branch information
rstz authored and copybara-github committed Mar 14, 2023
1 parent 3e177db commit 615c643
Show file tree
Hide file tree
Showing 16 changed files with 2,028 additions and 0 deletions.
16 changes: 16 additions & 0 deletions examples/model_compiler/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
load("//yggdrasil_decision_forests/utils:compile.bzl", "cc_library_ydf")

package(
default_visibility = ["//visibility:public"],
licenses = ["notice"],
)

# An example of a model generated by the model compiler.
cc_library_ydf(
name = "generated_model",
hdrs = ["generated_model.h"],
deps = [
"//yggdrasil_decision_forests/serving/decision_forest:decision_forest_serving",
"@com_google_absl//absl/strings",
],
)
1,288 changes: 1,288 additions & 0 deletions examples/model_compiler/generated_model.h

Large diffs are not rendered by default.

61 changes: 61 additions & 0 deletions yggdrasil_decision_forests/cli/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,18 @@ cc_binary_ydf(
],
)

cc_binary_ydf(
name = "compile_model",
srcs = ["compile_model.cc"],
deps = [
"//yggdrasil_decision_forests/serving/decision_forest:model_compiler",
"//yggdrasil_decision_forests/utils:logging",
"//yggdrasil_decision_forests/utils:status_macros",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/status:statusor",
],
)

# Tests
# =====

Expand All @@ -225,3 +237,52 @@ sh_test(
"//yggdrasil_decision_forests/test_data",
],
)

genrule(
name = "compile_model_for_test",
testonly = 1,
srcs = [
":compile_model.cc",
"//yggdrasil_decision_forests/test_data",
],
outs = ["generated_model.h"],
cmd = """
TESTDATA_BASEDIR=$$(dirname $$(echo $(locations //yggdrasil_decision_forests/test_data) | cut -d ' ' -f1))/model/synthetic_ranking_gbdt_numerical/
$(location :compile_model) --model $$TESTDATA_BASEDIR --namespace test_model > $@
""",
exec_tools = [":compile_model"],
)

cc_library_ydf(
name = "compiled_model_for_test",
testonly = 1,
hdrs = ["generated_model.h"],
deps = [
"//yggdrasil_decision_forests/model:abstract_model_cc_proto",
"//yggdrasil_decision_forests/serving/decision_forest:decision_forest_serving",
"//yggdrasil_decision_forests/utils:protobuf",
"@com_google_absl//absl/strings",
],
)

cc_test(
name = "compile_model_test",
srcs = ["compile_model_test.cc"],
data = [
"//yggdrasil_decision_forests/test_data",
],
deps = [
":compiled_model_for_test",
"//yggdrasil_decision_forests/dataset:all_dataset_formats",
"//yggdrasil_decision_forests/dataset:vertical_dataset_io",
"//yggdrasil_decision_forests/model:abstract_model",
"//yggdrasil_decision_forests/model:model_library",
"//yggdrasil_decision_forests/model/gradient_boosted_trees",
"//yggdrasil_decision_forests/serving/decision_forest",
"//yggdrasil_decision_forests/utils:filesystem",
"//yggdrasil_decision_forests/utils:logging",
"//yggdrasil_decision_forests/utils:test",
"//yggdrasil_decision_forests/utils:testing_macros",
"@com_google_googletest//:gtest_main",
],
)
65 changes: 65 additions & 0 deletions yggdrasil_decision_forests/cli/compile_model.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright 2022 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// Compile a YDF model into a small C++ header for very efficient serving and
// print this .h model on the standard output.
//
// This tool converts a YDF model into a C++ header that can be served with
// minimal dependencies. Doing so is only necessary for very specific use cases
// where binary size is very important. The majority of users should use the
// classic path, see predict.cc for details.
//
// An example of model header generated with this tool is available at
// examples/model_compiler/generated_model.h
//
// Supported models:
// - Ranking GBT with numerical only splits

#include <iostream>

#include "absl/flags/flag.h"
#include "absl/status/statusor.h"
#include "yggdrasil_decision_forests/serving/decision_forest/model_compiler.h"
#include "yggdrasil_decision_forests/utils/logging.h"
#include "yggdrasil_decision_forests/utils/status_macros.h"

ABSL_FLAG(std::string, model, "", "Model directory (required).");

ABSL_FLAG(std::string, namespace, "",
"Innermost namespace for the model (required), e.g. my_model. The "
"model will then be available as under "
"yggdrasil_decision_forests::compiled_model::my_model::GetModel()");

constexpr char kUsageMessage[] = "Compile a model into a C++ include.";
namespace yggdrasil_decision_forests {
namespace cli {
absl::StatusOr<std::string> CompileModel() {
// Check required flags.
STATUS_CHECK(!absl::GetFlag(FLAGS_model).empty());
STATUS_CHECK(!absl::GetFlag(FLAGS_namespace).empty());

return serving::decision_forest::CompileRankingNumericalOnly(
absl::GetFlag(FLAGS_model), absl::GetFlag(FLAGS_namespace));
}
} // namespace cli
} // namespace yggdrasil_decision_forests

int main(int argc, char** argv) {
InitLogging(kUsageMessage, &argc, &argv, true);
const auto model_file = yggdrasil_decision_forests::cli::CompileModel();
QCHECK_OK(model_file.status());
std::cout << model_file.value();
return 0;
}
139 changes: 139 additions & 0 deletions yggdrasil_decision_forests/cli/compile_model_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Copyright 2022 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <string>

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "yggdrasil_decision_forests/cli/generated_model.h"
#include "yggdrasil_decision_forests/dataset/vertical_dataset_io.h"
#include "yggdrasil_decision_forests/model/abstract_model.h"
#include "yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.h"
#include "yggdrasil_decision_forests/model/model_library.h"
#include "yggdrasil_decision_forests/serving/decision_forest/decision_forest.h"
#include "yggdrasil_decision_forests/utils/filesystem.h"
#include "yggdrasil_decision_forests/utils/logging.h"
#include "yggdrasil_decision_forests/utils/test.h"
#include "yggdrasil_decision_forests/utils/testing_macros.h"

namespace yggdrasil_decision_forests {
namespace cli {
namespace {

std::string TestDataDir() {
return file::JoinPath(test::DataRootDirectory(),
"yggdrasil_decision_forests/test_data");
}

using test::EqualsProto;
using ::testing::ElementsAre;
using ::testing::FloatNear;
using ::testing::Pointwise;
using ::testing::SizeIs;

// Margin of error for numerical tests.
constexpr float kTestPrecision = 0.00001f;

// TODO: Add a test with a model with oblique weights.
TEST(CompileModelTest, DataSpec) {
ASSERT_OK_AND_ASSIGN(const auto compiled_model,
compiled_model::test_model::GetModel());

std::unique_ptr<model::AbstractModel> uncompiled_model;
ASSERT_OK(model::LoadModel(file::JoinPath(TestDataDir(), "model",
"synthetic_ranking_gbdt_numerical"),
&uncompiled_model));
EXPECT_THAT(compiled_model->internal_features.data_spec(),
EqualsProto(uncompiled_model->data_spec()));
}

TEST(CompileModelTest, Metadata) {
ASSERT_OK_AND_ASSIGN(const auto compiled_model,
compiled_model::test_model::GetModel());

std::unique_ptr<model::AbstractModel> uncompiled_model;
ASSERT_OK(model::LoadModel(file::JoinPath(TestDataDir(), "model",
"synthetic_ranking_gbdt_numerical"),
&uncompiled_model));
model::proto::Metadata uncompiled_model_metadata_proto;
uncompiled_model->metadata().Export(&uncompiled_model_metadata_proto);
EXPECT_THAT(compiled_model->metadata,
EqualsProto(uncompiled_model_metadata_proto));
}

TEST(CompileModelTest, GBTModelParameters) {
ASSERT_OK_AND_ASSIGN(const auto compiled_model,
compiled_model::test_model::GetModel());

std::unique_ptr<model::AbstractModel> uncompiled_model;
ASSERT_OK(model::LoadModel(file::JoinPath(TestDataDir(), "model",
"synthetic_ranking_gbdt_numerical"),
&uncompiled_model));

auto* gbt_model =
dynamic_cast<model::gradient_boosted_trees::GradientBoostedTreesModel*>(
uncompiled_model.get());
ASSERT_NE(gbt_model, nullptr);
EXPECT_THAT(compiled_model->root_offsets, SizeIs(gbt_model->NumTrees()));
EXPECT_THAT(gbt_model->initial_predictions(),
ElementsAre(compiled_model->initial_predictions));
}

TEST(CompiledModelTest, ModelPredictions) {
ASSERT_OK_AND_ASSIGN(const auto compiled_model,
compiled_model::test_model::GetModel());

std::unique_ptr<model::AbstractModel> uncompiled_model;
ASSERT_OK(model::LoadModel(file::JoinPath(TestDataDir(), "model",
"synthetic_ranking_gbdt_numerical"),
&uncompiled_model));

const auto& test_ds_path = absl::StrCat(
"csv:",
file::JoinPath(TestDataDir(), "dataset", "synthetic_ranking_test.csv"));
dataset::VerticalDataset dataset;
ASSERT_OK(dataset::LoadVerticalDataset(
test_ds_path, compiled_model->internal_features.data_spec(), &dataset));

std::vector<float> slow_engine_predictions;
slow_engine_predictions.resize(dataset.nrow());
for (dataset::VerticalDataset::row_t example_idx = 0;
example_idx < dataset.nrow(); example_idx++) {
model::proto::Prediction prediction;
uncompiled_model->Predict(dataset, example_idx, &prediction);
slow_engine_predictions[example_idx] = prediction.ranking().relevance();
}

std::vector<float> flat_examples;
auto feature_names =
FeatureNames(compiled_model->internal_features.fixed_length_features());
auto replacement_values =
compiled_model->internal_features.fixed_length_na_replacement_values();
ASSERT_OK(serving::decision_forest::LoadFlatBatchFromDataset(
dataset, 0, dataset.nrow(), feature_names, replacement_values,
&flat_examples, serving::ExampleFormat::FORMAT_EXAMPLE_MAJOR));

std::vector<float> compiled_model_predictions;
compiled_model_predictions.resize(dataset.nrow());
yggdrasil_decision_forests::serving::decision_forest::PredictOptimizedV1(
*compiled_model, flat_examples, dataset.nrow(),
&compiled_model_predictions);
EXPECT_THAT(compiled_model_predictions,
Pointwise(FloatNear(kTestPrecision), slow_engine_predictions));
}

} // namespace
} // namespace cli
} // namespace yggdrasil_decision_forests
21 changes: 21 additions & 0 deletions yggdrasil_decision_forests/serving/decision_forest/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,27 @@ package(
# Libraries
# =========

cc_library_ydf(
name = "model_compiler",
srcs = [
"model_compiler.cc",
],
hdrs = [
"model_compiler.h",
],
deps = [
":decision_forest",
"//yggdrasil_decision_forests/model:abstract_model",
"//yggdrasil_decision_forests/model:model_library",
"//yggdrasil_decision_forests/model/gradient_boosted_trees",
"//yggdrasil_decision_forests/utils:logging",
"//yggdrasil_decision_forests/utils:status_macros",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
],
)

cc_library_ydf(
name = "register_engines",
srcs = [
Expand Down
Loading

0 comments on commit 615c643

Please sign in to comment.