Skip to content

Commit

Permalink
Document embedder - use base vectorizer
Browse files Browse the repository at this point in the history
  • Loading branch information
PrimozGodec committed May 26, 2022
1 parent f8dc6b1 commit aaad632
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 231 deletions.
36 changes: 13 additions & 23 deletions orangecontrib/text/tests/test_documentembedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,22 @@ def tearDown(self):

@patch(PATCH_METHOD)
def test_with_empty_corpus(self, mock):
self.assertIsNone(self.embedder(self.corpus[:0])[0])
self.assertIsNone(self.embedder(self.corpus[:0])[1])
self.assertIsNone(self.embedder.transform(self.corpus[:0])[0])
self.assertIsNone(self.embedder.transform(self.corpus[:0])[1])
mock.request.assert_not_called()
mock.get_response.assert_not_called()
self.assertEqual(self.embedder._embedder._cache._cache_dict, dict())

@patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}'))
def test_success_subset(self):
res, skipped = self.embedder(self.corpus[[0]])
res, skipped = self.embedder.transform(self.corpus[[0]])
assert_array_equal(res.X, [[0.3, 1]])
self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 1)
self.assertIsNone(skipped)

@patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}'))
def test_success_shapes(self):
res, skipped = self.embedder(self.corpus)
res, skipped = self.embedder.transform(self.corpus)
self.assertEqual(res.X.shape, (len(self.corpus), 2))
self.assertEqual(len(res.domain.variables),
len(self.corpus.domain.variables) + 2)
Expand All @@ -58,31 +58,31 @@ def test_success_shapes(self):
@patch(PATCH_METHOD, make_dummy_post(b''))
def test_empty_response(self):
with self.assertWarns(RuntimeWarning):
res, skipped = self.embedder(self.corpus[[0]])
res, skipped = self.embedder.transform(self.corpus[[0]])
self.assertIsNone(res)
self.assertEqual(len(skipped), 1)
self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0)

@patch(PATCH_METHOD, make_dummy_post(b'str'))
def test_invalid_response(self):
with self.assertWarns(RuntimeWarning):
res, skipped = self.embedder(self.corpus[[0]])
res, skipped = self.embedder.transform(self.corpus[[0]])
self.assertIsNone(res)
self.assertEqual(len(skipped), 1)
self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0)

@patch(PATCH_METHOD, make_dummy_post(b'{"embeddings": [0.3, 1]}'))
def test_invalid_json_key(self):
with self.assertWarns(RuntimeWarning):
res, skipped = self.embedder(self.corpus[[0]])
res, skipped = self.embedder.transform(self.corpus[[0]])
self.assertIsNone(res)
self.assertEqual(len(skipped), 1)
self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0)

@patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}'))
def test_persistent_caching(self):
self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0)
self.embedder(self.corpus[[0]])
self.embedder.transform(self.corpus[[0]])
self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 1)
self.embedder._embedder._cache.persist_cache()

Expand All @@ -98,7 +98,7 @@ def test_cache_for_different_languages(self):
embedder = DocumentEmbedder(language='sl')
embedder.clear_cache()
self.assertEqual(len(embedder._embedder._cache._cache_dict), 0)
embedder(self.corpus[[0]])
embedder.transform(self.corpus[[0]])
self.assertEqual(len(embedder._embedder._cache._cache_dict), 1)
embedder._embedder._cache.persist_cache()

Expand All @@ -116,44 +116,34 @@ def test_cache_for_different_aggregators(self):
embedder = DocumentEmbedder(aggregator='max')
embedder.clear_cache()
self.assertEqual(len(embedder._embedder._cache._cache_dict), 0)
embedder(self.corpus[[0]])
embedder.transform(self.corpus[[0]])
self.assertEqual(len(embedder._embedder._cache._cache_dict), 1)
embedder._embedder._cache.persist_cache()

embedder = DocumentEmbedder(aggregator='min')
self.assertEqual(len(embedder._embedder._cache._cache_dict), 1)
embedder(self.corpus[[0]])
embedder.transform(self.corpus[[0]])
self.assertEqual(len(embedder._embedder._cache._cache_dict), 2)

@patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}'))
def test_with_statement(self):
with self.embedder as embedder:
res, skipped = embedder(self.corpus[[0]])
assert_array_equal(res.X, [[0.3, 1]])

@patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}'))
def test_cancel(self):
self.assertFalse(self.embedder._embedder._cancelled)
self.embedder._embedder._cancelled = True
with self.assertRaises(Exception):
self.embedder(self.corpus[[0]])
self.embedder.transform(self.corpus[[0]])

@patch(PATCH_METHOD, side_effect=OSError)
def test_connection_error(self, _):
embedder = DocumentEmbedder()
with self.assertRaises(ConnectionError):
embedder(self.corpus[[0]])
embedder.transform(self.corpus[[0]])

def test_invalid_parameters(self):
with self.assertRaises(ValueError):
self.embedder = DocumentEmbedder(language='eng')
with self.assertRaises(ValueError):
self.embedder = DocumentEmbedder(aggregator='average')

def test_invalid_corpus_type(self):
with self.assertRaises(ValueError):
self.embedder(self.corpus[0])


if __name__ == "__main__":
unittest.main()
46 changes: 16 additions & 30 deletions orangecontrib/text/vectorization/document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
import sys
import warnings
import zlib
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Optional, Tuple

import numpy as np
from Orange.misc.server_embedder import ServerEmbedderCommunicator
from Orange.util import dummy_callback

from orangecontrib.text import Corpus
from orangecontrib.text.vectorization.base import BaseVectorizer

AGGREGATORS = ['Mean', 'Sum', 'Max', 'Min']
AGGREGATORS_L = ['mean', 'sum', 'max', 'min']
Expand Down Expand Up @@ -52,7 +53,7 @@
LANGUAGES = list(LANGS_TO_ISO.values())


class DocumentEmbedder:
class DocumentEmbedder(BaseVectorizer):
"""This class is used for obtaining dense embeddings of documents in
corpus using fastText pretrained models from:
E. Grave, P. Bojanowski, P. Gupta, A. Joulin, T. Mikolov,
Expand Down Expand Up @@ -93,9 +94,9 @@ def __init__(self, language: str = 'en',
server_url='https://apiv2.garaza.io',
embedder_type='text')

def __call__(
self, corpus: Union[Corpus, List[List[str]]], callback=dummy_callback
) -> Union[Tuple[Corpus, Corpus], List[Optional[List[float]]]]:
def _transform(
self, corpus: Corpus, _, callback=dummy_callback
) -> Tuple[Corpus, Corpus]:
"""Adds matrix of document embeddings to a corpus.
Parameters
Expand All @@ -109,14 +110,7 @@ def __call__(
Corpus (original or a copy) with new features added.
Skipped documents
Corpus of documents that were not embedded
Raises
------
ValueError
If corpus is not instance of Corpus.
"""
if not isinstance(corpus, (Corpus, list)):
raise ValueError("Input should be instance of Corpus or list.")
embs = self._embedder.embedd_data(
list(corpus.ngrams) if isinstance(corpus, Corpus) else corpus,
callback=callback,
Expand All @@ -135,12 +129,6 @@ def __call__(
skipped_documents = [emb is None for emb in embs]
embedded_documents = np.logical_not(skipped_documents)

variable_attrs = {
'hidden': True,
'skip-normalization': True,
'embedding-feature': True
}

new_corpus = None
if np.any(embedded_documents):
# if at least one embedding is not None, extend attributes
Expand All @@ -150,18 +138,22 @@ def __call__(
[e for e, ns in zip(embs, embedded_documents) if ns],
dtype=float,
),
['Dim{}'.format(i + 1) for i in range(dim)],
var_attrs=variable_attrs
["Dim{}".format(i + 1) for i in range(dim)],
var_attrs={
"embedding-feature": True,
"hidden": True,
},
)

skipped_corpus = None
if np.any(skipped_documents):
skipped_corpus = corpus[skipped_documents].copy()
skipped_corpus.name = "Skipped documents"
warnings.warn(("Some documents were not embedded for " +
"unknown reason. Those documents " +
"are skipped."),
RuntimeWarning)
warnings.warn(
"Some documents were not embedded for unknown reason. Those "
"documents are skipped.",
RuntimeWarning,
)

return new_corpus, skipped_corpus

Expand All @@ -181,12 +173,6 @@ def clear_cache(self):
if self._embedder:
self._embedder.clear_cache()

def __enter__(self):
return self

def __exit__(self, _, __, ___):
pass


class _ServerEmbedder(ServerEmbedderCommunicator):
def __init__(self, aggregator: str, *args, **kwargs) -> None:
Expand Down
Loading

0 comments on commit aaad632

Please sign in to comment.