Skip to content

Commit

Permalink
DocumentEmbedder - support language from Corpus
Browse files Browse the repository at this point in the history
  • Loading branch information
PrimozGodec committed Mar 29, 2023
1 parent 9bd6012 commit 57b54fc
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 97 deletions.
94 changes: 58 additions & 36 deletions orangecontrib/text/tests/test_documentembedder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import unittest
from unittest.mock import patch
from unittest.mock import patch, ANY
import asyncio

from Orange.misc.utils.embedder_utils import EmbedderCache
from numpy.testing import assert_array_equal

from orangecontrib.text.vectorization.document_embedder import DocumentEmbedder
Expand Down Expand Up @@ -29,23 +31,24 @@ class DocumentEmbedderTest(unittest.TestCase):
def setUp(self):
self.embedder = DocumentEmbedder() # default params
self.corpus = Corpus.from_file('deerwester')
self.embedder.clear_cache("en")

def tearDown(self):
self.embedder.clear_cache()
self.embedder.clear_cache("en")

@patch(PATCH_METHOD)
def test_with_empty_corpus(self, mock):
self.assertIsNone(self.embedder.transform(self.corpus[:0])[0])
self.assertIsNone(self.embedder.transform(self.corpus[:0])[1])
mock.request.assert_not_called()
mock.get_response.assert_not_called()
self.assertEqual(self.embedder._embedder._cache._cache_dict, dict())
self.assertEqual(EmbedderCache("fasttext-en")._cache_dict, dict())

@patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}'))
def test_success_subset(self):
res, skipped = self.embedder.transform(self.corpus[[0]])
assert_array_equal(res.X, [[0.3, 1]])
self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 1)
self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 1)
self.assertIsNone(skipped)

@patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}'))
Expand All @@ -62,69 +65,63 @@ def test_empty_response(self):
res, skipped = self.embedder.transform(self.corpus[[0]])
self.assertIsNone(res)
self.assertEqual(len(skipped), 1)
self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0)
self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 0)

@patch(PATCH_METHOD, make_dummy_post(b'str'))
def test_invalid_response(self):
with self.assertWarns(RuntimeWarning):
res, skipped = self.embedder.transform(self.corpus[[0]])
self.assertIsNone(res)
self.assertEqual(len(skipped), 1)
self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0)
self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 0)

@patch(PATCH_METHOD, make_dummy_post(b'{"embeddings": [0.3, 1]}'))
def test_invalid_json_key(self):
with self.assertWarns(RuntimeWarning):
res, skipped = self.embedder.transform(self.corpus[[0]])
self.assertIsNone(res)
self.assertEqual(len(skipped), 1)
self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0)
self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 0)

@patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}'))
def test_persistent_caching(self):
self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0)
self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 0)
self.embedder.transform(self.corpus[[0]])
self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 1)
self.embedder._embedder._cache.persist_cache()
self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 1)

self.embedder = DocumentEmbedder()
self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 1)
self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 1)

self.embedder.clear_cache()
self.embedder.clear_cache("en")
self.embedder = DocumentEmbedder()
self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0)
self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 0)

@patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}'))
def test_cache_for_different_languages(self):
embedder = DocumentEmbedder(language='sl')
embedder.clear_cache()
self.assertEqual(len(embedder._embedder._cache._cache_dict), 0)
embedder.transform(self.corpus[[0]])
self.assertEqual(len(embedder._embedder._cache._cache_dict), 1)
embedder._embedder._cache.persist_cache()
def test_different_languages(self):
self.corpus.attributes["language"] = "sl"

self.embedder = DocumentEmbedder()
self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0)
self.embedder._embedder._cache.persist_cache()

embedder = DocumentEmbedder(language='sl')
self.assertEqual(len(embedder._embedder._cache._cache_dict), 1)
embedder.clear_cache()
self.embedder.clear_cache()
embedder = DocumentEmbedder()
embedder.clear_cache("sl")
self.assertEqual(len(EmbedderCache("fasttext-sl")._cache_dict), 0)
embedder.transform(self.corpus[[0]])
self.assertEqual(len(EmbedderCache("fasttext-sl")._cache_dict), 1)
self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 0)
self.assertEqual(len(EmbedderCache("fasttext-sl")._cache_dict), 1)
embedder.clear_cache("sl")

@patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}'))
def test_cache_for_different_aggregators(self):
embedder = DocumentEmbedder(aggregator='max')
embedder.clear_cache()
self.assertEqual(len(embedder._embedder._cache._cache_dict), 0)
embedder.clear_cache("en")

self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 0)
embedder.transform(self.corpus[[0]])
self.assertEqual(len(embedder._embedder._cache._cache_dict), 1)
embedder._embedder._cache.persist_cache()
self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 1)

embedder = DocumentEmbedder(aggregator='min')
self.assertEqual(len(embedder._embedder._cache._cache_dict), 1)
self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 1)
embedder.transform(self.corpus[[0]])
self.assertEqual(len(embedder._embedder._cache._cache_dict), 2)
self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 2)

@patch(PATCH_METHOD, side_effect=OSError)
def test_connection_error(self, _):
Expand All @@ -133,11 +130,36 @@ def test_connection_error(self, _):
embedder.transform(self.corpus[[0]])

def test_invalid_parameters(self):
with self.assertRaises(ValueError):
with self.assertRaises(AssertionError):
self.embedder = DocumentEmbedder(language='eng')
with self.assertRaises(ValueError):
with self.assertRaises(AssertionError):
self.embedder = DocumentEmbedder(aggregator='average')

