From f045fe8587128243c8b20f18c7e04aaf627d4700 Mon Sep 17 00:00:00 2001
From: PrimozGodec
Date: Fri, 10 Feb 2023 17:07:53 +0100
Subject: [PATCH] Ontology - Enable insert in ontology with multiple roots
---
orangecontrib/text/ontology.py | 37 +++++++++++++++--------
orangecontrib/text/tests/test_ontology.py | 12 ++++++++
orangecontrib/text/widgets/owontology.py | 2 +-
3 files changed, 38 insertions(+), 13 deletions(-)
diff --git a/orangecontrib/text/ontology.py b/orangecontrib/text/ontology.py
index 3da9a7a5d..d1b8cb9ad 100644
--- a/orangecontrib/text/ontology.py
+++ b/orangecontrib/text/ontology.py
@@ -250,17 +250,11 @@ def generate(
)
return Tree.from_prufer_sequence(ontology, words, root).to_dict(), skipped
- def insert(
- self,
- tree: Dict,
- words: List[str],
- callback: Callable = dummy_callback
- ) -> Tuple[Dict, int]:
- tree = Tree.from_dict(tree)
- ticks = iter(np.linspace(0.3, 0.9, len(words)))
+ def insert_in_tree(
+ self, tree: Tree, words: List[str], callback: Callable
+ ) -> Tuple[Tree, int]:
skipped = 0
-
- for word in words:
+ for iw, word in enumerate(words, start=1):
tree.adj_list.append(set())
tree.labels.append(word)
embeddings = self.embedder(tree.labels)
@@ -282,9 +276,28 @@ def insert(
best = np.argmax(scores)
tree.adj_list[best].add(idx)
tree.adj_list[idx].add(best)
- callback(next(ticks))
+ callback(iw / len(words))
+ return tree, skipped
+
+ __TEMP_ROOT = "" # root different from any word
+
+ def insert(
+ self, tree: Dict, words: List[str], callback: Callable = dummy_callback
+ ) -> Tuple[Dict, int]:
+ dummy_root_used = False
+ if len(tree) > 1:
+ # if ontology has multiple roots insert temporary root
+ tree = {self.__TEMP_ROOT: tree}
+ dummy_root_used = True
+
+ tree = Tree.from_dict(tree)
+ tree, skipped = self.insert_in_tree(tree, words, callback)
+ tree = tree.to_dict()
- return tree.to_dict(), skipped
+ if dummy_root_used:
+ # if temporary root inserted remove it
+ tree = tree[self.__TEMP_ROOT]
+ return tree, skipped
def score(self, tree: Dict, callback: Callable = dummy_callback) -> float:
if not tree:
diff --git a/orangecontrib/text/tests/test_ontology.py b/orangecontrib/text/tests/test_ontology.py
index ba02828eb..545def87e 100644
--- a/orangecontrib/text/tests/test_ontology.py
+++ b/orangecontrib/text/tests/test_ontology.py
@@ -130,6 +130,18 @@ def test_insert(self):
)
self.assertEqual(skipped, 0)
+ @patch(
+ "httpx.AsyncClient.post",
+ make_dummy_post(arrays_to_response(RESPONSE2 + RESPONSE2)),
+ )
+ def test_insert_not_tree(self):
+ """Insert should also work when ontology has multiple roots"""
+ tree = {"1": {"2": {}}, "4": {}}
+ new_tree, skipped = self.handler.insert(tree, ["7"])
+ # 7 goes under number 1 since it has the same embedding as 1
+ self.assertDictEqual(new_tree, {"1": {"2": {}, "7": {}}, "4": {}})
+ self.assertEqual(skipped, 0)
+
@patch('httpx.AsyncClient.post', make_dummy_post(array_to_response(np.zeros(384))))
def test_score(self):
tree, skipped = self.handler.generate(['1', '2', '3'])
diff --git a/orangecontrib/text/widgets/owontology.py b/orangecontrib/text/widgets/owontology.py
index 112d62581..e74d6c5f4 100644
--- a/orangecontrib/text/widgets/owontology.py
+++ b/orangecontrib/text/widgets/owontology.py
@@ -922,7 +922,7 @@ def _save_state(self):
def _enable_include_button(self):
tree = self.__ontology_view.get_data()
words = self.__get_selected_input_words()
- enabled = len(tree) == 1 and len(words) > 0
+ enabled = len(tree) >= 1 and len(words) > 0
self.__inc_button.setEnabled(enabled)
def send_report(self):