-
Notifications
You must be signed in to change notification settings - Fork 25
/
utils.py
145 lines (106 loc) · 5.04 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Tuple, List, Dict
def get_episode_indices(episodes_string: str) -> List[int]:
"""
Parse a string such as '2' or '1-5' into a list of integers such as [2] or [1, 2, 3, 4, 5].
"""
episode_indices = []
if episodes_string is not None and episodes_string is not '':
ll = [int(item) for item in episodes_string.split('-')]
if len(ll) == 1:
episode_indices = ll
else:
_start, _end = ll
episode_indices = list(range(_start, _end + 1))
return episode_indices
def expand_tokens(tokens: List[str], augmentations: List[Tuple[List[tuple], int, int]],
entity_tree: Dict[int, List[int]], root: int,
begin_entity_token: str, sep_token: str, relation_sep_token: str, end_entity_token: str) \
-> List[str]:
"""
Recursively expand the tokens to obtain a sentence in augmented natural language.
Used in the augment_sentence function below (see the documentation there).
"""
new_tokens = []
root_start, root_end = augmentations[root][1:] if root >= 0 else (0, len(tokens))
i = root_start # current index
for entity_index in entity_tree[root]:
tags, start, end = augmentations[entity_index]
# add tokens before this entity
new_tokens += tokens[i:start]
# expand this entity
new_tokens.append(begin_entity_token)
new_tokens += expand_tokens(tokens, augmentations, entity_tree, entity_index,
begin_entity_token, sep_token, relation_sep_token, end_entity_token)
for tag in tags:
if tag[0]:
# only append tag[0] if it is a type, otherwise skip the type
new_tokens.append(sep_token)
new_tokens.append(tag[0])
for x in tag[1:]:
new_tokens.append(relation_sep_token)
new_tokens.append(x)
new_tokens.append(end_entity_token)
i = end
# add tokens after all entities
new_tokens += tokens[i:root_end]
return new_tokens
def augment_sentence(tokens: List[str], augmentations: List[Tuple[List[tuple], int, int]], begin_entity_token: str,
sep_token: str, relation_sep_token: str, end_entity_token: str) -> str:
"""
Augment a sentence by adding tags in the specified positions.
Args:
tokens: Tokens of the sentence to augment.
augmentations: List of tuples (tags, start, end).
begin_entity_token: Beginning token for an entity, e.g. '['
sep_token: Separator token, e.g. '|'
relation_sep_token: Separator token for relations, e.g. '='
end_entity_token: End token for an entity e.g. ']'
An example follows.
tokens:
['Tolkien', 'was', 'born', 'here']
augmentations:
[
([('person',), ('born in', 'here')], 0, 1),
([('location',)], 3, 4),
]
output augmented sentence:
[ Tolkien | person | born in = here ] was born [ here | location ]
"""
# sort entities by start position, longer entities first
augmentations = list(sorted(augmentations, key=lambda z: (z[1], -z[2])))
# check that the entities have a tree structure (if two entities overlap, then one is contained in
# the other), and build the entity tree
root = -1 # each node is represented by its position in the list of augmentations, except that the root is -1
entity_tree = {root: []} # list of children of each node
current_stack = [root] # where we are in the tree
for j, x in enumerate(augmentations):
tags, start, end = x
if any(augmentations[k][1] < start < augmentations[k][2] < end for k in current_stack):
# tree structure is not satisfied!
logging.warning(f'Tree structure is not satisfied! Dropping annotation {x}')
continue
while current_stack[-1] >= 0 and \
not (augmentations[current_stack[-1]][1] <= start <= end <= augmentations[current_stack[-1]][2]):
current_stack.pop()
# add as a child of its father
entity_tree[current_stack[-1]].append(j)
# update stack
current_stack.append(j)
# create empty list of children for this new node
entity_tree[j] = []
return ' '.join(expand_tokens(
tokens, augmentations, entity_tree, root, begin_entity_token, sep_token, relation_sep_token, end_entity_token
))
def get_span(l: List[str], span: List[int]):
assert len(span) == 2
return " ".join([l[i] for i in range(span[0], span[1]) if i < len(l)])
def get_precision_recall_f1(num_correct, num_predicted, num_gt):
assert 0 <= num_correct <= num_predicted
assert 0 <= num_correct <= num_gt
precision = num_correct / num_predicted if num_predicted > 0 else 0.
recall = num_correct / num_gt if num_gt > 0 else 0.
f1 = 2. / (1. / precision + 1. / recall) if num_correct > 0 else 0.
return precision, recall, f1