Skip to content

Commit

Permalink
Ontology - Enable insert in ontology with multiple roots
Browse files Browse the repository at this point in the history
  • Loading branch information
PrimozGodec committed Feb 10, 2023
1 parent ad40020 commit e765686
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 32 deletions.
37 changes: 25 additions & 12 deletions orangecontrib/text/ontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,17 +250,11 @@ def generate(
)
return Tree.from_prufer_sequence(ontology, words, root).to_dict(), skipped

def insert(
self,
tree: Dict,
words: List[str],
callback: Callable = dummy_callback
) -> Tuple[Dict, int]:
tree = Tree.from_dict(tree)
ticks = iter(np.linspace(0.3, 0.9, len(words)))
def insert_in_tree(
self, tree: Tree, words: List[str], callback: Callable
) -> Tuple[Tree, int]:
skipped = 0

for word in words:
for iw, word in enumerate(words, start=1):
tree.adj_list.append(set())
tree.labels.append(word)
embeddings = self.embedder(tree.labels)
Expand All @@ -282,9 +276,28 @@ def insert(
best = np.argmax(scores)
tree.adj_list[best].add(idx)
tree.adj_list[idx].add(best)
callback(next(ticks))
callback(iw / len(words))
return tree, skipped

__TEMP_ROOT = "<R-O-O-T>" # root different from any word

def insert(
self, tree: Dict, words: List[str], callback: Callable = dummy_callback
) -> Tuple[Dict, int]:
dummy_root_used = False
if len(tree) > 1:
# if ontology has multiple roots insert temporary root
tree = {self.__TEMP_ROOT: tree}
dummy_root_used = True

tree = Tree.from_dict(tree)
tree, skipped = self.insert_in_tree(tree, words, callback)
tree = tree.to_dict()

return tree.to_dict(), skipped
if dummy_root_used:
# if temporary root inserted remove it
tree = tree[self.__TEMP_ROOT]
return tree, skipped

def score(self, tree: Dict, callback: Callable = dummy_callback) -> float:
if not tree:
Expand Down
12 changes: 12 additions & 0 deletions orangecontrib/text/tests/test_ontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,18 @@ def test_insert(self):
)
self.assertEqual(skipped, 0)

@patch(
"httpx.AsyncClient.post",
make_dummy_post(arrays_to_response(RESPONSE2 + RESPONSE2)),
)
def test_insert_not_tree(self):
"""Insert should also work when ontology has multiple roots"""
tree = {"1": {"2": {}}, "4": {}}
new_tree, skipped = self.handler.insert(tree, ["7"])
# 7 goes under number 1 since it has the same embedding as 1
self.assertDictEqual(new_tree, {"1": {"2": {}, "7": {}}, "4": {}})
self.assertEqual(skipped, 0)

@patch('httpx.AsyncClient.post', make_dummy_post(array_to_response(np.zeros(384))))
def test_score(self):
tree, skipped = self.handler.generate(['1', '2', '3'])
Expand Down
2 changes: 1 addition & 1 deletion orangecontrib/text/widgets/owontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,7 @@ def _save_state(self):
def _enable_include_button(self):
tree = self.__ontology_view.get_data()
words = self.__get_selected_input_words()
enabled = len(tree) == 1 and len(words) > 0
enabled = len(tree) >= 1 and len(words) > 0
self.__inc_button.setEnabled(enabled)

def send_report(self):
Expand Down
39 changes: 20 additions & 19 deletions orangecontrib/text/widgets/tests/test_owontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from orangecontrib.text.widgets.utils.words import create_words_table


SBERT_PATCH_METHOD = "orangecontrib.text.ontology.SBERT.__call__"


class TestUtils(unittest.TestCase):
def test_tree_to_html(self):
tree = {"foo": {"bar": {},
Expand Down Expand Up @@ -133,6 +136,7 @@ def test_set_data_with_selection(self):


class TestOWOntology(WidgetTest):
@patch(SBERT_PATCH_METHOD, Mock(return_value=[np.ones(300)] * 3))
def setUp(self):
self._ontology_1 = {"foo1": {"bar1": {}, "baz1": {}}}
self._ontology_2 = {"foo2": {"bar2": {}, "baz2": {}}}
Expand All @@ -146,7 +150,8 @@ def setUp(self):
}
self.widget = self.create_widget(OWOntology, stored_settings=settings)

def test_input_words(self):
@patch(SBERT_PATCH_METHOD, return_value=[np.ones(300)] * 3)
def test_input_words(self, _):
get_ontology_data = self.widget._OWOntology__ontology_view.get_data

words = create_words_table(["foo"])
Expand Down Expand Up @@ -188,14 +193,16 @@ def select_words(indices):
output = self.get_output(self.widget.Outputs.words)
self.assert_table_equal(words, output)

def test_library_sel_changed(self):
@patch(SBERT_PATCH_METHOD, return_value=[np.ones(300)] * 3)
def test_library_sel_changed(self, _):
get_ontology_data = self.widget._OWOntology__ontology_view.get_data
self.assertEqual(get_ontology_data(), self._ontology_1)
self.widget._OWOntology__set_selected_row(1)
self.assertEqual(self.widget._OWOntology__get_selected_row(), 1)
self.assertEqual(get_ontology_data(), self._ontology_2)

def test_library_add(self):
@patch(SBERT_PATCH_METHOD, return_value=[np.ones(300)] * 3)
def test_library_add(self, _):
get_ontology_data = self.widget._OWOntology__ontology_view.get_data

self.widget._OWOntology__on_add()
Expand All @@ -205,7 +212,8 @@ def test_library_add(self):
self.widget._OWOntology__set_selected_row(1)
self.assertEqual(get_ontology_data(), self._ontology_2)

def test_library_remove(self):
@patch(SBERT_PATCH_METHOD, return_value=[np.ones(300)] * 3)
def test_library_remove(self, _):
get_ontology_data = self.widget._OWOntology__ontology_view.get_data

self.widget._OWOntology__on_remove()
Expand All @@ -218,7 +226,8 @@ def test_library_remove(self):
self.assertEqual(get_ontology_data(), self._ontology_2)
self.assertIsNone(self.widget._OWOntology__get_selected_row())

def test_library_update(self):
@patch(SBERT_PATCH_METHOD, return_value=[np.ones(300)] * 3)
def test_library_update(self, _):
self.assertEqual(self.widget._OWOntology__get_selected_row(), 0)
model = self.widget._OWOntology__ontology_view._EditableTreeView__model
model.setData(model.index(0, 0), "foo3", role=Qt.EditRole)
Expand All @@ -232,7 +241,8 @@ def test_library_update(self):
self.assertEqual(settings["ontology_library"][0]["ontology"],
{"foo3": {"bar1": {}, "baz1": {}}})

def test_library_import(self):
@patch(SBERT_PATCH_METHOD, return_value=[np.ones(300)] * 3)
def test_library_import(self, _):
ontology = {"foo3": {"bar3": {}, "baz3": {}}}
get_ontology_data = self.widget._OWOntology__ontology_view.get_data

Expand Down Expand Up @@ -271,10 +281,7 @@ def test_skipped_words_generate(self):
self.assertDictEqual(get_ontology_data(), {"foo1": {"bar1": {}, "baz1": {}}})

# generate with embedding error - two skipped
with patch(
"orangecontrib.text.vectorization.sbert.SBERT.__call__",
return_value=[np.ones(300), None, None],
):
with patch(SBERT_PATCH_METHOD, return_value=[np.ones(300), None, None]):
self.widget._OWOntology__run_button.click()
self.wait_until_finished()
self.assertDictEqual(get_ontology_data(), {"foo1": {}})
Expand All @@ -285,10 +292,7 @@ def test_skipped_words_generate(self):
)

# generate without embedding error
with patch(
"orangecontrib.text.vectorization.sbert.SBERT.__call__",
return_value=[np.ones(300)],
):
with patch(SBERT_PATCH_METHOD, return_value=[np.ones(300)]):
self.widget._OWOntology__run_button.click()
self.wait_until_finished()
self.assertDictEqual(get_ontology_data(), {"foo1": {}})
Expand All @@ -304,7 +308,7 @@ def test_skipped_words_insert(self):

# insert with an embedding error
with patch(
"orangecontrib.text.vectorization.sbert.SBERT.__call__",
SBERT_PATCH_METHOD,
side_effect=[
[np.ones(300), np.ones(300), np.ones(300), None],
[np.ones(300), np.ones(300), np.ones(300)],
Expand All @@ -330,10 +334,7 @@ def test_skipped_words_insert(self):
)

# insert without embedding error
with patch(
"orangecontrib.text.vectorization.sbert.SBERT.__call__",
return_value=[np.ones(300)] * 4,
):
with patch(SBERT_PATCH_METHOD, return_value=[np.ones(300)] * 4):
self.widget._OWOntology__inc_button.click()
self.wait_until_finished()
self.assertDictEqual(
Expand Down

0 comments on commit e765686

Please sign in to comment.