Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Version release based on practice Cassandra use #155

Merged
merged 94 commits into from
Oct 17, 2024

Conversation

rvandewater
Copy link
Owner

@rvandewater rvandewater commented Oct 15, 2024

Updates from the cassandra project currently include:

  • Optuna hyperparam optimization
    
  • Polars backend to improve speed and ram usage
    
  • New models
    
  • Probably more :)
    

Summary by CodeRabbit

  • New Features

    • Introduced configuration settings for multiple classifiers including Balanced Random Forest, CatBoost, XGBoost, and RUSBClassifier.
    • Added new deep learning model configurations for RNN, GRU, TCN, and Transformer.
    • Implemented new dataset handling capabilities with Polars integration.
    • Enhanced capabilities for binary classification, regression, and dataset imputation tasks.
    • Added new configuration for benchmarking experiments.
    • New files for executing Weights & Biases sweeps and job submissions on HPC clusters.
  • Bug Fixes

    • Enhanced error handling in cross-validation processes.
  • Documentation

    • Updated instructions for adding new models to the framework.
  • Chores

    • Updated dependencies and package versions for better compatibility.

Copy link

coderabbitai bot commented Oct 15, 2024

Caution

Review failed

The pull request is closed.

Walkthrough

The changes in this pull request introduce new configuration files for various classifiers, including Balanced Random Forest, CatBoost, and XGBoost, along with modifications to existing configurations for models like GRU, Random Forest, and Transformer. Additionally, significant updates were made to support data handling using the Polars library, enhancing data loading and preprocessing functionalities. New scripts for benchmarking and job scheduling on HPC clusters were added, while several files were created for new deep learning architectures, including RNNs and Transformers. Various updates were also made to enhance hyperparameter tuning and model training processes.

Changes

File Path Change Summary
configs/prediction_models/BRFClassifier.gin New file for Balanced Random Forest Classifier configuration with hyperparameters.
configs/prediction_models/CBClassifier.gin New file for CatBoost classifier configuration with hyperparameters.
configs/prediction_models/GRU.gin Updates to GRU model configuration, including learning rate and hidden dimensions.
configs/prediction_models/RFClassifier.gin Modifications to Random Forest Classifier configuration, adjusting hyperparameters.
configs/prediction_models/RUSBClassifier.gin New file for RUSBClassifier configuration with hyperparameters.
configs/prediction_models/TCN.gin Updates to TCN model configuration, including learning rate and kernel size.
configs/prediction_models/Transformer.gin Modifications to Transformer model configuration, updating learning rate and hyperparameters.
configs/prediction_models/XGBClassifier.gin New file for XGBoost classifier configuration with hyperparameters.
configs/prediction_models/common/DLCommon.gin Updates to deep learning common configurations, including import statements and epochs.
configs/prediction_models/common/MLCommon.gin Modifications to machine learning common configurations, updating import statements.
configs/tasks/BinaryClassification.gin Modifications to binary classification task configuration, enhancing preprocessing.
configs/tasks/CassClassification.gin New file for a classification task in deep learning with specific parameters.
configs/tasks/DatasetImputation.gin Modifications to data imputation task configuration, including data loading settings.
configs/tasks/Regression.gin Modifications to regression task configuration, enhancing data loading.
configs/tasks/common/Dataloader.gin New configuration settings for dataset classes for improved data handling.
configs/tasks/common/PredictionTaskVariables.gin New configuration section for modality mapping in prediction tasks.
docs/adding_model/RNN.gin New file for RNN model configuration settings.
docs/adding_model/instructions.md Updated guidelines for adding new models to the YAIB framework.
docs/adding_model/rnn.py New file defining RNN model using PyTorch.
environment.yml Updated pip dependency specification to a version range.
experiments/benchmark_cass.yml New configuration for a benchmarking experiment with execution parameters.
experiments/charhpc_wandb_sweep.sh New Bash script for executing a WandB sweep on HPC cluster.
experiments/charhpc_wandb_sweep_cpu.sh New shell script for executing a job on HPC cluster.
experiments/slurm_base_char_sc.sh New Bash script for submitting binary classification jobs.
icu_benchmarks/cross_validation.py Enhancements to error handling in the execute_repeated_cv function.
icu_benchmarks/data/loader.py Significant modifications for Polars data manipulation, including new classes and renaming.
icu_benchmarks/data/preprocessor.py Major updates to implement new preprocessing classes and methods for Polars.
icu_benchmarks/data/split_process_data.py Enhancements for Polars data manipulation with new methods added.
icu_benchmarks/models/constants.py New import statement for confusion matrix and updates to metrics dictionaries.
icu_benchmarks/models/custom_metrics.py Added function for computing confusion matrices.
icu_benchmarks/models/dl_models.py Deleted file with various neural network architectures.
icu_benchmarks/models/dl_models/rnn.py New file for RNN architectures with multiple classes.
icu_benchmarks/models/dl_models/tcn.py New class for Temporal Convolutional Network.
icu_benchmarks/models/dl_models/transformer.py New file defining Transformer and LocalTransformer classes.
icu_benchmarks/models/ml_models/catboost.py New file for CatBoost classification tasks.
icu_benchmarks/models/ml_models/imblearn.py New classes for imbalanced learning tasks.
icu_benchmarks/models/ml_models/lgbm.py New file for LightGBM model wrappers.
icu_benchmarks/models/ml_models/sklearn.py Significant modifications, including removal of LightGBM wrapper classes.
icu_benchmarks/models/ml_models/xgboost.py New file for XGBoost classification tasks.
icu_benchmarks/models/train.py Updates to replace pandas with Polars for data handling.
icu_benchmarks/models/utils.py Added functionality related to scoring and label smoothing.
icu_benchmarks/models/wrappers.py Enhanced metrics integration and data handling.
icu_benchmarks/run.py Enhancements to configurability and robustness.
icu_benchmarks/run_utils.py Updates to configuration and logging functionality.
icu_benchmarks/tuning/hyperparameters.py Enhanced hyperparameter tuning functionality.
requirements.txt Updates to package versions and dependencies.

Poem

In the garden of code, new blooms arise,
Classifiers sprout, reaching for the skies.
With tweaks and turns, the models grow,
RNNs and TCNs, putting on a show.
Data flows freely, like rivers run,
In the world of AI, we’ve just begun! 🐇✨


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Inline review comments failed to post

🛑 Comments failed to post (80)
environment.yml (1)

5-5: 🛠️ Refactor suggestion

Consider upgrading to a newer Python version

The environment is currently using Python 3.10. While this is still a supported version, newer versions (like 3.11 or 3.12) offer performance improvements and new features.

Consider upgrading to a newer Python version if your project doesn't have specific dependencies requiring Python 3.10. This could potentially improve your application's performance and allow you to use newer language features.

experiments/charhpc_wandb_sweep.sh (2)

10-13: 🛠️ Refactor suggestion

Simplify Conda initialization.

The current Conda initialization uses both source and eval methods, which might be redundant. You can simplify this by using only one method.

Consider updating the Conda initialization as follows:

-source /etc/profile.d/conda.sh
-
-eval "$(conda shell.bash hook)"
+. /etc/profile.d/conda.sh

This change uses the . command (equivalent to source) to load the Conda profile, which should be sufficient for initializing Conda in most environments.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

. /etc/profile.d/conda.sh

conda activate yaib_req_pl

14-14: 🛠️ Refactor suggestion

Add error handling for the sweep ID argument.

The WandB agent command looks good, but it assumes that the sweep ID will always be provided as an argument. It's a good practice to add error handling to ensure the script fails gracefully if the argument is missing.

Consider adding a check for the command-line argument before running the WandB agent:

+if [ $# -eq 0 ]; then
+    echo "Error: Sweep ID not provided"
+    echo "Usage: $0 <sweep_id>"
+    exit 1
+fi
+
 wandb agent --count 1 cassandra_hpi/cassandra/"$1"

This change will print an error message and exit if no argument is provided, making the script more robust.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

if [ $# -eq 0 ]; then
    echo "Error: Sweep ID not provided"
    echo "Usage: $0 <sweep_id>"
    exit 1
fi

wandb agent --count 1 cassandra_hpi/cassandra/"$1"
configs/tasks/DatasetImputation.gin (1)

26-26: 💡 Codebase verification

Potential Conflicts Detected in Dataloader.gin

  • ImputationPandasDataset.vars and ImputationPandasDataset.ram_cache are still defined in configs/tasks/common/Dataloader.gin.

    These settings may conflict with previous configurations where ImputationDataset.vars and ImputationDataset.ram_cache were removed or altered.

Recommended Actions:

  • Review the Dataloader.gin file to determine if the ImputationPandasDataset configurations should be updated or removed to align with the intended changes.
  • Ensure that all related configurations across the codebase are consistent with the updated data loading approach.
🔗 Analysis chain

Confirm the contents and impact of the new Dataloader configuration

The addition of include "configs/tasks/common/Dataloader.gin" suggests a move towards a more standardized approach for data loading. This change could improve consistency across different tasks or datasets.

To ensure this change is implemented correctly and doesn't introduce any issues, please run the following script:

Please review the output to ensure that:

  1. The Dataloader.gin file exists and contains the expected configurations.
  2. Other relevant files are also updated to use this new Dataloader configuration.
  3. There are no conflicts with existing configurations, especially regarding the previously removed ImputationDataset.vars and ImputationDataset.ram_cache.
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the contents and usage of the new Dataloader configuration

echo "Contents of configs/tasks/common/Dataloader.gin:"
cat configs/tasks/common/Dataloader.gin

echo "Checking for other files that include Dataloader.gin:"
rg 'include.*"configs/tasks/common/Dataloader.gin"' --type gin

echo "Checking for potential conflicts or overrides with existing configurations:"
rg "ImputationDataset|vars|ram_cache" configs/tasks/common/Dataloader.gin

Length of output: 1183


Script:

#!/bin/bash
# Description: Verify the contents and usage of the new Dataloader configuration

echo "Contents of configs/tasks/common/Dataloader.gin:"
cat configs/tasks/common/Dataloader.gin

echo "Checking for other files that include Dataloader.gin:"
rg 'include\s+"configs/tasks/common/Dataloader\.gin"' --ignore-case

echo "Checking for potential conflicts or overrides with existing configurations:"
rg "ImputationDataset|vars|ram_cache" configs/tasks/common/Dataloader.gin

Length of output: 1390

docs/adding_model/RNN.gin (1)

6-7: 💡 Codebase verification

Issue Found: RNNet Class Definition Missing

The @RNNet model is specified in the train params, but the RNNet class definition was not found in the codebase.

🔗 Analysis chain

LGTM: Clear train params configuration. Verify @rnnet definition.

The train params configuration is concise and clear. However, it's important to ensure that @RNNet is properly defined elsewhere in the codebase.

Run the following script to verify the RNNet class definition:

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify the existence and definition of the RNNet class.

# Test: Search for the RNNet class definition
rg --type python -A 5 'class RNNet'

Length of output: 1678

configs/prediction_models/RFClassifier.gin (1)

16-16: 💡 Codebase verification

Inconsistent tuple syntax found in configuration files.
The following files still use tuple syntax for class_weight and should be updated to use list syntax for consistency:

  • configs/prediction_models/RFClassifier.gin
  • configs/prediction_models/BRFClassifier.gin
🔗 Analysis chain

LGTM! Verify consistency across all configuration files.

The change from tuple to list syntax for class_weight is appropriate and consistent with the changes made to other hyperparameters.

To ensure consistency across the entire project, please run the following script to check if all similar configuration files have been updated with the new list syntax:

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for inconsistent syntax in gin configuration files

# Test: Search for tuple syntax in hyperparameter definitions
rg --type gin -g 'configs/prediction_models/*.gin' '(True, False)|(None,)|(None, .+)' -l

# If the above command returns any results, it indicates files that still use tuple syntax and may need updating.

Length of output: 118


Script:

#!/bin/bash
# Description: Check for inconsistent tuple syntax in .gin configuration files

# Test: Search for tuple syntax in class_weight hyperparameter definitions within .gin files
rg '(True, False)|(None,)|(None, .+)' -g '*.gin' -l

Length of output: 140

configs/prediction_models/Transformer.gin (1)

22-22: 🛠️ Refactor suggestion

⚠️ Potential issue

Consider adjusting the upper bound for attention dropout

The attention dropout range has been expanded from (0.0, 0.4) to (0.0, 1.0). An upper bound of 1.0 for dropout is extremely high and could potentially degrade model performance significantly.

Consider adjusting the upper bound to a more reasonable value, such as 0.5:

-model/hyperparameter.dropout_att = (0.0, 1.0)
+model/hyperparameter.dropout_att = (0.0, 0.5)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

model/hyperparameter.dropout_att = (0.0, 0.5)
configs/prediction_models/XGBClassifier.gin (1)

9-17: 🛠️ Refactor suggestion

Consider narrowing the range for scale_pos_weight

The hyperparameter settings look good overall and cover important parameters for XGBoost tuning. However, the range for scale_pos_weight is very wide (1 to 1000). This might lead to suboptimal results or unnecessarily long tuning times.

Consider narrowing down the range based on your dataset's class imbalance ratio. A typical approach is to set it close to the ratio of negative to positive instances. For example, if you have a 1:10 imbalance, you might use:

-model/hyperparameter.scale_pos_weight = [1, 5, 10, 15, 20, 25, 30, 35, 40, 50, 75, 99, 100, 1000]
+model/hyperparameter.scale_pos_weight = [1, 5, 10, 15, 20]

This change would focus the search on more relevant values and potentially speed up the tuning process.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

model/hyperparameter.class_to_tune = @XGBClassifier
model/hyperparameter.learning_rate = (0.01, 0.1, "log")
model/hyperparameter.n_estimators = [50, 100, 250, 500, 750, 1000,1500,2000]
model/hyperparameter.max_depth = [3, 5, 10, 15]
model/hyperparameter.scale_pos_weight = [1, 5, 10, 15, 20]
model/hyperparameter.min_child_weight = [1, 0.5]
model/hyperparameter.max_delta_step = [0, 1, 2, 3, 4, 5, 10]
model/hyperparameter.colsample_bytree = [0.1, 0.25, 0.5, 0.75, 1.0]
model/hyperparameter.eval_metric = "aucpr"
icu_benchmarks/models/ml_models/catboost.py (1)

14-17: 🛠️ Refactor suggestion

⚠️ Potential issue

Consider clarifying model initialization and addressing potential issues.

  1. The commented-out line suggests GPU support was considered. If GPU support is planned for the future, consider adding a TODO comment explaining this.

  2. The use of set_model_args is good for flexibility, but it's unclear what arguments are being passed. Consider adding a comment explaining the expected arguments or their purpose.

  3. The static analysis tool flagged a potential issue with star-arg unpacking after a keyword argument on line 16. This could lead to unexpected behavior. Consider refactoring to avoid this issue:

def __init__(self, *args, **kwargs):
    model_args = self.set_model_args(*args, **kwargs)
    self.model = cb.CatBoostClassifier(task_type="CPU", **model_args)
    super().__init__(*args, **kwargs)

This refactoring separates the argument preparation from the model initialization, which may be clearer and avoid the star-arg unpacking issue.

🧰 Tools
🪛 Ruff

16-16: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)

configs/tasks/common/PredictionTaskVariables.gin (1)

20-25: 🛠️ Refactor suggestion

Consider the necessity and potential redundancy of the modality_mapping section.

The newly added modality_mapping section appears to be an exact duplicate of the "DYNAMIC" and "STATIC" lists in the vars section. While this addition might serve a specific purpose, it's worth considering the following points:

  1. Redundancy: Having identical information in two places can lead to maintenance issues if one section is updated and the other is forgotten.
  2. Purpose: The purpose of modality_mapping is not immediately clear from the context. It would be helpful to add a comment explaining its specific use case and how it differs from vars.
  3. DRY principle: To adhere to the Don't Repeat Yourself (DRY) principle, consider refactoring this configuration to avoid duplication.

Consider one of the following approaches:

  1. If modality_mapping serves a distinct purpose, add a comment explaining its role and how it differs from vars.
  2. If modality_mapping is intended to replace vars, update all references to use the new structure and remove the old one.
  3. If the two structures are meant to be identical, consider using a single source of truth:
dynamic_vars = ["alb", "alp", "alt", "ast", "be", "bicar", "bili", "bili_dir", "bnd", "bun", "ca", "cai", "ck", "ckmb", "cl",
    "crea", "crp", "dbp", "fgn", "fio2", "glu", "hgb", "hr", "inr_pt", "k", "lact", "lymph", "map", "mch", "mchc", "mcv",
    "methb", "mg", "na", "neut", "o2sat", "pco2", "ph", "phos", "plt", "po2", "ptt", "resp", "sbp", "temp", "tnt", "urine",
    "wbc"]
static_vars = ["age", "sex", "height", "weight"]

vars = {
    "GROUP": "stay_id",
    "SEQUENCE": "time",
    "LABEL": "label",
    "DYNAMIC": dynamic_vars,
    "STATIC": static_vars,
}

modality_mapping = {
    "DYNAMIC": dynamic_vars,
    "STATIC": static_vars,
}

This approach ensures that any updates to the variable lists are automatically reflected in both structures.

experiments/slurm_base_char_sc.sh (4)

22-22: ⚠️ Potential issue

Fix typo and address undefined DATASETS array.

  1. There's a typo in the "echo" command:
-echi "Task type:" ${TASK}
+echo "Task type:" ${TASK}
  1. The DATASETS array is used but not defined. Either define it before use or replace it with the appropriate variable:
-echo "Dataset: "${DATASETS[$SLURM_ARRAY_TASK_ID]}
+echo "Dataset: ${DATASET_ROOT_PATH}"

If you intended to use multiple datasets, make sure to define the DATASETS array before this line.

Also applies to: 25-25


18-18: ⚠️ Potential issue

Remove space after '=' in DATASET_ROOT_PATH assignment.

The space after the '=' in the DATASET_ROOT_PATH assignment can cause issues in Bash. Remove it to ensure proper variable assignment:

-DATASET_ROOT_PATH= /sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format
+DATASET_ROOT_PATH=/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

DATASET_ROOT_PATH=/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format #data/YAIB_Datasets/data
🧰 Tools
🪛 Shellcheck

[warning] 18-18: Remove space after = if trying to assign a value (for empty string, use var='' ... ).

(SC1007)


28-28: 🛠️ Refactor suggestion

Improve error handling for directory change.

Add error handling to the cd command to exit the script if the directory change fails:

-cd ${YAIB_PATH}
+cd ${YAIB_PATH} || { echo "Failed to change directory to ${YAIB_PATH}"; exit 1; }

This ensures that the script doesn't continue executing in the wrong directory if the cd command fails.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

cd ${YAIB_PATH} || { echo "Failed to change directory to ${YAIB_PATH}"; exit 1; }
🧰 Tools
🪛 Shellcheck

[warning] 28-28: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.

(SC2164)


35-44: ⚠️ Potential issue

Address undefined variables and document --tune flag.

  1. The ${DATASETS} variable is used but not defined earlier in the script. Either define it or replace it with the appropriate variable:
-  -n ${DATASETS} \
+  -n ${DATASET_ROOT_PATH} \
  1. Consider adding a comment to explain the purpose of the --tune flag, e.g.:
+  # Enable hyperparameter tuning
   --tune

Ensure that all variables used in this command are properly defined earlier in the script.

Committable suggestion was skipped due to low confidence.

icu_benchmarks/models/constants.py (1)

48-50: 💡 Codebase verification

**Concerns About Removing `average_precision_score`**

1. The `average_precision_score` is still imported and used in multiple files:
   - `icu_benchmarks/models/wrappers.py`
   - `icu_benchmarks/models/utils.py`
   - `icu_benchmarks/models/constants.py`

2. Commenting out or removing it from `constants.py` may adversely affect these modules that rely on this metric.

Could you please confirm whether the removal of `average_precision_score` is intentional and ensure that it does not impact existing functionalities?

🔗 Analysis chain

Clarify the removal of average_precision_score and approve addition of confusion_matrix

  1. The addition of confusion_matrix to the MULTICLASS_CLASSIFICATION dictionary is appropriate and valuable for evaluating multiclass classification models.

  2. However, the average_precision_score has been commented out. This might impact the calculation of precision-recall metrics for multiclass scenarios. Could you please provide the rationale for removing this metric? If it's no longer needed, consider removing the commented line entirely to maintain code cleanliness.

To ensure that the removal of average_precision_score doesn't negatively impact any existing code, let's verify its usage:



🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for any remaining usage of average_precision_score in multiclass scenarios

# Test: Search for usage of average_precision_score in Python files
rg --type python 'average_precision_score.*multiclass'

Length of output: 87


Script:

#!/bin/bash
# Description: Check for any remaining usage of average_precision_score in multiclass scenarios

# Search for usage of average_precision_score in Python files related to multiclass classification
rg --glob '*.py' 'average_precision_score.*multiclass'

Length of output: 56


Script:

#!/bin/bash
# Description: Check for any remaining usage of average_precision_score in the codebase

# Search for usage of average_precision_score in all Python files
rg --glob '*.py' 'average_precision_score'

Length of output: 885

experiments/benchmark_cass.yml (1)

29-75: 🛠️ Refactor suggestion

Review and optimize parameter configurations

The parameters section is comprehensive, but there are several points to consider:

  1. Data Directory:

    • The current data directory path (line 32) is very specific. Consider using environment variables or a configuration file to make it more portable across different environments.
    • There are many commented-out data directory options. Document the reasons for these alternatives or remove them if they're no longer relevant.
  2. Model Selection:

    • The current selection includes both traditional (XGBClassifier) and deep learning (GRU, Transformer) models. This is good for comparison, but ensure that the computational resources are sufficient for all these models.
  3. Modalities:

    • Currently set to "all" (line 59). Consider experimenting with specific combinations of modalities to understand their individual impacts on the model's performance.
  4. Seed:

    • The fixed seed (1111) ensures reproducibility, which is good. However, for a more robust evaluation, consider running multiple experiments with different seeds.
  5. Pretrained Imputation:

    • Set to None (line 75). Depending on your data quality, using pretrained imputation might improve model performance. Consider experimenting with this option.

To make the configuration more flexible and easier to maintain, consider the following changes:

  1. Use environment variables for paths:
data_dir:
  values:
    - ${YAIB_DATA_DIR}/dataset_segment_1.0_horizon_6:00:00_transfer_3_2024-10-07T23:26:22
  1. Add a comment explaining the modalities choice:
modalities:
  values:
    - "all"  # Using all available modalities for comprehensive analysis
  1. Consider adding multiple seeds for robustness:
seed:
  values:
    - 1111
    - 2222
    - 3333
  1. Experiment with pretrained imputation:
use_pretrained_imputation:
  values:
    - None
    - True  # Add this line to experiment with pretrained imputation
icu_benchmarks/models/custom_metrics.py (2)

137-150: 🛠️ Refactor suggestion

⚠️ Potential issue

Refine confusion_matrix function implementation and signature.

The new confusion_matrix function is a valuable addition, but there are a few issues to address:

  1. The function signature indicates a return type of torch.tensor, but it actually returns a dictionary.
  2. There's an unused variable confusion_tensor.
  3. There are commented-out code blocks that may be unnecessary.

Please consider the following changes:

  1. Update the function signature to match the actual return type:
def confusion_matrix(y_true: ndarray, y_pred: ndarray, normalize=False) -> dict:
  1. Remove the unused variable:
-    confusion_tensor = torch.tensor(confusion)
  1. Remove or uncomment the logging statement if it's needed:
-    # logging.info(f"Confusion matrix: {confusion_dict}")
  1. Consider removing the commented-out alternative dictionary creation if it's no longer needed:
-    # dict = {"TP": confusion[0][0], "FP": confusion[0][1], "FN": confusion[1][0], "TN": confusion[1][1]}

To improve readability and efficiency, consider using a dictionary comprehension:

confusion_dict = {f"class_{i}_pred_{j}": confusion[i][j] for i in range(confusion.shape[0]) for j in range(confusion.shape[1])}

This change will make the code more concise and potentially more efficient.

🧰 Tools
🪛 Ruff

142-142: Local variable confusion_tensor is assigned to but never used

Remove assignment to unused variable confusion_tensor

(F841)


1-8: ⚠️ Potential issue

Remove unused import and approve new imports.

The new imports are relevant to the added confusion_matrix function. However, there's an unused import that should be removed.

Please apply the following change:

-import logging

The other new imports (ndarray from numpy and confusion_matrix as sk_confusion_matrix from sklearn.metrics) are appropriate for the new functionality.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

import torch
from typing import Callable
import numpy as np
from ignite.metrics import EpochMetric
from numpy import ndarray
from sklearn.metrics import balanced_accuracy_score, mean_absolute_error, confusion_matrix as sk_confusion_matrix
🧰 Tools
🪛 Ruff

1-1: logging imported but unused

Remove unused import: logging

(F401)

icu_benchmarks/cross_validation.py (1)

126-128: 🛠️ Refactor suggestion

Consider making epochs and patience configurable parameters.

The epochs and patience values are currently hardcoded in the train_common function call. To improve flexibility, consider making these configurable parameters of the execute_repeated_cv function or reading them from a configuration file.

You could modify the function signature and call like this:

def execute_repeated_cv(
    # ... other parameters ...
    epochs: int = 20,
    patience: int = 5,
    # ... other parameters ...
):
    # ... existing code ...
    agg_loss += train_common(
        # ... other parameters ...
        train_only=complete_train,
        epochs=epochs,
        patience=patience
    )
    # ... rest of the function ...

This change would allow users to easily adjust these parameters without modifying the function body.

icu_benchmarks/models/ml_models/imblearn.py (3)

2-2: ⚠️ Potential issue

Fix typo in the import statement.

The module name contants seems to be misspelled. It should be constants.

Apply this change:

-from icu_benchmarks.contants import RunMode
+from icu_benchmarks.constants import RunMode
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

from icu_benchmarks.constants import RunMode

10-12: 🛠️ Refactor suggestion

Call super().__init__ before setting self.model.

It's a good practice to initialize the parent class before accessing instance attributes to ensure proper setup.

Apply this change:

 def __init__(self, *args, **kwargs):
+    super().__init__(*args, **kwargs)
     self.model = self.set_model_args(BalancedRandomForestClassifier, *args, **kwargs)
-    super().__init__(*args, **kwargs)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = self.set_model_args(BalancedRandomForestClassifier, *args, **kwargs)

17-19: 🛠️ Refactor suggestion

Call super().__init__ before setting self.model.

To ensure proper initialization, invoke the parent class constructor first.

Apply this change:

 def __init__(self, *args, **kwargs):
+    super().__init__(*args, **kwargs)
     self.model = self.set_model_args(RUSBoostClassifier, *args, **kwargs)
-    super().__init__(*args, **kwargs)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.model = self.set_model_args(RUSBoostClassifier, *args, **kwargs)
docs/adding_model/rnn.py (4)

19-19: ⚠️ Potential issue

Validate input_size before accessing input_size[2].

Accessing input_size[2] assumes that input_size has at least three elements. If input_size does not meet this condition, it could lead to an IndexError. Ensure that input_size is correctly defined and consider adding a validation check.

Consider adding this validation:

if len(input_size) < 3:
    raise ValueError("input_size must have at least three elements")

3-3: ⚠️ Potential issue

Fix the typo in the module name 'contants' to 'constants'.

There is a typographical error in the import statement on line 3. The module name should be corrected to properly import RunMode.

Apply this diff to fix the typo:

-from icu_benchmarks.contants import RunMode
+from icu_benchmarks.constants import RunMode
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

from icu_benchmarks.constants import RunMode

15-15: ⚠️ Potential issue

Avoid unpacking *args and **kwargs after keyword arguments in function calls.

Passing *args and **kwargs after keyword arguments in a function call is discouraged and can lead to unexpected behavior. Consider refactoring the call to super().__init__ to improve code clarity and adherence to best practices.

Apply this diff to reorder the arguments:

 super().__init__(
-    input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, *args, **kwargs
+    *args, 
+    **kwargs, 
+    input_size=input_size, 
+    hidden_dim=hidden_dim, 
+    layer_dim=layer_dim, 
+    num_classes=num_classes
 )

Alternatively, if possible, pass all parameters explicitly without using *args and **kwargs.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

            *args, 
            **kwargs, 
            input_size=input_size, 
            hidden_dim=hidden_dim, 
            layer_dim=layer_dim, 
            num_classes=num_classes
🧰 Tools
🪛 Ruff

15-15: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)


