Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem Definition: Object-Oriented Improvements for generate_cutoff_times(...) #32

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
83 changes: 79 additions & 4 deletions cardea/problem_definition/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,31 @@


class ProblemDefinition:
"""A class that defines the prediction problem
by specifying cutoff times and generating the target label if it does not exist.
"""Base class that defines a prediction problem.

Attributes:
target_label_column_name: The target label of the prediction problem.
target_entity: Name of the entity containing the target label.
cutoff_time_label: The cutoff time label of the prediction problem.
cutoff_entity: Name of the entity containing the cutoff time label.
prediction_type: The type of the machine learning prediction.
"""

def __init__(self, target_label_column_name,
target_entity, cutoff_time_label,
cutoff_entity, prediction_type,
updated_es=None, conn=None):

self.target_label_column_name = target_label_column_name
self.target_entity = target_entity
self.cutoff_time_label = cutoff_time_label
self.cutoff_entity = cutoff_entity
self.prediction_type = prediction_type

# optionals
self.conn = conn
self.updated_es = updated_es

def check_target_label(self, entity_set, target_entity, target_label):
"""Checks if target label exists in the entity set.

Expand Down Expand Up @@ -49,11 +70,12 @@ def generate_target_label(self, entity_set, target_entity, target_label):
Target entity with the generated label.
"""

def generate_cutoff_times(self, entity_set):
def generate_cutoff_times(self, entity_set,
cutoff_time_unifier='unify_cutoff_time_admission_time'):
"""Generates cutoff times for the predection problem.

Args:
entity_set: fhir entityset.
entity_set: the FHIR entityset.

Returns:
entity_set, target_entity, series of target_labels and a dataframe of cutoff_times.
Expand All @@ -62,6 +84,59 @@ def generate_cutoff_times(self, entity_set):
ValueError: An error occurs if the cutoff variable does not exist.
"""

loader = DataLoader()

target_label_exists = loader.check_column_existence(
entity_set, self.target_entity, self.target_label_column_name
)

target_label_has_missing_values = loader.check_for_missing_values(
entity_set, self.target_entity, self.target_label_column_name
)

if target_label_exists and not target_label_has_missing_values:
cutoff_time_label_exists = loader.check_column_existence(
entity_set, self.cutoff_entity, self.cutoff_time_label
)

if not cutoff_time_label_exists:
raise ValueError(
'Cutoff time label {} does not exist in table {}'.format(
self.cutoff_time_label,
self.cutoff_entity
)
)

cutoff_time_unifier_func = getattr(self, cutoff_time_unifier)
generated_cts = cutoff_time_unifier_func(
entity_set, self.cutoff_entity, self.cutoff_time_label
)

# new entity set
es = entity_set.entity_from_dataframe(
entity_id=self.cutoff_entity, dataframe=generated_cts, index='object_id'
)

label = es[self.target_entity].df[self.conn].values

instance_id = list(es[self.target_entity].df.index)

# get cutoff_times
cutoff_times = es[self.cutoff_entity].df['ct'].to_frame()
cutoff_times = cutoff_times.reindex(index=label)
cutoff_times = cutoff_times[cutoff_times.index.isin(label)]
cutoff_times['instance_id'] = instance_id
cutoff_times.columns = ['cutoff_time', 'instance_id']
cutoff_times['label'] = list(es[self.target_entity].df[self.target_label_column_name])

return (es, self.target_entity, cutoff_times)

# get a new entity set
self.updated_es = self.generate_target_label(entity_set)

# recursive call
return self.generate_cutoff_times(self.updated_es)

def unify_cutoff_times_hours_admission_time(self, df, cutoff_time_label):
"""Unify records cutoff times based on shared time.

Expand Down
80 changes: 13 additions & 67 deletions cardea/problem_definition/length_of_stay.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,77 +5,23 @@
from cardea.problem_definition import ProblemDefinition


class LengthOfStay (ProblemDefinition):
"""Defines the problem of length of stay, predicting how many days
the patient will be in the hospital.

Attributes:
target_label_column_name: The target label of the prediction problem.
target_entity: Name of the entity containing the target label.
cutoff_time_label: The cutoff time label of the prediction problem.
cutoff_entity: Name of the entity containing the cutoff time label.
prediction_type: The type of the machine learning prediction.
class LengthOfStay(ProblemDefinition):
"""Defines the problem of Length of Stay.

It predicts how many days the patient will be in the hospital.
"""

__name__ = 'los'

updated_es = None
target_label_column_name = 'length'
target_entity = 'Encounter'
cutoff_time_label = 'start'
cutoff_entity = 'Period'
conn = 'period'
prediction_type = 'regression'

def generate_cutoff_times(self, es):
"""Generates cutoff times for the predection problem.

Args:
es: fhir entityset.

Returns:
entity_set, target_entity, and a dataframe of cutoff_times and target_labels.

