Skip to content

Commit

Permalink
changed 0/1 to categorical vars
Browse files Browse the repository at this point in the history
  • Loading branch information
youssef.mecky committed Oct 9, 2023
1 parent 8880ade commit 869a611
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 103 deletions.
99 changes: 50 additions & 49 deletions configs/prediction_models/RNNpytorch.gin
Original file line number Diff line number Diff line change
Expand Up @@ -125,55 +125,7 @@ PredictionDatasetTFTpytorch.target=[
"MissingIndicator_48",
"label",
]
PredictionDatasetTFTpytorch.time_varying_unknown_reals=["alb",
"alp",
"alt",
"ast",
"be",
"bicar",
"bili",
"bili_dir",
"bnd",
"bun",
"ca",
"cai",
"ck",
"ckmb",
"cl",
"crea",
"crp",
"dbp",
"fgn",
"fio2",
"glu",
"hgb",
"hr",
"inr_pt",
"k",
"lact",
"lymph",
"map",
"mch",
"mchc",
"mcv",
"methb",
"mg",
"na",
"neut",
"o2sat",
"pco2",
"ph",
"phos",
"plt",
"po2",
"ptt",
"resp",
"sbp",
"temp",
"tnt",
"urine",
"wbc",
"MissingIndicator_1",
PredictionDatasetTFTpytorch.time_varying_unknown_categoricals=["MissingIndicator_1",
"MissingIndicator_2",
"MissingIndicator_3",
"MissingIndicator_4",
Expand Down Expand Up @@ -223,3 +175,52 @@ PredictionDatasetTFTpytorch.time_varying_unknown_reals=["alb",
"MissingIndicator_48",
"label",
]
PredictionDatasetTFTpytorch.time_varying_unknown_reals=["alb",
"alp",
"alt",
"ast",
"be",
"bicar",
"bili",
"bili_dir",
"bnd",
"bun",
"ca",
"cai",
"ck",
"ckmb",
"cl",
"crea",
"crp",
"dbp",
"fgn",
"fio2",
"glu",
"hgb",
"hr",
"inr_pt",
"k",
"lact",
"lymph",
"map",
"mch",
"mchc",
"mcv",
"methb",
"mg",
"na",
"neut",
"o2sat",
"pco2",
"ph",
"phos",
"plt",
"po2",
"ptt",
"resp",
"sbp",
"temp",
"tnt",
"urine",
"wbc",]

101 changes: 51 additions & 50 deletions configs/prediction_models/TFTpytorch.gin
Original file line number Diff line number Diff line change
Expand Up @@ -26,55 +26,7 @@ PredictionDatasetTFTpytorch.max_prediction_length = 24
PredictionDatasetTFTpytorch.target="label"
PredictionDatasetTFTpytorch.time_varying_known_reals=["time_idx"]
PredictionDatasetTFTpytorch.add_relative_time_idx=True
PredictionDatasetTFTpytorch.time_varying_unknown_reals=["alb",
"alp",
"alt",
"ast",
"be",
"bicar",
"bili",
"bili_dir",
"bnd",
"bun",
"ca",
"cai",
"ck",
"ckmb",
"cl",
"crea",
"crp",
"dbp",
"fgn",
"fio2",
"glu",
"hgb",
"hr",
"inr_pt",
"k",
"lact",
"lymph",
"map",
"mch",
"mchc",
"mcv",
"methb",
"mg",
"na",
"neut",
"o2sat",
"pco2",
"ph",
"phos",
"plt",
"po2",
"ptt",
"resp",
"sbp",
"temp",
"tnt",
"urine",
"wbc",
"MissingIndicator_1",
PredictionDatasetTFTpytorch.time_varying_unknown_categoricals=["MissingIndicator_1",
"MissingIndicator_2",
"MissingIndicator_3",
"MissingIndicator_4",
Expand Down Expand Up @@ -123,4 +75,53 @@ PredictionDatasetTFTpytorch.time_varying_unknown_reals=["alb",
"MissingIndicator_47",
"MissingIndicator_48",
"label",
]
]
PredictionDatasetTFTpytorch.time_varying_unknown_reals=["alb",
"alp",
"alt",
"ast",
"be",
"bicar",
"bili",
"bili_dir",
"bnd",
"bun",
"ca",
"cai",
"ck",
"ckmb",
"cl",
"crea",
"crp",
"dbp",
"fgn",
"fio2",
"glu",
"hgb",
"hr",
"inr_pt",
"k",
"lact",
"lymph",
"map",
"mch",
"mchc",
"mcv",
"methb",
"mg",
"na",
"neut",
"o2sat",
"pco2",
"ph",
"phos",
"plt",
"po2",
"ptt",
"resp",
"sbp",
"temp",
"tnt",
"urine",
"wbc",]

7 changes: 5 additions & 2 deletions icu_benchmarks/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,7 @@ def __init__(
time_varying_unknown_reals: List[str],
target: Union[str, List[str]],
time_varying_known_reals: List[str],
time_varying_unknown_categoricals: List[str],
*args,
ram_cache: bool = False,
add_relative_time_idx: bool = False,
Expand All @@ -486,6 +487,7 @@ def __init__(
) # combine labels and features
# self.data["sex"].replace([0, 1], ["Female", "Male"], inplace=True)
# List of column names to convert from boolean to float

boolean_columns = [
"MissingIndicator_1",
"MissingIndicator_2",
Expand Down Expand Up @@ -540,8 +542,9 @@ def __init__(

# Convert multiple columns from boolean to float
self.data[boolean_columns] = self.data[boolean_columns].astype(
float
str
) # changing boolean to floats to allow input to models

self.split = split
self.args = args
self.ram_cache = ram_cache
Expand All @@ -561,7 +564,7 @@ def __init__(
static_reals=["height", "weight", "age", "sex"],
time_varying_known_categoricals=[],
time_varying_known_reals=time_varying_known_reals,
time_varying_unknown_categoricals=[],
time_varying_unknown_categoricals=time_varying_unknown_categoricals,
time_varying_unknown_reals=time_varying_unknown_reals,
add_relative_time_idx=add_relative_time_idx,
add_target_scales=True,
Expand Down
3 changes: 1 addition & 2 deletions icu_benchmarks/models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ def train_common(
num_workers: Number of workers to use for data loading.
"""
logging.info(f"Training model: {model.__name__}.")
with open("data.pkl", "wb") as f:
pickle.dump(data, f)

# choose dataset_class based on the model
dataset_class = (
ImputationDataset
Expand Down

0 comments on commit 869a611

Please sign in to comment.