28-30: 🛠️ Refactor suggestion

Clarify whether sequence-level or time-step predictions are intended.

In the forward method, passing out directly to self.logit applies the linear layer to all time steps, resulting in predictions at each time step. If the intended behavior is to make sequence-level predictions, consider using only the last time step's output.

Apply this diff to modify the forward pass:

 def forward(self, x):
     h0 = self.init_hidden(x)
     out, hn = self.rnn(x, h0)
-    pred = self.logit(out)
+    pred = self.logit(out[:, -1, :])
     return pred
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

        out, hn = self.rnn(x, h0)
        pred = self.logit(out[:, -1, :])
        return pred
icu_benchmarks/models/ml_models/lgbm.py (4)

44-49: 🛠️ Refactor suggestion

Add predict method to LGBMRegressor for consistency

The LGBMRegressor class lacks a predict method, which may lead to inconsistencies when using the model in pipelines that expect this method. Consider adding a predict method similar to LGBMClassifier.

Suggested addition:

def predict(self, features):
    """Predicts outputs for the given features."""
    return self.model.predict(features)

7-7: ⚠️ Potential issue

Typo in module name 'contants'

There's a typo in the module name contants; it should be constants.

Apply this fix:

- from icu_benchmarks.contants import RunMode
+ from icu_benchmarks.constants import RunMode
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

from icu_benchmarks.constants import RunMode

15-15: ⚠️ Potential issue

Incorrect usage of LightGBM callbacks

The LightGBM callback functions early_stopping and log_evaluation should be accessed from the lgbm.callback module, not directly from lgbm.

Apply this fix:

- callbacks = [lgbm.early_stopping(self.hparams.patience, verbose=True), lgbm.log_evaluation(period=-1)]
+ callbacks = [lgbm.callback.early_stopping(self.hparams.patience, verbose=True), lgbm.callback.log_evaluation(period=-1)]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

        callbacks = [lgbm.callback.early_stopping(self.hparams.patience, verbose=True), lgbm.callback.log_evaluation(period=-1)]

23-23: ⚠️ Potential issue

Incorrect format for eval_set parameter

The eval_set parameter in the fit method should be a list of tuples, not a single tuple. This ensures proper evaluation during training.

Apply this fix:

- eval_set=(val_data, val_labels),
+ eval_set=[(val_data, val_labels)],
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

            eval_set=[(val_data, val_labels)],
icu_benchmarks/models/dl_models/tcn.py (2)

7-7: ⚠️ Potential issue

Typo in module name 'contants'

There's a typo in the module name contants. It should be constants. This correction is necessary to avoid import errors.

Apply this diff to fix the typo:

-from icu_benchmarks.contants import RunMode
+from icu_benchmarks.constants import RunMode
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

from icu_benchmarks.constants import RunMode

23-23: ⚠️ Potential issue

Avoid star-arg unpacking after a keyword argument ([B026])

Placing *args after keyword arguments in a function call is discouraged as it can lead to unexpected behavior. According to PEP 8, positional arguments (*args) should come before keyword arguments. Consider rearranging the arguments or removing *args if it's not needed.

Apply this diff to reposition *args:

             input_size=input_size,
             num_channels=num_channels,
             num_classes=num_classes,
-            *args,
             max_seq_length=max_seq_length,
             kernel_size=kernel_size,
             dropout=dropout,
+            *args,
             **kwargs,

Committable suggestion was skipped due to low confidence.

🧰 Tools
🪛 Ruff

23-23: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)

icu_benchmarks/models/ml_models/sklearn.py (2)

3-3: ⚠️ Potential issue

Typo in import statement: 'contants' should be 'constants'

There is a typo in the import statement on line 3. The module name should be 'constants' instead of 'contants'.

Apply this diff to fix the typo:

-from icu_benchmarks.contants import RunMode
+from icu_benchmarks.constants import RunMode
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

from icu_benchmarks.constants import RunMode

8-8: ⚠️ Potential issue

Inconsistent naming of attribute _supported_run_modes

In line 8, the LogisticRegression class defines __supported_run_modes with double underscores, whereas other classes use a single underscore _supported_run_modes. For consistency and to avoid unintended name mangling, please change it to a single underscore.

Apply this diff to fix the inconsistency:

 class LogisticRegression(MLWrapper):
-    __supported_run_modes = [RunMode.classification]
+    _supported_run_modes = [RunMode.classification]

Committable suggestion was skipped due to low confidence.

icu_benchmarks/models/dl_models/rnn.py (4)

4-4: ⚠️ Potential issue

Typo in module name 'contants'

There's a typo in the import statement:

from icu_benchmarks.contants import RunMode

The module name contants might be intended to be constants. Please correct the module name to ensure proper import.

Apply this diff to fix the typo:

-from icu_benchmarks.contants import RunMode
+from icu_benchmarks.constants import RunMode

16-16: ⚠️ Potential issue

Avoid star-arg unpacking after keyword arguments

In the __init__ methods of RNNet, LSTMNet, and GRUNet, the use of *args and **kwargs after keyword arguments is discouraged and can lead to unexpected behaviors.

Apply this diff to reorder the arguments:

 def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs):
     super().__init__(
-        input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, *args, **kwargs
+        input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, *args, **kwargs
     )

Alternatively, consider placing *args and **kwargs before the keyword arguments in the method definition if appropriate.

Also applies to: 42-42, 69-69

🧰 Tools
🪛 Ruff

16-16: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)


23-25: 🛠️ Refactor suggestion

Ensure correct initialization of hidden states

The init_hidden methods use x.new_zeros to initialize hidden states:

h0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim)

While this works, it's recommended to use torch.zeros with the appropriate device to ensure compatibility across different devices (CPU/GPU).

Consider modifying the hidden state initialization:

-def init_hidden(self, x):
-    h0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim)
+def init_hidden(self, x):
+    h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim, device=x.device)

For the LSTMNet, ensure both h0 and c0 are initialized similarly.

Also applies to: 49-52, 76-78


30-31: 💡 Codebase verification

⚠️ Potential issue

Output Dimensions Verification Needed Before Passing to self.logit

The out tensor is being passed directly to self.logit(out) in multiple locations, which may include outputs for all timesteps. To ensure the correct output shape, modify the code to use only the last timestep's output.

  • Affected Files and Lines:
    • icu_benchmarks/models/dl_models/rnn.py:30-31
    • icu_benchmarks/models/dl_models/rnn.py:57-58
    • icu_benchmarks/models/dl_models/rnn.py:82-85
    • docs/adding_model/rnn.py:pred = self.logit(out)

Suggested Changes:

-pred = self.logit(out)
+pred = self.logit(out[:, -1, :])
🔗 Analysis chain

Verify the output dimensions before passing to the linear layer

The out tensor from the RNN layers is passed directly to self.logit(out). Depending on the RNN configuration, out may contain outputs for all timesteps. Ensure that self.logit is applied correctly to produce the desired output shape.

Run the following script to check how out is used in the codebase:

If you intend to use only the output from the last timestep, modify the code as follows:

Apply this diff to use the last timestep's output:

-pred = self.logit(out)
+pred = self.logit(out[:, -1, :])

Also applies to: 57-58, 82-85

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Verify that the output from RNN layers is correctly processed.

# Test: Search for usages of 'self.logit' after RNNs.
# Expect: 'out' is appropriately reshaped or indexed.
rg --type python "self\.logit\(out\)"

Length of output: 70


Script:

#!/bin/bash
# Description: Verify that the output from RNN layers is correctly processed.

# Test: Search for usages of 'self.logit' after RNNs.
# Expect: 'out' is appropriately reshaped or indexed.
rg --type py "self\.logit\(out\)"

Length of output: 301

icu_benchmarks/models/ml_models/xgboost.py (8)

23-23: ⚠️ Potential issue

Avoid star-arg unpacking after a keyword argument

Using *args after a keyword argument (device="cpu") is discouraged and can lead to unexpected behavior. Consider reordering the arguments or including device in kwargs.

Apply this diff to reorder arguments:

-self.model = self.set_model_args(xgb.XGBClassifier, device="cpu", *args, **kwargs)
+self.model = self.set_model_args(xgb.XGBClassifier, *args, device="cpu", **kwargs)

Alternatively, include device in kwargs:

-self.model = self.set_model_args(xgb.XGBClassifier, device="cpu", *args, **kwargs)
+self.model = self.set_model_args(xgb.XGBClassifier, *args, **kwargs)

Then, in your code, ensure device is set within kwargs:

kwargs.setdefault('device', 'cpu')
🧰 Tools
🪛 Ruff

23-23: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)


6-6: ⚠️ Potential issue

Remove unused import torch

The torch library is imported but not used anywhere in this module. Removing it will clean up the imports.

Apply this diff to remove the unused import:

-import torch
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.


🧰 Tools
🪛 Ruff

6-6: torch imported but unused

Remove unused import: torch

(F401)


11-11: ⚠️ Potential issue

Remove unused import LearningRateScheduler

The LearningRateScheduler from xgboost.callback is imported but not used in this module. Please remove it to keep the code clean.

Apply this diff:

-from xgboost.callback import EarlyStopping, LearningRateScheduler
+from xgboost.callback import EarlyStopping
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

from xgboost.callback import EarlyStopping
🧰 Tools
🪛 Ruff

11-11: xgboost.callback.LearningRateScheduler imported but unused

Remove unused import: xgboost.callback.LearningRateScheduler

(F401)


15-15: ⚠️ Potential issue

Remove unused import XGBoostPruningCallback

The XGBoostPruningCallback from optuna.integration is not used in this module. Consider removing it to reduce unnecessary imports.

Apply this diff:

-from optuna.integration import XGBoostPruningCallback
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.


🧰 Tools
🪛 Ruff

15-15: optuna.integration.XGBoostPruningCallback imported but unused

Remove unused import: optuna.integration.XGBoostPruningCallback

(F401)


37-37: ⚠️ Potential issue

Avoid logging training labels directly

Logging train_labels directly can lead to excessively large log outputs and potential exposure of sensitive data. Consider summarizing the labels instead.

Apply this diff:

-logging.info(train_labels)
+logging.info(f"Training labels distribution: {np.bincount(train_labels)}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

        logging.info(f"Training labels distribution: {np.bincount(train_labels)}")

73-82: ⚠️ Potential issue

Filter keyword arguments to valid model hyperparameters

Currently, all keyword arguments are passed to the model constructor without filtering, which may cause errors if invalid parameters are included. Consider filtering kwargs to include only valid hyperparameters accepted by the model.

Modify the set_model_args method as follows:

 def set_model_args(self, model, *args, **kwargs):
     """XGBoost signature does not include the hyperparams so we need to pass them manually."""
     signature = inspect.signature(model.__init__).parameters
     possible_hps = list(signature.keys())
     # Get passed keyword arguments
     arguments = kwargs
     # Get valid hyperparameters
-    hyperparams = arguments
+    hyperparams = {k: v for k, v in arguments.items() if k in possible_hps}
     logging.debug(f"Creating model with: {hyperparams}.")
     return model(**hyperparams)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

    def set_model_args(self, model, *args, **kwargs):
        """XGBoost signature does not include the hyperparams so we need to pass them manually."""
        signature = inspect.signature(model.__init__).parameters
        possible_hps = list(signature.keys())
        # Get passed keyword arguments
        arguments = kwargs
        # Get valid hyperparameters
        hyperparams = {k: v for k, v in arguments.items() if k in possible_hps}
        logging.debug(f"Creating model with: {hyperparams}.")
        return model(**hyperparams)

59-60: ⚠️ Potential issue

Remove unnecessary f-strings and fix filename inconsistency

The strings 'pred_indicators.csv' and 'row_indicators.csv' do not contain placeholders, so the f prefix is unnecessary. Additionally, there is an inconsistency between the filename used when saving and logging.

Apply this diff to remove unnecessary f-strings and fix the filename:

-np.savetxt(os.path.join(self.logger.save_dir,f'pred_indicators.csv'), pred_indicators, delimiter=",")
+np.savetxt(os.path.join(self.logger.save_dir,'pred_indicators.csv'), pred_indicators, delimiter=",")

-logging.debug(f"Saved row indicators to {os.path.join(self.logger.save_dir,f'row_indicators.csv')}")
+logging.debug(f"Saved prediction indicators to {os.path.join(self.logger.save_dir,'pred_indicators.csv')}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

            np.savetxt(os.path.join(self.logger.save_dir,'pred_indicators.csv'), pred_indicators, delimiter=",")
            logging.debug(f"Saved prediction indicators to {os.path.join(self.logger.save_dir,'pred_indicators.csv')}")
🧰 Tools
🪛 Ruff

59-59: f-string without any placeholders

Remove extraneous f prefix

(F541)


60-60: f-string without any placeholders

Remove extraneous f prefix

(F541)


32-38: ⚠️ Potential issue

Pass callbacks parameter to model.fit method

The callbacks list is defined but not used in the self.model.fit call. To enable early stopping and other callbacks, pass the callbacks parameter to the fit method.

Apply this diff:

 self.model.fit(train_data, train_labels, eval_set=[(val_data, val_labels)], verbose=False)
+self.model.fit(train_data, train_labels, eval_set=[(val_data, val_labels)], verbose=False, callbacks=callbacks)

Committable suggestion was skipped due to low confidence.

icu_benchmarks/models/dl_models/transformer.py (4)

15-77: 🛠️ Refactor suggestion

Refactor to reduce code duplication between Transformer and LocalTransformer

Both Transformer and LocalTransformer classes share significant code in their __init__ and forward methods. Consider refactoring by creating a base class or utility functions to encapsulate the shared logic. This will improve maintainability and reduce redundancy.

Also applies to: 83-148

🧰 Tools
🪛 Ruff

37-37: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)


52-52: Loop control variable i not used within loop body

Rename unused i to _i

(B007)


4-4: ⚠️ Potential issue

Fix typo in module name in import statement

There's a typo in the module name contants; it should be constants.

Apply this diff to fix the typo:

-from icu_benchmarks.contants import RunMode
+from icu_benchmarks.constants import RunMode
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

from icu_benchmarks.constants import RunMode

106-106: ⚠️ Potential issue

Avoid unpacking *args after keyword arguments in function call

Unpacking *args after keyword arguments in function calls is discouraged as it may lead to unexpected behavior. Consider moving *args before the keyword arguments.

Apply this diff to reorder the arguments in the super().__init__ call:

         super().__init__(
             input_size=input_size,
             hidden=hidden,
             heads=heads,
             ff_hidden_mult=ff_hidden_mult,
             depth=depth,
             num_classes=num_classes,
+            *args,
             dropout=dropout,
             l1_reg=l1_reg,
             pos_encoding=pos_encoding,
             local_context=local_context,
             dropout_att=dropout_att,
-            *args,
             **kwargs,
         )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

        super().__init__(
            input_size=input_size,
            hidden=hidden,
            heads=heads,
            ff_hidden_mult=ff_hidden_mult,
            depth=depth,
            num_classes=num_classes,
            *args,
            dropout=dropout,
            l1_reg=l1_reg,
            pos_encoding=pos_encoding,
            local_context=local_context,
            dropout_att=dropout_att,
            **kwargs,
        )
🧰 Tools
🪛 Ruff

106-106: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)


37-37: ⚠️ Potential issue

Avoid unpacking *args after keyword arguments in function call

Unpacking *args after keyword arguments in function calls is discouraged as it may lead to unexpected behavior. Consider moving *args before the keyword arguments.

Apply this diff to reorder the arguments in the super().__init__ call:

         super().__init__(
             input_size=input_size,
             hidden=hidden,
             heads=heads,
             ff_hidden_mult=ff_hidden_mult,
             depth=depth,
             num_classes=num_classes,
+            *args,
             dropout=dropout,
             l1_reg=l1_reg,
             pos_encoding=pos_encoding,
             dropout_att=dropout_att,
-            *args,
             **kwargs,
         )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

        super().__init__(
            input_size=input_size,
            hidden=hidden,
            heads=heads,
            ff_hidden_mult=ff_hidden_mult,
            depth=depth,
            num_classes=num_classes,
            *args,
            dropout=dropout,
            l1_reg=l1_reg,
            pos_encoding=pos_encoding,
            dropout_att=dropout_att,
            **kwargs,
        )
🧰 Tools
🪛 Ruff

37-37: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)

icu_benchmarks/run.py (2)

61-64: 🛠️ Refactor suggestion

Enhance error messages with available options

To improve user experience, include the list of available tasks or models in the error messages. This helps users correct their input when a task or model is not found.

Apply this diff to modify the error messages:

 if task not in tasks:
-    raise ValueError(f"Task {task} not found in tasks.")
+    raise ValueError(f"Task '{task}' not found. Available tasks: {', '.join(tasks)}.")

 if model not in models:
-    raise ValueError(f"Model {model} not found in models.")
+    raise ValueError(f"Model '{model}' not found. Available models: {', '.join(models)}.")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

    if task not in tasks:
        raise ValueError(f"Task '{task}' not found. Available tasks: {', '.join(tasks)}.")
    if model not in models:
        raise ValueError(f"Model '{model}' not found. Available models: {', '.join(models)}.")

191-194: ⚠️ Potential issue

Improve exception handling for result aggregation

Catching the base Exception may suppress important errors. Consider catching specific exceptions or re-raising the exception after logging. Also, use logging.exception to log the stack trace for better debugging.

Apply this diff to improve exception handling:

 try:
     aggregate_results(run_dir, execution_time)
-except Exception as e:
-    logging.error(f"Failed to aggregate results: {e}")
+except Exception:
+    logging.exception("Failed to aggregate results")
+    raise
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

    try:
        aggregate_results(run_dir, execution_time)
    except Exception:
        logging.exception("Failed to aggregate results")
        raise
icu_benchmarks/models/train.py (4)

113-113: 🛠️ Refactor suggestion

Uncomment pin_memory if required for performance

The pin_memory parameter in the DataLoader is commented out. If data loading performance is a priority and the hardware supports it, consider uncommenting pin_memory=not cpu to enable faster data transfer to GPU memory.

Apply this diff to enable pin_memory:

-    # pin_memory=not cpu,
+    pin_memory=not cpu,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

        pin_memory=not cpu,

85-86: 🛠️ Refactor suggestion

Improve readability of nested conditional expression

The assignment of dataset_class uses a nested ternary operator, which can reduce code readability. Refactor to an explicit if-elif-else structure for clarity.

Apply this diff to enhance readability:

-    dataset_class = ImputationPandasDataset if mode == RunMode.imputation else PredictionPolarsDataset if polars else PredictionPandasDataset
+    if mode == RunMode.imputation:
+        dataset_class = ImputationPandasDataset
+    elif polars:
+        dataset_class = PredictionPolarsDataset
+    else:
+        dataset_class = PredictionPandasDataset
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

    if mode == RunMode.imputation:
        dataset_class = ImputationPandasDataset
    elif polars:
        dataset_class = PredictionPolarsDataset
    else:
        dataset_class = PredictionPandasDataset
    # dataset_class = ImputationPandasDataset if mode == RunMode.imputation else PredictionPandasDataset

41-42: ⚠️ Potential issue

Re-evaluate default batch_size and epochs values

The default batch_size has been changed from 64 to 1, and epochs from 1000 to 100. A batch_size of 1 can lead to slow training due to less efficient GPU utilization. Unless there's a specific reason for these values, consider setting a larger batch_size and increasing epochs to ensure adequate training.

Apply this diff to adjust the default values:

-    batch_size=1,
-    epochs=100,
+    batch_size=64,
+    epochs=1000,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

    batch_size=64,
    epochs=1000,

200-219: 🛠️ Refactor suggestion

⚠️ Potential issue

Handle exceptions specifically and correct file naming in logs

In the persist_data function:

  • Catching the general Exception can obscure specific errors. Consider catching more specific exceptions like AttributeError or IOError.
  • The logging messages refer to 'test_shap_values.parquet', but the files are saved as 'shap_values_test.parquet' and 'shap_values_train.parquet'. This mismatch could confuse users checking the logs.

Apply this diff to correct the log messages:

-    logging.info(f"Saved shap values to {log_dir / 'test_shap_values.parquet'}")
+    logging.info(f"Saved shap values to {log_dir / 'shap_values_test.parquet'}")

-    logging.info(f"Saved shap values to {log_dir / 'test_shap_values.parquet'}")
+    logging.info(f"Saved shap values to {log_dir / 'shap_values_train.parquet'}")

Additionally, update the exception handling:

-    except Exception as e:
+    except (AttributeError, IOError) as e:

Committable suggestion was skipped due to low confidence.

icu_benchmarks/models/utils.py (5)

199-205: ⚠️ Potential issue

Potential misuse of average_precision_score with class labels instead of scores

In the __call__ method of scorer_wrapper, when y_pred is multi-dimensional and y_true has at most two unique values (binary classification), np.argmax is used to convert y_pred to class labels before passing it to self.scorer. However, average_precision_score expects probability estimates or confidence scores, not discrete class labels. Passing class labels may lead to incorrect metric calculations.

Consider passing the probability estimates directly to self.scorer without applying np.argmax. If y_pred contains logits or probabilities, they should be used to accurately compute the average precision score.


237-238: ⚠️ Potential issue

Possible IndexError when accessing pos_event_change_full[-1]

In the condition:

if ((last_know_label == 1) and (len(pos_event_change_full) == 0)) or (
        (last_know_label == 1) and (last_know_idx >= pos_event_change_full[-1])):

If pos_event_change_full is empty, accessing pos_event_change_full[-1] will raise an IndexError. The second part of the or condition does not check if the list is non-empty before accessing the last element.

Ensure that pos_event_change_full is not empty before accessing pos_event_change_full[-1]. You can restructure the condition or add a check to prevent the potential IndexError.


248-250: ⚠️ Potential issue

Potential IndexError due to empty array when accessing last_pos_change

In the code:

pos_change = np.where((diffs_label >= 1) & (label_for_event == 1))[0]
last_pos_change = pos_change[np.where(pos_change > max(last_know_event, last_known_stable))][0]

If pos_change is empty or if no elements satisfy the condition (pos_change > max(last_know_event, last_known_stable)), accessing index [0] will raise an IndexError.

Add a check to ensure that the array is not empty before accessing [0]. Handle cases where pos_change has no elements satisfying the condition to prevent runtime errors.


206-208: 🛠️ Refactor suggestion

Use __str__ or __repr__ instead of defining __name__ method

Defining a __name__ method in a class is unconventional in Python. The __name__ attribute is typically associated with modules and functions, not class instances. To provide a string representation of the class or its instances, consider implementing the __str__ or __repr__ method.

Apply this diff to replace __name__ with __str__:

-def __name__(self):
+def __str__(self):
     return "scorer_wrapper"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

    def __str__(self):
        return "scorer_wrapper"

192-195: 🛠️ Refactor suggestion

Adjust docstring indentation in scorer_wrapper class

The docstring inside the scorer_wrapper class has incorrect indentation. According to PEP 257, the docstring should be directly under the class definition with proper indentation for readability.

Apply this diff to fix the docstring indentation:

 class scorer_wrapper:
-    """
-        Wrapper that flattens the binary classification input such that we can use a broader range of sklearn metrics.
-    """
+    """
+    Wrapper that flattens the binary classification input such that we can use a broader range of sklearn metrics.
+    """
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

class scorer_wrapper:
    """
    Wrapper that flattens the binary classification input such that we can use a broader range of sklearn metrics.
    """
icu_benchmarks/run_utils.py (3)

79-84: 🛠️ Refactor suggestion

Simplify directory creation logic

The current logic for creating log_dir_run can be simplified to handle directory clashes more elegantly and avoid potential race conditions.

Consider refactoring using a loop to ensure a unique directory is created:

while True:
    log_dir_run = log_dir / datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f")
    try:
        log_dir_run.mkdir(parents=True)
        break
    except FileExistsError:
        continue

This approach attempts to create the directory and, in case of a FileExistsError, retries with a new timestamp until it succeeds.


112-112: ⚠️ Potential issue

Remove unused variable shap_values_train

The variable shap_values_train is assigned but never used, which may lead to confusion.

Apply this diff to remove the unused variable:

-        shap_values_train = []
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.


🧰 Tools
🪛 Ruff

112-112: Local variable shap_values_train is assigned to but never used

Remove assignment to unused variable shap_values_train

(F841)


257-268: 🛠️ Refactor suggestion

Use Path methods instead of mixing os.path and pathlib.Path

The current implementation mixes os.path and pathlib.Path, which can lead to inconsistencies and errors. Since config_dir is a Path object, it's better to use Path methods throughout.

Refactor the function to use Path.glob() and Path properties:

-def get_config_files(config_dir: Path):
-    tasks = glob.glob(os.path.join(config_dir / "tasks", '*'))
-    models = glob.glob(os.path.join(config_dir / "prediction_models", '*'))
-    tasks = [os.path.splitext(os.path.basename(task))[0] for task in tasks]
-    models = [os.path.splitext(os.path.basename(model))[0] for model in models]
+def get_config_files(config_dir: Path):
+    tasks = [task.stem for task in (config_dir / "tasks").glob('*')]
+    models = [model.stem for model in (config_dir / "prediction_models").glob('*')]
     if "common" in tasks:
         tasks.remove("common")
     if "common" in models:
         models.remove("common")
     logging.info(f"Found tasks: {tasks}")
     logging.info(f"Found models: {models}")
     return tasks, models

This change enhances readability and consistency by fully utilizing pathlib.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

def get_config_files(config_dir: Path):
    tasks = [task.stem for task in (config_dir / "tasks").glob('*')]
    models = [model.stem for model in (config_dir / "prediction_models").glob('*')]
    if "common" in tasks:
        tasks.remove("common")
    if "common" in models:
        models.remove("common")
    logging.info(f"Found tasks: {tasks}")
    logging.info(f"Found models: {models}")
    return tasks, models
icu_benchmarks/tuning/hyperparameters.py (3)

243-260: ⚠️ Potential issue

Correct typos: Replace trail with trial and trails with trials

There are several instances where trail is used instead of trial, and trails instead of trials. This may lead to errors or confusion. Please correct these typos throughout the code.

Apply these diffs to fix the typos:

  1. Replace trail with trial in the objective function and the study.optimize call:
 def objective(trail, hyperparams_bounds, hyperparams_names):
