Skip to content

Commit

Permalink
Fix confusion matrix display with floating point weights.
Browse files Browse the repository at this point in the history
Also, add option for weighted evaluation.

PiperOrigin-RevId: 649624588
  • Loading branch information
achoum authored and copybara-github committed Jul 5, 2024
1 parent f036610 commit 4669c72
Show file tree
Hide file tree
Showing 12 changed files with 180 additions and 38 deletions.
5 changes: 5 additions & 0 deletions yggdrasil_decision_forests/port/python/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
- `model.to_tensorflow_saved_model` support preprocessing functions which have
a different signature than the YDF model.
- Improve error messages when feeding wrong size numpy arrays.
- Add option for weighted evaluation in `model.evaluate`.

### Fix

- Fix display of confusion matrix with floating point weights.

## 0.5.0 - 2024-06-17

Expand Down
1 change: 1 addition & 0 deletions yggdrasil_decision_forests/port/python/ydf/cc/ydf.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class GenericCCModel:
self,
dataset: VerticalDataset,
options: metric_pb2.EvaluationOptions,
weighted: bool,
) -> metric_pb2.EvaluationResults: ...
def Analyze(
self,
Expand Down
54 changes: 54 additions & 0 deletions yggdrasil_decision_forests/port/python/ydf/learner/learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,60 @@ def test_multidimensional_weights(self):
):
_ = learner.train(ds)

def test_weighted_training_and_evaluation(self):

def gen_ds(seed, n=10000):
np.random.seed(seed)
f1 = np.random.uniform(size=n)
f2 = np.random.uniform(size=n)
f3 = np.random.uniform(size=n)
weights = np.random.uniform(size=n)
return {
"f1": f1,
"f2": f2,
"f3": f3,
"label": (
# Make the examples with high weights harder to predict.
f1 + f2 * 0.5 + f3 * 0.5 + np.random.uniform(size=n) * weights
>= 1.5
),
"weights": weights,
}

model = specialized_learners.RandomForestLearner(
label="label",
weights="weights",
num_trees=300,
winner_take_all=False,
).train(gen_ds(0))

test_ds = gen_ds(1)

self_evaluation = model.self_evaluation()
non_weighted_evaluation = model.evaluate(test_ds, weighted=False)
weighted_evaluation = model.evaluate(test_ds, weighted=True)

self.assertIsNotNone(self_evaluation)
self.assertAlmostEqual(self_evaluation.accuracy, 0.824501, delta=0.005)
self.assertAlmostEqual(
non_weighted_evaluation.accuracy, 0.8417, delta=0.005
)
self.assertAlmostEqual(weighted_evaluation.accuracy, 0.8172290, delta=0.005)
predictions = model.predict(test_ds)

manual_non_weighted_evaluation = np.mean(
(predictions >= 0.5) == test_ds["label"]
)
manual_weighted_evaluation = np.sum(
((predictions >= 0.5) == test_ds["label"]) * test_ds["weights"]
) / np.sum(test_ds["weights"])
self.assertAlmostEqual(
manual_non_weighted_evaluation, non_weighted_evaluation.accuracy
)
self.assertAlmostEqual(
manual_weighted_evaluation, weighted_evaluation.accuracy
)

