Skip to content

Commit

Permalink
Ontology - Enable insert in ontology with multiple roots
Browse files Browse the repository at this point in the history
  • Loading branch information
PrimozGodec committed Feb 14, 2023
1 parent cd3e804 commit f045fe8
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 13 deletions.
37 changes: 25 additions & 12 deletions orangecontrib/text/ontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,17 +250,11 @@ def generate(
)
return Tree.from_prufer_sequence(ontology, words, root).to_dict(), skipped

def insert(
self,
tree: Dict,
words: List[str],
callback: Callable = dummy_callback
) -> Tuple[Dict, int]:
tree = Tree.from_dict(tree)
ticks = iter(np.linspace(0.3, 0.9, len(words)))
def insert_in_tree(
self, tree: Tree, words: List[str], callback: Callable
) -> Tuple[Tree, int]:
skipped = 0

for word in words:
for iw, word in enumerate(words, start=1):
tree.adj_list.append(set())
tree.labels.append(word)
embeddings = self.embedder(tree.labels)
Expand All @@ -282,9 +276,28 @@ def insert(
best = np.argmax(scores)
tree.adj_list[best].add(idx)
tree.adj_list[idx].add(best)
callback(next(ticks))
callback(iw / len(words))
return tree, skipped

__TEMP_ROOT = "<R-O-O-T>" # root different from any word

def insert(
self, tree: Dict, words: List[str], callback: Callable = dummy_callback
) -> Tuple[Dict, int]:
dummy_root_used = False
if len(tree) > 1:
# if ontology has multiple roots insert temporary root
tree = {self.__TEMP_ROOT: tree}
dummy_root_used = True

tree = Tree.from_dict(tree)
tree, skipped = self.insert_in_tree(tree, words, callback)
tree = tree.to_dict()

return tree.to_dict(), skipped
if dummy_root_used:
# if temporary root inserted remove it
tree = tree[self.__TEMP_ROOT]
return tree, skipped

def score(self, tree: Dict, callback: Callable = dummy_callback) -> float:
if not tree:
Expand Down
12 changes: 12 additions & 0 deletions orangecontrib/text/tests/test_ontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,18 @@ def test_insert(self):
)
self.assertEqual(skipped, 0)

@patch(
"httpx.AsyncClient.post",
make_dummy_post(arrays_to_response(RESPONSE2 + RESPONSE2)),
)
def test_insert_not_tree(self):
"""Insert should also work when ontology has multiple roots"""
tree = {"1": {"2": {}}, "4": {}}
new_tree, skipped = self.handler.insert(tree, ["7"])
# 7 goes under number 1 since it has the same embedding as 1
self.assertDictEqual(new_tree, {"1": {"2": {}, "7": {}}, "4": {}})
self.assertEqual(skipped, 0)

@patch('httpx.AsyncClient.post', make_dummy_post(array_to_response(np.zeros(384))))
def test_score(self):
tree, skipped = self.handler.generate(['1', '2', '3'])
Expand Down
2 changes: 1 addition & 1 deletion orangecontrib/text/widgets/owontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,7 @@ def _save_state(self):
def _enable_include_button(self):
tree = self.__ontology_view.get_data()
words = self.__get_selected_input_words()
enabled = len(tree) == 1 and len(words) > 0
enabled = len(tree) >= 1 and len(words) > 0
self.__inc_button.setEnabled(enabled)

def send_report(self):
Expand Down

0 comments on commit f045fe8

Please sign in to comment.