Skip to content

Commit

Permalink
Implement QuickSelect to speed-up MAE loss
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 571423918
  • Loading branch information
achoum authored and copybara-github committed Oct 6, 2023
1 parent 9c32afb commit 672cb63
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 3 deletions.
7 changes: 7 additions & 0 deletions yggdrasil_decision_forests/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,12 @@ cc_library_ydf(

cc_library_ydf(
name = "math",
srcs = ["math.cc"],
hdrs = ["math.h"],
deps = [
":logging",
"@com_google_absl//absl/types:span",
],
)

cc_library_ydf(
Expand Down Expand Up @@ -1126,6 +1131,8 @@ cc_test(
srcs = ["math_test.cc"],
deps = [
":math",
":test",
"@com_google_absl//absl/random",
"@com_google_googletest//:gtest_main",
],
)
Expand Down
80 changes: 77 additions & 3 deletions yggdrasil_decision_forests/utils/math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,80 @@

#include "yggdrasil_decision_forests/utils/math.h"

namespace yggdrasil_decision_forests {
namespace utils {} // namespace utils
} // namespace yggdrasil_decision_forests
#include <stddef.h>

#include <algorithm>
#include <limits>
#include <vector>

#include "absl/types/span.h"
#include "yggdrasil_decision_forests/utils/logging.h" // IWYU pragma: keep

namespace yggdrasil_decision_forests::utils {
namespace {

// Returns the "target_idx"-th smallest value in "values".
// "values" is a non-sorted array containing non infinite or nan values.
// The content of "values" is reordered during the computation.
// "values" cannot be empty.
float QuickSelect(std::vector<float>& values, size_t target_idx) {
// Boundaries of the search window.
size_t left = 0;
// Using a "right" instead of an "end" simplifies the code.
size_t right = values.size() - 1;

while (true) {
DCHECK_LE(left, right) << "The left index cannot move past the right index";
DCHECK_GE(target_idx, left) << "target_idx should be in [left, right]";
DCHECK_LE(target_idx, right) << "target_idx should be in [left, right]";

if (left == right) {
return values[left];
}

// Pivot the values around "pivot_selector".
// std::partition cannot be used as it does not guarantee that the pivoted
// value will be located at the output pivot index.
//
// Note: This code can be sped up using the Hoare algorithm.
using std::swap;
const float pivot_value = values[target_idx];
size_t pivot_idx = left;
swap(values[target_idx], values[right]);
for (size_t i = left; i < right; ++i) {
if (values[i] < pivot_value) {
swap(values[i], values[pivot_idx]);
pivot_idx++;
}
}
swap(values[pivot_idx], values[right]);

// Select the side containing the target index.
if (pivot_idx == target_idx) {
return values[pivot_idx];
} else if (target_idx < pivot_idx) {
right = pivot_idx - 1;
} else if (pivot_idx == 0) {
left = pivot_idx + 1;
}
}
}

} // namespace

float Median(const absl::Span<const float> values) {
if (values.empty()) {
return std::numeric_limits<float>::quiet_NaN();
}
std::vector<float> working_values = {values.begin(), values.end()};
const size_t half_size = working_values.size() / 2;
if (values.size() % 2 == 1) {
return QuickSelect(working_values, half_size);
} else {
return (QuickSelect(working_values, half_size) +
QuickSelect(working_values, half_size - 1)) /
2;
}
}

} // namespace yggdrasil_decision_forests::utils
10 changes: 10 additions & 0 deletions yggdrasil_decision_forests/utils/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

#include <type_traits>

#include "absl/types/span.h"

namespace yggdrasil_decision_forests {
namespace utils {

Expand All @@ -31,6 +33,14 @@ T CeilDiV(T x, T y) {
return (x + y - 1) / y;
}

// Computes the median of "values".
//
// Uses the Quick Select algorithm: The average time and space complexity is
// linear. If the number of values is event, return the average of the two
// median values. If empty, returns NaN. "values" should not contain NaNs or
// Infs.
float Median(absl::Span<const float> values);

} // namespace utils
} // namespace yggdrasil_decision_forests

Expand Down
56 changes: 56 additions & 0 deletions yggdrasil_decision_forests/utils/math_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@

#include "yggdrasil_decision_forests/utils/math.h"

#include <algorithm>
#include <cmath>
#include <random>
#include <vector>

#include "gtest/gtest.h"
#include "absl/random/random.h"

namespace yggdrasil_decision_forests {
namespace utils {
Expand All @@ -36,6 +42,56 @@ TEST(CeilDiV, Base) {
EXPECT_EQ(CeilDiV(16, 5), 4);
}

TEST(Median, Empty) { EXPECT_TRUE(std::isnan(Median({}))); }

TEST(Median, Base) {
EXPECT_EQ(Median({1.f}), 1.f);

EXPECT_EQ(Median({1.f, 2.f}), 1.5f);
EXPECT_EQ(Median({2.f, 1.f}), 1.5f);

EXPECT_EQ(Median({1.f, 2.f, 3.f}), 2.f);
EXPECT_EQ(Median({2.f, 3.f, 1.f}), 2.f);
EXPECT_EQ(Median({3.f, 1.f, 2.f}), 2.f);

EXPECT_EQ(Median({3.f, 4.f, 1.f, 2.f}), 2.5f);
}

TEST(Median, Duplicates) {
EXPECT_EQ(Median({2.f, 2.f}), 2.f);
EXPECT_EQ(Median({2.f, 2.f, 2.f}), 2.f);
EXPECT_EQ(Median({2.f, 2.f, 1.f}), 2.f);
EXPECT_EQ(Median({3.f, 2.f, 2.f}), 2.f);
}

class MedianRandomTest : public testing::TestWithParam<int> {};

TEST_P(MedianRandomTest, Base) {
const int n = GetParam();
EXPECT_GE(n, 1);

// Generate some data
absl::BitGen rnd;
std::vector<float> values(n);
std::generate(values.begin(), values.end(), [&rnd]() {
return std::uniform_real_distribution<float>()(rnd);
});

const float median = Median(values);

// Check median results against n log n algorithm.
std::sort(values.begin(), values.end());
if ((n % 2) == 0) {
// Event
EXPECT_EQ(median, (values[n / 2] + values[n / 2 - 1]) / 2);
} else {
// Odd
EXPECT_EQ(median, values[n / 2]);
}
}

INSTANTIATE_TEST_SUITE_P(Event, MedianRandomTest, testing::Values(2, 10, 50));
INSTANTIATE_TEST_SUITE_P(Odd, MedianRandomTest, testing::Values(1, 51, 101));
} // namespace
} // namespace utils
} // namespace yggdrasil_decision_forests

0 comments on commit 672cb63

Please sign in to comment.