def test_learn_and_predict_when_label_is_not_last_column(self):
label = "age"
learner = specialized_learners.RandomForestLearner(
Expand Down
28 changes: 27 additions & 1 deletion yggdrasil_decision_forests/port/python/ydf/metric/metric_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,28 @@ def test_str(self):
"""),
)

def test_with_weights(self):
c = metric.ConfusionMatrix(
classes=(
"a",
"b",
),
matrix=np.array([[1.123456, 12345.6789], [3.456789, 4.5678901]]),
)
self.assertEqual(
str(c),
"""\
label (row) \\ prediction (col)
+---------+---------+---------+
| | a | b |
+---------+---------+---------+
| a | 1.12346 | 12345.7 |
+---------+---------+---------+
| b | 3.45679 | 4.56789 |
+---------+---------+---------+
""",
)


class SafeDivTest(absltest.TestCase):

Expand Down Expand Up @@ -151,7 +173,11 @@ def test_classification(self):
print(evaluation)
dict_eval = evaluation.to_dict()
self.assertLessEqual(
{"accuracy": (1 + 4) / (1 + 2 + 3 + 4), "loss": 2.0, "num_examples": 1}.items(),
{
"accuracy": (1 + 4) / (1 + 2 + 3 + 4),
"loss": 2.0,
"num_examples": 1,
}.items(),
dict_eval.items(),
)

Expand Down
1 change: 1 addition & 0 deletions yggdrasil_decision_forests/port/python/ydf/model/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pybind_library(
"@ydf_cc//yggdrasil_decision_forests/dataset:data_spec_cc_proto",
"@ydf_cc//yggdrasil_decision_forests/dataset:types",
"@ydf_cc//yggdrasil_decision_forests/dataset:vertical_dataset",
"@ydf_cc//yggdrasil_decision_forests/dataset:weight",
"@ydf_cc//yggdrasil_decision_forests/metric:metric_cc_proto",
"@ydf_cc//yggdrasil_decision_forests/model:abstract_model",
"@ydf_cc//yggdrasil_decision_forests/model:abstract_model_cc_proto",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,9 @@ def predict(self, data: dataset.InputDataset) -> np.ndarray:
def evaluate(
self,
data: dataset.InputDataset,
*,
bootstrapping: Union[bool, int] = False,
weighted: bool = False,
) -> metric.Evaluation:
"""Evaluates the quality of a model on a dataset.
Expand Down Expand Up @@ -400,6 +402,9 @@ def evaluate(
to an integer, it specifies the number of bootstrapping samples to use.
In this case, if the number is less than 100, an error is raised as
bootstrapping will not yield useful results.
weighted: If true, the evaluation is weighted according to the training
weights. If false, the evaluation is non-weighted. b/351279797: Change
default to weights=True.
Returns:
Model evaluation.
Expand All @@ -426,7 +431,9 @@ def evaluate(
task=self.task()._to_proto_type(), # pylint: disable=protected-access
)

evaluation_proto = self._model.Evaluate(ds._dataset, options_proto) # pylint: disable=protected-access
evaluation_proto = self._model.Evaluate(
ds._dataset, options_proto, weighted=weighted
) # pylint: disable=protected-access
return metric.Evaluation(evaluation_proto)

def analyze_prediction(
Expand Down
2 changes: 1 addition & 1 deletion yggdrasil_decision_forests/port/python/ydf/model/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ void init_model(py::module_& m) {
py::arg("dataset"))
// WARNING: This method releases the Global Interpreter Lock.
.def("Evaluate", WithStatusOr(&GenericCCModel::Evaluate),
py::arg("dataset"), py::arg("options"))
py::arg("dataset"), py::arg("options"), py::arg("weighted"))
// WARNING: This method releases the Global Interpreter Lock.
.def("Analyze", WithStatusOr(&GenericCCModel::Analyze),
py::arg("dataset"), py::arg("options"))
Expand Down
16 changes: 13 additions & 3 deletions yggdrasil_decision_forests/port/python/ydf/model/model_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "absl/types/span.h"
#include "yggdrasil_decision_forests/dataset/types.h"
#include "yggdrasil_decision_forests/dataset/vertical_dataset.h"
#include "yggdrasil_decision_forests/dataset/weight.h"
#include "yggdrasil_decision_forests/metric/metric.pb.h"
#include "yggdrasil_decision_forests/model/abstract_model.h"
#include "yggdrasil_decision_forests/model/describe.h"
Expand Down Expand Up @@ -115,12 +116,21 @@ absl::StatusOr<py::array_t<float>> GenericCCModel::Predict(

absl::StatusOr<metric::proto::EvaluationResults> GenericCCModel::Evaluate(
const dataset::VerticalDataset& dataset,
const metric::proto::EvaluationOptions& options) {
const metric::proto::EvaluationOptions& options, const bool weighted) {
py::gil_scoped_release release;

auto effective_options = options;
if (weighted && model_->weights().has_value()) {
ASSIGN_OR_RETURN(*effective_options.mutable_weights(),
dataset::GetUnlinkedWeightDefinition(
model_->weights().value(), model_->data_spec()));
}

ASSIGN_OR_RETURN(const auto engine, GetEngine());
utils::RandomEngine rnd;
ASSIGN_OR_RETURN(const auto evaluation,
model_->EvaluateWithEngine(*engine, dataset, options, &rnd));
ASSIGN_OR_RETURN(
const auto evaluation,
model_->EvaluateWithEngine(*engine, dataset, effective_options, &rnd));
return evaluation;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class GenericCCModel {

absl::StatusOr<metric::proto::EvaluationResults> Evaluate(
const dataset::VerticalDataset& dataset,
const metric::proto::EvaluationOptions& options);
const metric::proto::EvaluationOptions& options, bool weighted);

absl::StatusOr<utils::model_analysis::proto::StandaloneAnalysisResult>
Analyze(const dataset::VerticalDataset& dataset,
Expand Down
1 change: 1 addition & 0 deletions yggdrasil_decision_forests/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ cc_library_ydf(
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
"@com_google_absl//absl/strings:str_format",
],
)

Expand Down
70 changes: 39 additions & 31 deletions yggdrasil_decision_forests/utils/distribution.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "absl/status/status.h"
#include "absl/strings/cord.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "yggdrasil_decision_forests/dataset/data_spec.h"
#include "yggdrasil_decision_forests/dataset/data_spec.pb.h"
Expand Down Expand Up @@ -524,62 +525,69 @@ absl::Status IntegersConfusionMatrix<T>::AppendTextReport(
STATUS_CHECK_EQ(column_labels.size(), ncol());
STATUS_CHECK_EQ(row_labels.size(), nrow());

// Minimum margin (expressed in spaces) between displayed elements.
// Margin, in spaces, between items.
const int margin = 2;

// Maximum length of the values' string representation i.e. maximum of
// "row_labels".
int max_row_label_length = 0;
// Maximum string length of the row labels.
int max_length_row_labels = 0;
for (const auto& label : row_labels) {
if (label.size() > max_row_label_length) {
max_row_label_length = label.size();
if (label.size() > max_length_row_labels) {
max_length_row_labels = label.size();
}
}

// Maximum string length of the elements in each column.
// Converts a value (double) in the confusion matrix into a string to display.
const auto value_to_str = [](const double value) {
return absl::StrFormat("%.16g", value);
};

// Maximum string length of items per columns.
std::vector<int> max_length_per_col(ncol_);
for (int col = begin_col_idx; col < ncol_; col++) {
// Counts.
T max_value = 1;
// Values
int max_length = 0;
for (int row = begin_row_idx; row < nrow_; row++) {
const auto value = at(row, col);
if (value > max_value) {
max_value = value;
const auto value_length = value_to_str(at(row, col)).size();
if (value_length > max_length) {
max_length = value_length;
}
}
// Column header.
max_length_per_col[col] =
std::max(static_cast<int>(column_labels[col].size()),
static_cast<int>(std::floor(std::log10(max_value))) + 1);
}

// Print "value" to the end of "result" using a left margin (similar to
// "std::setw").
const auto print_string = [&](int length, absl::string_view value) {
const int preceding_spaces =
std::max(length - static_cast<int>(value.size()), 0);
absl::StrAppend(result, std::string(preceding_spaces, ' '), value);
// Column label
max_length =
std::max(max_length, static_cast<int>(column_labels[col].size()));

max_length_per_col[col] = max_length + margin;
}

// Appends "value" to the"result" using a left margin (similar to std::setw).
const auto append_string = [&](int length, absl::string_view value) {
const int num_spaces = std::max(length - static_cast<int>(value.size()), 0);
absl::StrAppend(result, std::string(num_spaces, ' '), value);
};
const auto print_value = [&](int length, T value) {
print_string(length, absl::StrCat(value));
const auto append_value = [&](int length, T value) {
append_string(length, value_to_str(value));
};

// Print header.
print_string(max_row_label_length, "");

// Empty top-left cell
append_string(max_length_row_labels, "");

// Column labels
for (int col = begin_col_idx; col < ncol_; col++) {
print_string(max_length_per_col[col] + margin, column_labels[col]);
append_string(max_length_per_col[col], column_labels[col]);
}
absl::StrAppend(result, "\n");

// Print body.
for (int row = begin_row_idx; row < nrow_; row++) {
print_string(max_row_label_length, row_labels[row]);
append_string(max_length_row_labels, row_labels[row]);
for (int col = begin_col_idx; col < ncol_; col++) {
print_value(max_length_per_col[col] + margin, at(row, col));
append_value(max_length_per_col[col], at(row, col));
}
absl::StrAppend(result, "\n");
}
absl::StrAppend(result, "Total: ", sum_, "\n");
absl::StrAppend(result, "Total: ", value_to_str(sum_), "\n");
return absl::OkStatus();
}

Expand Down
29 changes: 29 additions & 0 deletions yggdrasil_decision_forests/utils/distribution_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,35 @@ Total: 2436
)");
}

TEST(Distribution, IntegersConfusionMatrix_AppendTextReport_with_floats) {
IntegersConfusionMatrixDouble confusion;
confusion.SetSize(4, 4);
dataset::proto::Column column;
column.mutable_categorical()->set_is_already_integerized(false);
column.mutable_categorical()->set_number_of_unique_values(4);
auto& items = *column.mutable_categorical()->mutable_items();
items["a"].set_index(0);
items["bb"].set_index(1);
items["ccc"].set_index(2);
items["dddd"].set_index(3);
for (int col = 0; col < confusion.ncol(); col++) {
for (int row = 0; row < confusion.nrow(); row++) {
double value = 0.00000123456789 * std::pow(10, (col + row) * 2);
confusion.Add(row, col, value);
}
}
std::string representation;
CHECK_OK(confusion.AppendTextReport(column, &representation));
EXPECT_EQ(representation, R"(truth\prediction
a bb ccc dddd
a 1.23456789e-06 0.000123456789 0.0123456789 1.23456789
bb 0.000123456789 0.0123456789 1.23456789 123.456789
ccc 0.0123456789 1.23456789 123.456789 12345.6789
dddd 1.23456789 123.456789 12345.6789 1234567.89
Total: 1259634.593723745
)");
}

TEST(Distribution, IntegersConfusionMatrix_AppendTextReportAlreadyIntegerized) {
IntegersConfusionMatrixDouble confusion;
confusion.SetSize(4, 4);
Expand Down

0 comments on commit 4669c72

Please sign in to comment.