@patch("orangecontrib.text.vectorization.document_embedder._ServerEmbedder")
def test_set_language(self, m):
# method 1: language from corpus
self.corpus.attributes["language"] = "sl"
embedder = DocumentEmbedder()
embedder.transform(self.corpus)
m.assert_called_with(
"mean",
model_name="fasttext-sl",
max_parallel_requests=ANY,
server_url=ANY,
embedder_type=ANY,
)

# method 2: language explicitly set
embedder = DocumentEmbedder(language="es")
embedder.transform(self.corpus)
m.assert_called_with(
"mean",
model_name="fasttext-es",
max_parallel_requests=ANY,
server_url=ANY,
embedder_type=ANY,
)


if __name__ == "__main__":
unittest.main()
97 changes: 36 additions & 61 deletions orangecontrib/text/vectorization/document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,47 +10,20 @@

import numpy as np
from Orange.misc.server_embedder import ServerEmbedderCommunicator
from Orange.misc.utils.embedder_utils import EmbedderCache
from Orange.util import dummy_callback

from orangecontrib.text import Corpus
from orangecontrib.text.vectorization.base import BaseVectorizer

AGGREGATORS = ['Mean', 'Sum', 'Max', 'Min']
AGGREGATORS_L = ['mean', 'sum', 'max', 'min']
LANGS_TO_ISO = {
'English': 'en',
'Slovenian': 'sl',
'German': 'de',
'Arabic': 'ar',
'Azerbaijani': 'az',
'Bengali': 'bn',
'Chinese': 'zh',
'Danish': 'da',
'Dutch': 'nl',
'Finnish': 'fi',
'French': 'fr',
'Greek': 'el',
'Hebrew': 'he',
'Hindi': 'hi',
'Hungarian': 'hu',
'Indonesian': 'id',
'Italian': 'it',
'Japanese': 'ja',
'Kazakh': 'kk',
'Korean': 'ko',
'Nepali': 'ne',
'Norwegian (Bokm\u00e5l)': 'no',
'Norwegian (Nynorsk)': 'nn',
'Polish': 'pl',
'Portuguese': 'pt',
'Romanian': 'ro',
'Russian': 'ru',
'Spanish': 'es',
'Swedish': 'sv',
'Tajik': 'tg',
'Turkish': 'tr'
}
LANGUAGES = list(LANGS_TO_ISO.values())
AGGREGATORS = ["mean", "sum", "max", "min"]
# fmt: off
LANGUAGES = [
'en', 'sl', 'de', 'ar', 'az', 'bn', 'zh', 'da', 'nl', 'fi', 'fr', 'el',
'he', 'hi', 'hu', 'id', 'it', 'ja', 'kk', 'ko', 'ne', 'no', 'nn', 'pl',
'pt', 'ro', 'ru', 'es', 'sv', 'tg', 'tr'
]
# fmt: on


class DocumentEmbedder(BaseVectorizer):
Expand All @@ -62,10 +35,7 @@ class DocumentEmbedder(BaseVectorizer):
Evaluation, 2018.
Embedding is performed on server so the internet connection is a
prerequisite for using the class. Currently supported languages are:
- English (en)
- Slovenian (sl)
- German (de)
prerequisite for using the class.
Attributes
----------
Expand All @@ -74,25 +44,18 @@ class DocumentEmbedder(BaseVectorizer):
aggregator : str
Aggregator which creates document embedding (single
vector) from word embeddings (multiple vectors).
Allowed values are Mean, Sum, Max, Min.
Allowed values are mean, sum, max, min.
"""

def __init__(self, language: str = 'en',
aggregator: str = 'Mean') -> None:
lang_error = '{} is not a valid language. Allowed values: {}'
agg_error = '{} is not a valid aggregator. Allowed values: {}'
if language.lower() not in LANGUAGES:
raise ValueError(lang_error.format(language, ', '.join(LANGUAGES)))
self.language = language.lower()
if aggregator.lower() not in AGGREGATORS_L:
raise ValueError(agg_error.format(aggregator, ', '.join(AGGREGATORS)))
self.aggregator = aggregator.lower()

self._embedder = _ServerEmbedder(self.aggregator,
model_name='fasttext-'+self.language,
max_parallel_requests=100,
server_url='https://api.garaza.io',
embedder_type='text')
def __init__(
self, language: Optional[str] = None, aggregator: str = "mean"
) -> None:
assert (
language is None or language in LANGUAGES
), f"Language should be one of: {LANGUAGES}"
assert aggregator in AGGREGATORS, f"Aggregator should be one of: {AGGREGATORS}"
self.aggregator = aggregator
self.language = language

def _transform(
self, corpus: Corpus, _, callback=dummy_callback
Expand All @@ -111,7 +74,19 @@ def _transform(
Skipped documents
Corpus of documents that were not embedded
"""
embs = self._embedder.embedd_data(
language = self.language if self.language else corpus.language
if language not in LANGUAGES:
raise ValueError(
"The FastText embedding does not support the Corpus's language."
)
embedder = _ServerEmbedder(
self.aggregator,
model_name="fasttext-" + language,
max_parallel_requests=100,
server_url="https://api.garaza.io",
embedder_type="text",
)
embs = embedder.embedd_data(
list(corpus.ngrams) if isinstance(corpus, Corpus) else corpus,
callback=callback,
)
Expand Down Expand Up @@ -171,10 +146,10 @@ def report(self) -> Tuple[Tuple[str, str], ...]:
("Aggregator", self.aggregator),
)

def clear_cache(self):
@staticmethod
def clear_cache(language):
"""Clears embedder cache"""
if self._embedder:
self._embedder.clear_cache()
EmbedderCache(f"fasttext-{language}").clear_cache()


class _ServerEmbedder(ServerEmbedderCommunicator):
Expand Down

0 comments on commit 57b54fc

Please sign in to comment.