Skip to content

Commit

Permalink
Merge pull request #5 from google/test_374585820
Browse files Browse the repository at this point in the history
Release of v0.1.3
  • Loading branch information
achoum authored May 19, 2021
2 parents 4903e46 + 31932e6 commit e2c9a49
Show file tree
Hide file tree
Showing 3 changed files with 304 additions and 2 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## 0.1.3 - 2021-05-19

### Features

- Register new inference engines.

## 0.1.2 - 2021-05-18

### Features
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ Download one of the build, and then run `examples/beginner.{sh,bat}`.

Target | Version | Link
------- | ------- | ----
Linux | 0.1.2 | [CLI](https://github.com/google/yggdrasil-decision-forests/releases/download/0.1.2/cli_linux.zip)
Windows | 0.1.2 | [CLI](https://github.com/google/yggdrasil-decision-forests/releases/download/0.1.2/cli_windows.zip)
Linux | 0.1.0 | [CLI](https://github.com/google/yggdrasil-decision-forests/releases/download/0.1.0/cli_linux.zip)
Windows | 0.1.0 | [CLI](https://github.com/google/yggdrasil-decision-forests/releases/download/0.1.0/cli_windows.zip)

## Installation from Source

Expand Down
296 changes: 296 additions & 0 deletions yggdrasil_decision_forests/serving/decision_forest/register_engines.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "yggdrasil_decision_forests/model/fast_engine_factory.h"
#include "yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.h"
#include "yggdrasil_decision_forests/serving/decision_forest/decision_forest.h"
#include "yggdrasil_decision_forests/serving/decision_forest/quick_scorer_extended.h"
#include "yggdrasil_decision_forests/serving/example_set_model_wrapper.h"
#include "yggdrasil_decision_forests/utils/compatibility.h"

Expand Down Expand Up @@ -130,6 +131,212 @@ class GradientBoostedTreesGenericFastEngineFactory : public FastEngineFactory {
REGISTER_FastEngineFactory(GradientBoostedTreesGenericFastEngineFactory,
serving::gradient_boosted_trees::kGeneric);

class GradientBoostedTreesQuickScorerFastEngineFactory
: public FastEngineFactory {
public:
using SourceModel = gradient_boosted_trees::GradientBoostedTreesModel;

std::string name() const override {
return serving::gradient_boosted_trees::kQuickScorerExtended;
}

bool IsCompatible(const AbstractModel* const model) const override {
auto* gbt_model = dynamic_cast<const SourceModel*>(model);
if (gbt_model == nullptr) {
return false;
}

if (!gbt_model->IsMissingValueConditionResultFollowGlobalImputation()) {
return false;
}

if (gbt_model->NumTrees() > serving::decision_forest::internal::
QuickScorerExtendedModel::kMaxTrees) {
return false;
}

for (const auto& src_tree : gbt_model->decision_trees()) {
if (src_tree->NumLeafs() > serving::decision_forest::internal::
QuickScorerExtendedModel::kMaxLeafs) {
return false;
}
}

switch (gbt_model->task()) {
case proto::CLASSIFICATION:
return gbt_model->label_col_spec()
.categorical()
.number_of_unique_values() == 3;
case proto::REGRESSION:
case proto::RANKING:
return true;
default:
return false;
}
}

std::vector<std::string> IsBetterThan() const override {
return {serving::gradient_boosted_trees::kGeneric,
serving::gradient_boosted_trees::kOptPred};
}

utils::StatusOr<std::unique_ptr<serving::FastEngine>> CreateEngine(
const AbstractModel* const model) const override {
auto* gbt_model = dynamic_cast<const SourceModel*>(model);
if (!gbt_model) {
return absl::InvalidArgumentError("The model is not a GBDT.");
}

switch (gbt_model->task()) {
case proto::CLASSIFICATION:
if (gbt_model->label_col_spec()
.categorical()
.number_of_unique_values() == 3) {
// Binary classification.
auto engine = absl::make_unique<serving::ExampleSetModelWrapper<
serving::decision_forest::
GradientBoostedTreesBinaryClassificationQuickScorerExtended,
serving::decision_forest::Predict>>();
RETURN_IF_ERROR(engine->LoadModel<SourceModel>(*gbt_model));
return engine;
} else {
return absl::InvalidArgumentError("Non supported GBDT model");
}

case proto::REGRESSION: {
auto engine = absl::make_unique<serving::ExampleSetModelWrapper<
serving::decision_forest::
GradientBoostedTreesRegressionQuickScorerExtended,
serving::decision_forest::Predict>>();
RETURN_IF_ERROR(engine->LoadModel<SourceModel>(*gbt_model));
return engine;
}

case proto::RANKING: {
auto engine = absl::make_unique<serving::ExampleSetModelWrapper<
serving::decision_forest::
GradientBoostedTreesRankingQuickScorerExtended,
serving::decision_forest::Predict>>();
RETURN_IF_ERROR(engine->LoadModel<SourceModel>(*gbt_model));
return engine;
}

default:
return absl::InvalidArgumentError("Non supported GBDT model");
}
}
};

REGISTER_FastEngineFactory(
GradientBoostedTreesQuickScorerFastEngineFactory,
serving::gradient_boosted_trees::kQuickScorerExtended);

class GradientBoostedTreesOptPredFastEngineFactory : public FastEngineFactory {
public:
using SourceModel = gradient_boosted_trees::GradientBoostedTreesModel;

std::string name() const override {
return serving::gradient_boosted_trees::kOptPred;
}

bool IsCompatible(const AbstractModel* const model) const override {
auto* gbt_model = dynamic_cast<const SourceModel*>(model);
if (gbt_model == nullptr) {
return false;
}

if (!gbt_model->IsMissingValueConditionResultFollowGlobalImputation()) {
return false;
}

for (const auto& src_tree : gbt_model->decision_trees()) {
if (src_tree->NumLeafs() > std::numeric_limits<uint16_t>::max()) {
return false;
}
}

for (const auto feature_idx : gbt_model->input_features()) {
const auto& feature = gbt_model->data_spec().columns(feature_idx);
switch (feature.type()) {
case dataset::proto::ColumnType::NUMERICAL:
break;
case dataset::proto::ColumnType::CATEGORICAL:
if (feature.categorical().number_of_unique_values() > 32) {
return false;
}
break;
default:
return false;
}
}

switch (gbt_model->task()) {
case proto::CLASSIFICATION:
return gbt_model->label_col_spec()
.categorical()
.number_of_unique_values() == 3;
case proto::REGRESSION:
case proto::RANKING:
return true;
default:
return false;
}
}

std::vector<std::string> IsBetterThan() const override {
return {serving::gradient_boosted_trees::kGeneric};
}

utils::StatusOr<std::unique_ptr<serving::FastEngine>> CreateEngine(
const AbstractModel* const model) const override {
auto* gbt_model = dynamic_cast<const SourceModel*>(model);
if (!gbt_model) {
return absl::InvalidArgumentError("The model is not a GBDT.");
}

switch (gbt_model->task()) {
case proto::CLASSIFICATION:
if (gbt_model->label_col_spec()
.categorical()
.number_of_unique_values() == 3) {
// Binary classification.
auto engine = absl::make_unique<serving::ExampleSetModelWrapper<
serving::decision_forest::
GradientBoostedTreesBinaryClassificationNumericalAndCategorical,
serving::decision_forest::PredictWithExampleSet>>();
RETURN_IF_ERROR(engine->LoadModel<SourceModel>(*gbt_model));
return engine;
} else {
return absl::InvalidArgumentError("Non supported GBDT model");
}

case proto::REGRESSION: {
auto engine = absl::make_unique<serving::ExampleSetModelWrapper<
serving::decision_forest::
GradientBoostedTreesRegressionNumericalAndCategorical,
serving::decision_forest::PredictWithExampleSet>>();
RETURN_IF_ERROR(engine->LoadModel<SourceModel>(*gbt_model));
return engine;
}

case proto::RANKING: {
auto engine = absl::make_unique<serving::ExampleSetModelWrapper<
serving::decision_forest::
GradientBoostedTreesRankingNumericalAndCategorical,
serving::decision_forest::PredictWithExampleSet>>();
RETURN_IF_ERROR(engine->LoadModel<SourceModel>(*gbt_model));
return engine;
}

default:
return absl::InvalidArgumentError("Non supported GBDT model");
}
}
};

REGISTER_FastEngineFactory(GradientBoostedTreesOptPredFastEngineFactory,
serving::gradient_boosted_trees::kOptPred);

class RandomForestGenericFastEngineFactory : public model::FastEngineFactory {
public:
using SourceModel = random_forest::RandomForestModel;
Expand Down Expand Up @@ -222,5 +429,94 @@ class RandomForestGenericFastEngineFactory : public model::FastEngineFactory {
REGISTER_FastEngineFactory(RandomForestGenericFastEngineFactory,
serving::random_forest::kGeneric);

class RandomForestOptPredFastEngineFactory : public model::FastEngineFactory {
public:
using SourceModel = random_forest::RandomForestModel;

std::string name() const override { return serving::random_forest::kOptPred; }

bool IsCompatible(const AbstractModel* const model) const override {
auto* rf_model = dynamic_cast<const SourceModel*>(model);
// This implementation is the most generic and least efficient engine.
if (rf_model == nullptr) {
return false;
}
if (!rf_model->IsMissingValueConditionResultFollowGlobalImputation()) {
return false;
}

for (const auto& src_tree : rf_model->decision_trees()) {
if (src_tree->NumLeafs() > std::numeric_limits<uint16_t>::max()) {
return false;
}
}

for (const auto feature_idx : rf_model->input_features()) {
const auto& feature = rf_model->data_spec().columns(feature_idx);
switch (feature.type()) {
case dataset::proto::ColumnType::NUMERICAL:
break;
case dataset::proto::ColumnType::CATEGORICAL:
if (feature.categorical().number_of_unique_values() > 32) {
return false;
}
break;
default:
return false;
}
}

switch (rf_model->task()) {
case proto::CLASSIFICATION:
return rf_model->label_col_spec()
.categorical()
.number_of_unique_values() == 3;
case proto::REGRESSION:
case proto::RANKING:
return true;
default:
return false;
}
}

std::vector<std::string> IsBetterThan() const override {
return {serving::random_forest::kGeneric};
}

utils::StatusOr<std::unique_ptr<serving::FastEngine>> CreateEngine(
const AbstractModel* const model) const override {
auto* rf_model = dynamic_cast<const SourceModel*>(model);
if (!rf_model) {
return absl::InvalidArgumentError("The model is not a RF.");
}

switch (rf_model->task()) {
case model::proto::CLASSIFICATION: {
auto engine = absl::make_unique<serving::ExampleSetModelWrapper<
serving::decision_forest::
RandomForestBinaryClassificationNumericalAndCategoricalFeatures,
serving::decision_forest::PredictWithExampleSet>>();
RETURN_IF_ERROR(engine->LoadModel<SourceModel>(*rf_model));
return engine;
}

case model::proto::REGRESSION: {
auto engine = absl::make_unique<serving::ExampleSetModelWrapper<
serving::decision_forest::
RandomForestRegressionNumericalAndCategorical,
serving::decision_forest::PredictWithExampleSet>>();
RETURN_IF_ERROR(engine->LoadModel<SourceModel>(*rf_model));
return engine;
}

default:
return absl::InvalidArgumentError("Non supported RF model");
}
}
};

REGISTER_FastEngineFactory(RandomForestOptPredFastEngineFactory,
serving::random_forest::kOptPred);

} // namespace model
} // namespace yggdrasil_decision_forests

0 comments on commit e2c9a49

Please sign in to comment.