Raises:
ValueError: An error occurs if the cutoff variable does not exist.
"""

if (self.check_target_label(es,
self.target_entity,
self.target_label_column_name) and not
self.check_for_missing_values_in_target_label(es,
self.target_entity,
self.target_label_column_name)):
if DL().check_column_existence(es,
self.cutoff_entity,
self.cutoff_time_label):
generated_cts = self.unify_cutoff_time_admission_time(
es, self.cutoff_entity, self.cutoff_time_label)

es = es.entity_from_dataframe(entity_id=self.cutoff_entity,
dataframe=generated_cts,
index='object_id')

cutoff_times = es[self.cutoff_entity].df['ct'].to_frame()

label = es[self.target_entity].df[self.conn].values
instance_id = list(es[self.target_entity].df.index)
cutoff_times = cutoff_times.reindex(index=label)
cutoff_times = cutoff_times[cutoff_times.index.isin(label)]
cutoff_times['instance_id'] = instance_id
cutoff_times.columns = ['cutoff_time', 'instance_id']

cutoff_times['label'] = list(
es[self.target_entity].df[self.target_label_column_name])
return(es, self.target_entity, cutoff_times)
else:
raise ValueError('Cutoff time label {} in table {}' +
'does not exist'.format(self.cutoff_time_label,
self.target_entity))

else:
updated_es = self.generate_target_label(es)
return self.generate_cutoff_times(updated_es)
def __init__(self):
super().__init__(
'length', # target_label_column_name
'Encounter', # target_entity
'start', # cutoff_time_label
'Period', # cutoff_entity
'regression', # prediction_type
conn='period'
)

def generate_target_label(self, es):
"""Generates target labels in the case of having missing label in the entityset.
Expand Down
85 changes: 23 additions & 62 deletions cardea/problem_definition/mortality_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,86 +3,47 @@
from cardea.data_loader import DataLoader
from cardea.problem_definition import ProblemDefinition

DEFAULT_CAUSES = ['X60', 'X84', 'Y87.0', 'X85', 'Y09', 'Y87.1',
'V02', 'V04', 'V09.0', 'V09.2', 'V12', 'V14']

class MortalityPrediction (ProblemDefinition):
"""Defines the problem of diagnosis Prediction.

Finding whether a patient will be diagnosed with a specifed diagnosis.
class MortalityPrediction(ProblemDefinition):
"""Defines the problem of Diagnosis Prediction.

It finds whether a patient will be diagnosed with a specifed diagnosis.

Note:
The patient visit is considered a readmission if he visits
the hospital again within 30 days.

The readmission diagnosis does not have to be the same as the initial visit diagnosis,
(he could be diagnosed of something that is a complication of the initial diagnosis).

Attributes:

target_label_column_name: The target label of the prediction problem.
target_entity: Name of the entity containing the target label.
cutoff_time_label: The cutoff time label of the prediction problem.
cutoff_entity: Name of the entity containing the cutoff time label.
prediction_type: The type of the machine learning prediction.
"""
__name__ = 'mortality'

updated_es = None
target_label_column_name = 'diagnosis'
target_entity = 'Encounter'
cutoff_time_label = 'start'
cutoff_entity = 'Period'
prediction_type = 'classification'
conn = 'period'
causes_of_death = ['X60', 'X84', 'Y87.0', 'X85', 'Y09',
'Y87.1', 'V02', 'V04', 'V09.0', 'V09.2', 'V12', 'V14']
def __init__(self, causes_of_death=DEFAULT_CAUSES):
self.causes_of_death = causes_of_death

def generate_cutoff_times(self, es):
"""Generates cutoff times for the predection problem.

Args:
es: fhir entityset.

Returns:
entity_set, target_entity, and a dataframe of cutoff_times and target_labels.

Raises:
ValueError: An error occurs if the cutoff variable does not exist.
"""
super().__init__(
'diagnosis', # target_label_column_name
'Encounter', # target_entity
'start', # cutoff_time_label
'Period', # cutoff_entity
'classification', # prediction_type
conn='period'
)

def generate_cutoff_times(self, es):
es = self.generate_target_label(es)

if DataLoader().check_column_existence(
es,
self.cutoff_entity,
self.cutoff_time_label): # check the existance of the cutoff label

generated_cts = self.unify_cutoff_time_admission_time(
es, self.cutoff_entity, self.cutoff_time_label)

es = es.entity_from_dataframe(entity_id=self.cutoff_entity,
dataframe=generated_cts,
index='object_id')
entity_set, target_entity, cutoff_times = super().generate_cutoff_times(es)

cutoff_times = es[self.cutoff_entity].df['ct'].to_frame()
# post-processing step
for (idx, row) in cutoff_times.iterrows():
new_val = row.loc['label'] in self.causes_of_death
cutoff_times.set_value(idx, 'label', new_val)

label = es[self.target_entity].df[self.conn].values
instance_id = list(es[self.target_entity].df.index)
cutoff_times = cutoff_times.reindex(index=label)

cutoff_times = cutoff_times[cutoff_times.index.isin(label)]
cutoff_times['instance_id'] = instance_id
cutoff_times.columns = ['cutoff_time', 'instance_id']

cutoff_times['label'] = list(es[self.target_entity].df[self.target_label_column_name])

for (idx, row) in cutoff_times.iterrows():
new_val = row.loc['label'] in self.causes_of_death
cutoff_times.set_value(idx, 'label', new_val)

return(es, self.target_entity, cutoff_times)
else:
raise ValueError('Cutoff time label {} in table {} does not exist'
.format(self.cutoff_time_label, self.target_entity))
return (entity_set, target_entity, cutoff_times)

def generate_target_label(self, es):
"""Generates target labels in the case of having missing label in the entityset.
Expand Down
Loading