diff --git a/yggdrasil_decision_forests/model/abstract_model.cc b/yggdrasil_decision_forests/model/abstract_model.cc index 13574680..61de571b 100644 --- a/yggdrasil_decision_forests/model/abstract_model.cc +++ b/yggdrasil_decision_forests/model/abstract_model.cc @@ -1270,78 +1270,94 @@ absl::Status AbstractModel::Validate() const { std::vector> AbstractModel::ListCompatibleFastEngines() const { - std::vector> compatible_engines; + // Index the compatible engines. + struct Item { + std::unique_ptr factory; + absl::flat_hash_set is_better_than; + }; + std::vector items; + for (auto& factory : ListAllFastEngines()) { if (!factory->IsCompatible(this)) { continue; } - compatible_engines.push_back(std::move(factory)); + const auto is_better_than = factory->IsBetterThan(); + items.push_back(Item{std::move(factory), + {is_better_than.begin(), is_better_than.end()}}); + } + + // Sort the engine by speed (fastest first). + std::sort(items.begin(), items.end(), [](const Item& a, const Item& b) { + if (a.is_better_than.find(b.factory->name()) != a.is_better_than.end()) { + // "a" is better than "b". + return true; + } + return false; + }); + + std::vector> compatible_engines; + compatible_engines.reserve(items.size()); + for (auto& item : items) { + compatible_engines.push_back(std::move(item.factory)); + } + return compatible_engines; +} + +std::vector AbstractModel::ListCompatibleFastEngineNames() const { + std::vector compatible_engines; + for (auto& factory : ListCompatibleFastEngines()) { + compatible_engines.push_back(factory->name()); } return compatible_engines; } absl::StatusOr> -AbstractModel::BuildFastEngine() const { +AbstractModel::BuildFastEngine( + const absl::optional& force_engine_name) const { if (!allow_fast_engine_) { return absl::NotFoundError("allow_fast_engine is set to false."); } // List the compatible engines. - auto compatible_engines = ListCompatibleFastEngines(); - - // Each engine in this set is slower than at least one other compatible - // engine. - absl::flat_hash_set all_is_better_than; - - for (auto& engine_factory : compatible_engines) { - const auto is_better_than = engine_factory->IsBetterThan(); - all_is_better_than.insert(is_better_than.begin(), is_better_than.end()); - } - - const auto no_compatible_engine_message = absl::Substitute( - "No compatible engine available for model $0. 1) Make sure the " - "corresponding engine is added as a dependency, 2) use the (slow) " - "generic engine (i.e. \"model.Predict()\") or 3) use one of the fast " - "non-generic engines available in ../serving.", - name()); - if (compatible_engines.empty()) { - return absl::NotFoundError(no_compatible_engine_message); - } - - // Select the best engine. - std::vector> best_engines; - for (auto& compatible_engine : compatible_engines) { - if (all_is_better_than.find(compatible_engine->name()) != - all_is_better_than.end()) { - // One other engine is better than this engine. - continue; + auto sorted_compatible_engines = ListCompatibleFastEngines(); + + // How to create the engine. + std::unique_ptr engine_factory; + if (force_engine_name.has_value()) { + for (auto& compatible_engine : sorted_compatible_engines) { + if (compatible_engine->name() == *force_engine_name) { + engine_factory = std::move(compatible_engine); + break; + } + } + if (!engine_factory) { + return absl::NotFoundError(absl::StrCat( + "The forced engine \"", *force_engine_name, + "\" does not exist or is not compatible with the model")); } - best_engines.push_back(std::move(compatible_engine)); - } - - std::unique_ptr best_engine; - if (best_engines.empty()) { - // No engine is better than all the other engines. - YDF_LOG(WARNING) << "Circular is_better relation between engines."; - best_engine = std::move(compatible_engines.front()); } else { - if (best_engines.size() > 1) { - // Multiple engines are "the best". - YDF_LOG(WARNING) - << "Non complete relation between engines. Cannot select the " - "best one. One engine selected randomly."; + if (sorted_compatible_engines.empty()) { + return absl::NotFoundError(absl::Substitute( + "No compatible engine available for model $0. 1)interresting Make " + "sure the " + "corresponding engine is added as a dependency, 2) use the (slow) " + "generic engine (i.e. \"model.Predict()\") or 3) use one of the fast " + "non-generic engines available in ../serving.", + name())); } - best_engine = std::move(best_engines.front()); + + // Select the best engine. + engine_factory = std::move(sorted_compatible_engines.front()); } - auto engine_or = best_engine->CreateEngine(this); + auto engine_or = engine_factory->CreateEngine(this); if (!engine_or.ok()) { - YDF_LOG(WARNING) << "The engine \"" << best_engine->name() + YDF_LOG(WARNING) << "The engine \"" << engine_factory->name() << "\" is compatible but could not be created: " << engine_or.status().message(); } else { - LOG_INFO_EVERY_N_SEC(10, - _ << "Engine \"" << best_engine->name() << "\" built"); + LOG_INFO_EVERY_N_SEC( + 10, _ << "Engine \"" << engine_factory->name() << "\" built"); } return engine_or; } diff --git a/yggdrasil_decision_forests/model/abstract_model.h b/yggdrasil_decision_forests/model/abstract_model.h index 2b2f3aa9..5334e78a 100644 --- a/yggdrasil_decision_forests/model/abstract_model.h +++ b/yggdrasil_decision_forests/model/abstract_model.h @@ -95,12 +95,24 @@ class AbstractModel { // // Because "BuildFastEngine" uses virtual calls, this solution is slower // than selecting directly the inference engine at compile time. - absl::StatusOr> BuildFastEngine() const; - - // List the fast engines compatible with the model. + // + // If specified, "force_engine_name" is the name of the created engine. + // If "force_engine_name" is not specified, create the fastest compatible + // engine. + absl::StatusOr> BuildFastEngine( + const absl::optional& force_engine_name = {}) const; + + // Lists the fast engines compatible with the model. + // Engines are sorted by decreasing expected speed i.e., for the fastest + // inference, use the first one. std::vector> ListCompatibleFastEngines() const; + // Lists the names of fast engines compatible with the model. + // Engines are sorted by decreasing expected speed i.e., for the fastest + // inference, use the first one. + std::vector ListCompatibleFastEngineNames() const; + // If set to "False", "BuildFastEngine" won't return an engine, even if one if // available. void SetAllowFastEngine(const bool allow_fast_engine) { diff --git a/yggdrasil_decision_forests/port/python/ydf/cc/ydf.pyi b/yggdrasil_decision_forests/port/python/ydf/cc/ydf.pyi index ef2e41dd..c3a8c4e0 100644 --- a/yggdrasil_decision_forests/port/python/ydf/cc/ydf.pyi +++ b/yggdrasil_decision_forests/port/python/ydf/cc/ydf.pyi @@ -144,6 +144,8 @@ class GenericCCModel: def VariableImportances( self, ) -> Dict[str, abstract_model_pb2.VariableImportanceSet]: ... + def ForceEngine(self, engine_name: Optional[str]) -> None: ... + def ListCompatibleEngines(self) -> Sequence[str]: ... class DecisionForestCCModel(GenericCCModel): def num_trees(self) -> int: ... diff --git a/yggdrasil_decision_forests/port/python/ydf/model/generic_model.py b/yggdrasil_decision_forests/port/python/ydf/model/generic_model.py index 18566848..7c75cc6f 100644 --- a/yggdrasil_decision_forests/port/python/ydf/model/generic_model.py +++ b/yggdrasil_decision_forests/port/python/ydf/model/generic_model.py @@ -812,5 +812,31 @@ def self_evaluation(self) -> metric.Evaluation: "Self-evaluation is not available for this model type." ) + def list_compatible_engines(self) -> Sequence[str]: + """Lists the inference engines compatible with the model. + + The engines are sorted to likely-fastest to likely-slowest. + + Returns: + List of compatible engines. + """ + return self._model.ListCompatibleEngines() + + def force_engine(self, engine_name: Optional[str]) -> None: + """Forces the engines used by the model. + + If not specified (i.e., None; default value), the fastest compatible engine + (i.e., the first value returned from "list_compatible_engines") is used for + all model inferences (e.g., model.predict, model.evaluate). + + If passing a non-existing or non-compatible engine, the next model inference + (e.g., model.predict, model.evaluate) will fail. + + Args: + engine_name: Name of a compatible engine or None to automatically select + the fastest engine. + """ + self._model.ForceEngine(engine_name) + ModelType = TypeVar("ModelType", bound=GenericModel) diff --git a/yggdrasil_decision_forests/port/python/ydf/model/model.cc b/yggdrasil_decision_forests/port/python/ydf/model/model.cc index 8c461c7b..00e35840 100644 --- a/yggdrasil_decision_forests/port/python/ydf/model/model.cc +++ b/yggdrasil_decision_forests/port/python/ydf/model/model.cc @@ -137,7 +137,9 @@ void init_model(py::module_& m) { .def("Benchmark", WithStatusOr(&GenericCCModel::Benchmark), py::arg("dataset"), py::arg("benchmark_duration"), py::arg("warmup_duration"), py::arg("batch_size")) - .def("VariableImportances", &GenericCCModel::VariableImportances); + .def("VariableImportances", &GenericCCModel::VariableImportances) + .def("ForceEngine", &GenericCCModel::ForceEngine, py::arg("engine_name")) + .def("ListCompatibleEngines", &GenericCCModel::ListCompatibleEngines); py::class_(m, "BenchmarkInferenceCCResult") .def_readwrite("duration_per_example", diff --git a/yggdrasil_decision_forests/port/python/ydf/model/model_test.py b/yggdrasil_decision_forests/port/python/ydf/model/model_test.py index 29caffb0..c7698529 100644 --- a/yggdrasil_decision_forests/port/python/ydf/model/model_test.py +++ b/yggdrasil_decision_forests/port/python/ydf/model/model_test.py @@ -23,6 +23,7 @@ from absl.testing import absltest from absl.testing import parameterized +import numpy as np import numpy.testing as npt import pandas as pd @@ -410,6 +411,43 @@ def test_empty_self_evaluation_rf(self): model = model_lib.load_model(model_path) self.assertIsNone(model.self_evaluation()) + def test_gbt_list_compatible_engines(self): + self.assertContainsSubsequence( + self.adult_binary_class_gbdt.list_compatible_engines(), + ["GradientBoostedTreesGeneric"], + ) + + def test_rf_list_compatible_engines(self): + self.assertContainsSubsequence( + self.adult_binary_class_rf.list_compatible_engines(), + ["RandomForestGeneric"], + ) + + def test_gbt_force_compatible_engines(self): + test_df = pd.read_csv( + os.path.join( + test_utils.ydf_test_data_path(), "dataset", "adult_test.csv" + ) + ) + p1 = self.adult_binary_class_gbdt.predict(test_df) + self.adult_binary_class_gbdt.force_engine("GradientBoostedTreesGeneric") + p2 = self.adult_binary_class_gbdt.predict(test_df) + self.adult_binary_class_gbdt.force_engine(None) + p3 = self.adult_binary_class_gbdt.predict(test_df) + + np.testing.assert_allclose( + p1, + p2, + rtol=1e-5, + atol=1e-5, + ) + np.testing.assert_allclose( + p1, + p3, + rtol=1e-5, + atol=1e-5, + ) + if __name__ == "__main__": absltest.main() diff --git a/yggdrasil_decision_forests/port/python/ydf/model/model_wrapper.cc b/yggdrasil_decision_forests/port/python/ydf/model/model_wrapper.cc index 98aa25be..c197b96e 100644 --- a/yggdrasil_decision_forests/port/python/ydf/model/model_wrapper.cc +++ b/yggdrasil_decision_forests/port/python/ydf/model/model_wrapper.cc @@ -55,7 +55,7 @@ GenericCCModel::GetEngine() { utils::concurrency::MutexLock lock(&engine_mutex_); if (engine_ == nullptr || invalidate_engine_) { RETURN_IF_ERROR(model_->Validate()); - ASSIGN_OR_RETURN(engine_, model_->BuildFastEngine()); + ASSIGN_OR_RETURN(engine_, model_->BuildFastEngine(force_engine_name_)); invalidate_engine_ = false; } return engine_; diff --git a/yggdrasil_decision_forests/port/python/ydf/model/model_wrapper.h b/yggdrasil_decision_forests/port/python/ydf/model/model_wrapper.h index 04b72191..eacff9e3 100644 --- a/yggdrasil_decision_forests/port/python/ydf/model/model_wrapper.h +++ b/yggdrasil_decision_forests/port/python/ydf/model/model_wrapper.h @@ -139,14 +139,29 @@ class GenericCCModel { void invalidate_engine() { invalidate_engine_ = true; } + void ForceEngine(std::optional engine_name) { + // TODO: Let the user configure inference without an engine. + utils::concurrency::MutexLock lock(&engine_mutex_); + force_engine_name_ = engine_name; + invalidate_engine(); + } + + std::vector ListCompatibleEngines() const { + return model_->ListCompatibleFastEngineNames(); + } + protected: std::unique_ptr model_; utils::concurrency::Mutex engine_mutex_; std::shared_ptr engine_ GUARDED_BY(engine_mutex_); - // If true, the "engine_mutex_" is outdated (e.g., the model was modified) and + // If true, the "engine_" is outdated (e.g., the model was modified) and // should be re-computed. std::atomic_bool invalidate_engine_{false}; + + // If set, for the creation of this specific engine. If non set, fastest + // compatible engine is created. + std::optional force_engine_name_; }; } // namespace yggdrasil_decision_forests::port::python