-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Version release based on practice Cassandra use #155
Conversation
# Conflicts: # environment.yml
…ary for ML models
Caution Review failedThe pull request is closed. WalkthroughThe changes in this pull request introduce new configuration files for various classifiers, including Balanced Random Forest, CatBoost, and XGBoost, along with modifications to existing configurations for models like GRU, Random Forest, and Transformer. Additionally, significant updates were made to support data handling using the Polars library, enhancing data loading and preprocessing functionalities. New scripts for benchmarking and job scheduling on HPC clusters were added, while several files were created for new deep learning architectures, including RNNs and Transformers. Various updates were also made to enhance hyperparameter tuning and model training processes. Changes
Poem
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Caution
Inline review comments failed to post
🛑 Comments failed to post (80)
environment.yml (1)
5-5: 🛠️ Refactor suggestion
Consider upgrading to a newer Python version
The environment is currently using Python 3.10. While this is still a supported version, newer versions (like 3.11 or 3.12) offer performance improvements and new features.
Consider upgrading to a newer Python version if your project doesn't have specific dependencies requiring Python 3.10. This could potentially improve your application's performance and allow you to use newer language features.
experiments/charhpc_wandb_sweep.sh (2)
10-13: 🛠️ Refactor suggestion
Simplify Conda initialization.
The current Conda initialization uses both
source
andeval
methods, which might be redundant. You can simplify this by using only one method.Consider updating the Conda initialization as follows:
-source /etc/profile.d/conda.sh - -eval "$(conda shell.bash hook)" +. /etc/profile.d/conda.shThis change uses the
.
command (equivalent tosource
) to load the Conda profile, which should be sufficient for initializing Conda in most environments.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.. /etc/profile.d/conda.sh conda activate yaib_req_pl
14-14: 🛠️ Refactor suggestion
Add error handling for the sweep ID argument.
The WandB agent command looks good, but it assumes that the sweep ID will always be provided as an argument. It's a good practice to add error handling to ensure the script fails gracefully if the argument is missing.
Consider adding a check for the command-line argument before running the WandB agent:
+if [ $# -eq 0 ]; then + echo "Error: Sweep ID not provided" + echo "Usage: $0 <sweep_id>" + exit 1 +fi + wandb agent --count 1 cassandra_hpi/cassandra/"$1"This change will print an error message and exit if no argument is provided, making the script more robust.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.if [ $# -eq 0 ]; then echo "Error: Sweep ID not provided" echo "Usage: $0 <sweep_id>" exit 1 fi wandb agent --count 1 cassandra_hpi/cassandra/"$1"
configs/tasks/DatasetImputation.gin (1)
26-26: 💡 Codebase verification
Potential Conflicts Detected in
Dataloader.gin
ImputationPandasDataset.vars
andImputationPandasDataset.ram_cache
are still defined inconfigs/tasks/common/Dataloader.gin
.These settings may conflict with previous configurations where
ImputationDataset.vars
andImputationDataset.ram_cache
were removed or altered.Recommended Actions:
- Review the
Dataloader.gin
file to determine if theImputationPandasDataset
configurations should be updated or removed to align with the intended changes.- Ensure that all related configurations across the codebase are consistent with the updated data loading approach.
🔗 Analysis chain
Confirm the contents and impact of the new Dataloader configuration
The addition of
include "configs/tasks/common/Dataloader.gin"
suggests a move towards a more standardized approach for data loading. This change could improve consistency across different tasks or datasets.To ensure this change is implemented correctly and doesn't introduce any issues, please run the following script:
Please review the output to ensure that:
- The Dataloader.gin file exists and contains the expected configurations.
- Other relevant files are also updated to use this new Dataloader configuration.
- There are no conflicts with existing configurations, especially regarding the previously removed
ImputationDataset.vars
andImputationDataset.ram_cache
.🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify the contents and usage of the new Dataloader configuration echo "Contents of configs/tasks/common/Dataloader.gin:" cat configs/tasks/common/Dataloader.gin echo "Checking for other files that include Dataloader.gin:" rg 'include.*"configs/tasks/common/Dataloader.gin"' --type gin echo "Checking for potential conflicts or overrides with existing configurations:" rg "ImputationDataset|vars|ram_cache" configs/tasks/common/Dataloader.ginLength of output: 1183
Script:
#!/bin/bash # Description: Verify the contents and usage of the new Dataloader configuration echo "Contents of configs/tasks/common/Dataloader.gin:" cat configs/tasks/common/Dataloader.gin echo "Checking for other files that include Dataloader.gin:" rg 'include\s+"configs/tasks/common/Dataloader\.gin"' --ignore-case echo "Checking for potential conflicts or overrides with existing configurations:" rg "ImputationDataset|vars|ram_cache" configs/tasks/common/Dataloader.ginLength of output: 1390
docs/adding_model/RNN.gin (1)
6-7: 💡 Codebase verification
Issue Found:
RNNet
Class Definition MissingThe
@RNNet
model is specified in the train params, but theRNNet
class definition was not found in the codebase.🔗 Analysis chain
LGTM: Clear train params configuration. Verify @rnnet definition.
The train params configuration is concise and clear. However, it's important to ensure that
@RNNet
is properly defined elsewhere in the codebase.Run the following script to verify the
RNNet
class definition:🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify the existence and definition of the RNNet class. # Test: Search for the RNNet class definition rg --type python -A 5 'class RNNet'Length of output: 1678
configs/prediction_models/RFClassifier.gin (1)
16-16: 💡 Codebase verification
Inconsistent tuple syntax found in configuration files.
The following files still use tuple syntax forclass_weight
and should be updated to use list syntax for consistency:
configs/prediction_models/RFClassifier.gin
configs/prediction_models/BRFClassifier.gin
🔗 Analysis chain
LGTM! Verify consistency across all configuration files.
The change from tuple to list syntax for
class_weight
is appropriate and consistent with the changes made to other hyperparameters.To ensure consistency across the entire project, please run the following script to check if all similar configuration files have been updated with the new list syntax:
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Check for inconsistent syntax in gin configuration files # Test: Search for tuple syntax in hyperparameter definitions rg --type gin -g 'configs/prediction_models/*.gin' '(True, False)|(None,)|(None, .+)' -l # If the above command returns any results, it indicates files that still use tuple syntax and may need updating.Length of output: 118
Script:
#!/bin/bash # Description: Check for inconsistent tuple syntax in .gin configuration files # Test: Search for tuple syntax in class_weight hyperparameter definitions within .gin files rg '(True, False)|(None,)|(None, .+)' -g '*.gin' -lLength of output: 140
configs/prediction_models/Transformer.gin (1)
22-22: 🛠️ Refactor suggestion
⚠️ Potential issueConsider adjusting the upper bound for attention dropout
The attention dropout range has been expanded from (0.0, 0.4) to (0.0, 1.0). An upper bound of 1.0 for dropout is extremely high and could potentially degrade model performance significantly.
Consider adjusting the upper bound to a more reasonable value, such as 0.5:
-model/hyperparameter.dropout_att = (0.0, 1.0) +model/hyperparameter.dropout_att = (0.0, 0.5)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.model/hyperparameter.dropout_att = (0.0, 0.5)
configs/prediction_models/XGBClassifier.gin (1)
9-17: 🛠️ Refactor suggestion
Consider narrowing the range for
scale_pos_weight
The hyperparameter settings look good overall and cover important parameters for XGBoost tuning. However, the range for
scale_pos_weight
is very wide (1 to 1000). This might lead to suboptimal results or unnecessarily long tuning times.Consider narrowing down the range based on your dataset's class imbalance ratio. A typical approach is to set it close to the ratio of negative to positive instances. For example, if you have a 1:10 imbalance, you might use:
-model/hyperparameter.scale_pos_weight = [1, 5, 10, 15, 20, 25, 30, 35, 40, 50, 75, 99, 100, 1000] +model/hyperparameter.scale_pos_weight = [1, 5, 10, 15, 20]This change would focus the search on more relevant values and potentially speed up the tuning process.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.model/hyperparameter.class_to_tune = @XGBClassifier model/hyperparameter.learning_rate = (0.01, 0.1, "log") model/hyperparameter.n_estimators = [50, 100, 250, 500, 750, 1000,1500,2000] model/hyperparameter.max_depth = [3, 5, 10, 15] model/hyperparameter.scale_pos_weight = [1, 5, 10, 15, 20] model/hyperparameter.min_child_weight = [1, 0.5] model/hyperparameter.max_delta_step = [0, 1, 2, 3, 4, 5, 10] model/hyperparameter.colsample_bytree = [0.1, 0.25, 0.5, 0.75, 1.0] model/hyperparameter.eval_metric = "aucpr"
icu_benchmarks/models/ml_models/catboost.py (1)
14-17: 🛠️ Refactor suggestion
⚠️ Potential issueConsider clarifying model initialization and addressing potential issues.
The commented-out line suggests GPU support was considered. If GPU support is planned for the future, consider adding a TODO comment explaining this.
The use of
set_model_args
is good for flexibility, but it's unclear what arguments are being passed. Consider adding a comment explaining the expected arguments or their purpose.The static analysis tool flagged a potential issue with star-arg unpacking after a keyword argument on line 16. This could lead to unexpected behavior. Consider refactoring to avoid this issue:
def __init__(self, *args, **kwargs): model_args = self.set_model_args(*args, **kwargs) self.model = cb.CatBoostClassifier(task_type="CPU", **model_args) super().__init__(*args, **kwargs)This refactoring separates the argument preparation from the model initialization, which may be clearer and avoid the star-arg unpacking issue.
🧰 Tools
🪛 Ruff
16-16: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
configs/tasks/common/PredictionTaskVariables.gin (1)
20-25: 🛠️ Refactor suggestion
Consider the necessity and potential redundancy of the
modality_mapping
section.The newly added
modality_mapping
section appears to be an exact duplicate of the "DYNAMIC" and "STATIC" lists in thevars
section. While this addition might serve a specific purpose, it's worth considering the following points:
- Redundancy: Having identical information in two places can lead to maintenance issues if one section is updated and the other is forgotten.
- Purpose: The purpose of
modality_mapping
is not immediately clear from the context. It would be helpful to add a comment explaining its specific use case and how it differs fromvars
.- DRY principle: To adhere to the Don't Repeat Yourself (DRY) principle, consider refactoring this configuration to avoid duplication.
Consider one of the following approaches:
- If
modality_mapping
serves a distinct purpose, add a comment explaining its role and how it differs fromvars
.- If
modality_mapping
is intended to replacevars
, update all references to use the new structure and remove the old one.- If the two structures are meant to be identical, consider using a single source of truth:
dynamic_vars = ["alb", "alp", "alt", "ast", "be", "bicar", "bili", "bili_dir", "bnd", "bun", "ca", "cai", "ck", "ckmb", "cl", "crea", "crp", "dbp", "fgn", "fio2", "glu", "hgb", "hr", "inr_pt", "k", "lact", "lymph", "map", "mch", "mchc", "mcv", "methb", "mg", "na", "neut", "o2sat", "pco2", "ph", "phos", "plt", "po2", "ptt", "resp", "sbp", "temp", "tnt", "urine", "wbc"] static_vars = ["age", "sex", "height", "weight"] vars = { "GROUP": "stay_id", "SEQUENCE": "time", "LABEL": "label", "DYNAMIC": dynamic_vars, "STATIC": static_vars, } modality_mapping = { "DYNAMIC": dynamic_vars, "STATIC": static_vars, }This approach ensures that any updates to the variable lists are automatically reflected in both structures.
experiments/slurm_base_char_sc.sh (4)
22-22:
⚠️ Potential issueFix typo and address undefined DATASETS array.
- There's a typo in the "echo" command:
-echi "Task type:" ${TASK} +echo "Task type:" ${TASK}
- The DATASETS array is used but not defined. Either define it before use or replace it with the appropriate variable:
-echo "Dataset: "${DATASETS[$SLURM_ARRAY_TASK_ID]} +echo "Dataset: ${DATASET_ROOT_PATH}"If you intended to use multiple datasets, make sure to define the DATASETS array before this line.
Also applies to: 25-25
18-18:
⚠️ Potential issueRemove space after '=' in DATASET_ROOT_PATH assignment.
The space after the '=' in the DATASET_ROOT_PATH assignment can cause issues in Bash. Remove it to ensure proper variable assignment:
-DATASET_ROOT_PATH= /sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format +DATASET_ROOT_PATH=/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.DATASET_ROOT_PATH=/sc-projects/sc-proj-cc08-cassandra/Prospective_Preprocessed/yaib_format #data/YAIB_Datasets/data
🧰 Tools
🪛 Shellcheck
[warning] 18-18: Remove space after = if trying to assign a value (for empty string, use var='' ... ).
(SC1007)
28-28: 🛠️ Refactor suggestion
Improve error handling for directory change.
Add error handling to the
cd
command to exit the script if the directory change fails:-cd ${YAIB_PATH} +cd ${YAIB_PATH} || { echo "Failed to change directory to ${YAIB_PATH}"; exit 1; }This ensures that the script doesn't continue executing in the wrong directory if the
cd
command fails.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.cd ${YAIB_PATH} || { echo "Failed to change directory to ${YAIB_PATH}"; exit 1; }
🧰 Tools
🪛 Shellcheck
[warning] 28-28: Use 'cd ... || exit' or 'cd ... || return' in case cd fails.
(SC2164)
35-44:
⚠️ Potential issueAddress undefined variables and document --tune flag.
- The ${DATASETS} variable is used but not defined earlier in the script. Either define it or replace it with the appropriate variable:
- -n ${DATASETS} \ + -n ${DATASET_ROOT_PATH} \
- Consider adding a comment to explain the purpose of the --tune flag, e.g.:
+ # Enable hyperparameter tuning --tune
Ensure that all variables used in this command are properly defined earlier in the script.
Committable suggestion was skipped due to low confidence.
icu_benchmarks/models/constants.py (1)
48-50: 💡 Codebase verification
**Concerns About Removing `average_precision_score`** 1. The `average_precision_score` is still imported and used in multiple files: - `icu_benchmarks/models/wrappers.py` - `icu_benchmarks/models/utils.py` - `icu_benchmarks/models/constants.py` 2. Commenting out or removing it from `constants.py` may adversely affect these modules that rely on this metric. Could you please confirm whether the removal of `average_precision_score` is intentional and ensure that it does not impact existing functionalities?🔗 Analysis chain
Clarify the removal of average_precision_score and approve addition of confusion_matrix
The addition of
confusion_matrix
to theMULTICLASS_CLASSIFICATION
dictionary is appropriate and valuable for evaluating multiclass classification models.However, the
average_precision_score
has been commented out. This might impact the calculation of precision-recall metrics for multiclass scenarios. Could you please provide the rationale for removing this metric? If it's no longer needed, consider removing the commented line entirely to maintain code cleanliness.To ensure that the removal of
average_precision_score
doesn't negatively impact any existing code, let's verify its usage:
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Check for any remaining usage of average_precision_score in multiclass scenarios # Test: Search for usage of average_precision_score in Python files rg --type python 'average_precision_score.*multiclass'Length of output: 87
Script:
#!/bin/bash # Description: Check for any remaining usage of average_precision_score in multiclass scenarios # Search for usage of average_precision_score in Python files related to multiclass classification rg --glob '*.py' 'average_precision_score.*multiclass'Length of output: 56
Script:
#!/bin/bash # Description: Check for any remaining usage of average_precision_score in the codebase # Search for usage of average_precision_score in all Python files rg --glob '*.py' 'average_precision_score'Length of output: 885
experiments/benchmark_cass.yml (1)
29-75: 🛠️ Refactor suggestion
Review and optimize parameter configurations
The parameters section is comprehensive, but there are several points to consider:
Data Directory:
- The current data directory path (line 32) is very specific. Consider using environment variables or a configuration file to make it more portable across different environments.
- There are many commented-out data directory options. Document the reasons for these alternatives or remove them if they're no longer relevant.
Model Selection:
- The current selection includes both traditional (XGBClassifier) and deep learning (GRU, Transformer) models. This is good for comparison, but ensure that the computational resources are sufficient for all these models.
Modalities:
- Currently set to "all" (line 59). Consider experimenting with specific combinations of modalities to understand their individual impacts on the model's performance.
Seed:
- The fixed seed (1111) ensures reproducibility, which is good. However, for a more robust evaluation, consider running multiple experiments with different seeds.
Pretrained Imputation:
- Set to None (line 75). Depending on your data quality, using pretrained imputation might improve model performance. Consider experimenting with this option.
To make the configuration more flexible and easier to maintain, consider the following changes:
- Use environment variables for paths:
data_dir: values: - ${YAIB_DATA_DIR}/dataset_segment_1.0_horizon_6:00:00_transfer_3_2024-10-07T23:26:22
- Add a comment explaining the modalities choice:
modalities: values: - "all" # Using all available modalities for comprehensive analysis
- Consider adding multiple seeds for robustness:
seed: values: - 1111 - 2222 - 3333
- Experiment with pretrained imputation:
use_pretrained_imputation: values: - None - True # Add this line to experiment with pretrained imputationicu_benchmarks/models/custom_metrics.py (2)
137-150: 🛠️ Refactor suggestion
⚠️ Potential issueRefine confusion_matrix function implementation and signature.
The new
confusion_matrix
function is a valuable addition, but there are a few issues to address:
- The function signature indicates a return type of
torch.tensor
, but it actually returns a dictionary.- There's an unused variable
confusion_tensor
.- There are commented-out code blocks that may be unnecessary.
Please consider the following changes:
- Update the function signature to match the actual return type:
def confusion_matrix(y_true: ndarray, y_pred: ndarray, normalize=False) -> dict:
- Remove the unused variable:
- confusion_tensor = torch.tensor(confusion)
- Remove or uncomment the logging statement if it's needed:
- # logging.info(f"Confusion matrix: {confusion_dict}")
- Consider removing the commented-out alternative dictionary creation if it's no longer needed:
- # dict = {"TP": confusion[0][0], "FP": confusion[0][1], "FN": confusion[1][0], "TN": confusion[1][1]}
To improve readability and efficiency, consider using a dictionary comprehension:
confusion_dict = {f"class_{i}_pred_{j}": confusion[i][j] for i in range(confusion.shape[0]) for j in range(confusion.shape[1])}This change will make the code more concise and potentially more efficient.
🧰 Tools
🪛 Ruff
142-142: Local variable
confusion_tensor
is assigned to but never usedRemove assignment to unused variable
confusion_tensor
(F841)
1-8:
⚠️ Potential issueRemove unused import and approve new imports.
The new imports are relevant to the added
confusion_matrix
function. However, there's an unused import that should be removed.Please apply the following change:
-import logging
The other new imports (
ndarray
fromnumpy
andconfusion_matrix as sk_confusion_matrix
fromsklearn.metrics
) are appropriate for the new functionality.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.import torch from typing import Callable import numpy as np from ignite.metrics import EpochMetric from numpy import ndarray from sklearn.metrics import balanced_accuracy_score, mean_absolute_error, confusion_matrix as sk_confusion_matrix
🧰 Tools
🪛 Ruff
1-1:
logging
imported but unusedRemove unused import:
logging
(F401)
icu_benchmarks/cross_validation.py (1)
126-128: 🛠️ Refactor suggestion
Consider making
epochs
andpatience
configurable parameters.The
epochs
andpatience
values are currently hardcoded in thetrain_common
function call. To improve flexibility, consider making these configurable parameters of theexecute_repeated_cv
function or reading them from a configuration file.You could modify the function signature and call like this:
def execute_repeated_cv( # ... other parameters ... epochs: int = 20, patience: int = 5, # ... other parameters ... ): # ... existing code ... agg_loss += train_common( # ... other parameters ... train_only=complete_train, epochs=epochs, patience=patience ) # ... rest of the function ...This change would allow users to easily adjust these parameters without modifying the function body.
icu_benchmarks/models/ml_models/imblearn.py (3)
2-2:
⚠️ Potential issueFix typo in the import statement.
The module name
contants
seems to be misspelled. It should beconstants
.Apply this change:
-from icu_benchmarks.contants import RunMode +from icu_benchmarks.constants import RunMode📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.from icu_benchmarks.constants import RunMode
10-12: 🛠️ Refactor suggestion
Call
super().__init__
before settingself.model
.It's a good practice to initialize the parent class before accessing instance attributes to ensure proper setup.
Apply this change:
def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self.model = self.set_model_args(BalancedRandomForestClassifier, *args, **kwargs) - super().__init__(*args, **kwargs)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.model = self.set_model_args(BalancedRandomForestClassifier, *args, **kwargs)
17-19: 🛠️ Refactor suggestion
Call
super().__init__
before settingself.model
.To ensure proper initialization, invoke the parent class constructor first.
Apply this change:
def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self.model = self.set_model_args(RUSBoostClassifier, *args, **kwargs) - super().__init__(*args, **kwargs)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.model = self.set_model_args(RUSBoostClassifier, *args, **kwargs)
docs/adding_model/rnn.py (4)
19-19:
⚠️ Potential issueValidate
input_size
before accessinginput_size[2]
.Accessing
input_size[2]
assumes thatinput_size
has at least three elements. Ifinput_size
does not meet this condition, it could lead to anIndexError
. Ensure thatinput_size
is correctly defined and consider adding a validation check.Consider adding this validation:
if len(input_size) < 3: raise ValueError("input_size must have at least three elements")
3-3:
⚠️ Potential issueFix the typo in the module name 'contants' to 'constants'.
There is a typographical error in the import statement on line 3. The module name should be corrected to properly import
RunMode
.Apply this diff to fix the typo:
-from icu_benchmarks.contants import RunMode +from icu_benchmarks.constants import RunMode📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.from icu_benchmarks.constants import RunMode
15-15:
⚠️ Potential issueAvoid unpacking
*args
and**kwargs
after keyword arguments in function calls.Passing
*args
and**kwargs
after keyword arguments in a function call is discouraged and can lead to unexpected behavior. Consider refactoring the call tosuper().__init__
to improve code clarity and adherence to best practices.Apply this diff to reorder the arguments:
super().__init__( - input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, *args, **kwargs + *args, + **kwargs, + input_size=input_size, + hidden_dim=hidden_dim, + layer_dim=layer_dim, + num_classes=num_classes )Alternatively, if possible, pass all parameters explicitly without using
*args
and**kwargs
.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.*args, **kwargs, input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes
🧰 Tools
🪛 Ruff
15-15: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
28-30: 🛠️ Refactor suggestion
Clarify whether sequence-level or time-step predictions are intended.
In the
forward
method, passingout
directly toself.logit
applies the linear layer to all time steps, resulting in predictions at each time step. If the intended behavior is to make sequence-level predictions, consider using only the last time step's output.Apply this diff to modify the forward pass:
def forward(self, x): h0 = self.init_hidden(x) out, hn = self.rnn(x, h0) - pred = self.logit(out) + pred = self.logit(out[:, -1, :]) return pred📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.out, hn = self.rnn(x, h0) pred = self.logit(out[:, -1, :]) return pred
icu_benchmarks/models/ml_models/lgbm.py (4)
44-49: 🛠️ Refactor suggestion
Add
predict
method toLGBMRegressor
for consistencyThe
LGBMRegressor
class lacks apredict
method, which may lead to inconsistencies when using the model in pipelines that expect this method. Consider adding apredict
method similar toLGBMClassifier
.Suggested addition:
def predict(self, features): """Predicts outputs for the given features.""" return self.model.predict(features)
7-7:
⚠️ Potential issueTypo in module name 'contants'
There's a typo in the module name
contants
; it should beconstants
.Apply this fix:
- from icu_benchmarks.contants import RunMode + from icu_benchmarks.constants import RunMode📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.from icu_benchmarks.constants import RunMode
15-15:
⚠️ Potential issueIncorrect usage of LightGBM callbacks
The LightGBM callback functions
early_stopping
andlog_evaluation
should be accessed from thelgbm.callback
module, not directly fromlgbm
.Apply this fix:
- callbacks = [lgbm.early_stopping(self.hparams.patience, verbose=True), lgbm.log_evaluation(period=-1)] + callbacks = [lgbm.callback.early_stopping(self.hparams.patience, verbose=True), lgbm.callback.log_evaluation(period=-1)]📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.callbacks = [lgbm.callback.early_stopping(self.hparams.patience, verbose=True), lgbm.callback.log_evaluation(period=-1)]
23-23:
⚠️ Potential issueIncorrect format for
eval_set
parameterThe
eval_set
parameter in thefit
method should be a list of tuples, not a single tuple. This ensures proper evaluation during training.Apply this fix:
- eval_set=(val_data, val_labels), + eval_set=[(val_data, val_labels)],📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.eval_set=[(val_data, val_labels)],
icu_benchmarks/models/dl_models/tcn.py (2)
7-7:
⚠️ Potential issueTypo in module name 'contants'
There's a typo in the module name
contants
. It should beconstants
. This correction is necessary to avoid import errors.Apply this diff to fix the typo:
-from icu_benchmarks.contants import RunMode +from icu_benchmarks.constants import RunMode📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.from icu_benchmarks.constants import RunMode
23-23:
⚠️ Potential issueAvoid star-arg unpacking after a keyword argument ([B026])
Placing
*args
after keyword arguments in a function call is discouraged as it can lead to unexpected behavior. According to PEP 8, positional arguments (*args
) should come before keyword arguments. Consider rearranging the arguments or removing*args
if it's not needed.Apply this diff to reposition
*args
:input_size=input_size, num_channels=num_channels, num_classes=num_classes, - *args, max_seq_length=max_seq_length, kernel_size=kernel_size, dropout=dropout, + *args, **kwargs,Committable suggestion was skipped due to low confidence.
🧰 Tools
🪛 Ruff
23-23: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
icu_benchmarks/models/ml_models/sklearn.py (2)
3-3:
⚠️ Potential issueTypo in import statement: 'contants' should be 'constants'
There is a typo in the import statement on line 3. The module name should be
'constants'
instead of'contants'
.Apply this diff to fix the typo:
-from icu_benchmarks.contants import RunMode +from icu_benchmarks.constants import RunMode📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.from icu_benchmarks.constants import RunMode
8-8:
⚠️ Potential issueInconsistent naming of attribute
_supported_run_modes
In line 8, the
LogisticRegression
class defines__supported_run_modes
with double underscores, whereas other classes use a single underscore_supported_run_modes
. For consistency and to avoid unintended name mangling, please change it to a single underscore.Apply this diff to fix the inconsistency:
class LogisticRegression(MLWrapper): - __supported_run_modes = [RunMode.classification] + _supported_run_modes = [RunMode.classification]Committable suggestion was skipped due to low confidence.
icu_benchmarks/models/dl_models/rnn.py (4)
4-4:
⚠️ Potential issueTypo in module name 'contants'
There's a typo in the import statement:
from icu_benchmarks.contants import RunModeThe module name
contants
might be intended to beconstants
. Please correct the module name to ensure proper import.Apply this diff to fix the typo:
-from icu_benchmarks.contants import RunMode +from icu_benchmarks.constants import RunMode
16-16:
⚠️ Potential issueAvoid star-arg unpacking after keyword arguments
In the
__init__
methods ofRNNet
,LSTMNet
, andGRUNet
, the use of*args
and**kwargs
after keyword arguments is discouraged and can lead to unexpected behaviors.Apply this diff to reorder the arguments:
def __init__(self, input_size, hidden_dim, layer_dim, num_classes, *args, **kwargs): super().__init__( - input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, *args, **kwargs + input_size=input_size, hidden_dim=hidden_dim, layer_dim=layer_dim, num_classes=num_classes, *args, **kwargs )Alternatively, consider placing
*args
and**kwargs
before the keyword arguments in the method definition if appropriate.Also applies to: 42-42, 69-69
🧰 Tools
🪛 Ruff
16-16: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
23-25: 🛠️ Refactor suggestion
Ensure correct initialization of hidden states
The
init_hidden
methods usex.new_zeros
to initialize hidden states:h0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim)While this works, it's recommended to use
torch.zeros
with the appropriate device to ensure compatibility across different devices (CPU/GPU).Consider modifying the hidden state initialization:
-def init_hidden(self, x): - h0 = x.new_zeros(self.layer_dim, x.size(0), self.hidden_dim) +def init_hidden(self, x): + h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim, device=x.device)For the
LSTMNet
, ensure bothh0
andc0
are initialized similarly.Also applies to: 49-52, 76-78
30-31: 💡 Codebase verification
⚠️ Potential issueOutput Dimensions Verification Needed Before Passing to
self.logit
The
out
tensor is being passed directly toself.logit(out)
in multiple locations, which may include outputs for all timesteps. To ensure the correct output shape, modify the code to use only the last timestep's output.
- Affected Files and Lines:
icu_benchmarks/models/dl_models/rnn.py:30-31
icu_benchmarks/models/dl_models/rnn.py:57-58
icu_benchmarks/models/dl_models/rnn.py:82-85
docs/adding_model/rnn.py:pred = self.logit(out)
Suggested Changes:
-pred = self.logit(out) +pred = self.logit(out[:, -1, :])🔗 Analysis chain
Verify the output dimensions before passing to the linear layer
The
out
tensor from the RNN layers is passed directly toself.logit(out)
. Depending on the RNN configuration,out
may contain outputs for all timesteps. Ensure thatself.logit
is applied correctly to produce the desired output shape.Run the following script to check how
out
is used in the codebase:If you intend to use only the output from the last timestep, modify the code as follows:
Apply this diff to use the last timestep's output:
-pred = self.logit(out) +pred = self.logit(out[:, -1, :])Also applies to: 57-58, 82-85
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Description: Verify that the output from RNN layers is correctly processed. # Test: Search for usages of 'self.logit' after RNNs. # Expect: 'out' is appropriately reshaped or indexed. rg --type python "self\.logit\(out\)"Length of output: 70
Script:
#!/bin/bash # Description: Verify that the output from RNN layers is correctly processed. # Test: Search for usages of 'self.logit' after RNNs. # Expect: 'out' is appropriately reshaped or indexed. rg --type py "self\.logit\(out\)"Length of output: 301
icu_benchmarks/models/ml_models/xgboost.py (8)
23-23:
⚠️ Potential issueAvoid star-arg unpacking after a keyword argument
Using
*args
after a keyword argument (device="cpu"
) is discouraged and can lead to unexpected behavior. Consider reordering the arguments or includingdevice
inkwargs
.Apply this diff to reorder arguments:
-self.model = self.set_model_args(xgb.XGBClassifier, device="cpu", *args, **kwargs) +self.model = self.set_model_args(xgb.XGBClassifier, *args, device="cpu", **kwargs)Alternatively, include
device
inkwargs
:-self.model = self.set_model_args(xgb.XGBClassifier, device="cpu", *args, **kwargs) +self.model = self.set_model_args(xgb.XGBClassifier, *args, **kwargs)Then, in your code, ensure
device
is set withinkwargs
:kwargs.setdefault('device', 'cpu')🧰 Tools
🪛 Ruff
23-23: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
6-6:
⚠️ Potential issueRemove unused import
torch
The
torch
library is imported but not used anywhere in this module. Removing it will clean up the imports.Apply this diff to remove the unused import:
-import torch
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
🧰 Tools
🪛 Ruff
6-6:
torch
imported but unusedRemove unused import:
torch
(F401)
11-11:
⚠️ Potential issueRemove unused import
LearningRateScheduler
The
LearningRateScheduler
fromxgboost.callback
is imported but not used in this module. Please remove it to keep the code clean.Apply this diff:
-from xgboost.callback import EarlyStopping, LearningRateScheduler +from xgboost.callback import EarlyStopping📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.from xgboost.callback import EarlyStopping
🧰 Tools
🪛 Ruff
11-11:
xgboost.callback.LearningRateScheduler
imported but unusedRemove unused import:
xgboost.callback.LearningRateScheduler
(F401)
15-15:
⚠️ Potential issueRemove unused import
XGBoostPruningCallback
The
XGBoostPruningCallback
fromoptuna.integration
is not used in this module. Consider removing it to reduce unnecessary imports.Apply this diff:
-from optuna.integration import XGBoostPruningCallback
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
🧰 Tools
🪛 Ruff
15-15:
optuna.integration.XGBoostPruningCallback
imported but unusedRemove unused import:
optuna.integration.XGBoostPruningCallback
(F401)
37-37:
⚠️ Potential issueAvoid logging training labels directly
Logging
train_labels
directly can lead to excessively large log outputs and potential exposure of sensitive data. Consider summarizing the labels instead.Apply this diff:
-logging.info(train_labels) +logging.info(f"Training labels distribution: {np.bincount(train_labels)}")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.logging.info(f"Training labels distribution: {np.bincount(train_labels)}")
73-82:
⚠️ Potential issueFilter keyword arguments to valid model hyperparameters
Currently, all keyword arguments are passed to the model constructor without filtering, which may cause errors if invalid parameters are included. Consider filtering
kwargs
to include only valid hyperparameters accepted by the model.Modify the
set_model_args
method as follows:def set_model_args(self, model, *args, **kwargs): """XGBoost signature does not include the hyperparams so we need to pass them manually.""" signature = inspect.signature(model.__init__).parameters possible_hps = list(signature.keys()) # Get passed keyword arguments arguments = kwargs # Get valid hyperparameters - hyperparams = arguments + hyperparams = {k: v for k, v in arguments.items() if k in possible_hps} logging.debug(f"Creating model with: {hyperparams}.") return model(**hyperparams)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.def set_model_args(self, model, *args, **kwargs): """XGBoost signature does not include the hyperparams so we need to pass them manually.""" signature = inspect.signature(model.__init__).parameters possible_hps = list(signature.keys()) # Get passed keyword arguments arguments = kwargs # Get valid hyperparameters hyperparams = {k: v for k, v in arguments.items() if k in possible_hps} logging.debug(f"Creating model with: {hyperparams}.") return model(**hyperparams)
59-60:
⚠️ Potential issueRemove unnecessary f-strings and fix filename inconsistency
The strings
'pred_indicators.csv'
and'row_indicators.csv'
do not contain placeholders, so thef
prefix is unnecessary. Additionally, there is an inconsistency between the filename used when saving and logging.Apply this diff to remove unnecessary f-strings and fix the filename:
-np.savetxt(os.path.join(self.logger.save_dir,f'pred_indicators.csv'), pred_indicators, delimiter=",") +np.savetxt(os.path.join(self.logger.save_dir,'pred_indicators.csv'), pred_indicators, delimiter=",") -logging.debug(f"Saved row indicators to {os.path.join(self.logger.save_dir,f'row_indicators.csv')}") +logging.debug(f"Saved prediction indicators to {os.path.join(self.logger.save_dir,'pred_indicators.csv')}")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.np.savetxt(os.path.join(self.logger.save_dir,'pred_indicators.csv'), pred_indicators, delimiter=",") logging.debug(f"Saved prediction indicators to {os.path.join(self.logger.save_dir,'pred_indicators.csv')}")
🧰 Tools
🪛 Ruff
59-59: f-string without any placeholders
Remove extraneous
f
prefix(F541)
60-60: f-string without any placeholders
Remove extraneous
f
prefix(F541)
32-38:
⚠️ Potential issuePass
callbacks
parameter tomodel.fit
methodThe
callbacks
list is defined but not used in theself.model.fit
call. To enable early stopping and other callbacks, pass thecallbacks
parameter to thefit
method.Apply this diff:
self.model.fit(train_data, train_labels, eval_set=[(val_data, val_labels)], verbose=False) +self.model.fit(train_data, train_labels, eval_set=[(val_data, val_labels)], verbose=False, callbacks=callbacks)
Committable suggestion was skipped due to low confidence.
icu_benchmarks/models/dl_models/transformer.py (4)
15-77: 🛠️ Refactor suggestion
Refactor to reduce code duplication between
Transformer
andLocalTransformer
Both
Transformer
andLocalTransformer
classes share significant code in their__init__
andforward
methods. Consider refactoring by creating a base class or utility functions to encapsulate the shared logic. This will improve maintainability and reduce redundancy.Also applies to: 83-148
🧰 Tools
🪛 Ruff
37-37: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
52-52: Loop control variable
i
not used within loop bodyRename unused
i
to_i
(B007)
4-4:
⚠️ Potential issueFix typo in module name in import statement
There's a typo in the module name
contants
; it should beconstants
.Apply this diff to fix the typo:
-from icu_benchmarks.contants import RunMode +from icu_benchmarks.constants import RunMode📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.from icu_benchmarks.constants import RunMode
106-106:
⚠️ Potential issueAvoid unpacking
*args
after keyword arguments in function callUnpacking
*args
after keyword arguments in function calls is discouraged as it may lead to unexpected behavior. Consider moving*args
before the keyword arguments.Apply this diff to reorder the arguments in the
super().__init__
call:super().__init__( input_size=input_size, hidden=hidden, heads=heads, ff_hidden_mult=ff_hidden_mult, depth=depth, num_classes=num_classes, + *args, dropout=dropout, l1_reg=l1_reg, pos_encoding=pos_encoding, local_context=local_context, dropout_att=dropout_att, - *args, **kwargs, )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.super().__init__( input_size=input_size, hidden=hidden, heads=heads, ff_hidden_mult=ff_hidden_mult, depth=depth, num_classes=num_classes, *args, dropout=dropout, l1_reg=l1_reg, pos_encoding=pos_encoding, local_context=local_context, dropout_att=dropout_att, **kwargs, )
🧰 Tools
🪛 Ruff
106-106: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
37-37:
⚠️ Potential issueAvoid unpacking
*args
after keyword arguments in function callUnpacking
*args
after keyword arguments in function calls is discouraged as it may lead to unexpected behavior. Consider moving*args
before the keyword arguments.Apply this diff to reorder the arguments in the
super().__init__
call:super().__init__( input_size=input_size, hidden=hidden, heads=heads, ff_hidden_mult=ff_hidden_mult, depth=depth, num_classes=num_classes, + *args, dropout=dropout, l1_reg=l1_reg, pos_encoding=pos_encoding, dropout_att=dropout_att, - *args, **kwargs, )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.super().__init__( input_size=input_size, hidden=hidden, heads=heads, ff_hidden_mult=ff_hidden_mult, depth=depth, num_classes=num_classes, *args, dropout=dropout, l1_reg=l1_reg, pos_encoding=pos_encoding, dropout_att=dropout_att, **kwargs, )
🧰 Tools
🪛 Ruff
37-37: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
icu_benchmarks/run.py (2)
61-64: 🛠️ Refactor suggestion
Enhance error messages with available options
To improve user experience, include the list of available tasks or models in the error messages. This helps users correct their input when a task or model is not found.
Apply this diff to modify the error messages:
if task not in tasks: - raise ValueError(f"Task {task} not found in tasks.") + raise ValueError(f"Task '{task}' not found. Available tasks: {', '.join(tasks)}.") if model not in models: - raise ValueError(f"Model {model} not found in models.") + raise ValueError(f"Model '{model}' not found. Available models: {', '.join(models)}.")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.if task not in tasks: raise ValueError(f"Task '{task}' not found. Available tasks: {', '.join(tasks)}.") if model not in models: raise ValueError(f"Model '{model}' not found. Available models: {', '.join(models)}.")
191-194:
⚠️ Potential issueImprove exception handling for result aggregation
Catching the base
Exception
may suppress important errors. Consider catching specific exceptions or re-raising the exception after logging. Also, uselogging.exception
to log the stack trace for better debugging.Apply this diff to improve exception handling:
try: aggregate_results(run_dir, execution_time) -except Exception as e: - logging.error(f"Failed to aggregate results: {e}") +except Exception: + logging.exception("Failed to aggregate results") + raise📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.try: aggregate_results(run_dir, execution_time) except Exception: logging.exception("Failed to aggregate results") raise
icu_benchmarks/models/train.py (4)
113-113: 🛠️ Refactor suggestion
Uncomment
pin_memory
if required for performanceThe
pin_memory
parameter in theDataLoader
is commented out. If data loading performance is a priority and the hardware supports it, consider uncommentingpin_memory=not cpu
to enable faster data transfer to GPU memory.Apply this diff to enable
pin_memory
:- # pin_memory=not cpu, + pin_memory=not cpu,📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.pin_memory=not cpu,
85-86: 🛠️ Refactor suggestion
Improve readability of nested conditional expression
The assignment of
dataset_class
uses a nested ternary operator, which can reduce code readability. Refactor to an explicitif-elif-else
structure for clarity.Apply this diff to enhance readability:
- dataset_class = ImputationPandasDataset if mode == RunMode.imputation else PredictionPolarsDataset if polars else PredictionPandasDataset + if mode == RunMode.imputation: + dataset_class = ImputationPandasDataset + elif polars: + dataset_class = PredictionPolarsDataset + else: + dataset_class = PredictionPandasDataset📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.if mode == RunMode.imputation: dataset_class = ImputationPandasDataset elif polars: dataset_class = PredictionPolarsDataset else: dataset_class = PredictionPandasDataset # dataset_class = ImputationPandasDataset if mode == RunMode.imputation else PredictionPandasDataset
41-42:
⚠️ Potential issueRe-evaluate default
batch_size
andepochs
valuesThe default
batch_size
has been changed from64
to1
, andepochs
from1000
to100
. Abatch_size
of1
can lead to slow training due to less efficient GPU utilization. Unless there's a specific reason for these values, consider setting a largerbatch_size
and increasingepochs
to ensure adequate training.Apply this diff to adjust the default values:
- batch_size=1, - epochs=100, + batch_size=64, + epochs=1000,📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.batch_size=64, epochs=1000,
200-219: 🛠️ Refactor suggestion
⚠️ Potential issueHandle exceptions specifically and correct file naming in logs
In the
persist_data
function:
- Catching the general
Exception
can obscure specific errors. Consider catching more specific exceptions likeAttributeError
orIOError
.- The logging messages refer to
'test_shap_values.parquet'
, but the files are saved as'shap_values_test.parquet'
and'shap_values_train.parquet'
. This mismatch could confuse users checking the logs.Apply this diff to correct the log messages:
- logging.info(f"Saved shap values to {log_dir / 'test_shap_values.parquet'}") + logging.info(f"Saved shap values to {log_dir / 'shap_values_test.parquet'}") - logging.info(f"Saved shap values to {log_dir / 'test_shap_values.parquet'}") + logging.info(f"Saved shap values to {log_dir / 'shap_values_train.parquet'}")Additionally, update the exception handling:
- except Exception as e: + except (AttributeError, IOError) as e:Committable suggestion was skipped due to low confidence.
icu_benchmarks/models/utils.py (5)
199-205:
⚠️ Potential issuePotential misuse of
average_precision_score
with class labels instead of scoresIn the
__call__
method ofscorer_wrapper
, wheny_pred
is multi-dimensional andy_true
has at most two unique values (binary classification),np.argmax
is used to converty_pred
to class labels before passing it toself.scorer
. However,average_precision_score
expects probability estimates or confidence scores, not discrete class labels. Passing class labels may lead to incorrect metric calculations.Consider passing the probability estimates directly to
self.scorer
without applyingnp.argmax
. Ify_pred
contains logits or probabilities, they should be used to accurately compute the average precision score.
237-238:
⚠️ Potential issuePossible
IndexError
when accessingpos_event_change_full[-1]
In the condition:
if ((last_know_label == 1) and (len(pos_event_change_full) == 0)) or ( (last_know_label == 1) and (last_know_idx >= pos_event_change_full[-1])):If
pos_event_change_full
is empty, accessingpos_event_change_full[-1]
will raise anIndexError
. The second part of theor
condition does not check if the list is non-empty before accessing the last element.Ensure that
pos_event_change_full
is not empty before accessingpos_event_change_full[-1]
. You can restructure the condition or add a check to prevent the potentialIndexError
.
248-250:
⚠️ Potential issuePotential
IndexError
due to empty array when accessinglast_pos_change
In the code:
pos_change = np.where((diffs_label >= 1) & (label_for_event == 1))[0] last_pos_change = pos_change[np.where(pos_change > max(last_know_event, last_known_stable))][0]If
pos_change
is empty or if no elements satisfy the condition(pos_change > max(last_know_event, last_known_stable))
, accessing index[0]
will raise anIndexError
.Add a check to ensure that the array is not empty before accessing
[0]
. Handle cases wherepos_change
has no elements satisfying the condition to prevent runtime errors.
206-208: 🛠️ Refactor suggestion
Use
__str__
or__repr__
instead of defining__name__
methodDefining a
__name__
method in a class is unconventional in Python. The__name__
attribute is typically associated with modules and functions, not class instances. To provide a string representation of the class or its instances, consider implementing the__str__
or__repr__
method.Apply this diff to replace
__name__
with__str__
:-def __name__(self): +def __str__(self): return "scorer_wrapper"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.def __str__(self): return "scorer_wrapper"
192-195: 🛠️ Refactor suggestion
Adjust docstring indentation in
scorer_wrapper
classThe docstring inside the
scorer_wrapper
class has incorrect indentation. According to PEP 257, the docstring should be directly under the class definition with proper indentation for readability.Apply this diff to fix the docstring indentation:
class scorer_wrapper: - """ - Wrapper that flattens the binary classification input such that we can use a broader range of sklearn metrics. - """ + """ + Wrapper that flattens the binary classification input such that we can use a broader range of sklearn metrics. + """📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.class scorer_wrapper: """ Wrapper that flattens the binary classification input such that we can use a broader range of sklearn metrics. """
icu_benchmarks/run_utils.py (3)
79-84: 🛠️ Refactor suggestion
Simplify directory creation logic
The current logic for creating
log_dir_run
can be simplified to handle directory clashes more elegantly and avoid potential race conditions.Consider refactoring using a loop to ensure a unique directory is created:
while True: log_dir_run = log_dir / datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f") try: log_dir_run.mkdir(parents=True) break except FileExistsError: continueThis approach attempts to create the directory and, in case of a
FileExistsError
, retries with a new timestamp until it succeeds.
112-112:
⚠️ Potential issueRemove unused variable
shap_values_train
The variable
shap_values_train
is assigned but never used, which may lead to confusion.Apply this diff to remove the unused variable:
- shap_values_train = []
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
🧰 Tools
🪛 Ruff
112-112: Local variable
shap_values_train
is assigned to but never usedRemove assignment to unused variable
shap_values_train
(F841)
257-268: 🛠️ Refactor suggestion
Use
Path
methods instead of mixingos.path
andpathlib.Path
The current implementation mixes
os.path
andpathlib.Path
, which can lead to inconsistencies and errors. Sinceconfig_dir
is aPath
object, it's better to usePath
methods throughout.Refactor the function to use
Path.glob()
andPath
properties:-def get_config_files(config_dir: Path): - tasks = glob.glob(os.path.join(config_dir / "tasks", '*')) - models = glob.glob(os.path.join(config_dir / "prediction_models", '*')) - tasks = [os.path.splitext(os.path.basename(task))[0] for task in tasks] - models = [os.path.splitext(os.path.basename(model))[0] for model in models] +def get_config_files(config_dir: Path): + tasks = [task.stem for task in (config_dir / "tasks").glob('*')] + models = [model.stem for model in (config_dir / "prediction_models").glob('*')] if "common" in tasks: tasks.remove("common") if "common" in models: models.remove("common") logging.info(f"Found tasks: {tasks}") logging.info(f"Found models: {models}") return tasks, modelsThis change enhances readability and consistency by fully utilizing
pathlib
.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.def get_config_files(config_dir: Path): tasks = [task.stem for task in (config_dir / "tasks").glob('*')] models = [model.stem for model in (config_dir / "prediction_models").glob('*')] if "common" in tasks: tasks.remove("common") if "common" in models: models.remove("common") logging.info(f"Found tasks: {tasks}") logging.info(f"Found models: {models}") return tasks, models
icu_benchmarks/tuning/hyperparameters.py (3)
243-260:
⚠️ Potential issueCorrect typos: Replace
trail
withtrial
andtrails
withtrials
There are several instances where
trail
is used instead oftrial
, andtrails
instead oftrials
. This may lead to errors or confusion. Please correct these typos throughout the code.Apply these diffs to fix the typos:
- Replace
trail
withtrial
in theobjective
function and thestudy.optimize
call:def objective(trail, hyperparams_bounds, hyperparams_names): +def objective(trial, hyperparams_bounds, hyperparams_names): hyperparams = {} logging.info(f"Bounds: {hyperparams_bounds}, Names: {hyperparams_names}") for name, value in zip(hyperparams_names, hyperparams_bounds): if isinstance(value, tuple): # Check for range or "list-type" hyperparameter bounds if isinstance(value[0], (int, float)) and isinstance(value[1], (int, float)): if len(value) == 3 and isinstance(value[2], str): if isinstance(value[0], int) and isinstance(value[1], int): - hyperparams[name] = trail.suggest_int(name, value[0], value[1], log=value[2] == "log") + hyperparams[name] = trial.suggest_int(name, value[0], value[1], log=value[2] == "log") elif isinstance(value[0], (int, float)) and isinstance(value[1], (int, float)): - hyperparams[name] = trail.suggest_float(name, value[0], value[1], log=value[2] == "log") + hyperparams[name] = trial.suggest_float(name, value[0], value[1], log=value[2] == "log") else: - hyperparams[name] = trail.suggest_categorical(name, value) + hyperparams[name] = trial.suggest_categorical(name, value) elif len(value) == 2: if isinstance(value[0], int) and isinstance(value[1], int): - hyperparams[name] = trail.suggest_int(name, value[0], value[1]) + hyperparams[name] = trial.suggest_int(name, value[0], value[1]) elif isinstance(value[0], (int, float)) and isinstance(value[1], (int, float)): - hyperparams[name] = trail.suggest_float(name, value[0], value[1]) + hyperparams[name] = trial.suggest_float(name, value[0], value[1]) else: - hyperparams[name] = trail.suggest_categorical(name, value) + hyperparams[name] = trial.suggest_categorical(name, value) else: - hyperparams[name] = trail.suggest_categorical(name, value) + hyperparams[name] = trial.suggest_categorical(name, value) else: - hyperparams[name] = trail.suggest_categorical(name, value) + hyperparams[name] = trial.suggest_categorical(name, value) return bind_params_and_train(hyperparams)
- Update the lambda function in
study.optimize
:study.optimize( - lambda trail: objective(trail, hyperparams_bounds, hyperparams_names), + lambda trial: objective(trial, hyperparams_bounds, hyperparams_names), n_trials=n_calls, callbacks=callbacks, gc_after_trial=True, )
- Correct
trails
totrials
in the logging statement:logging.info(f"Starting or resuming Optuna study with {n_calls} trails and callbacks: {callbacks}.") +logging.info(f"Starting or resuming Optuna study with {n_calls} trials and callbacks: {callbacks}.")
Also applies to: 355-355, 352-352
373-382:
⚠️ Potential issueCorrect the plotting and saving of Optuna visualization figures
The functions
plot_param_importances
andplot_optimization_history
fromoptuna.visualization
return Plotly figures, not Matplotlib figures. Usingplt.savefig
will not save these figures correctly.To save the Plotly figures, you can use the
write_image
method. Here's how you can modify the code:if plot: try: logging.info("Plotting hyperparameter importances.") - plot_param_importances(study) - plt.savefig(log_dir / "param_importances.png") - plot_optimization_history(study) - plt.savefig(log_dir / "optimization_history.png") + import plotly.io as pio + + fig1 = plot_param_importances(study) + pio.write_image(fig1, str(log_dir / "param_importances.png")) + + fig2 = plot_optimization_history(study) + pio.write_image(fig2, str(log_dir / "optimization_history.png")) except Exception as e: logging.error(f"Failed to plot hyperparameter importances: {e}")Additionally, ensure that you have installed the necessary dependencies (e.g.,
kaleido
) to enable Plotly to export images.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.if plot: try: logging.info("Plotting hyperparameter importances.") import plotly.io as pio fig1 = plot_param_importances(study) pio.write_image(fig1, str(log_dir / "param_importances.png")) fig2 = plot_optimization_history(study) pio.write_image(fig2, str(log_dir / "optimization_history.png")) except Exception as e: logging.error(f"Failed to plot hyperparameter importances: {e}")
323-326:
⚠️ Potential issueFix logical condition when checking for checkpoint existence
The condition in the
if
statement contradicts the logging message. If the checkpoint exists, it should proceed to load it instead of logging that it does not exist. Adjust the condition to check if the checkpoint does not exist.Apply this diff to correct the condition:
if checkpoint and checkpoint.exists(): - logging.warning(f"Hyperparameter checkpoint {checkpoint} does not exist.") + logging.info(f"Loading checkpoint at {checkpoint}") study = optuna.load_study( study_name="tuning", storage="sqlite:///" + str(checkpoint), sampler=sampler, pruner=pruner, ) n_calls = n_calls - len(study.trials) else: if checkpoint: logging.warning("Checkpoint path given as flag but not found, starting from scratch.") # Continue with creating a new study📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.if checkpoint and checkpoint.exists(): logging.info(f"Loading checkpoint at {checkpoint}") study = optuna.load_study( study_name="tuning", storage="sqlite:///" + str(checkpoint), sampler=sampler, pruner=pruner, ) n_calls = n_calls - len(study.trials) else: if checkpoint: logging.warning("Checkpoint path given as flag but not found, starting from scratch.") # Continue with creating a new study
icu_benchmarks/data/loader.py (2)
15-184: 🛠️ Refactor suggestion
Consider refactoring to reduce code duplication
There is significant code duplication between the Polars-based and Pandas-based dataset classes (
CommonPolarsDataset
vs.CommonPandasDataset
,PredictionPolarsDataset
vs.PredictionPandasDataset
). This can make maintenance more challenging and increase the risk of inconsistencies. Consider refactoring by creating a base class that encapsulates shared functionality or using a strategy pattern to handle the differences between Polars and Pandas. This will promote code reuse and improve maintainability.Also applies to: 186-336
309-309:
⚠️ Potential issueRemove unused variable
weights
At line 309, the variable
weights
is assigned but not used before being returned. You can return the computed value directly without assigning it toweights
.Apply this diff to fix the issue:
- weights = list((1 / counts) * np.sum(counts) / counts.shape[0]) return list((1 / counts) * np.sum(counts) / counts.shape[0])
Committable suggestion was skipped due to low confidence.
🧰 Tools
🪛 Ruff
309-309: Local variable
weights
is assigned to but never usedRemove assignment to unused variable
weights
(F841)
icu_benchmarks/data/split_process_data.py (6)
9-9:
⚠️ Potential issueRemove unused imports to improve code cleanliness.
The imports
pyarrow.parquet as pq
,sequence
fromsetuptools.dist
, andfalse
fromsqlalchemy
are not used in the code and can be safely removed.Apply this diff to remove the unused imports:
-import pyarrow.parquet as pq -from setuptools.dist import sequence -from sqlalchemy import falseAlso applies to: 15-15, 17-17
31-32:
⚠️ Potential issueAvoid using mutable default arguments.
Using mutable default arguments like
{}
and[]
can lead to unexpected behavior because default arguments are evaluated only once. It's better to useNone
and assign within the function to avoid shared mutable defaults.Apply this diff to fix the default arguments:
- modality_mapping: dict[str] = {}, - selected_modalities: list[str] = "all", ... - vars_to_exclude: list[str] = [], + modality_mapping: dict[str] = None, + selected_modalities: list[str] = None, ... + vars_to_exclude: list[str] = None,Then, within the function, initialize them:
if modality_mapping is None: modality_mapping = {} if selected_modalities is None: selected_modalities = ["all"] if vars_to_exclude is None: vars_to_exclude = []Also applies to: 46-46
🧰 Tools
🪛 Ruff
31-31: Do not use mutable data structures for argument defaults
Replace with
None
; initialize within function(B006)
32-32:
⚠️ Potential issueType mismatch in default value for
selected_modalities
.The parameter
selected_modalities
is annotated aslist[str]
, but the default value is a string"all"
. This can lead to type errors. Consider initializing it toNone
and setting the default within the function.Apply this diff:
-selected_modalities: list[str] = "all", +selected_modalities: list[str] = None,Then, within the function:
if selected_modalities is None: selected_modalities = ["all"]
129-129:
⚠️ Potential issueUse
is
for comparison toNone
and simplify condition.Comparison to
None
should be done usingis
oris not
. Additionally, the condition can be simplified for better readability.Apply this diff:
-if not (selected_modalities == "all" or selected_modalities == ["all"] or selected_modalities == None): +if selected_modalities not in [None, "all", ["all"]]:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.if selected_modalities not in [None, "all", ["all"]]:
🧰 Tools
🪛 Ruff
129-129: Comparison to
None
should becond is None
Replace with
cond is None
(E711)
161-162:
⚠️ Potential issueAvoid using
dict
as a variable name to prevent shadowing built-in types.Using
dict
as a variable name shadows the built-indict
type in Python, which can lead to unexpected behaviors. Consider renaming the variable to something likedata_dict
ordata_split
.Apply this diff:
-for dict in data.values(): - for key, val in dict.items(): +for data_split in data.values(): + for key, val in data_split.items():📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.for data_split in data.values(): for key, val in data_split.items():
192-192:
⚠️ Potential issueVariable
sequence
is redefined and shadows an imported name.The variable
sequence
is redefined, which could potentially lead to confusion or errors.Consider renaming the variable or removing the unused import if it's not needed.
Apply this diff:
-from setuptools.dist import sequence ... - sequence = vars[Var.sequence] if Var.sequence in vars.keys() else None + sequence_var = vars[Var.sequence] if Var.sequence in vars else None📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.sequence_var = vars[Var.sequence] if Var.sequence in vars else None
🧰 Tools
🪛 Ruff
192-192: Redefinition of unused
sequence
from line 15(F811)
192-192: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
icu_benchmarks/models/wrappers.py (3)
6-6:
⚠️ Potential issueRemove unused imports
average_precision_score
androc_auc_score
.The functions
average_precision_score
androc_auc_score
are imported but not used in the code. Removing unused imports helps keep the codebase clean and reduces potential confusion.Apply this diff to remove the unused imports:
-from sklearn.metrics import log_loss, mean_squared_error, average_precision_score, roc_auc_score +from sklearn.metrics import log_loss, mean_squared_error📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.from sklearn.metrics import log_loss, mean_squared_error
20-20:
⚠️ Potential issueRemove unused import
scorer_wrapper
.The
scorer_wrapper
function is imported fromicu_benchmarks.models.utils
but is not utilized within the code. Cleaning up unused imports improves code readability and maintenance.Apply this diff to remove the unused import:
-from icu_benchmarks.models.utils import create_optimizer, create_scheduler, scorer_wrapper +from icu_benchmarks.models.utils import create_optimizer, create_scheduler📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.from icu_benchmarks.models.utils import create_optimizer, create_scheduler
🧰 Tools
🪛 Ruff
20-20:
icu_benchmarks.models.utils.scorer_wrapper
imported but unusedRemove unused import:
icu_benchmarks.models.utils.scorer_wrapper
(F401)
33-34:
⚠️ Potential issueRemove unnecessary gin configurations for unused metrics.
The
average_precision_score
androc_auc_score
metrics are registered with Gin but are not used elsewhere in the code. Consider removing these configurations to streamline the code.Apply this diff to remove the unnecessary configurations:
-gin.config.external_configurable(average_precision_score, module="sklearn.metrics") -gin.config.external_configurable(roc_auc_score, module="sklearn.metrics")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
icu_benchmarks/data/preprocessor.py (2)
9-9:
⚠️ Potential issueRemove unused import
polars.selectors as cs
The import statement
import polars.selectors as cs
on line 9 is unused in the code and can be removed to clean up the imports.Apply this diff to remove the unused import:
-import polars.selectors as cs
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
🧰 Tools
🪛 Ruff
9-9:
polars.selectors
imported but unusedRemove unused import:
polars.selectors
(F401)
203-203:
⚠️ Potential issueHandle potential
NoneType
forself.imputation_model
In the
to_cache_string
method, accessingself.imputation_model.__class__.__name__
assumes thatself.imputation_model
is notNone
. Ifself.imputation_model
isNone
, this will raise anAttributeError
. Consider handling this case to prevent potential errors.Apply this diff to handle the potential
NoneType
:+ imputation_model_name = self.imputation_model.__class__.__name__ if self.imputation_model else 'None' return ( super().to_cache_string() - + f"_classification_{self.generate_features}_{self.scaling}_{self.imputation_model.__class__.__name__}" + + f"_classification_{self.generate_features}_{self.scaling}_{imputation_model_name}" )Committable suggestion was skipped due to low confidence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 12
🧹 Outside diff range and nitpick comments (41)
icu_benchmarks/models/ml_models/catboost.py (1)
17-19
: Enhance the docstring for the predict methodThe
predict
method looks good, but the docstring could be more informative.Consider expanding the docstring to include information about the input and output:
def predict(self, features): """ Predicts class probabilities for the given features. Args: features: Input features for prediction. Returns: numpy.ndarray: Predicted probabilities for each class. """ return self.model.predict_proba(features)This provides more context about what the method expects and what it returns.
icu_benchmarks/models/ml_models/xgboost.py (4)
1-15
: LGTM! Consider organizing imports.The imports look appropriate for the XGBoost classifier implementation. However, consider organizing them according to PEP 8 guidelines: standard library imports first, then third-party imports, and finally local imports.
Here's a suggested organization:
import inspect import logging import os from statistics import mean import gin import numpy as np import shap import wandb import xgboost as xgb from xgboost.callback import EarlyStopping from wandb.integration.xgboost import wandb_callback as wandb_xgb from icu_benchmarks.constants import RunMode from icu_benchmarks.models.wrappers import MLWrapper # Uncomment if needed in the future # from optuna.integration import XGBoostPruningCallback
30-46
: LGTM! Consider minor improvements in fit_model.The
fit_model
method looks good overall:
- EarlyStopping callback is used to prevent overfitting.
- Optional WandB integration is flexible.
- SHAP values are computed for model interpretability.
Consider the following improvements:
- The SHAP computation might be expensive for large datasets. Consider adding an option to disable it or compute it on a subset of the data.
- The evaluation score calculation could be more explicit. Consider using a named metric instead of the first available one.
Here's a suggested improvement for the evaluation score:
eval_results = self.model.evals_result_["validation_0"] primary_metric = list(eval_results.keys())[0] # e.g., 'auc', 'logloss' eval_score = mean(eval_results[primary_metric]) return eval_score
48-73
: LGTM with suggestions for improved clarity and robustness.The
test_step
method is comprehensive, handling prediction, logging, and SHAP value computation. However, consider the following improvements:
- The handling of
pred_indicators
could be more explicit. Add type hints or documentation to clarify the expected structure.- The purpose of
self.mps
is not clear. Consider adding a comment or renaming for clarity.- The CSV saving logic could be extracted into a separate method for better readability.
Here's a suggested refactoring for the CSV saving logic:
def save_predictions_to_csv(self, pred_indicators, test_label, test_pred): if (len(pred_indicators.shape) > 1 and len(test_pred.shape) > 1 and pred_indicators.shape[1] == test_pred.shape[1]): output = np.hstack((pred_indicators, test_label.reshape(-1, 1), test_pred)) output_path = os.path.join(self.logger.save_dir, "pred_indicators.csv") np.savetxt(output_path, output, delimiter=",") logging.debug(f"Saved predictions to {output_path}") # Call this method in test_step self.save_predictions_to_csv(pred_indicators, test_label, test_pred)This refactoring improves readability and separates concerns in the
test_step
method.🧰 Tools
🪛 Ruff
63-63: f-string without any placeholders
Remove extraneous
f
prefix(F541)
85-86
: LGTM! Consider adding error handling.The
get_feature_importance
method correctly returns the feature importances from the XGBoost model. However, consider adding error handling in case the model hasn't been fit yet.Here's a suggested improvement:
def get_feature_importance(self): if not hasattr(self.model, 'feature_importances_'): raise ValueError("Model has not been fit yet. Call fit_model() before getting feature importances.") return self.model.feature_importances_This addition will provide a clear error message if the method is called before the model is fit.
icu_benchmarks/run.py (3)
54-58
: LGTM: New logic for handling modalities and labelsThe new code block effectively handles the modalities and label arguments, using gin bindings to set them appropriately. The debug logging for modalities is helpful for tracking.
Consider adding a similar debug log for the label binding:
if args.label: logging.debug(f"Binding label: {args.label}") gin.bind_parameter("preprocess.label", args.label)This would maintain consistency with the modalities logging and provide additional clarity during execution.
59-64
: LGTM: New logic for config file retrieval and validationThe addition of the
get_config_files
function call and subsequent validation checks for tasks and models enhances the robustness of the script. This ensures that only valid tasks and models are processed, preventing potential runtime errors.Consider combining the two separate checks into a single if statement for slightly more efficient error handling:
if task not in tasks or model not in models: raise ValueError(f"Invalid task or model. Task: {task} {'not ' if task not in tasks else ''}found. Model: {model} {'not ' if model not in models else ''}found.")This change would provide a single, comprehensive error message if either the task or model is invalid.
191-194
: LGTM: Added error handling for aggregate_resultsThe introduction of a try-except block around the
aggregate_results
call is a good addition. It improves the script's robustness by preventing crashes due to issues during result aggregation.Consider adding more detailed error logging:
try: aggregate_results(run_dir, execution_time) except Exception as e: logging.error(f"Failed to aggregate results: {e}") logging.debug("Error details:", exc_info=True)This change would provide a stack trace in debug mode, which could be helpful for troubleshooting more complex issues.
icu_benchmarks/data/pooling.py (1)
69-70
: Approve changes with a minor optimization suggestionThe use of
pq.read_table().to_pandas(self_destruct=True)
is a good performance optimization. It allows the Arrow memory to be reused, potentially reducing memory usage.However, there's a minor optimization opportunity:
Consider simplifying the dictionary comprehension by removing
.keys()
:data[folder.name] = { f: pq.read_table(folder / self.file_names[f]).to_pandas(self_destruct=True) for f in self.file_names }This change slightly improves performance by avoiding the creation of a
dict_keys
object.🧰 Tools
🪛 Ruff
70-70: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
icu_benchmarks/models/train.py (5)
84-89
: LGTM: Flexible dataset class selection implementedThe new logic for selecting the dataset class based on the
polars
parameter enhances flexibility. The added logging statement improves transparency and aids in debugging.Consider simplifying the dataset class selection logic using a dictionary mapping:
dataset_classes = { RunMode.imputation: ImputationPandasDataset, RunMode.classification: PredictionPolarsDataset if polars else PredictionPandasDataset, RunMode.regression: PredictionPolarsDataset if polars else PredictionPandasDataset, } dataset_class = dataset_classes[mode]This approach would be more scalable if additional modes are added in the future.
Also applies to: 92-92
95-98
: LGTM: DataLoader updates for performance and polars compatibilityThe addition of the
persistent_workers
parameter to the DataLoader instantiation could potentially improve data loading performance. The commented-out code for converting datasets to pandas is no longer necessary due to the transition to polars.Consider removing the commented-out code (lines 96-98) to improve code cleanliness. If this code needs to be preserved for reference, consider moving it to a separate document or adding a more detailed comment explaining why it's being kept.
Also applies to: 119-119, 128-128
155-155
: LGTM: Trainer configuration improvedThe changes to the Trainer instantiation improve the training process:
- Setting
min_epochs=1
ensures at least one epoch of training is performed.- Adding
num_sanity_val_steps=2
helps catch errors in the validation loop before training begins.Consider adding a comment explaining the rationale behind changing
min_epochs
from 0 to 1, as this might not be immediately obvious to other developers.Also applies to: 164-164
181-181
: LGTM: Test dataset handling improved and data persistence addedThe changes to test dataset creation and evaluation are consistent with earlier modifications:
- Addition of
ram_cache
parameter to test dataset creation.- Inclusion of
persistent_workers
in the test DataLoader.The new call to
persist_data(trainer, log_dir)
suggests additional data persistence functionality has been implemented.Please add documentation (preferably a docstring) for the new
persist_data
function to explain its purpose, parameters, and return value (if any).Also applies to: 192-192, 200-200
205-223
: LGTM: Newpersist_data
function for SHAP value persistenceThe new
persist_data
function is a valuable addition for persisting SHAP values:
- It handles both test and train SHAP values.
- Uses polars DataFrame for efficient data handling.
- Implements proper error handling and logging.
Consider the following improvements:
- Add a docstring explaining the function's purpose, parameters, and return value (if any).
- The commented-out code (lines 209-212) can be removed if it's no longer needed.
- Consider using a context manager (
with
) when writing parquet files to ensure proper file handling.- The error message in the except block could be more specific, e.g., include which operation failed (test or train SHAP values).
Example improvement for point 3:
with (log_dir / "shap_values_test.parquet").open("wb") as f: shaps_test.write_parquet(f)icu_benchmarks/run_utils.py (4)
58-59
: LGTM: New command-line arguments enhance script flexibility.The addition of
--modalities
and--label
arguments improves the script's versatility by allowing users to specify evaluation parameters. The implementation is consistent with the existing code style.Consider adding help text for the
--modalities
argument to clarify its purpose and expected input format, similar to the--label
argument:- parser.add_argument("-mo", "--modalities", nargs="+", help="Modalities to use for evaluation.") + parser.add_argument("-mo", "--modalities", nargs="+", help="Modalities to use for evaluation. Specify multiple modalities separated by spaces.")
80-85
: LGTM: Improved directory creation logic.The new implementation ensures unique directory names and handles potential naming conflicts effectively. The added checks for directory existence before creation are a good practice.
Consider simplifying the logic slightly by using a loop to handle potential conflicts:
- if not log_dir_run.exists(): - log_dir_run.mkdir(parents=True) - else: - # Directory clash at last moment - log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f")) - log_dir_run.mkdir(parents=True) + while log_dir_run.exists(): + log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f")) + log_dir_run.mkdir(parents=True)This approach reduces code duplication and handles multiple potential conflicts in a single loop.
Line range hint
112-137
: LGTM: Added SHAP values aggregation.The new implementation effectively aggregates SHAP values from test datasets using polars for efficient parquet file handling. The aggregated values are correctly saved to a new parquet file.
Consider adding error handling for the parquet file operations:
if shap_values_test: - shap_values = pl.concat(shap_values_test) - shap_values.write_parquet(log_dir / "aggregated_shap_values.parquet") + try: + shap_values = pl.concat(shap_values_test) + shap_values.write_parquet(log_dir / "aggregated_shap_values.parquet") + except Exception as e: + logging.error(f"Error aggregating or writing SHAP values: {e}")This addition will help catch and log any potential issues with the parquet operations.
258-269
: LGTM: New function for retrieving configuration files.The
get_config_files
function effectively retrieves and processes configuration files from the specified directories. It correctly filters out "common" entries and logs the found tasks and models.Consider adding error handling and using
Path
objects consistently for better cross-platform compatibility:def get_config_files(config_dir: Path): - tasks = glob.glob(os.path.join(config_dir / "tasks", "*")) - models = glob.glob(os.path.join(config_dir / "prediction_models", "*")) - tasks = [os.path.splitext(os.path.basename(task))[0] for task in tasks] - models = [os.path.splitext(os.path.basename(model))[0] for model in models] + try: + tasks = list((config_dir / "tasks").glob("*")) + models = list((config_dir / "prediction_models").glob("*")) + tasks = [task.stem for task in tasks if task.is_file()] + models = [model.stem for model in models if model.is_file()] + except Exception as e: + logging.error(f"Error retrieving config files: {e}") + return [], [] if "common" in tasks: tasks.remove("common") if "common" in models: models.remove("common") logging.info(f"Found tasks: {tasks}") logging.info(f"Found models: {models}") return tasks, modelsThis modification improves error handling and uses
Path
objects consistently, enhancing robustness and cross-platform compatibility.icu_benchmarks/tuning/hyperparameters.py (5)
5-5
: LGTM! Consider grouping related imports.The new imports for matplotlib and Optuna visualization functions are appropriate for the added visualization capabilities. The
plot
parameter in the function signature is a good addition for flexibility.Consider grouping related imports together for better readability. For example, you could move the matplotlib import closer to the Optuna visualization imports.
Also applies to: 18-18, 194-194
282-292
: LGTM! Consider consolidating checkpoint loading logic.The updated checkpoint handling using Optuna's study loading functionality is a significant improvement. The added error handling and logging provide better visibility into the process.
Consider consolidating the checkpoint loading logic into a separate function for better modularity and reusability. This would make the main function cleaner and easier to maintain.
Also applies to: 324-342
244-261
: LGTM! Consider refactoring for improved readability.The updated objective function now handles hyperparameter bounds more flexibly, providing better support for various types and ranges. The logic for determining the suggestion type is comprehensive and robust.
Consider refactoring the nested if-else statements into separate functions for each hyperparameter type. This could improve readability and make the code easier to maintain. For example:
def suggest_int_param(trial, name, value): return trial.suggest_int(name, value[0], value[1], log=value[2] == "log" if len(value) == 3 else False) def suggest_float_param(trial, name, value): return trial.suggest_float(name, value[0], value[1], log=value[2] == "log" if len(value) == 3 else False) def suggest_categorical_param(trial, name, value): return trial.suggest_categorical(name, value) # Then in the objective function: if isinstance(value[0], int) and isinstance(value[1], int): hyperparams[name] = suggest_int_param(trial, name, value) elif isinstance(value[0], (int, float)) and isinstance(value[1], (int, float)): hyperparams[name] = suggest_float_param(trial, name, value) else: hyperparams[name] = suggest_categorical_param(trial, name, value)This approach would make the code more modular and easier to extend in the future.
343-365
: LGTM! Robust Optuna study implementation with W&B integration.The Optuna study creation and optimization logic is well-implemented, supporting resumable studies and integrating with Weights & Biases for enhanced tracking. The code handles both new studies and resuming from checkpoints effectively.
Consider adding a comment explaining the
gc_after_trial=True
parameter in thestudy.optimize
call. This parameter is important for memory management, especially for long-running optimizations, and a brief explanation would be helpful for future maintainers.
374-383
: LGTM! Valuable visualization additions with proper error handling.The new visualization code for plotting hyperparameter importances and optimization history provides valuable insights into the optimization process. The error handling ensures that the main process continues even if plotting fails, which is a good practice.
Consider adding a configuration option to specify the output format of the plots (e.g., PNG, PDF, SVG). This would provide more flexibility for users who might prefer different formats for their documentation or analysis. You could add a parameter like
plot_format: str = 'png'
to the function and use it in thesavefig
calls.icu_benchmarks/data/loader.py (5)
15-80
: Good implementation ofCommonPolarsDataset
using PolarsThe transition to Polars for data manipulation is a good choice for improving performance with large datasets. The class structure and methods are well-implemented.
However, there are a few minor improvements that can be made:
- Remove commented-out code to improve readability. For example, lines 33-35 can be deleted.
- Consider adding type hints to method return values for better code documentation.
Here's an example of adding type hints:
- def get_feature_names(self): + def get_feature_names(self) -> List[str]: return self.features_df.columns
Line range hint
185-243
: Consider deprecatingCommonPandasDataset
in favor ofCommonPolarsDataset
The
CommonPandasDataset
class has been kept, likely for backward compatibility. However, as the codebase is transitioning to Polars, it might be beneficial to encourage migration to the Polars version.Suggestions for improvement:
- Add a deprecation warning to the class to encourage users to migrate to
CommonPolarsDataset
. For example:import warnings @gin.configurable("CommonPandasDataset") class CommonPandasDataset(Dataset): def __init__(self, *args, **kwargs): warnings.warn("CommonPandasDataset is deprecated. Use CommonPolarsDataset instead.", DeprecationWarning, stacklevel=2) super().__init__(*args, **kwargs) # ... rest of the class implementation
- Add type hints to method return values for consistency with the Polars version. For example:
- def get_feature_names(self): + def get_feature_names(self) -> List[str]: return self.features_df.columns
- Consider adding a migration guide or helper function to assist users in transitioning from the Pandas version to the Polars version.
Line range hint
244-335
: Consider deprecatingPredictionPandasDataset
in favor ofPredictionPolarsDataset
The
PredictionPandasDataset
class has been kept, likely for backward compatibility. However, as the codebase is transitioning to Polars, it might be beneficial to encourage migration to the Polars version.Suggestions for improvement:
- Add a deprecation warning to the class to encourage users to migrate to
PredictionPolarsDataset
. For example:import warnings @gin.configurable("PredictionPandasDataset") class PredictionPandasDataset(CommonPandasDataset): def __init__(self, *args, **kwargs): warnings.warn("PredictionPandasDataset is deprecated. Use PredictionPolarsDataset instead.", DeprecationWarning, stacklevel=2) super().__init__(*args, **kwargs) # ... rest of the class implementation
- Add type hints to method return values for consistency with the Polars version. For example:
- def get_balance(self) -> list: + def get_balance(self) -> List[float]:
Consider adding a migration guide or helper function to assist users in transitioning from the Pandas version to the Polars version.
There's an inconsistency in the
get_balance
method. The commented-out line 308 should be removed, and the return statement should be simplified:- # weights = list((1 / counts) * np.sum(counts) / counts.shape[0]) - return list((1 / counts) * np.sum(counts) / counts.shape[0]) + return ((1 / counts) * np.sum(counts) / counts.shape[0]).tolist()
Line range hint
336-465
: UpdateImputationPandasDataset
andImputationPredictionDataset
to use PolarsThe
ImputationPandasDataset
andImputationPredictionDataset
classes have not been updated to use Polars, which is inconsistent with the rest of the file.Suggestions for improvement:
Create Polars versions of these classes for consistency with the rest of the file. This would involve rewriting the data manipulation logic using Polars operations.
Add deprecation warnings to the existing Pandas-based classes to encourage migration to the Polars versions once they are implemented. For example:
import warnings @gin.configurable("ImputationPandasDataset") class ImputationPandasDataset(CommonPandasDataset): def __init__(self, *args, **kwargs): warnings.warn("ImputationPandasDataset is deprecated. Use ImputationPolarsDataset instead.", DeprecationWarning, stacklevel=2) super().__init__(*args, **kwargs) # ... rest of the class implementation @gin.configurable("ImputationPredictionDataset") class ImputationPredictionDataset(Dataset): def __init__(self, *args, **kwargs): warnings.warn("ImputationPredictionDataset is deprecated. Use ImputationPredictionPolarsDataset instead.", DeprecationWarning, stacklevel=2) super().__init__(*args, **kwargs) # ... rest of the class implementation
Add type hints to method return values for consistency with the other classes in the file.
Consider updating the
ampute_data
function to work with Polars DataFrames, or create a Polars-compatible version of this function.As part of the transition to Polars, consider creating a base class (e.g.,
BasePolarsDataset
) that implements common functionality for all Polars-based dataset classes. This would help reduce code duplication and ensure consistency across different dataset types.
Line range hint
1-465
: Overall positive changes with room for further improvementsThe transition to Polars for most dataset classes is a significant improvement that should enhance performance, especially for large datasets. The new
CommonPolarsDataset
andPredictionPolarsDataset
classes are well-implemented and make good use of Polars' features.Key points and recommendations:
The transition to Polars is a positive change that should improve performance and maintainability.
Complete the transition to Polars by updating
ImputationPandasDataset
andImputationPredictionDataset
to use Polars.Add deprecation warnings to all Pandas-based classes to encourage migration to their Polars counterparts.
Improve code documentation by adding type hints consistently across all classes and methods.
Consider creating a base Polars dataset class to reduce code duplication and ensure consistency across different dataset types.
Remove any remaining commented-out code to improve readability.
By addressing these points, the code will be more consistent, maintainable, and future-proof. Great work on the transition to Polars, and keep up the momentum to fully leverage its benefits across all dataset classes.
icu_benchmarks/data/split_process_data.py (4)
3-4
: LGTM! Transition to Polars and new functionality additions.The changes in imports and function signature reflect the transition to Polars and the addition of new functionality for modality selection and data exclusion. This should improve performance and flexibility.
Consider using type hints for the new parameters
modality_mapping
andselected_modalities
to improve code readability and maintainability.Also applies to: 9-9, 12-12, 14-14, 23-23, 26-27, 40-42
74-80
: Improved data handling and preprocessing.The changes enhance label handling flexibility, introduce data sanitization, and provide better logging for data quality issues. The transition to Polars is consistently implemented.
Consider adding error handling for the case when no label is selected after filtering. For example:
if not vars[Var.label]: raise ValueError("No label selected after filtering.")Also applies to: 118-134, 156-177
190-238
: New data preprocessing functions: Approved.The new functions
check_sanitize_data
andmodality_selection
add important steps for data cleaning and filtering. They are well-implemented and should improve data quality.In the
modality_selection
function, consider adding error handling for the case when none of the selected modalities are found in the modality mapping. For example:if not any(col in modality_mapping.keys() for col in selected_modalities): raise ValueError("None of the selected modalities found in modality mapping.")🧰 Tools
🪛 Ruff
192-192: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
193-193: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
195-195: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
199-199: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
203-203: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
218-218: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
234-234: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
Line range hint
240-442
: Data splitting functions updated for Polars support: Approved.The
make_train_val
andmake_single_split
functions have been successfully adapted to work with both Polars and Pandas DataFrames. This maintains backwards compatibility while allowing for the performance benefits of Polars.Consider refactoring the common logic for Polars and Pandas into separate helper functions to reduce code duplication. For example:
def _get_stays(data, id, polars): return pl.Series(name=id, values=data[Segment.outcome][id].unique()) if polars else pd.Series(data[Segment.outcome][id].unique(), name=id) def _get_labels(data, id, vars, polars): if polars: return data[Segment.outcome].group_by(id).max()[vars[Var.label]] else: return data[Segment.outcome].groupby(id).max()[vars[Var.label]].reset_index(drop=True) # Use these helper functions in both make_train_val and make_single_splitThis refactoring would improve maintainability and reduce the risk of inconsistencies between the two functions.
🧰 Tools
🪛 Ruff
303-303: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
311-311: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
315-315: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
icu_benchmarks/models/wrappers.py (3)
6-6
: Unused imports and configurations detected.The newly added imports
average_precision_score
androc_auc_score
fromsklearn.metrics
, along with their corresponding gin configurations, are not currently used in the code. Consider removing them if they are not needed, or implement their usage if they are intended to be used in the future.Also applies to: 33-34
376-379
: Commented-out code for new metrics.There's commented-out code related to
average_precision_score
androc_auc_score
. If these metrics are intended to be used, consider implementing them. Otherwise, it might be better to remove the commented code to improve readability.
Line range hint
562-589
: Approved: Good improvements to the ImputationWrapper class.The new
init_weights
method provides a flexible way to initialize weights using different strategies. The changes inon_fit_start
andstep_fn
methods enhance the functionality of the imputation wrapper.As a minor suggestion, consider adding type hints to the
init_weights
method parameters for consistency with the rest of the codebase:def init_weights(self, init_type: ImputationInit = ImputationInit.NORMAL, gain: float = 0.02):Also applies to: 591-595, 597-611
icu_benchmarks/data/preprocessor.py (6)
77-143
: LGTM! Comprehensive implementation with good error handling.The
apply
method effectively handles both static and dynamic features, with proper checks and integration. The logging statements are helpful for debugging.Consider extracting the repeated code for joining static and dynamic data into a separate method to improve readability and maintainability.
Here's a suggestion for extracting the join operation:
def _join_static_dynamic(self, data, vars): for split in [Split.train, Split.val, Split.test]: data[split][Segment.dynamic] = data[split][Segment.dynamic].join( data[split][Segment.static], on=vars["GROUP"] ) return dataYou can then call this method in the
apply
function instead of repeating the join operation for each split.
145-198
: LGTM! Comprehensive preprocessing for both static and dynamic features.The
_process_static
and_process_dynamic
methods provide a thorough approach to data preprocessing. The use of the Recipe class allows for a clear and modular implementation of various preprocessing steps.Consider adding error handling for potential exceptions that might occur during the preprocessing steps, especially for the
apply_recipe_to_splits
function calls.Here's a suggestion for adding error handling:
try: data = apply_recipe_to_splits(sta_rec, data, Segment.static, self.save_cache, self.load_cache) except Exception as e: logging.error(f"Error occurred while processing static features: {str(e)}") raiseApply similar error handling to the dynamic feature processing as well.
215-276
: LGTM! Well-structured regression preprocessor.The
PolarsRegressionPreprocessor
effectively extends the classification preprocessor for regression tasks. The outcome processing is flexible, allowing for predefined ranges or automatic scaling.Consider moving the outcome processing logic to a separate method to improve code organization and reusability.
Here's a suggestion for extracting the outcome processing logic:
def _scale_outcome(self, outcome_data, vars): outcome_rec = Recipe(outcome_data, vars["LABEL"], [], vars["GROUP"]) if self.outcome_max is not None and self.outcome_min is not None: outcome_rec.add_step( StepSklearn( sklearn_transformer=FunctionTransformer( func=lambda x: ((x - self.outcome_min) / (self.outcome_max - self.outcome_min)) ), sel=all_outcomes(), ) ) else: outcome_rec.add_step(StepSklearn(MinMaxScaler(), sel=all_outcomes())) return outcome_rec.prep().bake()You can then call this method in the
_process_outcome
function.
Line range hint
279-475
: LGTM! Consistent implementation for Pandas-based preprocessing.The Pandas-based preprocessor classes maintain consistency with their Polars counterparts, which is excellent for code maintainability and understanding.
Consider creating a base preprocessor class that contains common functionality for both Polars and Pandas implementations. This could help reduce code duplication and make it easier to maintain both versions in the future.
Here's a high-level suggestion for creating a base preprocessor:
class BasePreprocessor(abc.ABC): @abc.abstractmethod def apply(self, data, vars): pass @abc.abstractmethod def _process_static(self, data, vars): pass @abc.abstractmethod def _process_dynamic(self, data, vars): pass # Common methods and properties can be defined hereThen, both Polars and Pandas preprocessors can inherit from this base class:
class PolarsClassificationPreprocessor(BasePreprocessor): # Polars-specific implementation class PandasClassificationPreprocessor(BasePreprocessor): # Pandas-specific implementationThis structure would help centralize common logic and make it easier to maintain both implementations.
Line range hint
476-536
: LGTM! Focused implementation for imputation preprocessing.The
PandasImputationPreprocessor
class provides a targeted approach for imputation tasks. The options for scaling and using static features offer flexibility in preprocessing.However, the current implementation of
_process_dynamic_data
removes entire stays (groups of data) if any missing values are found. This approach might lead to significant data loss in some scenarios.Consider implementing a more granular approach to handling missing values. For example:
- Allow for a threshold of missing values before removing a stay.
- Implement different strategies for handling missing values based on the nature of the data (e.g., forward fill for time series data).
Here's a suggestion for a more flexible approach:
def _process_dynamic_data(self, data, vars): if self.filter_missing_values: missing_ratio = data[Segment.dynamic][vars[Segment.dynamic]].isna().mean(axis=1) ids_to_remove = data[Segment.dynamic].loc[missing_ratio > self.missing_threshold][vars["GROUP"]].unique() data = {table_name: table.loc[~table[vars["GROUP"]].isin(ids_to_remove)] for table_name, table in data.items()} logging.info(f"Removed {len(ids_to_remove)} stays with more than {self.missing_threshold*100}% missing values.") return dataThis approach allows you to set a threshold for the acceptable percentage of missing values before removing a stay.
Line range hint
537-589
: LGTM! Robust implementation of recipe application and caching.The utility functions provide a flexible approach to applying recipes across different data splits and managing recipe caching. The support for both Polars and Pandas DataFrames is a nice touch for compatibility.
For consistency, consider adding similar error handling to the
cache_recipe
function as you have inrestore_recipe
.Here's a suggestion for adding error handling to
cache_recipe
:def cache_recipe(recipe: Recipe, cache_file: str) -> None: """Cache recipe to make it available for e.g. transfer learning.""" try: recipe_cache = copy.deepcopy(recipe) recipe_cache.cache() if not (cache_file / "..").exists(): (cache_file / "..").mkdir(parents=True, exist_ok=True) with open(cache_file, "wb") as f: pickle.dump(recipe_cache, f, pickle.HIGHEST_PROTOCOL) logging.info(f"Cached recipe in {cache_file}.") except Exception as e: logging.error(f"Error occurred while caching recipe: {str(e)}") raiseThis change adds error handling and ensures the parent directory is created if it doesn't exist.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (16)
- icu_benchmarks/cross_validation.py (5 hunks)
- icu_benchmarks/data/loader.py (5 hunks)
- icu_benchmarks/data/pooling.py (5 hunks)
- icu_benchmarks/data/preprocessor.py (8 hunks)
- icu_benchmarks/data/split_process_data.py (11 hunks)
- icu_benchmarks/models/constants.py (2 hunks)
- icu_benchmarks/models/custom_metrics.py (2 hunks)
- icu_benchmarks/models/ml_models/catboost.py (1 hunks)
- icu_benchmarks/models/ml_models/imblearn.py (1 hunks)
- icu_benchmarks/models/ml_models/xgboost.py (1 hunks)
- icu_benchmarks/models/train.py (10 hunks)
- icu_benchmarks/models/utils.py (2 hunks)
- icu_benchmarks/models/wrappers.py (10 hunks)
- icu_benchmarks/run.py (7 hunks)
- icu_benchmarks/run_utils.py (6 hunks)
- icu_benchmarks/tuning/hyperparameters.py (8 hunks)
🚧 Files skipped from review as they are similar to previous changes (5)
- icu_benchmarks/cross_validation.py
- icu_benchmarks/models/constants.py
- icu_benchmarks/models/custom_metrics.py
- icu_benchmarks/models/ml_models/imblearn.py
- icu_benchmarks/models/utils.py
🧰 Additional context used
🪛 Ruff
icu_benchmarks/data/pooling.py
70-70: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
106-106: Do not use mutable data structures for argument defaults
Replace with
None
; initialize within function(B006)
108-108: Do not use mutable data structures for argument defaults
Replace with
None
; initialize within function(B006)
icu_benchmarks/data/split_process_data.py
26-26: Do not use mutable data structures for argument defaults
Replace with
None
; initialize within function(B006)
41-41: Do not use mutable data structures for argument defaults
Replace with
None
; initialize within function(B006)
119-119: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
124-124: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
124-124: Test for membership should be
not in
Convert to
not in
(E713)
192-192: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
193-193: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
195-195: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
199-199: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
203-203: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
218-218: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
234-234: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
303-303: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
311-311: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
315-315: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
409-412: Use ternary operator
dev_stays = stays[dev] if polars else stays.iloc[dev]
instead ofif
-else
-blockReplace
if
-else
-block withdev_stays = stays[dev] if polars else stays.iloc[dev]
(SIM108)
428-428: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
436-436: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
440-440: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
icu_benchmarks/models/ml_models/catboost.py
14-14: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
icu_benchmarks/models/ml_models/xgboost.py
23-23: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
63-63: f-string without any placeholders
Remove extraneous
f
prefix(F541)
icu_benchmarks/models/wrappers.py
474-474: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
🔇 Additional comments (22)
icu_benchmarks/models/ml_models/catboost.py (2)
7-9
: LGTM: Class definition and attributesThe class definition looks good. The use of
@gin.configurable
allows for easy configuration, and the_supported_run_modes
attribute correctly specifies the supported task type.
1-19
: Overall impression: Good implementation with room for minor improvementsThe
CBClassifier
class is well-structured and implements the CatBoost classifier effectively. The use ofgin
for configuration and inheritance fromMLWrapper
shows good design practices. There are a few minor improvements suggested throughout the review, including:
- Fixing a typo in the import statement.
- Making the task type configurable in the constructor.
- Enhancing the docstring for the
predict
method.- Refactoring to address the star-arg unpacking warning.
Implementing these suggestions will further improve the code's clarity, flexibility, and maintainability. Overall, this is a solid addition to the project.
🧰 Tools
🪛 Ruff
14-14: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
icu_benchmarks/models/ml_models/xgboost.py (3)
18-24
: LGTM! Well-structured class definition.The class definition and initialization look good:
- The
@gin.configurable
decorator allows for flexible configuration.- The
_supported_run_modes
class attribute clearly defines the intended use.- The
__init__
method is concise and usesset_model_args
for model initialization.🧰 Tools
🪛 Ruff
23-23: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
26-28
: LGTM! Concise predict method.The
predict
method is simple and correctly usespredict_proba
for classification tasks, returning probability predictions.
1-86
: Overall assessment: Well-implemented XGBoost classifier with room for minor improvements.The
XGBClassifier
implementation is solid and includes essential methods for training, prediction, and evaluation. It demonstrates good practices such as:
- Use of gin for configuration
- Integration with WandB for experiment tracking
- Computation of SHAP values for model interpretability
- Flexible model argument setting
To further improve the implementation, consider addressing the following:
- Organize imports according to PEP 8 guidelines
- Improve robustness of the
set_model_args
method- Add error handling to
get_feature_importance
- Refactor the
test_step
method for better readability- Address the static analysis warnings
These improvements will enhance the overall quality, readability, and robustness of the code.
🧰 Tools
🪛 Ruff
23-23: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
63-63: f-string without any placeholders
Remove extraneous
f
prefix(F541)
icu_benchmarks/run.py (4)
21-21
: LGTM: New import added forget_config_files
The addition of
get_config_files
to the imports is appropriate. This function will likely be used to retrieve configuration files for tasks and models, which aligns with the changes described in the AI-generated summary.
53-53
: LGTM: Newmodalities
argument added to main functionThe addition of the
modalities
argument to the main function enhances the flexibility of the script, allowing users to specify which data modalities to use in the analysis. This change is consistent with the improvements mentioned in the AI-generated summary.
84-86
: LGTM: Improved formatting for wandb config updateThe changes to the
update_wandb_config
call enhance readability by breaking the long line into multiple lines. The use of a ternary operator to handle the case wherepretrained_imputation_model
is None is a concise and effective approach.
145-148
: LGTM: Switched to keyword arguments for improved clarityThe modification of the
choose_and_bind_hyperparameters_optuna
function call to use keyword arguments instead of positional arguments is a positive change. This improves code readability, makes the function call more explicit, and reduces the likelihood of errors when modifying the call in the future. It's consistent with best practices for function calls with multiple arguments.icu_benchmarks/data/pooling.py (3)
19-29
: LGTM: Improved method signature formattingThe
__init__
method signature has been reformatted for better readability, with parameters aligned vertically. This change enhances code clarity without affecting functionality.
51-55
: LGTM: Improved method signature formattingThe
generate
method signature has been reformatted for better readability, with parameters aligned vertically. This change enhances code clarity without affecting functionality.
148-150
: LGTM: Improved stratification in train-test splitThe modification to use
stratify=labels
in thetrain_test_split
function for classification mode is a good improvement. This ensures that the class distribution is roughly the same in both the train and test sets, which is crucial for maintaining the integrity of the model evaluation process.icu_benchmarks/models/train.py (3)
6-6
: LGTM: Imports updated for polars integrationThe addition of polars import and the inclusion of PredictionPolarsDataset are consistent with the transition to using polars for data handling. These changes lay the groundwork for the polars-based data processing in the rest of the file.
Also applies to: 14-14
30-30
: Verify impact of parameter changes intrain_common
The function signature has been updated to reflect the transition to polars and add flexibility:
- Data type hint changed from
pd.DataFrame
topl.DataFrame
.- New
polars
parameter added with defaultTrue
.However, there are significant changes to default values:
batch_size
default changed from 64 to 1epochs
default changed from 1000 to 100These changes could significantly impact the training process and model performance.
Please clarify the rationale behind these changes and confirm that they have been thoroughly tested. Consider adding comments in the code to explain these default value changes.
Also applies to: 41-42, 54-54
134-134
: LGTM: Model loading and instantiation updated for CPU flexibilityThe addition of the
cpu
parameter to both theload_model
call and model instantiation enhances flexibility, allowing for easier CPU-based operations when needed. This change is consistent with thecpu
parameter in thetrain_common
function signature.Also applies to: 136-136
icu_benchmarks/run_utils.py (2)
18-20
: LGTM: New imports added for enhanced functionality.The addition of
os
,glob
, andpolars
imports is appropriate for the new features implemented in this file. These libraries are used effectively in the newget_config_files
function and the updatedaggregate_results
function.
Line range hint
1-269
: Overall assessment: Positive improvements to run_utils.pyThe changes made to
run_utils.py
enhance its functionality and flexibility in several ways:
- New imports support additional features.
- Additional command-line arguments allow for more specific evaluations.
- Improved directory creation logic ensures unique run directories.
- SHAP values aggregation adds valuable analytical capabilities.
- The new
get_config_files
function streamlines configuration file management.These modifications collectively improve the script's versatility and robustness. The implementation is consistent with the existing code style, and the changes are well-integrated into the file's structure.
icu_benchmarks/tuning/hyperparameters.py (2)
Line range hint
266-271
: LGTM! Improved callback integration with Optuna.The updated tuning callback function now works seamlessly with Optuna's study and trial objects. The logging improvements provide better visibility into the tuning process, which is valuable for monitoring and debugging.
The integration with Weights & Biases (wandb) for logging the optimization iteration is a good addition for those using wandb for experiment tracking.
Line range hint
1-383
: Overall, excellent implementation of Optuna for hyperparameter optimization!The changes in this file significantly enhance the hyperparameter tuning capabilities of the project. The integration of Optuna provides a more flexible and powerful optimization framework, while the added visualization features offer valuable insights into the tuning process. The code is well-structured, with appropriate error handling and logging throughout.
Key improvements include:
- Flexible handling of various hyperparameter types and ranges
- Robust checkpoint management for resumable studies
- Integration with Weights & Biases for enhanced tracking
- Visualization of hyperparameter importances and optimization history
These changes will greatly benefit the project by enabling more efficient and insightful hyperparameter tuning processes.
icu_benchmarks/data/split_process_data.py (1)
Line range hint
1-442
: Overall assessment: Significant improvements with minor suggestions for refinement.The transition to Polars and the addition of new functionality for modality selection and data sanitization are well-implemented. These changes should lead to improved performance and more flexible data handling. The suggested refinements, including error handling improvements, refactoring to reduce code duplication, and addressing static analysis warnings, will further enhance the code quality and maintainability.
🧰 Tools
🪛 Ruff
192-192: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
193-193: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
195-195: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
199-199: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
203-203: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
218-218: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
234-234: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
icu_benchmarks/models/wrappers.py (1)
Line range hint
1-624
: Summary of review findings
- The changes to the wrapper classes, especially the ImputationWrapper, generally improve functionality and flexibility.
- There are several instances of unused imports, parameters, and attributes that should be addressed.
- The MPS compatibility type casting in the
log_metrics
method needs review.- Minor improvements can be made for performance and code consistency.
Overall, the changes are positive, but addressing the identified issues will further enhance the code quality and maintainability.
🧰 Tools
🪛 Ruff
474-474: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
icu_benchmarks/data/preprocessor.py (1)
Line range hint
9-75
: LGTM! Polars integration looks promising.The introduction of Polars and the new
PolarsClassificationPreprocessor
class is a positive change. Polars is known for its performance benefits in data processing tasks. The flexible initialization parameters allow for customized preprocessing, which is great for adaptability to different datasets and requirements.
icu_benchmarks/data/loader.py
Outdated
@gin.configurable("PredictionPolarsDataset") | ||
class PredictionPolarsDataset(CommonPolarsDataset): | ||
"""Subclass of common dataset for prediction tasks. | ||
|
||
Args: | ||
ram_cache (bool, optional): Whether the complete dataset should be stored in ram. Defaults to True. | ||
""" | ||
|
||
def __init__(self, *args, ram_cache: bool = True, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.outcome_df = self.grouping_df | ||
self.ram_cache(ram_cache) | ||
|
||
def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]: | ||
"""Function to sample from the data split of choice. Used for deep learning implementations. | ||
|
||
Args: | ||
idx: A specific row index to sample. | ||
|
||
Returns: | ||
A sample from the data, consisting of data, labels and padding mask. | ||
""" | ||
if self._cached_dataset is not None: | ||
return self._cached_dataset[idx] | ||
|
||
pad_value = 0.0 | ||
# stay_id = self.outcome_df.index.unique()[idx] # [self.vars["GROUP"]] | ||
stay_id = self.outcome_df[self.vars["GROUP"]].unique()[idx] # [self.vars["GROUP"]] | ||
|
||
# slice to make sure to always return a DF | ||
# window = self.features_df.loc[stay_id:stay_id].to_numpy() | ||
# labels = self.outcome_df.loc[stay_id:stay_id][self.vars["LABEL"]].to_numpy(dtype=float) | ||
window = self.features_df.filter(pl.col(self.vars["GROUP"]) == stay_id).to_numpy() | ||
labels = self.outcome_df.filter(pl.col(self.vars["GROUP"]) == stay_id)[self.vars["LABEL"]].to_numpy().astype(float) | ||
|
||
if len(labels) == 1: | ||
# only one label per stay, align with window | ||
labels = np.concatenate([np.empty(window.shape[0] - 1) * np.nan, labels], axis=0) | ||
|
||
length_diff = self.maxlen - window.shape[0] | ||
pad_mask = np.ones(window.shape[0]) | ||
|
||
# Padding the array to fulfill size requirement | ||
if length_diff > 0: | ||
# window shorter than the longest window in dataset, pad to same length | ||
window = np.concatenate([window, np.ones((length_diff, window.shape[1])) * pad_value], axis=0) | ||
labels = np.concatenate([labels, np.ones(length_diff) * pad_value], axis=0) | ||
pad_mask = np.concatenate([pad_mask, np.zeros(length_diff)], axis=0) | ||
|
||
not_labeled = np.argwhere(np.isnan(labels)) | ||
if len(not_labeled) > 0: | ||
labels[not_labeled] = -1 | ||
pad_mask[not_labeled] = 0 | ||
|
||
pad_mask = pad_mask.astype(bool) | ||
labels = labels.astype(np.float32) | ||
data = window.astype(np.float32) | ||
|
||
return from_numpy(data), from_numpy(labels), from_numpy(pad_mask) | ||
|
||
def get_balance(self) -> list: | ||
"""Return the weight balance for the split of interest. | ||
|
||
Returns: | ||
Weights for each label. | ||
""" | ||
counts = self.outcome_df[self.vars["LABEL"]].value_counts(parallel=True).get_columns()[1] | ||
counts = counts.to_numpy() | ||
weights = list((1 / counts) * np.sum(counts) / counts.shape[0]) | ||
return weights | ||
|
||
def get_data_and_labels(self) -> Tuple[np.array, np.array, np.array]: | ||
"""Function to return all the data and labels aligned at once. | ||
|
||
We use this function for the ML methods which don't require an iterator. | ||
|
||
Returns: | ||
A Tuple containing data points and label for the split. | ||
""" | ||
labels = self.outcome_df[self.vars["LABEL"]].to_numpy().astype(float) | ||
rep = self.features_df | ||
|
||
if len(labels) == self.num_stays: | ||
# order of groups could be random, we make sure not to change it | ||
# rep = rep.groupby(level=self.vars["GROUP"], sort=False).last() | ||
rep = rep.group_by(self.vars["GROUP"]).last() | ||
else: | ||
# Adding segment count for each stay id and timestep. | ||
rep = rep.with_columns(pl.col(self.vars["GROUP"]).cum_count().over(self.vars["GROUP"]).alias("counter")) | ||
rep = rep.to_numpy().astype(float) | ||
logging.debug(f"rep shape: {rep.shape}") | ||
logging.debug(f"labels shape: {labels.shape}") | ||
return rep, labels, self.row_indicators.to_numpy() | ||
|
||
def to_tensor(self): | ||
data, labels, row_indicators = self.get_data_and_labels() | ||
if self.mps: | ||
return from_numpy(data).to(float32), from_numpy(labels).to(float32) | ||
else: | ||
return from_numpy(data), from_numpy(labels), row_indicators |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Well-implemented PredictionPolarsDataset
with Polars integration
The PredictionPolarsDataset
class is well-implemented and consistently uses Polars for data manipulation. The methods are appropriately adapted for prediction tasks.
Here are some suggestions for improvement:
- Add type hints to method return values for better code documentation. For example:
- def get_data_and_labels(self):
+ def get_data_and_labels(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
-
In the
get_balance
method, consider adding a comment explaining the weight calculation formula for better understanding. -
In the
__getitem__
method, consider using Polars' built-in methods for performance optimization. For example, replace the numpy operations with Polars operations where possible.
Here's an example of how you might optimize the __getitem__
method:
def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]:
if self._cached_dataset is not None:
return self._cached_dataset[idx]
pad_value = 0.0
stay_id = self.outcome_df[self.vars["GROUP"]].unique()[idx]
window = self.features_df.filter(pl.col(self.vars["GROUP"]) == stay_id)
labels = self.outcome_df.filter(pl.col(self.vars["GROUP"]) == stay_id)[self.vars["LABEL"]]
if len(labels) == 1:
labels = pl.concat([pl.Series([None] * (window.height - 1)), labels])
length_diff = self.maxlen - window.height
pad_mask = pl.Series([1] * window.height)
if length_diff > 0:
window = window.vstack(pl.DataFrame({col: [pad_value] * length_diff for col in window.columns}))
labels = labels.extend(pl.Series([pad_value] * length_diff))
pad_mask = pad_mask.extend(pl.Series([0] * length_diff))
labels = labels.fill_null(-1)
pad_mask = pad_mask.where(labels != -1, 0)
return (
from_numpy(window.to_numpy().astype(np.float32)),
from_numpy(labels.to_numpy().astype(np.float32)),
from_numpy(pad_mask.to_numpy().astype(bool))
)
This version uses more Polars operations and should be more efficient, especially for larger datasets.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 15
🧹 Outside diff range and nitpick comments (53)
icu_benchmarks/models/ml_models/lgbm.py (2)
11-27
: LGTM: LGBMWrapper implementation is solid, with a minor suggestion for improvement.The LGBMWrapper class provides a robust implementation for fitting LightGBM models. It includes early stopping, logging callbacks, and wandb integration, which are all excellent practices. The method returns the validation loss, which is useful for model evaluation.
Consider adding type hints to improve code readability and maintainability. For example:
def fit_model(self, train_data: np.ndarray, train_labels: np.ndarray, val_data: np.ndarray, val_labels: np.ndarray) -> float:This would make the expected input and output types clearer to other developers.
30-48
: LGTM: LGBMClassifier implementation is well-structured, with a minor suggestion.The LGBMClassifier class is well-implemented, utilizing gin for configuration and correctly specifying its supported run mode. The predict method is properly documented and returns class probabilities as expected.
To enhance code robustness, consider adding input validation in the predict method. For example:
def predict(self, features: np.ndarray) -> np.ndarray: if not isinstance(features, np.ndarray): raise TypeError("features must be a numpy array") if features.ndim != 2: raise ValueError("features must be a 2D array") return self.model.predict_proba(features)This would help catch potential errors earlier in the prediction process.
icu_benchmarks/models/dl_models/tcn.py (2)
12-16
: Enhance the class docstring for better documentation.While the current docstring mentions the adaptation from the original TCN paper, it could be more informative. Consider expanding it to include:
- A brief explanation of what a Temporal Convolutional Network does
- Key features or advantages of this implementation
- Description of important parameters
- Example usage, if applicable
This would greatly improve the documentation and make it easier for other developers to understand and use the class.
1-62
: Consider adding type hints for improved code quality.While the implementation is solid, adding type hints to method parameters and return values could further improve code readability and help catch potential type-related issues early in development. This is especially useful for complex classes like TCN.
Example for the
__init__
method:def __init__(self, input_size: Tuple[int, int, int], num_channels: Union[int, List[int]], num_classes: int, *args, max_seq_length: int = 0, kernel_size: int = 2, dropout: float = 0.0, **kwargs) -> None:And for the
forward
method:def forward(self, x: torch.Tensor) -> torch.Tensor:Consider adding similar type hints throughout the class for improved code quality and maintainability.
🧰 Tools
🪛 Ruff
23-23: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
icu_benchmarks/models/ml_models/sklearn.py (2)
Line range hint
54-66
: Minor inconsistency in SVM classes.The
SVMClassifier
andSVMRegressor
classes useself.model_args
instead ofself.set_model_args
which is used in all other classes. This might be a typo or an oversight.Consider updating these lines for consistency:
- self.model = self.model_args(svm.SVC, *args, **kwargs) + self.model = self.set_model_args(svm.SVC, *args, **kwargs)- self.model = self.model_args(svm.SVR, *args, **kwargs) + self.model = self.set_model_args(svm.SVR, *args, **kwargs)
Line range hint
1-93
: Overall assessment: Simplified sklearn.py with potential wider impact.The changes to this file have significantly simplified it by removing LightGBM-related code and focusing on scikit-learn models. This may improve maintainability and reduce dependencies. However, it's crucial to ensure that:
- The removal of LightGBM functionality doesn't negatively impact other parts of the codebase that may have depended on it.
- Any documentation or user guides are updated to reflect the removal of LightGBM support.
- The minor inconsistency in the SVM classes is addressed for better code uniformity.
These changes align with the PR objectives of incorporating enhancements from the Cassandra project, although the specific connection isn't clear from this file alone. The introduction of new models mentioned in the PR objectives isn't evident in this file, so it may be worth checking other files in the PR for those additions.
Consider updating the module docstring (if it exists) to reflect the current set of supported models and the removal of LightGBM support.
icu_benchmarks/models/dl_models/rnn.py (2)
8-31
: LGTM: RNNet class implementation is well-structured.The RNNet class is well-implemented with appropriate initialization, hidden state management, and forward pass. The use of @gin.configurable allows for flexible configuration.
Consider adding type hints to method parameters and return values for improved code readability and maintainability. For example:
def init_hidden(self, x: torch.Tensor) -> torch.Tensor: # ... def forward(self, x: torch.Tensor) -> torch.Tensor: # ...🧰 Tools
🪛 Ruff
16-16: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
1-85
: Summary: Good implementation of RNN models with some areas for improvementOverall, the implementation of RNNet, LSTMNet, and GRUNet classes is well-structured and follows good practices. However, there are a few areas that need attention:
- Add the missing return statement in the GRUNet's forward method to ensure correct functionality.
- Refactor the super().init() calls in all classes to avoid star-arg unpacking after keyword arguments.
- Consider adding type hints to improve code readability and maintainability.
These changes will enhance the code quality and ensure proper functionality of the RNN models. Once these issues are addressed, the implementation will be robust and ready for use.
To further improve the code, consider the following suggestions:
- Implement a base RNN class that encapsulates common functionality, and have RNNet, LSTMNet, and GRUNet inherit from it. This would reduce code duplication and improve maintainability.
- Add docstrings to methods, especially
forward
, to clarify the expected input and output shapes.- Consider adding a method to save and load model weights, which could be useful for model persistence and deployment.
🧰 Tools
🪛 Ruff
16-16: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
42-42: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
69-69: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
icu_benchmarks/models/dl_models/transformer.py (2)
52-64
: Consider using enumerate or range(depth) directly.The loop control variables
i
in the initialization of transformer blocks are unused. Consider usingenumerate
if you need the index, or userange(depth)
directly if you don't need it.Here's a suggested change:
# In Transformer class self.tblocks = nn.Sequential(*[ TransformerBlock( emb=hidden, hidden=hidden, heads=heads, mask=True, ff_hidden_mult=ff_hidden_mult, dropout=dropout, dropout_att=dropout_att, ) for _ in range(depth) ]) # In LocalTransformer class self.tblocks = nn.Sequential(*[ LocalBlock( emb=hidden, hidden=hidden, heads=heads, mask=True, ff_hidden_mult=ff_hidden_mult, local_context=local_context, dropout=dropout, dropout_att=dropout_att, ) for _ in range(depth) ])This change eliminates the unused variable warning and makes the code more concise.
Also applies to: 123-135
🧰 Tools
🪛 Ruff
52-52: Loop control variable
i
not used within loop bodyRename unused
i
to_i
(B007)
1-148
: Overall assessment: Good implementation with room for improvementThe
transformer.py
file implements two transformer-based models (Transformer
andLocalTransformer
) with a clear structure and good use of the gin configuration system. The code is functional and follows the expected architecture for transformer models.Main points:
- Good use of gin for configuration and clear support for classification and regression modes.
- Well-structured forward methods implementing the transformer architecture.
- Significant code duplication between
Transformer
andLocalTransformer
classes.- Use of
*args
in__init__
methods, which can lead to unexpected behavior.- Minor issues with unused loop control variables.
Recommendations:
- Refactor to introduce a base class for both transformer types to reduce code duplication.
- Use configuration objects instead of numerous parameters and
*args
in__init__
methods.- Address minor issues like unused loop control variables.
These changes would significantly improve code maintainability and readability while preserving the current functionality.
🧰 Tools
🪛 Ruff
37-37: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
52-52: Loop control variable
i
not used within loop bodyRename unused
i
to_i
(B007)
106-106: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
123-123: Loop control variable
i
not used within loop bodyRename unused
i
to_i
(B007)
icu_benchmarks/models/ml_models/xgboost.py (1)
42-59
: LGTM:fit_model
method is well-implemented with a minor suggestion.The
fit_model
method is correctly implemented with good practices such as early stopping and optional integration with Weights & Biases. The conditional computation of SHAP values is a good approach for performance.Minor suggestion:
Consider using an f-string for the logging statement on line 48 for consistency with other logging calls in the file:logging.info(f"train_data: {train_data.shape}, train_labels: {train_labels.shape}")icu_benchmarks/run.py (4)
53-59
: LGTM: Added modalities and label binding with proper logging.The new code block enhances flexibility by allowing modalities and labels to be bound to the gin configuration. The logging statements are helpful for debugging.
Consider adding type checking for
modalities
andargs.label
to ensure they are of the expected type before binding. For example:if modalities and isinstance(modalities, list): logging.debug(f"Binding modalities: {modalities}") gin.bind_parameter("preprocess.selected_modalities", modalities) elif modalities: logging.warning(f"Unexpected type for modalities: {type(modalities)}. Expected list.") if args.label and isinstance(args.label, str): logging.debug(f"Binding label: {args.label}") gin.bind_parameter("preprocess.label", args.label) elif args.label: logging.warning(f"Unexpected type for label: {type(args.label)}. Expected string.")
60-64
: LGTM: Added config file retrieval and validation.The new code block adds an important validation check for the existence of the specified task and model. This is a good practice to catch configuration errors early in the execution.
To improve readability, consider using f-strings for the entire error message and simplifying the condition checks:
tasks, models = get_config_files(Path("configs")) if task not in tasks or model not in models: raise ValueError( f"Invalid configuration. " f"Task '{task}' {'not ' if task not in tasks else ''}found in tasks. " f"Model '{model}' {'not ' if model not in models else ''}found in models." )
134-135
: LGTM: Added support for different model types based on run mode.The modification to the
model_path
assignment adds support for different model types (imputation or prediction) based on the run mode. This change aligns well with the summary mentioning updates to support various model types.To improve readability, consider using a dictionary for model types:
model_types = { RunMode.imputation: "imputation_models", RunMode.prediction: "prediction_models" } model_path = Path("configs") / model_types[mode] / f"{model}.gin"This approach would make it easier to add new model types in the future if needed.
191-195
: LGTM: Added error handling for aggregate_results function.The addition of a try-except block for the
aggregate_results
function call is a good practice. It ensures that any exceptions during result aggregation are properly caught and logged, which aligns with the summary mentioning enhanced error handling.For consistency with the rest of the code, consider using f-strings for error logging:
try: aggregate_results(run_dir, execution_time) except Exception as e: logging.error(f"Failed to aggregate results: {e}") logging.debug(f"Error details:", exc_info=True)icu_benchmarks/models/train.py (5)
30-30
: Update function signature and default valuesThe changes to the
train_common
function signature reflect the transition to Polars and introduce new configuration options. However, there are a few points to consider:
The reduction in default
batch_size
(from 64 to 1) andepochs
(from 1000 to 100) is significant. This might affect training performance and model convergence. Please ensure this change is intentional and doesn't negatively impact the model's learning process.The new
polars
parameter should be documented in the function's docstring, explaining its purpose and impact on the data handling process.The
persistent_workers
parameter is introduced but set to None by default. Consider adding documentation about when and why this parameter should be used.Could you please update the function's docstring to include explanations for the new parameters and the reasoning behind the changes in default values?
Also applies to: 41-42, 54-55
85-91
: LGTM: Flexible dataset class selectionThe new dataset class selection logic is a good addition, allowing for flexibility between pandas and polars-based datasets. This change supports both backward compatibility and performance optimization.
However, there's a TODO comment that needs attention:
# todo: add support for polars versions of datasets
Would you like assistance in implementing the polars versions of the datasets to fully leverage the benefits of the Polars library?
113-113
: LGTM: Addition of persistent_workers to DataLoaderThe addition of the
persistent_workers
parameter to the DataLoader instantiation is a good optimization. This can potentially improve data loading performance, especially for larger datasets, by keeping worker processes alive between data loading iterations.Consider adding a comment or updating the documentation to explain:
- The performance implications of using persistent workers.
- Any potential memory considerations when using this feature.
- Guidelines on when to enable or disable this feature based on dataset size or system resources.
Also applies to: 121-121
148-148
: LGTM: Improved Trainer configurationThe changes to the Trainer configuration are beneficial:
- Setting
min_epochs=1
ensures that at least one epoch of training occurs, which is crucial for obtaining results.- Setting
num_sanity_val_steps=2
helps catch potential errors in the validation loop before the actual training begins.These additions provide better control and error checking in the training process.
Consider adding comments to explain:
- Why a minimum of 1 epoch is necessary.
- The reasoning behind choosing 2 sanity validation steps and what kinds of errors this might catch.
Also applies to: 157-157
198-220
: LGTM: New function to persist SHAP valuesThe addition of the
persist_shap_data
function is a valuable enhancement to the model's interpretability features. Saving SHAP values to parquet files for both test and train datasets will allow for efficient storage and retrieval of this important information.Positive aspects:
- Use of parquet files for efficient storage.
- Separate handling for test and train SHAP values.
- Inclusion of error handling.
Consider enhancing the error logging in the except block:
except Exception as e: logging.error(f"Failed to save SHAP values: {e}", exc_info=True)This will provide more detailed stack trace information, which can be helpful for debugging.
icu_benchmarks/run_utils.py (5)
18-20
: Consider removing unused importsThe
os
andglob
modules are imported but not used in the visible changes. If they are not used elsewhere in the file, consider removing these imports to keep the code clean.If these imports are indeed unused, you can remove them:
-import os -import glob import polars as pl🧰 Tools
🪛 Ruff
18-18:
os
imported but unusedRemove unused import:
os
(F401)
19-19:
glob
imported but unusedRemove unused import:
glob
(F401)
58-64
: LGTM! Consider clarifying help textThe addition of
--modalities
and--label
arguments enhances the flexibility of the benchmarking process. Good job!Consider clarifying the help text for the
--label
argument to indicate its relationship with multiple labels:- parser.add_argument("--label", type=str, help="Label to use for evaluation in case of multiple labels.", default=None) + parser.add_argument("--label", type=str, help="Specific label to use for evaluation when the dataset contains multiple labels.", default=None)
81-82
: LGTM! Consider optimizing directory creationThe addition of the while loop ensures unique directory names, preventing potential conflicts. This is a good improvement for data integrity.
Consider using
Path.mkdir(parents=True, exist_ok=True)
to simplify the directory creation process:- log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) - while log_dir_run.exists(): - log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f")) - log_dir_run.mkdir(parents=True) + while True: + log_dir_run = log_dir / datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f") + try: + log_dir_run.mkdir(parents=True, exist_ok=False) + break + except FileExistsError: + continueThis approach is more efficient and handles race conditions better in multi-threaded environments.
Line range hint
111-142
: Remove duplicate code and optimize SHAP value aggregationThe addition of SHAP value aggregation is valuable. However, there's unnecessary duplication in the code.
Please remove the duplicate block and optimize the SHAP value aggregation:
shap_values_test = [] - # shap_values_train = [] for repetition in log_dir.iterdir(): if repetition.is_dir(): aggregated[repetition.name] = {} for fold_iter in repetition.iterdir(): aggregated[repetition.name][fold_iter.name] = {} if (fold_iter / "test_metrics.json").is_file(): with open(fold_iter / "test_metrics.json", "r") as f: result = json.load(f) aggregated[repetition.name][fold_iter.name].update(result) elif (fold_iter / "val_metrics.csv").is_file(): with open(fold_iter / "val_metrics.csv", "r") as f: result = json.load(f) aggregated[repetition.name][fold_iter.name].update(result) # Add durations to metrics if (fold_iter / "durations.json").is_file(): with open(fold_iter / "durations.json", "r") as f: result = json.load(f) aggregated[repetition.name][fold_iter.name].update(result) if (fold_iter / "test_shap_values.parquet").is_file(): shap_values_test.append(pl.read_parquet(fold_iter / "test_shap_values.parquet")) - if shap_values_test: - shap_values = pl.concat(shap_values_test) - shap_values.write_parquet(log_dir / "aggregated_shap_values.parquet") - try: - shap_values = pl.concat(shap_values_test) - shap_values.write_parquet(log_dir / "aggregated_shap_values.parquet") + if shap_values_test: + shap_values = pl.concat(shap_values_test) + shap_values.write_parquet(log_dir / "aggregated_shap_values.parquet") except Exception as e: logging.error(f"Error aggregating or writing SHAP values: {e}")This change removes the duplicate code and ensures that SHAP values are only processed once, improving efficiency and readability.
🧰 Tools
🪛 Ruff
145-145: Loop control variable
repetition
not used within loop body(B007)
262-286
: LGTM! Consider using list comprehension for efficiencyThe
get_config_files
function is a valuable addition for retrieving task and model configurations. The error handling and logging are well implemented.Consider using list comprehensions to make the code more concise and potentially more efficient:
def get_config_files(config_dir: Path): try: - tasks = list((config_dir / "tasks").glob("*")) - models = list((config_dir / "prediction_models").glob("*")) - tasks = [task.stem for task in tasks if task.is_file()] - models = [model.stem for model in models if model.is_file()] + tasks = [task.stem for task in (config_dir / "tasks").glob("*") if task.is_file() and task.stem != "common"] + models = [model.stem for model in (config_dir / "prediction_models").glob("*") if model.is_file() and model.stem != "common"] except Exception as e: logging.error(f"Error retrieving config files: {e}") return [], [] - if "common" in tasks: - tasks.remove("common") - if "common" in models: - models.remove("common") logging.info(f"Found tasks: {tasks}") logging.info(f"Found models: {models}") return tasks, modelsThis change combines the file listing and filtering steps, reducing the number of iterations and eliminating the need for separate
remove
operations.icu_benchmarks/tuning/hyperparameters.py (5)
194-194
: LGTM: Newplot
parameter added.The addition of the
plot
parameter provides useful flexibility for controlling visualization output.Consider updating the function's docstring to include a description of the new
plot
parameter.
238-261
: LGTM: Improved hyperparameter suggestion logic.The refactoring of the
objective
function to use Optuna's trial object and the addition of helper functions for different parameter types significantly improve the flexibility and readability of the hyperparameter suggestion process.Consider adding error handling for unexpected value types in the
hyperparams_bounds
to make the function more robust.
372-381
: LGTM: Added visualization for hyperparameter tuning results.The new plotting functionality for hyperparameter importances and optimization history enhances the interpretability of the tuning results. The error handling for the plotting process is a good practice.
Consider using a context manager (
with plt.figure():
) when saving the plots to ensure proper cleanup of matplotlib resources.
Line range hint
1-411
: Overall: Significant improvement in hyperparameter tuning capabilities.The changes in this file represent a substantial enhancement to the project's hyperparameter tuning capabilities:
- The implementation of Optuna provides more flexibility and potentially better optimization results.
- New visualization features enhance the interpretability of the tuning process.
- Improved checkpoint handling and integration with wandb for better experiment tracking.
These changes should lead to more efficient and effective model tuning. Great work on this refactoring!
As the project evolves, consider creating a separate module for visualization functions to keep the main tuning logic clean and modular.
Remove the redundant commented-out checkpoint code.
Line 325 contains a commented-out logging statement related to
Attempting to find latest checkpoint file
. Since this functionality is active at line 85, the commented-out code appears unnecessary and can be removed to clean up the codebase.🔗 Analysis chain
Line range hint
280-340
: LGTM: Updated checkpoint loading for Optuna studies.The checkpoint loading process has been successfully adapted to work with Optuna studies, which is appropriate for the new implementation.
There's a commented-out section for finding the latest checkpoint file. Could you clarify if this functionality is still needed or if it can be removed?
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Check for commented-out code related to checkpoint finding rg -n "Attempting to find latest checkpoint file" icu_benchmarks/tuning/hyperparameters.pyLength of output: 239
icu_benchmarks/data/loader.py (8)
Line range hint
1-15
: LGTM! Consider grouping importsThe addition of Polars (
import polars as pl
) is appropriate for the new Polars-based dataset implementations.Consider grouping imports by standard library, third-party, and local imports for better readability. For example:
# Standard library imports import warnings from typing import List, Dict, Tuple # Third-party imports import gin import numpy as np import polars as pl from torch import Tensor, cat, from_numpy, float32 from torch.utils.data import Dataset # Local imports from icu_benchmarks.imputation.amputations import ampute_data from .constants import DataSegment as Segment from .constants import DataSplit as Split
16-82
: Well-implementedCommonPolarsDataset
, consider these enhancementsThe
CommonPolarsDataset
class is a good implementation of a Polars-based dataset. Here are some suggestions for improvement:
- Add type hints to method return values for better code documentation.
- Consider using Polars' lazy execution for better performance in the
__init__
method.- The
to_tensor
method could be optimized for memory usage.Here are the suggested changes:
- Add return type hints:
- def __len__(self) -> int: + def __len__(self) -> int: - def get_feature_names(self) -> List[str]: + def get_feature_names(self) -> List[str]: - def to_tensor(self) -> List[Tensor]: + def to_tensor(self) -> List[Tensor]:
- Use lazy execution in
__init__
:- self.features_df = data[split][Segment.features] - self.features_df = self.features_df.sort([self.vars["GROUP"], self.vars["SEQUENCE"]]) - self.features_df = self.features_df.drop(self.vars["SEQUENCE"]) + self.features_df = ( + data[split][Segment.features] + .lazy() + .sort([self.vars["GROUP"], self.vars["SEQUENCE"]]) + .drop(self.vars["SEQUENCE"]) + .collect() + )
- Optimize
to_tensor
method:def to_tensor(self) -> List[Tensor]: return [cat([entry[i].unsqueeze(0) for entry in self], dim=0) for i in range(len(self[0]))]This version avoids creating intermediate lists and should be more memory-efficient.
84-183
: Well-implementedPredictionPolarsDataset
, consider these optimizationsThe
PredictionPolarsDataset
class is a good implementation for prediction tasks. Here are some suggestions for improvement:
- Optimize the
__getitem__
method using Polars operations.- Add error handling for potential edge cases.
- Improve type hinting and docstrings.
Here are the suggested changes:
- Optimize
__getitem__
:def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]: if self._cached_dataset is not None: return self._cached_dataset[idx] stay_id = self.outcome_df[self.vars["GROUP"]].unique()[idx] window = self.features_df.filter(pl.col(self.vars["GROUP"]) == stay_id) labels = self.outcome_df.filter(pl.col(self.vars["GROUP"]) == stay_id)[self.vars["LABEL"]] if len(labels) == 1: labels = pl.concat([pl.Series([None] * (window.height - 1)), labels]) length_diff = self.maxlen - window.height pad_mask = pl.Series([1] * window.height) if length_diff > 0: window = window.vstack(pl.DataFrame({col: [0] * length_diff for col in window.columns})) labels = labels.extend(pl.Series([0] * length_diff)) pad_mask = pad_mask.extend(pl.Series([0] * length_diff)) labels = labels.fill_null(-1) pad_mask = pad_mask.where(labels != -1, 0) return ( from_numpy(window.to_numpy().astype(np.float32)), from_numpy(labels.to_numpy().astype(np.float32)), from_numpy(pad_mask.to_numpy().astype(bool)) )
- Add error handling:
def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor]: if idx < 0 or idx >= len(self): raise IndexError(f"Index {idx} is out of bounds for dataset of size {len(self)}") # ... rest of the method
- Improve type hinting and docstrings:
def get_balance(self) -> List[float]: """ Calculate the weight balance for the split of interest. Returns: List[float]: Weights for each label, calculated as (1 / count) * sum(counts) / num_labels. """ # ... method implementation def get_data_and_labels(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Retrieve all data and labels aligned at once. Returns: Tuple[np.ndarray, np.ndarray, np.ndarray]: A tuple containing: - Data points for the split - Labels for the split - Row indicators """ # ... method implementationThese changes should improve performance, readability, and robustness of the
PredictionPolarsDataset
class.
Line range hint
186-245
: Consider a deprecation strategy forCommonPandasDataset
The
CommonPandasDataset
class is marked as deprecated, but it's still present in the codebase. While this approach maintains backward compatibility, it may lead to confusion and maintenance issues in the long term.Consider the following deprecation strategy:
- Add a deprecation timeline in the class docstring, indicating when this class will be removed.
- Implement a wrapper class that uses the new
CommonPolarsDataset
internally but exposes the same interface asCommonPandasDataset
. This allows for a smoother transition.- In the
__init__
method, log a warning message that includes migration instructions.Here's an example implementation:
import warnings from functools import wraps def deprecated(func): @wraps(func) def wrapper(*args, **kwargs): warnings.warn(f"{func.__name__} is deprecated and will be removed in version X.X. Use CommonPolarsDataset instead.", DeprecationWarning, stacklevel=2) return func(*args, **kwargs) return wrapper @gin.configurable("CommonPandasDataset") class CommonPandasDataset(Dataset): """ DEPRECATED: This class will be removed in version X.X. Use CommonPolarsDataset instead. This is a wrapper around CommonPolarsDataset that maintains the old interface for backward compatibility. """ @deprecated def __init__(self, *args, **kwargs): self._polars_dataset = CommonPolarsDataset(*args, **kwargs) warnings.warn( "CommonPandasDataset is deprecated and will be removed in version X.X. " "Use CommonPolarsDataset instead. " "Migration guide: [URL to migration guide]", DeprecationWarning, stacklevel=2 ) @deprecated def __getitem__(self, idx): return self._polars_dataset[idx] @deprecated def __len__(self): return len(self._polars_dataset) # Implement other methods as needed, delegating to self._polars_datasetThis approach allows users of the old API to continue using it while encouraging migration to the new Polars-based implementation. It also provides a clear timeline for when the old implementation will be removed.
Line range hint
246-337
: Apply consistent deprecation strategy forPredictionPandasDataset
Similar to
CommonPandasDataset
, thePredictionPandasDataset
class is deprecated but still present in the codebase. To maintain consistency and facilitate a smooth transition, apply the same deprecation strategy as suggested forCommonPandasDataset
.Implement a wrapper class for
PredictionPandasDataset
that uses the newPredictionPolarsDataset
internally:@gin.configurable("PredictionPandasDataset") class PredictionPandasDataset(CommonPandasDataset): """ DEPRECATED: This class will be removed in version X.X. Use PredictionPolarsDataset instead. This is a wrapper around PredictionPolarsDataset that maintains the old interface for backward compatibility. """ @deprecated def __init__(self, *args, **kwargs): self._polars_dataset = PredictionPolarsDataset(*args, **kwargs) warnings.warn( "PredictionPandasDataset is deprecated and will be removed in version X.X. " "Use PredictionPolarsDataset instead. " "Migration guide: [URL to migration guide]", DeprecationWarning, stacklevel=2 ) @deprecated def __getitem__(self, idx): return self._polars_dataset[idx] @deprecated def get_balance(self): return self._polars_dataset.get_balance() @deprecated def get_data_and_labels(self): return self._polars_dataset.get_data_and_labels() # Implement other methods as needed, delegating to self._polars_datasetThis approach ensures consistency across the codebase and provides a clear path for users to migrate to the new Polars-based implementation.
Line range hint
338-401
: Consider creating a Polars-based version ofImputationPandasDataset
The
ImputationPandasDataset
class is still using Pandas, which is inconsistent with the transition to Polars in other parts of the codebase. To maintain consistency and potentially improve performance, consider creating a Polars-based version of this class.Here are the steps to consider:
- Create a new
ImputationPolarsDataset
class that uses Polars instead of Pandas.- Deprecate the current
ImputationPandasDataset
class.- Provide a migration path for users of the old class.
Here's a skeleton for the new Polars-based class:
@gin.configurable("ImputationPolarsDataset") class ImputationPolarsDataset(CommonPolarsDataset): """Polars-based dataset for imputation models.""" def __init__( self, data: Dict[str, pl.DataFrame], split: str = Split.train, vars: Dict[str, str] = gin.REQUIRED, mask_proportion=0.3, mask_method="MCAR", mask_observation_proportion=0.3, ram_cache: bool = True, ): super().__init__(data, split, vars, grouping_segment=Segment.static) # Convert ampute_data function to work with Polars DataFrames self.amputated_values, self.amputation_mask = ampute_data_polars( self.features_df, mask_method, mask_proportion, mask_observation_proportion ) self.amputation_mask = (self.amputation_mask | self.features_df.is_null()).cast(pl.Boolean) self.amputation_mask = self.amputation_mask.with_columns(pl.col(self.vars["GROUP"])) self.amputation_mask = self.amputation_mask.set_index(self.vars["GROUP"]) self.target_missingness_mask = self.features_df.is_null() self.features_df = self.features_df.fill_null(0) self.ram_cache(ram_cache) def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor, Tensor, Tensor]: # Implement this method using Polars operations pass # Implement other methods as neededAfter creating this new class, deprecate the old
ImputationPandasDataset
:@gin.configurable("ImputationPandasDataset") class ImputationPandasDataset(CommonPandasDataset): """ DEPRECATED: This class will be removed in version X.X. Use ImputationPolarsDataset instead. """ @deprecated def __init__(self, *args, **kwargs): warnings.warn( "ImputationPandasDataset is deprecated and will be removed in version X.X. " "Use ImputationPolarsDataset instead. " "Migration guide: [URL to migration guide]", DeprecationWarning, stacklevel=2 ) super().__init__(*args, **kwargs) # Rest of the initializationThis approach will help maintain consistency across the codebase and provide a clear migration path for users.
Line range hint
404-478
: Consider creating a Polars-based version ofImputationPredictionDataset
The
ImputationPredictionDataset
class is still using Pandas, which is inconsistent with the transition to Polars in other parts of the codebase. To maintain consistency and potentially improve performance, consider creating a Polars-based version of this class.Here are the steps to consider:
- Create a new
ImputationPredictionPolarsDataset
class that uses Polars instead of Pandas.- Deprecate the current
ImputationPredictionDataset
class.- Provide a migration path for users of the old class.
Here's a skeleton for the new Polars-based class:
@gin.configurable("ImputationPredictionPolarsDataset") class ImputationPredictionPolarsDataset(Dataset): """Polars-based dataset for imputation prediction tasks.""" def __init__( self, data: pl.DataFrame, grouping_column: str = "stay_id", select_columns: List[str] = None, ram_cache: bool = True, ): self.dyn_df = data if select_columns is not None: self.dyn_df = self.dyn_df.select(select_columns + [grouping_column]) if grouping_column is not None: self.dyn_df = self.dyn_df.set_index(grouping_column) # calculate basic info for the data self.group_indices = self.dyn_df.get_column(grouping_column).unique() self.maxlen = self.dyn_df.group_by(grouping_column).agg(pl.count()).max().item() self._cached_dataset = None if ram_cache: logging.info("Caching dataset in ram.") self._cached_dataset = [self[i] for i in range(len(self))] def __len__(self) -> int: return len(self.group_indices) def __getitem__(self, idx: int) -> Tensor: if self._cached_dataset is not None: return self._cached_dataset[idx] stay_id = self.group_indices[idx] window = self.dyn_df.filter(pl.col(self.dyn_df.index_column) == stay_id) return from_numpy(window.to_numpy()).to(float32)After creating this new class, deprecate the old
ImputationPredictionDataset
:@gin.configurable("ImputationPredictionDataset") class ImputationPredictionDataset(Dataset): """ DEPRECATED: This class will be removed in version X.X. Use ImputationPredictionPolarsDataset instead. """ @deprecated def __init__(self, *args, **kwargs): warnings.warn( "ImputationPredictionDataset is deprecated and will be removed in version X.X. " "Use ImputationPredictionPolarsDataset instead. " "Migration guide: [URL to migration guide]", DeprecationWarning, stacklevel=2 ) super().__init__() # Rest of the initializationThis approach will help maintain consistency across the codebase and provide a clear migration path for users. It will also potentially improve performance by leveraging Polars' efficient data manipulation capabilities.
Line range hint
1-478
: Overall assessment: Good progress on Polars transition, but further consistency neededThe transition from Pandas to Polars in this file is a positive change that should improve performance and efficiency. However, there are still some inconsistencies and areas for improvement:
The new Polars-based classes (
CommonPolarsDataset
andPredictionPolarsDataset
) are well-implemented but could benefit from some optimizations and improved type hinting.The deprecated Pandas-based classes (
CommonPandasDataset
andPredictionPandasDataset
) are still present. A clear deprecation strategy with wrappers for backward compatibility has been suggested.Two classes (
ImputationPandasDataset
andImputationPredictionDataset
) are still using Pandas and have not been updated to use Polars. Creating Polars-based versions of these classes has been recommended.To improve the overall quality and consistency of the codebase:
- Implement the suggested optimizations for the Polars-based classes.
- Apply the proposed deprecation strategy for Pandas-based classes.
- Create Polars-based versions of the remaining Pandas classes.
- Ensure consistent error handling and type hinting across all classes.
- Consider adding more comprehensive unit tests to verify the behavior of both Pandas and Polars implementations during the transition period.
These changes will result in a more consistent, efficient, and maintainable codebase.
icu_benchmarks/data/split_process_data.py (6)
Line range hint
22-42
: LGTM: Function signature updates align with new features and Polars transition.The changes to the
preprocess_data
function signature appropriately reflect the new features and the transition to Polars. However, there's a minor improvement we can make:Consider updating the default values for
modality_mapping
andvars_to_exclude
to useNone
instead of mutable default arguments:- modality_mapping: dict[str] = {}, + modality_mapping: dict[str] | None = None, - vars_to_exclude: list[str] = [], + vars_to_exclude: list[str] | None = None,Then, initialize these variables at the beginning of the function:
modality_mapping = modality_mapping or {} vars_to_exclude = vars_to_exclude or []This change follows the best practice of avoiding mutable default arguments in Python.
🧰 Tools
🪛 Ruff
26-26: Do not use mutable data structures for argument defaults
Replace with
None
; initialize within function(B006)
192-213
: LGTM: check_sanitize_data function implementation.The
check_sanitize_data
function effectively removes duplicates from the loaded data, which is crucial for data integrity. The implementation looks correct and efficient.Consider adding a docstring to explain the function's parameters and return value:
def check_sanitize_data(data, vars): """ Check for duplicates in the loaded data and remove them. Args: data (dict): Dictionary containing data segments (static, dynamic, outcome). vars (dict): Dictionary containing variable mappings. Returns: dict: Data with duplicates removed. """ # ... (existing implementation)🧰 Tools
🪛 Ruff
194-194: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
195-195: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
197-197: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
201-201: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
205-205: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
216-242
: LGTM: modality_selection function implementation.The
modality_selection
function provides a useful way to select specific modalities based on the provided mapping. The implementation looks correct and handles edge cases well.Consider adding type hints to the function signature for better code readability and maintainability:
def modality_selection( data: dict[str, pl.DataFrame], modality_mapping: dict[str, list[str]], selected_modalities: list[str], vars: dict[str, Any] ) -> tuple[dict[str, pl.DataFrame], dict[str, Any]]: # ... (existing implementation)🧰 Tools
🪛 Ruff
220-220: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
221-221: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
238-238: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
Line range hint
244-318
: LGTM: make_train_val function updates for Polars support.The
make_train_val
function has been successfully updated to support both Pandas and Polars DataFrames. The changes maintain backwards compatibility while allowing for the transition to Polars.Consider refactoring the function to reduce code duplication between Pandas and Polars operations. You could create helper functions for common operations:
def sample_data(data, fraction, seed, polars): return ( data.sample(fraction=fraction, seed=seed) if polars else data.sample(frac=fraction, random_state=seed) ) def create_split(stays, train, val, polars): if polars: return { Split.train: stays[train].cast(pl.datatypes.Int64).to_frame(), Split.val: stays[val].cast(pl.datatypes.Int64).to_frame(), } else: return {Split.train: stays.iloc[train], Split.val: stays.iloc[val]} # Use these helper functions in make_train_valThis refactoring will make the code more maintainable and easier to read.
🧰 Tools
🪛 Ruff
301-301: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
309-309: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
313-313: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
Line range hint
332-452
: LGTM: make_single_split function updates for Polars support.The
make_single_split
function has been successfully updated to support both Pandas and Polars DataFrames. The changes maintain backwards compatibility while allowing for the transition to Polars.Similar to the
make_train_val
function, consider refactoring to reduce code duplication:
- Create helper functions for common operations (e.g., getting labels, creating splits).
- Use ternary operators for simple Polars/Pandas conditionals:
dev_stays = stays[dev] if polars else stays.iloc[dev]
- Consider extracting the data split creation logic into a separate function, as it's similar in both
make_train_val
andmake_single_split
.These refactorings will improve code maintainability and readability.
Line range hint
1-465
: Overall LGTM: Successful transition to Polars with enhanced functionality.The changes in this file represent a significant improvement to the data processing pipeline:
- The transition from Pandas to Polars is well-implemented, which should lead to better performance and reduced memory usage.
- New features like modality selection and data sanitization add valuable functionality to the preprocessing steps.
- Backwards compatibility with Pandas is maintained, allowing for a smooth transition.
These changes enhance the overall functionality and efficiency of the module. Great job on the implementation!
For future improvements, consider:
- Further modularizing the code by extracting common operations into utility functions.
- Adding more comprehensive error handling and logging throughout the pipeline.
- Implementing unit tests for the new functions and modified logic to ensure robustness.
🧰 Tools
🪛 Ruff
301-301: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
309-309: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
313-313: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
icu_benchmarks/models/wrappers.py (2)
413-414
: LGTM: Added row_indicators to data retrieval.The addition of
row_indicators
to theget_data_and_labels()
calls enhances the data handling capabilities, possibly to manage row-level information. This change is consistent with similar modifications in other methods likevalidation_step
andtest_step
.Consider updating the method's docstring to reflect the new
row_indicators
parameter and its purpose.
Line range hint
562-621
: LGTM: Added flexible weight initialization and improved on_fit_start.The addition of the
init_weights
method provides valuable flexibility in weight initialization, supporting various techniques like normal, Xavier, Kaiming, and orthogonal initialization. The modification ofon_fit_start
to initialize weights and reset metrics ensures proper setup before training begins.Consider adding type hints to the
init_weights
method for improved code readability:def init_weights(self, init_type: str = "normal", gain: float = 0.02): # ... (rest of the method remains the same)This change would make the expected types of the parameters more explicit.
icu_benchmarks/data/preprocessor.py (6)
32-34
: Document the usage ofvars_to_exclude
in derived classes.The
vars_to_exclude
parameter has been added to theapply
method signature in the abstract base class. However, it's not clear how this parameter is used in the derived classes. Consider adding documentation or implementing this feature consistently across all preprocessor classes.
164-177
: Clarify the usage of the_model_impute
method.The
_model_impute
method is defined but never called within this class. If it's intended to be used by derived classes or externally, consider adding a docstring explaining its purpose and usage.
Line range hint
279-350
: Add error handling for missing data segments.In the
apply
method, the code assumes that all required data segments (dynamic, static, features) exist. Consider adding error handling to gracefully handle cases where these segments might be missing from the data dictionary.Suggested addition:
if Segment.dynamic not in data[Split.train]: raise ValueError(f"Missing {Segment.dynamic} segment in {Split.train} data.") # Add similar checks for other required segments
Line range hint
476-536
: Enhance logging for the missing value filtering process.In the
_process_dynamic_data
method, consider adding more detailed logging about the filtering process. This could include the number of rows removed and the percentage of data affected.Suggested addition:
if self.filter_missing_values: total_rows = len(data[Segment.dynamic]) rows_removed = rows_to_remove.sum() logging.info(f"Removed {rows_removed} rows ({rows_removed/total_rows:.2%}) with missing values.") logging.info(f"Removed {len(ids_to_remove)} stays with missing values.")
Line range hint
574-584
: Add error handling for file operations in cache_recipe.Consider adding try-except blocks to handle potential IOErrors when creating directories or writing to files. This will make the function more robust against file system issues.
Suggested addition:
def cache_recipe(recipe: Recipe, cache_file: str) -> None: recipe_cache = copy.deepcopy(recipe) recipe_cache.cache() try: if not (cache_file / "..").exists(): (cache_file / "..").mkdir(parents=True, exist_ok=True) cache_file.touch() with open(cache_file, "wb") as f: pickle.dump(recipe_cache, f, pickle.HIGHEST_PROTOCOL) logging.info(f"Cached recipe in {cache_file}.") except IOError as e: logging.error(f"Failed to cache recipe: {e}")
Security Issue Confirmed: Usage of
pickle.load
inpreprocessor.py
.The
restore_recipe
function usespickle.load(f)
, which can execute arbitrary code if the source of the cached recipe is not trusted. This poses a significant security risk.
- File:
icu_benchmarks/data/preprocessor.py
- Line:
recipe = pickle.load(f)
Consider using a more secure serialization method, such as JSON with custom encoding/decoding for complex objects, or implement safety checks before loading pickled data.
🔗 Analysis chain
Line range hint
587-596
: Consider security implications of using pickle for recipe restoration.Using
pickle.load()
can be a security risk if the source of the cached recipe is not trusted, as it can execute arbitrary code. Consider using a more secure serialization method like JSON with custom encoding/decoding for complex objects, or implement safety checks before loading pickled data.To verify the potential security risk, you can run the following script:
If this search returns multiple results, it might indicate a wider use of pickle that should be reviewed for security.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Check for uses of pickle.load in the codebase rg --type python 'pickle\.load'Length of output: 64
Script:
#!/bin/bash # Check for uses of pickle.load in Python files rg 'pickle\.load' --glob '*.py'Length of output: 185
docs/adding_model/rnn.py (1)
9-9
: Enhance the class docstring for better clarity.The current docstring
"""Torch standard RNN model"""
is brief. Expanding it to include details about the model's parameters, expected input shapes, and usage examples would improve readability and maintainability.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (20)
- docs/adding_model/rnn.py (1 hunks)
- icu_benchmarks/cross_validation.py (6 hunks)
- icu_benchmarks/data/loader.py (7 hunks)
- icu_benchmarks/data/pooling.py (6 hunks)
- icu_benchmarks/data/preprocessor.py (8 hunks)
- icu_benchmarks/data/split_process_data.py (10 hunks)
- icu_benchmarks/models/dl_models/rnn.py (1 hunks)
- icu_benchmarks/models/dl_models/tcn.py (1 hunks)
- icu_benchmarks/models/dl_models/transformer.py (1 hunks)
- icu_benchmarks/models/ml_models/catboost.py (1 hunks)
- icu_benchmarks/models/ml_models/imblearn.py (1 hunks)
- icu_benchmarks/models/ml_models/lgbm.py (1 hunks)
- icu_benchmarks/models/ml_models/sklearn.py (1 hunks)
- icu_benchmarks/models/ml_models/xgboost.py (1 hunks)
- icu_benchmarks/models/train.py (10 hunks)
- icu_benchmarks/models/wrappers.py (9 hunks)
- icu_benchmarks/run.py (7 hunks)
- icu_benchmarks/run_utils.py (6 hunks)
- icu_benchmarks/tuning/hyperparameters.py (7 hunks)
- requirements.txt (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- icu_benchmarks/cross_validation.py
- icu_benchmarks/models/ml_models/imblearn.py
- requirements.txt
🧰 Additional context used
🪛 Ruff
docs/adding_model/rnn.py
15-15: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
icu_benchmarks/data/pooling.py
106-106: Do not use mutable data structures for argument defaults
Replace with
None
; initialize within function(B006)
108-108: Do not use mutable data structures for argument defaults
Replace with
None
; initialize within function(B006)
icu_benchmarks/data/split_process_data.py
26-26: Do not use mutable data structures for argument defaults
Replace with
None
; initialize within function(B006)
41-41: Do not use mutable data structures for argument defaults
Replace with
None
; initialize within function(B006)
121-121: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
126-126: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
126-126: Test for membership should be
not in
Convert to
not in
(E713)
194-194: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
195-195: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
197-197: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
201-201: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
205-205: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
220-220: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
221-221: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
238-238: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
301-301: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
309-309: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
313-313: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
419-422: Use ternary operator
dev_stays = stays[dev] if polars else stays.iloc[dev]
instead ofif
-else
-blockReplace
if
-else
-block withdev_stays = stays[dev] if polars else stays.iloc[dev]
(SIM108)
438-438: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
446-446: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
450-450: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
icu_benchmarks/models/dl_models/rnn.py
16-16: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
42-42: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
69-69: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
icu_benchmarks/models/dl_models/tcn.py
23-23: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
icu_benchmarks/models/dl_models/transformer.py
37-37: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
52-52: Loop control variable
i
not used within loop bodyRename unused
i
to_i
(B007)
106-106: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
123-123: Loop control variable
i
not used within loop bodyRename unused
i
to_i
(B007)
icu_benchmarks/models/ml_models/catboost.py
14-14: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
icu_benchmarks/models/ml_models/xgboost.py
27-27: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
76-76: f-string without any placeholders
Remove extraneous
f
prefix(F541)
icu_benchmarks/models/wrappers.py
474-474: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
icu_benchmarks/run_utils.py
18-18:
os
imported but unusedRemove unused import:
os
(F401)
19-19:
glob
imported but unusedRemove unused import:
glob
(F401)
🔇 Additional comments (41)
icu_benchmarks/models/ml_models/catboost.py (4)
1-9
: LGTM: Imports and class declaration look goodThe imports are correct, including the fixed import from
constants
. TheCBClassifier
class is properly defined with the@gin.configurable
decorator and the correct_supported_run_modes
attribute.
17-27
: LGTM: Predict method is well-implementedThe
predict
method is correctly implemented to return class probabilities using the CatBoost model'spredict_proba
method. The docstring provides clear and concise information about the method's functionality, parameters, and return value.
1-27
: Overall assessment: Good implementation with minor improvement suggestionsThe
CBClassifier
class provides a solid wrapper for the CatBoost classifier, integrating well with the project's structure. The implementation is generally good, with clear method definitions and appropriate use of decorators.The main area for improvement is in the constructor, as noted in the previous comment. Addressing these suggestions will enhance the flexibility and maintainability of the code.
Great job on fixing the import statement and implementing the predict method with clear documentation!
🧰 Tools
🪛 Ruff
14-14: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
11-15
: 🛠️ Refactor suggestionRefactor constructor for improved flexibility and code cleanliness
Consider the following improvements to the constructor:
- Make the task type configurable instead of hardcoding it to "CPU".
- Remove the commented-out code to improve readability.
- Address the star-arg unpacking warning.
Here's a suggested implementation:
def __init__(self, task_type="CPU", *args, **kwargs): model_kwargs = {'task_type': task_type} self.model = self.set_model_args(cb.CatBoostClassifier, **model_kwargs, **kwargs) super().__init__(*args, **kwargs)This change allows the task type to be specified when initializing the class, defaulting to "CPU" if not provided, and addresses the star-arg unpacking warning.
🧰 Tools
🪛 Ruff
14-14: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
icu_benchmarks/models/ml_models/lgbm.py (3)
1-8
: LGTM: Import statements are appropriate and well-organized.The import statements are concise and include all necessary libraries for the LightGBM wrapper implementation, including gin for configuration, numpy for numerical operations, and wandb for logging. The specific import of the wandb callback for LightGBM is a good practice for integration.
1-57
: Overall, excellent implementation of LightGBM wrappers with minor suggestions for enhancement.This file provides a well-structured and robust implementation of LightGBM wrappers for both classification and regression tasks. The use of gin for configuration and wandb for logging demonstrates good practices in machine learning model development. The code is clean, readable, and follows a consistent style.
Key strengths:
- Proper separation of concerns between wrapper, classifier, and regressor classes.
- Integration with wandb for experiment tracking.
- Use of early stopping and callbacks for efficient training.
Suggestions for improvement:
- Add type hints to method signatures for better code readability and maintainability.
- Consider adding input validation in the predict methods to enhance robustness.
- Clarify the inheritance of the predict method in the LGBMRegressor class.
These minor enhancements would further improve an already solid implementation.
51-57
: LGTM: LGBMRegressor implementation is concise, but clarification needed.The LGBMRegressor class is correctly implemented, using gin for configuration and properly specifying its supported run mode as regression.
However, I noticed that unlike the LGBMClassifier, this class doesn't include a predict method. Is this intentional? If so, it might be helpful to add a comment explaining that the predict method is inherited from the parent class. If not, consider adding a predict method specific to regression tasks.
To verify the existence of a predict method in the parent class, we can run the following script:
This will help us confirm whether the predict method is indeed inherited from the parent class.
icu_benchmarks/models/dl_models/tcn.py (3)
1-9
: LGTM: Imports are appropriate and well-organized.The imports cover all necessary dependencies for implementing the Temporal Convolutional Network, including configuration management (gin), numerical operations (numpy), neural network components (torch.nn), and custom modules from the icu_benchmarks package.
18-55
: LGTM: Well-structured constructor with automatic channel computation.The constructor is well-designed, handling various input scenarios and automatically computing the number of channels when needed. The network construction logic clearly implements the TCN architecture using TemporalBlock instances.
🧰 Tools
🪛 Ruff
23-23: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
57-62
: LGTM: Forward method is correctly implemented.The forward method efficiently handles the necessary tensor permutations and correctly passes the data through the network and final linear layer. The implementation follows the expected flow for a TCN.
icu_benchmarks/models/ml_models/sklearn.py (2)
2-2
: LGTM: Import statements have been streamlined.The consolidation of scikit-learn imports into a single line improves code clarity and readability. The removal of LightGBM and WandB imports is consistent with the elimination of LightGBM-related functionality.
Line range hint
1-93
: Verify the impact of LightGBM removal.The LightGBM-related classes and methods have been removed as per the AI summary. While this change simplifies the
sklearn.py
file, it's crucial to ensure that this removal doesn't cause issues elsewhere in the codebase.Please run the following script to check for any remaining references to LightGBM in the codebase:
If any references are found, they may need to be removed or updated to maintain consistency with this change.
icu_benchmarks/models/dl_models/rnn.py (1)
1-5
: LGTM: Import statements are appropriate and concise.The import statements are well-organized and include all necessary modules for the implemented functionality.
icu_benchmarks/models/dl_models/transformer.py (3)
1-6
: LGTM: Import statements are well-organized and relevant.The import statements are appropriately structured, following the convention of standard library imports first, followed by third-party libraries, and then local modules. All imports seem relevant to the implementation of transformer models.
9-14
: Good use of gin for configuration and clear mode support.The
Transformer
class is well-structured with@gin.configurable
decorator, allowing for flexible configuration. The explicit definition of supported run modes (classification
andregression
) is a good practice for clarity.
69-76
: LGTM: Forward method is concise and follows the transformer architecture.The
forward
method is well-implemented, following the expected transformer architecture. It processes the input through the embedding layer, positional encoding (if enabled), transformer blocks, and finally the output layer.icu_benchmarks/models/ml_models/xgboost.py (4)
1-25
: LGTM: Imports and class definition look good.The imports are appropriate for the XGBoost classifier implementation. The use of the
@gin.configurable
decorator allows for flexible configuration, which is a good practice. The_supported_run_modes
attribute clearly defines that this class is intended for classification tasks only.
30-40
: LGTM:predict
method is well-implemented.The
predict
method is correctly implemented, using the model'spredict_proba
method to return class probabilities. The docstring is well-written, providing clear information about the method's input, output, and functionality.
98-101
: LGTM:get_feature_importance
method is well-implemented.The
get_feature_importance
method is correctly implemented. It properly checks if the model has been fit before attempting to return the feature importances. The error message is clear and informative, guiding the user to callfit_model()
if needed.
87-96
:⚠️ Potential issueImprove robustness of
set_model_args
method.The current implementation of
set_model_args
can be improved for better robustness and clarity. Consider implementing the following changes:
- Directly use the
kwargs
parameter instead oflocals()['kwargs']
.- Filter out invalid arguments to prevent potential errors when initializing the model.
Here's a suggested implementation:
def set_model_args(self, model, *args, **kwargs): """Set model arguments based on the model's signature.""" signature = inspect.signature(model.__init__).parameters valid_params = signature.keys() # Filter out invalid arguments valid_kwargs = {k: v for k, v in kwargs.items() if k in valid_params} logging.debug(f"Creating model with: {valid_kwargs}.") return model(**valid_kwargs)This refactoring improves robustness by filtering out invalid arguments and avoids the use of
locals()
.icu_benchmarks/data/pooling.py (2)
19-29
: Approve formatting changes in__init__
method signature.The vertical alignment of parameters in the
__init__
method signature improves code readability without affecting functionality. This change is in line with good coding practices.
51-55
: Approve formatting changes ingenerate
method signature.The vertical alignment of parameters in the
generate
method signature enhances code readability without altering functionality. This change is consistent with the formatting improvements made throughout the class.icu_benchmarks/run.py (4)
21-21
: LGTM: New import added for configuration file retrieval.The addition of
get_config_files
to the imports is consistent with the new functionality described in the summary. This change enhances the module's capability to handle configuration files.
84-86
: LGTM: Improved readability of wandb config update.The modification to the
update_wandb_config
call enhances code readability by breaking the long line into multiple lines. The functionality remains unchanged, which is good.
145-148
: LGTM: Improved readability of hyperparameter tuning function call.The modification to the
choose_and_bind_hyperparameters_optuna
function call enhances code readability by using keyword arguments. This change not only improves clarity but also reduces the likelihood of errors when passing arguments. It aligns well with the summary mentioning updates to enhance hyperparameter tuning processes.
Line range hint
1-203
: Overall assessment: The changes improve the code quality and functionality.The modifications in this file align well with the PR objectives and the AI-generated summary. Key improvements include:
- Enhanced configuration file handling and validation.
- Added support for modalities and label binding.
- Improved error handling and logging.
- Better readability through the use of keyword arguments and line breaks.
- Support for different model types based on run mode.
These changes contribute to a more robust and flexible codebase. The minor suggestions provided in the review comments can further enhance code quality and maintainability.
icu_benchmarks/models/train.py (3)
6-6
: LGTM: Transition to Polars for improved performanceThe addition of Polars (pl) import and PredictionPolarsDataset is a positive change. Polars is known for its superior performance and lower memory usage compared to pandas, especially for large datasets. This transition should lead to improved efficiency in data handling operations.
Also applies to: 14-14
127-127
: LGTM: Flexible model loading with CPU optionThe addition of the
cpu
parameter to the model loading process is a good improvement. This change enhances flexibility, particularly for CPU-only environments, and improves cross-platform compatibility.
Line range hint
1-247
: Overall assessment: Significant improvements with minor suggestionsThe changes to
train.py
represent a substantial improvement in the codebase:
- The transition from pandas to polars should lead to better performance and reduced memory usage.
- New configuration options provide greater flexibility in training and data handling.
- The addition of SHAP value persistence enhances model interpretability.
These changes are well-implemented and should positively impact the project. However, to further improve the code:
- Update the docstrings for functions with new or changed parameters.
- Add inline comments explaining the reasoning behind specific configuration choices.
- Address the TODO comment regarding full support for polars versions of datasets.
- Consider the suggestions for minor improvements in error handling and logging.
Great work on these enhancements! The changes show a clear focus on performance, flexibility, and interpretability.
icu_benchmarks/tuning/hyperparameters.py (3)
5-5
: LGTM: New imports for Optuna and visualization.The added imports for matplotlib and Optuna visualization functions are appropriate for the new hyperparameter tuning implementation. The relocation of the RunMode import doesn't affect functionality.
Also applies to: 16-18
264-269
: LGTM: Updated callback function for Optuna compatibility.The
tune_step_callback
function has been successfully adapted to work with Optuna's study and trial objects. The logging improvements provide better visibility into the tuning progress.
341-363
: LGTM: Optuna study creation and optimization implemented correctly.The creation of the Optuna study with specified sampler and pruner, as well as the optimization process with appropriate callbacks, has been implemented correctly. The integration with wandb through the WeightsAndBiasesCallback is a nice addition for experiment tracking.
icu_benchmarks/data/split_process_data.py (1)
3-4
: LGTM: Import changes align with the transition to Polars.The import changes correctly reflect the shift from Pandas to Polars for data processing. The addition of
os
andtimeit
imports, along with the removal ofpyarrow.parquet
, are appropriate for the new implementation.Also applies to: 9-9, 12-12, 14-15
icu_benchmarks/models/wrappers.py (5)
6-7
: LGTM: New metric imports added.The addition of
average_precision_score
androc_auc_score
from sklearn.metrics enhances the model evaluation capabilities. These are valuable metrics for classification tasks, particularly for imbalanced datasets.
33-34
: LGTM: New metrics configured for gin.The addition of gin configurations for
average_precision_score
androc_auc_score
allows for flexible use of these metrics in the gin-config system. This is consistent with the existing pattern and enhances the configurability of the model evaluation process.
376-383
:⚠️ Potential issueConsider removing or implementing commented code and unused attribute.
The commented-out code (lines 376-379) for setting
self.loss
based on the loss function name is not being used. If this functionality is no longer needed, consider removing it entirely. If it's intended for future use, add a TODO comment explaining its purpose and when it might be implemented.The
self.loss_weight
attribute (line 383) is added but not used anywhere in the class. If it's intended for future use, add a TODO comment explaining its purpose. If it's not needed, consider removing it to avoid confusion.To check if
self.loss_weight
is used elsewhere in the codebase:#!/bin/bash # Search for usage of self.loss_weight in Python files rg "self\.loss_weight" --type python
440-440
:⚠️ Potential issueConsider using or removing the unused row_indicators variable.
The
row_indicators
variable is retrieved fromget_data_and_labels()
but not used in the method body. If it's intended for future use, consider adding a TODO comment explaining its purpose. If it's not needed, you can remove it to keep the code clean.To check if
row_indicators
is used elsewhere in thevalidation_step
method:
474-488
:⚠️ Potential issueSimplify dictionary key check and review MPS-dependent type casting.
Simplify the dictionary key check:
Instead ofif "Confusion_Matrix" in self.metrics.keys():
, useif "Confusion_Matrix" in self.metrics:
. This is more efficient as it avoids creating a temporary view object for the dictionary keys.Review MPS-dependent type casting:
The current implementation of MPS-dependent type casting seems redundant. Both branches of the conditional expression perform the same operation:metric(self.label_transform(label), self.output_transform(pred))This might not achieve the intended MPS compatibility. Consider reviewing and simplifying this logic, or add a comment explaining the necessity of this structure if it's intentional.
Here's a suggested simplification:
if "Confusion_Matrix" in self.metrics: self.log_dict(confusion_matrix(self.label_transform(label), self.output_transform(pred)), sync_dist=True) self.log_dict( { f"{metric_type}/{name}": metric(self.label_transform(label), self.output_transform(pred)) for name, metric in self.metrics.items() if not isinstance(metric(self.label_transform(label), self.output_transform(pred)), tuple) and name != "Confusion_Matrix" }, sync_dist=True, )If MPS-specific handling is necessary, please add a comment explaining its purpose and implementation.
To check if there are any MPS-specific implementations elsewhere in the codebase:
🧰 Tools
🪛 Ruff
474-474: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
icu_benchmarks/data/preprocessor.py (2)
Line range hint
9-21
: LGTM: New imports for Polars and additional preprocessing steps.The new imports for Polars and additional preprocessing steps from the recipys library are appropriate for the changes made in this file.
Line range hint
537-571
: LGTM: Flexible implementation for applying recipes.The
apply_recipe_to_splits
function is well-implemented, handling both Polars and Pandas DataFrames, and providing options for caching and loading recipes. This flexibility is beneficial for various preprocessing scenarios.docs/adding_model/rnn.py (1)
19-19
:⚠️ Potential issueVerify the indexing of
input_size
when initializingnn.RNN
.In line 19,
input_size[2]
is used as the input dimension fornn.RNN
:self.rnn = nn.RNN(input_size[2], hidden_dim, layer_dim, batch_first=True)This implies that
input_size
is an iterable with at least three elements. Ensure thatinput_size[2]
correctly represents the feature dimension expected by the RNN.To confirm, please verify how
input_size
is defined throughout the codebase. Ifinput_size
is intended to be an integer representing the input feature size, adjust the code accordingly.Apply this diff if
input_size
should be an integer:- self.rnn = nn.RNN(input_size[2], hidden_dim, layer_dim, batch_first=True) + self.rnn = nn.RNN(input_size, hidden_dim, layer_dim, batch_first=True)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 16
🧹 Outside diff range and nitpick comments (33)
icu_benchmarks/models/dl_models/rnn.py (2)
34-58
: LGTM: LSTMNet implementation is correct.The LSTMNet class is well-implemented with appropriate initialization, hidden state management, and forward pass. The use of @gin.configurable allows for flexible configuration.
The issues mentioned in the past review comment have been addressed:
- The argument unpacking in the super().init() call has been fixed.
- The implementation is correct and doesn't have any critical issues.
As a minor suggestion, consider adding type hints to method parameters and return values for improved code readability and maintainability.
61-84
: LGTM: GRUNet implementation is correct.The GRUNet class is well-implemented with appropriate initialization, hidden state management, and forward pass. The use of @gin.configurable allows for flexible configuration.
The issues mentioned in the past review comment have been addressed:
- The missing return statement in the forward method has been added.
- The argument unpacking in the super().init() call has been fixed.
As a minor suggestion, consider adding type hints to method parameters and return values for improved code readability and maintainability.
icu_benchmarks/models/ml_models/xgboost.py (2)
42-59
: LGTM with suggestion:fit_model
method is well-implemented.The
fit_model
method correctly fits the model to the training data, incorporating early stopping and optional Weights & Biases integration. The conditional computation of SHAP values is a good approach for flexibility.Suggestion for improvement:
Consider allowing the user to specify which validation metric to use for the return value, instead of always using the first one. This would provide more flexibility in model evaluation.Example:
def fit_model(self, train_data, train_labels, val_data, val_labels, eval_metric='auto'): # ... existing code ... eval_score = mean(self.model.evals_result_["validation_0"][eval_metric]) return eval_scoreThis change would allow users to specify which metric they want to use for evaluation.
61-84
: LGTM with minor fix:test_step
method is well-implemented.The
test_step
method correctly processes test data, makes predictions, and logs results. It handles both multi-processing and standard logging scenarios effectively.Minor fix needed:
On line 76, there's an f-string without any placeholders. Remove thef
prefix:logging.debug(f"Saved row indicators to {os.path.join(self.logger.save_dir, 'row_indicators.csv')}")This change resolves the static analysis warning and improves code clarity.
🧰 Tools
🪛 Ruff
76-76: f-string without any placeholders
Remove extraneous
f
prefix(F541)
icu_benchmarks/run_utils.py (5)
58-64
: LGTM! Consider adding type hints for clarityThe new arguments
--modalities
and--label
are well-implemented and enhance the flexibility of the parser. To improve code clarity, consider adding type hints to thenargs
andtype
parameters.Here's a suggested improvement:
parser.add_argument( "-mo", "--modalities", - nargs="+", + nargs: str = "+", help="Optional modality selection to use. Specify multiple modalities separated by spaces.", ) -parser.add_argument("--label", type=str, help="Label to use for evaluation in case of multiple labels.", default=None) +parser.add_argument("--label", type: str = str, help="Label to use for evaluation in case of multiple labels.", default=None)
81-82
: LGTM! Consider using a more efficient approachThe addition of the while loop effectively prevents directory name conflicts. However, for better efficiency, consider using a single call to
datetime.now()
and formatting it with both seconds and microseconds.Here's a suggested optimization:
-log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) -while log_dir_run.exists(): - log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f")) +now = datetime.now() +log_dir_run = log_dir / now.strftime("%Y-%m-%dT%H-%M-%S.%f") +while log_dir_run.exists(): + now = datetime.now() + log_dir_run = log_dir / now.strftime("%Y-%m-%dT%H-%M-%S.%f")This approach reduces the number of
datetime.now()
calls and ensures unique directory names.
Line range hint
111-142
: Refactor duplicate code and improve error handlingThe addition of SHAP value aggregation is a valuable enhancement. However, there are a few issues to address:
- There's duplicate code for concatenating and writing SHAP values (lines 134-136 and 138-140).
- The error handling could be more specific and informative.
Here's a suggested refactoring:
shap_values_test = [] -# shap_values_train = [] for repetition in log_dir.iterdir(): if repetition.is_dir(): aggregated[repetition.name] = {} for fold_iter in repetition.iterdir(): aggregated[repetition.name][fold_iter.name] = {} if (fold_iter / "test_metrics.json").is_file(): with open(fold_iter / "test_metrics.json", "r") as f: result = json.load(f) aggregated[repetition.name][fold_iter.name].update(result) elif (fold_iter / "val_metrics.csv").is_file(): with open(fold_iter / "val_metrics.csv", "r") as f: result = json.load(f) aggregated[repetition.name][fold_iter.name].update(result) # Add durations to metrics if (fold_iter / "durations.json").is_file(): with open(fold_iter / "durations.json", "r") as f: result = json.load(f) aggregated[repetition.name][fold_iter.name].update(result) if (fold_iter / "test_shap_values.parquet").is_file(): shap_values_test.append(pl.read_parquet(fold_iter / "test_shap_values.parquet")) -if shap_values_test: - shap_values = pl.concat(shap_values_test) - shap_values.write_parquet(log_dir / "aggregated_shap_values.parquet") - try: - shap_values = pl.concat(shap_values_test) - shap_values.write_parquet(log_dir / "aggregated_shap_values.parquet") + if shap_values_test: + shap_values = pl.concat(shap_values_test) + shap_values.write_parquet(log_dir / "aggregated_shap_values.parquet") + logging.info(f"Successfully aggregated and wrote SHAP values to {log_dir / 'aggregated_shap_values.parquet'}") except Exception as e: - logging.error(f"Error aggregating or writing SHAP values: {e}") + logging.error(f"Error aggregating or writing SHAP values: {str(e)}", exc_info=True)This refactoring removes the duplicate code, improves error handling, and adds a success log message for better traceability.
🧰 Tools
🪛 Ruff
145-145: Loop control variable
repetition
not used within loop body(B007)
262-286
: LGTM! Consider enhancing error handling and readabilityThe
get_config_files
function is a valuable addition for dynamically discovering available tasks and models. However, there are a few suggestions to improve its robustness and readability:
- Use
Path.glob()
instead of creating alist
immediately, as it's more memory-efficient.- Combine the filtering and processing steps for better readability.
- Add more specific error handling for different scenarios (e.g., directory not found, permission issues).
Here's a suggested refactoring:
def get_config_files(config_dir: Path): """ Get all task and model config files in the specified directory. Args: config_dir: Path to the directory containing the config gin files. Returns: tuple: Lists of task names and model names. """ tasks, models = [], [] try: tasks = [task.stem for task in (config_dir / "tasks").glob("*") if task.is_file() and task.stem != "common"] models = [model.stem for model in (config_dir / "prediction_models").glob("*") if model.is_file() and model.stem != "common"] except FileNotFoundError as e: logging.error(f"Directory not found: {e}") except PermissionError as e: logging.error(f"Permission denied when accessing config files: {e}") except Exception as e: logging.error(f"Unexpected error when retrieving config files: {e}") logging.info(f"Found tasks: {tasks}") logging.info(f"Found models: {models}") return tasks, modelsThis refactored version improves readability, efficiency, and error handling while maintaining the original functionality.
288-301
: LGTM! Consider adding an option for soft validationThe
check_required_keys
function is a valuable addition for input validation. It's well-implemented and follows good practices. To enhance its flexibility, consider adding an option for "soft" validation that returns a boolean instead of raising an exception.Here's a suggested enhancement:
def check_required_keys(vars: dict, required_keys: list, raise_error: bool = True) -> bool: """ Checks if all required keys are present in the vars dictionary. Args: vars (dict): The dictionary to check. required_keys (list): The list of required keys. raise_error (bool): If True, raises a KeyError when keys are missing. If False, returns a boolean. Returns: bool: True if all required keys are present, False otherwise (only if raise_error is False). Raises: KeyError: If any required key is missing and raise_error is True. """ missing_keys = [key for key in required_keys if key not in vars] if missing_keys: if raise_error: raise KeyError(f"Missing required keys in vars: {', '.join(missing_keys)}") return False return TrueThis enhancement allows the function to be used in different scenarios, providing more flexibility to the caller.
icu_benchmarks/tuning/hyperparameters.py (2)
196-196
: Consider documenting theplot
parameter.The new
plot
parameter has been added to the function signature, but it's not documented in the function's docstring. Consider adding a description of this parameter to maintain consistency with the documentation of other parameters.
Line range hint
201-201
: Clarify thesampler
parameter and its usage.The
sampler
parameter is added to the function signature and used later in the function, but it's not documented in the function's docstring. Additionally, the default value is set tooptuna.samplers.GPSampler
, but this is instantiated differently based on whether it's this specific sampler or not. Consider:
- Adding documentation for the
sampler
parameter in the function's docstring.- Clarifying the logic for sampler instantiation, possibly by creating a separate function for this.
Here's a suggested refactor for the sampler instantiation:
def create_sampler(sampler_class, seed, n_initial_points): if sampler_class == optuna.samplers.GPSampler: return sampler_class(seed=seed, n_startup_trials=n_initial_points, deterministic_objective=True) return sampler_class(seed=seed) # In the main function sampler = create_sampler(sampler, seed, n_initial_points)Also applies to: 322-323
icu_benchmarks/data/loader.py (4)
18-28
: Add type hints to__init__
method parametersConsider adding type hints to the
__init__
method parameters for better code documentation and to leverage static type checking tools. For example:def __init__( self, data: Dict[str, pl.DataFrame], split: str = Split.train, vars: Dict[str, str] = gin.REQUIRED, grouping_segment: str = Segment.outcome, mps: bool = False, name: str = "", *args, **kwargs, ):
Line range hint
185-203
: Clarify deprecation path forCommonPandasDataset
The renaming of
CommonDataset
toCommonPandasDataset
and the addition of a deprecation warning are appropriate steps in the transition to Polars. However, the class still uses Pandas operations, which may be confusing for developers.Consider the following suggestions:
- Add a comment explaining the deprecation timeline and migration path to the Polars-based classes.
- If possible, update the class to use Polars operations internally while maintaining the same interface. This would allow for a smoother transition for existing code.
Example comment:
""" This class is deprecated and will be removed in version X.X. Please migrate to CommonPolarsDataset. For assistance with migration, refer to the migration guide at <link_to_migration_guide>. """
245-246
: Add deprecation warning toPredictionPandasDataset
For consistency with the
CommonPandasDataset
class and to facilitate a smooth transition to the Polars-based implementation, consider adding a deprecation warning to thePredictionPandasDataset
class. This will help users understand that they should migrate to the newPredictionPolarsDataset
class in the future.Example:
@gin.configurable("PredictionPandasDataset") class PredictionPandasDataset(CommonPandasDataset): def __init__(self, *args, ram_cache: bool = True, **kwargs): warnings.warn("PredictionPandasDataset is deprecated. Use PredictionPolarsDataset instead.", DeprecationWarning, stacklevel=2) super().__init__(*args, grouping_segment=Segment.outcome, **kwargs) # ... rest of the __init__ method ...
Line range hint
1-462
: Overall assessment: Good progress on Pandas to Polars transition with room for improvementThe transition from Pandas to Polars in this file is a significant improvement that should enhance performance and reduce memory usage. The addition of new Polars-based classes (
CommonPolarsDataset
andPredictionPolarsDataset
) alongside the renamed Pandas-based classes provides a clear path for migration.However, there are a few areas where the transition could be more consistent and complete:
- Ensure all Pandas-based classes have appropriate deprecation warnings.
- Consider creating Polars-based versions of all existing classes (e.g.,
ImputationPolarsDataset
).- Update utility functions like
ampute_data
to work with Polars DataFrames.- Provide clear documentation or a migration guide to help users transition from Pandas-based to Polars-based classes.
- Review and optimize all methods in the new Polars-based classes to fully leverage Polars operations for maximum efficiency.
Addressing these points will help complete the transition to Polars and provide a more consistent and efficient codebase.
icu_benchmarks/data/split_process_data.py (8)
Line range hint
20-42
: LGTM! Enhanced flexibility with new parameters.The function signature changes improve flexibility, particularly with modality selection. The default preprocessor change aligns with the Polars transition.
Suggestion: Consider using
None
as the default formodality_mapping
instead of an empty dict to avoid potential issues with mutable default arguments.- modality_mapping: dict[str] = {}, + modality_mapping: dict[str] | None = None,Then, initialize it within the function:
modality_mapping = modality_mapping or {}This change would align with Python best practices for handling mutable default arguments.
🧰 Tools
🪛 Ruff
27-27: Do not use mutable data structures for argument defaults
Replace with
None
; initialize within function(B006)
Line range hint
70-137
: LGTM! Improved robustness and flexibility in data preprocessing.The changes enhance the function's capabilities, particularly in label handling, modality selection, and the transition to Polars for data loading. The improved logging provides better visibility into the preprocessing steps.
Suggestion: In the label handling logic, consider using a more explicit approach for selecting the default label:
- if label is not None: - vars[Var.label] = [label] - else: - logging.debug(f"Multiple labels found and no value provided. Using first label: {vars[Var.label]}") - vars[Var.label] = vars[Var.label][0] + if label is not None: + vars[Var.label] = [label] + elif isinstance(vars[Var.label], list) and len(vars[Var.label]) > 0: + logging.debug(f"Multiple labels found and no value provided. Using first label: {vars[Var.label][0]}") + vars[Var.label] = [vars[Var.label][0]] + else: + raise ValueError("No valid label found or provided.")This change ensures that there's always a valid label and provides a clear error message if no valid label is found.
159-180
: LGTM! Improved preprocessing and data quality checks.The changes enhance the preprocessing step with better timing measurement and more robust NaN/null handling using Polars capabilities. The detailed logging of data quality issues is a valuable addition.
Suggestion: Consider consolidating the NaN and null handling steps:
- dict[key] = val.fill_null(strategy="zero") - dict[key] = val.fill_nan(0) + dict[key] = val.fill_null(strategy="zero").fill_nan(0)This change combines the two operations into a single method chain, which might be slightly more efficient and readable.
193-214
: LGTM! Valuable addition for data sanitization.The
check_sanitize_data
function is a great addition for ensuring data quality by removing duplicates from different data segments. It effectively uses Polars for efficient data manipulation.Suggestion: Consider adding error handling for cases where expected columns are missing:
def check_sanitize_data(data, vars): """Check for duplicates in the loaded data and remove them.""" group = vars[Var.group] if Var.group in vars.keys() else None sequence = vars[Var.sequence] if Var.sequence in vars.keys() else None + if group is None: + raise ValueError("Group variable not found in vars dictionary.") keep = "last" if Segment.static in data.keys(): + if group not in data[Segment.static].columns: + raise ValueError(f"Group column '{group}' not found in static data.") old_len = len(data[Segment.static]) data[Segment.static] = data[Segment.static].unique(subset=group, keep=keep, maintain_order=True) logging.warning(f"Removed {old_len - len(data[Segment.static])} duplicates from static data.") if Segment.dynamic in data.keys(): + if group not in data[Segment.dynamic].columns or sequence not in data[Segment.dynamic].columns: + raise ValueError(f"Group column '{group}' or sequence column '{sequence}' not found in dynamic data.") old_len = len(data[Segment.dynamic]) data[Segment.dynamic] = data[Segment.dynamic].unique(subset=[group, sequence], keep=keep, maintain_order=True) logging.warning(f"Removed {old_len - len(data[Segment.dynamic])} duplicates from dynamic data.") # ... (similar checks for outcome data) return dataThis addition will help catch potential issues early in the data processing pipeline.
🧰 Tools
🪛 Ruff
195-195: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
196-196: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
198-198: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
202-202: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
206-206: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
217-243
: LGTM! Excellent addition for flexible modality selection.The
modality_selection
function adds valuable flexibility to the data processing pipeline, allowing users to select specific modalities. It effectively uses the modality_mapping for column filtering and updates both data and vars dictionaries.Suggestion: Consider adding more robust error handling and validation:
def modality_selection( data: dict[pl.DataFrame], modality_mapping: dict[str], selected_modalities: list[str], vars ) -> dict[pl.DataFrame]: logging.info(f"Selected modalities: {selected_modalities}") + if not modality_mapping: + raise ValueError("Modality mapping is empty or None.") + if not selected_modalities or selected_modalities == ["all"]: + logging.info("No specific modalities selected. Using all columns.") + return data, vars selected_columns = [modality_mapping[cols] for cols in selected_modalities if cols in modality_mapping.keys()] if not any(col in modality_mapping.keys() for col in selected_modalities): raise ValueError("None of the selected modalities found in modality mapping.") - if selected_columns == []: - logging.info("No columns selected. Using all columns.") - return data, vars + if not selected_columns: + raise ValueError("No columns were selected based on the provided modalities.") selected_columns = sum(selected_columns, []) selected_columns.extend([vars[Var.group], vars[Var.label], vars[Var.sequence]]) # ... (rest of the function remains the same)These changes provide more explicit handling of edge cases and clearer error messages.
🧰 Tools
🪛 Ruff
221-221: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
222-222: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
239-239: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
Line range hint
245-330
: LGTM! Improved compatibility and code organization.The changes to
make_train_val
function successfully add support for Polars DataFrames while maintaining compatibility with Pandas. The new helper functions_get_stays
and_get_labels
improve code organization and readability.Suggestion: Consider updating the type hints to reflect the dual support for Pandas and Polars:
-def make_train_val( - data: dict[pd.DataFrame], - vars: dict[str], - train_size=0.8, - seed: int = 42, - debug: bool = False, - runmode: RunMode = RunMode.classification, - polars: bool = True, -) -> dict[dict[pl.DataFrame]]: +from typing import Union, Dict + +def make_train_val( + data: Dict[str, Union[pd.DataFrame, pl.DataFrame]], + vars: Dict[str, str], + train_size: float = 0.8, + seed: int = 42, + debug: bool = False, + runmode: RunMode = RunMode.classification, + polars: bool = True, +) -> Dict[str, Dict[str, Union[pd.DataFrame, pl.DataFrame]]]:This change more accurately represents the function's ability to work with both Pandas and Polars DataFrames.
🧰 Tools
🪛 Ruff
302-302: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
310-310: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
314-314: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
Line range hint
334-453
: LGTM! Improved flexibility and error handling.The changes to
make_single_split
function successfully add support for Polars DataFrames while maintaining Pandas compatibility. The enhanced error handling for classification tasks with insufficient samples is a valuable addition.Suggestion: Consider extracting the data splitting logic into a separate function to improve readability and maintainability:
def _split_data(stays, labels, train_size, cv_repetitions, cv_folds, repetition_index, fold_index, seed, runmode): if Var.label in vars and runmode is RunMode.classification: if train_size: outer_cv = StratifiedShuffleSplit(cv_repetitions, train_size=train_size) else: outer_cv = StratifiedKFold(cv_repetitions, shuffle=True, random_state=seed) inner_cv = StratifiedKFold(cv_folds, shuffle=True, random_state=seed) else: if train_size: outer_cv = ShuffleSplit(cv_repetitions, train_size=train_size) else: outer_cv = KFold(cv_repetitions, shuffle=True, random_state=seed) inner_cv = KFold(cv_folds, shuffle=True, random_state=seed) dev, test = list(outer_cv.split(stays, labels))[repetition_index] dev_stays = stays[dev] if isinstance(stays, pl.Series) else stays.iloc[dev] dev_labels = labels[dev] if isinstance(labels, pl.Series) else labels.iloc[dev] train, val = list(inner_cv.split(dev_stays, dev_labels))[fold_index] return dev, test, train, val, dev_stays # Use this function in make_single_split dev, test, train, val, dev_stays = _split_data(stays, labels, train_size, cv_repetitions, cv_folds, repetition_index, fold_index, seed, runmode)This refactoring would make the
make_single_split
function more concise and easier to maintain.
Line range hint
1-465
: Great job on the Polars transition and new features!The changes in this file significantly improve performance through the use of Polars while maintaining backward compatibility with Pandas. The new features, such as modality selection and improved data sanitization, add valuable flexibility to the data processing pipeline.
General suggestions for further improvement:
Consider adding type hints consistently throughout the file, especially for the new Polars-related code.
Add more comprehensive docstrings to the new functions (
check_sanitize_data
,modality_selection
) explaining their parameters, return values, and any raised exceptions.Consider creating a configuration object or using a library like
pydantic
to manage and validate the various options and parameters used throughout the file.Add unit tests for the new functions and the updated logic to ensure everything works correctly with both Polars and Pandas.
Consider adding a feature flag or configuration option to completely switch to Polars in the future, removing the Pandas code paths for simplicity once the transition is complete.
Overall, this is a solid improvement to the data processing pipeline. Great work!
🧰 Tools
🪛 Ruff
302-302: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
310-310: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
314-314: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
icu_benchmarks/models/wrappers.py (3)
33-35
: LGTM: New metrics configured for gin.The addition of average_precision_score and roc_auc_score to the gin configuration is a good improvement for model evaluation.
Consider addressing the commented-out line.
There's a commented-out line for
scorer_wrapper
. If this is no longer needed, consider removing it to keep the code clean.
491-505
: LGTM: New methods for model explanation and output saving.The addition of
_explain_model
and_save_model_outputs
methods enhances the class's capabilities for model interpretability and debugging.Minor formatting issue in f-string.
On line 503, there's an unnecessary f-string:
logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / f'row_indicators.csv'}")The inner f-string can be removed:
logging.debug(f"Saved row indicators to {Path(self.logger.save_dir) / 'row_indicators.csv'}")🧰 Tools
🪛 Ruff
503-503: f-string without any placeholders
Remove extraneous
f
prefix(F541)
Line range hint
1-530
: Overall assessment: Significant improvements with minor issues.This update to
icu_benchmarks/models/wrappers.py
introduces several valuable enhancements:
- New metrics and improved metric handling
- Enhanced debugging and model explanation capabilities
- More flexible weight setting for handling imbalanced datasets
There are a few minor issues to address:
- Unused
row_indicators
parameter in multiple methods- Unused
loss_weight
attribute in theMLWrapper
class- A minor formatting issue with an f-string
These changes significantly improve the functionality and flexibility of the wrapper classes. Addressing the minor issues will further refine the code quality.
🧰 Tools
🪛 Ruff
503-503: f-string without any placeholders
Remove extraneous
f
prefix(F541)
icu_benchmarks/data/preprocessor.py (7)
50-145
: LGTM! Comprehensive Polars-based classification preprocessor implementation.The
PolarsClassificationPreprocessor
class provides a robust implementation for preprocessing classification data using Polars. It handles both static and dynamic features, applies scaling, and generates additional features when required. The implementation is well-structured and covers various preprocessing scenarios.Consider optimizing the join operations in the
apply
method. Instead of performing separate joins for each split (train, val, test), you could create a function to handle the join operation and apply it to all splits. This would reduce code duplication and improve maintainability. For example:def join_static_dynamic(dynamic, static, group_var): return dynamic.join(static, on=group_var) # Then in the apply method: for split in [Split.train, Split.val, Split.test]: data[split][Segment.dynamic] = join_static_dynamic( data[split][Segment.dynamic], data[split][Segment.static], vars["GROUP"] )
147-200
: LGTM! Comprehensive static and dynamic data preprocessing.The
_process_static
and_process_dynamic
methods provide thorough preprocessing for static and dynamic data using Polars. They handle scaling, imputation, and feature generation effectively. The implementation follows best practices for data preprocessing.Consider adding error handling for cases where the required columns might be missing in the data. For example, in the
_process_dynamic
method:def _process_dynamic(self, data, vars): required_columns = set(vars[Segment.dynamic] + [vars["GROUP"], vars["SEQUENCE"]]) missing_columns = required_columns - set(data[Split.train][Segment.dynamic].columns) if missing_columns: raise ValueError(f"Missing required columns: {missing_columns}") # Rest of the method implementation...This addition would make the code more robust against potential data inconsistencies.
218-281
: LGTM! Well-implemented regression preprocessor with outcome scaling.The
PolarsRegressionPreprocessor
class effectively extends the classification preprocessor to handle regression tasks. It includes outcome processing with scaling based on predefined or data-driven ranges, which is crucial for regression problems.Consider adding an option to skip outcome scaling entirely. This could be useful in cases where the outcome is already in the desired range or when preserving the original scale is important. You could add a
scale_outcome
parameter to the constructor:def __init__( self, generate_features: bool = False, scaling: bool = True, use_static_features: bool = True, outcome_max=None, outcome_min=None, scale_outcome: bool = True, save_cache=None, load_cache=None, ): # ... existing code ... self.scale_outcome = scale_outcome def _process_outcome(self, data, vars, split): if not self.scale_outcome: return data # ... existing scaling code ...This addition would provide more flexibility in handling different regression scenarios.
Line range hint
284-480
: LGTM! Clear renaming and consistency in Pandas-based preprocessors.The renaming of Pandas-based preprocessor classes improves clarity and consistency with the new Polars-based classes. The structure remains similar, which is good for code organization and maintainability.
While the current implementation maintains separate Polars and Pandas classes, consider for future refactoring:
- Extracting common preprocessing logic into a shared base class or utility functions.
- Using a strategy pattern to switch between Polars and Pandas implementations.
This could reduce code duplication and make it easier to maintain both versions in the long run. For example:
class BasePreprocessor(abc.ABC): @abc.abstractmethod def _process_static(self, data, vars): pass @abc.abstractmethod def _process_dynamic(self, data, vars): pass class PolarsPreprocessor(BasePreprocessor): # Polars-specific implementations class PandasPreprocessor(BasePreprocessor): # Pandas-specific implementations class ClassificationPreprocessor: def __init__(self, backend='polars'): self.preprocessor = PolarsPreprocessor() if backend == 'polars' else PandasPreprocessor() def apply(self, data, vars): return self.preprocessor.apply(data, vars)This structure would allow easier maintenance and extension of both Polars and Pandas implementations.
Line range hint
481-538
: LGTM! Well-implemented imputation preprocessor with customization options.The
PandasImputationPreprocessor
class provides a specialized preprocessor for imputation tasks with good flexibility through its customization options. The implementation is correct and consistent with the established patterns in the file.Consider enhancing the logging in the
_process_dynamic_data
method to provide more detailed information about the removed data. This could be helpful for debugging and monitoring the preprocessing process. For example:def _process_dynamic_data(self, data, vars): if self.filter_missing_values: rows_to_remove = data[Segment.dynamic][vars[Segment.dynamic]].isna().sum(axis=1) != 0 ids_to_remove = data[Segment.dynamic].loc[rows_to_remove][vars["GROUP"]].unique() original_size = {table_name: len(table) for table_name, table in data.items()} data = {table_name: table.loc[~table[vars["GROUP"]].isin(ids_to_remove)] for table_name, table in data.items()} removed_size = {table_name: original_size[table_name] - len(table) for table_name, table in data.items()} logging.info(f"Removed {len(ids_to_remove)} stays with missing values.") logging.info(f"Rows removed per table: {removed_size}") return dataThis addition would provide more detailed information about the impact of filtering missing values on each data segment.
Line range hint
541-594
: LGTM! Robust utility functions for recipe handling with multi-library support.The utility functions for applying recipes to splits, caching recipes, and restoring recipes are well-implemented and now support both Polars and Pandas DataFrames. This flexibility is valuable for supporting different data processing libraries.
Consider enhancing the error handling in the
restore_recipe
function to provide more informative error messages. For example:def restore_recipe(cache_file: str) -> Recipe: """Restore recipe from cache to use for e.g. transfer learning.""" try: if not cache_file.exists(): raise FileNotFoundError(f"Cache file {cache_file} not found.") with open(cache_file, "rb") as f: logging.info(f"Loading cached recipe from {cache_file}.") recipe = pickle.load(f) return recipe except (pickle.UnpicklingError, EOFError) as e: raise ValueError(f"Error unpickling recipe from {cache_file}: {str(e)}") from e except Exception as e: raise RuntimeError(f"Unexpected error while restoring recipe from {cache_file}: {str(e)}") from eThis enhancement would provide more specific error messages, making it easier to diagnose issues with recipe restoration.
Line range hint
1-594
: Excellent enhancements to the preprocessing module with multi-library support.The changes to
icu_benchmarks/data/preprocessor.py
significantly improve the flexibility and capabilities of the preprocessing module. Key improvements include:
- Addition of Polars support alongside Pandas, potentially offering performance benefits.
- New preprocessor classes for classification, regression, and imputation tasks using both Polars and Pandas.
- Consistent code structure and naming conventions across different preprocessor types.
- Enhanced utility functions supporting multiple DataFrame types.
These changes provide a robust foundation for handling various preprocessing scenarios in the ICU benchmarks project. The multi-library support allows for easy comparison and selection of the most suitable data processing approach for different use cases.
For future development, consider:
- Implementing benchmarks to compare the performance of Polars vs. Pandas implementations for typical preprocessing tasks in your project.
- Creating a factory pattern for preprocessor creation to simplify the selection between Polars and Pandas implementations based on configuration or data characteristics.
- Developing a comprehensive suite of unit tests to ensure the consistency of results between Polars and Pandas implementations.
These steps would further enhance the robustness and usability of the preprocessing module.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (15)
- icu_benchmarks/cross_validation.py (6 hunks)
- icu_benchmarks/data/loader.py (7 hunks)
- icu_benchmarks/data/pooling.py (7 hunks)
- icu_benchmarks/data/preprocessor.py (8 hunks)
- icu_benchmarks/data/split_process_data.py (10 hunks)
- icu_benchmarks/models/custom_metrics.py (2 hunks)
- icu_benchmarks/models/dl_models/rnn.py (1 hunks)
- icu_benchmarks/models/dl_models/tcn.py (1 hunks)
- icu_benchmarks/models/dl_models/transformer.py (1 hunks)
- icu_benchmarks/models/ml_models/catboost.py (1 hunks)
- icu_benchmarks/models/ml_models/xgboost.py (1 hunks)
- icu_benchmarks/models/wrappers.py (11 hunks)
- icu_benchmarks/run_utils.py (6 hunks)
- icu_benchmarks/tuning/hyperparameters.py (8 hunks)
- requirements.txt (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (5)
- icu_benchmarks/cross_validation.py
- icu_benchmarks/data/pooling.py
- icu_benchmarks/models/dl_models/tcn.py
- icu_benchmarks/models/ml_models/catboost.py
- requirements.txt
🧰 Additional context used
🪛 Ruff
icu_benchmarks/data/split_process_data.py
27-27: Do not use mutable data structures for argument defaults
Replace with
None
; initialize within function(B006)
122-122: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
127-127: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
127-127: Test for membership should be
not in
Convert to
not in
(E713)
195-195: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
196-196: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
198-198: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
202-202: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
206-206: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
221-221: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
222-222: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
239-239: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
302-302: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
310-310: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
314-314: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
420-423: Use ternary operator
dev_stays = stays[dev] if polars else stays.iloc[dev]
instead ofif
-else
-blockReplace
if
-else
-block withdev_stays = stays[dev] if polars else stays.iloc[dev]
(SIM108)
439-439: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
447-447: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
451-451: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
icu_benchmarks/models/dl_models/transformer.py
47-47: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
62-62: Loop control variable
i
not used within loop bodyRename unused
i
to_i
(B007)
107-107: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
124-124: Loop control variable
i
not used within loop bodyRename unused
i
to_i
(B007)
icu_benchmarks/models/ml_models/xgboost.py
27-27: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
76-76: f-string without any placeholders
Remove extraneous
f
prefix(F541)
icu_benchmarks/models/wrappers.py
503-503: f-string without any placeholders
Remove extraneous
f
prefix(F541)
icu_benchmarks/run_utils.py
18-18:
os
imported but unusedRemove unused import:
os
(F401)
19-19:
glob
imported but unusedRemove unused import:
glob
(F401)
🔇 Additional comments (23)
icu_benchmarks/models/dl_models/rnn.py (1)
1-84
: Overall, excellent implementation of RNN, LSTM, and GRU models.This file provides well-structured and correct implementations of three recurrent neural network architectures: RNN, LSTM, and GRU. The use of @gin.configurable for all classes allows for flexible configuration, and the support for both classification and regression run modes adds versatility to the models.
Minor suggestions for improvement across all classes:
- Add type hints to method parameters and return values for improved code readability and maintainability.
- Consider adding docstrings to methods for better documentation.
Great job on addressing the issues mentioned in previous reviews and creating a solid foundation for these RNN-based models.
icu_benchmarks/models/dl_models/transformer.py (2)
9-18
: LGTM: BaseTransformer class implementationThe BaseTransformer class is well-structured and implements common functionality for transformer-based models. The forward method correctly follows the standard transformer architecture flow, promoting code reuse and maintainability.
1-140
: Overall assessment: Good implementation with room for improvementThe implementation of transformer-based models for ICU benchmarks is correct and follows the HiRID-Benchmark specifications. The code structure is generally good, with clear separation of concerns between the
BaseTransformer
,Transformer
, andLocalTransformer
classes.However, there are opportunities for improvement:
- Refactoring the initialization process to use configuration objects can enhance readability and maintainability.
- Reducing code duplication between
Transformer
andLocalTransformer
classes can improve the overall code structure.- Minor improvements like removing unused loop variables can be made.
Implementing the suggested refactoring will result in a more robust and maintainable codebase while preserving the current functionality.
🧰 Tools
🪛 Ruff
47-47: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
62-62: Loop control variable
i
not used within loop bodyRename unused
i
to_i
(B007)
107-107: Star-arg unpacking after a keyword argument is strongly discouraged
(B026)
124-124: Loop control variable
i
not used within loop bodyRename unused
i
to_i
(B007)
icu_benchmarks/models/ml_models/xgboost.py (3)
1-25
: LGTM: Imports and class definition are well-structured.The imports are appropriate for the XGBoost classifier implementation. The use of
@gin.configurable
decorator allows for flexible configuration, and the_supported_run_modes
attribute correctly specifies the classification run mode.
30-40
: LGTM:predict
method is well-implemented and documented.The
predict
method correctly uses the model'spredict_proba
method to return class probabilities. The docstring provides clear information about the method's functionality, inputs, and outputs.
96-99
: LGTM:get_feature_importance
method is well-implemented.The
get_feature_importance
method correctly retrieves feature importances from the trained model. It includes a proper check to ensure the model has been fit before attempting to access the feature importances. The error message is clear and informative, guiding the user to callfit_model()
if needed.icu_benchmarks/models/custom_metrics.py (2)
5-6
: LGTM: New imports are appropriate and well-structured.The addition of
ndarray
fromnumpy
for type hinting and the alias forconfusion_matrix
fromsklearn.metrics
are good practices. This helps improve code readability and avoids potential naming conflicts.
136-145
:⚠️ Potential issueFunction implementation needs refinement.
The
confusion_matrix
function has a few issues that need to be addressed:
- The return type in the function signature (
torch.tensor
) doesn't match the actual return type (dictionary).- There's no input validation, which could lead to runtime errors if incorrect data types are provided.
- The function doesn't handle the case where
y_true
andy_pred
have different shapes.Here's a suggested improvement:
def confusion_matrix(y_true: ndarray, y_pred: ndarray, normalize: bool = False) -> dict: if y_true.shape != y_pred.shape: raise ValueError("y_true and y_pred must have the same shape") y_pred = np.rint(y_pred).astype(int) confusion = sk_confusion_matrix(y_true, y_pred) if normalize: confusion = confusion.astype('float') / confusion.sum(axis=1)[:, np.newaxis] confusion_dict = {f"class_{i}_pred_{j}": confusion[i][j] for i in range(confusion.shape[0]) for j in range(confusion.shape[1])} return confusion_dictThis implementation:
- Corrects the return type annotation
- Adds basic input validation
- Uses a more efficient dictionary comprehension
- Improves normalization to avoid division by zero
To ensure the function is used correctly throughout the codebase, you can run:
This will help identify any places where the function is called, allowing you to verify that the changes won't break existing code.
✅ Verification successful
Confusion Matrix Function Return Type Mismatch.
- The
confusion_matrix
function's return type annotation (torch.tensor
) does not match its actual return type (dict
).- It is only used in
icu_benchmarks/models/wrappers.py
, where the returned dictionary is passed toself.log_dict()
, which expects a dictionary.🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash # Search for uses of the confusion_matrix function rg "confusion_matrix\(" --type pythonLength of output: 70
Script:
#!/bin/bash # Search for uses of the confusion_matrix function rg "confusion_matrix\(" --type pyLength of output: 408
Script:
#!/bin/bash # Get the lines surrounding the confusion_matrix function call in wrappers.py rg "confusion_matrix\(" --context 2 icu_benchmarks/models/wrappers.pyLength of output: 315
icu_benchmarks/run_utils.py (1)
Line range hint
1-301
: Overall, excellent enhancements to the run_utils.py fileThe changes made to this file significantly improve the functionality and flexibility of the ICU benchmarking framework. Key improvements include:
- Enhanced command-line argument parsing with support for modalities and labels.
- Improved directory creation logic to prevent conflicts.
- Addition of SHAP value aggregation in the results processing.
- New utility functions for discovering configuration files and validating required keys.
These changes make the framework more robust and adaptable to different scenarios. The code is generally well-structured and readable, with only minor suggestions for improvement provided in the previous comments.
Great job on these enhancements! The added functionality will greatly benefit users of this framework.
icu_benchmarks/tuning/hyperparameters.py (3)
5-5
: LGTM: New imports for Optuna and visualization.The added imports for matplotlib.pyplot and Optuna visualization functions are appropriate for the new Optuna-based implementation and plotting capabilities.
Also applies to: 18-18
25-25
: LGTM: Proper deprecation of the scikit-optimize function.The renaming of
choose_and_bind_hyperparameters
tochoose_and_bind_hyperparameters_scikit_optimize
and the addition of a deprecation warning are good practices. This approach maintains backwards compatibility while clearly signaling the transition to the new Optuna-based implementation.Also applies to: 66-67
Line range hint
1-383
: Overall: Significant improvement with transition to Optuna.The transition from scikit-optimize to Optuna for hyperparameter tuning is a substantial improvement. Optuna offers more advanced tuning capabilities and better visualization options. The overall structure of the code is good, incorporating the new Optuna-based approach effectively.
Key improvements:
- Advanced hyperparameter tuning with Optuna.
- Addition of visualization capabilities.
- Proper deprecation of the old implementation.
Areas for further improvement:
- Enhanced documentation, especially for new parameters.
- Refactoring of some complex sections for better clarity and maintainability.
- Expansion of plotting capabilities.
Overall, these changes significantly enhance the hyperparameter tuning functionality of the project.
icu_benchmarks/data/loader.py (1)
10-10
: Appropriate addition of Polars library importThe import of the Polars library is consistent with the transition from Pandas to Polars for data handling. This change aligns with the overall goal of improving performance and reducing memory usage.
icu_benchmarks/data/split_process_data.py (2)
3-15
: LGTM! Transition to Polars for improved performance.The import changes reflect the shift from Pandas to Polars for data processing, which should lead to improved performance. The addition of 'os' and 'timeit' imports suggests new file operations and performance measurement capabilities.
Line range hint
456-465
: LGTM, but consider reviewing for Polars compatibility.The caching function remains unchanged, which is fine if it's working as expected. However, given the transition to Polars in other parts of the code, it might be worth reviewing this function to ensure it's fully compatible with Polars DataFrames.
Consider testing the caching mechanism with Polars DataFrames to confirm that serialization and deserialization work correctly. If issues are found, you may need to update the pickling process or consider using a different serialization method that's optimized for Polars.
🧰 Tools
🪛 Ruff
420-423: Use ternary operator
dev_stays = stays[dev] if polars else stays.iloc[dev]
instead ofif
-else
-blockReplace
if
-else
-block withdev_stays = stays[dev] if polars else stays.iloc[dev]
(SIM108)
439-439: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
447-447: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
451-451: Use
key in dict
instead ofkey in dict.keys()
Remove
.keys()
(SIM118)
icu_benchmarks/models/wrappers.py (7)
4-6
: LGTM: New imports added.The addition of these imports (Path, torchmetrics, and additional sklearn metrics) is appropriate for the changes made in the file.
19-25
: LGTM: Import statements updated.The addition of the confusion_matrix import and the change in the RunMode import location are consistent with the file's purpose and likely reflect changes in the project structure.
49-50
: LGTM: New attributes added to BaseModule.The addition of
debug
andexplain_features
attributes provides useful functionality for debugging and model interpretability.
67-74
: LGTM: Enhanced set_weight method.The updated
set_weight
method now supports dataset-specific weight balancing, which is a valuable improvement for handling imbalanced datasets.
410-411
: Unused parameter persists.The
row_indicators
parameter is still not used in the method body. This issue has been previously identified and remains unaddressed.
Line range hint
437-457
: LGTM: Enhanced debugging and feature explanation.The addition of debug output saving and feature explanation functionality in the
test_step
method is a valuable improvement for model interpretability and debugging.
476-487
: LGTM: Improved metric logging.The updates to the
log_metrics
method, including special handling for the Confusion_Matrix and refined metric logging logic, enhance the class's ability to track and report model performance.icu_benchmarks/data/preprocessor.py (1)
9-9
: LGTM! Enhanced flexibility with Polars integration andvars_to_exclude
parameter.The addition of Polars library import and the modification to the
Preprocessor
base class to includevars_to_exclude
parameter provide more flexibility in data preprocessing. This change allows for better control over which variables are processed, which can be particularly useful in complex preprocessing scenarios.Also applies to: 34-34
Updates from the cassandra project currently include:
Summary by CodeRabbit
New Features
Bug Fixes
Documentation
Chores