From 4cc521f7bb8fb448e5f673b55c3041b7594e0524 Mon Sep 17 00:00:00 2001
From: PrimozGodec
Date: Tue, 30 Aug 2022 11:17:01 +0200
Subject: [PATCH] Ontology - remove cache
---
orangecontrib/text/ontology.py | 129 ++--------------------
orangecontrib/text/tests/test_ontology.py | 82 +++-----------
orangecontrib/text/widgets/owontology.py | 6 +-
3 files changed, 33 insertions(+), 184 deletions(-)
diff --git a/orangecontrib/text/ontology.py b/orangecontrib/text/ontology.py
index f734e3216..56f9df517 100644
--- a/orangecontrib/text/ontology.py
+++ b/orangecontrib/text/ontology.py
@@ -1,17 +1,13 @@
from typing import List, Set, Dict, Tuple, Optional, Callable
from collections import Counter
from itertools import chain
-import os
-import pickle
import numpy as np
+from sklearn.metrics.pairwise import cosine_similarity
from orangecontrib.text.vectorization.sbert import SBERT
-from Orange.misc.environ import cache_dir
from Orange.util import dummy_callback, wrap_callback
-EMB_DIM = 384
-
class Tree:
@@ -222,56 +218,9 @@ def generate_ontology(
return generation[best, :], roots[best]
-def cos_sim(x: np.array, y: np.array) -> float:
- dot = np.dot(x, y)
- return 0 if np.allclose(dot, 0) else dot / (np.linalg.norm(x) * np. linalg.norm(y))
-
-
-class EmbeddingStorage:
-
- def __init__(self):
- self.cache_dir = os.path.join(cache_dir(), 'ontology')
- if not os.path.isdir(self.cache_dir):
- os.makedirs(self.cache_dir)
- self.similarities = dict()
- try:
- with open(os.path.join(self.cache_dir, 'sims.pkl'), 'rb') as file:
- self.similarities = pickle.load(file)
- except IOError:
- self.similarities = dict()
- self.embeddings = dict()
-
- def save_similarities(self):
- with open(os.path.join(self.cache_dir, 'sims.pkl'), 'wb') as file:
- pickle.dump(self.similarities, file)
-
- def get_embedding(self, word: str) -> Optional[np.array]:
- if word in self.embeddings:
- return self.embeddings[word]
- try:
- emb = np.load(os.path.join(self.cache_dir, f'{word}.npy'))
- self.embeddings[word] = emb
- return emb
- except IOError:
- return None
-
- def save_embedding(self, word: str, emb: np.array) -> None:
- self.embeddings[word] = emb
- np.save(os.path.join(self.cache_dir, f'{word}.npy'), emb)
-
- def clear_storage(self) -> None:
- self.similarities = dict()
- self.embeddings = dict()
- if os.path.isdir(self.cache_dir):
- for file in os.listdir(self.cache_dir):
- os.remove(os.path.join(self.cache_dir, file))
-
-
class OntologyHandler:
-
def __init__(self):
self.embedder = SBERT()
- self.storage = EmbeddingStorage()
def generate(
self,
@@ -284,11 +233,9 @@ def generate(
return {words[0]: {}}
if len(words) == 2:
return {sorted(words)[0]: {sorted(words)[1]: {}}}
- sims = self._get_similarities(
- words,
- self._get_embeddings(words, wrap_callback(callback, end=0.1)),
- wrap_callback(callback, start=0.1, end=0.2)
- )
+ embeddings = self.embedder(words, wrap_callback(callback, end=0.1))
+ sims = self._get_similarities(embeddings)
+ callback(0.2)
if len(words) == 3:
root = np.argmin(np.sum(sims, axis=1))
rest = sorted([words[i] for i in range(3) if i != root])
@@ -307,18 +254,13 @@ def insert(
callback: Callable = dummy_callback
) -> Dict:
tree = Tree.from_dict(tree)
- self._get_embeddings(words, wrap_callback(callback, end=0.3))
ticks = iter(np.linspace(0.3, 0.9, len(words)))
for word in words:
- tick = next(ticks)
tree.adj_list.append(set())
tree.labels.append(word)
- sims = self._get_similarities(
- tree.labels,
- self._get_embeddings(tree.labels, lambda x: callback(tick)),
- lambda x: callback(tick)
- )
+ embeddings = self.embedder(tree.labels)
+ sims = self._get_similarities(embeddings)
idx = len(tree.adj_list) - 1
fitness_function = FitnessFunction(tree.labels, sims).fitness
scores = list()
@@ -331,65 +273,18 @@ def insert(
best = np.argmax(scores)
tree.adj_list[best].add(idx)
tree.adj_list[idx].add(best)
- callback(tick)
+ callback(next(ticks))
return tree.to_dict()
def score(self, tree: Dict, callback: Callable = dummy_callback) -> float:
tree = Tree.from_dict(tree)
- sims = self._get_similarities(
- tree.labels,
- self._get_embeddings(tree.labels, wrap_callback(callback, end=0.7)),
- wrap_callback(callback, start=0.7, end=0.8)
- )
+ embeddings = self.embedder(tree.labels, wrap_callback(callback, end=0.7))
+ sims = self._get_similarities(embeddings)
callback(0.9)
fitness_function = FitnessFunction(tree.labels, sims).fitness
return fitness_function(tree, tree.root)[0]
- def _get_embeddings(
- self,
- words: List[str],
- callback: Callable = dummy_callback
- ) -> np.array:
- embeddings = np.zeros((len(words), EMB_DIM))
- missing, missing_idx = list(), list()
- ticks = iter(np.linspace(0.0, 0.6, len(words)))
- for i, word in enumerate(words):
- callback(next(ticks))
- emb = self.storage.get_embedding(word)
- if emb is None:
- missing.append(word)
- missing_idx.append(i)
- else:
- embeddings[i, :] = emb
- if len(missing_idx) > 0:
- embs = self.embedder(missing, callback=wrap_callback(callback, start=0.6, end=0.9))
- if None in embs:
- raise RuntimeError("Couldn't obtain embeddings.")
- embeddings[missing_idx, :] = np.array(embs)
- for i in missing_idx:
- self.storage.save_embedding(words[i], embeddings[i, :])
-
- return embeddings
-
- def _get_similarities(
- self,
- words: List[str],
- embeddings: np.array,
- callback: Callable = dummy_callback
- ) -> np.array:
- sims = np.zeros((len(words), len(words)))
- ticks = iter(np.linspace(0.0, 1.0, int(len(words) * (len(words) - 1) / 2)))
- for i in range(len(words)):
- for j in range(i + 1, len(words)):
- callback(next(ticks))
- key = tuple(sorted((words[i], words[j])))
- try:
- sim = self.storage.similarities[key]
- except KeyError:
- sim = cos_sim(embeddings[i, :], embeddings[j, :])
- self.storage.similarities[key] = sim
- sims[i, j] = sim
- sims[j, i] = sim
- self.storage.save_similarities()
- return sims
+ @staticmethod
+ def _get_similarities(embeddings: np.array) -> np.array:
+ return cosine_similarity(embeddings, embeddings)
diff --git a/orangecontrib/text/tests/test_ontology.py b/orangecontrib/text/tests/test_ontology.py
index cd0251377..245b65cf5 100644
--- a/orangecontrib/text/tests/test_ontology.py
+++ b/orangecontrib/text/tests/test_ontology.py
@@ -1,22 +1,32 @@
import unittest
+from typing import List, Union
from unittest.mock import patch
from collections.abc import Iterator
-import os
import asyncio
import numpy as np
-from orangecontrib.text.ontology import Tree, EmbeddingStorage, OntologyHandler, EMB_DIM
+from orangecontrib.text.ontology import Tree, OntologyHandler
+EMB_DIM = 384
RESPONSE = [
f'{{ "embedding": {[i] * EMB_DIM} }}'.encode()
for i in range(4)
]
+RESPONSE2 = [np.zeros(384), np.ones(384), np.zeros(384), np.ones(384)*2]
+RESPONSE3 = [np.zeros(384), np.ones(384), np.arange(384), np.ones(384)*2]
-class DummyResponse:
+def arrays_to_response(array: List[Union[np.ndarray, List]]) -> Iterator[bytes]:
+ return iter(array_to_response(a) for a in array)
+
+
+def array_to_response(array: Union[np.ndarray, List]) -> bytes:
+ return f'{{ "embedding": {array.tolist()} }}'.encode()
+
+class DummyResponse:
def __init__(self, content):
self.content = content
@@ -72,54 +82,11 @@ def test_assertion_errors(self):
Tree.from_prufer_sequence([1, 0, 3], list(map(str, range(4))))
-class TestEmbeddingStorage(unittest.TestCase):
-
- def setUp(self):
- self.storage = EmbeddingStorage()
-
- def tearDown(self):
- self.storage.clear_storage()
-
- def test_clear_storage(self):
- self.storage.save_embedding("testword", np.zeros(3))
- self.assertEqual(len(self.storage.embeddings), 1)
- self.storage.clear_storage()
- self.assertEqual(len(self.storage.embeddings), 0)
- self.assertEqual(len(os.listdir(self.storage.cache_dir)), 0)
-
- def test_save_embedding(self):
- self.storage.save_embedding("testword", np.zeros(3))
- self.storage.save_embedding("testword2", np.zeros(3))
- self.assertEqual(len(self.storage.embeddings), 2)
- self.assertEqual(len(os.listdir(self.storage.cache_dir)), 2)
-
- def test_get_embedding(self):
- self.storage.save_embedding("testword", np.arange(3))
- emb = self.storage.get_embedding("testword")
- self.assertEqual(emb.tolist(), [0, 1, 2])
-
- def test_get_from_cache(self):
- self.storage.save_embedding("testword", np.arange(3))
- self.storage.embeddings = dict()
- emb = self.storage.get_embedding("testword")
- self.assertEqual(emb.tolist(), [0, 1, 2])
-
- def test_similarities(self):
- self.storage.similarities['a', 'b'] = 0.75
- self.storage.save_similarities()
- storage = EmbeddingStorage()
- self.assertEqual(len(storage.similarities), 1)
- self.assertTrue(('a', 'b') in storage.similarities)
- self.assertEqual(storage.similarities['a', 'b'], 0.75)
-
-
class TestOntologyHandler(unittest.TestCase):
-
def setUp(self):
self.handler = OntologyHandler()
def tearDown(self):
- self.handler.storage.clear_storage()
self.handler.embedder.clear_cache()
@patch('orangecontrib.text.ontology.generate_ontology')
@@ -128,34 +95,23 @@ def test_small_trees(self, mock):
self.handler.generate(words)
mock.assert_not_called()
+ @patch('httpx.AsyncClient.post', make_dummy_post(arrays_to_response(RESPONSE3)))
def test_generate_small(self):
- self.handler.storage.save_embedding('1', np.zeros(384))
- self.handler.storage.save_embedding('2', np.ones(384))
- self.handler.storage.save_embedding('3', np.arange(384))
tree = self.handler.generate(['1', '2', '3'])
self.assertTrue(isinstance(tree, dict))
- @patch('httpx.AsyncClient.post')
- def test_generate(self, mock):
- self.handler.storage.save_embedding('1', np.zeros(384))
- self.handler.storage.save_embedding('2', np.ones(384))
- self.handler.storage.save_embedding('3', np.arange(384))
- self.handler.storage.save_embedding('4', np.ones(384) * 2)
+ @patch('httpx.AsyncClient.post', make_dummy_post(arrays_to_response(RESPONSE3)))
+ def test_generate(self):
tree = self.handler.generate(['1', '2', '3', '4'])
self.assertTrue(isinstance(tree, dict))
- mock.request.assert_not_called()
- mock.get_response.assert_not_called()
@patch('httpx.AsyncClient.post', make_dummy_post(iter(RESPONSE)))
def test_generate_with_unknown_embeddings(self):
tree = self.handler.generate(['1', '2', '3', '4'])
self.assertTrue(isinstance(tree, dict))
+ @patch('httpx.AsyncClient.post', make_dummy_post(arrays_to_response(RESPONSE2)))
def test_insert(self):
- self.handler.storage.save_embedding('1', np.zeros(384))
- self.handler.storage.save_embedding('2', np.ones(384))
- self.handler.storage.save_embedding('3', np.arange(384))
- self.handler.storage.save_embedding('4', np.ones(384) * 2)
tree = self.handler.generate(['1', '2', '3'])
new_tree = self.handler.insert(tree, ['4'])
self.assertGreater(
@@ -163,10 +119,8 @@ def test_insert(self):
len(Tree.from_dict(tree).adj_list)
)
+ @patch('httpx.AsyncClient.post', make_dummy_post(array_to_response(np.zeros(384))))
def test_score(self):
- self.handler.storage.save_embedding('1', np.zeros(384))
- self.handler.storage.save_embedding('2', np.ones(384))
- self.handler.storage.save_embedding('3', np.arange(384))
tree = self.handler.generate(['1', '2', '3'])
score = self.handler.score(tree)
self.assertGreater(score, 0)
diff --git a/orangecontrib/text/widgets/owontology.py b/orangecontrib/text/widgets/owontology.py
index b306befb2..38dbb14ef 100644
--- a/orangecontrib/text/widgets/owontology.py
+++ b/orangecontrib/text/widgets/owontology.py
@@ -153,7 +153,7 @@ def __init__(self, data_changed_cb: Callable):
edit_triggers = QTreeView.DoubleClicked | QTreeView.EditKeyPressed
super().__init__(
- editTriggers=int(edit_triggers),
+ editTriggers=edit_triggers,
selectionMode=QTreeView.ExtendedSelection,
dragEnabled=True,
acceptDrops=True,
@@ -165,7 +165,7 @@ def __init__(self, data_changed_cb: Callable):
self.__disconnected = False
- def startDrag(self, actions: Qt.DropActions):
+ def startDrag(self, actions: Qt.DropAction):
with disconnected(self.model().dataChanged, self.__data_changed_cb):
super().startDrag(actions)
self.drop_finished.emit()
@@ -626,7 +626,7 @@ def _setup_gui(self):
edit_triggers = QListView.DoubleClicked | QListView.EditKeyPressed
self.__library_view = QListView(
- editTriggers=int(edit_triggers),
+ editTriggers=edit_triggers,
minimumWidth=200,
sizePolicy=QSizePolicy(QSizePolicy.Ignored, QSizePolicy.Expanding),
)