Skip to content

Commit

Permalink
Introduced shap value logging
Browse files Browse the repository at this point in the history
  • Loading branch information
rvandewater committed Aug 16, 2024
1 parent d0ea1a2 commit 358654c
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 4 deletions.
28 changes: 25 additions & 3 deletions icu_benchmarks/models/ml_models/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import logging

import gin
import numpy as np

from icu_benchmarks.contants import RunMode
from icu_benchmarks.models.wrappers import MLWrapper
import xgboost as xgb
Expand All @@ -10,6 +12,8 @@
import wandb
from statistics import mean
from optuna.integration import XGBoostPruningCallback
import shap

@gin.configurable
class XGBClassifier(MLWrapper):
_supported_run_modes = [RunMode.classification]
Expand All @@ -29,11 +33,29 @@ def fit_model(self, train_data, train_labels, val_data, val_labels):
if wandb.run is not None:
callbacks.append(wandb_xgb())
self.model.fit(train_data, train_labels, eval_set=[(val_data, val_labels)], verbose=False)
logging.info(self.model.get_booster().get_score(importance_type='weight'))

self.log_dict(self.model.get_booster().get_score(importance_type='weight'))
self.explainer = shap.TreeExplainer(self.model)
self.train_shap_values = self.explainer(train_data)
# shap.summary_plot(shap_values, X_test, feature_names=features)
# logging.info(self.model.get_booster().get_score(importance_type='weight'))
# self.log_dict(self.model.get_booster().get_score(importance_type='weight'))
return mean(self.model.evals_result_["validation_0"]["logloss"])#, callbacks=callbacks)

def test_step(self, dataset, _):
test_rep, test_label = dataset
test_rep, test_label = test_rep.squeeze().cpu().numpy(), test_label.squeeze().cpu().numpy()
self.set_metrics(test_label)
test_pred = self.predict(test_rep)
if self.explainer is not None:
self.test_shap_values = self.explainer(test_rep)
logging.info(f"Shap values: {self.test_shap_values}")
# self.log("test/shap_values", self.test_shap_values, sync_dist=True)
if self.mps:
self.log("test/loss", np.float32(self.loss(test_label, test_pred)), sync_dist=True)
self.log_metrics(np.float32(test_label), np.float32(test_pred), "test")
else:
self.log("test/loss", self.loss(test_label, test_pred), sync_dist=True)
self.log_metrics(test_label, test_pred, "test")
logging.debug(f"Test loss: {self.loss(test_label, test_pred)}")

def set_model_args(self, model, *args, **kwargs):
"""XGBoost signature does not include the hyperparams so we need to pass them manually."""
Expand Down
18 changes: 18 additions & 0 deletions icu_benchmarks/models/train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import gin
import numpy as np
import torch
import logging
import polars as pl
Expand Down Expand Up @@ -192,9 +193,26 @@ def train_common(

model.set_weight("balanced", train_dataset)
test_loss = trainer.test(model, dataloaders=test_loader, verbose=verbose)[0]["test/loss"]
persist_data(trainer, log_dir)
save_config_file(log_dir)
return test_loss

def persist_data(trainer, log_dir):
if trainer.lightning_module.test_shap_values is not None:
shap_values = trainer.lightning_module.test_shap_values
# sdf_test = pl.DataFrame({
# 'features': trainer.lightning_module.trained_columns,
# 'feature_value': np.transpose(shap_values.values.mean(axis=0)),
# })
shaps_test = pl.DataFrame(schema = trainer.lightning_module.trained_columns,
data = np.transpose(shap_values.values))
shaps_test.write_parquet(log_dir / "shap_values_test.parquet")
logging.info(f"Saved shap values to {log_dir / 'test_shap_values.parquet'}")
if trainer.lightning_module.train_shap_values is not None:
shap_values = trainer.lightning_module.train_shap_values
shaps_train = pl.DataFrame(schema = trainer.lightning_module.trained_columns,
data = np.transpose(shap_values.values))
shaps_train.write_parquet(log_dir / "shap_values_train.parquet")

def load_model(model, source_dir, pl_model=True):
if source_dir.exists():
Expand Down
4 changes: 3 additions & 1 deletion icu_benchmarks/models/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,9 @@ def test_step(self, dataset, _):
test_rep, test_label = test_rep.squeeze().cpu().numpy(), test_label.squeeze().cpu().numpy()
self.set_metrics(test_label)
test_pred = self.predict(test_rep)

# if self.explainer is not None:
# self.test_shap_values = self.explainer(test_rep)
# logging.info(f"Shap values: {self.test_shap_values}")
if self.mps:
self.log("test/loss", np.float32(self.loss(test_label, test_pred)), sync_dist=True)
self.log_metrics(np.float32(test_label), np.float32(test_pred), "test")
Expand Down

0 comments on commit 358654c

Please sign in to comment.