Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version release based on practice Cassandra use #155

Merged
merged 94 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from 81 commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
8a9efb8
Cass changes
rvandewater Jun 18, 2024
e3955a1
temporary logging
rvandewater Jun 19, 2024
96baca6
testing timesnet
rvandewater Jun 24, 2024
0044a6f
Merge branch 'refs/heads/development' into cass
rvandewater Jun 25, 2024
4e2223a
Check for task/model
rvandewater Jun 26, 2024
e1e1549
refactor models: split into file per architecture for DL and per libr…
rvandewater Jun 26, 2024
29f1abe
Improve hyperparameter tuning parsing
rvandewater Jun 26, 2024
f1841d8
deleted dl_models.py
rvandewater Jun 26, 2024
51b9ba8
gpu support for ml models
rvandewater Jun 26, 2024
1722d95
new model
rvandewater Jun 26, 2024
c753039
new model
rvandewater Jun 26, 2024
8a7fbdf
hyperparameter tuning fixes: loading seems to work now
rvandewater Jun 27, 2024
7afe2bd
revert to normal dataloader cores
rvandewater Jun 27, 2024
bce11c5
fixes and extensions in hyperparameter tuning configs
rvandewater Jun 27, 2024
7b1ad9d
hyperparameter bound changes to enable a better search space
rvandewater Jul 4, 2024
ef7374d
catboost gpu leads to errors
rvandewater Jul 4, 2024
17d2268
logging
rvandewater Jul 4, 2024
01e13f8
logging
rvandewater Jul 4, 2024
b7ff826
xgboost model configuration
rvandewater Jul 4, 2024
42ed128
cleaned up hp tuning
rvandewater Jul 4, 2024
c4dffb0
cpu submission
rvandewater Jul 4, 2024
105d5ae
increased mem, time for slurm job
rvandewater Jul 4, 2024
d7252c1
Polars version of the dataloader
rvandewater Jul 4, 2024
a64e79a
polars version of split_process_data.py (backwards compatible with pa…
rvandewater Jul 5, 2024
3b21838
requirements update
rvandewater Jul 5, 2024
b7c80c4
preprocessing using polars partially working
rvandewater Jul 5, 2024
b93a923
preprocessing working experimentally
rvandewater Jul 5, 2024
f074b8c
ordering missing indicator back because of nan/none fix
rvandewater Jul 5, 2024
3aac4f1
preproc timing
rvandewater Jul 8, 2024
5b37115
catching errors
rvandewater Jul 9, 2024
af285e0
Merge branch 'refs/heads/cass' into polars
rvandewater Jul 9, 2024
06cf6c0
Cast to get consistent join key
rvandewater Jul 9, 2024
957c499
refactor and cleanup. Added support for regression
rvandewater Jul 9, 2024
6d15904
binary classification
rvandewater Jul 9, 2024
c226a12
cleanup
rvandewater Jul 9, 2024
bdbbff5
cass related changes
rvandewater Jul 9, 2024
b1ffc31
cleaning up requirements with release candidate recipies
rvandewater Jul 9, 2024
7eab87d
complete train for polars
rvandewater Jul 9, 2024
d966fc9
Complete train extended
rvandewater Jul 9, 2024
08ea3b4
Merge pull request #152 from rvandewater/polars
rvandewater Jul 9, 2024
0f7b366
debugging looped dl training
rvandewater Jul 10, 2024
2e921ea
classification task
rvandewater Jul 10, 2024
2ad630f
fix dataloader hanging due to not caching
rvandewater Jul 11, 2024
af74e29
Modality selection
rvandewater Jul 12, 2024
148843d
flake for ci
rvandewater Jul 12, 2024
2d14647
modality mapping change to include all by default
rvandewater Jul 24, 2024
b81346d
modality mapping bug fixes
rvandewater Jul 24, 2024
a15db4f
modality mapping enhancements and naming clash
rvandewater Jul 29, 2024
d00f26f
Reduce logging "spam"
rvandewater Jul 30, 2024
db07daf
Multiple label support
rvandewater Aug 14, 2024
56eca4f
Multiple label support
rvandewater Aug 14, 2024
d0ea1a2
Fix for endpoint checking
rvandewater Aug 16, 2024
358654c
Introduced shap value logging
rvandewater Aug 16, 2024
a8a06ca
xgboost adjustments
rvandewater Aug 19, 2024
d7ed831
added option to use more sklearn metrics
rvandewater Aug 19, 2024
61f5f28
failsave if not using shap
rvandewater Aug 19, 2024
8243dca
concatenating shap
rvandewater Aug 19, 2024
93f680b
add support for excluding variables from feature generation
rvandewater Aug 21, 2024
90b0b05
Added auprc loss and prediction indicators to inspect model predictions
rvandewater Sep 17, 2024
b9aad2b
Optuna update
rvandewater Sep 17, 2024
32d1707
Confusion matrix
rvandewater Sep 17, 2024
babaaed
Added to constants
rvandewater Sep 17, 2024
561f3e7
prediction indicators in dataloader
rvandewater Sep 17, 2024
487ed81
todo
rvandewater Sep 17, 2024
0592cdb
hyperparameter ranges with scale_pos_weight
rvandewater Sep 17, 2024
6365a09
Added smoothed labels code (not implemented yet)
rvandewater Sep 17, 2024
b9456df
Added smoothed labels code (not implemented yet)
rvandewater Sep 17, 2024
640c8da
experiments
rvandewater Oct 15, 2024
f93d9fb
More changes in experiments
rvandewater Oct 15, 2024
786e256
debugging
rvandewater Oct 15, 2024
125e4f0
Compatibility with static (non temporal) datasets
rvandewater Oct 15, 2024
b307474
Experiment related tuning
rvandewater Oct 15, 2024
9777418
revert to log loss
rvandewater Oct 15, 2024
808eb9f
Saving pred indicators to evaluate predictions
rvandewater Oct 15, 2024
34e0573
Data sanitization
rvandewater Oct 15, 2024
ab3ab4b
Preprocessing improvements and failsafe when using just static features
rvandewater Oct 15, 2024
9905abc
Remove cass specific files
rvandewater Oct 17, 2024
860bf3a
update recipies to 1.0
rvandewater Oct 17, 2024
4a4ab0c
Major refactor
rvandewater Oct 17, 2024
9547bfa
Major refactoring round 2
rvandewater Oct 17, 2024
4ed54cf
Merge branch 'development' into version_release
rvandewater Oct 17, 2024
5d3ee69
remove cass files
rvandewater Oct 17, 2024
99ce855
remove cass files
rvandewater Oct 17, 2024
151d68f
Added polars
rvandewater Oct 17, 2024
47c594a
More major refactoring
rvandewater Oct 17, 2024
d4e8845
Linting
rvandewater Oct 17, 2024
9a4ab8e
added checks
rvandewater Oct 17, 2024
5004ae4
Increase allowed complexity
rvandewater Oct 17, 2024
619937c
docs
rvandewater Oct 17, 2024
0133dbd
Transformer refactor to reduce code duplication
rvandewater Oct 17, 2024
b3d7ad0
transformer fix
rvandewater Oct 17, 2024
23d2c68
optuna integration
rvandewater Oct 17, 2024
18175d9
removed superfluous test step (integrated in main wrapper now)
rvandewater Oct 17, 2024
5372971
xgboost linting
rvandewater Oct 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions configs/prediction_models/BRFClassifier.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Settings for ImbLearn Balanced Random Forest Classifier.

# Common settings for ML models
include "configs/prediction_models/common/MLCommon.gin"

# Train params
train_common.model = @BRFClassifier

model/hyperparameter.class_to_tune = @BRFClassifier
model/hyperparameter.n_estimators = [50, 100, 250, 500, 750,1000,1500]
model/hyperparameter.max_depth = [3, 5, 10, 15]
model/hyperparameter.min_samples_split = (2, 5, 10)
model/hyperparameter.min_samples_leaf = (1, 2, 4)
model/hyperparameter.max_features = ['sqrt', 'log2', 1.0]
model/hyperparameter.bootstrap = [True, False]
model/hyperparameter.class_weight = [None, 'balanced']


15 changes: 15 additions & 0 deletions configs/prediction_models/CBClassifier.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Settings for Catboost classifier.

# Common settings for ML models
include "configs/prediction_models/common/MLCommon.gin"

# Train params
train_common.model = @CBClassifier

model/hyperparameter.class_to_tune = @CBClassifier
model/hyperparameter.learning_rate = (1e-4, 0.5, "log")
model/hyperparameter.num_trees = [50, 100, 250, 500, 750,1000,1500]
model/hyperparameter.depth = [3, 5, 10, 15]
model/hyperparameter.scale_pos_weight = [1, 5, 10, 25, 50, 75, 99, 100, 1000]
model/hyperparameter.border_count = [5, 10, 20, 50, 100, 200]
model/hyperparameter.l2_leaf_reg = [1, 3, 5, 7, 9]
6 changes: 3 additions & 3 deletions configs/prediction_models/GRU.gin
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ train_common.model = @GRUNet
# Optimizer params
optimizer/hyperparameter.class_to_tune = @Adam
optimizer/hyperparameter.weight_decay = 1e-6
optimizer/hyperparameter.lr = (1e-5, 3e-4)
optimizer/hyperparameter.lr = (1e-6, 1e-4, "log")

# Encoder params
model/hyperparameter.class_to_tune = @GRUNet
model/hyperparameter.num_classes = %NUM_CLASSES
model/hyperparameter.hidden_dim = (32, 256, "log-uniform", 2)
model/hyperparameter.layer_dim = (1, 3)
model/hyperparameter.hidden_dim = (32, 512, "log")
model/hyperparameter.layer_dim = (1, 10)

8 changes: 4 additions & 4 deletions configs/prediction_models/RFClassifier.gin
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ train_common.model = @RFClassifier

model/hyperparameter.class_to_tune = @RFClassifier
model/hyperparameter.n_estimators = (10, 50, 100, 200, 500)
model/hyperparameter.max_depth = (None, 5, 10, 20)
model/hyperparameter.max_depth = (5, 10, 20)
model/hyperparameter.min_samples_split = (2, 5, 10)
model/hyperparameter.min_samples_leaf = (1, 2, 4)
model/hyperparameter.max_features = ('sqrt', 'log2', None)
model/hyperparameter.bootstrap = (True, False)
model/hyperparameter.class_weight = (None, 'balanced')
model/hyperparameter.max_features = ['sqrt', 'log2', None]
model/hyperparameter.bootstrap = [True, False]
model/hyperparameter.class_weight = [None, 'balanced']


14 changes: 14 additions & 0 deletions configs/prediction_models/RUSBClassifier.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Settings for ImbLearn Balanced Random Forest Classifier.

# Common settings for ML models
include "configs/prediction_models/common/MLCommon.gin"

# Train params
train_common.model = @RUSBClassifier

model/hyperparameter.class_to_tune = @RUSBClassifier
model/hyperparameter.n_estimators = (10, 50, 100, 200, 500)
model/hyperparameter.learning_rate = (0.005, 1, "log")
model/hyperparameter.sampling_strategy = "auto"


6 changes: 3 additions & 3 deletions configs/prediction_models/TCN.gin
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ train_common.model = @TemporalConvNet
# Optimizer params
optimizer/hyperparameter.class_to_tune = @Adam
optimizer/hyperparameter.weight_decay = 1e-6
optimizer/hyperparameter.lr = (1e-5, 3e-4)
optimizer/hyperparameter.lr = (1e-6, 3e-4)

# Encoder params
model/hyperparameter.class_to_tune = @TemporalConvNet
model/hyperparameter.num_classes = %NUM_CLASSES
model/hyperparameter.max_seq_length = %HORIZON
model/hyperparameter.num_channels = (32, 256, "log-uniform", 2)
model/hyperparameter.kernel_size = (2, 32, "log-uniform", 2)
model/hyperparameter.num_channels = (32, 256, "log")
model/hyperparameter.kernel_size = (2, 128, "log")
model/hyperparameter.dropout = (0.0, 0.4)
14 changes: 7 additions & 7 deletions configs/prediction_models/Transformer.gin
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@ train_common.model = @Transformer

optimizer/hyperparameter.class_to_tune = @Adam
optimizer/hyperparameter.weight_decay = 1e-6
optimizer/hyperparameter.lr = (1e-5, 3e-4)
optimizer/hyperparameter.lr = (1e-6, 1e-4)

# Encoder params
model/hyperparameter.class_to_tune = @Transformer
model/hyperparameter.ff_hidden_mult = 2
model/hyperparameter.l1_reg = 0.0
model/hyperparameter.ff_hidden_mult = (2,4,6,8)
model/hyperparameter.l1_reg = (0.0,1.0)
model/hyperparameter.num_classes = %NUM_CLASSES
model/hyperparameter.hidden = (32, 256, "log-uniform", 2)
model/hyperparameter.heads = (1, 8, "log-uniform", 2)
model/hyperparameter.hidden = (32, 512, "log")
model/hyperparameter.heads = (1, 8, "log")
model/hyperparameter.depth = (1, 3)
model/hyperparameter.dropout = (0.0, 0.4)
model/hyperparameter.dropout_att = (0.0, 0.4)
model/hyperparameter.dropout = 0 # no improvement (0.0, 0.4)
model/hyperparameter.dropout_att = (0.0, 1.0)


17 changes: 17 additions & 0 deletions configs/prediction_models/XGBClassifier.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Settings for XGBoost classifier.

# Common settings for ML models
include "configs/prediction_models/common/MLCommon.gin"

# Train params
train_common.model = @XGBClassifier

model/hyperparameter.class_to_tune = @XGBClassifier
model/hyperparameter.learning_rate = (0.01, 0.1, "log")
model/hyperparameter.n_estimators = [50, 100, 250, 500, 750, 1000,1500,2000]
model/hyperparameter.max_depth = [3, 5, 10, 15]
model/hyperparameter.scale_pos_weight = [1, 5, 10, 15, 20, 25, 30, 35, 40, 50, 75, 99, 100, 1000]
model/hyperparameter.min_child_weight = [1, 0.5]
model/hyperparameter.max_delta_step = [0, 1, 2, 3, 4, 5, 10]
model/hyperparameter.colsample_bytree = [0.1, 0.25, 0.5, 0.75, 1.0]
model/hyperparameter.eval_metric = "aucpr"
6 changes: 4 additions & 2 deletions configs/prediction_models/common/DLCommon.gin
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
# Imports to register the models
import gin.torch.external_configurables
import icu_benchmarks.models.wrappers
import icu_benchmarks.models.dl_models
import icu_benchmarks.models.dl_models.rnn
import icu_benchmarks.models.dl_models.transformer
import icu_benchmarks.models.dl_models.tcn
import icu_benchmarks.models.utils

# Do not generate features from dynamic data
Expand All @@ -12,7 +14,7 @@ base_regression_preprocessor.generate_features = False

# Train params
train_common.optimizer = @Adam
train_common.epochs = 1000
train_common.epochs = 50
train_common.batch_size = 64
train_common.patience = 10
train_common.min_delta = 1e-4
Expand Down
2 changes: 1 addition & 1 deletion configs/prediction_models/common/DLTuning.gin
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
tune_hyperparameters.scopes = ["model", "optimizer"]
tune_hyperparameters.n_initial_points = 5
tune_hyperparameters.n_calls = 30
tune_hyperparameters.folds_to_tune_on = 2
tune_hyperparameters.folds_to_tune_on = 5
6 changes: 5 additions & 1 deletion configs/prediction_models/common/MLCommon.gin
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
# Imports to register the models
import gin.torch.external_configurables
import icu_benchmarks.models.wrappers
import icu_benchmarks.models.ml_models
import icu_benchmarks.models.ml_models.sklearn
import icu_benchmarks.models.ml_models.lgbm
import icu_benchmarks.models.ml_models.xgboost
import icu_benchmarks.models.ml_models.imblearn
import icu_benchmarks.models.ml_models.catboost
import icu_benchmarks.models.utils

# Patience for early stopping
Expand Down
6 changes: 3 additions & 3 deletions configs/prediction_models/common/MLTuning.gin
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Hyperparameter tuner settings for classical Machine Learning.
tune_hyperparameters.scopes = ["model"]
tune_hyperparameters.n_initial_points = 10
tune_hyperparameters.n_calls = 3
tune_hyperparameters.folds_to_tune_on = 1
tune_hyperparameters.n_initial_points = 5
tune_hyperparameters.n_calls = 30
tune_hyperparameters.folds_to_tune_on = 5
5 changes: 2 additions & 3 deletions configs/tasks/BinaryClassification.gin
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ DLPredictionWrapper.loss = @cross_entropy

# SELECTING PREPROCESSOR
preprocess.preprocessor = @base_classification_preprocessor
preprocess.modality_mapping = %modality_mapping
preprocess.vars = %vars
preprocess.use_static = True

# SELECTING DATASET
PredictionDataset.vars = %vars
PredictionDataset.ram_cache = True

include "configs/tasks/common/Dataloader.gin"

43 changes: 43 additions & 0 deletions configs/tasks/CassClassification.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# COMMON IMPORTS
include "configs/tasks/common/Imports.gin"


# CROSS-VALIDATION
include "configs/tasks/common/CrossValidation.gin"

# MODE SETTINGS
Run.mode = "Classification"
NUM_CLASSES = 2 # Binary classification
HORIZON = 12
train_common.weight = "balanced"
train_common.ram_cache = True

# DEEP LEARNING
DLPredictionWrapper.loss = @cross_entropy

# SELECTING PREPROCESSOR
preprocess.preprocessor = @base_classification_preprocessor
preprocess.vars = %vars
preprocess.use_static = True

# SELECTING DATASET
include "configs/tasks/common/Dataloader.gin"
# SELECTING DATASET
# PredictionDataset.vars = %vars
# PredictionDataset.ram_cache = True

# DATASET CONFIGURATION
preprocess.file_names = {
"DYNAMIC": "dyn.parquet",
"OUTCOME": "outc.parquet",
"STATIC": "sta.parquet",
}

#include "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/14-08-2024_normal_ward/vars.gin"
#include "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/19-08-2024_all_wards/vars.gin"
# include "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/dataset_x_segment_duration_x_transfer_2024-08-26 13:43:20/vars.gin"
# include "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/dataset_x_segment_duration_x_transfer_2024-09-12T16:22:01/vars.gin"
# include "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/dataset_baseline_flat_2024-09-17T17:25:20/vars.gin"
# include "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/dataset_segment_1.0_horizon_6:00:00_transfer_full_2024-09-27T16:14:05/vars.gin"
include "/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format/dataset_segment_1.0_horizon_6:00:00_transfer_3_2024-10-07T23:26:22/vars.gin"
# preprocess.modality_mapping = %modality_mapping
4 changes: 2 additions & 2 deletions configs/tasks/DatasetImputation.gin
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ preprocess.file_names = {
preprocess.preprocessor = @base_imputation_preprocessor

preprocess.vars = %vars
ImputationDataset.vars = %vars
ImputationDataset.ram_cache = True

include "configs/tasks/common/Dataloader.gin"

3 changes: 1 addition & 2 deletions configs/tasks/Regression.gin
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,5 @@ base_regression_preprocessor.outcome_min = 0
base_regression_preprocessor.outcome_max = 15

# SELECTING DATASET
PredictionDataset.vars = %vars
PredictionDataset.ram_cache = True
include "configs/tasks/common/Dataloader.gin"

8 changes: 8 additions & 0 deletions configs/tasks/common/Dataloader.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Prediction
PredictionPandasDataset.vars = %vars
PredictionPandasDataset.ram_cache = True
PredictionPolarsDataset.vars = %vars
PredictionPolarsDataset.ram_cache = True
# Imputation
ImputationPandasDataset.vars = %vars
ImputationPandasDataset.ram_cache = True
8 changes: 8 additions & 0 deletions configs/tasks/common/PredictionTaskVariables.gin
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,12 @@ vars = {
"methb", "mg", "na", "neut", "o2sat", "pco2", "ph", "phos", "plt", "po2", "ptt", "resp", "sbp", "temp", "tnt", "urine",
"wbc"],
"STATIC": ["age", "sex", "height", "weight"],
}

modality_mapping = {
"DYNAMIC": ["alb", "alp", "alt", "ast", "be", "bicar", "bili", "bili_dir", "bnd", "bun", "ca", "cai", "ck", "ckmb", "cl",
"crea", "crp", "dbp", "fgn", "fio2", "glu", "hgb", "hr", "inr_pt", "k", "lact", "lymph", "map", "mch", "mchc", "mcv",
"methb", "mg", "na", "neut", "o2sat", "pco2", "ph", "phos", "plt", "po2", "ptt", "resp", "sbp", "temp", "tnt", "urine",
"wbc"],
"STATIC": ["age", "sex", "height", "weight"],
}
18 changes: 18 additions & 0 deletions docs/adding_model/RNN.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Settings for Recurrent Neural Network (RNN) models.

# Common settings for DL models
include "configs/prediction_models/common/DLCommon.gin"

# Train params
train_common.model = @RNNet

# Optimizer params
optimizer/hyperparameter.class_to_tune = @Adam
optimizer/hyperparameter.weight_decay = 1e-6
optimizer/hyperparameter.lr = (1e-5, 3e-4)

# Encoder params
model/hyperparameter.class_to_tune = @RNNet
model/hyperparameter.num_classes = %NUM_CLASSES
model/hyperparameter.hidden_dim = (32, 256, "log-uniform", 2)
model/hyperparameter.layer_dim = (1, 3)
Loading
Loading