From 57b54fc30772dbae4393a4fba84dce4a39294649 Mon Sep 17 00:00:00 2001
From: PrimozGodec
Date: Wed, 29 Mar 2023 11:50:03 +0200
Subject: [PATCH] DocumentEmbedder - support language from Corpus
---
.../text/tests/test_documentembedder.py | 94 +++++++++++-------
.../text/vectorization/document_embedder.py | 97 +++++++------------
2 files changed, 94 insertions(+), 97 deletions(-)
diff --git a/orangecontrib/text/tests/test_documentembedder.py b/orangecontrib/text/tests/test_documentembedder.py
index 9deb1747f..e9ebe1a3f 100644
--- a/orangecontrib/text/tests/test_documentembedder.py
+++ b/orangecontrib/text/tests/test_documentembedder.py
@@ -1,6 +1,8 @@
import unittest
-from unittest.mock import patch
+from unittest.mock import patch, ANY
import asyncio
+
+from Orange.misc.utils.embedder_utils import EmbedderCache
from numpy.testing import assert_array_equal
from orangecontrib.text.vectorization.document_embedder import DocumentEmbedder
@@ -29,9 +31,10 @@ class DocumentEmbedderTest(unittest.TestCase):
def setUp(self):
self.embedder = DocumentEmbedder() # default params
self.corpus = Corpus.from_file('deerwester')
+ self.embedder.clear_cache("en")
def tearDown(self):
- self.embedder.clear_cache()
+ self.embedder.clear_cache("en")
@patch(PATCH_METHOD)
def test_with_empty_corpus(self, mock):
@@ -39,13 +42,13 @@ def test_with_empty_corpus(self, mock):
self.assertIsNone(self.embedder.transform(self.corpus[:0])[1])
mock.request.assert_not_called()
mock.get_response.assert_not_called()
- self.assertEqual(self.embedder._embedder._cache._cache_dict, dict())
+ self.assertEqual(EmbedderCache("fasttext-en")._cache_dict, dict())
@patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}'))
def test_success_subset(self):
res, skipped = self.embedder.transform(self.corpus[[0]])
assert_array_equal(res.X, [[0.3, 1]])
- self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 1)
+ self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 1)
self.assertIsNone(skipped)
@patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}'))
@@ -62,7 +65,7 @@ def test_empty_response(self):
res, skipped = self.embedder.transform(self.corpus[[0]])
self.assertIsNone(res)
self.assertEqual(len(skipped), 1)
- self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0)
+ self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 0)
@patch(PATCH_METHOD, make_dummy_post(b'str'))
def test_invalid_response(self):
@@ -70,7 +73,7 @@ def test_invalid_response(self):
res, skipped = self.embedder.transform(self.corpus[[0]])
self.assertIsNone(res)
self.assertEqual(len(skipped), 1)
- self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0)
+ self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 0)
@patch(PATCH_METHOD, make_dummy_post(b'{"embeddings": [0.3, 1]}'))
def test_invalid_json_key(self):
@@ -78,53 +81,47 @@ def test_invalid_json_key(self):
res, skipped = self.embedder.transform(self.corpus[[0]])
self.assertIsNone(res)
self.assertEqual(len(skipped), 1)
- self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0)
+ self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 0)
@patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}'))
def test_persistent_caching(self):
- self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0)
+ self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 0)
self.embedder.transform(self.corpus[[0]])
- self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 1)
- self.embedder._embedder._cache.persist_cache()
+ self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 1)
self.embedder = DocumentEmbedder()
- self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 1)
+ self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 1)
- self.embedder.clear_cache()
+ self.embedder.clear_cache("en")
self.embedder = DocumentEmbedder()
- self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0)
+ self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 0)
@patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}'))
- def test_cache_for_different_languages(self):
- embedder = DocumentEmbedder(language='sl')
- embedder.clear_cache()
- self.assertEqual(len(embedder._embedder._cache._cache_dict), 0)
- embedder.transform(self.corpus[[0]])
- self.assertEqual(len(embedder._embedder._cache._cache_dict), 1)
- embedder._embedder._cache.persist_cache()
+ def test_different_languages(self):
+ self.corpus.attributes["language"] = "sl"
- self.embedder = DocumentEmbedder()
- self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0)
- self.embedder._embedder._cache.persist_cache()
-
- embedder = DocumentEmbedder(language='sl')
- self.assertEqual(len(embedder._embedder._cache._cache_dict), 1)
- embedder.clear_cache()
- self.embedder.clear_cache()
+ embedder = DocumentEmbedder()
+ embedder.clear_cache("sl")
+ self.assertEqual(len(EmbedderCache("fasttext-sl")._cache_dict), 0)
+ embedder.transform(self.corpus[[0]])
+ self.assertEqual(len(EmbedderCache("fasttext-sl")._cache_dict), 1)
+ self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 0)
+ self.assertEqual(len(EmbedderCache("fasttext-sl")._cache_dict), 1)
+ embedder.clear_cache("sl")
@patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}'))
def test_cache_for_different_aggregators(self):
embedder = DocumentEmbedder(aggregator='max')
- embedder.clear_cache()
- self.assertEqual(len(embedder._embedder._cache._cache_dict), 0)
+ embedder.clear_cache("en")
+
+ self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 0)
embedder.transform(self.corpus[[0]])
- self.assertEqual(len(embedder._embedder._cache._cache_dict), 1)
- embedder._embedder._cache.persist_cache()
+ self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 1)
embedder = DocumentEmbedder(aggregator='min')
- self.assertEqual(len(embedder._embedder._cache._cache_dict), 1)
+ self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 1)
embedder.transform(self.corpus[[0]])
- self.assertEqual(len(embedder._embedder._cache._cache_dict), 2)
+ self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 2)
@patch(PATCH_METHOD, side_effect=OSError)
def test_connection_error(self, _):
@@ -133,11 +130,36 @@ def test_connection_error(self, _):
embedder.transform(self.corpus[[0]])
def test_invalid_parameters(self):
- with self.assertRaises(ValueError):
+ with self.assertRaises(AssertionError):
self.embedder = DocumentEmbedder(language='eng')
- with self.assertRaises(ValueError):
+ with self.assertRaises(AssertionError):
self.embedder = DocumentEmbedder(aggregator='average')
+ @patch("orangecontrib.text.vectorization.document_embedder._ServerEmbedder")
+ def test_set_language(self, m):
+ # method 1: language from corpus
+ self.corpus.attributes["language"] = "sl"
+ embedder = DocumentEmbedder()
+ embedder.transform(self.corpus)
+ m.assert_called_with(
+ "mean",
+ model_name="fasttext-sl",
+ max_parallel_requests=ANY,
+ server_url=ANY,
+ embedder_type=ANY,
+ )
+
+ # method 2: language explicitly set
+ embedder = DocumentEmbedder(language="es")
+ embedder.transform(self.corpus)
+ m.assert_called_with(
+ "mean",
+ model_name="fasttext-es",
+ max_parallel_requests=ANY,
+ server_url=ANY,
+ embedder_type=ANY,
+ )
+
if __name__ == "__main__":
unittest.main()
diff --git a/orangecontrib/text/vectorization/document_embedder.py b/orangecontrib/text/vectorization/document_embedder.py
index 10adc89bf..dace5f620 100644
--- a/orangecontrib/text/vectorization/document_embedder.py
+++ b/orangecontrib/text/vectorization/document_embedder.py
@@ -10,47 +10,20 @@
import numpy as np
from Orange.misc.server_embedder import ServerEmbedderCommunicator
+from Orange.misc.utils.embedder_utils import EmbedderCache
from Orange.util import dummy_callback
from orangecontrib.text import Corpus
from orangecontrib.text.vectorization.base import BaseVectorizer
-AGGREGATORS = ['Mean', 'Sum', 'Max', 'Min']
-AGGREGATORS_L = ['mean', 'sum', 'max', 'min']
-LANGS_TO_ISO = {
- 'English': 'en',
- 'Slovenian': 'sl',
- 'German': 'de',
- 'Arabic': 'ar',
- 'Azerbaijani': 'az',
- 'Bengali': 'bn',
- 'Chinese': 'zh',
- 'Danish': 'da',
- 'Dutch': 'nl',
- 'Finnish': 'fi',
- 'French': 'fr',
- 'Greek': 'el',
- 'Hebrew': 'he',
- 'Hindi': 'hi',
- 'Hungarian': 'hu',
- 'Indonesian': 'id',
- 'Italian': 'it',
- 'Japanese': 'ja',
- 'Kazakh': 'kk',
- 'Korean': 'ko',
- 'Nepali': 'ne',
- 'Norwegian (Bokm\u00e5l)': 'no',
- 'Norwegian (Nynorsk)': 'nn',
- 'Polish': 'pl',
- 'Portuguese': 'pt',
- 'Romanian': 'ro',
- 'Russian': 'ru',
- 'Spanish': 'es',
- 'Swedish': 'sv',
- 'Tajik': 'tg',
- 'Turkish': 'tr'
-}
-LANGUAGES = list(LANGS_TO_ISO.values())
+AGGREGATORS = ["mean", "sum", "max", "min"]
+# fmt: off
+LANGUAGES = [
+ 'en', 'sl', 'de', 'ar', 'az', 'bn', 'zh', 'da', 'nl', 'fi', 'fr', 'el',
+ 'he', 'hi', 'hu', 'id', 'it', 'ja', 'kk', 'ko', 'ne', 'no', 'nn', 'pl',
+ 'pt', 'ro', 'ru', 'es', 'sv', 'tg', 'tr'
+]
+# fmt: on
class DocumentEmbedder(BaseVectorizer):
@@ -62,10 +35,7 @@ class DocumentEmbedder(BaseVectorizer):
Evaluation, 2018.
Embedding is performed on server so the internet connection is a
- prerequisite for using the class. Currently supported languages are:
- - English (en)
- - Slovenian (sl)
- - German (de)
+ prerequisite for using the class.
Attributes
----------
@@ -74,25 +44,18 @@ class DocumentEmbedder(BaseVectorizer):
aggregator : str
Aggregator which creates document embedding (single
vector) from word embeddings (multiple vectors).
- Allowed values are Mean, Sum, Max, Min.
+ Allowed values are mean, sum, max, min.
"""
- def __init__(self, language: str = 'en',
- aggregator: str = 'Mean') -> None:
- lang_error = '{} is not a valid language. Allowed values: {}'
- agg_error = '{} is not a valid aggregator. Allowed values: {}'
- if language.lower() not in LANGUAGES:
- raise ValueError(lang_error.format(language, ', '.join(LANGUAGES)))
- self.language = language.lower()
- if aggregator.lower() not in AGGREGATORS_L:
- raise ValueError(agg_error.format(aggregator, ', '.join(AGGREGATORS)))
- self.aggregator = aggregator.lower()
-
- self._embedder = _ServerEmbedder(self.aggregator,
- model_name='fasttext-'+self.language,
- max_parallel_requests=100,
- server_url='https://api.garaza.io',
- embedder_type='text')
+ def __init__(
+ self, language: Optional[str] = None, aggregator: str = "mean"
+ ) -> None:
+ assert (
+ language is None or language in LANGUAGES
+ ), f"Language should be one of: {LANGUAGES}"
+ assert aggregator in AGGREGATORS, f"Aggregator should be one of: {AGGREGATORS}"
+ self.aggregator = aggregator
+ self.language = language
def _transform(
self, corpus: Corpus, _, callback=dummy_callback
@@ -111,7 +74,19 @@ def _transform(
Skipped documents
Corpus of documents that were not embedded
"""
- embs = self._embedder.embedd_data(
+ language = self.language if self.language else corpus.language
+ if language not in LANGUAGES:
+ raise ValueError(
+ "The FastText embedding does not support the Corpus's language."
+ )
+ embedder = _ServerEmbedder(
+ self.aggregator,
+ model_name="fasttext-" + language,
+ max_parallel_requests=100,
+ server_url="https://api.garaza.io",
+ embedder_type="text",
+ )
+ embs = embedder.embedd_data(
list(corpus.ngrams) if isinstance(corpus, Corpus) else corpus,
callback=callback,
)
@@ -171,10 +146,10 @@ def report(self) -> Tuple[Tuple[str, str], ...]:
("Aggregator", self.aggregator),
)
- def clear_cache(self):
+ @staticmethod
+ def clear_cache(language):
"""Clears embedder cache"""
- if self._embedder:
- self._embedder.clear_cache()
+ EmbedderCache(f"fasttext-{language}").clear_cache()
class _ServerEmbedder(ServerEmbedderCommunicator):