+def objective(trial, hyperparams_bounds, hyperparams_names):
     hyperparams = {}
     logging.info(f"Bounds: {hyperparams_bounds}, Names: {hyperparams_names}")
     for name, value in zip(hyperparams_names, hyperparams_bounds):
         if isinstance(value, tuple):
             # Check for range or "list-type" hyperparameter bounds
             if isinstance(value[0], (int, float)) and isinstance(value[1], (int, float)):
                 if len(value) == 3 and isinstance(value[2], str):
                     if isinstance(value[0], int) and isinstance(value[1], int):
-                        hyperparams[name] = trail.suggest_int(name, value[0], value[1], log=value[2] == "log")
+                        hyperparams[name] = trial.suggest_int(name, value[0], value[1], log=value[2] == "log")
                     elif isinstance(value[0], (int, float)) and isinstance(value[1], (int, float)):
-                        hyperparams[name] = trail.suggest_float(name, value[0], value[1], log=value[2] == "log")
+                        hyperparams[name] = trial.suggest_float(name, value[0], value[1], log=value[2] == "log")
                     else:
-                        hyperparams[name] = trail.suggest_categorical(name, value)
+                        hyperparams[name] = trial.suggest_categorical(name, value)
                 elif len(value) == 2:
                     if isinstance(value[0], int) and isinstance(value[1], int):
-                        hyperparams[name] = trail.suggest_int(name, value[0], value[1])
+                        hyperparams[name] = trial.suggest_int(name, value[0], value[1])
                     elif isinstance(value[0], (int, float)) and isinstance(value[1], (int, float)):
-                        hyperparams[name] = trail.suggest_float(name, value[0], value[1])
+                        hyperparams[name] = trial.suggest_float(name, value[0], value[1])
                     else:
-                        hyperparams[name] = trail.suggest_categorical(name, value)
+                        hyperparams[name] = trial.suggest_categorical(name, value)
                 else:
-                    hyperparams[name] = trail.suggest_categorical(name, value)
+                    hyperparams[name] = trial.suggest_categorical(name, value)
         else:
-            hyperparams[name] = trail.suggest_categorical(name, value)
+            hyperparams[name] = trial.suggest_categorical(name, value)
     return bind_params_and_train(hyperparams)
  1. Update the lambda function in study.optimize:
 study.optimize(
-    lambda trail: objective(trail, hyperparams_bounds, hyperparams_names),
+    lambda trial: objective(trial, hyperparams_bounds, hyperparams_names),
     n_trials=n_calls,
     callbacks=callbacks,
     gc_after_trial=True,
 )
  1. Correct trails to trials in the logging statement:
 logging.info(f"Starting or resuming Optuna study with {n_calls} trails and callbacks: {callbacks}.")
+logging.info(f"Starting or resuming Optuna study with {n_calls} trials and callbacks: {callbacks}.")

Also applies to: 355-355, 352-352


373-382: ⚠️ Potential issue

Correct the plotting and saving of Optuna visualization figures

The functions plot_param_importances and plot_optimization_history from optuna.visualization return Plotly figures, not Matplotlib figures. Using plt.savefig will not save these figures correctly.

To save the Plotly figures, you can use the write_image method. Here's how you can modify the code:

 if plot:
     try:
         logging.info("Plotting hyperparameter importances.")
-        plot_param_importances(study)
-        plt.savefig(log_dir / "param_importances.png")
-        plot_optimization_history(study)
-        plt.savefig(log_dir / "optimization_history.png")
+        import plotly.io as pio
+
+        fig1 = plot_param_importances(study)
+        pio.write_image(fig1, str(log_dir / "param_importances.png"))
+
+        fig2 = plot_optimization_history(study)
+        pio.write_image(fig2, str(log_dir / "optimization_history.png"))
     except Exception as e:
         logging.error(f"Failed to plot hyperparameter importances: {e}")

Additionally, ensure that you have installed the necessary dependencies (e.g., kaleido) to enable Plotly to export images.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

    if plot:
        try:
            logging.info("Plotting hyperparameter importances.")
            import plotly.io as pio

            fig1 = plot_param_importances(study)
            pio.write_image(fig1, str(log_dir / "param_importances.png"))

            fig2 = plot_optimization_history(study)
            pio.write_image(fig2, str(log_dir / "optimization_history.png"))
        except Exception as e:
            logging.error(f"Failed to plot hyperparameter importances: {e}")

323-326: ⚠️ Potential issue

Fix logical condition when checking for checkpoint existence

The condition in the if statement contradicts the logging message. If the checkpoint exists, it should proceed to load it instead of logging that it does not exist. Adjust the condition to check if the checkpoint does not exist.

Apply this diff to correct the condition:

 if checkpoint and checkpoint.exists():
-    logging.warning(f"Hyperparameter checkpoint {checkpoint} does not exist.")
+    logging.info(f"Loading checkpoint at {checkpoint}")
     study = optuna.load_study(
         study_name="tuning",
         storage="sqlite:///" + str(checkpoint),
         sampler=sampler,
         pruner=pruner,
     )
     n_calls = n_calls - len(study.trials)
 else:
     if checkpoint:
         logging.warning("Checkpoint path given as flag but not found, starting from scratch.")
     # Continue with creating a new study
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

    if checkpoint and checkpoint.exists():
        logging.info(f"Loading checkpoint at {checkpoint}")
        study = optuna.load_study(
            study_name="tuning",
            storage="sqlite:///" + str(checkpoint),
            sampler=sampler,
            pruner=pruner,
        )
        n_calls = n_calls - len(study.trials)
    else:
        if checkpoint:
            logging.warning("Checkpoint path given as flag but not found, starting from scratch.")
        # Continue with creating a new study
icu_benchmarks/data/loader.py (2)

15-184: 🛠️ Refactor suggestion

Consider refactoring to reduce code duplication

There is significant code duplication between the Polars-based and Pandas-based dataset classes (CommonPolarsDataset vs. CommonPandasDataset, PredictionPolarsDataset vs. PredictionPandasDataset). This can make maintenance more challenging and increase the risk of inconsistencies. Consider refactoring by creating a base class that encapsulates shared functionality or using a strategy pattern to handle the differences between Polars and Pandas. This will promote code reuse and improve maintainability.

Also applies to: 186-336


309-309: ⚠️ Potential issue

Remove unused variable weights

At line 309, the variable weights is assigned but not used before being returned. You can return the computed value directly without assigning it to weights.

Apply this diff to fix the issue:

- weights = list((1 / counts) * np.sum(counts) / counts.shape[0])
  return list((1 / counts) * np.sum(counts) / counts.shape[0])

Committable suggestion was skipped due to low confidence.

🧰 Tools
🪛 Ruff

309-309: Local variable weights is assigned to but never used

Remove assignment to unused variable weights

(F841)

icu_benchmarks/data/split_process_data.py (6)

9-9: ⚠️ Potential issue

Remove unused imports to improve code cleanliness.

The imports pyarrow.parquet as pq, sequence from setuptools.dist, and false from sqlalchemy are not used in the code and can be safely removed.

Apply this diff to remove the unused imports:

-import pyarrow.parquet as pq
-from setuptools.dist import sequence
-from sqlalchemy import false

Also applies to: 15-15, 17-17


31-32: ⚠️ Potential issue

Avoid using mutable default arguments.

Using mutable default arguments like {} and [] can lead to unexpected behavior because default arguments are evaluated only once. It's better to use None and assign within the function to avoid shared mutable defaults.

Apply this diff to fix the default arguments:

-    modality_mapping: dict[str] = {},
-    selected_modalities: list[str] = "all",
...
-    vars_to_exclude: list[str] = [],
+    modality_mapping: dict[str] = None,
+    selected_modalities: list[str] = None,
...
+    vars_to_exclude: list[str] = None,

Then, within the function, initialize them:

if modality_mapping is None:
    modality_mapping = {}

if selected_modalities is None:
    selected_modalities = ["all"]

if vars_to_exclude is None:
    vars_to_exclude = []

Also applies to: 46-46

🧰 Tools
🪛 Ruff

31-31: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


32-32: ⚠️ Potential issue

Type mismatch in default value for selected_modalities.

The parameter selected_modalities is annotated as list[str], but the default value is a string "all". This can lead to type errors. Consider initializing it to None and setting the default within the function.

Apply this diff:

-selected_modalities: list[str] = "all",
+selected_modalities: list[str] = None,

Then, within the function:

if selected_modalities is None:
    selected_modalities = ["all"]

129-129: ⚠️ Potential issue

Use is for comparison to None and simplify condition.

Comparison to None should be done using is or is not. Additionally, the condition can be simplified for better readability.

Apply this diff:

-if not (selected_modalities == "all" or selected_modalities == ["all"] or selected_modalities == None):
+if selected_modalities not in [None, "all", ["all"]]:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

        if selected_modalities not in [None, "all", ["all"]]:
🧰 Tools
🪛 Ruff

129-129: Comparison to None should be cond is None

Replace with cond is None

(E711)


161-162: ⚠️ Potential issue

Avoid using dict as a variable name to prevent shadowing built-in types.

Using dict as a variable name shadows the built-in dict type in Python, which can lead to unexpected behaviors. Consider renaming the variable to something like data_dict or data_split.

Apply this diff:

-for dict in data.values():
-    for key, val in dict.items():
+for data_split in data.values():
+    for key, val in data_split.items():
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

    for data_split in data.values():
        for key, val in data_split.items():

192-192: ⚠️ Potential issue

Variable sequence is redefined and shadows an imported name.

The variable sequence is redefined, which could potentially lead to confusion or errors.

Consider renaming the variable or removing the unused import if it's not needed.

Apply this diff:

-from setuptools.dist import sequence

...

- sequence = vars[Var.sequence] if Var.sequence in vars.keys() else None
+ sequence_var = vars[Var.sequence] if Var.sequence in vars else None
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

    sequence_var = vars[Var.sequence] if Var.sequence in vars else None
🧰 Tools
🪛 Ruff

192-192: Redefinition of unused sequence from line 15

(F811)


192-192: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

icu_benchmarks/models/wrappers.py (3)

6-6: ⚠️ Potential issue

Remove unused imports average_precision_score and roc_auc_score.

The functions average_precision_score and roc_auc_score are imported but not used in the code. Removing unused imports helps keep the codebase clean and reduces potential confusion.

Apply this diff to remove the unused imports:

-from sklearn.metrics import log_loss, mean_squared_error, average_precision_score, roc_auc_score
+from sklearn.metrics import log_loss, mean_squared_error
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

from sklearn.metrics import log_loss, mean_squared_error

20-20: ⚠️ Potential issue

Remove unused import scorer_wrapper.

The scorer_wrapper function is imported from icu_benchmarks.models.utils but is not utilized within the code. Cleaning up unused imports improves code readability and maintenance.

Apply this diff to remove the unused import:

-from icu_benchmarks.models.utils import create_optimizer, create_scheduler, scorer_wrapper
+from icu_benchmarks.models.utils import create_optimizer, create_scheduler
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

from icu_benchmarks.models.utils import create_optimizer, create_scheduler
🧰 Tools
🪛 Ruff

20-20: icu_benchmarks.models.utils.scorer_wrapper imported but unused

Remove unused import: icu_benchmarks.models.utils.scorer_wrapper

(F401)


33-34: ⚠️ Potential issue

Remove unnecessary gin configurations for unused metrics.

The average_precision_score and roc_auc_score metrics are registered with Gin but are not used elsewhere in the code. Consider removing these configurations to streamline the code.

Apply this diff to remove the unnecessary configurations:

-gin.config.external_configurable(average_precision_score, module="sklearn.metrics")
-gin.config.external_configurable(roc_auc_score, module="sklearn.metrics")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.


icu_benchmarks/data/preprocessor.py (2)

9-9: ⚠️ Potential issue

Remove unused import polars.selectors as cs

The import statement import polars.selectors as cs on line 9 is unused in the code and can be removed to clean up the imports.

Apply this diff to remove the unused import:

-import polars.selectors as cs
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.


🧰 Tools
🪛 Ruff

9-9: polars.selectors imported but unused

Remove unused import: polars.selectors

(F401)


203-203: ⚠️ Potential issue

Handle potential NoneType for self.imputation_model

In the to_cache_string method, accessing self.imputation_model.__class__.__name__ assumes that self.imputation_model is not None. If self.imputation_model is None, this will raise an AttributeError. Consider handling this case to prevent potential errors.

Apply this diff to handle the potential NoneType:

+        imputation_model_name = self.imputation_model.__class__.__name__ if self.imputation_model else 'None'
         return (
             super().to_cache_string()
-            + f"_classification_{self.generate_features}_{self.scaling}_{self.imputation_model.__class__.__name__}"
+            + f"_classification_{self.generate_features}_{self.scaling}_{imputation_model_name}"
         )

Committable suggestion was skipped due to low confidence.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 12

🧹 Outside diff range and nitpick comments (41)
icu_benchmarks/models/ml_models/catboost.py (1)

17-19: Enhance the docstring for the predict method

The predict method looks good, but the docstring could be more informative.

Consider expanding the docstring to include information about the input and output:

def predict(self, features):
    """
    Predicts class probabilities for the given features.

    Args:
        features: Input features for prediction.

    Returns:
        numpy.ndarray: Predicted probabilities for each class.
    """
    return self.model.predict_proba(features)

This provides more context about what the method expects and what it returns.

icu_benchmarks/models/ml_models/xgboost.py (4)

1-15: LGTM! Consider organizing imports.

The imports look appropriate for the XGBoost classifier implementation. However, consider organizing them according to PEP 8 guidelines: standard library imports first, then third-party imports, and finally local imports.

Here's a suggested organization:

import inspect
import logging
import os
from statistics import mean

import gin
import numpy as np
import shap
import wandb
import xgboost as xgb
from xgboost.callback import EarlyStopping
from wandb.integration.xgboost import wandb_callback as wandb_xgb

from icu_benchmarks.constants import RunMode
from icu_benchmarks.models.wrappers import MLWrapper

# Uncomment if needed in the future
# from optuna.integration import XGBoostPruningCallback

30-46: LGTM! Consider minor improvements in fit_model.

The fit_model method looks good overall:

  • EarlyStopping callback is used to prevent overfitting.
  • Optional WandB integration is flexible.
  • SHAP values are computed for model interpretability.

Consider the following improvements:

  1. The SHAP computation might be expensive for large datasets. Consider adding an option to disable it or compute it on a subset of the data.
  2. The evaluation score calculation could be more explicit. Consider using a named metric instead of the first available one.

Here's a suggested improvement for the evaluation score:

eval_results = self.model.evals_result_["validation_0"]
primary_metric = list(eval_results.keys())[0]  # e.g., 'auc', 'logloss'
eval_score = mean(eval_results[primary_metric])
return eval_score

48-73: LGTM with suggestions for improved clarity and robustness.

The test_step method is comprehensive, handling prediction, logging, and SHAP value computation. However, consider the following improvements:

  1. The handling of pred_indicators could be more explicit. Add type hints or documentation to clarify the expected structure.
  2. The purpose of self.mps is not clear. Consider adding a comment or renaming for clarity.
  3. The CSV saving logic could be extracted into a separate method for better readability.

Here's a suggested refactoring for the CSV saving logic:

def save_predictions_to_csv(self, pred_indicators, test_label, test_pred):
    if (len(pred_indicators.shape) > 1 and 
        len(test_pred.shape) > 1 and 
        pred_indicators.shape[1] == test_pred.shape[1]):
        
        output = np.hstack((pred_indicators, 
                            test_label.reshape(-1, 1), 
                            test_pred))
        output_path = os.path.join(self.logger.save_dir, "pred_indicators.csv")
        np.savetxt(output_path, output, delimiter=",")
        logging.debug(f"Saved predictions to {output_path}")

# Call this method in test_step
self.save_predictions_to_csv(pred_indicators, test_label, test_pred)

This refactoring improves readability and separates concerns in the test_step method.

🧰 Tools
🪛 Ruff

63-63: f-string without any placeholders

Remove extraneous f prefix

(F541)


85-86: LGTM! Consider adding error handling.

The get_feature_importance method correctly returns the feature importances from the XGBoost model. However, consider adding error handling in case the model hasn't been fit yet.

Here's a suggested improvement:

def get_feature_importance(self):
    if not hasattr(self.model, 'feature_importances_'):
        raise ValueError("Model has not been fit yet. Call fit_model() before getting feature importances.")
    return self.model.feature_importances_

This addition will provide a clear error message if the method is called before the model is fit.

icu_benchmarks/run.py (3)

54-58: LGTM: New logic for handling modalities and labels

The new code block effectively handles the modalities and label arguments, using gin bindings to set them appropriately. The debug logging for modalities is helpful for tracking.

Consider adding a similar debug log for the label binding:

if args.label:
    logging.debug(f"Binding label: {args.label}")
    gin.bind_parameter("preprocess.label", args.label)

This would maintain consistency with the modalities logging and provide additional clarity during execution.


59-64: LGTM: New logic for config file retrieval and validation

The addition of the get_config_files function call and subsequent validation checks for tasks and models enhances the robustness of the script. This ensures that only valid tasks and models are processed, preventing potential runtime errors.

Consider combining the two separate checks into a single if statement for slightly more efficient error handling:

if task not in tasks or model not in models:
    raise ValueError(f"Invalid task or model. Task: {task} {'not ' if task not in tasks else ''}found. Model: {model} {'not ' if model not in models else ''}found.")

This change would provide a single, comprehensive error message if either the task or model is invalid.


191-194: LGTM: Added error handling for aggregate_results

The introduction of a try-except block around the aggregate_results call is a good addition. It improves the script's robustness by preventing crashes due to issues during result aggregation.

Consider adding more detailed error logging:

try:
    aggregate_results(run_dir, execution_time)
except Exception as e:
    logging.error(f"Failed to aggregate results: {e}")
    logging.debug("Error details:", exc_info=True)

This change would provide a stack trace in debug mode, which could be helpful for troubleshooting more complex issues.

icu_benchmarks/data/pooling.py (1)

69-70: Approve changes with a minor optimization suggestion

The use of pq.read_table().to_pandas(self_destruct=True) is a good performance optimization. It allows the Arrow memory to be reused, potentially reducing memory usage.

However, there's a minor optimization opportunity:

Consider simplifying the dictionary comprehension by removing .keys():

data[folder.name] = {
    f: pq.read_table(folder / self.file_names[f]).to_pandas(self_destruct=True)
    for f in self.file_names
}

This change slightly improves performance by avoiding the creation of a dict_keys object.

🧰 Tools
🪛 Ruff

70-70: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

icu_benchmarks/models/train.py (5)

84-89: LGTM: Flexible dataset class selection implemented

The new logic for selecting the dataset class based on the polars parameter enhances flexibility. The added logging statement improves transparency and aids in debugging.

Consider simplifying the dataset class selection logic using a dictionary mapping:

dataset_classes = {
    RunMode.imputation: ImputationPandasDataset,
    RunMode.classification: PredictionPolarsDataset if polars else PredictionPandasDataset,
    RunMode.regression: PredictionPolarsDataset if polars else PredictionPandasDataset,
}
dataset_class = dataset_classes[mode]

This approach would be more scalable if additional modes are added in the future.

Also applies to: 92-92


95-98: LGTM: DataLoader updates for performance and polars compatibility

The addition of the persistent_workers parameter to the DataLoader instantiation could potentially improve data loading performance. The commented-out code for converting datasets to pandas is no longer necessary due to the transition to polars.

Consider removing the commented-out code (lines 96-98) to improve code cleanliness. If this code needs to be preserved for reference, consider moving it to a separate document or adding a more detailed comment explaining why it's being kept.

Also applies to: 119-119, 128-128


155-155: LGTM: Trainer configuration improved

The changes to the Trainer instantiation improve the training process:

  1. Setting min_epochs=1 ensures at least one epoch of training is performed.
  2. Adding num_sanity_val_steps=2 helps catch errors in the validation loop before training begins.

Consider adding a comment explaining the rationale behind changing min_epochs from 0 to 1, as this might not be immediately obvious to other developers.

Also applies to: 164-164


181-181: LGTM: Test dataset handling improved and data persistence added

The changes to test dataset creation and evaluation are consistent with earlier modifications:

  1. Addition of ram_cache parameter to test dataset creation.
  2. Inclusion of persistent_workers in the test DataLoader.

The new call to persist_data(trainer, log_dir) suggests additional data persistence functionality has been implemented.

Please add documentation (preferably a docstring) for the new persist_data function to explain its purpose, parameters, and return value (if any).

Also applies to: 192-192, 200-200


205-223: LGTM: New persist_data function for SHAP value persistence

The new persist_data function is a valuable addition for persisting SHAP values:

  • It handles both test and train SHAP values.
  • Uses polars DataFrame for efficient data handling.
  • Implements proper error handling and logging.

Consider the following improvements:

  1. Add a docstring explaining the function's purpose, parameters, and return value (if any).
  2. The commented-out code (lines 209-212) can be removed if it's no longer needed.
  3. Consider using a context manager (with) when writing parquet files to ensure proper file handling.
  4. The error message in the except block could be more specific, e.g., include which operation failed (test or train SHAP values).

Example improvement for point 3:

with (log_dir / "shap_values_test.parquet").open("wb") as f:
    shaps_test.write_parquet(f)
icu_benchmarks/run_utils.py (4)

58-59: LGTM: New command-line arguments enhance script flexibility.

The addition of --modalities and --label arguments improves the script's versatility by allowing users to specify evaluation parameters. The implementation is consistent with the existing code style.

Consider adding help text for the --modalities argument to clarify its purpose and expected input format, similar to the --label argument:

-    parser.add_argument("-mo", "--modalities", nargs="+", help="Modalities to use for evaluation.")
+    parser.add_argument("-mo", "--modalities", nargs="+", help="Modalities to use for evaluation. Specify multiple modalities separated by spaces.")

80-85: LGTM: Improved directory creation logic.

The new implementation ensures unique directory names and handles potential naming conflicts effectively. The added checks for directory existence before creation are a good practice.

Consider simplifying the logic slightly by using a loop to handle potential conflicts:

-    if not log_dir_run.exists():
-        log_dir_run.mkdir(parents=True)
-    else:
-        # Directory clash at last moment
-        log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f"))
-        log_dir_run.mkdir(parents=True)
+    while log_dir_run.exists():
+        log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f"))
+    log_dir_run.mkdir(parents=True)

This approach reduces code duplication and handles multiple potential conflicts in a single loop.


Line range hint 112-137: LGTM: Added SHAP values aggregation.

The new implementation effectively aggregates SHAP values from test datasets using polars for efficient parquet file handling. The aggregated values are correctly saved to a new parquet file.

Consider adding error handling for the parquet file operations:

     if shap_values_test:
-        shap_values = pl.concat(shap_values_test)
-        shap_values.write_parquet(log_dir / "aggregated_shap_values.parquet")
+        try:
+            shap_values = pl.concat(shap_values_test)
+            shap_values.write_parquet(log_dir / "aggregated_shap_values.parquet")
+        except Exception as e:
+            logging.error(f"Error aggregating or writing SHAP values: {e}")

This addition will help catch and log any potential issues with the parquet operations.


258-269: LGTM: New function for retrieving configuration files.

The get_config_files function effectively retrieves and processes configuration files from the specified directories. It correctly filters out "common" entries and logs the found tasks and models.

Consider adding error handling and using Path objects consistently for better cross-platform compatibility:

 def get_config_files(config_dir: Path):
-    tasks = glob.glob(os.path.join(config_dir / "tasks", "*"))
-    models = glob.glob(os.path.join(config_dir / "prediction_models", "*"))
-    tasks = [os.path.splitext(os.path.basename(task))[0] for task in tasks]
-    models = [os.path.splitext(os.path.basename(model))[0] for model in models]
+    try:
+        tasks = list((config_dir / "tasks").glob("*"))
+        models = list((config_dir / "prediction_models").glob("*"))
+        tasks = [task.stem for task in tasks if task.is_file()]
+        models = [model.stem for model in models if model.is_file()]
+    except Exception as e:
+        logging.error(f"Error retrieving config files: {e}")
+        return [], []
     if "common" in tasks:
         tasks.remove("common")
     if "common" in models:
         models.remove("common")
     logging.info(f"Found tasks: {tasks}")
     logging.info(f"Found models: {models}")
     return tasks, models

This modification improves error handling and uses Path objects consistently, enhancing robustness and cross-platform compatibility.

icu_benchmarks/tuning/hyperparameters.py (5)

5-5: LGTM! Consider grouping related imports.

The new imports for matplotlib and Optuna visualization functions are appropriate for the added visualization capabilities. The plot parameter in the function signature is a good addition for flexibility.

Consider grouping related imports together for better readability. For example, you could move the matplotlib import closer to the Optuna visualization imports.

Also applies to: 18-18, 194-194


282-292: LGTM! Consider consolidating checkpoint loading logic.

The updated checkpoint handling using Optuna's study loading functionality is a significant improvement. The added error handling and logging provide better visibility into the process.

Consider consolidating the checkpoint loading logic into a separate function for better modularity and reusability. This would make the main function cleaner and easier to maintain.

Also applies to: 324-342


244-261: LGTM! Consider refactoring for improved readability.

The updated objective function now handles hyperparameter bounds more flexibly, providing better support for various types and ranges. The logic for determining the suggestion type is comprehensive and robust.

Consider refactoring the nested if-else statements into separate functions for each hyperparameter type. This could improve readability and make the code easier to maintain. For example:

def suggest_int_param(trial, name, value):
    return trial.suggest_int(name, value[0], value[1], log=value[2] == "log" if len(value) == 3 else False)

def suggest_float_param(trial, name, value):
    return trial.suggest_float(name, value[0], value[1], log=value[2] == "log" if len(value) == 3 else False)

def suggest_categorical_param(trial, name, value):
    return trial.suggest_categorical(name, value)

# Then in the objective function:
if isinstance(value[0], int) and isinstance(value[1], int):
    hyperparams[name] = suggest_int_param(trial, name, value)
elif isinstance(value[0], (int, float)) and isinstance(value[1], (int, float)):
    hyperparams[name] = suggest_float_param(trial, name, value)
else:
    hyperparams[name] = suggest_categorical_param(trial, name, value)

This approach would make the code more modular and easier to extend in the future.


343-365: LGTM! Robust Optuna study implementation with W&B integration.

The Optuna study creation and optimization logic is well-implemented, supporting resumable studies and integrating with Weights & Biases for enhanced tracking. The code handles both new studies and resuming from checkpoints effectively.

Consider adding a comment explaining the gc_after_trial=True parameter in the study.optimize call. This parameter is important for memory management, especially for long-running optimizations, and a brief explanation would be helpful for future maintainers.


374-383: LGTM! Valuable visualization additions with proper error handling.

The new visualization code for plotting hyperparameter importances and optimization history provides valuable insights into the optimization process. The error handling ensures that the main process continues even if plotting fails, which is a good practice.

Consider adding a configuration option to specify the output format of the plots (e.g., PNG, PDF, SVG). This would provide more flexibility for users who might prefer different formats for their documentation or analysis. You could add a parameter like plot_format: str = 'png' to the function and use it in the savefig calls.

icu_benchmarks/data/loader.py (5)

15-80: Good implementation of CommonPolarsDataset using Polars

The transition to Polars for data manipulation is a good choice for improving performance with large datasets. The class structure and methods are well-implemented.

