forked from cltl/aproof-icf-classifier
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
148 lines (121 loc) · 4.92 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""
The script generates predictions of the level of functioning that is described in a clinical note in Dutch. The predictions are made for 9 WHO-ICF domains: 'ADM', 'ATT', 'BER', 'ENR', 'ETN', 'FAC', 'INS', 'MBW', 'STM'.
The script can be customized with the following parameters:
--in_csv: path to input csv file
--text_col: name of the column containing the text
To change the default values of a parameter, pass it in the command line, e.g.:
$ python main.py --in_csv myfile.csv --text_col notitie_tekst
"""
import spacy
import argparse
import warnings
import pandas as pd
from pathlib import Path
from shutil import ReadError
from src.text_processing import anonymize
from src.icf_classifiers import predict_domains, predict_levels
def add_level_predictions(
sents,
domains,
):
"""
For each domain, select the sentences in `sents` that were predicted as discussing this domain. Apply the relevant levels regression model to get level predictions and join them back to `sents`.
Parameters
----------
sents: pd DataFrame
df with sentences and `predictions` of the domains classifier
domains: list
list of all the domains, in the order in which they appear in the multi-label
Returns
-------
sents: pd DataFrame
the input df with additional columns containing levels predictions
"""
for i, dom in enumerate(domains):
boolean = sents['predictions'].apply(lambda x: bool(x[i]))
results = sents[boolean]
if results.empty:
print(f'There are no sentences for which {dom} was predicted.')
else:
print(f'Generating levels predictions for {dom}.')
lvl_model = f'CLTL/icf-levels-{dom.lower()}'
predictions = predict_levels(results['text'], 'roberta', lvl_model).rename(f"{dom}_lvl")
sents = sents.join(predictions)
return sents
def main(
in_csv,
text_col,
):
"""
Read the `in_csv` file, process the text by row (anonymize, split to sentences), predict domains and levels per sentence, aggregate the results back to note-level, write the results to the output file.
Parameters
----------
in_csv: str
path to csv file with the text to process; the csv must follow the following specs: sep=';', quotechar='"', encoding='utf-8', first row is the header
text_col: str
name of the column containing the text
Returns
-------
None
"""
domains=['ADM', 'ATT', 'BER', 'ENR', 'ETN', 'FAC', 'INS', 'MBW', 'STM']
levels = [f"{domain}_lvl" for domain in domains]
# check path
in_csv = Path(in_csv)
msg = f'The csv file cannot be found in this location: "{in_csv}"'
assert in_csv.exists(), msg
# read csv
print(f'Loading input csv file: {in_csv}')
try:
df = pd.read_csv(
in_csv,
sep=';',
header=0,
quotechar='"',
encoding='utf-8',
low_memory=False,
)
print(f'Input csv file ({in_csv}) is successfuly loaded!')
except:
raise ReadError('The input csv file cannot be read. Please check that it conforms with the required specifications (separator, header, quotechar, encoding).')
if len(df) > 3000:
warnings.warn('The csv file contains more than 3,000 rows. This is not recommended since it might cause problems when generating predictions; consider splitting to several smaller files.')
# anonymize
print(f'Anonymizing the text in "{text_col}" column. This might take a while.')
nlp = spacy.load('nl_core_news_lg')
anonym_notes = df[text_col].apply(lambda i: anonymize(i, nlp)).rename('anonym_text')
# split to sentences
print(f'Splitting the text in "{text_col}" column to sentences. This might take a while.')
to_sentence = lambda txt: [str(sent) for sent in list(nlp(txt).sents)]
sents = anonym_notes.apply(to_sentence).explode().rename('text').reset_index().rename(columns={'index': 'note_index'})
# predict domains
print('Generating domains predictions. This might take a while.')
sents['predictions'] = predict_domains(
sents['text'],
'roberta',
'CLTL/icf-domains',
)
# predict levels
print('Processing domains predictions.')
sents = add_level_predictions(sents, domains)
# aggregate to note-level
note_predictions = sents.groupby('note_index')[levels].mean()
df = df.merge(
note_predictions,
how='left',
left_index=True,
right_index=True,
)
# save output file
out_csv = in_csv.parent / (in_csv.stem + '_output.csv')
df.to_csv(out_csv, sep='\t', index=False)
print(f'The output file is saved: {out_csv}')
if __name__ == '__main__':
argparser = argparse.ArgumentParser()
argparser.add_argument('--in_csv', default='./example/input.csv')
argparser.add_argument('--text_col', default='text')
args = argparser.parse_args()
main(
args.in_csv,
args.text_col,
)