diff --git a/CHANGELOG.md b/CHANGELOG.md index ddcb0638..841a0d91 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Changelog +## 0.1.3 - 2021-05-19 + +### Features + +- Register new inference engines. + ## 0.1.2 - 2021-05-18 ### Features diff --git a/README.md b/README.md index 93d4c52c..c07ad0e2 100644 --- a/README.md +++ b/README.md @@ -95,8 +95,8 @@ Download one of the build, and then run `examples/beginner.{sh,bat}`. Target | Version | Link ------- | ------- | ---- -Linux | 0.1.2 | [CLI](https://github.com/google/yggdrasil-decision-forests/releases/download/0.1.2/cli_linux.zip) -Windows | 0.1.2 | [CLI](https://github.com/google/yggdrasil-decision-forests/releases/download/0.1.2/cli_windows.zip) +Linux | 0.1.0 | [CLI](https://github.com/google/yggdrasil-decision-forests/releases/download/0.1.0/cli_linux.zip) +Windows | 0.1.0 | [CLI](https://github.com/google/yggdrasil-decision-forests/releases/download/0.1.0/cli_windows.zip) ## Installation from Source diff --git a/yggdrasil_decision_forests/serving/decision_forest/register_engines.cc b/yggdrasil_decision_forests/serving/decision_forest/register_engines.cc index 1eece911..6c1bac68 100644 --- a/yggdrasil_decision_forests/serving/decision_forest/register_engines.cc +++ b/yggdrasil_decision_forests/serving/decision_forest/register_engines.cc @@ -21,6 +21,7 @@ #include "yggdrasil_decision_forests/model/fast_engine_factory.h" #include "yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.h" #include "yggdrasil_decision_forests/serving/decision_forest/decision_forest.h" +#include "yggdrasil_decision_forests/serving/decision_forest/quick_scorer_extended.h" #include "yggdrasil_decision_forests/serving/example_set_model_wrapper.h" #include "yggdrasil_decision_forests/utils/compatibility.h" @@ -130,6 +131,212 @@ class GradientBoostedTreesGenericFastEngineFactory : public FastEngineFactory { REGISTER_FastEngineFactory(GradientBoostedTreesGenericFastEngineFactory, serving::gradient_boosted_trees::kGeneric); +class GradientBoostedTreesQuickScorerFastEngineFactory + : public FastEngineFactory { + public: + using SourceModel = gradient_boosted_trees::GradientBoostedTreesModel; + + std::string name() const override { + return serving::gradient_boosted_trees::kQuickScorerExtended; + } + + bool IsCompatible(const AbstractModel* const model) const override { + auto* gbt_model = dynamic_cast(model); + if (gbt_model == nullptr) { + return false; + } + + if (!gbt_model->IsMissingValueConditionResultFollowGlobalImputation()) { + return false; + } + + if (gbt_model->NumTrees() > serving::decision_forest::internal:: + QuickScorerExtendedModel::kMaxTrees) { + return false; + } + + for (const auto& src_tree : gbt_model->decision_trees()) { + if (src_tree->NumLeafs() > serving::decision_forest::internal:: + QuickScorerExtendedModel::kMaxLeafs) { + return false; + } + } + + switch (gbt_model->task()) { + case proto::CLASSIFICATION: + return gbt_model->label_col_spec() + .categorical() + .number_of_unique_values() == 3; + case proto::REGRESSION: + case proto::RANKING: + return true; + default: + return false; + } + } + + std::vector IsBetterThan() const override { + return {serving::gradient_boosted_trees::kGeneric, + serving::gradient_boosted_trees::kOptPred}; + } + + utils::StatusOr> CreateEngine( + const AbstractModel* const model) const override { + auto* gbt_model = dynamic_cast(model); + if (!gbt_model) { + return absl::InvalidArgumentError("The model is not a GBDT."); + } + + switch (gbt_model->task()) { + case proto::CLASSIFICATION: + if (gbt_model->label_col_spec() + .categorical() + .number_of_unique_values() == 3) { + // Binary classification. + auto engine = absl::make_unique>(); + RETURN_IF_ERROR(engine->LoadModel(*gbt_model)); + return engine; + } else { + return absl::InvalidArgumentError("Non supported GBDT model"); + } + + case proto::REGRESSION: { + auto engine = absl::make_unique>(); + RETURN_IF_ERROR(engine->LoadModel(*gbt_model)); + return engine; + } + + case proto::RANKING: { + auto engine = absl::make_unique>(); + RETURN_IF_ERROR(engine->LoadModel(*gbt_model)); + return engine; + } + + default: + return absl::InvalidArgumentError("Non supported GBDT model"); + } + } +}; + +REGISTER_FastEngineFactory( + GradientBoostedTreesQuickScorerFastEngineFactory, + serving::gradient_boosted_trees::kQuickScorerExtended); + +class GradientBoostedTreesOptPredFastEngineFactory : public FastEngineFactory { + public: + using SourceModel = gradient_boosted_trees::GradientBoostedTreesModel; + + std::string name() const override { + return serving::gradient_boosted_trees::kOptPred; + } + + bool IsCompatible(const AbstractModel* const model) const override { + auto* gbt_model = dynamic_cast(model); + if (gbt_model == nullptr) { + return false; + } + + if (!gbt_model->IsMissingValueConditionResultFollowGlobalImputation()) { + return false; + } + + for (const auto& src_tree : gbt_model->decision_trees()) { + if (src_tree->NumLeafs() > std::numeric_limits::max()) { + return false; + } + } + + for (const auto feature_idx : gbt_model->input_features()) { + const auto& feature = gbt_model->data_spec().columns(feature_idx); + switch (feature.type()) { + case dataset::proto::ColumnType::NUMERICAL: + break; + case dataset::proto::ColumnType::CATEGORICAL: + if (feature.categorical().number_of_unique_values() > 32) { + return false; + } + break; + default: + return false; + } + } + + switch (gbt_model->task()) { + case proto::CLASSIFICATION: + return gbt_model->label_col_spec() + .categorical() + .number_of_unique_values() == 3; + case proto::REGRESSION: + case proto::RANKING: + return true; + default: + return false; + } + } + + std::vector IsBetterThan() const override { + return {serving::gradient_boosted_trees::kGeneric}; + } + + utils::StatusOr> CreateEngine( + const AbstractModel* const model) const override { + auto* gbt_model = dynamic_cast(model); + if (!gbt_model) { + return absl::InvalidArgumentError("The model is not a GBDT."); + } + + switch (gbt_model->task()) { + case proto::CLASSIFICATION: + if (gbt_model->label_col_spec() + .categorical() + .number_of_unique_values() == 3) { + // Binary classification. + auto engine = absl::make_unique>(); + RETURN_IF_ERROR(engine->LoadModel(*gbt_model)); + return engine; + } else { + return absl::InvalidArgumentError("Non supported GBDT model"); + } + + case proto::REGRESSION: { + auto engine = absl::make_unique>(); + RETURN_IF_ERROR(engine->LoadModel(*gbt_model)); + return engine; + } + + case proto::RANKING: { + auto engine = absl::make_unique>(); + RETURN_IF_ERROR(engine->LoadModel(*gbt_model)); + return engine; + } + + default: + return absl::InvalidArgumentError("Non supported GBDT model"); + } + } +}; + +REGISTER_FastEngineFactory(GradientBoostedTreesOptPredFastEngineFactory, + serving::gradient_boosted_trees::kOptPred); + class RandomForestGenericFastEngineFactory : public model::FastEngineFactory { public: using SourceModel = random_forest::RandomForestModel; @@ -222,5 +429,94 @@ class RandomForestGenericFastEngineFactory : public model::FastEngineFactory { REGISTER_FastEngineFactory(RandomForestGenericFastEngineFactory, serving::random_forest::kGeneric); +class RandomForestOptPredFastEngineFactory : public model::FastEngineFactory { + public: + using SourceModel = random_forest::RandomForestModel; + + std::string name() const override { return serving::random_forest::kOptPred; } + + bool IsCompatible(const AbstractModel* const model) const override { + auto* rf_model = dynamic_cast(model); + // This implementation is the most generic and least efficient engine. + if (rf_model == nullptr) { + return false; + } + if (!rf_model->IsMissingValueConditionResultFollowGlobalImputation()) { + return false; + } + + for (const auto& src_tree : rf_model->decision_trees()) { + if (src_tree->NumLeafs() > std::numeric_limits::max()) { + return false; + } + } + + for (const auto feature_idx : rf_model->input_features()) { + const auto& feature = rf_model->data_spec().columns(feature_idx); + switch (feature.type()) { + case dataset::proto::ColumnType::NUMERICAL: + break; + case dataset::proto::ColumnType::CATEGORICAL: + if (feature.categorical().number_of_unique_values() > 32) { + return false; + } + break; + default: + return false; + } + } + + switch (rf_model->task()) { + case proto::CLASSIFICATION: + return rf_model->label_col_spec() + .categorical() + .number_of_unique_values() == 3; + case proto::REGRESSION: + case proto::RANKING: + return true; + default: + return false; + } + } + + std::vector IsBetterThan() const override { + return {serving::random_forest::kGeneric}; + } + + utils::StatusOr> CreateEngine( + const AbstractModel* const model) const override { + auto* rf_model = dynamic_cast(model); + if (!rf_model) { + return absl::InvalidArgumentError("The model is not a RF."); + } + + switch (rf_model->task()) { + case model::proto::CLASSIFICATION: { + auto engine = absl::make_unique>(); + RETURN_IF_ERROR(engine->LoadModel(*rf_model)); + return engine; + } + + case model::proto::REGRESSION: { + auto engine = absl::make_unique>(); + RETURN_IF_ERROR(engine->LoadModel(*rf_model)); + return engine; + } + + default: + return absl::InvalidArgumentError("Non supported RF model"); + } + } +}; + +REGISTER_FastEngineFactory(RandomForestOptPredFastEngineFactory, + serving::random_forest::kOptPred); + } // namespace model } // namespace yggdrasil_decision_forests