However, there are a few minor improvements that can be made:

  1. Remove commented-out code to improve readability. For example, lines 33-35 can be deleted.
  2. Consider adding type hints to method return values for better code documentation.

Here's an example of adding type hints:

- def get_feature_names(self):
+ def get_feature_names(self) -> List[str]:
     return self.features_df.columns

Line range hint 185-243: Consider deprecating CommonPandasDataset in favor of CommonPolarsDataset

The CommonPandasDataset class has been kept, likely for backward compatibility. However, as the codebase is transitioning to Polars, it might be beneficial to encourage migration to the Polars version.

Suggestions for improvement:

  1. Add a deprecation warning to the class to encourage users to migrate to CommonPolarsDataset. For example:
import warnings

@gin.configurable("CommonPandasDataset")
class CommonPandasDataset(Dataset):
    def __init__(self, *args, **kwargs):
        warnings.warn("CommonPandasDataset is deprecated. Use CommonPolarsDataset instead.", DeprecationWarning, stacklevel=2)
        super().__init__(*args, **kwargs)
    # ... rest of the class implementation
  1. Add type hints to method return values for consistency with the Polars version. For example:
- def get_feature_names(self):
+ def get_feature_names(self) -> List[str]:
     return self.features_df.columns
  1. Consider adding a migration guide or helper function to assist users in transitioning from the Pandas version to the Polars version.

Line range hint 244-335: Consider deprecating PredictionPandasDataset in favor of PredictionPolarsDataset

The PredictionPandasDataset class has been kept, likely for backward compatibility. However, as the codebase is transitioning to Polars, it might be beneficial to encourage migration to the Polars version.

Suggestions for improvement:

  1. Add a deprecation warning to the class to encourage users to migrate to PredictionPolarsDataset. For example:
import warnings

@gin.configurable("PredictionPandasDataset")
class PredictionPandasDataset(CommonPandasDataset):
    def __init__(self, *args, **kwargs):
        warnings.warn("PredictionPandasDataset is deprecated. Use PredictionPolarsDataset instead.", DeprecationWarning, stacklevel=2)
        super().__init__(*args, **kwargs)
    # ... rest of the class implementation
  1. Add type hints to method return values for consistency with the Polars version. For example:
- def get_balance(self) -> list:
+ def get_balance(self) -> List[float]:
  1. Consider adding a migration guide or helper function to assist users in transitioning from the Pandas version to the Polars version.

  2. There's an inconsistency in the get_balance method. The commented-out line 308 should be removed, and the return statement should be simplified:

- # weights = list((1 / counts) * np.sum(counts) / counts.shape[0])
- return list((1 / counts) * np.sum(counts) / counts.shape[0])
+ return ((1 / counts) * np.sum(counts) / counts.shape[0]).tolist()

Line range hint 336-465: Update ImputationPandasDataset and ImputationPredictionDataset to use Polars

The ImputationPandasDataset and ImputationPredictionDataset classes have not been updated to use Polars, which is inconsistent with the rest of the file.

Suggestions for improvement:

  1. Create Polars versions of these classes for consistency with the rest of the file. This would involve rewriting the data manipulation logic using Polars operations.

  2. Add deprecation warnings to the existing Pandas-based classes to encourage migration to the Polars versions once they are implemented. For example:

import warnings

@gin.configurable("ImputationPandasDataset")
class ImputationPandasDataset(CommonPandasDataset):
    def __init__(self, *args, **kwargs):
        warnings.warn("ImputationPandasDataset is deprecated. Use ImputationPolarsDataset instead.", DeprecationWarning, stacklevel=2)
        super().__init__(*args, **kwargs)
    # ... rest of the class implementation

@gin.configurable("ImputationPredictionDataset")
class ImputationPredictionDataset(Dataset):
    def __init__(self, *args, **kwargs):
        warnings.warn("ImputationPredictionDataset is deprecated. Use ImputationPredictionPolarsDataset instead.", DeprecationWarning, stacklevel=2)
        super().__init__(*args, **kwargs)
    # ... rest of the class implementation
  1. Add type hints to method return values for consistency with the other classes in the file.

  2. Consider updating the ampute_data function to work with Polars DataFrames, or create a Polars-compatible version of this function.

  3. As part of the transition to Polars, consider creating a base class (e.g., BasePolarsDataset) that implements common functionality for all Polars-based dataset classes. This would help reduce code duplication and ensure consistency across different dataset types.


Line range hint 1-465: Overall positive changes with room for further improvements

The transition to Polars for most dataset classes is a significant improvement that should enhance performance, especially for large datasets. The new CommonPolarsDataset and PredictionPolarsDataset classes are well-implemented and make good use of Polars' features.

Key points and recommendations:

  1. The transition to Polars is a positive change that should improve performance and maintainability.

  2. Complete the transition to Polars by updating ImputationPandasDataset and ImputationPredictionDataset to use Polars.

  3. Add deprecation warnings to all Pandas-based classes to encourage migration to their Polars counterparts.

  4. Improve code documentation by adding type hints consistently across all classes and methods.

  5. Consider creating a base Polars dataset class to reduce code duplication and ensure consistency across different dataset types.

  6. Remove any remaining commented-out code to improve readability.

By addressing these points, the code will be more consistent, maintainable, and future-proof. Great work on the transition to Polars, and keep up the momentum to fully leverage its benefits across all dataset classes.

icu_benchmarks/data/split_process_data.py (4)

3-4: LGTM! Transition to Polars and new functionality additions.

The changes in imports and function signature reflect the transition to Polars and the addition of new functionality for modality selection and data exclusion. This should improve performance and flexibility.

Consider using type hints for the new parameters modality_mapping and selected_modalities to improve code readability and maintainability.

Also applies to: 9-9, 12-12, 14-14, 23-23, 26-27, 40-42


74-80: Improved data handling and preprocessing.

The changes enhance label handling flexibility, introduce data sanitization, and provide better logging for data quality issues. The transition to Polars is consistently implemented.

Consider adding error handling for the case when no label is selected after filtering. For example:

if not vars[Var.label]:
    raise ValueError("No label selected after filtering.")

Also applies to: 118-134, 156-177


190-238: New data preprocessing functions: Approved.

The new functions check_sanitize_data and modality_selection add important steps for data cleaning and filtering. They are well-implemented and should improve data quality.

In the modality_selection function, consider adding error handling for the case when none of the selected modalities are found in the modality mapping. For example:

if not any(col in modality_mapping.keys() for col in selected_modalities):
    raise ValueError("None of the selected modalities found in modality mapping.")
🧰 Tools
🪛 Ruff

192-192: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


193-193: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


195-195: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


199-199: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


203-203: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


218-218: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


234-234: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


Line range hint 240-442: Data splitting functions updated for Polars support: Approved.

The make_train_val and make_single_split functions have been successfully adapted to work with both Polars and Pandas DataFrames. This maintains backwards compatibility while allowing for the performance benefits of Polars.

Consider refactoring the common logic for Polars and Pandas into separate helper functions to reduce code duplication. For example:

def _get_stays(data, id, polars):
    return pl.Series(name=id, values=data[Segment.outcome][id].unique()) if polars else pd.Series(data[Segment.outcome][id].unique(), name=id)

def _get_labels(data, id, vars, polars):
    if polars:
        return data[Segment.outcome].group_by(id).max()[vars[Var.label]]
    else:
        return data[Segment.outcome].groupby(id).max()[vars[Var.label]].reset_index(drop=True)

# Use these helper functions in both make_train_val and make_single_split

This refactoring would improve maintainability and reduce the risk of inconsistencies between the two functions.

🧰 Tools
🪛 Ruff

303-303: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


311-311: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


315-315: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

icu_benchmarks/models/wrappers.py (3)

6-6: Unused imports and configurations detected.

The newly added imports average_precision_score and roc_auc_score from sklearn.metrics, along with their corresponding gin configurations, are not currently used in the code. Consider removing them if they are not needed, or implement their usage if they are intended to be used in the future.

Also applies to: 33-34


376-379: Commented-out code for new metrics.

There's commented-out code related to average_precision_score and roc_auc_score. If these metrics are intended to be used, consider implementing them. Otherwise, it might be better to remove the commented code to improve readability.


Line range hint 562-589: Approved: Good improvements to the ImputationWrapper class.

The new init_weights method provides a flexible way to initialize weights using different strategies. The changes in on_fit_start and step_fn methods enhance the functionality of the imputation wrapper.

As a minor suggestion, consider adding type hints to the init_weights method parameters for consistency with the rest of the codebase:

def init_weights(self, init_type: ImputationInit = ImputationInit.NORMAL, gain: float = 0.02):

Also applies to: 591-595, 597-611

icu_benchmarks/data/preprocessor.py (6)

77-143: LGTM! Comprehensive implementation with good error handling.

The apply method effectively handles both static and dynamic features, with proper checks and integration. The logging statements are helpful for debugging.

Consider extracting the repeated code for joining static and dynamic data into a separate method to improve readability and maintainability.

Here's a suggestion for extracting the join operation:

def _join_static_dynamic(self, data, vars):
    for split in [Split.train, Split.val, Split.test]:
        data[split][Segment.dynamic] = data[split][Segment.dynamic].join(
            data[split][Segment.static], on=vars["GROUP"]
        )
    return data

You can then call this method in the apply function instead of repeating the join operation for each split.


145-198: LGTM! Comprehensive preprocessing for both static and dynamic features.

The _process_static and _process_dynamic methods provide a thorough approach to data preprocessing. The use of the Recipe class allows for a clear and modular implementation of various preprocessing steps.

Consider adding error handling for potential exceptions that might occur during the preprocessing steps, especially for the apply_recipe_to_splits function calls.

Here's a suggestion for adding error handling:

try:
    data = apply_recipe_to_splits(sta_rec, data, Segment.static, self.save_cache, self.load_cache)
except Exception as e:
    logging.error(f"Error occurred while processing static features: {str(e)}")
    raise

Apply similar error handling to the dynamic feature processing as well.


215-276: LGTM! Well-structured regression preprocessor.

The PolarsRegressionPreprocessor effectively extends the classification preprocessor for regression tasks. The outcome processing is flexible, allowing for predefined ranges or automatic scaling.

Consider moving the outcome processing logic to a separate method to improve code organization and reusability.

Here's a suggestion for extracting the outcome processing logic:

def _scale_outcome(self, outcome_data, vars):
    outcome_rec = Recipe(outcome_data, vars["LABEL"], [], vars["GROUP"])
    if self.outcome_max is not None and self.outcome_min is not None:
        outcome_rec.add_step(
            StepSklearn(
                sklearn_transformer=FunctionTransformer(
                    func=lambda x: ((x - self.outcome_min) / (self.outcome_max - self.outcome_min))
                ),
                sel=all_outcomes(),
            )
        )
    else:
        outcome_rec.add_step(StepSklearn(MinMaxScaler(), sel=all_outcomes()))
    return outcome_rec.prep().bake()

You can then call this method in the _process_outcome function.


Line range hint 279-475: LGTM! Consistent implementation for Pandas-based preprocessing.

The Pandas-based preprocessor classes maintain consistency with their Polars counterparts, which is excellent for code maintainability and understanding.

Consider creating a base preprocessor class that contains common functionality for both Polars and Pandas implementations. This could help reduce code duplication and make it easier to maintain both versions in the future.

Here's a high-level suggestion for creating a base preprocessor:

class BasePreprocessor(abc.ABC):
    @abc.abstractmethod
    def apply(self, data, vars):
        pass

    @abc.abstractmethod
    def _process_static(self, data, vars):
        pass

    @abc.abstractmethod
    def _process_dynamic(self, data, vars):
        pass

    # Common methods and properties can be defined here

Then, both Polars and Pandas preprocessors can inherit from this base class:

class PolarsClassificationPreprocessor(BasePreprocessor):
    # Polars-specific implementation

class PandasClassificationPreprocessor(BasePreprocessor):
    # Pandas-specific implementation

This structure would help centralize common logic and make it easier to maintain both implementations.


Line range hint 476-536: LGTM! Focused implementation for imputation preprocessing.

The PandasImputationPreprocessor class provides a targeted approach for imputation tasks. The options for scaling and using static features offer flexibility in preprocessing.

However, the current implementation of _process_dynamic_data removes entire stays (groups of data) if any missing values are found. This approach might lead to significant data loss in some scenarios.

Consider implementing a more granular approach to handling missing values. For example:

  1. Allow for a threshold of missing values before removing a stay.
  2. Implement different strategies for handling missing values based on the nature of the data (e.g., forward fill for time series data).

Here's a suggestion for a more flexible approach:

def _process_dynamic_data(self, data, vars):
    if self.filter_missing_values:
        missing_ratio = data[Segment.dynamic][vars[Segment.dynamic]].isna().mean(axis=1)
        ids_to_remove = data[Segment.dynamic].loc[missing_ratio > self.missing_threshold][vars["GROUP"]].unique()
        data = {table_name: table.loc[~table[vars["GROUP"]].isin(ids_to_remove)] for table_name, table in data.items()}
        logging.info(f"Removed {len(ids_to_remove)} stays with more than {self.missing_threshold*100}% missing values.")
    return data

This approach allows you to set a threshold for the acceptable percentage of missing values before removing a stay.


Line range hint 537-589: LGTM! Robust implementation of recipe application and caching.

The utility functions provide a flexible approach to applying recipes across different data splits and managing recipe caching. The support for both Polars and Pandas DataFrames is a nice touch for compatibility.

For consistency, consider adding similar error handling to the cache_recipe function as you have in restore_recipe.

Here's a suggestion for adding error handling to cache_recipe:

def cache_recipe(recipe: Recipe, cache_file: str) -> None:
    """Cache recipe to make it available for e.g. transfer learning."""
    try:
        recipe_cache = copy.deepcopy(recipe)
        recipe_cache.cache()
        if not (cache_file / "..").exists():
            (cache_file / "..").mkdir(parents=True, exist_ok=True)
        with open(cache_file, "wb") as f:
            pickle.dump(recipe_cache, f, pickle.HIGHEST_PROTOCOL)
        logging.info(f"Cached recipe in {cache_file}.")
    except Exception as e:
        logging.error(f"Error occurred while caching recipe: {str(e)}")
        raise

This change adds error handling and ensures the parent directory is created if it doesn't exist.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 860bf3a and 4a4ab0c.

📒 Files selected for processing (16)
  • icu_benchmarks/cross_validation.py (5 hunks)
  • icu_benchmarks/data/loader.py (5 hunks)
  • icu_benchmarks/data/pooling.py (5 hunks)
  • icu_benchmarks/data/preprocessor.py (8 hunks)
  • icu_benchmarks/data/split_process_data.py (11 hunks)
  • icu_benchmarks/models/constants.py (2 hunks)
  • icu_benchmarks/models/custom_metrics.py (2 hunks)
  • icu_benchmarks/models/ml_models/catboost.py (1 hunks)
  • icu_benchmarks/models/ml_models/imblearn.py (1 hunks)
  • icu_benchmarks/models/ml_models/xgboost.py (1 hunks)
  • icu_benchmarks/models/train.py (10 hunks)
  • icu_benchmarks/models/utils.py (2 hunks)
  • icu_benchmarks/models/wrappers.py (10 hunks)
  • icu_benchmarks/run.py (7 hunks)
  • icu_benchmarks/run_utils.py (6 hunks)
  • icu_benchmarks/tuning/hyperparameters.py (8 hunks)
🚧 Files skipped from review as they are similar to previous changes (5)
  • icu_benchmarks/cross_validation.py
  • icu_benchmarks/models/constants.py
  • icu_benchmarks/models/custom_metrics.py
  • icu_benchmarks/models/ml_models/imblearn.py
  • icu_benchmarks/models/utils.py
🧰 Additional context used
🪛 Ruff
icu_benchmarks/data/pooling.py

70-70: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


106-106: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


108-108: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)

icu_benchmarks/data/split_process_data.py

26-26: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


41-41: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


119-119: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


124-124: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


124-124: Test for membership should be not in

Convert to not in

(E713)


192-192: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


193-193: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


195-195: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


199-199: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


203-203: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


218-218: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


234-234: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


303-303: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


311-311: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


315-315: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


409-412: Use ternary operator dev_stays = stays[dev] if polars else stays.iloc[dev] instead of if-else-block

Replace if-else-block with dev_stays = stays[dev] if polars else stays.iloc[dev]

(SIM108)


428-428: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


436-436: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


440-440: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

icu_benchmarks/models/ml_models/catboost.py

14-14: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)

icu_benchmarks/models/ml_models/xgboost.py

23-23: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)


63-63: f-string without any placeholders

Remove extraneous f prefix

(F541)

icu_benchmarks/models/wrappers.py

474-474: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

🔇 Additional comments (22)
icu_benchmarks/models/ml_models/catboost.py (2)

7-9: LGTM: Class definition and attributes

The class definition looks good. The use of @gin.configurable allows for easy configuration, and the _supported_run_modes attribute correctly specifies the supported task type.


1-19: Overall impression: Good implementation with room for minor improvements

The CBClassifier class is well-structured and implements the CatBoost classifier effectively. The use of gin for configuration and inheritance from MLWrapper shows good design practices. There are a few minor improvements suggested throughout the review, including:

  1. Fixing a typo in the import statement.
  2. Making the task type configurable in the constructor.
  3. Enhancing the docstring for the predict method.
  4. Refactoring to address the star-arg unpacking warning.

Implementing these suggestions will further improve the code's clarity, flexibility, and maintainability. Overall, this is a solid addition to the project.

🧰 Tools
🪛 Ruff

14-14: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)

icu_benchmarks/models/ml_models/xgboost.py (3)

18-24: LGTM! Well-structured class definition.

The class definition and initialization look good:

  • The @gin.configurable decorator allows for flexible configuration.
  • The _supported_run_modes class attribute clearly defines the intended use.
  • The __init__ method is concise and uses set_model_args for model initialization.
🧰 Tools
🪛 Ruff

23-23: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)


26-28: LGTM! Concise predict method.

The predict method is simple and correctly uses predict_proba for classification tasks, returning probability predictions.


1-86: Overall assessment: Well-implemented XGBoost classifier with room for minor improvements.

The XGBClassifier implementation is solid and includes essential methods for training, prediction, and evaluation. It demonstrates good practices such as:

  • Use of gin for configuration
  • Integration with WandB for experiment tracking
  • Computation of SHAP values for model interpretability
  • Flexible model argument setting

To further improve the implementation, consider addressing the following:

  1. Organize imports according to PEP 8 guidelines
  2. Improve robustness of the set_model_args method
  3. Add error handling to get_feature_importance
  4. Refactor the test_step method for better readability
  5. Address the static analysis warnings

These improvements will enhance the overall quality, readability, and robustness of the code.

🧰 Tools
🪛 Ruff

23-23: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)


63-63: f-string without any placeholders

Remove extraneous f prefix

(F541)

icu_benchmarks/run.py (4)

21-21: LGTM: New import added for get_config_files

The addition of get_config_files to the imports is appropriate. This function will likely be used to retrieve configuration files for tasks and models, which aligns with the changes described in the AI-generated summary.


53-53: LGTM: New modalities argument added to main function

The addition of the modalities argument to the main function enhances the flexibility of the script, allowing users to specify which data modalities to use in the analysis. This change is consistent with the improvements mentioned in the AI-generated summary.


84-86: LGTM: Improved formatting for wandb config update

The changes to the update_wandb_config call enhance readability by breaking the long line into multiple lines. The use of a ternary operator to handle the case where pretrained_imputation_model is None is a concise and effective approach.


145-148: LGTM: Switched to keyword arguments for improved clarity

The modification of the choose_and_bind_hyperparameters_optuna function call to use keyword arguments instead of positional arguments is a positive change. This improves code readability, makes the function call more explicit, and reduces the likelihood of errors when modifying the call in the future. It's consistent with best practices for function calls with multiple arguments.

icu_benchmarks/data/pooling.py (3)

19-29: LGTM: Improved method signature formatting

The __init__ method signature has been reformatted for better readability, with parameters aligned vertically. This change enhances code clarity without affecting functionality.


51-55: LGTM: Improved method signature formatting

The generate method signature has been reformatted for better readability, with parameters aligned vertically. This change enhances code clarity without affecting functionality.


148-150: LGTM: Improved stratification in train-test split

The modification to use stratify=labels in the train_test_split function for classification mode is a good improvement. This ensures that the class distribution is roughly the same in both the train and test sets, which is crucial for maintaining the integrity of the model evaluation process.

icu_benchmarks/models/train.py (3)

6-6: LGTM: Imports updated for polars integration

The addition of polars import and the inclusion of PredictionPolarsDataset are consistent with the transition to using polars for data handling. These changes lay the groundwork for the polars-based data processing in the rest of the file.

Also applies to: 14-14


30-30: Verify impact of parameter changes in train_common

The function signature has been updated to reflect the transition to polars and add flexibility:

  1. Data type hint changed from pd.DataFrame to pl.DataFrame.
  2. New polars parameter added with default True.

However, there are significant changes to default values:

  • batch_size default changed from 64 to 1
  • epochs default changed from 1000 to 100

These changes could significantly impact the training process and model performance.

Please clarify the rationale behind these changes and confirm that they have been thoroughly tested. Consider adding comments in the code to explain these default value changes.

Also applies to: 41-42, 54-54


134-134: LGTM: Model loading and instantiation updated for CPU flexibility

The addition of the cpu parameter to both the load_model call and model instantiation enhances flexibility, allowing for easier CPU-based operations when needed. This change is consistent with the cpu parameter in the train_common function signature.

Also applies to: 136-136

icu_benchmarks/run_utils.py (2)

18-20: LGTM: New imports added for enhanced functionality.

The addition of os, glob, and polars imports is appropriate for the new features implemented in this file. These libraries are used effectively in the new get_config_files function and the updated aggregate_results function.


Line range hint 1-269: Overall assessment: Positive improvements to run_utils.py

The changes made to run_utils.py enhance its functionality and flexibility in several ways:

  1. New imports support additional features.
  2. Additional command-line arguments allow for more specific evaluations.
  3. Improved directory creation logic ensures unique run directories.
  4. SHAP values aggregation adds valuable analytical capabilities.
  5. The new get_config_files function streamlines configuration file management.

These modifications collectively improve the script's versatility and robustness. The implementation is consistent with the existing code style, and the changes are well-integrated into the file's structure.

icu_benchmarks/tuning/hyperparameters.py (2)

Line range hint 266-271: LGTM! Improved callback integration with Optuna.

The updated tuning callback function now works seamlessly with Optuna's study and trial objects. The logging improvements provide better visibility into the tuning process, which is valuable for monitoring and debugging.

The integration with Weights & Biases (wandb) for logging the optimization iteration is a good addition for those using wandb for experiment tracking.


Line range hint 1-383: Overall, excellent implementation of Optuna for hyperparameter optimization!

The changes in this file significantly enhance the hyperparameter tuning capabilities of the project. The integration of Optuna provides a more flexible and powerful optimization framework, while the added visualization features offer valuable insights into the tuning process. The code is well-structured, with appropriate error handling and logging throughout.

Key improvements include:

  1. Flexible handling of various hyperparameter types and ranges
  2. Robust checkpoint management for resumable studies
  3. Integration with Weights & Biases for enhanced tracking
  4. Visualization of hyperparameter importances and optimization history

These changes will greatly benefit the project by enabling more efficient and insightful hyperparameter tuning processes.

icu_benchmarks/data/split_process_data.py (1)

Line range hint 1-442: Overall assessment: Significant improvements with minor suggestions for refinement.

The transition to Polars and the addition of new functionality for modality selection and data sanitization are well-implemented. These changes should lead to improved performance and more flexible data handling. The suggested refinements, including error handling improvements, refactoring to reduce code duplication, and addressing static analysis warnings, will further enhance the code quality and maintainability.

🧰 Tools
🪛 Ruff

192-192: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


193-193: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


195-195: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


199-199: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


203-203: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


218-218: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


234-234: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

icu_benchmarks/models/wrappers.py (1)

Line range hint 1-624: Summary of review findings

  1. The changes to the wrapper classes, especially the ImputationWrapper, generally improve functionality and flexibility.
  2. There are several instances of unused imports, parameters, and attributes that should be addressed.
  3. The MPS compatibility type casting in the log_metrics method needs review.
  4. Minor improvements can be made for performance and code consistency.

Overall, the changes are positive, but addressing the identified issues will further enhance the code quality and maintainability.

🧰 Tools
🪛 Ruff

474-474: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

icu_benchmarks/data/preprocessor.py (1)

Line range hint 9-75: LGTM! Polars integration looks promising.

The introduction of Polars and the new PolarsClassificationPreprocessor class is a positive change. Polars is known for its performance benefits in data processing tasks. The flexible initialization parameters allow for customized preprocessing, which is great for adaptability to different datasets and requirements.

icu_benchmarks/models/ml_models/catboost.py Outdated Show resolved Hide resolved
icu_benchmarks/models/ml_models/catboost.py Outdated Show resolved Hide resolved
icu_benchmarks/models/ml_models/catboost.py Outdated Show resolved Hide resolved
icu_benchmarks/models/ml_models/xgboost.py Outdated Show resolved Hide resolved
Comment on lines 83 to 182
@gin.configurable("PredictionPolarsDataset")
class PredictionPolarsDataset(CommonPolarsDataset):
"""Subclass of common dataset for prediction tasks.

Args:
ram_cache (bool, optional): Whether the complete dataset should be stored in ram. Defaults to True.
"""

def __init__(self, *args, ram_cache: bool = True, **kwargs):
super().__init__(*args, **kwargs)
self.outcome_df = self.grouping_df
self.ram_cache(ram_cache)

def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]:
"""Function to sample from the data split of choice. Used for deep learning implementations.

Args:
idx: A specific row index to sample.

Returns:
A sample from the data, consisting of data, labels and padding mask.
"""
if self._cached_dataset is not None:
return self._cached_dataset[idx]

pad_value = 0.0
# stay_id = self.outcome_df.index.unique()[idx] # [self.vars["GROUP"]]
stay_id = self.outcome_df[self.vars["GROUP"]].unique()[idx] # [self.vars["GROUP"]]

# slice to make sure to always return a DF
# window = self.features_df.loc[stay_id:stay_id].to_numpy()
# labels = self.outcome_df.loc[stay_id:stay_id][self.vars["LABEL"]].to_numpy(dtype=float)
window = self.features_df.filter(pl.col(self.vars["GROUP"]) == stay_id).to_numpy()
labels = self.outcome_df.filter(pl.col(self.vars["GROUP"]) == stay_id)[self.vars["LABEL"]].to_numpy().astype(float)

if len(labels) == 1:
# only one label per stay, align with window
labels = np.concatenate([np.empty(window.shape[0] - 1) * np.nan, labels], axis=0)

length_diff = self.maxlen - window.shape[0]
pad_mask = np.ones(window.shape[0])

# Padding the array to fulfill size requirement
if length_diff > 0:
# window shorter than the longest window in dataset, pad to same length
window = np.concatenate([window, np.ones((length_diff, window.shape[1])) * pad_value], axis=0)
labels = np.concatenate([labels, np.ones(length_diff) * pad_value], axis=0)
pad_mask = np.concatenate([pad_mask, np.zeros(length_diff)], axis=0)

not_labeled = np.argwhere(np.isnan(labels))
if len(not_labeled) > 0:
labels[not_labeled] = -1
pad_mask[not_labeled] = 0

