From 4cc521f7bb8fb448e5f673b55c3041b7594e0524 Mon Sep 17 00:00:00 2001 From: PrimozGodec Date: Tue, 30 Aug 2022 11:17:01 +0200 Subject: [PATCH] Ontology - remove cache --- orangecontrib/text/ontology.py | 129 ++-------------------- orangecontrib/text/tests/test_ontology.py | 82 +++----------- orangecontrib/text/widgets/owontology.py | 6 +- 3 files changed, 33 insertions(+), 184 deletions(-) diff --git a/orangecontrib/text/ontology.py b/orangecontrib/text/ontology.py index f734e3216..56f9df517 100644 --- a/orangecontrib/text/ontology.py +++ b/orangecontrib/text/ontology.py @@ -1,17 +1,13 @@ from typing import List, Set, Dict, Tuple, Optional, Callable from collections import Counter from itertools import chain -import os -import pickle import numpy as np +from sklearn.metrics.pairwise import cosine_similarity from orangecontrib.text.vectorization.sbert import SBERT -from Orange.misc.environ import cache_dir from Orange.util import dummy_callback, wrap_callback -EMB_DIM = 384 - class Tree: @@ -222,56 +218,9 @@ def generate_ontology( return generation[best, :], roots[best] -def cos_sim(x: np.array, y: np.array) -> float: - dot = np.dot(x, y) - return 0 if np.allclose(dot, 0) else dot / (np.linalg.norm(x) * np. linalg.norm(y)) - - -class EmbeddingStorage: - - def __init__(self): - self.cache_dir = os.path.join(cache_dir(), 'ontology') - if not os.path.isdir(self.cache_dir): - os.makedirs(self.cache_dir) - self.similarities = dict() - try: - with open(os.path.join(self.cache_dir, 'sims.pkl'), 'rb') as file: - self.similarities = pickle.load(file) - except IOError: - self.similarities = dict() - self.embeddings = dict() - - def save_similarities(self): - with open(os.path.join(self.cache_dir, 'sims.pkl'), 'wb') as file: - pickle.dump(self.similarities, file) - - def get_embedding(self, word: str) -> Optional[np.array]: - if word in self.embeddings: - return self.embeddings[word] - try: - emb = np.load(os.path.join(self.cache_dir, f'{word}.npy')) - self.embeddings[word] = emb - return emb - except IOError: - return None - - def save_embedding(self, word: str, emb: np.array) -> None: - self.embeddings[word] = emb - np.save(os.path.join(self.cache_dir, f'{word}.npy'), emb) - - def clear_storage(self) -> None: - self.similarities = dict() - self.embeddings = dict() - if os.path.isdir(self.cache_dir): - for file in os.listdir(self.cache_dir): - os.remove(os.path.join(self.cache_dir, file)) - - class OntologyHandler: - def __init__(self): self.embedder = SBERT() - self.storage = EmbeddingStorage() def generate( self, @@ -284,11 +233,9 @@ def generate( return {words[0]: {}} if len(words) == 2: return {sorted(words)[0]: {sorted(words)[1]: {}}} - sims = self._get_similarities( - words, - self._get_embeddings(words, wrap_callback(callback, end=0.1)), - wrap_callback(callback, start=0.1, end=0.2) - ) + embeddings = self.embedder(words, wrap_callback(callback, end=0.1)) + sims = self._get_similarities(embeddings) + callback(0.2) if len(words) == 3: root = np.argmin(np.sum(sims, axis=1)) rest = sorted([words[i] for i in range(3) if i != root]) @@ -307,18 +254,13 @@ def insert( callback: Callable = dummy_callback ) -> Dict: tree = Tree.from_dict(tree) - self._get_embeddings(words, wrap_callback(callback, end=0.3)) ticks = iter(np.linspace(0.3, 0.9, len(words))) for word in words: - tick = next(ticks) tree.adj_list.append(set()) tree.labels.append(word) - sims = self._get_similarities( - tree.labels, - self._get_embeddings(tree.labels, lambda x: callback(tick)), - lambda x: callback(tick) - ) + embeddings = self.embedder(tree.labels) + sims = self._get_similarities(embeddings) idx = len(tree.adj_list) - 1 fitness_function = FitnessFunction(tree.labels, sims).fitness scores = list() @@ -331,65 +273,18 @@ def insert( best = np.argmax(scores) tree.adj_list[best].add(idx) tree.adj_list[idx].add(best) - callback(tick) + callback(next(ticks)) return tree.to_dict() def score(self, tree: Dict, callback: Callable = dummy_callback) -> float: tree = Tree.from_dict(tree) - sims = self._get_similarities( - tree.labels, - self._get_embeddings(tree.labels, wrap_callback(callback, end=0.7)), - wrap_callback(callback, start=0.7, end=0.8) - ) + embeddings = self.embedder(tree.labels, wrap_callback(callback, end=0.7)) + sims = self._get_similarities(embeddings) callback(0.9) fitness_function = FitnessFunction(tree.labels, sims).fitness return fitness_function(tree, tree.root)[0] - def _get_embeddings( - self, - words: List[str], - callback: Callable = dummy_callback - ) -> np.array: - embeddings = np.zeros((len(words), EMB_DIM)) - missing, missing_idx = list(), list() - ticks = iter(np.linspace(0.0, 0.6, len(words))) - for i, word in enumerate(words): - callback(next(ticks)) - emb = self.storage.get_embedding(word) - if emb is None: - missing.append(word) - missing_idx.append(i) - else: - embeddings[i, :] = emb - if len(missing_idx) > 0: - embs = self.embedder(missing, callback=wrap_callback(callback, start=0.6, end=0.9)) - if None in embs: - raise RuntimeError("Couldn't obtain embeddings.") - embeddings[missing_idx, :] = np.array(embs) - for i in missing_idx: - self.storage.save_embedding(words[i], embeddings[i, :]) - - return embeddings - - def _get_similarities( - self, - words: List[str], - embeddings: np.array, - callback: Callable = dummy_callback - ) -> np.array: - sims = np.zeros((len(words), len(words))) - ticks = iter(np.linspace(0.0, 1.0, int(len(words) * (len(words) - 1) / 2))) - for i in range(len(words)): - for j in range(i + 1, len(words)): - callback(next(ticks)) - key = tuple(sorted((words[i], words[j]))) - try: - sim = self.storage.similarities[key] - except KeyError: - sim = cos_sim(embeddings[i, :], embeddings[j, :]) - self.storage.similarities[key] = sim - sims[i, j] = sim - sims[j, i] = sim - self.storage.save_similarities() - return sims + @staticmethod + def _get_similarities(embeddings: np.array) -> np.array: + return cosine_similarity(embeddings, embeddings) diff --git a/orangecontrib/text/tests/test_ontology.py b/orangecontrib/text/tests/test_ontology.py index cd0251377..245b65cf5 100644 --- a/orangecontrib/text/tests/test_ontology.py +++ b/orangecontrib/text/tests/test_ontology.py @@ -1,22 +1,32 @@ import unittest +from typing import List, Union from unittest.mock import patch from collections.abc import Iterator -import os import asyncio import numpy as np -from orangecontrib.text.ontology import Tree, EmbeddingStorage, OntologyHandler, EMB_DIM +from orangecontrib.text.ontology import Tree, OntologyHandler +EMB_DIM = 384 RESPONSE = [ f'{{ "embedding": {[i] * EMB_DIM} }}'.encode() for i in range(4) ] +RESPONSE2 = [np.zeros(384), np.ones(384), np.zeros(384), np.ones(384)*2] +RESPONSE3 = [np.zeros(384), np.ones(384), np.arange(384), np.ones(384)*2] -class DummyResponse: +def arrays_to_response(array: List[Union[np.ndarray, List]]) -> Iterator[bytes]: + return iter(array_to_response(a) for a in array) + + +def array_to_response(array: Union[np.ndarray, List]) -> bytes: + return f'{{ "embedding": {array.tolist()} }}'.encode() + +class DummyResponse: def __init__(self, content): self.content = content @@ -72,54 +82,11 @@ def test_assertion_errors(self): Tree.from_prufer_sequence([1, 0, 3], list(map(str, range(4)))) -class TestEmbeddingStorage(unittest.TestCase): - - def setUp(self): - self.storage = EmbeddingStorage() - - def tearDown(self): - self.storage.clear_storage() - - def test_clear_storage(self): - self.storage.save_embedding("testword", np.zeros(3)) - self.assertEqual(len(self.storage.embeddings), 1) - self.storage.clear_storage() - self.assertEqual(len(self.storage.embeddings), 0) - self.assertEqual(len(os.listdir(self.storage.cache_dir)), 0) - - def test_save_embedding(self): - self.storage.save_embedding("testword", np.zeros(3)) - self.storage.save_embedding("testword2", np.zeros(3)) - self.assertEqual(len(self.storage.embeddings), 2) - self.assertEqual(len(os.listdir(self.storage.cache_dir)), 2) - - def test_get_embedding(self): - self.storage.save_embedding("testword", np.arange(3)) - emb = self.storage.get_embedding("testword") - self.assertEqual(emb.tolist(), [0, 1, 2]) - - def test_get_from_cache(self): - self.storage.save_embedding("testword", np.arange(3)) - self.storage.embeddings = dict() - emb = self.storage.get_embedding("testword") - self.assertEqual(emb.tolist(), [0, 1, 2]) - - def test_similarities(self): - self.storage.similarities['a', 'b'] = 0.75 - self.storage.save_similarities() - storage = EmbeddingStorage() - self.assertEqual(len(storage.similarities), 1) - self.assertTrue(('a', 'b') in storage.similarities) - self.assertEqual(storage.similarities['a', 'b'], 0.75) - - class TestOntologyHandler(unittest.TestCase): - def setUp(self): self.handler = OntologyHandler() def tearDown(self): - self.handler.storage.clear_storage() self.handler.embedder.clear_cache() @patch('orangecontrib.text.ontology.generate_ontology') @@ -128,34 +95,23 @@ def test_small_trees(self, mock): self.handler.generate(words) mock.assert_not_called() + @patch('httpx.AsyncClient.post', make_dummy_post(arrays_to_response(RESPONSE3))) def test_generate_small(self): - self.handler.storage.save_embedding('1', np.zeros(384)) - self.handler.storage.save_embedding('2', np.ones(384)) - self.handler.storage.save_embedding('3', np.arange(384)) tree = self.handler.generate(['1', '2', '3']) self.assertTrue(isinstance(tree, dict)) - @patch('httpx.AsyncClient.post') - def test_generate(self, mock): - self.handler.storage.save_embedding('1', np.zeros(384)) - self.handler.storage.save_embedding('2', np.ones(384)) - self.handler.storage.save_embedding('3', np.arange(384)) - self.handler.storage.save_embedding('4', np.ones(384) * 2) + @patch('httpx.AsyncClient.post', make_dummy_post(arrays_to_response(RESPONSE3))) + def test_generate(self): tree = self.handler.generate(['1', '2', '3', '4']) self.assertTrue(isinstance(tree, dict)) - mock.request.assert_not_called() - mock.get_response.assert_not_called() @patch('httpx.AsyncClient.post', make_dummy_post(iter(RESPONSE))) def test_generate_with_unknown_embeddings(self): tree = self.handler.generate(['1', '2', '3', '4']) self.assertTrue(isinstance(tree, dict)) + @patch('httpx.AsyncClient.post', make_dummy_post(arrays_to_response(RESPONSE2))) def test_insert(self): - self.handler.storage.save_embedding('1', np.zeros(384)) - self.handler.storage.save_embedding('2', np.ones(384)) - self.handler.storage.save_embedding('3', np.arange(384)) - self.handler.storage.save_embedding('4', np.ones(384) * 2) tree = self.handler.generate(['1', '2', '3']) new_tree = self.handler.insert(tree, ['4']) self.assertGreater( @@ -163,10 +119,8 @@ def test_insert(self): len(Tree.from_dict(tree).adj_list) ) + @patch('httpx.AsyncClient.post', make_dummy_post(array_to_response(np.zeros(384)))) def test_score(self): - self.handler.storage.save_embedding('1', np.zeros(384)) - self.handler.storage.save_embedding('2', np.ones(384)) - self.handler.storage.save_embedding('3', np.arange(384)) tree = self.handler.generate(['1', '2', '3']) score = self.handler.score(tree) self.assertGreater(score, 0) diff --git a/orangecontrib/text/widgets/owontology.py b/orangecontrib/text/widgets/owontology.py index b306befb2..38dbb14ef 100644 --- a/orangecontrib/text/widgets/owontology.py +++ b/orangecontrib/text/widgets/owontology.py @@ -153,7 +153,7 @@ def __init__(self, data_changed_cb: Callable): edit_triggers = QTreeView.DoubleClicked | QTreeView.EditKeyPressed super().__init__( - editTriggers=int(edit_triggers), + editTriggers=edit_triggers, selectionMode=QTreeView.ExtendedSelection, dragEnabled=True, acceptDrops=True, @@ -165,7 +165,7 @@ def __init__(self, data_changed_cb: Callable): self.__disconnected = False - def startDrag(self, actions: Qt.DropActions): + def startDrag(self, actions: Qt.DropAction): with disconnected(self.model().dataChanged, self.__data_changed_cb): super().startDrag(actions) self.drop_finished.emit() @@ -626,7 +626,7 @@ def _setup_gui(self): edit_triggers = QListView.DoubleClicked | QListView.EditKeyPressed self.__library_view = QListView( - editTriggers=int(edit_triggers), + editTriggers=edit_triggers, minimumWidth=200, sizePolicy=QSizePolicy(QSizePolicy.Ignored, QSizePolicy.Expanding), )