Skip to content

Commit

Permalink
Ontology - remove cache
Browse files Browse the repository at this point in the history
  • Loading branch information
PrimozGodec committed Aug 30, 2022
1 parent 339ad59 commit 4cc521f
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 184 deletions.
129 changes: 12 additions & 117 deletions orangecontrib/text/ontology.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
from typing import List, Set, Dict, Tuple, Optional, Callable
from collections import Counter
from itertools import chain
import os
import pickle

import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

from orangecontrib.text.vectorization.sbert import SBERT
from Orange.misc.environ import cache_dir
from Orange.util import dummy_callback, wrap_callback

EMB_DIM = 384


class Tree:

Expand Down Expand Up @@ -222,56 +218,9 @@ def generate_ontology(
return generation[best, :], roots[best]


def cos_sim(x: np.array, y: np.array) -> float:
dot = np.dot(x, y)
return 0 if np.allclose(dot, 0) else dot / (np.linalg.norm(x) * np. linalg.norm(y))


class EmbeddingStorage:

def __init__(self):
self.cache_dir = os.path.join(cache_dir(), 'ontology')
if not os.path.isdir(self.cache_dir):
os.makedirs(self.cache_dir)
self.similarities = dict()
try:
with open(os.path.join(self.cache_dir, 'sims.pkl'), 'rb') as file:
self.similarities = pickle.load(file)
except IOError:
self.similarities = dict()
self.embeddings = dict()

def save_similarities(self):
with open(os.path.join(self.cache_dir, 'sims.pkl'), 'wb') as file:
pickle.dump(self.similarities, file)

def get_embedding(self, word: str) -> Optional[np.array]:
if word in self.embeddings:
return self.embeddings[word]
try:
emb = np.load(os.path.join(self.cache_dir, f'{word}.npy'))
self.embeddings[word] = emb
return emb
except IOError:
return None

def save_embedding(self, word: str, emb: np.array) -> None:
self.embeddings[word] = emb
np.save(os.path.join(self.cache_dir, f'{word}.npy'), emb)

def clear_storage(self) -> None:
self.similarities = dict()
self.embeddings = dict()
if os.path.isdir(self.cache_dir):
for file in os.listdir(self.cache_dir):
os.remove(os.path.join(self.cache_dir, file))


class OntologyHandler:

def __init__(self):
self.embedder = SBERT()
self.storage = EmbeddingStorage()

def generate(
self,
Expand All @@ -284,11 +233,9 @@ def generate(
return {words[0]: {}}
if len(words) == 2:
return {sorted(words)[0]: {sorted(words)[1]: {}}}
sims = self._get_similarities(
words,
self._get_embeddings(words, wrap_callback(callback, end=0.1)),
wrap_callback(callback, start=0.1, end=0.2)
)
embeddings = self.embedder(words, wrap_callback(callback, end=0.1))
sims = self._get_similarities(embeddings)
callback(0.2)
if len(words) == 3:
root = np.argmin(np.sum(sims, axis=1))
rest = sorted([words[i] for i in range(3) if i != root])
Expand All @@ -307,18 +254,13 @@ def insert(
callback: Callable = dummy_callback
) -> Dict:
tree = Tree.from_dict(tree)
self._get_embeddings(words, wrap_callback(callback, end=0.3))
ticks = iter(np.linspace(0.3, 0.9, len(words)))

for word in words:
tick = next(ticks)
tree.adj_list.append(set())
tree.labels.append(word)
sims = self._get_similarities(
tree.labels,
self._get_embeddings(tree.labels, lambda x: callback(tick)),
lambda x: callback(tick)
)
embeddings = self.embedder(tree.labels)
sims = self._get_similarities(embeddings)
idx = len(tree.adj_list) - 1
fitness_function = FitnessFunction(tree.labels, sims).fitness
scores = list()
Expand All @@ -331,65 +273,18 @@ def insert(
best = np.argmax(scores)
tree.adj_list[best].add(idx)
tree.adj_list[idx].add(best)
callback(tick)
callback(next(ticks))

return tree.to_dict()

def score(self, tree: Dict, callback: Callable = dummy_callback) -> float:
tree = Tree.from_dict(tree)
sims = self._get_similarities(
tree.labels,
self._get_embeddings(tree.labels, wrap_callback(callback, end=0.7)),
wrap_callback(callback, start=0.7, end=0.8)
)
embeddings = self.embedder(tree.labels, wrap_callback(callback, end=0.7))
sims = self._get_similarities(embeddings)
callback(0.9)
fitness_function = FitnessFunction(tree.labels, sims).fitness
return fitness_function(tree, tree.root)[0]

def _get_embeddings(
self,
words: List[str],
callback: Callable = dummy_callback
) -> np.array:
embeddings = np.zeros((len(words), EMB_DIM))
missing, missing_idx = list(), list()
ticks = iter(np.linspace(0.0, 0.6, len(words)))
for i, word in enumerate(words):
callback(next(ticks))
emb = self.storage.get_embedding(word)
if emb is None:
missing.append(word)
missing_idx.append(i)
else:
embeddings[i, :] = emb
if len(missing_idx) > 0:
embs = self.embedder(missing, callback=wrap_callback(callback, start=0.6, end=0.9))
if None in embs:
raise RuntimeError("Couldn't obtain embeddings.")
embeddings[missing_idx, :] = np.array(embs)
for i in missing_idx:
self.storage.save_embedding(words[i], embeddings[i, :])

return embeddings

def _get_similarities(
self,
words: List[str],
embeddings: np.array,
callback: Callable = dummy_callback
) -> np.array:
sims = np.zeros((len(words), len(words)))
ticks = iter(np.linspace(0.0, 1.0, int(len(words) * (len(words) - 1) / 2)))
for i in range(len(words)):
for j in range(i + 1, len(words)):
callback(next(ticks))
key = tuple(sorted((words[i], words[j])))
try:
sim = self.storage.similarities[key]
except KeyError:
sim = cos_sim(embeddings[i, :], embeddings[j, :])
self.storage.similarities[key] = sim
sims[i, j] = sim
sims[j, i] = sim
self.storage.save_similarities()
return sims
@staticmethod
def _get_similarities(embeddings: np.array) -> np.array:
return cosine_similarity(embeddings, embeddings)
82 changes: 18 additions & 64 deletions orangecontrib/text/tests/test_ontology.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,32 @@
import unittest
from typing import List, Union
from unittest.mock import patch
from collections.abc import Iterator
import os
import asyncio

import numpy as np

from orangecontrib.text.ontology import Tree, EmbeddingStorage, OntologyHandler, EMB_DIM
from orangecontrib.text.ontology import Tree, OntologyHandler


EMB_DIM = 384
RESPONSE = [
f'{{ "embedding": {[i] * EMB_DIM} }}'.encode()
for i in range(4)
]
RESPONSE2 = [np.zeros(384), np.ones(384), np.zeros(384), np.ones(384)*2]
RESPONSE3 = [np.zeros(384), np.ones(384), np.arange(384), np.ones(384)*2]


class DummyResponse:
def arrays_to_response(array: List[Union[np.ndarray, List]]) -> Iterator[bytes]:
return iter(array_to_response(a) for a in array)


def array_to_response(array: Union[np.ndarray, List]) -> bytes:
return f'{{ "embedding": {array.tolist()} }}'.encode()


class DummyResponse:
def __init__(self, content):
self.content = content

Expand Down Expand Up @@ -72,54 +82,11 @@ def test_assertion_errors(self):
Tree.from_prufer_sequence([1, 0, 3], list(map(str, range(4))))


class TestEmbeddingStorage(unittest.TestCase):

def setUp(self):
self.storage = EmbeddingStorage()

def tearDown(self):
self.storage.clear_storage()

def test_clear_storage(self):
self.storage.save_embedding("testword", np.zeros(3))
self.assertEqual(len(self.storage.embeddings), 1)
self.storage.clear_storage()
self.assertEqual(len(self.storage.embeddings), 0)
self.assertEqual(len(os.listdir(self.storage.cache_dir)), 0)

def test_save_embedding(self):
self.storage.save_embedding("testword", np.zeros(3))
self.storage.save_embedding("testword2", np.zeros(3))
self.assertEqual(len(self.storage.embeddings), 2)
self.assertEqual(len(os.listdir(self.storage.cache_dir)), 2)

def test_get_embedding(self):
self.storage.save_embedding("testword", np.arange(3))
emb = self.storage.get_embedding("testword")
self.assertEqual(emb.tolist(), [0, 1, 2])

def test_get_from_cache(self):
self.storage.save_embedding("testword", np.arange(3))
self.storage.embeddings = dict()
emb = self.storage.get_embedding("testword")
self.assertEqual(emb.tolist(), [0, 1, 2])

def test_similarities(self):
self.storage.similarities['a', 'b'] = 0.75
self.storage.save_similarities()
storage = EmbeddingStorage()
self.assertEqual(len(storage.similarities), 1)
self.assertTrue(('a', 'b') in storage.similarities)
self.assertEqual(storage.similarities['a', 'b'], 0.75)


class TestOntologyHandler(unittest.TestCase):

def setUp(self):
self.handler = OntologyHandler()

def tearDown(self):
self.handler.storage.clear_storage()
self.handler.embedder.clear_cache()

@patch('orangecontrib.text.ontology.generate_ontology')
Expand All @@ -128,45 +95,32 @@ def test_small_trees(self, mock):
self.handler.generate(words)
mock.assert_not_called()

@patch('httpx.AsyncClient.post', make_dummy_post(arrays_to_response(RESPONSE3)))
def test_generate_small(self):
self.handler.storage.save_embedding('1', np.zeros(384))
self.handler.storage.save_embedding('2', np.ones(384))
self.handler.storage.save_embedding('3', np.arange(384))
tree = self.handler.generate(['1', '2', '3'])
self.assertTrue(isinstance(tree, dict))

@patch('httpx.AsyncClient.post')
def test_generate(self, mock):
self.handler.storage.save_embedding('1', np.zeros(384))
self.handler.storage.save_embedding('2', np.ones(384))
self.handler.storage.save_embedding('3', np.arange(384))
self.handler.storage.save_embedding('4', np.ones(384) * 2)
@patch('httpx.AsyncClient.post', make_dummy_post(arrays_to_response(RESPONSE3)))
def test_generate(self):
tree = self.handler.generate(['1', '2', '3', '4'])
self.assertTrue(isinstance(tree, dict))
mock.request.assert_not_called()
mock.get_response.assert_not_called()

@patch('httpx.AsyncClient.post', make_dummy_post(iter(RESPONSE)))
def test_generate_with_unknown_embeddings(self):
tree = self.handler.generate(['1', '2', '3', '4'])
self.assertTrue(isinstance(tree, dict))

@patch('httpx.AsyncClient.post', make_dummy_post(arrays_to_response(RESPONSE2)))
def test_insert(self):
self.handler.storage.save_embedding('1', np.zeros(384))
self.handler.storage.save_embedding('2', np.ones(384))
self.handler.storage.save_embedding('3', np.arange(384))
self.handler.storage.save_embedding('4', np.ones(384) * 2)
tree = self.handler.generate(['1', '2', '3'])
new_tree = self.handler.insert(tree, ['4'])
self.assertGreater(
len(Tree.from_dict(new_tree).adj_list),
len(Tree.from_dict(tree).adj_list)
)

@patch('httpx.AsyncClient.post', make_dummy_post(array_to_response(np.zeros(384))))
def test_score(self):
self.handler.storage.save_embedding('1', np.zeros(384))
self.handler.storage.save_embedding('2', np.ones(384))
self.handler.storage.save_embedding('3', np.arange(384))
tree = self.handler.generate(['1', '2', '3'])
score = self.handler.score(tree)
self.assertGreater(score, 0)
Expand Down
6 changes: 3 additions & 3 deletions orangecontrib/text/widgets/owontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(self, data_changed_cb: Callable):

edit_triggers = QTreeView.DoubleClicked | QTreeView.EditKeyPressed
super().__init__(
editTriggers=int(edit_triggers),
editTriggers=edit_triggers,
selectionMode=QTreeView.ExtendedSelection,
dragEnabled=True,
acceptDrops=True,
Expand All @@ -165,7 +165,7 @@ def __init__(self, data_changed_cb: Callable):

self.__disconnected = False

def startDrag(self, actions: Qt.DropActions):
def startDrag(self, actions: Qt.DropAction):
with disconnected(self.model().dataChanged, self.__data_changed_cb):
super().startDrag(actions)
self.drop_finished.emit()
Expand Down Expand Up @@ -626,7 +626,7 @@ def _setup_gui(self):

edit_triggers = QListView.DoubleClicked | QListView.EditKeyPressed
self.__library_view = QListView(
editTriggers=int(edit_triggers),
editTriggers=edit_triggers,
minimumWidth=200,
sizePolicy=QSizePolicy(QSizePolicy.Ignored, QSizePolicy.Expanding),
)
Expand Down

0 comments on commit 4cc521f

Please sign in to comment.