Skip to content

Commit

Permalink
modality mapping enhancements and naming clash
Browse files Browse the repository at this point in the history
  • Loading branch information
rvandewater committed Jul 29, 2024
1 parent b81346d commit a15db4f
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 7 deletions.
16 changes: 13 additions & 3 deletions icu_benchmarks/data/split_process_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,10 @@ def preprocess_data(
logging.info(f"Loading data from directory {data_dir.absolute()}")
data = {f: pl.read_parquet(data_dir / file_names[f]) for f in file_names.keys()}

logging.debug(f"Modality mapping: {modality_mapping}")
if len(modality_mapping) > 0:
# Optional modality selection
if not selected_modalities == "all":
if not (selected_modalities == "all" or selected_modalities == ["all"] or selected_modalities == None):
data, vars = modality_selection(data, modality_mapping, selected_modalities, vars)
else:
logging.info(f"Selecting all modalities.")
Expand Down Expand Up @@ -162,16 +163,25 @@ def preprocess_data(
def modality_selection(data: dict[pl.DataFrame], modality_mapping: dict[str], selected_modalities: list[str], vars) -> dict[pl.DataFrame]:
logging.info(f"Selected modalities: {selected_modalities}")
selected_columns =[modality_mapping[cols] for cols in selected_modalities if cols in modality_mapping.keys()]
if selected_columns == []:
logging.info(f"No columns selected. Using all columns.")
return data, vars
selected_columns = sum(selected_columns, [])
selected_columns.extend([vars[Var.group], vars[Var.label], vars[Var.sequence]])
old_columns =[]
# Update vars dict
for key, value in vars.items():
if key not in [Var.group, Var.label, Var.sequence]:
old_columns.extend(value)
vars[key] = [col for col in value if col in selected_columns]
logging.info(f"Selected columns: {selected_columns}")
# -3 becaus of standard columns
logging.info(f"Selected columns: {len(selected_columns)-3}, old columns: {len(old_columns)}")
logging.debug(f"Difference: {set(old_columns) - set(selected_columns)}")
# Update data dict
for key in data.keys():
sel_col = [col for col in data[key].columns if col in selected_columns]
data[key] = data[key].select(sel_col)
logging.debug(f"Selected columns in {key}: {data[key].columns}")
logging.debug(f"Selected columns in {key}: {len(data[key].columns)}")
return data, vars

def make_train_val(
Expand Down
13 changes: 10 additions & 3 deletions icu_benchmarks/run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,12 @@ def create_run_dir(log_dir: Path, randomly_searched_params: str = None) -> Path:
log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S"))
else:
log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f"))
log_dir_run.mkdir(parents=True)
if not log_dir_run.exists():
log_dir_run.mkdir(parents=True)
else:
# Directory clash at last moment
log_dir_run = log_dir / str(datetime.now().strftime("%Y-%m-%dT%H-%M-%S.%f"))
log_dir_run.mkdir(parents=True)
if randomly_searched_params:
(log_dir_run / randomly_searched_params).touch()
return log_dir_run
Expand Down Expand Up @@ -244,8 +249,10 @@ def get_config_files(config_dir: Path):
models = glob.glob(os.path.join(config_dir / "prediction_models", '*'))
tasks = [os.path.splitext(os.path.basename(task))[0] for task in tasks]
models = [os.path.splitext(os.path.basename(model))[0] for model in models]
tasks.remove("common")
models.remove("common")
if "common" in tasks:
tasks.remove("common")
if "common" in models:
models.remove("common")
logging.info(f"Found tasks: {tasks}")
logging.info(f"Found models: {models}")
return tasks, models
3 changes: 2 additions & 1 deletion icu_benchmarks/wandb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def set_wandb_experiment_name(args, mode):
data_dir = Path(args.data_dir)
args.name = data_dir.name
run_name = f"{mode}_{args.model}_{args.name}"

if args.modalities:
run_name += f"_mods_{args.modalities}"
if args.fine_tune:
run_name += f"_source_{args.source_name}_fine-tune_{args.fine_tune}_samples"
elif args.eval:
Expand Down

0 comments on commit a15db4f

Please sign in to comment.