pad_mask = pad_mask.astype(bool)
labels = labels.astype(np.float32)
data = window.astype(np.float32)

return from_numpy(data), from_numpy(labels), from_numpy(pad_mask)

def get_balance(self) -> list:
"""Return the weight balance for the split of interest.

Returns:
Weights for each label.
"""
counts = self.outcome_df[self.vars["LABEL"]].value_counts(parallel=True).get_columns()[1]
counts = counts.to_numpy()
weights = list((1 / counts) * np.sum(counts) / counts.shape[0])
return weights

def get_data_and_labels(self) -> Tuple[np.array, np.array, np.array]:
"""Function to return all the data and labels aligned at once.

We use this function for the ML methods which don't require an iterator.

Returns:
A Tuple containing data points and label for the split.
"""
labels = self.outcome_df[self.vars["LABEL"]].to_numpy().astype(float)
rep = self.features_df

if len(labels) == self.num_stays:
# order of groups could be random, we make sure not to change it
# rep = rep.groupby(level=self.vars["GROUP"], sort=False).last()
rep = rep.group_by(self.vars["GROUP"]).last()
else:
# Adding segment count for each stay id and timestep.
rep = rep.with_columns(pl.col(self.vars["GROUP"]).cum_count().over(self.vars["GROUP"]).alias("counter"))
rep = rep.to_numpy().astype(float)
logging.debug(f"rep shape: {rep.shape}")
logging.debug(f"labels shape: {labels.shape}")
return rep, labels, self.row_indicators.to_numpy()

def to_tensor(self):
data, labels, row_indicators = self.get_data_and_labels()
if self.mps:
return from_numpy(data).to(float32), from_numpy(labels).to(float32)
else:
return from_numpy(data), from_numpy(labels), row_indicators
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Well-implemented PredictionPolarsDataset with Polars integration

The PredictionPolarsDataset class is well-implemented and consistently uses Polars for data manipulation. The methods are appropriately adapted for prediction tasks.

Here are some suggestions for improvement:

  1. Add type hints to method return values for better code documentation. For example:
- def get_data_and_labels(self):
+ def get_data_and_labels(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
  1. In the get_balance method, consider adding a comment explaining the weight calculation formula for better understanding.

  2. In the __getitem__ method, consider using Polars' built-in methods for performance optimization. For example, replace the numpy operations with Polars operations where possible.

Here's an example of how you might optimize the __getitem__ method:

def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]:
    if self._cached_dataset is not None:
        return self._cached_dataset[idx]

    pad_value = 0.0
    stay_id = self.outcome_df[self.vars["GROUP"]].unique()[idx]

    window = self.features_df.filter(pl.col(self.vars["GROUP"]) == stay_id)
    labels = self.outcome_df.filter(pl.col(self.vars["GROUP"]) == stay_id)[self.vars["LABEL"]]

    if len(labels) == 1:
        labels = pl.concat([pl.Series([None] * (window.height - 1)), labels])

    length_diff = self.maxlen - window.height
    pad_mask = pl.Series([1] * window.height)

    if length_diff > 0:
        window = window.vstack(pl.DataFrame({col: [pad_value] * length_diff for col in window.columns}))
        labels = labels.extend(pl.Series([pad_value] * length_diff))
        pad_mask = pad_mask.extend(pl.Series([0] * length_diff))

    labels = labels.fill_null(-1)
    pad_mask = pad_mask.where(labels != -1, 0)

    return (
        from_numpy(window.to_numpy().astype(np.float32)),
        from_numpy(labels.to_numpy().astype(np.float32)),
        from_numpy(pad_mask.to_numpy().astype(bool))
    )

This version uses more Polars operations and should be more efficient, especially for larger datasets.

icu_benchmarks/data/split_process_data.py Show resolved Hide resolved
icu_benchmarks/models/wrappers.py Show resolved Hide resolved
icu_benchmarks/models/wrappers.py Outdated Show resolved Hide resolved
icu_benchmarks/models/wrappers.py Outdated Show resolved Hide resolved
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 15

🧹 Outside diff range and nitpick comments (53)
icu_benchmarks/models/ml_models/lgbm.py (2)

11-27: LGTM: LGBMWrapper implementation is solid, with a minor suggestion for improvement.

The LGBMWrapper class provides a robust implementation for fitting LightGBM models. It includes early stopping, logging callbacks, and wandb integration, which are all excellent practices. The method returns the validation loss, which is useful for model evaluation.

Consider adding type hints to improve code readability and maintainability. For example:

def fit_model(self, train_data: np.ndarray, train_labels: np.ndarray, 
              val_data: np.ndarray, val_labels: np.ndarray) -> float:

This would make the expected input and output types clearer to other developers.


30-48: LGTM: LGBMClassifier implementation is well-structured, with a minor suggestion.

The LGBMClassifier class is well-implemented, utilizing gin for configuration and correctly specifying its supported run mode. The predict method is properly documented and returns class probabilities as expected.

To enhance code robustness, consider adding input validation in the predict method. For example:

def predict(self, features: np.ndarray) -> np.ndarray:
    if not isinstance(features, np.ndarray):
        raise TypeError("features must be a numpy array")
    if features.ndim != 2:
        raise ValueError("features must be a 2D array")
    return self.model.predict_proba(features)

This would help catch potential errors earlier in the prediction process.

icu_benchmarks/models/dl_models/tcn.py (2)

12-16: Enhance the class docstring for better documentation.

While the current docstring mentions the adaptation from the original TCN paper, it could be more informative. Consider expanding it to include:

  1. A brief explanation of what a Temporal Convolutional Network does
  2. Key features or advantages of this implementation
  3. Description of important parameters
  4. Example usage, if applicable

This would greatly improve the documentation and make it easier for other developers to understand and use the class.


1-62: Consider adding type hints for improved code quality.

While the implementation is solid, adding type hints to method parameters and return values could further improve code readability and help catch potential type-related issues early in development. This is especially useful for complex classes like TCN.

Example for the __init__ method:

def __init__(self, input_size: Tuple[int, int, int], num_channels: Union[int, List[int]], num_classes: int, *args, max_seq_length: int = 0, kernel_size: int = 2, dropout: float = 0.0, **kwargs) -> None:

And for the forward method:

def forward(self, x: torch.Tensor) -> torch.Tensor:

Consider adding similar type hints throughout the class for improved code quality and maintainability.

🧰 Tools
🪛 Ruff

23-23: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)

icu_benchmarks/models/ml_models/sklearn.py (2)

Line range hint 54-66: Minor inconsistency in SVM classes.

The SVMClassifier and SVMRegressor classes use self.model_args instead of self.set_model_args which is used in all other classes. This might be a typo or an oversight.

Consider updating these lines for consistency:

-        self.model = self.model_args(svm.SVC, *args, **kwargs)
+        self.model = self.set_model_args(svm.SVC, *args, **kwargs)
-        self.model = self.model_args(svm.SVR, *args, **kwargs)
+        self.model = self.set_model_args(svm.SVR, *args, **kwargs)

Line range hint 1-93: Overall assessment: Simplified sklearn.py with potential wider impact.

The changes to this file have significantly simplified it by removing LightGBM-related code and focusing on scikit-learn models. This may improve maintainability and reduce dependencies. However, it's crucial to ensure that:

  1. The removal of LightGBM functionality doesn't negatively impact other parts of the codebase that may have depended on it.
  2. Any documentation or user guides are updated to reflect the removal of LightGBM support.
  3. The minor inconsistency in the SVM classes is addressed for better code uniformity.

These changes align with the PR objectives of incorporating enhancements from the Cassandra project, although the specific connection isn't clear from this file alone. The introduction of new models mentioned in the PR objectives isn't evident in this file, so it may be worth checking other files in the PR for those additions.

Consider updating the module docstring (if it exists) to reflect the current set of supported models and the removal of LightGBM support.

icu_benchmarks/models/dl_models/rnn.py (2)

8-31: LGTM: RNNet class implementation is well-structured.

The RNNet class is well-implemented with appropriate initialization, hidden state management, and forward pass. The use of @gin.configurable allows for flexible configuration.

Consider adding type hints to method parameters and return values for improved code readability and maintainability. For example:

def init_hidden(self, x: torch.Tensor) -> torch.Tensor:
    # ...

def forward(self, x: torch.Tensor) -> torch.Tensor:
    # ...
🧰 Tools
🪛 Ruff

16-16: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)


1-85: Summary: Good implementation of RNN models with some areas for improvement

Overall, the implementation of RNNet, LSTMNet, and GRUNet classes is well-structured and follows good practices. However, there are a few areas that need attention:

  1. Add the missing return statement in the GRUNet's forward method to ensure correct functionality.
  2. Refactor the super().init() calls in all classes to avoid star-arg unpacking after keyword arguments.
  3. Consider adding type hints to improve code readability and maintainability.

These changes will enhance the code quality and ensure proper functionality of the RNN models. Once these issues are addressed, the implementation will be robust and ready for use.

To further improve the code, consider the following suggestions:

  1. Implement a base RNN class that encapsulates common functionality, and have RNNet, LSTMNet, and GRUNet inherit from it. This would reduce code duplication and improve maintainability.
  2. Add docstrings to methods, especially forward, to clarify the expected input and output shapes.
  3. Consider adding a method to save and load model weights, which could be useful for model persistence and deployment.
🧰 Tools
🪛 Ruff

16-16: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)


42-42: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)


69-69: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)

icu_benchmarks/models/dl_models/transformer.py (2)

52-64: Consider using enumerate or range(depth) directly.

The loop control variables i in the initialization of transformer blocks are unused. Consider using enumerate if you need the index, or use range(depth) directly if you don't need it.

Here's a suggested change:

# In Transformer class
self.tblocks = nn.Sequential(*[
    TransformerBlock(
        emb=hidden,
        hidden=hidden,
        heads=heads,
        mask=True,
        ff_hidden_mult=ff_hidden_mult,
        dropout=dropout,
        dropout_att=dropout_att,
    ) for _ in range(depth)
])

# In LocalTransformer class
self.tblocks = nn.Sequential(*[
    LocalBlock(
        emb=hidden,
        hidden=hidden,
        heads=heads,
        mask=True,
        ff_hidden_mult=ff_hidden_mult,
        local_context=local_context,
        dropout=dropout,
        dropout_att=dropout_att,
    ) for _ in range(depth)
])

This change eliminates the unused variable warning and makes the code more concise.

Also applies to: 123-135

🧰 Tools
🪛 Ruff

52-52: Loop control variable i not used within loop body

Rename unused i to _i

(B007)


1-148: Overall assessment: Good implementation with room for improvement

The transformer.py file implements two transformer-based models (Transformer and LocalTransformer) with a clear structure and good use of the gin configuration system. The code is functional and follows the expected architecture for transformer models.

Main points:

  1. Good use of gin for configuration and clear support for classification and regression modes.
  2. Well-structured forward methods implementing the transformer architecture.
  3. Significant code duplication between Transformer and LocalTransformer classes.
  4. Use of *args in __init__ methods, which can lead to unexpected behavior.
  5. Minor issues with unused loop control variables.

Recommendations:

  1. Refactor to introduce a base class for both transformer types to reduce code duplication.
  2. Use configuration objects instead of numerous parameters and *args in __init__ methods.
  3. Address minor issues like unused loop control variables.

These changes would significantly improve code maintainability and readability while preserving the current functionality.

🧰 Tools
🪛 Ruff

37-37: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)


52-52: Loop control variable i not used within loop body

Rename unused i to _i

(B007)


106-106: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)


123-123: Loop control variable i not used within loop body

Rename unused i to _i

(B007)

icu_benchmarks/models/ml_models/xgboost.py (1)

42-59: LGTM: fit_model method is well-implemented with a minor suggestion.

The fit_model method is correctly implemented with good practices such as early stopping and optional integration with Weights & Biases. The conditional computation of SHAP values is a good approach for performance.

Minor suggestion:
Consider using an f-string for the logging statement on line 48 for consistency with other logging calls in the file:

logging.info(f"train_data: {train_data.shape}, train_labels: {train_labels.shape}")
icu_benchmarks/run.py (4)

53-59: LGTM: Added modalities and label binding with proper logging.

The new code block enhances flexibility by allowing modalities and labels to be bound to the gin configuration. The logging statements are helpful for debugging.

Consider adding type checking for modalities and args.label to ensure they are of the expected type before binding. For example:

if modalities and isinstance(modalities, list):
    logging.debug(f"Binding modalities: {modalities}")
    gin.bind_parameter("preprocess.selected_modalities", modalities)
elif modalities:
    logging.warning(f"Unexpected type for modalities: {type(modalities)}. Expected list.")

if args.label and isinstance(args.label, str):
    logging.debug(f"Binding label: {args.label}")
    gin.bind_parameter("preprocess.label", args.label)
elif args.label:
    logging.warning(f"Unexpected type for label: {type(args.label)}. Expected string.")

60-64: LGTM: Added config file retrieval and validation.

The new code block adds an important validation check for the existence of the specified task and model. This is a good practice to catch configuration errors early in the execution.

To improve readability, consider using f-strings for the entire error message and simplifying the condition checks:

tasks, models = get_config_files(Path("configs"))
if task not in tasks or model not in models:
    raise ValueError(
        f"Invalid configuration. "
        f"Task '{task}' {'not ' if task not in tasks else ''}found in tasks. "
        f"Model '{model}' {'not ' if model not in models else ''}found in models."
    )

134-135: LGTM: Added support for different model types based on run mode.

The modification to the model_path assignment adds support for different model types (imputation or prediction) based on the run mode. This change aligns well with the summary mentioning updates to support various model types.

To improve readability, consider using a dictionary for model types:

model_types = {
    RunMode.imputation: "imputation_models",
    RunMode.prediction: "prediction_models"
}
model_path = Path("configs") / model_types[mode] / f"{model}.gin"

This approach would make it easier to add new model types in the future if needed.


191-195: LGTM: Added error handling for aggregate_results function.

The addition of a try-except block for the aggregate_results function call is a good practice. It ensures that any exceptions during result aggregation are properly caught and logged, which aligns with the summary mentioning enhanced error handling.

For consistency with the rest of the code, consider using f-strings for error logging:

try:
    aggregate_results(run_dir, execution_time)
except Exception as e:
    logging.error(f"Failed to aggregate results: {e}")
    logging.debug(f"Error details:", exc_info=True)
icu_benchmarks/models/train.py (5)

30-30: Update function signature and default values

The changes to the train_common function signature reflect the transition to Polars and introduce new configuration options. However, there are a few points to consider:

  1. The reduction in default batch_size (from 64 to 1) and epochs (from 1000 to 100) is significant. This might affect training performance and model convergence. Please ensure this change is intentional and doesn't negatively impact the model's learning process.

  2. The new polars parameter should be documented in the function's docstring, explaining its purpose and impact on the data handling process.

  3. The persistent_workers parameter is introduced but set to None by default. Consider adding documentation about when and why this parameter should be used.

Could you please update the function's docstring to include explanations for the new parameters and the reasoning behind the changes in default values?

Also applies to: 41-42, 54-55


85-91: LGTM: Flexible dataset class selection

The new dataset class selection logic is a good addition, allowing for flexibility between pandas and polars-based datasets. This change supports both backward compatibility and performance optimization.

However, there's a TODO comment that needs attention:

# todo: add support for polars versions of datasets

Would you like assistance in implementing the polars versions of the datasets to fully leverage the benefits of the Polars library?


113-113: LGTM: Addition of persistent_workers to DataLoader

The addition of the persistent_workers parameter to the DataLoader instantiation is a good optimization. This can potentially improve data loading performance, especially for larger datasets, by keeping worker processes alive between data loading iterations.

Consider adding a comment or updating the documentation to explain:

  1. The performance implications of using persistent workers.
  2. Any potential memory considerations when using this feature.
  3. Guidelines on when to enable or disable this feature based on dataset size or system resources.

Also applies to: 121-121


148-148: LGTM: Improved Trainer configuration

The changes to the Trainer configuration are beneficial:

  1. Setting min_epochs=1 ensures that at least one epoch of training occurs, which is crucial for obtaining results.
  2. Setting num_sanity_val_steps=2 helps catch potential errors in the validation loop before the actual training begins.

These additions provide better control and error checking in the training process.

Consider adding comments to explain:

  1. Why a minimum of 1 epoch is necessary.
  2. The reasoning behind choosing 2 sanity validation steps and what kinds of errors this might catch.

Also applies to: 157-157


198-220: LGTM: New function to persist SHAP values

The addition of the persist_shap_data function is a valuable enhancement to the model's interpretability features. Saving SHAP values to parquet files for both test and train datasets will allow for efficient storage and retrieval of this important information.

Positive aspects:

  1. Use of parquet files for efficient storage.
  2. Separate handling for test and train SHAP values.
  3. Inclusion of error handling.

Consider enhancing the error logging in the except block:

except Exception as e:
    logging.error(f"Failed to save SHAP values: {e}", exc_info=True)

This will provide more detailed stack trace information, which can be helpful for debugging.

icu_benchmarks/run_utils.py (5)

18-20: Consider removing unused imports

The os and glob modules are imported but not used in the visible changes. If they are not used elsewhere in the file, consider removing these imports to keep the code clean.

If these imports are indeed unused, you can remove them:

-import os
-import glob
import polars as pl
🧰 Tools
🪛 Ruff

18-18: os imported but unused

Remove unused import: os

(F401)


19-19: glob imported but unused

Remove unused import: glob

(F401)


58-64: LGTM! Consider clarifying help text

The addition of --modalities and --label arguments enhances the flexibility of the benchmarking process. Good job!

Consider clarifying the help text for the --label argument to indicate its relationship with multiple labels:

-    parser.add_argument("--label", type=str, help="Label to use for evaluation in case of multiple labels.", default=None)
+    parser.add_argument("--label", type=str, help="Specific label to use for evaluation when the dataset contains multiple labels.", default=None)

81-82: LGTM! Consider optimizing directory creation

The addition of the while loop ensures unique directory names, preventing potential conflicts. This is a good improvement for data integrity.

Consider using Path.mkdir(parents=True, exist_ok=True) to simplify the directory creation process:

-    log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
-    while log_dir_run.exists():
-        log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f"))
-    log_dir_run.mkdir(parents=True)
+    while True:
+        log_dir_run = log_dir / datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f")
+        try:
+            log_dir_run.mkdir(parents=True, exist_ok=False)
+            break
+        except FileExistsError:
+            continue

This approach is more efficient and handles race conditions better in multi-threaded environments.


Line range hint 111-142: Remove duplicate code and optimize SHAP value aggregation

The addition of SHAP value aggregation is valuable. However, there's unnecessary duplication in the code.

Please remove the duplicate block and optimize the SHAP value aggregation:

     shap_values_test = []
-    # shap_values_train = []
     for repetition in log_dir.iterdir():
         if repetition.is_dir():
             aggregated[repetition.name] = {}
             for fold_iter in repetition.iterdir():
                 aggregated[repetition.name][fold_iter.name] = {}
                 if (fold_iter / "test_metrics.json").is_file():
                     with open(fold_iter / "test_metrics.json", "r") as f:
                         result = json.load(f)
                         aggregated[repetition.name][fold_iter.name].update(result)
                 elif (fold_iter / "val_metrics.csv").is_file():
                     with open(fold_iter / "val_metrics.csv", "r") as f:
                         result = json.load(f)
                         aggregated[repetition.name][fold_iter.name].update(result)
                 # Add durations to metrics
                 if (fold_iter / "durations.json").is_file():
                     with open(fold_iter / "durations.json", "r") as f:
                         result = json.load(f)
                         aggregated[repetition.name][fold_iter.name].update(result)
                 if (fold_iter / "test_shap_values.parquet").is_file():
                     shap_values_test.append(pl.read_parquet(fold_iter / "test_shap_values.parquet"))

-    if shap_values_test:
-        shap_values = pl.concat(shap_values_test)
-        shap_values.write_parquet(log_dir / "aggregated_shap_values.parquet")
-
     try:
-        shap_values = pl.concat(shap_values_test)
-        shap_values.write_parquet(log_dir / "aggregated_shap_values.parquet")
+        if shap_values_test:
+            shap_values = pl.concat(shap_values_test)
+            shap_values.write_parquet(log_dir / "aggregated_shap_values.parquet")
     except Exception as e:
         logging.error(f"Error aggregating or writing SHAP values: {e}")

This change removes the duplicate code and ensures that SHAP values are only processed once, improving efficiency and readability.

🧰 Tools
🪛 Ruff

145-145: Loop control variable repetition not used within loop body

(B007)


262-286: LGTM! Consider using list comprehension for efficiency

The get_config_files function is a valuable addition for retrieving task and model configurations. The error handling and logging are well implemented.

Consider using list comprehensions to make the code more concise and potentially more efficient:

 def get_config_files(config_dir: Path):
     try:
-        tasks = list((config_dir / "tasks").glob("*"))
-        models = list((config_dir / "prediction_models").glob("*"))
-        tasks = [task.stem for task in tasks if task.is_file()]
-        models = [model.stem for model in models if model.is_file()]
+        tasks = [task.stem for task in (config_dir / "tasks").glob("*") if task.is_file() and task.stem != "common"]
+        models = [model.stem for model in (config_dir / "prediction_models").glob("*") if model.is_file() and model.stem != "common"]
     except Exception as e:
         logging.error(f"Error retrieving config files: {e}")
         return [], []
-    if "common" in tasks:
-        tasks.remove("common")
-    if "common" in models:
-        models.remove("common")
     logging.info(f"Found tasks: {tasks}")
     logging.info(f"Found models: {models}")
     return tasks, models

This change combines the file listing and filtering steps, reducing the number of iterations and eliminating the need for separate remove operations.

icu_benchmarks/tuning/hyperparameters.py (5)

194-194: LGTM: New plot parameter added.

The addition of the plot parameter provides useful flexibility for controlling visualization output.

Consider updating the function's docstring to include a description of the new plot parameter.


238-261: LGTM: Improved hyperparameter suggestion logic.

The refactoring of the objective function to use Optuna's trial object and the addition of helper functions for different parameter types significantly improve the flexibility and readability of the hyperparameter suggestion process.

Consider adding error handling for unexpected value types in the hyperparams_bounds to make the function more robust.


372-381: LGTM: Added visualization for hyperparameter tuning results.

The new plotting functionality for hyperparameter importances and optimization history enhances the interpretability of the tuning results. The error handling for the plotting process is a good practice.

Consider using a context manager (with plt.figure():) when saving the plots to ensure proper cleanup of matplotlib resources.


Line range hint 1-411: Overall: Significant improvement in hyperparameter tuning capabilities.

The changes in this file represent a substantial enhancement to the project's hyperparameter tuning capabilities:

  1. The implementation of Optuna provides more flexibility and potentially better optimization results.
  2. New visualization features enhance the interpretability of the tuning process.
  3. Improved checkpoint handling and integration with wandb for better experiment tracking.

These changes should lead to more efficient and effective model tuning. Great work on this refactoring!

As the project evolves, consider creating a separate module for visualization functions to keep the main tuning logic clean and modular.


Remove the redundant commented-out checkpoint code.

Line 325 contains a commented-out logging statement related to Attempting to find latest checkpoint file. Since this functionality is active at line 85, the commented-out code appears unnecessary and can be removed to clean up the codebase.

🔗 Analysis chain

Line range hint 280-340: LGTM: Updated checkpoint loading for Optuna studies.

The checkpoint loading process has been successfully adapted to work with Optuna studies, which is appropriate for the new implementation.

There's a commented-out section for finding the latest checkpoint file. Could you clarify if this functionality is still needed or if it can be removed?

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check for commented-out code related to checkpoint finding
rg -n "Attempting to find latest checkpoint file" icu_benchmarks/tuning/hyperparameters.py

Length of output: 239

icu_benchmarks/data/loader.py (8)

Line range hint 1-15: LGTM! Consider grouping imports

The addition of Polars (import polars as pl) is appropriate for the new Polars-based dataset implementations.

Consider grouping imports by standard library, third-party, and local imports for better readability. For example:

# Standard library imports
import warnings
from typing import List, Dict, Tuple

# Third-party imports
import gin
import numpy as np
import polars as pl
from torch import Tensor, cat, from_numpy, float32
from torch.utils.data import Dataset

# Local imports
from icu_benchmarks.imputation.amputations import ampute_data
from .constants import DataSegment as Segment
from .constants import DataSplit as Split

16-82: Well-implemented CommonPolarsDataset, consider these enhancements

The CommonPolarsDataset class is a good implementation of a Polars-based dataset. Here are some suggestions for improvement:

  1. Add type hints to method return values for better code documentation.
  2. Consider using Polars' lazy execution for better performance in the __init__ method.
  3. The to_tensor method could be optimized for memory usage.

Here are the suggested changes:

  1. Add return type hints:
- def __len__(self) -> int:
+ def __len__(self) -> int:

- def get_feature_names(self) -> List[str]:
+ def get_feature_names(self) -> List[str]:

- def to_tensor(self) -> List[Tensor]:
+ def to_tensor(self) -> List[Tensor]:
  1. Use lazy execution in __init__:
- self.features_df = data[split][Segment.features]
- self.features_df = self.features_df.sort([self.vars["GROUP"], self.vars["SEQUENCE"]])
- self.features_df = self.features_df.drop(self.vars["SEQUENCE"])
+ self.features_df = (
+     data[split][Segment.features]
+     .lazy()
+     .sort([self.vars["GROUP"], self.vars["SEQUENCE"]])
+     .drop(self.vars["SEQUENCE"])
+     .collect()
+ )
  1. Optimize to_tensor method:
def to_tensor(self) -> List[Tensor]:
    return [cat([entry[i].unsqueeze(0) for entry in self], dim=0) for i in range(len(self[0]))]

This version avoids creating intermediate lists and should be more memory-efficient.


84-183: Well-implemented PredictionPolarsDataset, consider these optimizations

The PredictionPolarsDataset class is a good implementation for prediction tasks. Here are some suggestions for improvement:

  1. Optimize the __getitem__ method using Polars operations.
  2. Add error handling for potential edge cases.
  3. Improve type hinting and docstrings.

Here are the suggested changes:

  1. Optimize __getitem__:
def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]:
    if self._cached_dataset is not None:
        return self._cached_dataset[idx]

    stay_id = self.outcome_df[self.vars["GROUP"]].unique()[idx]
    
    window = self.features_df.filter(pl.col(self.vars["GROUP"]) == stay_id)
    labels = self.outcome_df.filter(pl.col(self.vars["GROUP"]) == stay_id)[self.vars["LABEL"]]

    if len(labels) == 1:
        labels = pl.concat([pl.Series([None] * (window.height - 1)), labels])

    length_diff = self.maxlen - window.height
    pad_mask = pl.Series([1] * window.height)

    if length_diff > 0:
        window = window.vstack(pl.DataFrame({col: [0] * length_diff for col in window.columns}))
        labels = labels.extend(pl.Series([0] * length_diff))
        pad_mask = pad_mask.extend(pl.Series([0] * length_diff))

    labels = labels.fill_null(-1)
    pad_mask = pad_mask.where(labels != -1, 0)

    return (
        from_numpy(window.to_numpy().astype(np.float32)),
        from_numpy(labels.to_numpy().astype(np.float32)),
        from_numpy(pad_mask.to_numpy().astype(bool))
    )
  1. Add error handling:
def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]:
    if idx < 0 or idx >= len(self):
        raise IndexError(f"Index {idx} is out of bounds for dataset of size {len(self)}")
    # ... rest of the method
  1. Improve type hinting and docstrings:
def get_balance(self) -> List[float]:
    """
    Calculate the weight balance for the split of interest.

    Returns:
        List[float]: Weights for each label, calculated as (1 / count) * sum(counts) / num_labels.
    """
    # ... method implementation

def get_data_and_labels(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Retrieve all data and labels aligned at once.

    Returns:
        Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing:
            - Data points for the split
            - Labels for the split
            - Row indicators
    """
    # ... method implementation

These changes should improve performance, readability, and robustness of the PredictionPolarsDataset class.


Line range hint 186-245: Consider a deprecation strategy for CommonPandasDataset

The CommonPandasDataset class is marked as deprecated, but it's still present in the codebase. While this approach maintains backward compatibility, it may lead to confusion and maintenance issues in the long term.

Consider the following deprecation strategy:

  1. Add a deprecation timeline in the class docstring, indicating when this class will be removed.
  2. Implement a wrapper class that uses the new CommonPolarsDataset internally but exposes the same interface as CommonPandasDataset. This allows for a smoother transition.
  3. In the __init__ method, log a warning message that includes migration instructions.

Here's an example implementation:

import warnings
from functools import wraps

def deprecated(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        warnings.warn(f"{func.__name__} is deprecated and will be removed in version X.X. Use CommonPolarsDataset instead.", DeprecationWarning, stacklevel=2)
        return func(*args, **kwargs)
    return wrapper

@gin.configurable("CommonPandasDataset")
class CommonPandasDataset(Dataset):
    """
    DEPRECATED: This class will be removed in version X.X. Use CommonPolarsDataset instead.
    
    This is a wrapper around CommonPolarsDataset that maintains the old interface for backward compatibility.
    """

    @deprecated
    def __init__(self, *args, **kwargs):
        self._polars_dataset = CommonPolarsDataset(*args, **kwargs)
        warnings.warn(
            "CommonPandasDataset is deprecated and will be removed in version X.X. "
            "Use CommonPolarsDataset instead. "
            "Migration guide: [URL to migration guide]",
            DeprecationWarning,
            stacklevel=2
        )

    @deprecated
    def __getitem__(self, idx):
        return self._polars_dataset[idx]

    @deprecated
    def __len__(self):
        return len(self._polars_dataset)

    # Implement other methods as needed, delegating to self._polars_dataset

This approach allows users of the old API to continue using it while encouraging migration to the new Polars-based implementation. It also provides a clear timeline for when the old implementation will be removed.


Line range hint 246-337: Apply consistent deprecation strategy for PredictionPandasDataset

Similar to CommonPandasDataset, the PredictionPandasDataset class is deprecated but still present in the codebase. To maintain consistency and facilitate a smooth transition, apply the same deprecation strategy as suggested for CommonPandasDataset.

Implement a wrapper class for PredictionPandasDataset that uses the new PredictionPolarsDataset internally:

@gin.configurable("PredictionPandasDataset")
class PredictionPandasDataset(CommonPandasDataset):
    """
    DEPRECATED: This class will be removed in version X.X. Use PredictionPolarsDataset instead.
    
    This is a wrapper around PredictionPolarsDataset that maintains the old interface for backward compatibility.
    """

    @deprecated
    def __init__(self, *args, **kwargs):
        self._polars_dataset = PredictionPolarsDataset(*args, **kwargs)
        warnings.warn(
            "PredictionPandasDataset is deprecated and will be removed in version X.X. "
            "Use PredictionPolarsDataset instead. "
            "Migration guide: [URL to migration guide]",
            DeprecationWarning,
            stacklevel=2
        )

    @deprecated
    def __getitem__(self, idx):
        return self._polars_dataset[idx]

    @deprecated
    def get_balance(self):
        return self._polars_dataset.get_balance()

    @deprecated
    def get_data_and_labels(self):
        return self._polars_dataset.get_data_and_labels()

    # Implement other methods as needed, delegating to self._polars_dataset

This approach ensures consistency across the codebase and provides a clear path for users to migrate to the new Polars-based implementation.


Line range hint 338-401: Consider creating a Polars-based version of ImputationPandasDataset

The ImputationPandasDataset class is still using Pandas, which is inconsistent with the transition to Polars in other parts of the codebase. To maintain consistency and potentially improve performance, consider creating a Polars-based version of this class.

Here are the steps to consider:

  1. Create a new ImputationPolarsDataset class that uses Polars instead of Pandas.
  2. Deprecate the current ImputationPandasDataset class.
  3. Provide a migration path for users of the old class.

Here's a skeleton for the new Polars-based class:

@gin.configurable("ImputationPolarsDataset")
class ImputationPolarsDataset(CommonPolarsDataset):
    """Polars-based dataset for imputation models."""

    def __init__(
        self,
        data: Dict[str, pl.DataFrame],
        split: str = Split.train,
        vars: Dict[str, str] = gin.REQUIRED,
        mask_proportion=0.3,
        mask_method="MCAR",
        mask_observation_proportion=0.3,
        ram_cache: bool = True,
    ):
        super().__init__(data, split, vars, grouping_segment=Segment.static)
        # Convert ampute_data function to work with Polars DataFrames
        self.amputated_values, self.amputation_mask = ampute_data_polars(
            self.features_df, mask_method, mask_proportion, mask_observation_proportion
        )
        self.amputation_mask = (self.amputation_mask | self.features_df.is_null()).cast(pl.Boolean)
        self.amputation_mask = self.amputation_mask.with_columns(pl.col(self.vars["GROUP"]))
        self.amputation_mask = self.amputation_mask.set_index(self.vars["GROUP"])

        self.target_missingness_mask = self.features_df.is_null()
        self.features_df = self.features_df.fill_null(0)
        self.ram_cache(ram_cache)

    def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        # Implement this method using Polars operations
        pass

    # Implement other methods as needed

After creating this new class, deprecate the old ImputationPandasDataset:

@gin.configurable("ImputationPandasDataset")
class ImputationPandasDataset(CommonPandasDataset):
    """
    DEPRECATED: This class will be removed in version X.X. Use ImputationPolarsDataset instead.
    """

    @deprecated
    def __init__(self, *args, **kwargs):
        warnings.warn(
            "ImputationPandasDataset is deprecated and will be removed in version X.X. "
            "Use ImputationPolarsDataset instead. "
            "Migration guide: [URL to migration guide]",
            DeprecationWarning,
            stacklevel=2
        )
        super().__init__(*args, **kwargs)
        # Rest of the initialization

This approach will help maintain consistency across the codebase and provide a clear migration path for users.


Line range hint 404-478: Consider creating a Polars-based version of ImputationPredictionDataset

The ImputationPredictionDataset class is still using Pandas, which is inconsistent with the transition to Polars in other parts of the codebase. To maintain consistency and potentially improve performance, consider creating a Polars-based version of this class.

Here are the steps to consider:

  1. Create a new ImputationPredictionPolarsDataset class that uses Polars instead of Pandas.
  2. Deprecate the current ImputationPredictionDataset class.
  3. Provide a migration path for users of the old class.

Here's a skeleton for the new Polars-based class:

@gin.configurable("ImputationPredictionPolarsDataset")
class ImputationPredictionPolarsDataset(Dataset):
    """Polars-based dataset for imputation prediction tasks."""

    def __init__(
        self,
        data: pl.DataFrame,
        grouping_column: str = "stay_id",
        select_columns: List[str] = None,
        ram_cache: bool = True,
    ):
        self.dyn_df = data

        if select_columns is not None:
            self.dyn_df = self.dyn_df.select(select_columns + [grouping_column])

        if grouping_column is not None:
            self.dyn_df = self.dyn_df.set_index(grouping_column)

        # calculate basic info for the data
        self.group_indices = self.dyn_df.get_column(grouping_column).unique()
        self.maxlen = self.dyn_df.group_by(grouping_column).agg(pl.count()).max().item()

        self._cached_dataset = None
        if ram_cache:
            logging.info("Caching dataset in ram.")
            self._cached_dataset = [self[i] for i in range(len(self))]

    def __len__(self) -> int:
        return len(self.group_indices)

    def __getitem__(self, idx: int) -> Tensor:
        if self._cached_dataset is not None:
            return self._cached_dataset[idx]
        stay_id = self.group_indices[idx]

        window = self.dyn_df.filter(pl.col(self.dyn_df.index_column) == stay_id)

        return from_numpy(window.to_numpy()).to(float32)

After creating this new class, deprecate the old ImputationPredictionDataset:

@gin.configurable("ImputationPredictionDataset")
class ImputationPredictionDataset(Dataset):
    """
    DEPRECATED: This class will be removed in version X.X. Use ImputationPredictionPolarsDataset instead.
    """

    @deprecated
    def __init__(self, *args, **kwargs):
        warnings.warn(
            "ImputationPredictionDataset is deprecated and will be removed in version X.X. "
            "Use ImputationPredictionPolarsDataset instead. "
            "Migration guide: [URL to migration guide]",
            DeprecationWarning,
            stacklevel=2
        )
        super().__init__()
        # Rest of the initialization

This approach will help maintain consistency across the codebase and provide a clear migration path for users. It will also potentially improve performance by leveraging Polars' efficient data manipulation capabilities.


Line range hint 1-478: Overall assessment: Good progress on Polars transition, but further consistency needed

The transition from Pandas to Polars in this file is a positive change that should improve performance and efficiency. However, there are still some inconsistencies and areas for improvement:

  1. The new Polars-based classes (CommonPolarsDataset and PredictionPolarsDataset) are well-implemented but could benefit from some optimizations and improved type hinting.

  2. The deprecated Pandas-based classes (CommonPandasDataset and PredictionPandasDataset) are still present. A clear deprecation strategy with wrappers for backward compatibility has been suggested.

  3. Two classes (ImputationPandasDataset and ImputationPredictionDataset) are still using Pandas and have not been updated to use Polars. Creating Polars-based versions of these classes has been recommended.

To improve the overall quality and consistency of the codebase:

  1. Implement the suggested optimizations for the Polars-based classes.
  2. Apply the proposed deprecation strategy for Pandas-based classes.
  3. Create Polars-based versions of the remaining Pandas classes.
  4. Ensure consistent error handling and type hinting across all classes.
  5. Consider adding more comprehensive unit tests to verify the behavior of both Pandas and Polars implementations during the transition period.

These changes will result in a more consistent, efficient, and maintainable codebase.

icu_benchmarks/data/split_process_data.py (6)

Line range hint 22-42: LGTM: Function signature updates align with new features and Polars transition.

The changes to the preprocess_data function signature appropriately reflect the new features and the transition to Polars. However, there's a minor improvement we can make:

Consider updating the default values for modality_mapping and vars_to_exclude to use None instead of mutable default arguments:

-    modality_mapping: dict[str] = {},
+    modality_mapping: dict[str] | None = None,
-    vars_to_exclude: list[str] = [],
+    vars_to_exclude: list[str] | None = None,

Then, initialize these variables at the beginning of the function:

modality_mapping = modality_mapping or {}
vars_to_exclude = vars_to_exclude or []

This change follows the best practice of avoiding mutable default arguments in Python.

🧰 Tools
🪛 Ruff

26-26: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


192-213: LGTM: check_sanitize_data function implementation.

The check_sanitize_data function effectively removes duplicates from the loaded data, which is crucial for data integrity. The implementation looks correct and efficient.

Consider adding a docstring to explain the function's parameters and return value:

def check_sanitize_data(data, vars):
    """
    Check for duplicates in the loaded data and remove them.

    Args:
        data (dict): Dictionary containing data segments (static, dynamic, outcome).
        vars (dict): Dictionary containing variable mappings.

    Returns:
        dict: Data with duplicates removed.
    """
    # ... (existing implementation)
🧰 Tools
🪛 Ruff

194-194: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


195-195: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


197-197: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


201-201: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


205-205: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


216-242: LGTM: modality_selection function implementation.

The modality_selection function provides a useful way to select specific modalities based on the provided mapping. The implementation looks correct and handles edge cases well.

Consider adding type hints to the function signature for better code readability and maintainability:

def modality_selection(
    data: dict[str, pl.DataFrame],
    modality_mapping: dict[str, list[str]],
    selected_modalities: list[str],
    vars: dict[str, Any]
) -> tuple[dict[str, pl.DataFrame], dict[str, Any]]:
    # ... (existing implementation)
🧰 Tools
🪛 Ruff

220-220: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


221-221: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


238-238: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


Line range hint 244-318: LGTM: make_train_val function updates for Polars support.

The make_train_val function has been successfully updated to support both Pandas and Polars DataFrames. The changes maintain backwards compatibility while allowing for the transition to Polars.

Consider refactoring the function to reduce code duplication between Pandas and Polars operations. You could create helper functions for common operations:

def sample_data(data, fraction, seed, polars):
    return (
        data.sample(fraction=fraction, seed=seed)
        if polars
        else data.sample(frac=fraction, random_state=seed)
    )

def create_split(stays, train, val, polars):
    if polars:
        return {
            Split.train: stays[train].cast(pl.datatypes.Int64).to_frame(),
            Split.val: stays[val].cast(pl.datatypes.Int64).to_frame(),
        }
    else:
        return {Split.train: stays.iloc[train], Split.val: stays.iloc[val]}

# Use these helper functions in make_train_val

This refactoring will make the code more maintainable and easier to read.

🧰 Tools
🪛 Ruff

301-301: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


309-309: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


313-313: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


Line range hint 332-452: LGTM: make_single_split function updates for Polars support.

The make_single_split function has been successfully updated to support both Pandas and Polars DataFrames. The changes maintain backwards compatibility while allowing for the transition to Polars.

Similar to the make_train_val function, consider refactoring to reduce code duplication:

  1. Create helper functions for common operations (e.g., getting labels, creating splits).
  2. Use ternary operators for simple Polars/Pandas conditionals:
dev_stays = stays[dev] if polars else stays.iloc[dev]
  1. Consider extracting the data split creation logic into a separate function, as it's similar in both make_train_val and make_single_split.

These refactorings will improve code maintainability and readability.


Line range hint 1-465: Overall LGTM: Successful transition to Polars with enhanced functionality.

The changes in this file represent a significant improvement to the data processing pipeline:

  1. The transition from Pandas to Polars is well-implemented, which should lead to better performance and reduced memory usage.
  2. New features like modality selection and data sanitization add valuable functionality to the preprocessing steps.
  3. Backwards compatibility with Pandas is maintained, allowing for a smooth transition.

These changes enhance the overall functionality and efficiency of the module. Great job on the implementation!

For future improvements, consider:

  1. Further modularizing the code by extracting common operations into utility functions.
  2. Adding more comprehensive error handling and logging throughout the pipeline.
  3. Implementing unit tests for the new functions and modified logic to ensure robustness.
🧰 Tools
🪛 Ruff

301-301: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


309-309: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


313-313: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

icu_benchmarks/models/wrappers.py (2)

413-414: LGTM: Added row_indicators to data retrieval.

The addition of row_indicators to the get_data_and_labels() calls enhances the data handling capabilities, possibly to manage row-level information. This change is consistent with similar modifications in other methods like validation_step and test_step.

Consider updating the method's docstring to reflect the new row_indicators parameter and its purpose.


Line range hint 562-621: LGTM: Added flexible weight initialization and improved on_fit_start.

The addition of the init_weights method provides valuable flexibility in weight initialization, supporting various techniques like normal, Xavier, Kaiming, and orthogonal initialization. The modification of on_fit_start to initialize weights and reset metrics ensures proper setup before training begins.

Consider adding type hints to the init_weights method for improved code readability:

def init_weights(self, init_type: str = "normal", gain: float = 0.02):
    # ... (rest of the method remains the same)

This change would make the expected types of the parameters more explicit.

icu_benchmarks/data/preprocessor.py (6)

32-34: Document the usage of vars_to_exclude in derived classes.

The vars_to_exclude parameter has been added to the apply method signature in the abstract base class. However, it's not clear how this parameter is used in the derived classes. Consider adding documentation or implementing this feature consistently across all preprocessor classes.


164-177: Clarify the usage of the _model_impute method.

The _model_impute method is defined but never called within this class. If it's intended to be used by derived classes or externally, consider adding a docstring explaining its purpose and usage.


Line range hint 279-350: Add error handling for missing data segments.

In the apply method, the code assumes that all required data segments (dynamic, static, features) exist. Consider adding error handling to gracefully handle cases where these segments might be missing from the data dictionary.

Suggested addition:

if Segment.dynamic not in data[Split.train]:
    raise ValueError(f"Missing {Segment.dynamic} segment in {Split.train} data.")
# Add similar checks for other required segments

Line range hint 476-536: Enhance logging for the missing value filtering process.

In the _process_dynamic_data method, consider adding more detailed logging about the filtering process. This could include the number of rows removed and the percentage of data affected.

Suggested addition:

if self.filter_missing_values:
    total_rows = len(data[Segment.dynamic])
    rows_removed = rows_to_remove.sum()
    logging.info(f"Removed {rows_removed} rows ({rows_removed/total_rows:.2%}) with missing values.")
    logging.info(f"Removed {len(ids_to_remove)} stays with missing values.")

Line range hint 574-584: Add error handling for file operations in cache_recipe.

Consider adding try-except blocks to handle potential IOErrors when creating directories or writing to files. This will make the function more robust against file system issues.

Suggested addition:

def cache_recipe(recipe: Recipe, cache_file: str) -> None:
    recipe_cache = copy.deepcopy(recipe)
    recipe_cache.cache()
    try:
        if not (cache_file / "..").exists():
            (cache_file / "..").mkdir(parents=True, exist_ok=True)
        cache_file.touch()
        with open(cache_file, "wb") as f:
            pickle.dump(recipe_cache, f, pickle.HIGHEST_PROTOCOL)
        logging.info(f"Cached recipe in {cache_file}.")
    except IOError as e:
        logging.error(f"Failed to cache recipe: {e}")

Security Issue Confirmed: Usage of pickle.load in preprocessor.py.

The restore_recipe function uses pickle.load(f), which can execute arbitrary code if the source of the cached recipe is not trusted. This poses a significant security risk.

  • File: icu_benchmarks/data/preprocessor.py
    • Line: recipe = pickle.load(f)

Consider using a more secure serialization method, such as JSON with custom encoding/decoding for complex objects, or implement safety checks before loading pickled data.

🔗 Analysis chain

Line range hint 587-596: Consider security implications of using pickle for recipe restoration.

Using pickle.load() can be a security risk if the source of the cached recipe is not trusted, as it can execute arbitrary code. Consider using a more secure serialization method like JSON with custom encoding/decoding for complex objects, or implement safety checks before loading pickled data.

To verify the potential security risk, you can run the following script:

If this search returns multiple results, it might indicate a wider use of pickle that should be reviewed for security.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Check for uses of pickle.load in the codebase
rg --type python 'pickle\.load'

Length of output: 64


Script:

#!/bin/bash
# Check for uses of pickle.load in Python files
rg 'pickle\.load' --glob '*.py'

Length of output: 185

docs/adding_model/rnn.py (1)

9-9: Enhance the class docstring for better clarity.

The current docstring """Torch standard RNN model""" is brief. Expanding it to include details about the model's parameters, expected input shapes, and usage examples would improve readability and maintainability.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 4a4ab0c and 4ed54cf.

📒 Files selected for processing (20)
  • docs/adding_model/rnn.py (1 hunks)
  • icu_benchmarks/cross_validation.py (6 hunks)
  • icu_benchmarks/data/loader.py (7 hunks)
  • icu_benchmarks/data/pooling.py (6 hunks)
  • icu_benchmarks/data/preprocessor.py (8 hunks)
  • icu_benchmarks/data/split_process_data.py (10 hunks)
  • icu_benchmarks/models/dl_models/rnn.py (1 hunks)
  • icu_benchmarks/models/dl_models/tcn.py (1 hunks)
  • icu_benchmarks/models/dl_models/transformer.py (1 hunks)
  • icu_benchmarks/models/ml_models/catboost.py (1 hunks)
  • icu_benchmarks/models/ml_models/imblearn.py (1 hunks)
  • icu_benchmarks/models/ml_models/lgbm.py (1 hunks)
  • icu_benchmarks/models/ml_models/sklearn.py (1 hunks)
  • icu_benchmarks/models/ml_models/xgboost.py (1 hunks)
  • icu_benchmarks/models/train.py (10 hunks)
  • icu_benchmarks/models/wrappers.py (9 hunks)
  • icu_benchmarks/run.py (7 hunks)
  • icu_benchmarks/run_utils.py (6 hunks)
  • icu_benchmarks/tuning/hyperparameters.py (7 hunks)
  • requirements.txt (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • icu_benchmarks/cross_validation.py
  • icu_benchmarks/models/ml_models/imblearn.py
  • requirements.txt
🧰 Additional context used
🪛 Ruff
docs/adding_model/rnn.py

15-15: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)

icu_benchmarks/data/pooling.py

106-106: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


108-108: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)

icu_benchmarks/data/split_process_data.py

26-26: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


41-41: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


121-121: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


126-126: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


126-126: Test for membership should be not in

Convert to not in

(E713)


194-194: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


195-195: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


197-197: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


201-201: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


205-205: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


220-220: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


221-221: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


238-238: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


301-301: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


309-309: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


313-313: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


419-422: Use ternary operator dev_stays = stays[dev] if polars else stays.iloc[dev] instead of if-else-block

Replace if-else-block with dev_stays = stays[dev] if polars else stays.iloc[dev]

(SIM108)


438-438: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


446-446: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


450-450: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

icu_benchmarks/models/dl_models/rnn.py

16-16: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)


42-42: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)


69-69: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)

icu_benchmarks/models/dl_models/tcn.py

23-23: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)

icu_benchmarks/models/dl_models/transformer.py

37-37: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)


52-52: Loop control variable i not used within loop body

Rename unused i to _i

(B007)


106-106: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)


123-123: Loop control variable i not used within loop body

Rename unused i to _i

(B007)

icu_benchmarks/models/ml_models/catboost.py

14-14: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)

icu_benchmarks/models/ml_models/xgboost.py

27-27: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)


76-76: f-string without any placeholders

Remove extraneous f prefix

(F541)

icu_benchmarks/models/wrappers.py

474-474: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

icu_benchmarks/run_utils.py

18-18: os imported but unused

Remove unused import: os

(F401)


19-19: glob imported but unused

Remove unused import: glob

(F401)

🔇 Additional comments (41)
icu_benchmarks/models/ml_models/catboost.py (4)

1-9: LGTM: Imports and class declaration look good

The imports are correct, including the fixed import from constants. The CBClassifier class is properly defined with the @gin.configurable decorator and the correct _supported_run_modes attribute.


17-27: LGTM: Predict method is well-implemented

The predict method is correctly implemented to return class probabilities using the CatBoost model's predict_proba method. The docstring provides clear and concise information about the method's functionality, parameters, and return value.


1-27: Overall assessment: Good implementation with minor improvement suggestions

The CBClassifier class provides a solid wrapper for the CatBoost classifier, integrating well with the project's structure. The implementation is generally good, with clear method definitions and appropriate use of decorators.

The main area for improvement is in the constructor, as noted in the previous comment. Addressing these suggestions will enhance the flexibility and maintainability of the code.

Great job on fixing the import statement and implementing the predict method with clear documentation!

🧰 Tools
🪛 Ruff

14-14: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)


11-15: 🛠️ Refactor suggestion

Refactor constructor for improved flexibility and code cleanliness

Consider the following improvements to the constructor:

  1. Make the task type configurable instead of hardcoding it to "CPU".
  2. Remove the commented-out code to improve readability.
  3. Address the star-arg unpacking warning.

Here's a suggested implementation:

def __init__(self, task_type="CPU", *args, **kwargs):
    model_kwargs = {'task_type': task_type}
    self.model = self.set_model_args(cb.CatBoostClassifier, **model_kwargs, **kwargs)
    super().__init__(*args, **kwargs)

This change allows the task type to be specified when initializing the class, defaulting to "CPU" if not provided, and addresses the star-arg unpacking warning.

🧰 Tools
🪛 Ruff

14-14: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)

icu_benchmarks/models/ml_models/lgbm.py (3)

1-8: LGTM: Import statements are appropriate and well-organized.

The import statements are concise and include all necessary libraries for the LightGBM wrapper implementation, including gin for configuration, numpy for numerical operations, and wandb for logging. The specific import of the wandb callback for LightGBM is a good practice for integration.


1-57: Overall, excellent implementation of LightGBM wrappers with minor suggestions for enhancement.

This file provides a well-structured and robust implementation of LightGBM wrappers for both classification and regression tasks. The use of gin for configuration and wandb for logging demonstrates good practices in machine learning model development. The code is clean, readable, and follows a consistent style.

Key strengths:

  1. Proper separation of concerns between wrapper, classifier, and regressor classes.
  2. Integration with wandb for experiment tracking.
  3. Use of early stopping and callbacks for efficient training.

Suggestions for improvement:

  1. Add type hints to method signatures for better code readability and maintainability.
  2. Consider adding input validation in the predict methods to enhance robustness.
  3. Clarify the inheritance of the predict method in the LGBMRegressor class.

These minor enhancements would further improve an already solid implementation.


51-57: LGTM: LGBMRegressor implementation is concise, but clarification needed.

The LGBMRegressor class is correctly implemented, using gin for configuration and properly specifying its supported run mode as regression.

However, I noticed that unlike the LGBMClassifier, this class doesn't include a predict method. Is this intentional? If so, it might be helpful to add a comment explaining that the predict method is inherited from the parent class. If not, consider adding a predict method specific to regression tasks.

To verify the existence of a predict method in the parent class, we can run the following script:

This will help us confirm whether the predict method is indeed inherited from the parent class.

icu_benchmarks/models/dl_models/tcn.py (3)

1-9: LGTM: Imports are appropriate and well-organized.

The imports cover all necessary dependencies for implementing the Temporal Convolutional Network, including configuration management (gin), numerical operations (numpy), neural network components (torch.nn), and custom modules from the icu_benchmarks package.


18-55: LGTM: Well-structured constructor with automatic channel computation.

The constructor is well-designed, handling various input scenarios and automatically computing the number of channels when needed. The network construction logic clearly implements the TCN architecture using TemporalBlock instances.

🧰 Tools
🪛 Ruff

23-23: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)


57-62: LGTM: Forward method is correctly implemented.

The forward method efficiently handles the necessary tensor permutations and correctly passes the data through the network and final linear layer. The implementation follows the expected flow for a TCN.

icu_benchmarks/models/ml_models/sklearn.py (2)

2-2: LGTM: Import statements have been streamlined.

The consolidation of scikit-learn imports into a single line improves code clarity and readability. The removal of LightGBM and WandB imports is consistent with the elimination of LightGBM-related functionality.


Line range hint 1-93: Verify the impact of LightGBM removal.

The LightGBM-related classes and methods have been removed as per the AI summary. While this change simplifies the sklearn.py file, it's crucial to ensure that this removal doesn't cause issues elsewhere in the codebase.

Please run the following script to check for any remaining references to LightGBM in the codebase:

If any references are found, they may need to be removed or updated to maintain consistency with this change.

icu_benchmarks/models/dl_models/rnn.py (1)

1-5: LGTM: Import statements are appropriate and concise.

The import statements are well-organized and include all necessary modules for the implemented functionality.

icu_benchmarks/models/dl_models/transformer.py (3)

1-6: LGTM: Import statements are well-organized and relevant.

The import statements are appropriately structured, following the convention of standard library imports first, followed by third-party libraries, and then local modules. All imports seem relevant to the implementation of transformer models.


9-14: Good use of gin for configuration and clear mode support.

The Transformer class is well-structured with @gin.configurable decorator, allowing for flexible configuration. The explicit definition of supported run modes (classification and regression) is a good practice for clarity.


69-76: LGTM: Forward method is concise and follows the transformer architecture.

The forward method is well-implemented, following the expected transformer architecture. It processes the input through the embedding layer, positional encoding (if enabled), transformer blocks, and finally the output layer.

icu_benchmarks/models/ml_models/xgboost.py (4)

1-25: LGTM: Imports and class definition look good.

The imports are appropriate for the XGBoost classifier implementation. The use of the @gin.configurable decorator allows for flexible configuration, which is a good practice. The _supported_run_modes attribute clearly defines that this class is intended for classification tasks only.


30-40: LGTM: predict method is well-implemented.

The predict method is correctly implemented, using the model's predict_proba method to return class probabilities. The docstring is well-written, providing clear information about the method's input, output, and functionality.


98-101: LGTM: get_feature_importance method is well-implemented.

The get_feature_importance method is correctly implemented. It properly checks if the model has been fit before attempting to return the feature importances. The error message is clear and informative, guiding the user to call fit_model() if needed.


87-96: ⚠️ Potential issue

Improve robustness of set_model_args method.

The current implementation of set_model_args can be improved for better robustness and clarity. Consider implementing the following changes:

  1. Directly use the kwargs parameter instead of locals()['kwargs'].
  2. Filter out invalid arguments to prevent potential errors when initializing the model.

Here's a suggested implementation:

def set_model_args(self, model, *args, **kwargs):
    """Set model arguments based on the model's signature."""
    signature = inspect.signature(model.__init__).parameters
    valid_params = signature.keys()
    
    # Filter out invalid arguments
    valid_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
    
    logging.debug(f"Creating model with: {valid_kwargs}.")
    return model(**valid_kwargs)

This refactoring improves robustness by filtering out invalid arguments and avoids the use of locals().

icu_benchmarks/data/pooling.py (2)

19-29: Approve formatting changes in __init__ method signature.

The vertical alignment of parameters in the __init__ method signature improves code readability without affecting functionality. This change is in line with good coding practices.


51-55: Approve formatting changes in generate method signature.

The vertical alignment of parameters in the generate method signature enhances code readability without altering functionality. This change is consistent with the formatting improvements made throughout the class.

icu_benchmarks/run.py (4)

21-21: LGTM: New import added for configuration file retrieval.

The addition of get_config_files to the imports is consistent with the new functionality described in the summary. This change enhances the module's capability to handle configuration files.


84-86: LGTM: Improved readability of wandb config update.

The modification to the update_wandb_config call enhances code readability by breaking the long line into multiple lines. The functionality remains unchanged, which is good.


145-148: LGTM: Improved readability of hyperparameter tuning function call.

The modification to the choose_and_bind_hyperparameters_optuna function call enhances code readability by using keyword arguments. This change not only improves clarity but also reduces the likelihood of errors when passing arguments. It aligns well with the summary mentioning updates to enhance hyperparameter tuning processes.


Line range hint 1-203: Overall assessment: The changes improve the code quality and functionality.

The modifications in this file align well with the PR objectives and the AI-generated summary. Key improvements include:

  1. Enhanced configuration file handling and validation.
  2. Added support for modalities and label binding.
  3. Improved error handling and logging.
  4. Better readability through the use of keyword arguments and line breaks.
  5. Support for different model types based on run mode.

These changes contribute to a more robust and flexible codebase. The minor suggestions provided in the review comments can further enhance code quality and maintainability.

icu_benchmarks/models/train.py (3)

6-6: LGTM: Transition to Polars for improved performance

The addition of Polars (pl) import and PredictionPolarsDataset is a positive change. Polars is known for its superior performance and lower memory usage compared to pandas, especially for large datasets. This transition should lead to improved efficiency in data handling operations.

Also applies to: 14-14


127-127: LGTM: Flexible model loading with CPU option

The addition of the cpu parameter to the model loading process is a good improvement. This change enhances flexibility, particularly for CPU-only environments, and improves cross-platform compatibility.


Line range hint 1-247: Overall assessment: Significant improvements with minor suggestions

The changes to train.py represent a substantial improvement in the codebase:

  1. The transition from pandas to polars should lead to better performance and reduced memory usage.
  2. New configuration options provide greater flexibility in training and data handling.
  3. The addition of SHAP value persistence enhances model interpretability.

These changes are well-implemented and should positively impact the project. However, to further improve the code:

  1. Update the docstrings for functions with new or changed parameters.
  2. Add inline comments explaining the reasoning behind specific configuration choices.
  3. Address the TODO comment regarding full support for polars versions of datasets.
  4. Consider the suggestions for minor improvements in error handling and logging.

Great work on these enhancements! The changes show a clear focus on performance, flexibility, and interpretability.

icu_benchmarks/tuning/hyperparameters.py (3)

5-5: LGTM: New imports for Optuna and visualization.

The added imports for matplotlib and Optuna visualization functions are appropriate for the new hyperparameter tuning implementation. The relocation of the RunMode import doesn't affect functionality.

Also applies to: 16-18


264-269: LGTM: Updated callback function for Optuna compatibility.

The tune_step_callback function has been successfully adapted to work with Optuna's study and trial objects. The logging improvements provide better visibility into the tuning progress.


341-363: LGTM: Optuna study creation and optimization implemented correctly.

The creation of the Optuna study with specified sampler and pruner, as well as the optimization process with appropriate callbacks, has been implemented correctly. The integration with wandb through the WeightsAndBiasesCallback is a nice addition for experiment tracking.

icu_benchmarks/data/split_process_data.py (1)

3-4: LGTM: Import changes align with the transition to Polars.

The import changes correctly reflect the shift from Pandas to Polars for data processing. The addition of os and timeit imports, along with the removal of pyarrow.parquet, are appropriate for the new implementation.

Also applies to: 9-9, 12-12, 14-15

icu_benchmarks/models/wrappers.py (5)

6-7: LGTM: New metric imports added.

The addition of average_precision_score and roc_auc_score from sklearn.metrics enhances the model evaluation capabilities. These are valuable metrics for classification tasks, particularly for imbalanced datasets.


33-34: LGTM: New metrics configured for gin.

The addition of gin configurations for average_precision_score and roc_auc_score allows for flexible use of these metrics in the gin-config system. This is consistent with the existing pattern and enhances the configurability of the model evaluation process.


376-383: ⚠️ Potential issue

Consider removing or implementing commented code and unused attribute.

  1. The commented-out code (lines 376-379) for setting self.loss based on the loss function name is not being used. If this functionality is no longer needed, consider removing it entirely. If it's intended for future use, add a TODO comment explaining its purpose and when it might be implemented.

  2. The self.loss_weight attribute (line 383) is added but not used anywhere in the class. If it's intended for future use, add a TODO comment explaining its purpose. If it's not needed, consider removing it to avoid confusion.

To check if self.loss_weight is used elsewhere in the codebase:

#!/bin/bash
# Search for usage of self.loss_weight in Python files
rg "self\.loss_weight" --type python

440-440: ⚠️ Potential issue

Consider using or removing the unused row_indicators variable.

The row_indicators variable is retrieved from get_data_and_labels() but not used in the method body. If it's intended for future use, consider adding a TODO comment explaining its purpose. If it's not needed, you can remove it to keep the code clean.

To check if row_indicators is used elsewhere in the validation_step method:


474-488: ⚠️ Potential issue

Simplify dictionary key check and review MPS-dependent type casting.

  1. Simplify the dictionary key check:
    Instead of if "Confusion_Matrix" in self.metrics.keys():, use if "Confusion_Matrix" in self.metrics:. This is more efficient as it avoids creating a temporary view object for the dictionary keys.

  2. Review MPS-dependent type casting:
    The current implementation of MPS-dependent type casting seems redundant. Both branches of the conditional expression perform the same operation:

    metric(self.label_transform(label), self.output_transform(pred))

    This might not achieve the intended MPS compatibility. Consider reviewing and simplifying this logic, or add a comment explaining the necessity of this structure if it's intentional.

Here's a suggested simplification:

if "Confusion_Matrix" in self.metrics:
    self.log_dict(confusion_matrix(self.label_transform(label), self.output_transform(pred)), sync_dist=True)

self.log_dict(
    {
        f"{metric_type}/{name}": metric(self.label_transform(label), self.output_transform(pred))
        for name, metric in self.metrics.items()
        if not isinstance(metric(self.label_transform(label), self.output_transform(pred)), tuple)
        and name != "Confusion_Matrix"
    },
    sync_dist=True,
)

If MPS-specific handling is necessary, please add a comment explaining its purpose and implementation.

To check if there are any MPS-specific implementations elsewhere in the codebase:

🧰 Tools
🪛 Ruff

474-474: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

icu_benchmarks/data/preprocessor.py (2)

Line range hint 9-21: LGTM: New imports for Polars and additional preprocessing steps.

The new imports for Polars and additional preprocessing steps from the recipys library are appropriate for the changes made in this file.


Line range hint 537-571: LGTM: Flexible implementation for applying recipes.

The apply_recipe_to_splits function is well-implemented, handling both Polars and Pandas DataFrames, and providing options for caching and loading recipes. This flexibility is beneficial for various preprocessing scenarios.

docs/adding_model/rnn.py (1)

19-19: ⚠️ Potential issue

Verify the indexing of input_size when initializing nn.RNN.

In line 19, input_size[2] is used as the input dimension for nn.RNN:

self.rnn = nn.RNN(input_size[2], hidden_dim, layer_dim, batch_first=True)

This implies that input_size is an iterable with at least three elements. Ensure that input_size[2] correctly represents the feature dimension expected by the RNN.

To confirm, please verify how input_size is defined throughout the codebase. If input_size is intended to be an integer representing the input feature size, adjust the code accordingly.

Apply this diff if input_size should be an integer:

- self.rnn = nn.RNN(input_size[2], hidden_dim, layer_dim, batch_first=True)
+ self.rnn = nn.RNN(input_size, hidden_dim, layer_dim, batch_first=True)

icu_benchmarks/models/dl_models/tcn.py Outdated Show resolved Hide resolved
icu_benchmarks/models/dl_models/rnn.py Outdated Show resolved Hide resolved
icu_benchmarks/models/dl_models/rnn.py Outdated Show resolved Hide resolved
icu_benchmarks/models/dl_models/transformer.py Outdated Show resolved Hide resolved
icu_benchmarks/models/dl_models/transformer.py Outdated Show resolved Hide resolved
icu_benchmarks/data/preprocessor.py Show resolved Hide resolved
icu_benchmarks/data/preprocessor.py Show resolved Hide resolved
icu_benchmarks/data/preprocessor.py Outdated Show resolved Hide resolved
docs/adding_model/rnn.py Show resolved Hide resolved
docs/adding_model/rnn.py Show resolved Hide resolved
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 16

🧹 Outside diff range and nitpick comments (33)
icu_benchmarks/models/dl_models/rnn.py (2)

34-58: LGTM: LSTMNet implementation is correct.

The LSTMNet class is well-implemented with appropriate initialization, hidden state management, and forward pass. The use of @gin.configurable allows for flexible configuration.

The issues mentioned in the past review comment have been addressed:

  1. The argument unpacking in the super().init() call has been fixed.
  2. The implementation is correct and doesn't have any critical issues.

As a minor suggestion, consider adding type hints to method parameters and return values for improved code readability and maintainability.


61-84: LGTM: GRUNet implementation is correct.

The GRUNet class is well-implemented with appropriate initialization, hidden state management, and forward pass. The use of @gin.configurable allows for flexible configuration.

The issues mentioned in the past review comment have been addressed:

  1. The missing return statement in the forward method has been added.
  2. The argument unpacking in the super().init() call has been fixed.

As a minor suggestion, consider adding type hints to method parameters and return values for improved code readability and maintainability.

icu_benchmarks/models/ml_models/xgboost.py (2)

42-59: LGTM with suggestion: fit_model method is well-implemented.

The fit_model method correctly fits the model to the training data, incorporating early stopping and optional Weights & Biases integration. The conditional computation of SHAP values is a good approach for flexibility.

Suggestion for improvement:
Consider allowing the user to specify which validation metric to use for the return value, instead of always using the first one. This would provide more flexibility in model evaluation.

Example:

def fit_model(self, train_data, train_labels, val_data, val_labels, eval_metric='auto'):
    # ... existing code ...
    eval_score = mean(self.model.evals_result_["validation_0"][eval_metric])
    return eval_score

This change would allow users to specify which metric they want to use for evaluation.


61-84: LGTM with minor fix: test_step method is well-implemented.

The test_step method correctly processes test data, makes predictions, and logs results. It handles both multi-processing and standard logging scenarios effectively.

Minor fix needed:
On line 76, there's an f-string without any placeholders. Remove the f prefix:

logging.debug(f"Saved row indicators to {os.path.join(self.logger.save_dir, 'row_indicators.csv')}")

This change resolves the static analysis warning and improves code clarity.

🧰 Tools
🪛 Ruff

76-76: f-string without any placeholders

Remove extraneous f prefix

(F541)

icu_benchmarks/run_utils.py (5)

58-64: LGTM! Consider adding type hints for clarity

The new arguments --modalities and --label are well-implemented and enhance the flexibility of the parser. To improve code clarity, consider adding type hints to the nargs and type parameters.

Here's a suggested improvement:

 parser.add_argument(
     "-mo",
     "--modalities",
-    nargs="+",
+    nargs: str = "+",
     help="Optional modality selection to use. Specify multiple modalities separated by spaces.",
 )
-parser.add_argument("--label", type=str, help="Label to use for evaluation in case of multiple labels.", default=None)
+parser.add_argument("--label", type: str = str, help="Label to use for evaluation in case of multiple labels.", default=None)

81-82: LGTM! Consider using a more efficient approach

The addition of the while loop effectively prevents directory name conflicts. However, for better efficiency, consider using a single call to datetime.now() and formatting it with both seconds and microseconds.

Here's a suggested optimization:

-log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
-while log_dir_run.exists():
-    log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f"))
+now = datetime.now()
+log_dir_run = log_dir / now.strftime("%Y-%m-%dT%H-%M-%S.%f")
+while log_dir_run.exists():
+    now = datetime.now()
+    log_dir_run = log_dir / now.strftime("%Y-%m-%dT%H-%M-%S.%f")

This approach reduces the number of datetime.now() calls and ensures unique directory names.


Line range hint 111-142: Refactor duplicate code and improve error handling

The addition of SHAP value aggregation is a valuable enhancement. However, there are a few issues to address:

  1. There's duplicate code for concatenating and writing SHAP values (lines 134-136 and 138-140).
  2. The error handling could be more specific and informative.

Here's a suggested refactoring:

 shap_values_test = []
-# shap_values_train = []
 for repetition in log_dir.iterdir():
     if repetition.is_dir():
         aggregated[repetition.name] = {}
         for fold_iter in repetition.iterdir():
             aggregated[repetition.name][fold_iter.name] = {}
             if (fold_iter / "test_metrics.json").is_file():
                 with open(fold_iter / "test_metrics.json", "r") as f:
                     result = json.load(f)
                     aggregated[repetition.name][fold_iter.name].update(result)
             elif (fold_iter / "val_metrics.csv").is_file():
                 with open(fold_iter / "val_metrics.csv", "r") as f:
                     result = json.load(f)
                     aggregated[repetition.name][fold_iter.name].update(result)
             # Add durations to metrics
             if (fold_iter / "durations.json").is_file():
                 with open(fold_iter / "durations.json", "r") as f:
                     result = json.load(f)
                     aggregated[repetition.name][fold_iter.name].update(result)
             if (fold_iter / "test_shap_values.parquet").is_file():
                 shap_values_test.append(pl.read_parquet(fold_iter / "test_shap_values.parquet"))

-if shap_values_test:
-    shap_values = pl.concat(shap_values_test)
-    shap_values.write_parquet(log_dir / "aggregated_shap_values.parquet")
-
 try:
-    shap_values = pl.concat(shap_values_test)
-    shap_values.write_parquet(log_dir / "aggregated_shap_values.parquet")
+    if shap_values_test:
+        shap_values = pl.concat(shap_values_test)
+        shap_values.write_parquet(log_dir / "aggregated_shap_values.parquet")
+        logging.info(f"Successfully aggregated and wrote SHAP values to {log_dir / 'aggregated_shap_values.parquet'}")
 except Exception as e:
-    logging.error(f"Error aggregating or writing SHAP values: {e}")
+    logging.error(f"Error aggregating or writing SHAP values: {str(e)}", exc_info=True)

This refactoring removes the duplicate code, improves error handling, and adds a success log message for better traceability.

🧰 Tools
🪛 Ruff

145-145: Loop control variable repetition not used within loop body

(B007)


262-286: LGTM! Consider enhancing error handling and readability

The get_config_files function is a valuable addition for dynamically discovering available tasks and models. However, there are a few suggestions to improve its robustness and readability:

  1. Use Path.glob() instead of creating a list immediately, as it's more memory-efficient.
  2. Combine the filtering and processing steps for better readability.
  3. Add more specific error handling for different scenarios (e.g., directory not found, permission issues).

Here's a suggested refactoring:

def get_config_files(config_dir: Path):
    """
    Get all task and model config files in the specified directory.
    
    Args:
        config_dir: Path to the directory containing the config gin files.

    Returns:
        tuple: Lists of task names and model names.
    """
    tasks, models = [], []
    
    try:
        tasks = [task.stem for task in (config_dir / "tasks").glob("*") 
                 if task.is_file() and task.stem != "common"]
        models = [model.stem for model in (config_dir / "prediction_models").glob("*") 
                  if model.is_file() and model.stem != "common"]
    except FileNotFoundError as e:
        logging.error(f"Directory not found: {e}")
    except PermissionError as e:
        logging.error(f"Permission denied when accessing config files: {e}")
    except Exception as e:
        logging.error(f"Unexpected error when retrieving config files: {e}")
    
    logging.info(f"Found tasks: {tasks}")
    logging.info(f"Found models: {models}")
    
    return tasks, models

This refactored version improves readability, efficiency, and error handling while maintaining the original functionality.


288-301: LGTM! Consider adding an option for soft validation

The check_required_keys function is a valuable addition for input validation. It's well-implemented and follows good practices. To enhance its flexibility, consider adding an option for "soft" validation that returns a boolean instead of raising an exception.

Here's a suggested enhancement:

def check_required_keys(vars: dict, required_keys: list, raise_error: bool = True) -> bool:
    """
    Checks if all required keys are present in the vars dictionary.

    Args:
        vars (dict): The dictionary to check.
        required_keys (list): The list of required keys.
        raise_error (bool): If True, raises a KeyError when keys are missing. If False, returns a boolean.

    Returns:
        bool: True if all required keys are present, False otherwise (only if raise_error is False).

    Raises:
        KeyError: If any required key is missing and raise_error is True.
    """
    missing_keys = [key for key in required_keys if key not in vars]
    if missing_keys:
        if raise_error:
            raise KeyError(f"Missing required keys in vars: {', '.join(missing_keys)}")
        return False
    return True

This enhancement allows the function to be used in different scenarios, providing more flexibility to the caller.

icu_benchmarks/tuning/hyperparameters.py (2)

196-196: Consider documenting the plot parameter.

The new plot parameter has been added to the function signature, but it's not documented in the function's docstring. Consider adding a description of this parameter to maintain consistency with the documentation of other parameters.


Line range hint 201-201: Clarify the sampler parameter and its usage.

The sampler parameter is added to the function signature and used later in the function, but it's not documented in the function's docstring. Additionally, the default value is set to optuna.samplers.GPSampler, but this is instantiated differently based on whether it's this specific sampler or not. Consider:

  1. Adding documentation for the sampler parameter in the function's docstring.
  2. Clarifying the logic for sampler instantiation, possibly by creating a separate function for this.

Here's a suggested refactor for the sampler instantiation:

def create_sampler(sampler_class, seed, n_initial_points):
    if sampler_class == optuna.samplers.GPSampler:
        return sampler_class(seed=seed, n_startup_trials=n_initial_points, deterministic_objective=True)
    return sampler_class(seed=seed)

# In the main function
sampler = create_sampler(sampler, seed, n_initial_points)

Also applies to: 322-323

icu_benchmarks/data/loader.py (4)

18-28: Add type hints to __init__ method parameters

Consider adding type hints to the __init__ method parameters for better code documentation and to leverage static type checking tools. For example:

def __init__(
    self,
    data: Dict[str, pl.DataFrame],
    split: str = Split.train,
    vars: Dict[str, str] = gin.REQUIRED,
    grouping_segment: str = Segment.outcome,
    mps: bool = False,
    name: str = "",
    *args,
    **kwargs,
):

Line range hint 185-203: Clarify deprecation path for CommonPandasDataset

The renaming of CommonDataset to CommonPandasDataset and the addition of a deprecation warning are appropriate steps in the transition to Polars. However, the class still uses Pandas operations, which may be confusing for developers.

Consider the following suggestions:

  1. Add a comment explaining the deprecation timeline and migration path to the Polars-based classes.
  2. If possible, update the class to use Polars operations internally while maintaining the same interface. This would allow for a smoother transition for existing code.

Example comment:

"""
This class is deprecated and will be removed in version X.X. Please migrate to CommonPolarsDataset.
For assistance with migration, refer to the migration guide at <link_to_migration_guide>.
"""

245-246: Add deprecation warning to PredictionPandasDataset

For consistency with the CommonPandasDataset class and to facilitate a smooth transition to the Polars-based implementation, consider adding a deprecation warning to the PredictionPandasDataset class. This will help users understand that they should migrate to the new PredictionPolarsDataset class in the future.

Example:

@gin.configurable("PredictionPandasDataset")
class PredictionPandasDataset(CommonPandasDataset):
    def __init__(self, *args, ram_cache: bool = True, **kwargs):
        warnings.warn("PredictionPandasDataset is deprecated. Use PredictionPolarsDataset instead.", DeprecationWarning, stacklevel=2)
        super().__init__(*args, grouping_segment=Segment.outcome, **kwargs)
        # ... rest of the __init__ method ...

Line range hint 1-462: Overall assessment: Good progress on Pandas to Polars transition with room for improvement

The transition from Pandas to Polars in this file is a significant improvement that should enhance performance and reduce memory usage. The addition of new Polars-based classes (CommonPolarsDataset and PredictionPolarsDataset) alongside the renamed Pandas-based classes provides a clear path for migration.

However, there are a few areas where the transition could be more consistent and complete:

  1. Ensure all Pandas-based classes have appropriate deprecation warnings.
  2. Consider creating Polars-based versions of all existing classes (e.g., ImputationPolarsDataset).
  3. Update utility functions like ampute_data to work with Polars DataFrames.
  4. Provide clear documentation or a migration guide to help users transition from Pandas-based to Polars-based classes.
  5. Review and optimize all methods in the new Polars-based classes to fully leverage Polars operations for maximum efficiency.

Addressing these points will help complete the transition to Polars and provide a more consistent and efficient codebase.

icu_benchmarks/data/split_process_data.py (8)

Line range hint 20-42: LGTM! Enhanced flexibility with new parameters.

The function signature changes improve flexibility, particularly with modality selection. The default preprocessor change aligns with the Polars transition.

Suggestion: Consider using None as the default for modality_mapping instead of an empty dict to avoid potential issues with mutable default arguments.

-    modality_mapping: dict[str] = {},
+    modality_mapping: dict[str] | None = None,

Then, initialize it within the function:

modality_mapping = modality_mapping or {}

This change would align with Python best practices for handling mutable default arguments.

🧰 Tools
🪛 Ruff

27-27: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


Line range hint 70-137: LGTM! Improved robustness and flexibility in data preprocessing.

The changes enhance the function's capabilities, particularly in label handling, modality selection, and the transition to Polars for data loading. The improved logging provides better visibility into the preprocessing steps.

Suggestion: In the label handling logic, consider using a more explicit approach for selecting the default label:

- if label is not None:
-     vars[Var.label] = [label]
- else:
-     logging.debug(f"Multiple labels found and no value provided. Using first label: {vars[Var.label]}")
-     vars[Var.label] = vars[Var.label][0]
+ if label is not None:
+     vars[Var.label] = [label]
+ elif isinstance(vars[Var.label], list) and len(vars[Var.label]) > 0:
+     logging.debug(f"Multiple labels found and no value provided. Using first label: {vars[Var.label][0]}")
+     vars[Var.label] = [vars[Var.label][0]]
+ else:
+     raise ValueError("No valid label found or provided.")

This change ensures that there's always a valid label and provides a clear error message if no valid label is found.


159-180: LGTM! Improved preprocessing and data quality checks.

The changes enhance the preprocessing step with better timing measurement and more robust NaN/null handling using Polars capabilities. The detailed logging of data quality issues is a valuable addition.

Suggestion: Consider consolidating the NaN and null handling steps:

- dict[key] = val.fill_null(strategy="zero")
- dict[key] = val.fill_nan(0)
+ dict[key] = val.fill_null(strategy="zero").fill_nan(0)

This change combines the two operations into a single method chain, which might be slightly more efficient and readable.


193-214: LGTM! Valuable addition for data sanitization.

The check_sanitize_data function is a great addition for ensuring data quality by removing duplicates from different data segments. It effectively uses Polars for efficient data manipulation.

Suggestion: Consider adding error handling for cases where expected columns are missing:

 def check_sanitize_data(data, vars):
     """Check for duplicates in the loaded data and remove them."""
     group = vars[Var.group] if Var.group in vars.keys() else None
     sequence = vars[Var.sequence] if Var.sequence in vars.keys() else None
+    if group is None:
+        raise ValueError("Group variable not found in vars dictionary.")
     keep = "last"
     if Segment.static in data.keys():
+        if group not in data[Segment.static].columns:
+            raise ValueError(f"Group column '{group}' not found in static data.")
         old_len = len(data[Segment.static])
         data[Segment.static] = data[Segment.static].unique(subset=group, keep=keep, maintain_order=True)
         logging.warning(f"Removed {old_len - len(data[Segment.static])} duplicates from static data.")
     if Segment.dynamic in data.keys():
+        if group not in data[Segment.dynamic].columns or sequence not in data[Segment.dynamic].columns:
+            raise ValueError(f"Group column '{group}' or sequence column '{sequence}' not found in dynamic data.")
         old_len = len(data[Segment.dynamic])
         data[Segment.dynamic] = data[Segment.dynamic].unique(subset=[group, sequence], keep=keep, maintain_order=True)
         logging.warning(f"Removed {old_len - len(data[Segment.dynamic])} duplicates from dynamic data.")
     # ... (similar checks for outcome data)
     return data

This addition will help catch potential issues early in the data processing pipeline.

🧰 Tools
🪛 Ruff

195-195: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


196-196: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


198-198: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


202-202: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


206-206: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


217-243: LGTM! Excellent addition for flexible modality selection.

The modality_selection function adds valuable flexibility to the data processing pipeline, allowing users to select specific modalities. It effectively uses the modality_mapping for column filtering and updates both data and vars dictionaries.

Suggestion: Consider adding more robust error handling and validation:

 def modality_selection(
     data: dict[pl.DataFrame], modality_mapping: dict[str], selected_modalities: list[str], vars
 ) -> dict[pl.DataFrame]:
     logging.info(f"Selected modalities: {selected_modalities}")
+    if not modality_mapping:
+        raise ValueError("Modality mapping is empty or None.")
+    if not selected_modalities or selected_modalities == ["all"]:
+        logging.info("No specific modalities selected. Using all columns.")
+        return data, vars
     selected_columns = [modality_mapping[cols] for cols in selected_modalities if cols in modality_mapping.keys()]
     if not any(col in modality_mapping.keys() for col in selected_modalities):
         raise ValueError("None of the selected modalities found in modality mapping.")
-    if selected_columns == []:
-        logging.info("No columns selected. Using all columns.")
-        return data, vars
+    if not selected_columns:
+        raise ValueError("No columns were selected based on the provided modalities.")
     selected_columns = sum(selected_columns, [])
     selected_columns.extend([vars[Var.group], vars[Var.label], vars[Var.sequence]])
     # ... (rest of the function remains the same)

These changes provide more explicit handling of edge cases and clearer error messages.

🧰 Tools
🪛 Ruff

221-221: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


222-222: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


239-239: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


Line range hint 245-330: LGTM! Improved compatibility and code organization.

The changes to make_train_val function successfully add support for Polars DataFrames while maintaining compatibility with Pandas. The new helper functions _get_stays and _get_labels improve code organization and readability.

Suggestion: Consider updating the type hints to reflect the dual support for Pandas and Polars:

-def make_train_val(
-    data: dict[pd.DataFrame],
-    vars: dict[str],
-    train_size=0.8,
-    seed: int = 42,
-    debug: bool = False,
-    runmode: RunMode = RunMode.classification,
-    polars: bool = True,
-) -> dict[dict[pl.DataFrame]]:
+from typing import Union, Dict
+
+def make_train_val(
+    data: Dict[str, Union[pd.DataFrame, pl.DataFrame]],
+    vars: Dict[str, str],
+    train_size: float = 0.8,
+    seed: int = 42,
+    debug: bool = False,
+    runmode: RunMode = RunMode.classification,
+    polars: bool = True,
+) -> Dict[str, Dict[str, Union[pd.DataFrame, pl.DataFrame]]]:

This change more accurately represents the function's ability to work with both Pandas and Polars DataFrames.

🧰 Tools
🪛 Ruff

302-302: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


310-310: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


314-314: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


Line range hint 334-453: LGTM! Improved flexibility and error handling.

The changes to make_single_split function successfully add support for Polars DataFrames while maintaining Pandas compatibility. The enhanced error handling for classification tasks with insufficient samples is a valuable addition.

Suggestion: Consider extracting the data splitting logic into a separate function to improve readability and maintainability:

def _split_data(stays, labels, train_size, cv_repetitions, cv_folds, repetition_index, fold_index, seed, runmode):
    if Var.label in vars and runmode is RunMode.classification:
        if train_size:
            outer_cv = StratifiedShuffleSplit(cv_repetitions, train_size=train_size)
        else:
            outer_cv = StratifiedKFold(cv_repetitions, shuffle=True, random_state=seed)
        inner_cv = StratifiedKFold(cv_folds, shuffle=True, random_state=seed)
    else:
        if train_size:
            outer_cv = ShuffleSplit(cv_repetitions, train_size=train_size)
        else:
            outer_cv = KFold(cv_repetitions, shuffle=True, random_state=seed)
        inner_cv = KFold(cv_folds, shuffle=True, random_state=seed)

    dev, test = list(outer_cv.split(stays, labels))[repetition_index]
    dev_stays = stays[dev] if isinstance(stays, pl.Series) else stays.iloc[dev]
    dev_labels = labels[dev] if isinstance(labels, pl.Series) else labels.iloc[dev]
    train, val = list(inner_cv.split(dev_stays, dev_labels))[fold_index]

    return dev, test, train, val, dev_stays

# Use this function in make_single_split
dev, test, train, val, dev_stays = _split_data(stays, labels, train_size, cv_repetitions, cv_folds, repetition_index, fold_index, seed, runmode)

This refactoring would make the make_single_split function more concise and easier to maintain.


Line range hint 1-465: Great job on the Polars transition and new features!

The changes in this file significantly improve performance through the use of Polars while maintaining backward compatibility with Pandas. The new features, such as modality selection and improved data sanitization, add valuable flexibility to the data processing pipeline.

General suggestions for further improvement:

  1. Consider adding type hints consistently throughout the file, especially for the new Polars-related code.

  2. Add more comprehensive docstrings to the new functions (check_sanitize_data, modality_selection) explaining their parameters, return values, and any raised exceptions.

  3. Consider creating a configuration object or using a library like pydantic to manage and validate the various options and parameters used throughout the file.

  4. Add unit tests for the new functions and the updated logic to ensure everything works correctly with both Polars and Pandas.

  5. Consider adding a feature flag or configuration option to completely switch to Polars in the future, removing the Pandas code paths for simplicity once the transition is complete.

Overall, this is a solid improvement to the data processing pipeline. Great work!

🧰 Tools
🪛 Ruff

302-302: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


310-310: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


314-314: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

icu_benchmarks/models/wrappers.py (3)

33-35: LGTM: New metrics configured for gin.

The addition of average_precision_score and roc_auc_score to the gin configuration is a good improvement for model evaluation.

Consider addressing the commented-out line.

There's a commented-out line for scorer_wrapper. If this is no longer needed, consider removing it to keep the code clean.


491-505: LGTM: New methods for model explanation and output saving.

The addition of _explain_model and _save_model_outputs methods enhances the class's capabilities for model interpretability and debugging.

Minor formatting issue in f-string.

On line 503, there's an unnecessary f-string:

logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / f'row_indicators.csv'}")

