diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index acc5ed1f..38e2c473 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,7 +33,7 @@ jobs: # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # the GitHub editor is 127 chars wide - flake8 . --count --max-complexity=20 --max-line-length=127 --statistics + flake8 . --count --max-complexity=30 --max-line-length=127 --statistics # - name: Test with pytest # run: python -m pytest ./tests/recipes # If we want to test running the tool later on diff --git a/configs/prediction_models/BRFClassifier.gin b/configs/prediction_models/BRFClassifier.gin new file mode 100644 index 00000000..3682bb8c --- /dev/null +++ b/configs/prediction_models/BRFClassifier.gin @@ -0,0 +1,18 @@ +# Settings for ImbLearn Balanced Random Forest Classifier. + +# Common settings for ML models +include "configs/prediction_models/common/MLCommon.gin" + +# Train params +train_common.model = @BRFClassifier + +model/hyperparameter.class_to_tune = @BRFClassifier +model/hyperparameter.n_estimators = [50, 100, 250, 500, 750,1000,1500] +model/hyperparameter.max_depth = [3, 5, 10, 15] +model/hyperparameter.min_samples_split = (2, 5, 10) +model/hyperparameter.min_samples_leaf = (1, 2, 4) +model/hyperparameter.max_features = ['sqrt', 'log2', 1.0] +model/hyperparameter.bootstrap = [True, False] +model/hyperparameter.class_weight = [None, 'balanced'] + + diff --git a/configs/prediction_models/CBClassifier.gin b/configs/prediction_models/CBClassifier.gin new file mode 100644 index 00000000..e9abbecd --- /dev/null +++ b/configs/prediction_models/CBClassifier.gin @@ -0,0 +1,15 @@ +# Settings for Catboost classifier. + +# Common settings for ML models +include "configs/prediction_models/common/MLCommon.gin" + +# Train params +train_common.model = @CBClassifier + +model/hyperparameter.class_to_tune = @CBClassifier +model/hyperparameter.learning_rate = (1e-4, 0.5, "log") +model/hyperparameter.num_trees = [50, 100, 250, 500, 750,1000,1500] +model/hyperparameter.depth = [3, 5, 10, 15] +model/hyperparameter.scale_pos_weight = [1, 5, 10, 25, 50, 75, 99, 100, 1000] +model/hyperparameter.border_count = [5, 10, 20, 50, 100, 200] +model/hyperparameter.l2_leaf_reg = [1, 3, 5, 7, 9] \ No newline at end of file diff --git a/configs/prediction_models/GRU.gin b/configs/prediction_models/GRU.gin index d2a28a79..43cb0218 100644 --- a/configs/prediction_models/GRU.gin +++ b/configs/prediction_models/GRU.gin @@ -9,11 +9,11 @@ train_common.model = @GRUNet # Optimizer params optimizer/hyperparameter.class_to_tune = @Adam optimizer/hyperparameter.weight_decay = 1e-6 -optimizer/hyperparameter.lr = (1e-5, 3e-4) +optimizer/hyperparameter.lr = (1e-6, 1e-4, "log") # Encoder params model/hyperparameter.class_to_tune = @GRUNet model/hyperparameter.num_classes = %NUM_CLASSES -model/hyperparameter.hidden_dim = (32, 256, "log-uniform", 2) -model/hyperparameter.layer_dim = (1, 3) +model/hyperparameter.hidden_dim = (32, 512, "log") +model/hyperparameter.layer_dim = (1, 10) diff --git a/configs/prediction_models/RFClassifier.gin b/configs/prediction_models/RFClassifier.gin index 72d03e66..61d627d6 100644 --- a/configs/prediction_models/RFClassifier.gin +++ b/configs/prediction_models/RFClassifier.gin @@ -8,11 +8,11 @@ train_common.model = @RFClassifier model/hyperparameter.class_to_tune = @RFClassifier model/hyperparameter.n_estimators = (10, 50, 100, 200, 500) -model/hyperparameter.max_depth = (None, 5, 10, 20) +model/hyperparameter.max_depth = (5, 10, 20) model/hyperparameter.min_samples_split = (2, 5, 10) model/hyperparameter.min_samples_leaf = (1, 2, 4) -model/hyperparameter.max_features = ('sqrt', 'log2', None) -model/hyperparameter.bootstrap = (True, False) -model/hyperparameter.class_weight = (None, 'balanced') +model/hyperparameter.max_features = ['sqrt', 'log2', None] +model/hyperparameter.bootstrap = [True, False] +model/hyperparameter.class_weight = [None, 'balanced'] diff --git a/configs/prediction_models/RUSBClassifier.gin b/configs/prediction_models/RUSBClassifier.gin new file mode 100644 index 00000000..e8f17722 --- /dev/null +++ b/configs/prediction_models/RUSBClassifier.gin @@ -0,0 +1,14 @@ +# Settings for ImbLearn Balanced Random Forest Classifier. + +# Common settings for ML models +include "configs/prediction_models/common/MLCommon.gin" + +# Train params +train_common.model = @RUSBClassifier + +model/hyperparameter.class_to_tune = @RUSBClassifier +model/hyperparameter.n_estimators = (10, 50, 100, 200, 500) +model/hyperparameter.learning_rate = (0.005, 1, "log") +model/hyperparameter.sampling_strategy = "auto" + + diff --git a/configs/prediction_models/TCN.gin b/configs/prediction_models/TCN.gin index c6b314db..d1cb748a 100644 --- a/configs/prediction_models/TCN.gin +++ b/configs/prediction_models/TCN.gin @@ -9,12 +9,12 @@ train_common.model = @TemporalConvNet # Optimizer params optimizer/hyperparameter.class_to_tune = @Adam optimizer/hyperparameter.weight_decay = 1e-6 -optimizer/hyperparameter.lr = (1e-5, 3e-4) +optimizer/hyperparameter.lr = (1e-6, 3e-4) # Encoder params model/hyperparameter.class_to_tune = @TemporalConvNet model/hyperparameter.num_classes = %NUM_CLASSES model/hyperparameter.max_seq_length = %HORIZON -model/hyperparameter.num_channels = (32, 256, "log-uniform", 2) -model/hyperparameter.kernel_size = (2, 32, "log-uniform", 2) +model/hyperparameter.num_channels = (32, 256, "log") +model/hyperparameter.kernel_size = (2, 128, "log") model/hyperparameter.dropout = (0.0, 0.4) diff --git a/configs/prediction_models/Transformer.gin b/configs/prediction_models/Transformer.gin index 2767fd37..69f31e51 100644 --- a/configs/prediction_models/Transformer.gin +++ b/configs/prediction_models/Transformer.gin @@ -8,17 +8,17 @@ train_common.model = @Transformer optimizer/hyperparameter.class_to_tune = @Adam optimizer/hyperparameter.weight_decay = 1e-6 -optimizer/hyperparameter.lr = (1e-5, 3e-4) +optimizer/hyperparameter.lr = (1e-6, 1e-4) # Encoder params model/hyperparameter.class_to_tune = @Transformer -model/hyperparameter.ff_hidden_mult = 2 -model/hyperparameter.l1_reg = 0.0 +model/hyperparameter.ff_hidden_mult = (2,4,6,8) +model/hyperparameter.l1_reg = (0.0,1.0) model/hyperparameter.num_classes = %NUM_CLASSES -model/hyperparameter.hidden = (32, 256, "log-uniform", 2) -model/hyperparameter.heads = (1, 8, "log-uniform", 2) +model/hyperparameter.hidden = (32, 512, "log") +model/hyperparameter.heads = (1, 8, "log") model/hyperparameter.depth = (1, 3) -model/hyperparameter.dropout = (0.0, 0.4) -model/hyperparameter.dropout_att = (0.0, 0.4) +model/hyperparameter.dropout = 0 # no improvement (0.0, 0.4) +model/hyperparameter.dropout_att = (0.0, 1.0) diff --git a/configs/prediction_models/XGBClassifier.gin b/configs/prediction_models/XGBClassifier.gin new file mode 100644 index 00000000..f1070672 --- /dev/null +++ b/configs/prediction_models/XGBClassifier.gin @@ -0,0 +1,17 @@ +# Settings for XGBoost classifier. + +# Common settings for ML models +include "configs/prediction_models/common/MLCommon.gin" + +# Train params +train_common.model = @XGBClassifier + +model/hyperparameter.class_to_tune = @XGBClassifier +model/hyperparameter.learning_rate = (0.01, 0.1, "log") +model/hyperparameter.n_estimators = [50, 100, 250, 500, 750, 1000,1500,2000] +model/hyperparameter.max_depth = [3, 5, 10, 15] +model/hyperparameter.scale_pos_weight = [1, 5, 10, 15, 20, 25, 30, 35, 40, 50, 75, 99, 100, 1000] +model/hyperparameter.min_child_weight = [1, 0.5] +model/hyperparameter.max_delta_step = [0, 1, 2, 3, 4, 5, 10] +model/hyperparameter.colsample_bytree = [0.1, 0.25, 0.5, 0.75, 1.0] +model/hyperparameter.eval_metric = "aucpr" \ No newline at end of file diff --git a/configs/prediction_models/common/DLCommon.gin b/configs/prediction_models/common/DLCommon.gin index c220e6ab..9d790775 100644 --- a/configs/prediction_models/common/DLCommon.gin +++ b/configs/prediction_models/common/DLCommon.gin @@ -3,7 +3,9 @@ # Imports to register the models import gin.torch.external_configurables import icu_benchmarks.models.wrappers -import icu_benchmarks.models.dl_models +import icu_benchmarks.models.dl_models.rnn +import icu_benchmarks.models.dl_models.transformer +import icu_benchmarks.models.dl_models.tcn import icu_benchmarks.models.utils # Do not generate features from dynamic data @@ -12,7 +14,7 @@ base_regression_preprocessor.generate_features = False # Train params train_common.optimizer = @Adam -train_common.epochs = 1000 +train_common.epochs = 50 train_common.batch_size = 64 train_common.patience = 10 train_common.min_delta = 1e-4 diff --git a/configs/prediction_models/common/DLTuning.gin b/configs/prediction_models/common/DLTuning.gin index b4d13e12..0d71c2f8 100644 --- a/configs/prediction_models/common/DLTuning.gin +++ b/configs/prediction_models/common/DLTuning.gin @@ -2,4 +2,4 @@ tune_hyperparameters.scopes = ["model", "optimizer"] tune_hyperparameters.n_initial_points = 5 tune_hyperparameters.n_calls = 30 -tune_hyperparameters.folds_to_tune_on = 2 \ No newline at end of file +tune_hyperparameters.folds_to_tune_on = 5 \ No newline at end of file diff --git a/configs/prediction_models/common/MLCommon.gin b/configs/prediction_models/common/MLCommon.gin index 460bceba..4d26b8c7 100644 --- a/configs/prediction_models/common/MLCommon.gin +++ b/configs/prediction_models/common/MLCommon.gin @@ -3,7 +3,11 @@ # Imports to register the models import gin.torch.external_configurables import icu_benchmarks.models.wrappers -import icu_benchmarks.models.ml_models +import icu_benchmarks.models.ml_models.sklearn +import icu_benchmarks.models.ml_models.lgbm +import icu_benchmarks.models.ml_models.xgboost +import icu_benchmarks.models.ml_models.imblearn +import icu_benchmarks.models.ml_models.catboost import icu_benchmarks.models.utils # Patience for early stopping diff --git a/configs/prediction_models/common/MLTuning.gin b/configs/prediction_models/common/MLTuning.gin index 92fa0f0c..9df38c47 100644 --- a/configs/prediction_models/common/MLTuning.gin +++ b/configs/prediction_models/common/MLTuning.gin @@ -1,5 +1,5 @@ # Hyperparameter tuner settings for classical Machine Learning. tune_hyperparameters.scopes = ["model"] -tune_hyperparameters.n_initial_points = 10 -tune_hyperparameters.n_calls = 3 -tune_hyperparameters.folds_to_tune_on = 1 \ No newline at end of file +tune_hyperparameters.n_initial_points = 5 +tune_hyperparameters.n_calls = 30 +tune_hyperparameters.folds_to_tune_on = 5 \ No newline at end of file diff --git a/configs/tasks/BinaryClassification.gin b/configs/tasks/BinaryClassification.gin index 492a12eb..f86436a4 100644 --- a/configs/tasks/BinaryClassification.gin +++ b/configs/tasks/BinaryClassification.gin @@ -19,11 +19,10 @@ DLPredictionWrapper.loss = @cross_entropy # SELECTING PREPROCESSOR preprocess.preprocessor = @base_classification_preprocessor +preprocess.modality_mapping = %modality_mapping preprocess.vars = %vars preprocess.use_static = True # SELECTING DATASET -PredictionDataset.vars = %vars -PredictionDataset.ram_cache = True - +include "configs/tasks/common/Dataloader.gin" diff --git a/configs/tasks/DatasetImputation.gin b/configs/tasks/DatasetImputation.gin index ddbd56a2..55914adc 100644 --- a/configs/tasks/DatasetImputation.gin +++ b/configs/tasks/DatasetImputation.gin @@ -22,6 +22,6 @@ preprocess.file_names = { preprocess.preprocessor = @base_imputation_preprocessor preprocess.vars = %vars -ImputationDataset.vars = %vars -ImputationDataset.ram_cache = True + +include "configs/tasks/common/Dataloader.gin" diff --git a/configs/tasks/Regression.gin b/configs/tasks/Regression.gin index 5cf3f8d9..c2c54174 100644 --- a/configs/tasks/Regression.gin +++ b/configs/tasks/Regression.gin @@ -28,6 +28,5 @@ base_regression_preprocessor.outcome_min = 0 base_regression_preprocessor.outcome_max = 15 # SELECTING DATASET -PredictionDataset.vars = %vars -PredictionDataset.ram_cache = True +include "configs/tasks/common/Dataloader.gin" diff --git a/configs/tasks/common/Dataloader.gin b/configs/tasks/common/Dataloader.gin new file mode 100644 index 00000000..6bed1b7e --- /dev/null +++ b/configs/tasks/common/Dataloader.gin @@ -0,0 +1,8 @@ +# Prediction +PredictionPandasDataset.vars = %vars +PredictionPandasDataset.ram_cache = True +PredictionPolarsDataset.vars = %vars +PredictionPolarsDataset.ram_cache = True +# Imputation +ImputationPandasDataset.vars = %vars +ImputationPandasDataset.ram_cache = True \ No newline at end of file diff --git a/configs/tasks/common/PredictionTaskVariables.gin b/configs/tasks/common/PredictionTaskVariables.gin index 6e38638e..d5006041 100644 --- a/configs/tasks/common/PredictionTaskVariables.gin +++ b/configs/tasks/common/PredictionTaskVariables.gin @@ -15,4 +15,12 @@ vars = { "methb", "mg", "na", "neut", "o2sat", "pco2", "ph", "phos", "plt", "po2", "ptt", "resp", "sbp", "temp", "tnt", "urine", "wbc"], "STATIC": ["age", "sex", "height", "weight"], +} + +modality_mapping = { + "DYNAMIC": ["alb", "alp", "alt", "ast", "be", "bicar", "bili", "bili_dir", "bnd", "bun", "ca", "cai", "ck", "ckmb", "cl", + "crea", "crp", "dbp", "fgn", "fio2", "glu", "hgb", "hr", "inr_pt", "k", "lact", "lymph", "map", "mch", "mchc", "mcv", + "methb", "mg", "na", "neut", "o2sat", "pco2", "ph", "phos", "plt", "po2", "ptt", "resp", "sbp", "temp", "tnt", "urine", + "wbc"], + "STATIC": ["age", "sex", "height", "weight"], } \ No newline at end of file diff --git a/docs/adding_model/RNN.gin b/docs/adding_model/RNN.gin new file mode 100644 index 00000000..531aeff6 --- /dev/null +++ b/docs/adding_model/RNN.gin @@ -0,0 +1,18 @@ +# Settings for Recurrent Neural Network (RNN) models. + +# Common settings for DL models +include "configs/prediction_models/common/DLCommon.gin" + +# Train params +train_common.model = @RNNet + +# Optimizer params +optimizer/hyperparameter.class_to_tune = @Adam +optimizer/hyperparameter.weight_decay = 1e-6 +optimizer/hyperparameter.lr = (1e-5, 3e-4) + +# Encoder params +model/hyperparameter.class_to_tune = @RNNet +model/hyperparameter.num_classes = %NUM_CLASSES +model/hyperparameter.hidden_dim = (32, 256, "log-uniform", 2) +model/hyperparameter.layer_dim = (1, 3) diff --git a/docs/adding_model/instructions.md b/docs/adding_model/instructions.md new file mode 100644 index 00000000..0896fedd --- /dev/null +++ b/docs/adding_model/instructions.md @@ -0,0 +1,190 @@ +# Adding new models to YAIB +## Example +We refer to the page [adding a new model](https://github.com/rvandewater/YAIB/wiki/Adding-a-new-model) for detailed instructions on adding new models. +We allow prediction models to be easily added and integrated into a Pytorch Lightning module. This +incorporates advanced logging and debugging capabilities, as well as +built-in parallelism. Our interface derives from the [`BaseModule`](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html). + +Adding a model consists of three steps: +1. Add a model through the existing `MLPredictionWrapper` or `DLPredictionWrapper`. +2. Add a GIN config file to bind hyperparameters. +3. Execute YAIB using a simple command. + +This folder contains everything you need to add a model to YAIB. +Putting the `RNN.gin` file in `configs/prediction_models` and the `rnn.py` file into icu_benchmarks/models allows you to run the model fully. + +``` +icu-benchmarks train \ + -d demo_data/mortality24/mimic_demo \ # Insert cohort dataset here + -n mimic_demo \ + -t BinaryClassification \ # Insert task name here + -tn Mortality24 \ + --log-dir ../yaib_logs/ \ + -m RNN \ # Insert model here + -s 2222 \ + -l ../yaib_logs/ \ + --tune +``` +# Adding more models +## Regular ML +For standard Scikit-Learn type models (e.g., LGBM), one can +simply wrap `MLPredictionWrapper` the function with minimal code +overhead. Many ML (and some DL) models can be incorporated this way, requiring minimal code additions. See below. + +``` {#code:ml-model-definition frame="single" style="pycharm" caption="\\textit{Example ML model definition}" label="code:ml-model-definition" columns="fullflexible"} +@gin.configurable +class RFClassifier(MLWrapper): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model = self.model_args() + + @gin.configurable(module="RFClassifier") + def model_args(self, *args, **kwargs): + return RandomForestClassifier(*args, **kwargs) +``` +## Adding DL models +It is relatively straightforward to add new Pytorch models to YAIB. We first provide a standard RNN-model which needs no extra components. Then, we show the implementation of the Temporal Fusion Transformer model. + +### Standard RNN-model +The definition of dl models can be done by creating a subclass from the +`DLPredictionWrapper`, inherits the standard methods needed for +training dl learning models. Pytorch Lightning significantly reduces the code +overhead. + + +``` {#code:dl-model-definition frame="single" style="pycharm" caption="\\textit{Example DL model definition}" label="code:dl-model-definition" columns="fullflexible"} +@gin.configurable +class RNNet(DLPredictionWrapper): + """Torch standard RNN model""" + + def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs): + super().__init__( + input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, *args, **kwargs + ) + self.hidden_dim = hidden_dim + self.layer_dim = layer_dim + self.rnn = nn.RNN(input_size[2], hidden_dim, layer_dim, batch_first=True) + self.logit = nn.Linear(hidden_dim, num_classes) + + def init_hidden(self, x): + h0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim) + return h0 + + def forward(self, x): + h0 = self.init_hidden(x) + out, hn = self.rnn(x, h0) + pred = self.logit(out) + return pred +``` +### Adding a SOTA model: Temporal Fusion Transformer +There are two main questions when you want to add a more complex model: + +* _Do you want to manually define the model or use an existing library?_ This might require adapting the `DLPredictionWrapper`. +* _Does the model expect the data to be in a certain format?_ This might require adapting the `PredictionDataset`. + +By adapting, we mean creating a new subclass that inherits most functionality to avoid code duplication, is future-proof, and follows good coding practices. + +First, you can add modules to `models/layers.py` to use them for your model. +``` {#code:building blocks frame="single" style="pycharm" caption="\\textit{Example building block}" label="code: layers" columns="fullflexible"} +class StaticCovariateEncoder(nn.Module): + """ + Network to produce 4 context vectors to enrich static variables + Variable selection Network --> GRNs + """ + + def __init__(self, num_static_vars, hidden, dropout): + super().__init__() + self.vsn = VariableSelectionNetwork(hidden, dropout, num_static_vars) + self.context_grns = nn.ModuleList([GRN(hidden, hidden, dropout=dropout) for _ in range(4)]) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + variable_ctx, sparse_weights = self.vsn(x) + + # Context vectors: + # variable selection context + # enrichment context + # state_c context + # state_h context + cs, ce, ch, cc = [m(variable_ctx) for m in self.context_grns] + + return cs, ce, ch, cc +``` +Note that we can create modules out of modules as well. + +### Adapting the `DLPredictionWrapper` +The next step is to use the building blocks defined in layers.py or modules from an existing library to add to the model in `models/dl_models.py`. In this In this case, we use the Pytorch-forecasting library (https://github.com/jdb78/pytorch-forecasting): + +``` {#code:dl-model-definition frame="single" style="pycharm" caption="\\textit{Example DL model definition}" label="code:dl-model-definition" columns="fullflexible"} +class TFTpytorch(DLPredictionWrapper): + + supported_run_modes = [RunMode.classification, RunMode.regression] + + def __init__(self, dataset, hidden, dropout, n_heads, dropout_att, lr, optimizer, num_classes, *args, **kwargs): + super().__init__(lr=lr, optimizer=optimizer, *args, **kwargs) + self.model = TemporalFusionTransformer.from_dataset( + dataset=dataset) + self.logit = nn.Linear(7, num_classes) + + + def forward(self, x): + out = self.model(x) + pred = self.logit(out["prediction"]) + return pred +``` + +### Adapting the `PredictionDataset` +Some models require an adjusted dataloader to facilitate, for example, explainability methods. In this case, changes need to be made to the `data/loader.py` file to ensure the data loader returns the data in the correct format. +This can be done by creating a class that inherits from PredictionDataset and editing the get_item method. +``` {#code:dataset frame="single" style="pycharm" caption="\\textit{Example custom dataset definition}" label="code: dataset" columns="fullflexible"} +@gin.configurable("PredictionDatasetTFT") +class PredictionDatasetTFT(PredictionDataset): + def __init__(self, *args, ram_cache: bool = True, **kwargs): + super().__init__(*args, ram_cache=True, **kwargs) + +def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]: + """Function to sample from the data split of choice. Used for TFT. + The data needs to be given to the model in the following order + [static categorical, static continuous,known categorical,known continuous, observed categorical, observed continuous,target,id] +``` +Then, you must check `models/wrapper.py`, particularly the step_fn method, to ensure the data is correctly transferred to the device. + +## Adding the model config GIN file +To define hyperparameters for each model in a standardized manner, we use GIN-config. We need to specify a GIN file to bind the parameters to train and optimize this model from a choice of hyperparameters. Note that we can use modifiers for the optimizer (e.g, Adam optimizer) and ranges that we can specify in rounded brackets "()". Square brackets, "[]", result in a random choice where the variable is uniformly sampled. +``` +# Hyperparameters for TFT model. + +# Common settings for DL models +include "configs/prediction_models/common/DLCommon.gin" + +# Optimizer params +train_common.model = @TFT + +optimizer/hyperparameter.class_to_tune = @Adam +optimizer/hyperparameter.weight_decay = 1e-6 +optimizer/hyperparameter.lr = (1e-5, 3e-4) + +# Encoder params +model/hyperparameter.class_to_tune = @TFT +model/hyperparameter.encoder_length = 24 +model/hyperparameter.hidden = 256 +model/hyperparameter.num_classes = %NUM_CLASSES +model/hyperparameter.dropout = (0.0, 0.4) +model/hyperparameter.dropout_att = (0.0, 0.4) +model/hyperparameter.n_heads =4 +model/hyperparameter.example_length=25 +``` +## Training the model +After these steps, your model should be trainable with the following command: + +``` +icu-benchmarks train \ + -d demo_data/mortality24/mimic_demo \ # Insert cohort dataset here + -n mimic_demo \ + -t BinaryClassification \ # Insert task name here + -tn Mortality24 \ + --log-dir ../yaib_logs/ \ + -m TFT \ # Insert model here + -s 2222 \ + -l ../yaib_logs/ \ + --tune +``` diff --git a/docs/adding_model/rnn.py b/docs/adding_model/rnn.py new file mode 100644 index 00000000..d2215627 --- /dev/null +++ b/docs/adding_model/rnn.py @@ -0,0 +1,30 @@ +import gin +import torch.nn as nn +from icu_benchmarks.constants import RunMode +from icu_benchmarks.models.wrappers import DLPredictionWrapper + + +@gin.configurable +class RNNet(DLPredictionWrapper): + """Torch standard RNN model""" + + _supported_run_modes = [RunMode.classification, RunMode.regression] + + def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs): + super().__init__( + input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, *args, **kwargs + ) + self.hidden_dim = hidden_dim + self.layer_dim = layer_dim + self.rnn = nn.RNN(input_size[2], hidden_dim, layer_dim, batch_first=True) + self.logit = nn.Linear(hidden_dim, num_classes) + + def init_hidden(self, x): + h0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim) + return h0 + + def forward(self, x): + h0 = self.init_hidden(x) + out, hn = self.rnn(x, h0) + pred = self.logit(out) + return pred diff --git a/environment.yml b/environment.yml index abba2126..405d9b47 100644 --- a/environment.yml +++ b/environment.yml @@ -3,7 +3,7 @@ channels: - conda-forge dependencies: - python=3.10 - - pip=24.0 + - pip>=24.0 - flake8=7.1.0 # - pip: # - -r requirements.txt diff --git a/icu_benchmarks/contants.py b/icu_benchmarks/constants.py similarity index 100% rename from icu_benchmarks/contants.py rename to icu_benchmarks/constants.py diff --git a/icu_benchmarks/cross_validation.py b/icu_benchmarks/cross_validation.py index 95e44e1a..e1563f9e 100644 --- a/icu_benchmarks/cross_validation.py +++ b/icu_benchmarks/cross_validation.py @@ -11,7 +11,7 @@ from icu_benchmarks.models.train import train_common from icu_benchmarks.models.utils import JsonResultLoggingEncoder from icu_benchmarks.run_utils import log_full_line -from icu_benchmarks.contants import RunMode +from icu_benchmarks.constants import RunMode @gin.configurable @@ -37,7 +37,7 @@ def execute_repeated_cv( cpu: bool = False, verbose: bool = False, wandb: bool = False, - complete_train: bool = False + complete_train: bool = False, ) -> float: """Preprocesses data and trains a model for each fold. @@ -81,8 +81,9 @@ def execute_repeated_cv( else: logging.info(f"Starting nested CV with {cv_repetitions_to_train} repetitions of {cv_folds_to_train} folds.") - + # Train model for each repetition (a manner of splitting the folds) for repetition in range(cv_repetitions_to_train): + # Train model for each fold configuration (i.e, one fold is test fold and the rest are train/val folds) for fold_index in range(cv_folds_to_train): repetition_fold_dir = log_dir / f"repetition_{repetition}" / f"fold_{fold_index}" repetition_fold_dir.mkdir(parents=True, exist_ok=True) @@ -101,9 +102,8 @@ def execute_repeated_cv( fold_index=fold_index, pretrained_imputation_model=pretrained_imputation_model, runmode=mode, - complete_train=complete_train + complete_train=complete_train, ) - preprocess_time = datetime.now() - start_time start_time = datetime.now() agg_loss += train_common( @@ -118,7 +118,7 @@ def execute_repeated_cv( cpu=cpu, verbose=verbose, use_wandb=wandb, - train_only=complete_train + train_only=complete_train, ) train_time = datetime.now() - start_time @@ -133,7 +133,10 @@ def execute_repeated_cv( if wandb: wandb_log({"Iteration": repetition * cv_folds_to_train + fold_index}) if repetition * cv_folds_to_train + fold_index > 1: - aggregate_results(log_dir) + try: + aggregate_results(log_dir) + except Exception as e: + logging.error(f"Failed to aggregate results: {e}") log_full_line(f"FINISHED CV REPETITION {repetition}", level=logging.INFO, char="=", num_newlines=3) return agg_loss / (cv_repetitions_to_train * cv_folds_to_train) diff --git a/icu_benchmarks/data/loader.py b/icu_benchmarks/data/loader.py index 3c7a9280..b227dbf2 100644 --- a/icu_benchmarks/data/loader.py +++ b/icu_benchmarks/data/loader.py @@ -1,3 +1,4 @@ +import warnings from typing import List from pandas import DataFrame import gin @@ -6,13 +7,183 @@ from torch.utils.data import Dataset import logging from typing import Dict, Tuple - +import polars as pl from icu_benchmarks.imputation.amputations import ampute_data from .constants import DataSegment as Segment from .constants import DataSplit as Split -class CommonDataset(Dataset): +@gin.configurable("CommonPolarsDataset") +class CommonPolarsDataset(Dataset): + def __init__( + self, + data: dict, + split: str = Split.train, + vars: Dict[str, str] = gin.REQUIRED, + grouping_segment: str = Segment.outcome, + mps: bool = False, + name: str = "", + *args, + **kwargs, + ): + # super().__init__(*args, **kwargs) + self.split = split + self.vars = vars + self.grouping_df = data[split][grouping_segment] # .set_index(self.vars["GROUP"]) + # logging.info(f"data split: {data[split]}") + # self.features_df = ( + # data[split][Segment.features].set_index(self.vars["GROUP"]).drop(labels=self.vars["SEQUENCE"], axis=1) + # ) + # Get the row indicators for the data to be able to match predicted labels + if "SEQUENCE" in self.vars and self.vars["SEQUENCE"] in data[split][Segment.features].columns: + # We have a time series dataset + self.row_indicators = data[split][Segment.features][self.vars["GROUP"], self.vars["SEQUENCE"]] + self.row_indicators = self.row_indicators.with_columns(pl.col(self.vars["SEQUENCE"]).dt.total_hours()) + self.features_df = data[split][Segment.features] + self.features_df = self.features_df.sort([self.vars["GROUP"], self.vars["SEQUENCE"]]) + self.features_df = self.features_df.drop(self.vars["SEQUENCE"]) + else: + # We have a static dataset + logging.info("Using static dataset") + self.row_indicators = data[split][Segment.features][self.vars["GROUP"]] + self.features_df = data[split][Segment.features] + # calculate basic info for the data + self.num_stays = self.grouping_df[self.vars["GROUP"]].unique().shape[0] + self.maxlen = self.features_df.group_by([self.vars["GROUP"]]).len().max().item(0, 1) + self.mps = mps + self.name = name + + def ram_cache(self, cache: bool = True): + self._cached_dataset = None + if cache: + logging.info(f"Caching {self.split} dataset in ram.") + self._cached_dataset = [self[i] for i in range(len(self))] + + def __len__(self) -> int: + """Returns number of stays in the data. + + Returns: + number of stays in the data + """ + return self.num_stays + + def get_feature_names(self) -> List[str]: + return self.features_df.columns + + def to_tensor(self) -> List[Tensor]: + values = [] + for entry in self: + for i, value in enumerate(entry): + if len(values) <= i: + values.append([]) + values[i].append(value.unsqueeze(0)) + return [cat(value, dim=0) for value in values] + + +@gin.configurable("PredictionPolarsDataset") +class PredictionPolarsDataset(CommonPolarsDataset): + """Subclass of common dataset for prediction tasks. + + Args: + ram_cache (bool, optional): Whether the complete dataset should be stored in ram. Defaults to True. + """ + + def __init__(self, *args, ram_cache: bool = True, **kwargs): + super().__init__(*args, **kwargs) + self.outcome_df = self.grouping_df + self.ram_cache(ram_cache) + + def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]: + """Function to sample from the data split of choice. Used for deep learning implementations. + + Args: + idx: A specific row index to sample. + + Returns: + A sample from the data, consisting of data, labels and padding mask. + """ + if self._cached_dataset is not None: + return self._cached_dataset[idx] + + pad_value = 0.0 + # stay_id = self.outcome_df.index.unique()[idx] # [self.vars["GROUP"]] + stay_id = self.outcome_df[self.vars["GROUP"]].unique()[idx] # [self.vars["GROUP"]] + + # slice to make sure to always return a DF + # window = self.features_df.loc[stay_id:stay_id].to_numpy() + # labels = self.outcome_df.loc[stay_id:stay_id][self.vars["LABEL"]].to_numpy(dtype=float) + window = self.features_df.filter(pl.col(self.vars["GROUP"]) == stay_id).to_numpy() + labels = self.outcome_df.filter(pl.col(self.vars["GROUP"]) == stay_id)[self.vars["LABEL"]].to_numpy().astype(float) + + if len(labels) == 1: + # only one label per stay, align with window + labels = np.concatenate([np.empty(window.shape[0] - 1) * np.nan, labels], axis=0) + + length_diff = self.maxlen - window.shape[0] + pad_mask = np.ones(window.shape[0]) + + # Padding the array to fulfill size requirement + if length_diff > 0: + # window shorter than the longest window in dataset, pad to same length + window = np.concatenate([window, np.ones((length_diff, window.shape[1])) * pad_value], axis=0) + labels = np.concatenate([labels, np.ones(length_diff) * pad_value], axis=0) + pad_mask = np.concatenate([pad_mask, np.zeros(length_diff)], axis=0) + + not_labeled = np.argwhere(np.isnan(labels)) + if len(not_labeled) > 0: + labels[not_labeled] = -1 + pad_mask[not_labeled] = 0 + + pad_mask = pad_mask.astype(bool) + labels = labels.astype(np.float32) + data = window.astype(np.float32) + + return from_numpy(data), from_numpy(labels), from_numpy(pad_mask) + + def get_balance(self) -> list: + """Return the weight balance for the split of interest. + + Returns: + Weights for each label. + """ + counts = self.outcome_df[self.vars["LABEL"]].value_counts(parallel=True).get_columns()[1] + counts = counts.to_numpy() + weights = list((1 / counts) * np.sum(counts) / counts.shape[0]) + return weights + + def get_data_and_labels(self) -> Tuple[np.array, np.array, np.array]: + """Function to return all the data and labels aligned at once. + + We use this function for the ML methods which don't require an iterator. + + Returns: + A Tuple containing data points and label for the split. + """ + labels = self.outcome_df[self.vars["LABEL"]].to_numpy().astype(float) + rep = self.features_df + + if len(labels) == self.num_stays: + # order of groups could be random, we make sure not to change it + # rep = rep.groupby(level=self.vars["GROUP"], sort=False).last() + rep = rep.group_by(self.vars["GROUP"]).last() + else: + # Adding segment count for each stay id and timestep. + rep = rep.with_columns(pl.col(self.vars["GROUP"]).cum_count().over(self.vars["GROUP"]).alias("counter")) + rep = rep.to_numpy().astype(float) + logging.debug(f"rep shape: {rep.shape}") + logging.debug(f"labels shape: {labels.shape}") + return rep, labels, self.row_indicators.to_numpy() + + def to_tensor(self) -> Tuple[Tensor, Tensor, Tensor]: + data, labels, row_indicators = self.get_data_and_labels() + if self.mps: + return from_numpy(data).to(float32), from_numpy(labels).to(float32) + else: + return from_numpy(data), from_numpy(labels), row_indicators + + +@gin.configurable("CommonPandasDataset") +class CommonPandasDataset(Dataset): """Common dataset: subclass of Torch Dataset that represents the data to learn on. Args: data: Dict of the different splits of the data. split: Either 'train','val' or 'test'. vars: Contains the names of @@ -29,9 +200,11 @@ def __init__( mps: bool = False, name: str = "", ): + warnings.warn("CommonPandasDataset is deprecated. Use CommonPolarsDataset instead.", DeprecationWarning, stacklevel=2) self.split = split self.vars = vars self.grouping_df = data[split][grouping_segment].set_index(self.vars["GROUP"]) + # logging.info(f"data split: {data[split]}") self.features_df = ( data[split][Segment.features].set_index(self.vars["GROUP"]).drop(labels=self.vars["SEQUENCE"], axis=1) ) @@ -56,10 +229,10 @@ def __len__(self) -> int: """ return self.num_stays - def get_feature_names(self): + def get_feature_names(self) -> List[str]: return self.features_df.columns - def to_tensor(self): + def to_tensor(self) -> List[Tensor]: values = [] for entry in self: for i, value in enumerate(entry): @@ -69,8 +242,8 @@ def to_tensor(self): return [cat(value, dim=0) for value in values] -@gin.configurable("PredictionDataset") -class PredictionDataset(CommonDataset): +@gin.configurable("PredictionPandasDataset") +class PredictionPandasDataset(CommonPandasDataset): """Subclass of common dataset for prediction tasks. Args: @@ -133,6 +306,7 @@ def get_balance(self) -> list: Weights for each label. """ counts = self.outcome_df[self.vars["LABEL"]].value_counts() + # weights = list((1 / counts) * np.sum(counts) / counts.shape[0]) return list((1 / counts) * np.sum(counts) / counts.shape[0]) def get_data_and_labels(self) -> Tuple[np.array, np.array]: @@ -160,8 +334,8 @@ def to_tensor(self): return from_numpy(data), from_numpy(labels) -@gin.configurable("ImputationDataset") -class ImputationDataset(CommonDataset): +@gin.configurable("ImputationPandasDataset") +class ImputationPandasDataset(CommonPandasDataset): """Subclass of Common Dataset that contains data for imputation models.""" def __init__( diff --git a/icu_benchmarks/data/pooling.py b/icu_benchmarks/data/pooling.py index 7eeb2949..e8ee8ba6 100644 --- a/icu_benchmarks/data/pooling.py +++ b/icu_benchmarks/data/pooling.py @@ -3,7 +3,7 @@ import pandas as pd from sklearn.model_selection import train_test_split from .constants import DataSegment as Segment, VarType as Var -from icu_benchmarks.contants import RunMode +from icu_benchmarks.constants import RunMode import pyarrow.parquet as pq @@ -16,16 +16,17 @@ class PooledDataset: class PooledData: - def __init__(self, - data_dir, - vars, - datasets, - file_names, - shuffle=False, - stratify=None, - runmode=RunMode.classification, - save_test=True, - ): + def __init__( + self, + data_dir, + vars, + datasets, + file_names, + shuffle=False, + stratify=None, + runmode=RunMode.classification, + save_test=True, + ): """ Generate pooled data from existing datasets. Args: @@ -48,10 +49,10 @@ def __init__(self, self.save_test = save_test def generate( - self, - datasets, - samples=10000, - seed=42, + self, + datasets, + samples=10000, + seed=42, ): """ Generate pooled data from existing datasets. @@ -65,8 +66,7 @@ def generate( if folder.is_dir(): if folder.name in datasets: data[folder.name] = { - f: pq.read_table(folder / self.file_names[f]).to_pandas(self_destruct=True) for f in - self.file_names.keys() + f: pq.read_table(folder / self.file_names[f]).to_pandas(self_destruct=True) for f in self.file_names } data = self._pool_datasets( datasets=data, @@ -101,15 +101,15 @@ def _save_pooled_data(self, data_dir, data, datasets, file_names, samples=10000) logging.info(f"Saved pooled data at {save_dir}") def _pool_datasets( - self, - datasets={}, - samples=10000, - vars=[], - seed=42, - shuffle=True, - runmode=RunMode.classification, - data_dir=Path("data"), - save_test=True, + self, + datasets=None, + samples=10000, + vars=None, + seed=42, + shuffle=True, + runmode=RunMode.classification, + data_dir=Path("data"), + save_test=True, ): """ Pool datasets into a single dataset. @@ -125,6 +125,10 @@ def _pool_datasets( Returns: pooled dataset """ + if datasets is None: + datasets = {} + if vars is None: + vars = [] if len(datasets) == 0: raise ValueError("No datasets supplied.") pooled_data = {Segment.static: [], Segment.dynamic: [], Segment.outcome: []} @@ -144,8 +148,9 @@ def _pool_datasets( # If we have more outcomes than stays, check max label value per stay id labels = outcome.groupby(id).max()[vars[Var.label]].reset_index(drop=True) # if pd.Series(outcome[id].unique()) is outcome[id]): - selected_stays = train_test_split(stays, stratify=labels, shuffle=shuffle, random_state=seed, - train_size=samples) + selected_stays = train_test_split( + stays, stratify=labels, shuffle=shuffle, random_state=seed, train_size=samples + ) else: selected_stays = train_test_split(stays, shuffle=shuffle, random_state=seed, train_size=samples) # Select only stays that are in the selected_stays diff --git a/icu_benchmarks/data/preprocessor.py b/icu_benchmarks/data/preprocessor.py index 69b90f5b..9f6103dd 100644 --- a/icu_benchmarks/data/preprocessor.py +++ b/icu_benchmarks/data/preprocessor.py @@ -6,12 +6,14 @@ import gin import pandas as pd +import polars as pl from recipys.recipe import Recipe from recipys.selector import all_numeric_predictors, all_outcomes, has_type, all_of from recipys.step import ( StepScale, StepImputeFastForwardFill, StepImputeFastZeroFill, + StepImputeFill, StepSklearn, StepHistorical, Accumulator, @@ -27,9 +29,9 @@ import abc -class Preprocessor: +class Preprocessor(abc.ABC): @abc.abstractmethod - def apply(self, data, vars, save_cache=False, load_cache=None): + def apply(self, data, vars, save_cache=False, load_cache=None, vars_to_exclude=None): return data @abc.abstractmethod @@ -43,7 +45,242 @@ def set_imputation_model(self, imputation_model): @gin.configurable("base_classification_preprocessor") -class DefaultClassificationPreprocessor(Preprocessor): +class PolarsClassificationPreprocessor(Preprocessor): + def __init__( + self, + generate_features: bool = False, + scaling: bool = True, + use_static_features: bool = True, + save_cache=None, + load_cache=None, + vars_to_exclude=None, + ): + """ + Args: + generate_features: Generate features for dynamic data. + scaling: Scaling of dynamic and static data. + use_static_features: Use static features. + save_cache: Save recipe cache from this path. + load_cache: Load recipe cache from this path. + vars_to_exclude: Variables to exclude from missing indicator/ feature generation. + Returns: + Preprocessed data. + """ + self.generate_features = generate_features + self.scaling = scaling + self.use_static_features = use_static_features + self.imputation_model = None + self.save_cache = save_cache + self.load_cache = load_cache + self.vars_to_exclude = vars_to_exclude + + def apply(self, data, vars) -> dict[dict[pl.DataFrame]]: + """ + Args: + data: Train, validation and test data dictionary. Further divided in static, dynamic, and outcome. + vars: Variables for static, dynamic, outcome. + Returns: + Preprocessed data. + """ + # Check if dynamic features are present + if ( + self.use_static_features + and all(Segment.static in value for value in data.values()) + and len(vars[Segment.static]) > 0 + ): + logging.info("Preprocessing static features.") + data = self._process_static(data, vars) + else: + self.use_static_features = False + + if all(Segment.dynamic in value for value in data.values()): + logging.info("Preprocessing dynamic features.") + logging.info(data.keys()) + data = self._process_dynamic(data, vars) + if self.use_static_features: + # Join static and dynamic data. + data[Split.train][Segment.dynamic] = data[Split.train][Segment.dynamic].join( + data[Split.train][Segment.static], on=vars["GROUP"] + ) + data[Split.val][Segment.dynamic] = data[Split.val][Segment.dynamic].join( + data[Split.val][Segment.static], on=vars["GROUP"] + ) + data[Split.test][Segment.dynamic] = data[Split.test][Segment.dynamic].join( + data[Split.test][Segment.static], on=vars["GROUP"] + ) + + # Remove static features from splits + data[Split.train][Segment.features] = data[Split.train].pop(Segment.static) + data[Split.val][Segment.features] = data[Split.val].pop(Segment.static) + data[Split.test][Segment.features] = data[Split.test].pop(Segment.static) + + # Create feature splits + data[Split.train][Segment.features] = data[Split.train].pop(Segment.dynamic) + data[Split.val][Segment.features] = data[Split.val].pop(Segment.dynamic) + data[Split.test][Segment.features] = data[Split.test].pop(Segment.dynamic) + elif self.use_static_features: + data[Split.train][Segment.features] = data[Split.train].pop(Segment.static) + data[Split.val][Segment.features] = data[Split.val].pop(Segment.static) + data[Split.test][Segment.features] = data[Split.test].pop(Segment.static) + else: + raise Exception(f"No recognized data segments data to preprocess. Available: {data.keys()}") + logging.debug("Data head") + logging.debug(data[Split.train][Segment.features].head()) + logging.debug(data[Split.train][Segment.outcome]) + for split in [Split.train, Split.val, Split.test]: + if vars["SEQUENCE"] in data[split][Segment.outcome] and len(data[split][Segment.features]) != len( + data[split][Segment.outcome] + ): + raise Exception( + f"Data and outcome length mismatch in {split} split: " + f"features: {len(data[split][Segment.features])}, outcome: {len(data[split][Segment.outcome])}" + ) + data[Split.train][Segment.features] = data[Split.train][Segment.features].unique() + data[Split.val][Segment.features] = data[Split.val][Segment.features].unique() + data[Split.test][Segment.features] = data[Split.test][Segment.features].unique() + + logging.info(f"Generate features: {self.generate_features}") + return data + + def _process_static(self, data, vars): + sta_rec = Recipe(data[Split.train][Segment.static], [], vars[Segment.static]) + sta_rec.add_step(StepSklearn(MissingIndicator(features="all"), sel=all_of(vars[Segment.static]), in_place=False)) + if self.scaling: + sta_rec.add_step(StepScale()) + sta_rec.add_step(StepImputeFill(sel=all_numeric_predictors(), strategy="zero")) + # sta_rec.add_step(StepImputeFastZeroFill(sel=all_numeric_predictors())) + # if len(data[Split.train][Segment.static].select_dtypes(include=["object"]).columns) > 0: + types = ["String", "Object", "Categorical"] + sel = has_type(types) + if len(sel(sta_rec.data)) > 0: + # if len(data[Split.train][Segment.static].select(cs.by_dtype(types)).columns) > 0: + sta_rec.add_step(StepSklearn(SimpleImputer(missing_values=None, strategy="most_frequent"), sel=has_type(types))) + sta_rec.add_step(StepSklearn(LabelEncoder(), sel=has_type(types), columnwise=True)) + + data = apply_recipe_to_splits(sta_rec, data, Segment.static, self.save_cache, self.load_cache) + + return data + + def _model_impute(self, data, group=None): + dataset = ImputationPredictionDataset(data, group, self.imputation_model.trained_columns) + input_data = torch.cat([data_point.unsqueeze(0) for data_point in dataset], dim=0) + self.imputation_model.eval() + with torch.no_grad(): + logging.info(f"Imputing with {self.imputation_model.__class__.__name__}.") + imputation = self.imputation_model.predict(input_data) + logging.info("Imputation done.") + assert imputation.isnan().sum() == 0 + data = data.copy() + data.loc[:, self.imputation_model.trained_columns] = imputation.flatten(end_dim=1).to("cpu") + if group is not None: + data.drop(columns=group, inplace=True) + return data + + def _process_dynamic(self, data, vars): + dyn_rec = Recipe(data[Split.train][Segment.dynamic], [], vars[Segment.dynamic], vars["GROUP"], vars["SEQUENCE"]) + if self.scaling: + dyn_rec.add_step(StepScale()) + if self.imputation_model is not None: + dyn_rec.add_step(StepImputeModel(model=self.model_impute, sel=all_of(vars[Segment.dynamic]))) + if self.vars_to_exclude is not None: + # Exclude vars_to_exclude from missing indicator/ feature generation + vars_to_apply = list(set(vars[Segment.dynamic]) - set(self.vars_to_exclude)) + else: + vars_to_apply = vars[Segment.dynamic] + dyn_rec.add_step(StepSklearn(MissingIndicator(features="all"), sel=all_of(vars_to_apply), in_place=False)) + # dyn_rec.add_step(StepImputeFastForwardFill()) + dyn_rec.add_step(StepImputeFill(strategy="forward")) + # dyn_rec.add_step(StepImputeFastZeroFill()) + dyn_rec.add_step(StepImputeFill(strategy="zero")) + if self.generate_features: + dyn_rec = self._dynamic_feature_generation(dyn_rec, all_of(vars_to_apply)) + data = apply_recipe_to_splits(dyn_rec, data, Segment.dynamic, self.save_cache, self.load_cache) + return data + + def _dynamic_feature_generation(self, data, dynamic_vars): + logging.debug("Adding dynamic feature generation.") + data.add_step(StepHistorical(sel=dynamic_vars, fun=Accumulator.MIN, suffix="min_hist")) + data.add_step(StepHistorical(sel=dynamic_vars, fun=Accumulator.MAX, suffix="max_hist")) + data.add_step(StepHistorical(sel=dynamic_vars, fun=Accumulator.COUNT, suffix="count_hist")) + data.add_step(StepHistorical(sel=dynamic_vars, fun=Accumulator.MEAN, suffix="mean_hist")) + return data + + def to_cache_string(self): + return ( + super().to_cache_string() + + f"_classification_{self.generate_features}_{self.scaling}_{self.imputation_model.__class__.__name__}" + ) + + +@gin.configurable("base_regression_preprocessor") +class PolarsRegressionPreprocessor(PolarsClassificationPreprocessor): + # Override base classification preprocessor + def __init__( + self, + generate_features: bool = False, + scaling: bool = True, + use_static_features: bool = True, + outcome_max=None, + outcome_min=None, + save_cache=None, + load_cache=None, + ): + """ + Args: + generate_features: Generate features for dynamic data. + scaling: Scaling of dynamic and static data. + use_static_features: Use static features. + max_range: Maximum value in outcome. + min_range: Minimum value in outcome. + save_cache: Save recipe cache. + load_cache: Load recipe cache. + Returns: + Preprocessed data. + """ + super().__init__(generate_features, scaling, use_static_features, save_cache, load_cache) + self.outcome_max = outcome_max + self.outcome_min = outcome_min + + def apply(self, data, vars): + """ + Args: + data: Train, validation and test data dictionary. Further divided in static, dynamic, and outcome. + vars: Variables for static, dynamic, outcome. + Returns: + Preprocessed data. + """ + for split in [Split.train, Split.val, Split.test]: + data = self._process_outcome(data, vars, split) + + data = super().apply(data, vars) + return data + + def _process_outcome(self, data, vars, split): + logging.debug(f"Processing {split} outcome values.") + outcome_rec = Recipe(data[split][Segment.outcome], vars["LABEL"], [], vars["GROUP"]) + # If the range is predefined, use predefined transformation function + if self.outcome_max is not None and self.outcome_min is not None: + if self.outcome_max == self.outcome_min: + logging.warning("outcome_max equals outcome_min. Skipping outcome scaling.") + else: + outcome_rec.add_step( + StepSklearn( + sklearn_transformer=FunctionTransformer( + func=lambda x: ((x - self.outcome_min) / (self.outcome_max - self.outcome_min)) + ), + sel=all_outcomes(), + ) + ) + else: + # If the range is not predefined, use MinMaxScaler + outcome_rec.add_step(StepSklearn(MinMaxScaler(), sel=all_outcomes())) + outcome_rec.prep() + data[split][Segment.outcome] = outcome_rec.bake() + return data + + +@gin.configurable("pandas_classification_preprocessor") +class PandasClassificationPreprocessor(Preprocessor): def __init__( self, generate_features: bool = True, @@ -109,6 +346,11 @@ def apply(self, data, vars) -> dict[dict[pd.DataFrame]]: data[Split.train][Segment.features] = data[Split.train].pop(Segment.dynamic) data[Split.val][Segment.features] = data[Split.val].pop(Segment.dynamic) data[Split.test][Segment.features] = data[Split.test].pop(Segment.dynamic) + + logging.debug("Data head") + logging.debug(data[Split.train][Segment.features].head()) + logging.debug(data[Split.train][Segment.outcome].head()) + logging.info(f"Generate features: {self.generate_features}") return data def _process_static(self, data, vars): @@ -117,8 +359,9 @@ def _process_static(self, data, vars): sta_rec.add_step(StepScale()) sta_rec.add_step(StepImputeFastZeroFill(sel=all_numeric_predictors())) - sta_rec.add_step(StepSklearn(SimpleImputer(missing_values=None, strategy="most_frequent"), sel=has_type("object"))) - sta_rec.add_step(StepSklearn(LabelEncoder(), sel=has_type("object"), columnwise=True)) + if len(data[Split.train][Segment.static].select_dtypes(include=["object"]).columns) > 0: + sta_rec.add_step(StepSklearn(SimpleImputer(missing_values=None, strategy="most_frequent"), sel=has_type("object"))) + sta_rec.add_step(StepSklearn(LabelEncoder(), sel=has_type("object"), columnwise=True)) data = apply_recipe_to_splits(sta_rec, data, Segment.static, self.save_cache, self.load_cache) @@ -168,8 +411,8 @@ def to_cache_string(self): ) -@gin.configurable("base_regression_preprocessor") -class DefaultRegressionPreprocessor(DefaultClassificationPreprocessor): +@gin.configurable("pandas_regression_preprocessor") +class PandasRegressionPreprocessor(PandasClassificationPreprocessor): # Override base classification preprocessor def __init__( self, @@ -233,7 +476,7 @@ def _process_outcome(self, data, vars, split): @gin.configurable("base_imputation_preprocessor") -class DefaultImputationPreprocessor(Preprocessor): +class PandasImputationPreprocessor(Preprocessor): def __init__( self, scaling: bool = True, @@ -294,6 +537,7 @@ def apply_recipe_to_splits( recipe: Recipe, data: dict[dict[pd.DataFrame]], type: str, save_cache=None, load_cache=None ) -> dict[dict[pd.DataFrame]]: """Fits and transforms the training features, then transforms the validation and test features with the recipe. + Works with both Polars and Pandas versions of recipys. Args: load_cache: Load recipe from cache, for e.g. transfer learning. diff --git a/icu_benchmarks/data/split_process_data.py b/icu_benchmarks/data/split_process_data.py index ee14d90f..a47c9729 100644 --- a/icu_benchmarks/data/split_process_data.py +++ b/icu_benchmarks/data/split_process_data.py @@ -1,17 +1,19 @@ import copy import logging +import os + import gin import json import hashlib import pandas as pd -import pyarrow.parquet as pq +import polars as pl from pathlib import Path import pickle - +from timeit import default_timer as timer from sklearn.model_selection import StratifiedKFold, KFold, StratifiedShuffleSplit, ShuffleSplit - -from icu_benchmarks.data.preprocessor import Preprocessor, DefaultClassificationPreprocessor -from icu_benchmarks.contants import RunMode +from icu_benchmarks.data.preprocessor import Preprocessor, PandasClassificationPreprocessor, PolarsClassificationPreprocessor +from icu_benchmarks.constants import RunMode +from icu_benchmarks.run_utils import check_required_keys from .constants import DataSplit as Split, DataSegment as Segment, VarType as Var @@ -19,9 +21,11 @@ def preprocess_data( data_dir: Path, file_names: dict[str] = gin.REQUIRED, - preprocessor: Preprocessor = DefaultClassificationPreprocessor, + preprocessor: Preprocessor = PolarsClassificationPreprocessor, use_static: bool = True, vars: dict[str] = gin.REQUIRED, + modality_mapping: dict[str] = {}, + selected_modalities: list[str] = "all", seed: int = 42, debug: bool = False, cv_repetitions: int = 5, @@ -34,7 +38,10 @@ def preprocess_data( pretrained_imputation_model: str = None, complete_train: bool = False, runmode: RunMode = RunMode.classification, -) -> dict[dict[pd.DataFrame]]: + label: str = None, + required_var_types=["GROUP", "SEQUENCE", "LABEL"], + required_segments=[Segment.static, Segment.dynamic, Segment.outcome], +) -> dict[dict[pl.DataFrame]] or dict[dict[pd.DataFrame]]: """Perform loading, splitting, imputing and normalising of task data. Args: @@ -62,19 +69,41 @@ def preprocess_data( """ cache_dir = data_dir / "cache" - + check_required_keys(vars, required_var_types) + check_required_keys(file_names, required_segments) if not use_static: file_names.pop(Segment.static) vars.pop(Segment.static) - + if isinstance(vars[Var.label], list) and len(vars[Var.label]) > 1: + if label is not None: + vars[Var.label] = [label] + else: + logging.debug(f"Multiple labels found and no value provided. Using first label: {vars[Var.label]}") + vars[Var.label] = vars[Var.label][0] + logging.info(f"Using label: {vars[Var.label]}") + if not vars[Var.label]: + raise ValueError("No label selected after filtering.") dumped_file_names = json.dumps(file_names, sort_keys=True) dumped_vars = json.dumps(vars, sort_keys=True) cache_filename = f"s_{seed}_r_{repetition_index}_f_{fold_index}_t_{train_size}_d_{debug}" logging.log(logging.INFO, f"Using preprocessor: {preprocessor.__name__}") - preprocessor = preprocessor(use_static_features=use_static, save_cache=data_dir / "preproc" / (cache_filename + "_recipe")) - if isinstance(preprocessor, DefaultClassificationPreprocessor): + vars_to_exclude = ( + modality_mapping.get("cat_clinical_notes") + modality_mapping.get("cat_med_embeddings_map") + if ( + modality_mapping.get("cat_clinical_notes") is not None + and modality_mapping.get("cat_med_embeddings_map") is not None + ) + else None + ) + + preprocessor = preprocessor( + use_static_features=use_static, + save_cache=data_dir / "preproc" / (cache_filename + "_recipe"), + vars_to_exclude=vars_to_exclude, + ) + if isinstance(preprocessor, PandasClassificationPreprocessor): preprocessor.set_imputation_model(pretrained_imputation_model) hash_config = hashlib.md5(f"{preprocessor.to_cache_string()}{dumped_file_names}{dumped_vars}".encode("utf-8")) @@ -91,9 +120,26 @@ def preprocess_data( # Read parquet files into pandas dataframes and remove the parquet file from memory logging.info(f"Loading data from directory {data_dir.absolute()}") - data = {f: pq.read_table(data_dir / file_names[f]).to_pandas(self_destruct=True) for f in file_names.keys()} + data = { + f: pl.read_parquet(data_dir / file_names[f]) for f in file_names.keys() if os.path.exists(data_dir / file_names[f]) + } + logging.info(f"Loaded data: {list(data.keys())}") + data = check_sanitize_data(data, vars) + + if not (Segment.dynamic in data.keys()): + logging.warning("No dynamic data found, using only static data.") + + logging.debug(f"Modality mapping: {modality_mapping}") + if len(modality_mapping) > 0: + # Optional modality selection + if selected_modalities not in [None, "all", ["all"]]: + data, vars = modality_selection(data, modality_mapping, selected_modalities, vars) + else: + logging.info("Selecting all modalities.") + # Generate the splits logging.info("Generating splits.") + # complete_train = True if not complete_train: data = make_single_split( data, @@ -109,10 +155,31 @@ def preprocess_data( ) else: # If full train is set, we use all data for training/validation - data = make_train_val(data, vars, train_size=0.8, seed=seed, debug=debug, runmode=runmode) + data = make_train_val(data, vars, train_size=None, seed=seed, debug=debug, runmode=runmode) # Apply preprocessing + + start = timer() data = preprocessor.apply(data, vars) + end = timer() + logging.info(f"Preprocessing took {end - start:.2f} seconds.") + logging.info(f"Checking for NaNs and nulls in {data.keys()}.") + for dict in data.values(): + for key, val in dict.items(): + logging.debug(f"Data type: {key}") + logging.debug("Is NaN:") + sel = dict[key].select(pl.selectors.numeric().is_nan().max()) + logging.debug(sel.select(col.name for col in sel if col.item(0))) + # logging.info(dict[key].select(pl.all().has_nulls()).sum_horizontal()) + logging.debug("Has nulls:") + sel = dict[key].select(pl.all().has_nulls()) + logging.debug(sel.select(col.name for col in sel if col.item(0))) + # dict[key] = val[:, [not (s.null_count() > 0) for s in val]] + dict[key] = val.fill_null(strategy="zero") + dict[key] = val.fill_nan(0) + logging.debug("Dropping columns with nulls") + sel = dict[key].select(pl.all().has_nulls()) + logging.debug(sel.select(col.name for col in sel if col.item(0))) # Generate cache if generate_cache: @@ -125,6 +192,59 @@ def preprocess_data( return data +def check_sanitize_data(data, vars): + """Check for duplicates in the loaded data and remove them.""" + group = vars[Var.group] if Var.group in vars.keys() else None + sequence = vars[Var.sequence] if Var.sequence in vars.keys() else None + keep = "last" + if Segment.static in data.keys(): + old_len = len(data[Segment.static]) + data[Segment.static] = data[Segment.static].unique(subset=group, keep=keep, maintain_order=True) + logging.warning(f"Removed {old_len - len(data[Segment.static])} duplicates from static data.") + if Segment.dynamic in data.keys(): + old_len = len(data[Segment.dynamic]) + data[Segment.dynamic] = data[Segment.dynamic].unique(subset=[group, sequence], keep=keep, maintain_order=True) + logging.warning(f"Removed {old_len - len(data[Segment.dynamic])} duplicates from dynamic data.") + if Segment.outcome in data.keys(): + old_len = len(data[Segment.outcome]) + if sequence in data[Segment.outcome].columns: + # We have a dynamic outcome with group and sequence + data[Segment.outcome] = data[Segment.outcome].unique(subset=[group, sequence], keep=keep, maintain_order=True) + else: + data[Segment.outcome] = data[Segment.outcome].unique(subset=[group], keep=keep, maintain_order=True) + logging.warning(f"Removed {old_len - len(data[Segment.outcome])} duplicates from outcome data.") + return data + + +def modality_selection( + data: dict[pl.DataFrame], modality_mapping: dict[str], selected_modalities: list[str], vars +) -> dict[pl.DataFrame]: + logging.info(f"Selected modalities: {selected_modalities}") + selected_columns = [modality_mapping[cols] for cols in selected_modalities if cols in modality_mapping.keys()] + if not any(col in modality_mapping.keys() for col in selected_modalities): + raise ValueError("None of the selected modalities found in modality mapping.") + if selected_columns == []: + logging.info("No columns selected. Using all columns.") + return data, vars + selected_columns = sum(selected_columns, []) + selected_columns.extend([vars[Var.group], vars[Var.label], vars[Var.sequence]]) + old_columns = [] + # Update vars dict + for key, value in vars.items(): + if key not in [Var.group, Var.label, Var.sequence]: + old_columns.extend(value) + vars[key] = [col for col in value if col in selected_columns] + # -3 because of standard columns + logging.info(f"Selected columns: {len(selected_columns) - 3}, old columns: {len(old_columns)}") + logging.debug(f"Difference: {set(old_columns) - set(selected_columns)}") + # Update data dict + for key in data.keys(): + sel_col = [col for col in data[key].columns if col in selected_columns] + data[key] = data[key].select(sel_col) + logging.debug(f"Selected columns in {key}: {len(data[key].columns)}") + return data, vars + + def make_train_val( data: dict[pd.DataFrame], vars: dict[str], @@ -132,7 +252,8 @@ def make_train_val( seed: int = 42, debug: bool = False, runmode: RunMode = RunMode.classification, -) -> dict[dict[pd.DataFrame]]: + polars: bool = True, +) -> dict[dict[pl.DataFrame]]: """Randomly split the data into training and validation sets for fitting a full model. Args: @@ -147,40 +268,78 @@ def make_train_val( # ID variable id = vars[Var.group] - # Get stay IDs from outcome segment - stays = pd.Series(data[Segment.outcome][id].unique(), name=id) - if debug: # Only use 1% of the data - stays = stays.sample(frac=0.01, random_state=seed) + logging.info("Using only 1% of the data for debugging. Note that this might lead to errors for small datasets.") + if polars: + data[Segment.outcome] = data[Segment.outcome].sample(fraction=0.01, seed=seed) + else: + data[Segment.outcome] = data[Segment.outcome].sample(frac=0.01, random_state=seed) + + # Get stay IDs from outcome segment + stays = _get_stays(data, id, polars) # If there are labels, and the task is classification, use stratified k-fold if Var.label in vars and runmode is RunMode.classification: # Get labels from outcome data (takes the highest value (or True) in case seq2seq classification) - labels = data[Segment.outcome].groupby(id).max()[vars[Var.label]].reset_index(drop=True) - if train_size: - train_val = StratifiedShuffleSplit(train_size=train_size, random_state=seed, n_splits=1) + labels = _get_labels(data, id, vars, polars) + train_val = StratifiedShuffleSplit(train_size=train_size, random_state=seed, n_splits=1) train, val = list(train_val.split(stays, labels))[0] + else: # If there are no labels, use random split train_val = ShuffleSplit(train_size=train_size, random_state=seed) train, val = list(train_val.split(stays))[0] - split = {Split.train: stays.iloc[train], Split.val: stays.iloc[val]} + if polars: + split = { + Split.train: stays[train].cast(pl.datatypes.Int64).to_frame(), + Split.val: stays[val].cast(pl.datatypes.Int64).to_frame(), + } + else: + split = {Split.train: stays.iloc[train], Split.val: stays.iloc[val]} data_split = {} for fold in split.keys(): # Loop through splits (train / val / test) # Loop through segments (DYNAMIC / STATIC / OUTCOME) # set sort to true to make sure that IDs are reordered after scrambling earlier - data_split[fold] = { - data_type: data[data_type].merge(split[fold], on=id, how="right", sort=True) for data_type in data.keys() - } + if polars: + data_split[fold] = { + data_type: split[fold] + .join(data[data_type].with_columns(pl.col(id).cast(pl.datatypes.Int64)), on=id, how="left") + .sort(by=id) + for data_type in data.keys() + } + else: + data_split[fold] = { + data_type: data[data_type].merge(split[fold], on=id, how="right", sort=True) for data_type in data.keys() + } + # Maintain compatibility with test split data_split[Split.test] = copy.deepcopy(data_split[Split.val]) return data_split +def _get_stays(data, id, polars): + return ( + pl.Series(name=id, values=data[Segment.outcome][id].unique()) + if polars + else pd.Series(data[Segment.outcome][id].unique(), name=id) + ) + + +def _get_labels(data, id, vars, polars): + # Get labels from outcome data (takes the highest value (or True) in case seq2seq classification) + if polars: + return data[Segment.outcome].group_by(id).max()[vars[Var.label]] + else: + return data[Segment.outcome].groupby(id).max()[vars[Var.label]].reset_index(drop=True) + + +# Use these helper functions in both make_train_val and make_single_split + + def make_single_split( data: dict[pd.DataFrame], vars: dict[str], @@ -192,7 +351,8 @@ def make_single_split( seed: int = 42, debug: bool = False, runmode: RunMode = RunMode.classification, -) -> dict[dict[pd.DataFrame]]: + polars: bool = True, +) -> dict[dict[pl.DataFrame]]: """Randomly split the data into training, validation, and test set. Args: @@ -216,19 +376,34 @@ def make_single_split( if debug: # Only use 1% of the data logging.info("Using only 1% of the data for debugging. Note that this might lead to errors for small datasets.") - data[Segment.outcome] = data[Segment.outcome].sample(frac=0.01, random_state=seed) + if polars: + data[Segment.outcome] = data[Segment.outcome].sample(fraction=0.01, seed=seed) + else: + data[Segment.outcome] = data[Segment.outcome].sample(frac=0.01, random_state=seed) # Get stay IDs from outcome segment - stays = pd.Series(data[Segment.outcome][id].unique(), name=id) + if polars: + stays = pl.Series(name=id, values=data[Segment.outcome][id].unique()) + else: + stays = pd.Series(data[Segment.outcome][id].unique(), name=id) # If there are labels, and the task is classification, use stratified k-fold if Var.label in vars and runmode is RunMode.classification: # Get labels from outcome data (takes the highest value (or True) in case seq2seq classification) - labels = data[Segment.outcome].groupby(id).max()[vars[Var.label]].reset_index(drop=True) - if labels.value_counts().min() < cv_folds: - raise Exception( - f"The smallest amount of samples in a class is: {labels.value_counts().min()}, " - f"but {cv_folds} folds are requested. Reduce the number of folds or use more data." - ) + if polars: + labels = data[Segment.outcome].group_by(id).max()[vars[Var.label]] + if labels.value_counts().min().item(0, 1) < cv_folds: + raise Exception( + f"The smallest amount of samples in a class is: {labels.value_counts().min()}, " + f"but {cv_folds} folds are requested. Reduce the number of folds or use more data." + ) + else: + labels = data[Segment.outcome].groupby(id).max()[vars[Var.label]].reset_index(drop=True) + if labels.value_counts().min() < cv_folds: + raise Exception( + f"The smallest amount of samples in a class is: {labels.value_counts().min()}, " + f"but {cv_folds} folds are requested. Reduce the number of folds or use more data." + ) + if train_size: outer_cv = StratifiedShuffleSplit(cv_repetitions, train_size=train_size) else: @@ -236,8 +411,12 @@ def make_single_split( inner_cv = StratifiedKFold(cv_folds, shuffle=True, random_state=seed) dev, test = list(outer_cv.split(stays, labels))[repetition_index] - dev_stays = stays.iloc[dev] - train, val = list(inner_cv.split(dev_stays, labels.iloc[dev]))[fold_index] + if polars: + dev_stays = stays[dev] + train, val = list(inner_cv.split(dev_stays, labels[dev]))[fold_index] + else: + dev_stays = stays.iloc[dev] + train, val = list(inner_cv.split(dev_stays, labels.iloc[dev]))[fold_index] else: # If there are no labels, or the task is regression, use regular k-fold. if train_size: @@ -247,23 +426,40 @@ def make_single_split( inner_cv = KFold(cv_folds, shuffle=True, random_state=seed) dev, test = list(outer_cv.split(stays))[repetition_index] - dev_stays = stays.iloc[dev] + if polars: + dev_stays = stays[dev] + else: + dev_stays = stays.iloc[dev] train, val = list(inner_cv.split(dev_stays))[fold_index] - - split = { - Split.train: dev_stays.iloc[train], - Split.val: dev_stays.iloc[val], - Split.test: stays.iloc[test], - } + if polars: + split = { + Split.train: dev_stays[train].cast(pl.datatypes.Int64).to_frame(), + Split.val: dev_stays[val].cast(pl.datatypes.Int64).to_frame(), + Split.test: stays[test].cast(pl.datatypes.Int64).to_frame(), + } + else: + split = { + Split.train: dev_stays.iloc[train], + Split.val: dev_stays.iloc[val], + Split.test: stays.iloc[test], + } data_split = {} for fold in split.keys(): # Loop through splits (train / val / test) # Loop through segments (DYNAMIC / STATIC / OUTCOME) # set sort to true to make sure that IDs are reordered after scrambling earlier - data_split[fold] = { - data_type: data[data_type].merge(split[fold], on=id, how="right", sort=True) for data_type in data.keys() - } - + if polars: + data_split[fold] = { + data_type: split[fold] + .join(data[data_type].with_columns(pl.col(id).cast(pl.datatypes.Int64)), on=id, how="left") + .sort(by=id) + for data_type in data.keys() + } + else: + data_split[fold] = { + data_type: data[data_type].merge(split[fold], on=id, how="right", sort=True) for data_type in data.keys() + } + logging.info(f"Data split: {data_split}") return data_split diff --git a/icu_benchmarks/models/constants.py b/icu_benchmarks/models/constants.py index 45af8271..43843db8 100644 --- a/icu_benchmarks/models/constants.py +++ b/icu_benchmarks/models/constants.py @@ -26,6 +26,7 @@ MAE, JSD, BinaryFairnessWrapper, + confusion_matrix, ) @@ -36,13 +37,15 @@ class MLMetrics: "PR": average_precision_score, "PR_Curve": precision_recall_curve, "RO_Curve": roc_curve, + "Confusion_Matrix": confusion_matrix, } MULTICLASS_CLASSIFICATION = { "Accuracy": accuracy_score, "AUC": roc_auc_score, "Balanced_Accuracy": balanced_accuracy_score, - "PR": average_precision_score, + # "PR": average_precision_score, + "Confusion_Matrix": confusion_matrix, } REGRESSION = { diff --git a/icu_benchmarks/models/custom_metrics.py b/icu_benchmarks/models/custom_metrics.py index ddb5d37e..eb0a5d23 100644 --- a/icu_benchmarks/models/custom_metrics.py +++ b/icu_benchmarks/models/custom_metrics.py @@ -2,7 +2,8 @@ from typing import Callable import numpy as np from ignite.metrics import EpochMetric -from sklearn.metrics import balanced_accuracy_score, mean_absolute_error +from numpy import ndarray +from sklearn.metrics import balanced_accuracy_score, mean_absolute_error, confusion_matrix as sk_confusion_matrix from sklearn.calibration import calibration_curve from scipy.spatial.distance import jensenshannon from torchmetrics.classification import BinaryFairness @@ -130,3 +131,15 @@ def feature_helper(self, trainer, step_prefix): else: feature_names = trainer.test_dataloaders.dataset.features return feature_names + + +def confusion_matrix(y_true: ndarray, y_pred: ndarray, normalize=False) -> torch.tensor: + y_pred = np.rint(y_pred).astype(int) + confusion = sk_confusion_matrix(y_true, y_pred) + if normalize: + confusion = confusion / confusion.sum() + confusion_dict = {} + for i in range(confusion.shape[0]): + for j in range(confusion.shape[1]): + confusion_dict[f"class_{i}_pred_{j}"] = confusion[i][j] + return confusion_dict diff --git a/icu_benchmarks/models/dl_models.py b/icu_benchmarks/models/dl_models.py deleted file mode 100644 index 0fb1b0d2..00000000 --- a/icu_benchmarks/models/dl_models.py +++ /dev/null @@ -1,282 +0,0 @@ -import gin -from numbers import Integral -import numpy as np -import torch.nn as nn -from icu_benchmarks.contants import RunMode -from icu_benchmarks.models.layers import TransformerBlock, LocalBlock, TemporalBlock, PositionalEncoding -from icu_benchmarks.models.wrappers import DLPredictionWrapper - - -@gin.configurable -class RNNet(DLPredictionWrapper): - """Torch standard RNN model""" - - _supported_run_modes = [RunMode.classification, RunMode.regression] - - def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs): - super().__init__( - input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, *args, **kwargs - ) - self.hidden_dim = hidden_dim - self.layer_dim = layer_dim - self.rnn = nn.RNN(input_size[2], hidden_dim, layer_dim, batch_first=True) - self.logit = nn.Linear(hidden_dim, num_classes) - - def init_hidden(self, x): - h0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim) - return h0 - - def forward(self, x): - h0 = self.init_hidden(x) - out, hn = self.rnn(x, h0) - pred = self.logit(out) - return pred - - -@gin.configurable -class LSTMNet(DLPredictionWrapper): - """Torch standard LSTM model.""" - - _supported_run_modes = [RunMode.classification, RunMode.regression] - - def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs): - super().__init__( - input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, *args, **kwargs - ) - self.hidden_dim = hidden_dim - self.layer_dim = layer_dim - self.rnn = nn.LSTM(input_size[2], hidden_dim, layer_dim, batch_first=True) - self.logit = nn.Linear(hidden_dim, num_classes) - - def init_hidden(self, x): - h0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim) - c0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim) - return [t for t in (h0, c0)] - - def forward(self, x): - h0, c0 = self.init_hidden(x) - out, h = self.rnn(x, (h0, c0)) - pred = self.logit(out) - return pred - - -@gin.configurable -class GRUNet(DLPredictionWrapper): - """Torch standard GRU model.""" - - _supported_run_modes = [RunMode.classification, RunMode.regression] - - def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs): - super().__init__( - input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, *args, **kwargs - ) - self.hidden_dim = hidden_dim - self.layer_dim = layer_dim - self.rnn = nn.GRU(input_size[2], hidden_dim, layer_dim, batch_first=True) - self.logit = nn.Linear(hidden_dim, num_classes) - - def init_hidden(self, x): - h0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim) - return h0 - - def forward(self, x): - h0 = self.init_hidden(x) - out, hn = self.rnn(x, h0) - pred = self.logit(out) - - return pred - - -@gin.configurable -class Transformer(DLPredictionWrapper): - """Transformer model as defined by the HiRID-Benchmark (https://github.com/ratschlab/HIRID-ICU-Benchmark).""" - - _supported_run_modes = [RunMode.classification, RunMode.regression] - - def __init__( - self, - input_size, - hidden, - heads, - ff_hidden_mult, - depth, - num_classes, - *args, - dropout=0.0, - l1_reg=0, - pos_encoding=True, - dropout_att=0.0, - **kwargs, - ): - super().__init__( - input_size=input_size, - hidden=hidden, - heads=heads, - ff_hidden_mult=ff_hidden_mult, - depth=depth, - num_classes=num_classes, - *args, - dropout=dropout, - l1_reg=l1_reg, - pos_encoding=pos_encoding, - dropout_att=dropout_att, - **kwargs, - ) - hidden = hidden if hidden % 2 == 0 else hidden + 1 # Make sure hidden is even - self.input_embedding = nn.Linear(input_size[2], hidden) # This acts as a time-distributed layer by defaults - if pos_encoding: - self.pos_encoder = PositionalEncoding(hidden) - else: - self.pos_encoder = None - - tblocks = [] - for i in range(depth): - tblocks.append( - TransformerBlock( - emb=hidden, - hidden=hidden, - heads=heads, - mask=True, - ff_hidden_mult=ff_hidden_mult, - dropout=dropout, - dropout_att=dropout_att, - ) - ) - - self.tblocks = nn.Sequential(*tblocks) - self.logit = nn.Linear(hidden, num_classes) - self.l1_reg = l1_reg - - def forward(self, x): - x = self.input_embedding(x) - if self.pos_encoder is not None: - x = self.pos_encoder(x) - x = self.tblocks(x) - pred = self.logit(x) - - return pred - - -@gin.configurable -class LocalTransformer(DLPredictionWrapper): - _supported_run_modes = [RunMode.classification, RunMode.regression] - - def __init__( - self, - input_size, - hidden, - heads, - ff_hidden_mult, - depth, - num_classes, - *args, - dropout=0.0, - l1_reg=0, - pos_encoding=True, - local_context=1, - dropout_att=0.0, - **kwargs, - ): - super().__init__( - input_size=input_size, - hidden=hidden, - heads=heads, - ff_hidden_mult=ff_hidden_mult, - depth=depth, - num_classes=num_classes, - *args, - dropout=dropout, - l1_reg=l1_reg, - pos_encoding=pos_encoding, - local_context=local_context, - dropout_att=dropout_att, - **kwargs, - ) - - hidden = hidden if hidden % 2 == 0 else hidden + 1 # Make sure hidden is even - self.input_embedding = nn.Linear(input_size[2], hidden) # This acts as a time-distributed layer by defaults - if pos_encoding: - self.pos_encoder = PositionalEncoding(hidden) - else: - self.pos_encoder = None - - tblocks = [] - for i in range(depth): - tblocks.append( - LocalBlock( - emb=hidden, - hidden=hidden, - heads=heads, - mask=True, - ff_hidden_mult=ff_hidden_mult, - local_context=local_context, - dropout=dropout, - dropout_att=dropout_att, - ) - ) - - self.tblocks = nn.Sequential(*tblocks) - self.logit = nn.Linear(hidden, num_classes) - self.l1_reg = l1_reg - - def forward(self, x): - x = self.input_embedding(x) - if self.pos_encoder is not None: - x = self.pos_encoder(x) - x = self.tblocks(x) - pred = self.logit(x) - - return pred - - -@gin.configurable -class TemporalConvNet(DLPredictionWrapper): - """Temporal Convolutional Network. Adapted from TCN original paper https://github.com/locuslab/TCN""" - - _supported_run_modes = [RunMode.classification, RunMode.regression] - - def __init__(self, input_size, num_channels, num_classes, *args, max_seq_length=0, kernel_size=2, dropout=0.0, **kwargs): - super().__init__( - input_size=input_size, - num_channels=num_channels, - num_classes=num_classes, - *args, - max_seq_length=max_seq_length, - kernel_size=kernel_size, - dropout=dropout, - **kwargs, - ) - layers = [] - - # We compute automatically the depth based on the desired seq_length. - if isinstance(num_channels, Integral) and max_seq_length: - num_channels = [num_channels] * int(np.ceil(np.log(max_seq_length / 2) / np.log(kernel_size))) - elif isinstance(num_channels, Integral) and not max_seq_length: - raise Exception("a maximum sequence length needs to be provided if num_channels is int") - - num_levels = len(num_channels) - for i in range(num_levels): - dilation_size = 2**i - in_channels = input_size[2] if i == 0 else num_channels[i - 1] - out_channels = num_channels[i] - layers += [ - TemporalBlock( - in_channels, - out_channels, - kernel_size, - stride=1, - dilation=dilation_size, - padding=(kernel_size - 1) * dilation_size, - dropout=dropout, - ) - ] - - self.network = nn.Sequential(*layers) - self.logit = nn.Linear(num_channels[-1], num_classes) - - def forward(self, x): - x = x.permute(0, 2, 1) # Permute to channel first - o = self.network(x) - o = o.permute(0, 2, 1) # Permute to channel last - pred = self.logit(o) - return pred diff --git a/icu_benchmarks/models/dl_models/__init__.py b/icu_benchmarks/models/dl_models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/icu_benchmarks/models/layers.py b/icu_benchmarks/models/dl_models/layers.py similarity index 100% rename from icu_benchmarks/models/layers.py rename to icu_benchmarks/models/dl_models/layers.py diff --git a/icu_benchmarks/models/dl_models/rnn.py b/icu_benchmarks/models/dl_models/rnn.py new file mode 100644 index 00000000..4f0c65bc --- /dev/null +++ b/icu_benchmarks/models/dl_models/rnn.py @@ -0,0 +1,84 @@ +import gin +from torch import nn as nn + +from icu_benchmarks.constants import RunMode +from icu_benchmarks.models.wrappers import DLPredictionWrapper + + +@gin.configurable +class RNNet(DLPredictionWrapper): + """Torch standard RNN model""" + + _supported_run_modes = [RunMode.classification, RunMode.regression] + + def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs): + super().__init__( + *args, input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, **kwargs + ) + self.hidden_dim = hidden_dim + self.layer_dim = layer_dim + self.rnn = nn.RNN(input_size[2], hidden_dim, layer_dim, batch_first=True) + self.logit = nn.Linear(hidden_dim, num_classes) + + def init_hidden(self, x): + h0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim) + return h0 + + def forward(self, x): + h0 = self.init_hidden(x) + out, hn = self.rnn(x, h0) + pred = self.logit(out) + return pred + + +@gin.configurable +class LSTMNet(DLPredictionWrapper): + """Torch standard LSTM model.""" + + _supported_run_modes = [RunMode.classification, RunMode.regression] + + def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs): + super().__init__( + *args, input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, **kwargs + ) + self.hidden_dim = hidden_dim + self.layer_dim = layer_dim + self.rnn = nn.LSTM(input_size[2], hidden_dim, layer_dim, batch_first=True) + self.logit = nn.Linear(hidden_dim, num_classes) + + def init_hidden(self, x): + h0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim) + c0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim) + return [t for t in (h0, c0)] + + def forward(self, x): + h0, c0 = self.init_hidden(x) + out, h = self.rnn(x, (h0, c0)) + pred = self.logit(out) + return pred + + +@gin.configurable +class GRUNet(DLPredictionWrapper): + """Torch standard GRU model.""" + + _supported_run_modes = [RunMode.classification, RunMode.regression] + + def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs): + super().__init__( + *args, input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, **kwargs + ) + self.hidden_dim = hidden_dim + self.layer_dim = layer_dim + self.rnn = nn.GRU(input_size[2], hidden_dim, layer_dim, batch_first=True) + self.logit = nn.Linear(hidden_dim, num_classes) + + def init_hidden(self, x): + h0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim) + return h0 + + def forward(self, x): + h0 = self.init_hidden(x) + out, hn = self.rnn(x, h0) + pred = self.logit(out) + return pred diff --git a/icu_benchmarks/models/dl_models/tcn.py b/icu_benchmarks/models/dl_models/tcn.py new file mode 100644 index 00000000..8be71fea --- /dev/null +++ b/icu_benchmarks/models/dl_models/tcn.py @@ -0,0 +1,62 @@ +from numbers import Integral + +import gin +import numpy as np +from torch import nn as nn + +from icu_benchmarks.constants import RunMode +from icu_benchmarks.models.dl_models.layers import TemporalBlock +from icu_benchmarks.models.wrappers import DLPredictionWrapper + + +@gin.configurable +class TemporalConvNet(DLPredictionWrapper): + """Temporal Convolutional Network. Adapted from TCN original paper https://github.com/locuslab/TCN""" + + _supported_run_modes = [RunMode.classification, RunMode.regression] + + def __init__(self, input_size, num_channels, num_classes, *args, max_seq_length=0, kernel_size=2, dropout=0.0, **kwargs): + super().__init__( + *args, + input_size=input_size, + num_channels=num_channels, + num_classes=num_classes, + max_seq_length=max_seq_length, + kernel_size=kernel_size, + dropout=dropout, + **kwargs, + ) + layers = [] + + # We compute automatically the depth based on the desired seq_length. + if isinstance(num_channels, Integral) and max_seq_length: + num_channels = [num_channels] * int(np.ceil(np.log(max_seq_length / 2) / np.log(kernel_size))) + elif isinstance(num_channels, Integral) and not max_seq_length: + raise Exception("a maximum sequence length needs to be provided if num_channels is int") + + num_levels = len(num_channels) + for i in range(num_levels): + dilation_size = 2**i + in_channels = input_size[2] if i == 0 else num_channels[i - 1] + out_channels = num_channels[i] + layers += [ + TemporalBlock( + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=dilation_size, + padding=(kernel_size - 1) * dilation_size, + dropout=dropout, + ) + ] + + self.network = nn.Sequential(*layers) + self.logit = nn.Linear(num_channels[-1], num_classes) + + def forward(self, x): + x = x.permute(0, 2, 1) # Permute to channel first + o = self.network(x) + o = o.permute(0, 2, 1) # Permute to channel last + pred = self.logit(o) + return pred diff --git a/icu_benchmarks/models/dl_models/transformer.py b/icu_benchmarks/models/dl_models/transformer.py new file mode 100644 index 00000000..ed7d4a2c --- /dev/null +++ b/icu_benchmarks/models/dl_models/transformer.py @@ -0,0 +1,81 @@ +import gin +from torch import nn as nn + +from icu_benchmarks.constants import RunMode +from icu_benchmarks.models.dl_models.layers import PositionalEncoding, TransformerBlock, LocalBlock +from icu_benchmarks.models.wrappers import DLPredictionWrapper + + +class BaseTransformer(DLPredictionWrapper): + _supported_run_modes = [RunMode.classification, RunMode.regression] + """Refactored Transformer model as defined by the HiRID-Benchmark (https://github.com/ratschlab/HIRID-ICU-Benchmark).""" + + def __init__( + self, + block_class, + input_size, + hidden, + heads, + ff_hidden_mult, + depth, + num_classes, + dropout=0.0, + l1_reg=0, + pos_encoding=True, + dropout_att=0.0, + local_context=None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + if local_context is not None and self._get_name() == "Transformer": + raise ValueError("Local context is only supported for LocalTransformer") + hidden = hidden if hidden % 2 == 0 else hidden + 1 # Make sure hidden is even + self.input_embedding = nn.Linear(input_size[2], hidden) # This acts as a time-distributed layer by defaults + if pos_encoding: + self.pos_encoder = PositionalEncoding(hidden) + else: + self.pos_encoder = None + + t_blocks = [] + for _ in range(depth): + t_blocks.append( + block_class( + emb=hidden, + hidden=hidden, + heads=heads, + mask=True, + ff_hidden_mult=ff_hidden_mult, + dropout=dropout, + dropout_att=dropout_att, + **({"local_context": local_context} if local_context is not None else {}), + ) + ) + + self.t_blocks = nn.Sequential(*t_blocks) + self.logit = nn.Linear(hidden, num_classes) + self.l1_reg = l1_reg + + def forward(self, x): + x = self.input_embedding(x) + if self.pos_encoder is not None: + x = self.pos_encoder(x) + x = self.t_blocks(x) + pred = self.logit(x) + return pred + + +@gin.configurable +class Transformer(BaseTransformer): + """Transformer model.""" + + def __init__(self, *kwargs, **args): + super().__init__(TransformerBlock, *kwargs, **args) + + +@gin.configurable +class LocalTransformer(BaseTransformer): + """Transformer model with local context.""" + + def __init__(self, *kwargs, **args): + super().__init__(LocalBlock, *kwargs, **args) diff --git a/icu_benchmarks/models/ml_models/__init__.py b/icu_benchmarks/models/ml_models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/icu_benchmarks/models/ml_models/catboost.py b/icu_benchmarks/models/ml_models/catboost.py new file mode 100644 index 00000000..4bbebea1 --- /dev/null +++ b/icu_benchmarks/models/ml_models/catboost.py @@ -0,0 +1,26 @@ +import gin +import catboost as cb +from icu_benchmarks.constants import RunMode +from icu_benchmarks.models.wrappers import MLWrapper + + +@gin.configurable +class CBClassifier(MLWrapper): + _supported_run_modes = [RunMode.classification] + + def __init__(self, task_type="CPU", *args, **kwargs): + model_kwargs = {"task_type": task_type, **kwargs} + self.model = self.set_model_args(cb.CatBoostClassifier, *args, **model_kwargs) + super().__init__(*args, **kwargs) + + def predict(self, features): + """ + Predicts class probabilities for the given features. + + Args: + features: Input features for prediction. + + Returns: + numpy.ndarray: Predicted probabilities for each class. + """ + return self.model.predict_proba(features) diff --git a/icu_benchmarks/models/ml_models/imblearn.py b/icu_benchmarks/models/ml_models/imblearn.py new file mode 100644 index 00000000..d1db0703 --- /dev/null +++ b/icu_benchmarks/models/ml_models/imblearn.py @@ -0,0 +1,22 @@ +from imblearn.ensemble import BalancedRandomForestClassifier, RUSBoostClassifier +from icu_benchmarks.constants import RunMode +from icu_benchmarks.models.wrappers import MLWrapper +import gin + + +@gin.configurable +class BRFClassifier(MLWrapper): + _supported_run_modes = [RunMode.classification] + + def __init__(self, *args, **kwargs): + self.model = self.set_model_args(BalancedRandomForestClassifier, *args, **kwargs) + super().__init__(*args, **kwargs) + + +@gin.configurable +class RUSBClassifier(MLWrapper): + _supported_run_modes = [RunMode.classification] + + def __init__(self, *args, **kwargs): + self.model = self.set_model_args(RUSBoostClassifier, *args, **kwargs) + super().__init__(*args, **kwargs) diff --git a/icu_benchmarks/models/ml_models/lgbm.py b/icu_benchmarks/models/ml_models/lgbm.py new file mode 100644 index 00000000..c2207555 --- /dev/null +++ b/icu_benchmarks/models/ml_models/lgbm.py @@ -0,0 +1,57 @@ +import gin +import lightgbm as lgbm +import numpy as np +import wandb +from wandb.integration.lightgbm import wandb_callback as wandb_lgbm + +from icu_benchmarks.constants import RunMode +from icu_benchmarks.models.wrappers import MLWrapper + + +class LGBMWrapper(MLWrapper): + def fit_model(self, train_data, train_labels, val_data, val_labels): + """Fitting function for LGBM models.""" + self.model.set_params(random_state=np.random.get_state()[1][0]) + callbacks = [lgbm.early_stopping(self.hparams.patience, verbose=True), lgbm.log_evaluation(period=-1)] + + if wandb.run is not None: + callbacks.append(wandb_lgbm()) + + self.model = self.model.fit( + train_data, + train_labels, + eval_set=(val_data, val_labels), + callbacks=callbacks, + ) + val_loss = list(self.model.best_score_["valid_0"].values())[0] + return val_loss + + +@gin.configurable +class LGBMClassifier(LGBMWrapper): + _supported_run_modes = [RunMode.classification] + + def __init__(self, *args, **kwargs): + self.model = self.set_model_args(lgbm.LGBMClassifier, *args, **kwargs) + super().__init__(*args, **kwargs) + + def predict(self, features): + """ + Predicts class probabilities for the given features. + + Args: + features: Input features for prediction. + + Returns: + numpy.ndarray: Predicted probabilities for each class. + """ + return self.model.predict_proba(features) + + +@gin.configurable +class LGBMRegressor(LGBMWrapper): + _supported_run_modes = [RunMode.regression] + + def __init__(self, *args, **kwargs): + self.model = self.set_model_args(lgbm.LGBMRegressor, *args, **kwargs) + super().__init__(*args, **kwargs) diff --git a/icu_benchmarks/models/ml_models.py b/icu_benchmarks/models/ml_models/sklearn.py similarity index 59% rename from icu_benchmarks/models/ml_models.py rename to icu_benchmarks/models/ml_models/sklearn.py index 5e52921b..1fe7a87b 100644 --- a/icu_benchmarks/models/ml_models.py +++ b/icu_benchmarks/models/ml_models/sklearn.py @@ -1,59 +1,9 @@ import gin -import lightgbm as lgbm -import numpy as np -import wandb -from sklearn import linear_model -from sklearn import ensemble -from sklearn import neural_network -from sklearn import svm +from sklearn import linear_model, ensemble, svm, neural_network +from icu_benchmarks.constants import RunMode from icu_benchmarks.models.wrappers import MLWrapper -from icu_benchmarks.contants import RunMode -from wandb.integration.lightgbm import wandb_callback -class LGBMWrapper(MLWrapper): - def fit_model(self, train_data, train_labels, val_data, val_labels): - """Fitting function for LGBM models.""" - self.model.set_params(random_state=np.random.get_state()[1][0]) - callbacks = [lgbm.early_stopping(self.hparams.patience, verbose=True), lgbm.log_evaluation(period=-1)] - - if wandb.run is not None: - callbacks.append(wandb_callback()) - - self.model = self.model.fit( - train_data, - train_labels, - eval_set=(val_data, val_labels), - verbose=True, - callbacks=callbacks, - ) - val_loss = list(self.model.best_score_["valid_0"].values())[0] - return val_loss - - -@gin.configurable -class LGBMClassifier(LGBMWrapper): - _supported_run_modes = [RunMode.classification] - - def __init__(self, *args, **kwargs): - self.model = self.set_model_args(lgbm.LGBMClassifier, *args, **kwargs) - super().__init__(*args, **kwargs) - - def predict(self, features): - """Predicts labels for the given features.""" - return self.model.predict_proba(features) - - -@gin.configurable -class LGBMRegressor(LGBMWrapper): - _supported_run_modes = [RunMode.regression] - - def __init__(self, *args, **kwargs): - self.model = self.set_model_args(lgbm.LGBMRegressor, *args, **kwargs) - super().__init__(*args, **kwargs) - - -# Scikit-learn models @gin.configurable class LogisticRegression(MLWrapper): __supported_run_modes = [RunMode.classification] diff --git a/icu_benchmarks/models/ml_models/xgboost.py b/icu_benchmarks/models/ml_models/xgboost.py new file mode 100644 index 00000000..5ca738ac --- /dev/null +++ b/icu_benchmarks/models/ml_models/xgboost.py @@ -0,0 +1,74 @@ +import inspect +import logging +from statistics import mean + +import gin +import shap +import wandb +import xgboost as xgb +from xgboost.callback import EarlyStopping +from wandb.integration.xgboost import wandb_callback as wandb_xgb + +from icu_benchmarks.constants import RunMode +from icu_benchmarks.models.wrappers import MLWrapper + + +# Uncomment if needed in the future +# from optuna.integration import XGBoostPruningCallback + + +@gin.configurable +class XGBClassifier(MLWrapper): + _supported_run_modes = [RunMode.classification] + _explain_values = False + + def __init__(self, *args, **kwargs): + self.model = self.set_model_args(xgb.XGBClassifier, *args, **kwargs, device="cpu") + super().__init__(*args, **kwargs) + + def predict(self, features): + """ + Predicts class probabilities for the given features. + + Args: + features: Input features for prediction. + + Returns: + numpy.ndarray: Predicted probabilities for each class. + """ + return self.model.predict_proba(features) + + def fit_model(self, train_data, train_labels, val_data, val_labels): + """Fit the model to the training data (default SKlearn syntax)""" + callbacks = [EarlyStopping(self.hparams.patience)] + + if wandb.run is not None: + callbacks.append(wandb_xgb()) + logging.info(f"train_data: {train_data.shape}, train_labels: {train_labels.shape}") + logging.info(train_labels) + self.model.fit(train_data, train_labels, eval_set=[(val_data, val_labels)], verbose=False) + if self._explain_values: + self.explainer = shap.TreeExplainer(self.model) + self.train_shap_values = self.explainer(train_data) + # shap.summary_plot(shap_values, X_test, feature_names=features) + # logging.info(self.model.get_booster().get_score(importance_type='weight')) + # self.log_dict(self.model.get_booster().get_score(importance_type='weight')) + # Return the first metric we use for validation + eval_score = mean(next(iter(self.model.evals_result_["validation_0"].values()))) + return eval_score # , callbacks=callbacks) + + def set_model_args(self, model, *args, **kwargs): + """XGBoost signature does not include the hyperparams so we need to pass them manually.""" + signature = inspect.signature(model.__init__).parameters + valid_params = signature.keys() + + # Filter out invalid arguments + valid_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} + + logging.debug(f"Creating model with: {valid_kwargs}.") + return model(**valid_kwargs) + + def get_feature_importance(self): + if not hasattr(self.model, "feature_importances_"): + raise ValueError("Model has not been fit yet. Call fit_model() before getting feature importances.") + return self.model.feature_importances_ diff --git a/icu_benchmarks/models/train.py b/icu_benchmarks/models/train.py index db7aabda..7bc8e45a 100644 --- a/icu_benchmarks/models/train.py +++ b/icu_benchmarks/models/train.py @@ -1,8 +1,9 @@ import os import gin +import numpy as np import torch import logging -import pandas as pd +import polars as pl from joblib import load from torch.optim import Adam from torch.utils.data import DataLoader @@ -10,9 +11,9 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar, LearningRateMonitor from pathlib import Path -from icu_benchmarks.data.loader import PredictionDataset, ImputationDataset +from icu_benchmarks.data.loader import PredictionPandasDataset, ImputationPandasDataset, PredictionPolarsDataset from icu_benchmarks.models.utils import save_config_file, JSONMetricsLogger -from icu_benchmarks.contants import RunMode +from icu_benchmarks.constants import RunMode from icu_benchmarks.data.constants import DataSplit as Split cpu_core_count = len(os.sched_getaffinity(0)) if hasattr(os, "sched_getaffinity") else os.cpu_count() @@ -26,7 +27,7 @@ def assure_minimum_length(dataset): @gin.configurable("train_common") def train_common( - data: dict[str, pd.DataFrame], + data: dict[str, pl.DataFrame], log_dir: Path, eval_only: bool = False, load_weights: bool = False, @@ -37,8 +38,8 @@ def train_common( weight: str = None, optimizer: type = Adam, precision=32, - batch_size=64, - epochs=1000, + batch_size=1, + epochs=100, patience=20, min_delta=1e-5, test_on: str = Split.test, @@ -50,6 +51,8 @@ def train_common( pl_model=True, train_only=False, num_workers: int = min(cpu_core_count, torch.cuda.device_count() * 8 * int(torch.cuda.is_available()), 32), + polars=True, + persistent_workers=None, ): """Common wrapper to train all benchmarked models. @@ -79,11 +82,17 @@ def train_common( """ logging.info(f"Training model: {model.__name__}.") - dataset_class = ImputationDataset if mode == RunMode.imputation else PredictionDataset - + # todo: add support for polars versions of datasets + dataset_classes = { + RunMode.imputation: ImputationPandasDataset, + RunMode.classification: PredictionPolarsDataset if polars else PredictionPandasDataset, + RunMode.regression: PredictionPolarsDataset if polars else PredictionPandasDataset, + } + dataset_class = dataset_classes[mode] + + logging.info(f"Using dataset class: {dataset_class.__name__}.") logging.info(f"Logging to directory: {log_dir}.") save_config_file(log_dir) # We save the operative config before and also after training - train_dataset = dataset_class(data, split=Split.train, ram_cache=ram_cache, name=dataset_names["train"]) val_dataset = dataset_class(data, split=Split.val, ram_cache=ram_cache, name=dataset_names["val"]) train_dataset, val_dataset = assure_minimum_length(train_dataset), assure_minimum_length(val_dataset) @@ -95,30 +104,29 @@ def train_common( f" {len(val_dataset)} samples." ) logging.info(f"Using {num_workers} workers for data loading.") - train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, - pin_memory=True, drop_last=True, + persistent_workers=persistent_workers, ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, - pin_memory=True, drop_last=True, + persistent_workers=persistent_workers, ) data_shape = next(iter(train_loader))[0].shape if load_weights: - model = load_model(model, source_dir, pl_model=pl_model) + model = load_model(model, source_dir, pl_model=pl_model, cpu=cpu) else: - model = model(optimizer=optimizer, input_size=data_shape, epochs=epochs, run_mode=mode) + model = model(optimizer=optimizer, input_size=data_shape, epochs=epochs, run_mode=mode, cpu=cpu) model.set_weight(weight, train_dataset) model.set_trained_columns(train_dataset.get_feature_names()) @@ -137,6 +145,7 @@ def train_common( trainer = Trainer( max_epochs=epochs if model.requires_backprop else 1, + min_epochs=1, # We need at least one epoch to get results. callbacks=callbacks, precision=precision, accelerator="auto" if not cpu else "cpu", @@ -145,7 +154,7 @@ def train_common( benchmark=not reproducible, enable_progress_bar=verbose, logger=loggers, - num_sanity_val_steps=-1, + num_sanity_val_steps=2, # Helps catch errors in the validation loop before training begins. log_every_n_steps=5, ) if not eval_only: @@ -162,7 +171,7 @@ def train_common( logging.info("Finished training full model.") save_config_file(log_dir) return 0 - test_dataset = dataset_class(data, split=test_on, name=dataset_names["test"]) + test_dataset = dataset_class(data, split=test_on, name=dataset_names["test"], ram_cache=ram_cache) test_dataset = assure_minimum_length(test_dataset) logging.info(f"Testing on {test_dataset.name} with {len(test_dataset)} samples.") test_loader = ( @@ -173,6 +182,7 @@ def train_common( num_workers=num_workers, pin_memory=True, drop_last=True, + persistent_workers=persistent_workers, ) if model.requires_backprop else DataLoader([test_dataset.to_tensor()], batch_size=1) @@ -180,10 +190,35 @@ def train_common( model.set_weight("balanced", train_dataset) test_loss = trainer.test(model, dataloaders=test_loader, verbose=verbose)[0]["test/loss"] + persist_shap_data(trainer, log_dir) save_config_file(log_dir) return test_loss +def persist_shap_data(trainer: Trainer, log_dir: Path): + """ + Persist shap values to disk. + Args: + trainer: Pytorch lightning trainer object + log_dir: Log directory + """ + try: + if trainer.lightning_module.test_shap_values is not None: + shap_values = trainer.lightning_module.test_shap_values + shaps_test = pl.DataFrame(schema=trainer.lightning_module.trained_columns, data=np.transpose(shap_values.values)) + with (log_dir / "shap_values_test.parquet").open("wb") as f: + shaps_test.write_parquet(f) + logging.info(f"Saved shap values to {log_dir / 'test_shap_values.parquet'}") + if trainer.lightning_module.train_shap_values is not None: + shap_values = trainer.lightning_module.train_shap_values + shaps_train = pl.DataFrame(schema=trainer.lightning_module.trained_columns, data=np.transpose(shap_values.values)) + with (log_dir / "shap_values_train.parquet").open("wb") as f: + shaps_train.write_parquet(f) + + except Exception as e: + logging.error(f"Failed to save shap values: {e}") + + def load_model(model, source_dir, pl_model=True): if source_dir.exists(): if model.requires_backprop: diff --git a/icu_benchmarks/models/utils.py b/icu_benchmarks/models/utils.py index 6c944ae7..fc5b5506 100644 --- a/icu_benchmarks/models/utils.py +++ b/icu_benchmarks/models/utils.py @@ -11,6 +11,7 @@ from pytorch_lightning.loggers.logger import Logger from pytorch_lightning.utilities import rank_zero_only +from sklearn.metrics import average_precision_score from torch.nn import Module from torch.optim import Optimizer, Adam, SGD, RAdam from typing import Optional, Union @@ -188,3 +189,96 @@ def version(self): @rank_zero_only def log_hyperparams(self, params): pass + + +class scorer_wrapper: + """ + Wrapper that flattens the binary classification input such that we can use a broader range of sklearn metrics. + """ + + def __init__(self, scorer=average_precision_score): + self.scorer = scorer + + def __call__(self, y_true, y_pred): + if len(np.unique(y_true)) <= 2 and y_pred.ndim > 1: + y_pred_argmax = np.argmax(y_pred, axis=1) + return self.scorer(y_true, y_pred_argmax) + else: + return self.scorer(y_true, y_pred) + + def __name__(self): + return "scorer_wrapper" + + +# Source: https://github.com/ratschlab/tls +@gin.configurable("get_smoothed_labels") +def get_smoothed_labels( + label, event, smoothing_fn=gin.REQUIRED, h_true=gin.REQUIRED, h_min=gin.REQUIRED, h_max=gin.REQUIRED, delta_h=12, gamma=0.1 +): + diffs = np.concatenate([np.zeros(1), event[1:] - event[:-1]], axis=-1) + pos_event_change_full = np.where((diffs == 1) & (event == 1))[0] + + multihorizon = isinstance(h_true, list) + if multihorizon: + label_for_event = label[0] + h_for_event = h_true[0] + else: + label_for_event = label + h_for_event = h_true + diffs_label = np.concatenate([np.zeros(1), label_for_event[1:] - label_for_event[:-1]], axis=-1) + + # Event that occurred after the end of the stay for M3B. + # In that case event are equal to the number of hours after the end of stay when the event occured. + pos_event_change_delayed = np.where((diffs >= 1) & (event > 1))[0] + if len(pos_event_change_delayed) > 0: + delays = event[pos_event_change_delayed] - 1 + pos_event_change_delayed += delays.astype(int) + pos_event_change_full = np.sort(np.concatenate([pos_event_change_full, pos_event_change_delayed])) + + last_know_label = label_for_event[np.where(label_for_event != -1)][-1] + last_know_idx = np.where(label_for_event == last_know_label)[0][-1] + + # Need to handle the case where the ts was truncatenated at 2016 for HiB + if ((last_know_label == 1) and (len(pos_event_change_full) == 0)) or ( + (last_know_label == 1) and (last_know_idx >= pos_event_change_full[-1]) + ): + last_know_event = 0 + if len(pos_event_change_full) > 0: + last_know_event = pos_event_change_full[-1] + + last_known_stable = 0 + known_stable = np.where(label_for_event == 0)[0] + if len(known_stable) > 0: + last_known_stable = known_stable[-1] + + pos_change = np.where((diffs_label >= 1) & (label_for_event == 1))[0] + last_pos_change = pos_change[np.where(pos_change > max(last_know_event, last_known_stable))][0] + pos_event_change_full = np.concatenate([pos_event_change_full, [last_pos_change + h_for_event]]) + + # No event case + if len(pos_event_change_full) == 0: + pos_event_change_full = np.array([np.inf]) + + time_array = np.arange(len(label)) + dist = pos_event_change_full.reshape(-1, 1) - time_array + dte = np.where(dist > 0, dist, np.inf).min(axis=0) + if multihorizon: + smoothed_labels = [] + for k in range(label.shape[-1]): + smoothed_labels.append( + np.array( + list( + map( + lambda x: smoothing_fn( + x, h_true=h_true[k], h_min=h_min[k], h_max=h_max[k], delta_h=delta_h, gamma=gamma + ), + dte, + ) + ) + ) + ) + return np.stack(smoothed_labels, axis=-1) + else: + return np.array( + list(map(lambda x: smoothing_fn(x, h_true=h_true, h_min=h_min, h_max=h_max, delta_h=delta_h, gamma=gamma), dte)) + ) diff --git a/icu_benchmarks/models/wrappers.py b/icu_benchmarks/models/wrappers.py index e8595a70..310f9fe6 100644 --- a/icu_benchmarks/models/wrappers.py +++ b/icu_benchmarks/models/wrappers.py @@ -1,9 +1,9 @@ import logging from abc import ABC from typing import Dict, Any, List, Optional, Union - +from pathlib import Path import torchmetrics -from sklearn.metrics import log_loss, mean_squared_error +from sklearn.metrics import log_loss, mean_squared_error, average_precision_score, roc_auc_score import torch from torch.nn import MSELoss, CrossEntropyLoss @@ -16,12 +16,13 @@ import numpy as np from ignite.exceptions import NotComputableError from icu_benchmarks.models.constants import ImputationInit +from icu_benchmarks.models.custom_metrics import confusion_matrix from icu_benchmarks.models.utils import create_optimizer, create_scheduler from joblib import dump from pytorch_lightning import LightningModule from icu_benchmarks.models.constants import MLMetrics, DLMetrics -from icu_benchmarks.contants import RunMode +from icu_benchmarks.constants import RunMode gin.config.external_configurable(nn.functional.nll_loss, module="torch.nn.functional") gin.config.external_configurable(nn.functional.cross_entropy, module="torch.nn.functional") @@ -29,6 +30,9 @@ gin.config.external_configurable(mean_squared_error, module="sklearn.metrics") gin.config.external_configurable(log_loss, module="sklearn.metrics") +gin.config.external_configurable(average_precision_score, module="sklearn.metrics") +gin.config.external_configurable(roc_auc_score, module="sklearn.metrics") +# gin.config.external_configurable(scorer_wrapper, module="icu_benchmarks.models.utils") @gin.configurable("BaseModule") @@ -42,6 +46,8 @@ class BaseModule(LightningModule): trained_columns = None # Type of run mode run_mode = None + debug = False + explain_features = False def forward(self, *args, **kwargs): raise NotImplementedError() @@ -58,8 +64,14 @@ def set_metrics(self, *args, **kwargs): def set_trained_columns(self, columns: List[str]): self.trained_columns = columns - def set_weight(self, weight, *args, **kwargs): - pass + def set_weight(self, weight, dataset): + """Set the weight for the loss function.""" + + if isinstance(weight, list): + weight = FloatTensor(weight).to(self.device) + elif weight == "balanced": + weight = FloatTensor(dataset.get_balance()).to(self.device) + self.loss_weights = weight def training_step(self, batch, batch_idx): return self.step_fn(batch, "train") @@ -249,15 +261,6 @@ def __init__( self.output_transform = None self.loss_weights = None - def set_weight(self, weight, dataset): - """Set the weight for the loss function.""" - - if isinstance(weight, list): - weight = FloatTensor(weight).to(self.device) - elif weight == "balanced": - weight = FloatTensor(dataset.get_balance()).to(self.device) - self.loss_weights = weight - def set_metrics(self, *args): """Set the evaluation metrics for the prediction model.""" @@ -372,6 +375,7 @@ def __init__(self, *args, run_mode=RunMode.classification, loss=log_loss, patien self.loss = loss self.patience = patience self.mps = mps + self.loss_weight = None def set_metrics(self, labels): if self.run_mode == RunMode.classification: @@ -401,8 +405,8 @@ def set_metrics(self, labels): def fit(self, train_dataset, val_dataset): """Fit the model to the training data.""" - train_rep, train_label = train_dataset.get_data_and_labels() - val_rep, val_label = val_dataset.get_data_and_labels() + train_rep, train_label, row_indicators = train_dataset.get_data_and_labels() + val_rep, val_label, row_indicators = val_dataset.get_data_and_labels() self.set_metrics(train_label) @@ -414,6 +418,7 @@ def fit(self, train_dataset, val_dataset): train_pred = self.predict(train_rep) logging.debug(f"Model:{self.model}") + self.log("train/loss", self.loss(train_label, train_pred), sync_dist=True) logging.debug(f"Train loss: {self.loss(train_label, train_pred)}") self.log("val/loss", val_loss, sync_dist=True) @@ -427,7 +432,7 @@ def fit_model(self, train_data, train_labels, val_data, val_labels): return val_loss def validation_step(self, val_dataset, _): - val_rep, val_label = val_dataset.get_data_and_labels() + val_rep, val_label, row_indicators = val_dataset.get_data_and_labels() val_rep, val_label = torch.from_numpy(val_rep).to(self.device), torch.from_numpy(val_label).to(self.device) self.set_metrics(val_label) @@ -438,11 +443,18 @@ def validation_step(self, val_dataset, _): self.log_metrics(val_label, val_pred, "val") def test_step(self, dataset, _): - test_rep, test_label = dataset - test_rep, test_label = test_rep.squeeze().cpu().numpy(), test_label.squeeze().cpu().numpy() + test_rep, test_label, pred_indicators = dataset + test_rep, test_label, pred_indicators = ( + test_rep.squeeze().cpu().numpy(), + test_label.squeeze().cpu().numpy(), + pred_indicators.squeeze().cpu().numpy(), + ) self.set_metrics(test_label) test_pred = self.predict(test_rep) - + if self.debug: + self._save_model_outputs(pred_indicators, test_pred, test_label) + if self.explain_features: + self.explain_model(test_rep, test_label) if self.mps: self.log("test/loss", np.float32(self.loss(test_label, test_pred)), sync_dist=True) self.log_metrics(np.float32(test_label), np.float32(test_pred), "test") @@ -459,21 +471,36 @@ def predict(self, features): def log_metrics(self, label, pred, metric_type): """Log metrics to the PL logs.""" - + if "Confusion_Matrix" in self.metrics: + self.log_dict(confusion_matrix(self.label_transform(label), self.output_transform(pred)), sync_dist=True) self.log_dict( { - # MPS dependent type casting - f"{metric_type}/{name}": metric(self.label_transform(label), self.output_transform(pred)) - if not self.mps - else metric(self.label_transform(label), self.output_transform(pred)) - # Fore very metric + f"{metric_type}/{name}": (metric(self.label_transform(label), self.output_transform(pred))) + # For every metric for name, metric in self.metrics.items() # Filter out metrics that return a tuple (e.g. precision_recall_curve) if not isinstance(metric(self.label_transform(label), self.output_transform(pred)), tuple) + and name != "Confusion_Matrix" }, sync_dist=True, ) + def _explain_model(self, test_rep, test_label): + if self.explainer is not None: + self.test_shap_values = self.explainer(test_rep) + else: + logging.warning("No explainer or explain_features values set.") + + def _save_model_outputs(self, pred_indicators, test_pred, test_label): + if len(pred_indicators.shape) > 1 and len(test_pred.shape) > 1 and pred_indicators.shape[1] == test_pred.shape[1]: + pred_indicators = np.hstack((pred_indicators, test_label.reshape(-1, 1))) + pred_indicators = np.hstack((pred_indicators, test_pred)) + # Save as: id, time (hours), ground truth, prediction 0, prediction 1 + np.savetxt(Path(self.logger.save_dir) / "pred_indicators.csv", pred_indicators, delimiter=",") + logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / f'row_indicators.csv'}") + else: + logging.warning("Could not save row indicators.") + def configure_optimizers(self): return None @@ -498,6 +525,7 @@ def set_model_args(self, model, *args, **kwargs): # Get passed keyword arguments arguments = locals()["kwargs"] # Get valid hyperparameters + logging.debug(f"Possible hps: {possible_hps}") hyperparams = {key: value for key, value in arguments.items() if key in possible_hps} logging.debug(f"Creating model with: {hyperparams}.") return model(**hyperparams) diff --git a/icu_benchmarks/run.py b/icu_benchmarks/run.py index b0bd5a31..1c5479a0 100644 --- a/icu_benchmarks/run.py +++ b/icu_benchmarks/run.py @@ -18,8 +18,9 @@ setup_logging, import_preprocessor, name_datasets, + get_config_files, ) -from icu_benchmarks.contants import RunMode +from icu_benchmarks.constants import RunMode @gin.configurable("Run") @@ -31,6 +32,7 @@ def get_mode(mode: gin.REQUIRED): def main(my_args=tuple(sys.argv[1:])): args, _ = build_parser().parse_known_args(my_args) + # Set arguments for wandb sweep if args.wandb_sweep: args = apply_wandb_sweep(args) set_wandb_experiment_name(args, "run") @@ -48,10 +50,22 @@ def main(my_args=tuple(sys.argv[1:])): evaluate = args.eval experiment = args.experiment source_dir = args.source_dir + modalities = args.modalities + if modalities: + logging.debug(f"Binding modalities: {modalities}") + gin.bind_parameter("preprocess.selected_modalities", modalities) + if args.label: + logging.debug(f"Binding label: {args.label}") + gin.bind_parameter("preprocess.label", args.label) + tasks, models = get_config_files(Path("configs")) + if task not in tasks or model not in models: + raise ValueError( + f"Invalid task or model. Task: {task} {'not ' if task not in tasks else ''} found. " + f"Model: {model} {'not ' if model not in models else ''}found." + ) # Load task config gin.parse_config_file(f"configs/tasks/{task}.gin") mode = get_mode() - # Set arguments for wandb sweep # Set experiment name if name is None: @@ -68,9 +82,9 @@ def main(my_args=tuple(sys.argv[1:])): # Log imputation model to wandb update_wandb_config( { - "pretrained_imputation_model": pretrained_imputation_model.__class__.__name__ - if pretrained_imputation_model is not None - else "None" + "pretrained_imputation_model": ( + pretrained_imputation_model.__class__.__name__ if pretrained_imputation_model is not None else "None" + ) } ) @@ -118,7 +132,7 @@ def main(my_args=tuple(sys.argv[1:])): name_datasets(args.name, args.name, args.name) hp_checkpoint = log_dir / args.hp_checkpoint if args.hp_checkpoint else None model_path = ( - Path("configs") / ("imputation_models" if mode == RunMode.imputation else "prediction_models") / f"{model}.gin" + Path("configs") / ("imputation_models" if mode == RunMode.imputation else "prediction_models") / f"{model}.gin" ) gin_config_files = ( [Path(f"configs/experiments/{args.experiment}.gin")] @@ -129,10 +143,10 @@ def main(my_args=tuple(sys.argv[1:])): log_full_line(f"Data directory: {data_dir.resolve()}", level=logging.INFO) run_dir = create_run_dir(log_dir) choose_and_bind_hyperparameters_optuna( - args.tune, - data_dir, - run_dir, - args.seed, + do_tune=args.tune, + data_dir=data_dir, + log_dir=run_dir, + seed=args.seed, run_mode=mode, checkpoint=hp_checkpoint, debug=args.debug, @@ -175,7 +189,11 @@ def main(my_args=tuple(sys.argv[1:])): log_full_line("FINISHED TRAINING", level=logging.INFO, char="=", num_newlines=3) execution_time = datetime.now() - start_time log_full_line(f"DURATION: {execution_time}", level=logging.INFO, char="") - aggregate_results(run_dir, execution_time) + try: + aggregate_results(run_dir, execution_time) + except Exception as e: + logging.error(f"Failed to aggregate results: {e}") + logging.debug("Error details:", exc_info=True) if args.plot: plot_aggregated_results(run_dir, "aggregated_test_metrics.json") diff --git a/icu_benchmarks/run_utils.py b/icu_benchmarks/run_utils.py index 85b676ac..97b50f8c 100644 --- a/icu_benchmarks/run_utils.py +++ b/icu_benchmarks/run_utils.py @@ -15,6 +15,7 @@ from statistics import mean, pstdev from icu_benchmarks.models.utils import JsonResultLoggingEncoder from icu_benchmarks.wandb_utils import wandb_log +import polars as pl def build_parser() -> ArgumentParser: @@ -52,6 +53,13 @@ def build_parser() -> ArgumentParser: parser.add_argument("-sn", "--source-name", type=Path, help="Name of the source dataset.") parser.add_argument("--source-dir", type=Path, help="Directory containing gin and model weights.") parser.add_argument("-sa", "--samples", type=int, default=None, help="Number of samples to use for evaluation.") + parser.add_argument( + "-mo", + "--modalities", + nargs="+", + help="Optional modality selection to use. Specify multiple modalities separated by spaces.", + ) + parser.add_argument("--label", type=str, help="Label to use for evaluation in case of multiple labels.", default=None) return parser @@ -68,9 +76,8 @@ def create_run_dir(log_dir: Path, randomly_searched_params: str = None) -> Path: Returns: Path to the created run log directory. """ - if not (log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))).exists(): - log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) - else: + log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) + while log_dir_run.exists(): log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f")) log_dir_run.mkdir(parents=True) if randomly_searched_params: @@ -99,6 +106,7 @@ def aggregate_results(log_dir: Path, execution_time: timedelta = None): execution_time: Overall execution time. """ aggregated = {} + shap_values_test = [] for repetition in log_dir.iterdir(): if repetition.is_dir(): aggregated[repetition.name] = {} @@ -117,7 +125,18 @@ def aggregate_results(log_dir: Path, execution_time: timedelta = None): with open(fold_iter / "durations.json", "r") as f: result = json.load(f) aggregated[repetition.name][fold_iter.name].update(result) + if (fold_iter / "test_shap_values.parquet").is_file(): + shap_values_test.append(pl.read_parquet(fold_iter / "test_shap_values.parquet")) + + if shap_values_test: + shap_values = pl.concat(shap_values_test) + shap_values.write_parquet(log_dir / "aggregated_shap_values.parquet") + try: + shap_values = pl.concat(shap_values_test) + shap_values.write_parquet(log_dir / "aggregated_shap_values.parquet") + except Exception as e: + logging.error(f"Error aggregating or writing SHAP values: {e}") # Aggregate results per metric list_scores = {} for repetition, folds in aggregated.items(): @@ -235,3 +254,46 @@ def setup_logging(date_format, log_format, verbose): for logger in loggers: logging.getLogger(logger).setLevel(logging.DEBUG) warnings.filterwarnings("default") + + +def get_config_files(config_dir: Path): + """ + Get all task and model config files in the specified directory. + Args: + config_dir: Name of the directory containing the config gin files. + + Returns: + tasks: List of task names + models: List of model names + """ + try: + tasks = list((config_dir / "tasks").glob("*")) + models = list((config_dir / "prediction_models").glob("*")) + tasks = [task.stem for task in tasks if task.is_file()] + models = [model.stem for model in models if model.is_file()] + except Exception as e: + logging.error(f"Error retrieving config files: {e}") + return [], [] + if "common" in tasks: + tasks.remove("common") + if "common" in models: + models.remove("common") + logging.info(f"Found tasks: {tasks}") + logging.info(f"Found models: {models}") + return tasks, models + + +def check_required_keys(vars, required_keys): + """ + Checks if all required keys are present in the vars dictionary. + + Args: + vars (dict): The dictionary to check. + required_keys (list): The list of required keys. + + Raises: + KeyError: If any required key is missing. + """ + missing_keys = [key for key in required_keys if key not in vars] + if missing_keys: + raise KeyError(f"Missing required keys in vars: {', '.join(missing_keys)}") diff --git a/icu_benchmarks/tuning/hyperparameters.py b/icu_benchmarks/tuning/hyperparameters.py index b1f5a73c..c879bd62 100644 --- a/icu_benchmarks/tuning/hyperparameters.py +++ b/icu_benchmarks/tuning/hyperparameters.py @@ -2,6 +2,7 @@ import gin import logging from logging import NOTSET +import matplotlib.pyplot as plt import numpy as np from pathlib import Path from skopt import gp_minimize @@ -12,15 +13,16 @@ from icu_benchmarks.cross_validation import execute_repeated_cv from icu_benchmarks.run_utils import log_full_line from icu_benchmarks.tuning.gin_utils import get_gin_hyperparameters, bind_gin_params -from icu_benchmarks.contants import RunMode +from icu_benchmarks.constants import RunMode from icu_benchmarks.wandb_utils import wandb_log +from optuna.visualization import plot_param_importances, plot_optimization_history TUNE = 25 logging.addLevelName(25, "TUNE") @gin.configurable("tune_hyperparameters_deprecated") -def choose_and_bind_hyperparameters( +def choose_and_bind_hyperparameters_scikit_optimize( do_tune: bool, data_dir: Path, log_dir: Path, @@ -61,6 +63,9 @@ def choose_and_bind_hyperparameters( Raises: ValueError: If checkpoint is not None and the checkpoint does not exist. """ + logging.warning( + "This function is deprecated and will be removed in the future. " "Use choose_and_bind_hyperparameters_optuna instead." + ) hyperparams = {} if len(scopes) == 0 or folds_to_tune_on is None: @@ -189,11 +194,13 @@ def choose_and_bind_hyperparameters_optuna( debug: bool = False, verbose: bool = False, wandb: bool = False, + plot: bool = True, ): """Choose hyperparameters to tune and bind them to gin. Uses Optuna for hyperparameter optimization. Args: - sampler: + plot: Whether to plot hyperparameter importances. + sampler: The sampler to use for hyperparameter optimization. wandb: Whether we use wandb or not. load_cache: Load cached data if available. generate_cache: Generate cache data. @@ -227,44 +234,36 @@ def choose_and_bind_hyperparameters_optuna( logging.info("No hyperparameters to tune, skipping tuning.") return - # Attempt checkpoint loading - configuration, evaluation = None, None - if checkpoint: - return NotImplementedError("Checkpoint loading is not implemented for Optuna yet.") - # checkpoint_path = checkpoint / checkpoint_file - # if not checkpoint_path.exists(): - # logging.warning(f"Hyperparameter checkpoint {checkpoint_path} does not exist.") - # logging.info("Attempting to find latest checkpoint file.") - # checkpoint_path = find_checkpoint(log_dir.parent, checkpoint_file) - # # Check if we found a checkpoint file - # if checkpoint_path: - # n_calls, configuration, evaluation = load_checkpoint(checkpoint_path, n_calls) - # # # Check if we surpassed maximum tuning iterations - # # if n_calls <= 0: - # # logging.log(TUNE, "No more hyperparameter tuning iterations left, skipping tuning.") - # # logging.info("Training with these hyperparameters:") - # # bind_gin_params(hyperparams_names, configuration[np.argmin(evaluation)]) # bind best hyperparameters - # # return - # else: - # logging.warning("No checkpoint file found, starting from scratch.") - # Function that trains the model with the given hyperparameters. header = ["ITERATION"] + hyperparams_names + ["LOSS AT ITERATION"] # Optuna objective function - def objective(trail, hyperparams_bounds, hyperparams_names): + def objective(trial, hyperparams_bounds, hyperparams_names): # Optuna objective function hyperparams = {} logging.info(f"Bounds: {hyperparams_bounds}, Names: {hyperparams_names}") for name, value in zip(hyperparams_names, hyperparams_bounds): if isinstance(value, tuple): + + def suggest_int_param(trial, name, value): + return trial.suggest_int(name, value[0], value[1], log=value[2] == "log" if len(value) == 3 else False) + + def suggest_float_param(trial, name, value): + return trial.suggest_float(name, value[0], value[1], log=value[2] == "log" if len(value) == 3 else False) + + def suggest_categorical_param(trial, name, value): + return trial.suggest_categorical(name, value) + + # Then in the objective function: if isinstance(value[0], int) and isinstance(value[1], int): - hyperparams[name] = trail.suggest_int(name, value[0], value[1], log=len(value) > 2 and value[2] == "log") + hyperparams[name] = suggest_int_param(trial, name, value) + elif isinstance(value[0], (int, float)) and isinstance(value[1], (int, float)): + hyperparams[name] = suggest_float_param(trial, name, value) else: - hyperparams[name] = trail.suggest_float(name, value[0], value[1], log=len(value) > 2 and value[2] == "log") + hyperparams[name] = suggest_categorical_param(trial, name, value) else: - hyperparams[name] = trail.suggest_categorical(name, value) + hyperparams[name] = trial.suggest_categorical(name, value) return bind_params_and_train(hyperparams) def tune_step_callback(study: optuna.study.Study, trial: optuna.trial.FrozenTrial): @@ -272,7 +271,7 @@ def tune_step_callback(study: optuna.study.Study, trial: optuna.trial.FrozenTria highlight = study.trials[-1] == study.best_trial # highlight if best so far log_table_row(header, TUNE) log_table_row(table_cells, TUNE, align=Align.RIGHT, header=header, highlight=highlight) - wandb_log({"hp-iteration": len(study.trials)}) + wandb_log({"HP-optimization-iteration": len(study.trials)}) if do_tune: log_full_line("STARTING TUNING", level=TUNE, char="=") @@ -283,12 +282,17 @@ def tune_step_callback(study: optuna.study.Study, trial: optuna.trial.FrozenTria log_table_row(header, TUNE) else: logging.log(TUNE, "Hyperparameter tuning disabled") - if configuration: + if checkpoint: + study = optuna.load_study(study_name="tuning", storage="sqlite:///" + str(checkpoint)) + configuration = study.best_params # We have loaded a checkpoint, use the best hyperparameters. logging.info("Training with the best hyperparameters from loaded checkpoint:") - bind_gin_params(hyperparams_names, configuration[np.argmin(evaluation)]) + bind_gin_params(configuration) + return else: - logging.log(TUNE, "Choosing hyperparameters randomly from bounds.") + logging.log( + TUNE, "Choosing hyperparameters randomly from bounds using hp tuning as no earlier checkpoint " "supplied." + ) n_initial_points = 1 n_calls = 1 @@ -318,28 +322,50 @@ def bind_params_and_train(hyperparams): sampler = sampler(seed=seed, n_startup_trials=n_initial_points, deterministic_objective=True) else: sampler = sampler(seed=seed) - + pruner = optuna.pruners.HyperbandPruner() # Optuna study - study = optuna.create_study( - sampler=sampler, - storage="sqlite:///" + str(log_dir / checkpoint_file), - study_name=str(data_dir) + str(seed), - pruner=optuna.pruners.HyperbandPruner(), - ) + # Attempt checkpoint loading + if checkpoint and checkpoint.exists(): + logging.warning(f"Hyperparameter checkpoint {checkpoint} does not exist.") + # logging.info("Attempting to find latest checkpoint file.") + # checkpoint_path = find_checkpoint(log_dir.parent, checkpoint_file) + # Check if we found a checkpoint file + logging.info(f"Loading checkpoint at {checkpoint}") + study = optuna.load_study(study_name="tuning", storage="sqlite:///" + str(checkpoint), sampler=sampler, pruner=pruner) + n_calls = n_calls - len(study.trials) + else: + if checkpoint: + logging.warning("Checkpoint path given as flag but not found, starting from scratch.") + study = optuna.create_study( + sampler=sampler, + storage="sqlite:///" + str(log_dir / checkpoint_file), + study_name="tuning", + pruner=pruner, + load_if_exists=True, + ) + callbacks = [tune_step_callback] if wandb: wandb_kwargs = { "config": {"sampler": sampler}, + "allow_val_change": True, } wandbc = WeightsAndBiasesCallback(metric_name="loss", wandb_kwargs=wandb_kwargs) callbacks.append(wandbc) - logging.info(f"Starting Optuna study with {n_calls} trials and callbacks: {callbacks}.") - study.optimize( - lambda trail: objective(trail, hyperparams_bounds, hyperparams_names), - n_trials=n_calls, - callbacks=callbacks, - gc_after_trial=True, - ) + + logging.info(f"Starting or resuming Optuna study with {n_calls} trails and callbacks: {callbacks}.") + if n_calls > 0: + study.optimize( + lambda trail: objective(trail, hyperparams_bounds, hyperparams_names), + n_trials=n_calls, + callbacks=callbacks, + gc_after_trial=True, + ) + else: + logging.info("No more hyperparameter tuning iterations left, skipping tuning.") + logging.info("Training with these hyperparameters:") + bind_gin_params(study.best_params) + return logging.disable(level=NOTSET) if do_tune: @@ -348,6 +374,16 @@ def bind_params_and_train(hyperparams): logging.info("Training with these hyperparameters:") bind_gin_params(study.best_params) + if plot: + try: + logging.info("Plotting hyperparameter importances.") + plot_param_importances(study) + plt.savefig(log_dir / "param_importances.png") + plot_optimization_history(study) + plt.savefig(log_dir / "optimization_history.png") + except Exception as e: + logging.error(f"Failed to plot hyperparameter importances: {e}") + def collect_bound_hyperparameters(hyperparams, scopes): for scope in scopes: @@ -358,17 +394,6 @@ def collect_bound_hyperparameters(hyperparams, scopes): return hyperparams_bounds, hyperparams_names -def load_optuna_checkpoint(checkpoint_path, n_calls): - logging.info(f"Loading checkpoint at {checkpoint_path}") - with open(checkpoint_path, "r") as f: - data = json.loads(f.read()) - x0 = data["x_iters"] - y0 = data["func_vals"] - n_calls -= len(x0) - logging.log(TUNE, f"Checkpoint contains {len(x0)} points.") - return n_calls, x0, y0 - - def load_checkpoint(checkpoint_path, n_calls): logging.info(f"Loading checkpoint at {checkpoint_path}") with open(checkpoint_path, "r") as f: diff --git a/icu_benchmarks/wandb_utils.py b/icu_benchmarks/wandb_utils.py index 6148af5d..2ea06b57 100644 --- a/icu_benchmarks/wandb_utils.py +++ b/icu_benchmarks/wandb_utils.py @@ -30,7 +30,7 @@ def apply_wandb_sweep(args: Namespace) -> Namespace: Returns: Namespace: arguments with sweep configuration applied (some are applied via hyperparams) """ - wandb.init(allow_val_change=True) + wandb.init(allow_val_change=True, dir=args.log_dir) sweep_config = wandb.config args.__dict__.update(sweep_config) if args.hyperparams is None: @@ -62,7 +62,8 @@ def set_wandb_experiment_name(args, mode): data_dir = Path(args.data_dir) args.name = data_dir.name run_name = f"{mode}_{args.model}_{args.name}" - + if args.modalities: + run_name += f"_mods_{args.modalities}" if args.fine_tune: run_name += f"_source_{args.source_name}_fine-tune_{args.fine_tune}_samples" elif args.eval: diff --git a/requirements.txt b/requirements.txt index 96d9ea12..e48f4d06 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,19 +1,22 @@ --extra-index-url https://download.pytorch.org/whl/cu118 black==24.3.0 coverage==7.2.3 -flake8==5.0.4 +flake8>=7.0.0 matplotlib==3.7.1 gin-config==0.5.0 pytorch-ignite==0.5.0.post2 # Note: versioning of Pytorch might be dependent on compatible CUDA version. # Please check yourself if your Pytorch installation supports cuda (for gpu acceleration) -torch==2.3.1 -lightning==2.3.3 +torch==2.4 +lightning==2.4.0 torchmetrics==1.0.3 -#pytorch-cuda==11.8 -lightgbm==3.3.5 +lightgbm==4.4.0 +xgboost==2.1.0 +imbalanced-learn==0.12.3 +catboost==1.2.5 numpy==1.24.3 -pandas==2.0.0 +pandas==2.2.2 +polars==1.9.0 pyarrow==14.0.1 pytest==7.3.1 scikit-learn==1.5.0 @@ -21,11 +24,12 @@ tensorboard==2.12.2 tqdm==4.66.3 einops==0.6.1 hydra-core==1.3 -optuna==3.6.1 -optuna-integration==3.6.0 -wandb==0.17.2 -recipies==0.1.3 +optuna==4.0.0 +optuna-integration==4.0.0 +wandb==0.17.3 +recipies==1.0 #Fixed version because of NumPy incompatibility and stale development status. scikit-optimize-fix==0.9.1 hydra-submitit-launcher==1.2.0 -pytest-runner==6.0.1 \ No newline at end of file +pytest-runner==6.0.1 +