From f045fe8587128243c8b20f18c7e04aaf627d4700 Mon Sep 17 00:00:00 2001 From: PrimozGodec Date: Fri, 10 Feb 2023 17:07:53 +0100 Subject: [PATCH] Ontology - Enable insert in ontology with multiple roots --- orangecontrib/text/ontology.py | 37 +++++++++++++++-------- orangecontrib/text/tests/test_ontology.py | 12 ++++++++ orangecontrib/text/widgets/owontology.py | 2 +- 3 files changed, 38 insertions(+), 13 deletions(-) diff --git a/orangecontrib/text/ontology.py b/orangecontrib/text/ontology.py index 3da9a7a5d..d1b8cb9ad 100644 --- a/orangecontrib/text/ontology.py +++ b/orangecontrib/text/ontology.py @@ -250,17 +250,11 @@ def generate( ) return Tree.from_prufer_sequence(ontology, words, root).to_dict(), skipped - def insert( - self, - tree: Dict, - words: List[str], - callback: Callable = dummy_callback - ) -> Tuple[Dict, int]: - tree = Tree.from_dict(tree) - ticks = iter(np.linspace(0.3, 0.9, len(words))) + def insert_in_tree( + self, tree: Tree, words: List[str], callback: Callable + ) -> Tuple[Tree, int]: skipped = 0 - - for word in words: + for iw, word in enumerate(words, start=1): tree.adj_list.append(set()) tree.labels.append(word) embeddings = self.embedder(tree.labels) @@ -282,9 +276,28 @@ def insert( best = np.argmax(scores) tree.adj_list[best].add(idx) tree.adj_list[idx].add(best) - callback(next(ticks)) + callback(iw / len(words)) + return tree, skipped + + __TEMP_ROOT = "" # root different from any word + + def insert( + self, tree: Dict, words: List[str], callback: Callable = dummy_callback + ) -> Tuple[Dict, int]: + dummy_root_used = False + if len(tree) > 1: + # if ontology has multiple roots insert temporary root + tree = {self.__TEMP_ROOT: tree} + dummy_root_used = True + + tree = Tree.from_dict(tree) + tree, skipped = self.insert_in_tree(tree, words, callback) + tree = tree.to_dict() - return tree.to_dict(), skipped + if dummy_root_used: + # if temporary root inserted remove it + tree = tree[self.__TEMP_ROOT] + return tree, skipped def score(self, tree: Dict, callback: Callable = dummy_callback) -> float: if not tree: diff --git a/orangecontrib/text/tests/test_ontology.py b/orangecontrib/text/tests/test_ontology.py index ba02828eb..545def87e 100644 --- a/orangecontrib/text/tests/test_ontology.py +++ b/orangecontrib/text/tests/test_ontology.py @@ -130,6 +130,18 @@ def test_insert(self): ) self.assertEqual(skipped, 0) + @patch( + "httpx.AsyncClient.post", + make_dummy_post(arrays_to_response(RESPONSE2 + RESPONSE2)), + ) + def test_insert_not_tree(self): + """Insert should also work when ontology has multiple roots""" + tree = {"1": {"2": {}}, "4": {}} + new_tree, skipped = self.handler.insert(tree, ["7"]) + # 7 goes under number 1 since it has the same embedding as 1 + self.assertDictEqual(new_tree, {"1": {"2": {}, "7": {}}, "4": {}}) + self.assertEqual(skipped, 0) + @patch('httpx.AsyncClient.post', make_dummy_post(array_to_response(np.zeros(384)))) def test_score(self): tree, skipped = self.handler.generate(['1', '2', '3']) diff --git a/orangecontrib/text/widgets/owontology.py b/orangecontrib/text/widgets/owontology.py index 112d62581..e74d6c5f4 100644 --- a/orangecontrib/text/widgets/owontology.py +++ b/orangecontrib/text/widgets/owontology.py @@ -922,7 +922,7 @@ def _save_state(self): def _enable_include_button(self): tree = self.__ontology_view.get_data() words = self.__get_selected_input_words() - enabled = len(tree) == 1 and len(words) > 0 + enabled = len(tree) >= 1 and len(words) > 0 self.__inc_button.setEnabled(enabled) def send_report(self):