Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 449694529
  • Loading branch information
achoum authored and copybara-github committed May 19, 2022
1 parent f9dcf9b commit c4dd99f
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,9 @@ message DecisionTreeTrainingConfig {
// Score: (p-q)^2/q
// Categorical outcome only.
CHI_SQUARED = 2;

// Conservative estimate (lower bound) of the euclidean distance.
CONSERVATIVE_EUCLIDEAN_DISTANCE = 3;
}

optional SplitScore split_score = 2 [default = KULLBACK_LEIBLER];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,8 @@ absl::Status GetGenericHyperParameterSpecification(
for (const auto& value :
{kHParamUpliftSplitScoreKL, kHParamUpliftSplitScoreKLAlt,
kHParamUpliftSplitScoreED, kHParamUpliftSplitScoreEDAlt,
kHParamUpliftSplitScoreCS, kHParamUpliftSplitScoreCSAlt}) {
kHParamUpliftSplitScoreCS, kHParamUpliftSplitScoreCSAlt,
kHParamUpliftSplitScoreCED, kHParamUpliftSplitScoreCEDAlt}) {
param->mutable_categorical()->add_possible_values(value);
}

Expand Down Expand Up @@ -685,6 +686,11 @@ absl::Status SetHyperParameters(
value == kHParamUpliftSplitScoreEDAlt) {
dt_config->mutable_uplift()->set_split_score(
proto::DecisionTreeTrainingConfig::Uplift::EUCLIDEAN_DISTANCE);
} else if (value == kHParamUpliftSplitScoreCED ||
value == kHParamUpliftSplitScoreCEDAlt) {
dt_config->mutable_uplift()->set_split_score(
proto::DecisionTreeTrainingConfig::Uplift::
CONSERVATIVE_EUCLIDEAN_DISTANCE);
} else {
return absl::InvalidArgumentError(
absl::StrFormat(R"(Unknown value "%s" for parameter "%s")", value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,11 @@ constexpr char kHParamUpliftMinExamplesInTreatment[] =
constexpr char kHParamUpliftSplitScoreKL[] = "KULLBACK_LEIBLER";
constexpr char kHParamUpliftSplitScoreED[] = "EUCLIDEAN_DISTANCE";
constexpr char kHParamUpliftSplitScoreCS[] = "CHI_SQUARED";
constexpr char kHParamUpliftSplitScoreCED[] = "CONSERVATIVE_EUCLIDEAN_DISTANCE";
constexpr char kHParamUpliftSplitScoreKLAlt[] = "KL";
constexpr char kHParamUpliftSplitScoreEDAlt[] = "ED";
constexpr char kHParamUpliftSplitScoreCSAlt[] = "CS";
constexpr char kHParamUpliftSplitScoreCEDAlt[] = "CED";

constexpr char kHParamHonest[] = "honest";
constexpr char kHParamHonestRatioLeafExamples[] = "honest_ratio_leaf_examples";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,19 +170,61 @@ struct UpliftLabelDistribution {
return response_treatment - response_control;
}

double UpliftSplitScore(const SplitScoreType score) const {
const double response_control = MeanOutcomePerTreatment(0);
const double response_treatment = MeanOutcomePerTreatment(1);
// Returns the lower bound of the 9.7% confidence interval of the uplift.
// Model the outcome as a normal distribution.
double ConservativeUplift() const {
// Only support binary treatment and single dimension outcome.
DCHECK_EQ(sum_weights_per_treatment_.size(), 2);
DCHECK_EQ(num_examples_per_treatment_.size(), 2);
DCHECK_EQ(sum_weights_per_treatment_and_outcome_.size(), 2);

if (response_treatment == 0) {
if (sum_weights_per_treatment_[0] == 0 ||
sum_weights_per_treatment_[1] == 0) {
return 0;
}

const double mean_c = MeanOutcomePerTreatment(0);
const double mean_t = MeanOutcomePerTreatment(1);
const auto var_c = mean_c * (1 - mean_c) / sum_weights_per_treatment_[0];
const auto var_t = mean_t * (1 - mean_t) / sum_weights_per_treatment_[1];
const auto mean_diff = mean_t - mean_c;
const auto var_diff = var_c + var_t;
const auto sd_diff = sqrt(var_diff);
// z-value for a ~9.7% confidence bound. This value was selected to give
// reasonable results on the train/test SimPTE dataset.
const double z = 1.3;

const double lb = mean_diff - z * sd_diff;
const double ub = mean_diff + z * sd_diff;

// Return the most conservative uplift value (i.e. the value closest to
// zero; i.e. with the smaller absolute value) in [lb, ub]. For example, if
// l=-0.1 and ub=0.3, return return 0.
if (lb > 0) {
return lb;
}
if (ub < 0) {
return ub;
}
return 0;
}

double UpliftSplitScore(const SplitScoreType score) const {
switch (score) {
case proto::DecisionTreeTrainingConfig::Uplift::EUCLIDEAN_DISTANCE:
case proto::DecisionTreeTrainingConfig::Uplift::EUCLIDEAN_DISTANCE: {
const double response_control = MeanOutcomePerTreatment(0);
const double response_treatment = MeanOutcomePerTreatment(1);

return (response_control - response_treatment) *
(response_control - response_treatment);
case proto::DecisionTreeTrainingConfig::Uplift::KULLBACK_LEIBLER:
}
case proto::DecisionTreeTrainingConfig::Uplift::KULLBACK_LEIBLER: {
const double response_control = MeanOutcomePerTreatment(0);
const double response_treatment = MeanOutcomePerTreatment(1);
if (response_treatment == 0) {
return 0;
}

if (response_control == 0) {
// The returned divergence should be infinite (or very high). However,
// this would essentially discard all the possible splits. Returning
Expand All @@ -193,7 +235,10 @@ struct UpliftLabelDistribution {
}
return response_treatment *
std::log(response_treatment / response_control);
case proto::DecisionTreeTrainingConfig::Uplift::CHI_SQUARED:
}
case proto::DecisionTreeTrainingConfig::Uplift::CHI_SQUARED: {
const double response_control = MeanOutcomePerTreatment(0);
const double response_treatment = MeanOutcomePerTreatment(1);
if (response_control == 0) {
// The returned divergence should be infinite (or very high). However,
// this would essentially discard all the possible splits. Returning
Expand All @@ -204,6 +249,12 @@ struct UpliftLabelDistribution {
}
return (response_treatment - response_control) *
(response_treatment - response_control) / response_control;
}
case proto::DecisionTreeTrainingConfig::Uplift::
CONSERVATIVE_EUCLIDEAN_DISTANCE: {
const auto u = ConservativeUplift();
return u * u;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,16 @@ TEST_F(RandomForestOnSimPTE, Honest) {
EXPECT_NEAR(metric::Qini(evaluation_), 0.106705, 0.002);
}

TEST_F(RandomForestOnSimPTE, LowerBound) {
auto* rf_config = train_config_.MutableExtension(
random_forest::proto::random_forest_config);
rf_config->mutable_decision_tree()->mutable_uplift()->set_split_score(
decision_tree::proto::DecisionTreeTrainingConfig::Uplift::
CONSERVATIVE_EUCLIDEAN_DISTANCE);
TrainAndEvaluateModel();
EXPECT_NEAR(metric::Qini(evaluation_), 0.10889, 0.002);
}

TEST(SampleTrainingExamples, WithReplacement) {
utils::RandomEngine random;
std::vector<dataset::VerticalDataset::row_t> examples;
Expand Down

0 comments on commit c4dd99f

Please sign in to comment.