Skip to content

Commit

Permalink
Make it possible to list the compatible engine and to select the one …
Browse files Browse the repository at this point in the history
…used by the model.

PiperOrigin-RevId: 616134404
  • Loading branch information
achoum authored and copybara-github committed Mar 15, 2024
1 parent 1d8bb9c commit 89dae4f
Show file tree
Hide file tree
Showing 8 changed files with 166 additions and 55 deletions.
114 changes: 65 additions & 49 deletions yggdrasil_decision_forests/model/abstract_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1270,78 +1270,94 @@ absl::Status AbstractModel::Validate() const {

std::vector<std::unique_ptr<FastEngineFactory>>
AbstractModel::ListCompatibleFastEngines() const {
std::vector<std::unique_ptr<FastEngineFactory>> compatible_engines;
// Index the compatible engines.
struct Item {
std::unique_ptr<FastEngineFactory> factory;
absl::flat_hash_set<std::string> is_better_than;
};
std::vector<Item> items;

for (auto& factory : ListAllFastEngines()) {
if (!factory->IsCompatible(this)) {
continue;
}
compatible_engines.push_back(std::move(factory));
const auto is_better_than = factory->IsBetterThan();
items.push_back(Item{std::move(factory),
{is_better_than.begin(), is_better_than.end()}});
}

// Sort the engine by speed (fastest first).
std::sort(items.begin(), items.end(), [](const Item& a, const Item& b) {
if (a.is_better_than.find(b.factory->name()) != a.is_better_than.end()) {
// "a" is better than "b".
return true;
}
return false;
});

std::vector<std::unique_ptr<FastEngineFactory>> compatible_engines;
compatible_engines.reserve(items.size());
for (auto& item : items) {
compatible_engines.push_back(std::move(item.factory));
}
return compatible_engines;
}

std::vector<std::string> AbstractModel::ListCompatibleFastEngineNames() const {
std::vector<std::string> compatible_engines;
for (auto& factory : ListCompatibleFastEngines()) {
compatible_engines.push_back(factory->name());
}
return compatible_engines;
}

absl::StatusOr<std::unique_ptr<serving::FastEngine>>
AbstractModel::BuildFastEngine() const {
AbstractModel::BuildFastEngine(
const absl::optional<std::string>& force_engine_name) const {
if (!allow_fast_engine_) {
return absl::NotFoundError("allow_fast_engine is set to false.");
}

// List the compatible engines.
auto compatible_engines = ListCompatibleFastEngines();

// Each engine in this set is slower than at least one other compatible
// engine.
absl::flat_hash_set<std::string> all_is_better_than;

for (auto& engine_factory : compatible_engines) {
const auto is_better_than = engine_factory->IsBetterThan();
all_is_better_than.insert(is_better_than.begin(), is_better_than.end());
}

const auto no_compatible_engine_message = absl::Substitute(
"No compatible engine available for model $0. 1) Make sure the "
"corresponding engine is added as a dependency, 2) use the (slow) "
"generic engine (i.e. \"model.Predict()\") or 3) use one of the fast "
"non-generic engines available in ../serving.",
name());
if (compatible_engines.empty()) {
return absl::NotFoundError(no_compatible_engine_message);
}

// Select the best engine.
std::vector<std::unique_ptr<FastEngineFactory>> best_engines;
for (auto& compatible_engine : compatible_engines) {
if (all_is_better_than.find(compatible_engine->name()) !=
all_is_better_than.end()) {
// One other engine is better than this engine.
continue;
auto sorted_compatible_engines = ListCompatibleFastEngines();

// How to create the engine.
std::unique_ptr<FastEngineFactory> engine_factory;
if (force_engine_name.has_value()) {
for (auto& compatible_engine : sorted_compatible_engines) {
if (compatible_engine->name() == *force_engine_name) {
engine_factory = std::move(compatible_engine);
break;
}
}
if (!engine_factory) {
return absl::NotFoundError(absl::StrCat(
"The forced engine \"", *force_engine_name,
"\" does not exist or is not compatible with the model"));
}
best_engines.push_back(std::move(compatible_engine));
}

std::unique_ptr<FastEngineFactory> best_engine;
if (best_engines.empty()) {
// No engine is better than all the other engines.
YDF_LOG(WARNING) << "Circular is_better relation between engines.";
best_engine = std::move(compatible_engines.front());
} else {
if (best_engines.size() > 1) {
// Multiple engines are "the best".
YDF_LOG(WARNING)
<< "Non complete relation between engines. Cannot select the "
"best one. One engine selected randomly.";
if (sorted_compatible_engines.empty()) {
return absl::NotFoundError(absl::Substitute(
"No compatible engine available for model $0. 1)interresting Make "
"sure the "
"corresponding engine is added as a dependency, 2) use the (slow) "
"generic engine (i.e. \"model.Predict()\") or 3) use one of the fast "
"non-generic engines available in ../serving.",
name()));
}
best_engine = std::move(best_engines.front());

// Select the best engine.
engine_factory = std::move(sorted_compatible_engines.front());
}

auto engine_or = best_engine->CreateEngine(this);
auto engine_or = engine_factory->CreateEngine(this);
if (!engine_or.ok()) {
YDF_LOG(WARNING) << "The engine \"" << best_engine->name()
YDF_LOG(WARNING) << "The engine \"" << engine_factory->name()
<< "\" is compatible but could not be created: "
<< engine_or.status().message();
} else {
LOG_INFO_EVERY_N_SEC(10,
_ << "Engine \"" << best_engine->name() << "\" built");
LOG_INFO_EVERY_N_SEC(
10, _ << "Engine \"" << engine_factory->name() << "\" built");
}
return engine_or;
}
Expand Down
18 changes: 15 additions & 3 deletions yggdrasil_decision_forests/model/abstract_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,24 @@ class AbstractModel {
//
// Because "BuildFastEngine" uses virtual calls, this solution is slower
// than selecting directly the inference engine at compile time.
absl::StatusOr<std::unique_ptr<serving::FastEngine>> BuildFastEngine() const;

// List the fast engines compatible with the model.
//
// If specified, "force_engine_name" is the name of the created engine.
// If "force_engine_name" is not specified, create the fastest compatible
// engine.
absl::StatusOr<std::unique_ptr<serving::FastEngine>> BuildFastEngine(
const absl::optional<std::string>& force_engine_name = {}) const;

// Lists the fast engines compatible with the model.
// Engines are sorted by decreasing expected speed i.e., for the fastest
// inference, use the first one.
std::vector<std::unique_ptr<FastEngineFactory>> ListCompatibleFastEngines()
const;

// Lists the names of fast engines compatible with the model.
// Engines are sorted by decreasing expected speed i.e., for the fastest
// inference, use the first one.
std::vector<std::string> ListCompatibleFastEngineNames() const;

// If set to "False", "BuildFastEngine" won't return an engine, even if one if
// available.
void SetAllowFastEngine(const bool allow_fast_engine) {
Expand Down
2 changes: 2 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/cc/ydf.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ class GenericCCModel:
def VariableImportances(
self,
) -> Dict[str, abstract_model_pb2.VariableImportanceSet]: ...
def ForceEngine(self, engine_name: Optional[str]) -> None: ...
def ListCompatibleEngines(self) -> Sequence[str]: ...

class DecisionForestCCModel(GenericCCModel):
def num_trees(self) -> int: ...
Expand Down
26 changes: 26 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/model/generic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,5 +812,31 @@ def self_evaluation(self) -> metric.Evaluation:
"Self-evaluation is not available for this model type."
)

def list_compatible_engines(self) -> Sequence[str]:
"""Lists the inference engines compatible with the model.
The engines are sorted to likely-fastest to likely-slowest.
Returns:
List of compatible engines.
"""
return self._model.ListCompatibleEngines()

def force_engine(self, engine_name: Optional[str]) -> None:
"""Forces the engines used by the model.
If not specified (i.e., None; default value), the fastest compatible engine
(i.e., the first value returned from "list_compatible_engines") is used for
all model inferences (e.g., model.predict, model.evaluate).
If passing a non-existing or non-compatible engine, the next model inference
(e.g., model.predict, model.evaluate) will fail.
Args:
engine_name: Name of a compatible engine or None to automatically select
the fastest engine.
"""
self._model.ForceEngine(engine_name)


ModelType = TypeVar("ModelType", bound=GenericModel)
4 changes: 3 additions & 1 deletion yggdrasil_decision_forests/port/python/ydf/model/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ void init_model(py::module_& m) {
.def("Benchmark", WithStatusOr(&GenericCCModel::Benchmark),
py::arg("dataset"), py::arg("benchmark_duration"),
py::arg("warmup_duration"), py::arg("batch_size"))
.def("VariableImportances", &GenericCCModel::VariableImportances);
.def("VariableImportances", &GenericCCModel::VariableImportances)
.def("ForceEngine", &GenericCCModel::ForceEngine, py::arg("engine_name"))
.def("ListCompatibleEngines", &GenericCCModel::ListCompatibleEngines);

py::class_<BenchmarkInferenceCCResult>(m, "BenchmarkInferenceCCResult")
.def_readwrite("duration_per_example",
Expand Down
38 changes: 38 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/model/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import numpy.testing as npt
import pandas as pd

Expand Down Expand Up @@ -410,6 +411,43 @@ def test_empty_self_evaluation_rf(self):
model = model_lib.load_model(model_path)
self.assertIsNone(model.self_evaluation())

def test_gbt_list_compatible_engines(self):
self.assertContainsSubsequence(
self.adult_binary_class_gbdt.list_compatible_engines(),
["GradientBoostedTreesGeneric"],
)

def test_rf_list_compatible_engines(self):
self.assertContainsSubsequence(
self.adult_binary_class_rf.list_compatible_engines(),
["RandomForestGeneric"],
)

def test_gbt_force_compatible_engines(self):
test_df = pd.read_csv(
os.path.join(
test_utils.ydf_test_data_path(), "dataset", "adult_test.csv"
)
)
p1 = self.adult_binary_class_gbdt.predict(test_df)
self.adult_binary_class_gbdt.force_engine("GradientBoostedTreesGeneric")
p2 = self.adult_binary_class_gbdt.predict(test_df)
self.adult_binary_class_gbdt.force_engine(None)
p3 = self.adult_binary_class_gbdt.predict(test_df)

np.testing.assert_allclose(
p1,
p2,
rtol=1e-5,
atol=1e-5,
)
np.testing.assert_allclose(
p1,
p3,
rtol=1e-5,
atol=1e-5,
)


if __name__ == "__main__":
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ GenericCCModel::GetEngine() {
utils::concurrency::MutexLock lock(&engine_mutex_);
if (engine_ == nullptr || invalidate_engine_) {
RETURN_IF_ERROR(model_->Validate());
ASSIGN_OR_RETURN(engine_, model_->BuildFastEngine());
ASSIGN_OR_RETURN(engine_, model_->BuildFastEngine(force_engine_name_));
invalidate_engine_ = false;
}
return engine_;
Expand Down
17 changes: 16 additions & 1 deletion yggdrasil_decision_forests/port/python/ydf/model/model_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,29 @@ class GenericCCModel {

void invalidate_engine() { invalidate_engine_ = true; }

void ForceEngine(std::optional<std::string> engine_name) {
// TODO: Let the user configure inference without an engine.
utils::concurrency::MutexLock lock(&engine_mutex_);
force_engine_name_ = engine_name;
invalidate_engine();
}

std::vector<std::string> ListCompatibleEngines() const {
return model_->ListCompatibleFastEngineNames();
}

protected:
std::unique_ptr<model::AbstractModel> model_;
utils::concurrency::Mutex engine_mutex_;
std::shared_ptr<const serving::FastEngine> engine_ GUARDED_BY(engine_mutex_);

// If true, the "engine_mutex_" is outdated (e.g., the model was modified) and
// If true, the "engine_" is outdated (e.g., the model was modified) and
// should be re-computed.
std::atomic_bool invalidate_engine_{false};

// If set, for the creation of this specific engine. If non set, fastest
// compatible engine is created.
std::optional<std::string> force_engine_name_;
};

} // namespace yggdrasil_decision_forests::port::python
Expand Down

0 comments on commit 89dae4f

Please sign in to comment.