diff --git a/icu_benchmarks/data/split_process_data.py b/icu_benchmarks/data/split_process_data.py index a9ae6ed7..5c0effbb 100644 --- a/icu_benchmarks/data/split_process_data.py +++ b/icu_benchmarks/data/split_process_data.py @@ -96,9 +96,10 @@ def preprocess_data( logging.info(f"Loading data from directory {data_dir.absolute()}") data = {f: pl.read_parquet(data_dir / file_names[f]) for f in file_names.keys()} + logging.debug(f"Modality mapping: {modality_mapping}") if len(modality_mapping) > 0: # Optional modality selection - if not selected_modalities == "all": + if not (selected_modalities == "all" or selected_modalities == ["all"] or selected_modalities == None): data, vars = modality_selection(data, modality_mapping, selected_modalities, vars) else: logging.info(f"Selecting all modalities.") @@ -162,16 +163,25 @@ def preprocess_data( def modality_selection(data: dict[pl.DataFrame], modality_mapping: dict[str], selected_modalities: list[str], vars) -> dict[pl.DataFrame]: logging.info(f"Selected modalities: {selected_modalities}") selected_columns =[modality_mapping[cols] for cols in selected_modalities if cols in modality_mapping.keys()] + if selected_columns == []: + logging.info(f"No columns selected. Using all columns.") + return data, vars selected_columns = sum(selected_columns, []) selected_columns.extend([vars[Var.group], vars[Var.label], vars[Var.sequence]]) + old_columns =[] + # Update vars dict for key, value in vars.items(): if key not in [Var.group, Var.label, Var.sequence]: + old_columns.extend(value) vars[key] = [col for col in value if col in selected_columns] - logging.info(f"Selected columns: {selected_columns}") + # -3 becaus of standard columns + logging.info(f"Selected columns: {len(selected_columns)-3}, old columns: {len(old_columns)}") + logging.debug(f"Difference: {set(old_columns) - set(selected_columns)}") + # Update data dict for key in data.keys(): sel_col = [col for col in data[key].columns if col in selected_columns] data[key] = data[key].select(sel_col) - logging.debug(f"Selected columns in {key}: {data[key].columns}") + logging.debug(f"Selected columns in {key}: {len(data[key].columns)}") return data, vars def make_train_val( diff --git a/icu_benchmarks/run_utils.py b/icu_benchmarks/run_utils.py index f29419c4..a38a4a1b 100644 --- a/icu_benchmarks/run_utils.py +++ b/icu_benchmarks/run_utils.py @@ -74,7 +74,12 @@ def create_run_dir(log_dir: Path, randomly_searched_params: str = None) -> Path: log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) else: log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f")) - log_dir_run.mkdir(parents=True) + if not log_dir_run.exists(): + log_dir_run.mkdir(parents=True) + else: + # Directory clash at last moment + log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f")) + log_dir_run.mkdir(parents=True) if randomly_searched_params: (log_dir_run / randomly_searched_params).touch() return log_dir_run @@ -244,8 +249,10 @@ def get_config_files(config_dir: Path): models = glob.glob(os.path.join(config_dir / "prediction_models", '*')) tasks = [os.path.splitext(os.path.basename(task))[0] for task in tasks] models = [os.path.splitext(os.path.basename(model))[0] for model in models] - tasks.remove("common") - models.remove("common") + if "common" in tasks: + tasks.remove("common") + if "common" in models: + models.remove("common") logging.info(f"Found tasks: {tasks}") logging.info(f"Found models: {models}") return tasks, models \ No newline at end of file diff --git a/icu_benchmarks/wandb_utils.py b/icu_benchmarks/wandb_utils.py index 225693a4..2ea06b57 100644 --- a/icu_benchmarks/wandb_utils.py +++ b/icu_benchmarks/wandb_utils.py @@ -62,7 +62,8 @@ def set_wandb_experiment_name(args, mode): data_dir = Path(args.data_dir) args.name = data_dir.name run_name = f"{mode}_{args.model}_{args.name}" - + if args.modalities: + run_name += f"_mods_{args.modalities}" if args.fine_tune: run_name += f"_source_{args.source_name}_fine-tune_{args.fine_tune}_samples" elif args.eval: