Skip to content

Commit

Permalink
Add unit test to test the concurrent compilation of models.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 572529090
  • Loading branch information
achoum authored and copybara-github committed Oct 11, 2023
1 parent d6ebc01 commit 76b7562
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
3 changes: 3 additions & 0 deletions yggdrasil_decision_forests/serving/decision_forest/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,17 @@ cc_test(
"//yggdrasil_decision_forests/model/decision_tree:decision_tree_cc_proto",
"//yggdrasil_decision_forests/model/gradient_boosted_trees",
"//yggdrasil_decision_forests/model/gradient_boosted_trees:gradient_boosted_trees_cc_proto",
"//yggdrasil_decision_forests/utils:concurrency",
"//yggdrasil_decision_forests/utils:csv",
"//yggdrasil_decision_forests/utils:distribution_cc_proto",
"//yggdrasil_decision_forests/utils:filesystem",
"//yggdrasil_decision_forests/utils:logging",
"//yggdrasil_decision_forests/utils:status_macros",
"//yggdrasil_decision_forests/utils:test",
"//yggdrasil_decision_forests/utils:test_utils",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_googletest//:gtest_main",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
#include <algorithm>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/flags/flag.h"
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
Expand All @@ -36,16 +39,19 @@
#include "yggdrasil_decision_forests/model/abstract_model.h"
#include "yggdrasil_decision_forests/model/decision_tree/decision_tree.h"
#include "yggdrasil_decision_forests/model/decision_tree/decision_tree.pb.h"
#include "yggdrasil_decision_forests/model/fast_engine_factory.h"
#include "yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.h"
#include "yggdrasil_decision_forests/model/gradient_boosted_trees/gradient_boosted_trees.pb.h"
#include "yggdrasil_decision_forests/model/model_library.h"
#include "yggdrasil_decision_forests/model/prediction.pb.h"
#include "yggdrasil_decision_forests/serving/decision_forest/quick_scorer_extended.h"
#include "yggdrasil_decision_forests/serving/decision_forest/register_engines.h"
#include "yggdrasil_decision_forests/utils/concurrency.h" // IWYU pragma: keep
#include "yggdrasil_decision_forests/utils/csv.h"
#include "yggdrasil_decision_forests/utils/distribution.pb.h"
#include "yggdrasil_decision_forests/utils/filesystem.h"
#include "yggdrasil_decision_forests/utils/logging.h"
#include "yggdrasil_decision_forests/utils/status_macros.h" // IWYU pragma: keep
#include "yggdrasil_decision_forests/utils/test.h"
#include "yggdrasil_decision_forests/utils/test_utils.h"

Expand Down Expand Up @@ -242,6 +248,44 @@ TEST_P(AllCompatibleEnginesTest, Automtac) {
CheckCompatibleEngine(*model, GetParam().expected_engines);
}

// Concurrently compiles 100 DF models on 10 threads for each engine.
TEST_P(AllCompatibleEnginesTest, Concurent) {
const std::unique_ptr<model::AbstractModel> model =
LoadModel(GetParam().model);
const dataset::VerticalDataset dataset = LoadDataset(
model->data_spec(), GetParam().dataset, GetParam().dataset_format);
const std::vector<std::unique_ptr<model::FastEngineFactory>> factories =
model->ListCompatibleFastEngines();

const auto compile_model =
[model = model.get()](
const model::FastEngineFactory* factory) -> absl::Status {
ASSIGN_OR_RETURN(const auto engine, factory->CreateEngine(model));
// Note: Discard the engine.
return absl::OkStatus();
};

const int num_repetitions = 100;

for (const auto& factory : factories) {
utils::concurrency::StreamProcessor<const model::FastEngineFactory*,
absl::Status>
processor("model compiler", /*num_threads=*/10, compile_model);
processor.StartWorkers();
for (int rep_idx = 0; rep_idx < num_repetitions; rep_idx++) {
processor.Submit(factory.get());
}
processor.CloseSubmits();

for (int rep_idx = 0; rep_idx < num_repetitions; rep_idx++) {
const absl::optional<absl::Status> result = processor.GetResult();
ASSERT_TRUE(result.has_value());
EXPECT_OK(*result);
}
processor.JoinAllAndStopThreads();
}
}

TEST_P(AllCompatibleEnginesTest, AutomtacForceCheckFail) {
const auto model = LoadModel(GetParam().model);

Expand Down

0 comments on commit 76b7562

Please sign in to comment.