The inner f-string can be removed:

logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / 'row_indicators.csv'}")
🧰 Tools
🪛 Ruff

503-503: f-string without any placeholders

Remove extraneous f prefix

(F541)


Line range hint 1-530: Overall assessment: Significant improvements with minor issues.

This update to icu_benchmarks/models/wrappers.py introduces several valuable enhancements:

  1. New metrics and improved metric handling
  2. Enhanced debugging and model explanation capabilities
  3. More flexible weight setting for handling imbalanced datasets

There are a few minor issues to address:

  1. Unused row_indicators parameter in multiple methods
  2. Unused loss_weight attribute in the MLWrapper class
  3. A minor formatting issue with an f-string

These changes significantly improve the functionality and flexibility of the wrapper classes. Addressing the minor issues will further refine the code quality.

🧰 Tools
🪛 Ruff

503-503: f-string without any placeholders

Remove extraneous f prefix

(F541)

icu_benchmarks/data/preprocessor.py (7)

50-145: LGTM! Comprehensive Polars-based classification preprocessor implementation.

The PolarsClassificationPreprocessor class provides a robust implementation for preprocessing classification data using Polars. It handles both static and dynamic features, applies scaling, and generates additional features when required. The implementation is well-structured and covers various preprocessing scenarios.

Consider optimizing the join operations in the apply method. Instead of performing separate joins for each split (train, val, test), you could create a function to handle the join operation and apply it to all splits. This would reduce code duplication and improve maintainability. For example:

def join_static_dynamic(dynamic, static, group_var):
    return dynamic.join(static, on=group_var)

# Then in the apply method:
for split in [Split.train, Split.val, Split.test]:
    data[split][Segment.dynamic] = join_static_dynamic(
        data[split][Segment.dynamic],
        data[split][Segment.static],
        vars["GROUP"]
    )

147-200: LGTM! Comprehensive static and dynamic data preprocessing.

The _process_static and _process_dynamic methods provide thorough preprocessing for static and dynamic data using Polars. They handle scaling, imputation, and feature generation effectively. The implementation follows best practices for data preprocessing.

Consider adding error handling for cases where the required columns might be missing in the data. For example, in the _process_dynamic method:

def _process_dynamic(self, data, vars):
    required_columns = set(vars[Segment.dynamic] + [vars["GROUP"], vars["SEQUENCE"]])
    missing_columns = required_columns - set(data[Split.train][Segment.dynamic].columns)
    if missing_columns:
        raise ValueError(f"Missing required columns: {missing_columns}")
    
    # Rest of the method implementation...

This addition would make the code more robust against potential data inconsistencies.


218-281: LGTM! Well-implemented regression preprocessor with outcome scaling.

The PolarsRegressionPreprocessor class effectively extends the classification preprocessor to handle regression tasks. It includes outcome processing with scaling based on predefined or data-driven ranges, which is crucial for regression problems.

Consider adding an option to skip outcome scaling entirely. This could be useful in cases where the outcome is already in the desired range or when preserving the original scale is important. You could add a scale_outcome parameter to the constructor:

def __init__(
    self,
    generate_features: bool = False,
    scaling: bool = True,
    use_static_features: bool = True,
    outcome_max=None,
    outcome_min=None,
    scale_outcome: bool = True,
    save_cache=None,
    load_cache=None,
):
    # ... existing code ...
    self.scale_outcome = scale_outcome

def _process_outcome(self, data, vars, split):
    if not self.scale_outcome:
        return data
    # ... existing scaling code ...

This addition would provide more flexibility in handling different regression scenarios.


Line range hint 284-480: LGTM! Clear renaming and consistency in Pandas-based preprocessors.

The renaming of Pandas-based preprocessor classes improves clarity and consistency with the new Polars-based classes. The structure remains similar, which is good for code organization and maintainability.

While the current implementation maintains separate Polars and Pandas classes, consider for future refactoring:

  1. Extracting common preprocessing logic into a shared base class or utility functions.
  2. Using a strategy pattern to switch between Polars and Pandas implementations.

This could reduce code duplication and make it easier to maintain both versions in the long run. For example:

class BasePreprocessor(abc.ABC):
    @abc.abstractmethod
    def _process_static(self, data, vars):
        pass

    @abc.abstractmethod
    def _process_dynamic(self, data, vars):
        pass

class PolarsPreprocessor(BasePreprocessor):
    # Polars-specific implementations

class PandasPreprocessor(BasePreprocessor):
    # Pandas-specific implementations

class ClassificationPreprocessor:
    def __init__(self, backend='polars'):
        self.preprocessor = PolarsPreprocessor() if backend == 'polars' else PandasPreprocessor()

    def apply(self, data, vars):
        return self.preprocessor.apply(data, vars)

This structure would allow easier maintenance and extension of both Polars and Pandas implementations.


Line range hint 481-538: LGTM! Well-implemented imputation preprocessor with customization options.

The PandasImputationPreprocessor class provides a specialized preprocessor for imputation tasks with good flexibility through its customization options. The implementation is correct and consistent with the established patterns in the file.

Consider enhancing the logging in the _process_dynamic_data method to provide more detailed information about the removed data. This could be helpful for debugging and monitoring the preprocessing process. For example:

def _process_dynamic_data(self, data, vars):
    if self.filter_missing_values:
        rows_to_remove = data[Segment.dynamic][vars[Segment.dynamic]].isna().sum(axis=1) != 0
        ids_to_remove = data[Segment.dynamic].loc[rows_to_remove][vars["GROUP"]].unique()
        original_size = {table_name: len(table) for table_name, table in data.items()}
        data = {table_name: table.loc[~table[vars["GROUP"]].isin(ids_to_remove)] for table_name, table in data.items()}
        removed_size = {table_name: original_size[table_name] - len(table) for table_name, table in data.items()}
        logging.info(f"Removed {len(ids_to_remove)} stays with missing values.")
        logging.info(f"Rows removed per table: {removed_size}")
    return data

This addition would provide more detailed information about the impact of filtering missing values on each data segment.


Line range hint 541-594: LGTM! Robust utility functions for recipe handling with multi-library support.

The utility functions for applying recipes to splits, caching recipes, and restoring recipes are well-implemented and now support both Polars and Pandas DataFrames. This flexibility is valuable for supporting different data processing libraries.

Consider enhancing the error handling in the restore_recipe function to provide more informative error messages. For example:

def restore_recipe(cache_file: str) -> Recipe:
    """Restore recipe from cache to use for e.g. transfer learning."""
    try:
        if not cache_file.exists():
            raise FileNotFoundError(f"Cache file {cache_file} not found.")
        
        with open(cache_file, "rb") as f:
            logging.info(f"Loading cached recipe from {cache_file}.")
            recipe = pickle.load(f)
            return recipe
    except (pickle.UnpicklingError, EOFError) as e:
        raise ValueError(f"Error unpickling recipe from {cache_file}: {str(e)}") from e
    except Exception as e:
        raise RuntimeError(f"Unexpected error while restoring recipe from {cache_file}: {str(e)}") from e

This enhancement would provide more specific error messages, making it easier to diagnose issues with recipe restoration.


Line range hint 1-594: Excellent enhancements to the preprocessing module with multi-library support.

The changes to icu_benchmarks/data/preprocessor.py significantly improve the flexibility and capabilities of the preprocessing module. Key improvements include:

  1. Addition of Polars support alongside Pandas, potentially offering performance benefits.
  2. New preprocessor classes for classification, regression, and imputation tasks using both Polars and Pandas.
  3. Consistent code structure and naming conventions across different preprocessor types.
  4. Enhanced utility functions supporting multiple DataFrame types.

These changes provide a robust foundation for handling various preprocessing scenarios in the ICU benchmarks project. The multi-library support allows for easy comparison and selection of the most suitable data processing approach for different use cases.

For future development, consider:

  1. Implementing benchmarks to compare the performance of Polars vs. Pandas implementations for typical preprocessing tasks in your project.
  2. Creating a factory pattern for preprocessor creation to simplify the selection between Polars and Pandas implementations based on configuration or data characteristics.
  3. Developing a comprehensive suite of unit tests to ensure the consistency of results between Polars and Pandas implementations.

These steps would further enhance the robustness and usability of the preprocessing module.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Files that changed from the base of the PR and between 4ed54cf and 47c594a.

📒 Files selected for processing (15)
  • icu_benchmarks/cross_validation.py (6 hunks)
  • icu_benchmarks/data/loader.py (7 hunks)
  • icu_benchmarks/data/pooling.py (7 hunks)
  • icu_benchmarks/data/preprocessor.py (8 hunks)
  • icu_benchmarks/data/split_process_data.py (10 hunks)
  • icu_benchmarks/models/custom_metrics.py (2 hunks)
  • icu_benchmarks/models/dl_models/rnn.py (1 hunks)
  • icu_benchmarks/models/dl_models/tcn.py (1 hunks)
  • icu_benchmarks/models/dl_models/transformer.py (1 hunks)
  • icu_benchmarks/models/ml_models/catboost.py (1 hunks)
  • icu_benchmarks/models/ml_models/xgboost.py (1 hunks)
  • icu_benchmarks/models/wrappers.py (11 hunks)
  • icu_benchmarks/run_utils.py (6 hunks)
  • icu_benchmarks/tuning/hyperparameters.py (8 hunks)
  • requirements.txt (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (5)
  • icu_benchmarks/cross_validation.py
  • icu_benchmarks/data/pooling.py
  • icu_benchmarks/models/dl_models/tcn.py
  • icu_benchmarks/models/ml_models/catboost.py
  • requirements.txt
🧰 Additional context used
🪛 Ruff
icu_benchmarks/data/split_process_data.py

27-27: Do not use mutable data structures for argument defaults

Replace with None; initialize within function

(B006)


122-122: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


127-127: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


127-127: Test for membership should be not in

Convert to not in

(E713)


195-195: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


196-196: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


198-198: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


202-202: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


206-206: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


221-221: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


222-222: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


239-239: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


302-302: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


310-310: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


314-314: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


420-423: Use ternary operator dev_stays = stays[dev] if polars else stays.iloc[dev] instead of if-else-block

Replace if-else-block with dev_stays = stays[dev] if polars else stays.iloc[dev]

(SIM108)


439-439: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


447-447: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


451-451: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

icu_benchmarks/models/dl_models/transformer.py

47-47: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)


62-62: Loop control variable i not used within loop body

Rename unused i to _i

(B007)


107-107: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)


124-124: Loop control variable i not used within loop body

Rename unused i to _i

(B007)

icu_benchmarks/models/ml_models/xgboost.py

27-27: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)


76-76: f-string without any placeholders

Remove extraneous f prefix

(F541)

icu_benchmarks/models/wrappers.py

503-503: f-string without any placeholders

Remove extraneous f prefix

(F541)

icu_benchmarks/run_utils.py

18-18: os imported but unused

Remove unused import: os

(F401)


19-19: glob imported but unused

Remove unused import: glob

(F401)

🔇 Additional comments (23)
icu_benchmarks/models/dl_models/rnn.py (1)

1-84: Overall, excellent implementation of RNN, LSTM, and GRU models.

This file provides well-structured and correct implementations of three recurrent neural network architectures: RNN, LSTM, and GRU. The use of @gin.configurable for all classes allows for flexible configuration, and the support for both classification and regression run modes adds versatility to the models.

Minor suggestions for improvement across all classes:

  1. Add type hints to method parameters and return values for improved code readability and maintainability.
  2. Consider adding docstrings to methods for better documentation.

Great job on addressing the issues mentioned in previous reviews and creating a solid foundation for these RNN-based models.

icu_benchmarks/models/dl_models/transformer.py (2)

9-18: LGTM: BaseTransformer class implementation

The BaseTransformer class is well-structured and implements common functionality for transformer-based models. The forward method correctly follows the standard transformer architecture flow, promoting code reuse and maintainability.


1-140: Overall assessment: Good implementation with room for improvement

The implementation of transformer-based models for ICU benchmarks is correct and follows the HiRID-Benchmark specifications. The code structure is generally good, with clear separation of concerns between the BaseTransformer, Transformer, and LocalTransformer classes.

However, there are opportunities for improvement:

  1. Refactoring the initialization process to use configuration objects can enhance readability and maintainability.
  2. Reducing code duplication between Transformer and LocalTransformer classes can improve the overall code structure.
  3. Minor improvements like removing unused loop variables can be made.

Implementing the suggested refactoring will result in a more robust and maintainable codebase while preserving the current functionality.

🧰 Tools
🪛 Ruff

47-47: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)


62-62: Loop control variable i not used within loop body

Rename unused i to _i

(B007)


107-107: Star-arg unpacking after a keyword argument is strongly discouraged

(B026)


124-124: Loop control variable i not used within loop body

Rename unused i to _i

(B007)

icu_benchmarks/models/ml_models/xgboost.py (3)

1-25: LGTM: Imports and class definition are well-structured.

The imports are appropriate for the XGBoost classifier implementation. The use of @gin.configurable decorator allows for flexible configuration, and the _supported_run_modes attribute correctly specifies the classification run mode.


30-40: LGTM: predict method is well-implemented and documented.

The predict method correctly uses the model's predict_proba method to return class probabilities. The docstring provides clear information about the method's functionality, inputs, and outputs.


96-99: LGTM: get_feature_importance method is well-implemented.

The get_feature_importance method correctly retrieves feature importances from the trained model. It includes a proper check to ensure the model has been fit before attempting to access the feature importances. The error message is clear and informative, guiding the user to call fit_model() if needed.

icu_benchmarks/models/custom_metrics.py (2)

5-6: LGTM: New imports are appropriate and well-structured.

The addition of ndarray from numpy for type hinting and the alias for confusion_matrix from sklearn.metrics are good practices. This helps improve code readability and avoids potential naming conflicts.


136-145: ⚠️ Potential issue

Function implementation needs refinement.

The confusion_matrix function has a few issues that need to be addressed:

  1. The return type in the function signature (torch.tensor) doesn't match the actual return type (dictionary).
  2. There's no input validation, which could lead to runtime errors if incorrect data types are provided.
  3. The function doesn't handle the case where y_true and y_pred have different shapes.

Here's a suggested improvement:

def confusion_matrix(y_true: ndarray, y_pred: ndarray, normalize: bool = False) -> dict:
    if y_true.shape != y_pred.shape:
        raise ValueError("y_true and y_pred must have the same shape")
    
    y_pred = np.rint(y_pred).astype(int)
    confusion = sk_confusion_matrix(y_true, y_pred)
    
    if normalize:
        confusion = confusion.astype('float') / confusion.sum(axis=1)[:, np.newaxis]
    
    confusion_dict = {f"class_{i}_pred_{j}": confusion[i][j] 
                      for i in range(confusion.shape[0]) 
                      for j in range(confusion.shape[1])}
    
    return confusion_dict

This implementation:

  • Corrects the return type annotation
  • Adds basic input validation
  • Uses a more efficient dictionary comprehension
  • Improves normalization to avoid division by zero

To ensure the function is used correctly throughout the codebase, you can run:

This will help identify any places where the function is called, allowing you to verify that the changes won't break existing code.

✅ Verification successful

Confusion Matrix Function Return Type Mismatch.

  • The confusion_matrix function's return type annotation (torch.tensor) does not match its actual return type (dict).
  • It is only used in icu_benchmarks/models/wrappers.py, where the returned dictionary is passed to self.log_dict(), which expects a dictionary.
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for uses of the confusion_matrix function
rg "confusion_matrix\(" --type python

Length of output: 70


Script:

#!/bin/bash
# Search for uses of the confusion_matrix function
rg "confusion_matrix\(" --type py

Length of output: 408


Script:

#!/bin/bash
# Get the lines surrounding the confusion_matrix function call in wrappers.py
rg "confusion_matrix\(" --context 2 icu_benchmarks/models/wrappers.py

Length of output: 315

icu_benchmarks/run_utils.py (1)

Line range hint 1-301: Overall, excellent enhancements to the run_utils.py file

The changes made to this file significantly improve the functionality and flexibility of the ICU benchmarking framework. Key improvements include:

  1. Enhanced command-line argument parsing with support for modalities and labels.
  2. Improved directory creation logic to prevent conflicts.
  3. Addition of SHAP value aggregation in the results processing.
  4. New utility functions for discovering configuration files and validating required keys.

These changes make the framework more robust and adaptable to different scenarios. The code is generally well-structured and readable, with only minor suggestions for improvement provided in the previous comments.

Great job on these enhancements! The added functionality will greatly benefit users of this framework.

icu_benchmarks/tuning/hyperparameters.py (3)

5-5: LGTM: New imports for Optuna and visualization.

The added imports for matplotlib.pyplot and Optuna visualization functions are appropriate for the new Optuna-based implementation and plotting capabilities.

Also applies to: 18-18


25-25: LGTM: Proper deprecation of the scikit-optimize function.

The renaming of choose_and_bind_hyperparameters to choose_and_bind_hyperparameters_scikit_optimize and the addition of a deprecation warning are good practices. This approach maintains backwards compatibility while clearly signaling the transition to the new Optuna-based implementation.

Also applies to: 66-67


Line range hint 1-383: Overall: Significant improvement with transition to Optuna.

The transition from scikit-optimize to Optuna for hyperparameter tuning is a substantial improvement. Optuna offers more advanced tuning capabilities and better visualization options. The overall structure of the code is good, incorporating the new Optuna-based approach effectively.

Key improvements:

  1. Advanced hyperparameter tuning with Optuna.
  2. Addition of visualization capabilities.
  3. Proper deprecation of the old implementation.

Areas for further improvement:

  1. Enhanced documentation, especially for new parameters.
  2. Refactoring of some complex sections for better clarity and maintainability.
  3. Expansion of plotting capabilities.

Overall, these changes significantly enhance the hyperparameter tuning functionality of the project.

icu_benchmarks/data/loader.py (1)

10-10: Appropriate addition of Polars library import

The import of the Polars library is consistent with the transition from Pandas to Polars for data handling. This change aligns with the overall goal of improving performance and reducing memory usage.

icu_benchmarks/data/split_process_data.py (2)

3-15: LGTM! Transition to Polars for improved performance.

The import changes reflect the shift from Pandas to Polars for data processing, which should lead to improved performance. The addition of 'os' and 'timeit' imports suggests new file operations and performance measurement capabilities.


Line range hint 456-465: LGTM, but consider reviewing for Polars compatibility.

The caching function remains unchanged, which is fine if it's working as expected. However, given the transition to Polars in other parts of the code, it might be worth reviewing this function to ensure it's fully compatible with Polars DataFrames.

Consider testing the caching mechanism with Polars DataFrames to confirm that serialization and deserialization work correctly. If issues are found, you may need to update the pickling process or consider using a different serialization method that's optimized for Polars.

🧰 Tools
🪛 Ruff

420-423: Use ternary operator dev_stays = stays[dev] if polars else stays.iloc[dev] instead of if-else-block

Replace if-else-block with dev_stays = stays[dev] if polars else stays.iloc[dev]

(SIM108)


439-439: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


447-447: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)


451-451: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

icu_benchmarks/models/wrappers.py (7)

4-6: LGTM: New imports added.

The addition of these imports (Path, torchmetrics, and additional sklearn metrics) is appropriate for the changes made in the file.


19-25: LGTM: Import statements updated.

The addition of the confusion_matrix import and the change in the RunMode import location are consistent with the file's purpose and likely reflect changes in the project structure.


49-50: LGTM: New attributes added to BaseModule.

The addition of debug and explain_features attributes provides useful functionality for debugging and model interpretability.


67-74: LGTM: Enhanced set_weight method.

The updated set_weight method now supports dataset-specific weight balancing, which is a valuable improvement for handling imbalanced datasets.


410-411: Unused parameter persists.

The row_indicators parameter is still not used in the method body. This issue has been previously identified and remains unaddressed.


Line range hint 437-457: LGTM: Enhanced debugging and feature explanation.

The addition of debug output saving and feature explanation functionality in the test_step method is a valuable improvement for model interpretability and debugging.


476-487: LGTM: Improved metric logging.

The updates to the log_metrics method, including special handling for the Confusion_Matrix and refined metric logging logic, enhance the class's ability to track and report model performance.

icu_benchmarks/data/preprocessor.py (1)

9-9: LGTM! Enhanced flexibility with Polars integration and vars_to_exclude parameter.

The addition of Polars library import and the modification to the Preprocessor base class to include vars_to_exclude parameter provide more flexibility in data preprocessing. This change allows for better control over which variables are processed, which can be particularly useful in complex preprocessing scenarios.

Also applies to: 34-34

icu_benchmarks/models/dl_models/rnn.py Outdated Show resolved Hide resolved
icu_benchmarks/models/dl_models/transformer.py Outdated Show resolved Hide resolved
icu_benchmarks/models/dl_models/transformer.py Outdated Show resolved Hide resolved
icu_benchmarks/models/ml_models/xgboost.py Outdated Show resolved Hide resolved
icu_benchmarks/models/ml_models/xgboost.py Outdated Show resolved Hide resolved
icu_benchmarks/data/loader.py Show resolved Hide resolved
icu_benchmarks/data/loader.py Show resolved Hide resolved
icu_benchmarks/data/loader.py Show resolved Hide resolved
icu_benchmarks/data/loader.py Show resolved Hide resolved
icu_benchmarks/models/wrappers.py Show resolved Hide resolved
@rvandewater rvandewater merged commit 4dd915e into development Oct 17, 2024
1 of 2 checks passed
@rvandewater rvandewater deleted the version_release branch October 17, 2024 15:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant