From 57b54fc30772dbae4393a4fba84dce4a39294649 Mon Sep 17 00:00:00 2001 From: PrimozGodec Date: Wed, 29 Mar 2023 11:50:03 +0200 Subject: [PATCH] DocumentEmbedder - support language from Corpus --- .../text/tests/test_documentembedder.py | 94 +++++++++++------- .../text/vectorization/document_embedder.py | 97 +++++++------------ 2 files changed, 94 insertions(+), 97 deletions(-) diff --git a/orangecontrib/text/tests/test_documentembedder.py b/orangecontrib/text/tests/test_documentembedder.py index 9deb1747f..e9ebe1a3f 100644 --- a/orangecontrib/text/tests/test_documentembedder.py +++ b/orangecontrib/text/tests/test_documentembedder.py @@ -1,6 +1,8 @@ import unittest -from unittest.mock import patch +from unittest.mock import patch, ANY import asyncio + +from Orange.misc.utils.embedder_utils import EmbedderCache from numpy.testing import assert_array_equal from orangecontrib.text.vectorization.document_embedder import DocumentEmbedder @@ -29,9 +31,10 @@ class DocumentEmbedderTest(unittest.TestCase): def setUp(self): self.embedder = DocumentEmbedder() # default params self.corpus = Corpus.from_file('deerwester') + self.embedder.clear_cache("en") def tearDown(self): - self.embedder.clear_cache() + self.embedder.clear_cache("en") @patch(PATCH_METHOD) def test_with_empty_corpus(self, mock): @@ -39,13 +42,13 @@ def test_with_empty_corpus(self, mock): self.assertIsNone(self.embedder.transform(self.corpus[:0])[1]) mock.request.assert_not_called() mock.get_response.assert_not_called() - self.assertEqual(self.embedder._embedder._cache._cache_dict, dict()) + self.assertEqual(EmbedderCache("fasttext-en")._cache_dict, dict()) @patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}')) def test_success_subset(self): res, skipped = self.embedder.transform(self.corpus[[0]]) assert_array_equal(res.X, [[0.3, 1]]) - self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 1) + self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 1) self.assertIsNone(skipped) @patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}')) @@ -62,7 +65,7 @@ def test_empty_response(self): res, skipped = self.embedder.transform(self.corpus[[0]]) self.assertIsNone(res) self.assertEqual(len(skipped), 1) - self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0) + self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 0) @patch(PATCH_METHOD, make_dummy_post(b'str')) def test_invalid_response(self): @@ -70,7 +73,7 @@ def test_invalid_response(self): res, skipped = self.embedder.transform(self.corpus[[0]]) self.assertIsNone(res) self.assertEqual(len(skipped), 1) - self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0) + self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 0) @patch(PATCH_METHOD, make_dummy_post(b'{"embeddings": [0.3, 1]}')) def test_invalid_json_key(self): @@ -78,53 +81,47 @@ def test_invalid_json_key(self): res, skipped = self.embedder.transform(self.corpus[[0]]) self.assertIsNone(res) self.assertEqual(len(skipped), 1) - self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0) + self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 0) @patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}')) def test_persistent_caching(self): - self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0) + self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 0) self.embedder.transform(self.corpus[[0]]) - self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 1) - self.embedder._embedder._cache.persist_cache() + self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 1) self.embedder = DocumentEmbedder() - self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 1) + self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 1) - self.embedder.clear_cache() + self.embedder.clear_cache("en") self.embedder = DocumentEmbedder() - self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0) + self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 0) @patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}')) - def test_cache_for_different_languages(self): - embedder = DocumentEmbedder(language='sl') - embedder.clear_cache() - self.assertEqual(len(embedder._embedder._cache._cache_dict), 0) - embedder.transform(self.corpus[[0]]) - self.assertEqual(len(embedder._embedder._cache._cache_dict), 1) - embedder._embedder._cache.persist_cache() + def test_different_languages(self): + self.corpus.attributes["language"] = "sl" - self.embedder = DocumentEmbedder() - self.assertEqual(len(self.embedder._embedder._cache._cache_dict), 0) - self.embedder._embedder._cache.persist_cache() - - embedder = DocumentEmbedder(language='sl') - self.assertEqual(len(embedder._embedder._cache._cache_dict), 1) - embedder.clear_cache() - self.embedder.clear_cache() + embedder = DocumentEmbedder() + embedder.clear_cache("sl") + self.assertEqual(len(EmbedderCache("fasttext-sl")._cache_dict), 0) + embedder.transform(self.corpus[[0]]) + self.assertEqual(len(EmbedderCache("fasttext-sl")._cache_dict), 1) + self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 0) + self.assertEqual(len(EmbedderCache("fasttext-sl")._cache_dict), 1) + embedder.clear_cache("sl") @patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [0.3, 1]}')) def test_cache_for_different_aggregators(self): embedder = DocumentEmbedder(aggregator='max') - embedder.clear_cache() - self.assertEqual(len(embedder._embedder._cache._cache_dict), 0) + embedder.clear_cache("en") + + self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 0) embedder.transform(self.corpus[[0]]) - self.assertEqual(len(embedder._embedder._cache._cache_dict), 1) - embedder._embedder._cache.persist_cache() + self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 1) embedder = DocumentEmbedder(aggregator='min') - self.assertEqual(len(embedder._embedder._cache._cache_dict), 1) + self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 1) embedder.transform(self.corpus[[0]]) - self.assertEqual(len(embedder._embedder._cache._cache_dict), 2) + self.assertEqual(len(EmbedderCache("fasttext-en")._cache_dict), 2) @patch(PATCH_METHOD, side_effect=OSError) def test_connection_error(self, _): @@ -133,11 +130,36 @@ def test_connection_error(self, _): embedder.transform(self.corpus[[0]]) def test_invalid_parameters(self): - with self.assertRaises(ValueError): + with self.assertRaises(AssertionError): self.embedder = DocumentEmbedder(language='eng') - with self.assertRaises(ValueError): + with self.assertRaises(AssertionError): self.embedder = DocumentEmbedder(aggregator='average') + @patch("orangecontrib.text.vectorization.document_embedder._ServerEmbedder") + def test_set_language(self, m): + # method 1: language from corpus + self.corpus.attributes["language"] = "sl" + embedder = DocumentEmbedder() + embedder.transform(self.corpus) + m.assert_called_with( + "mean", + model_name="fasttext-sl", + max_parallel_requests=ANY, + server_url=ANY, + embedder_type=ANY, + ) + + # method 2: language explicitly set + embedder = DocumentEmbedder(language="es") + embedder.transform(self.corpus) + m.assert_called_with( + "mean", + model_name="fasttext-es", + max_parallel_requests=ANY, + server_url=ANY, + embedder_type=ANY, + ) + if __name__ == "__main__": unittest.main() diff --git a/orangecontrib/text/vectorization/document_embedder.py b/orangecontrib/text/vectorization/document_embedder.py index 10adc89bf..dace5f620 100644 --- a/orangecontrib/text/vectorization/document_embedder.py +++ b/orangecontrib/text/vectorization/document_embedder.py @@ -10,47 +10,20 @@ import numpy as np from Orange.misc.server_embedder import ServerEmbedderCommunicator +from Orange.misc.utils.embedder_utils import EmbedderCache from Orange.util import dummy_callback from orangecontrib.text import Corpus from orangecontrib.text.vectorization.base import BaseVectorizer -AGGREGATORS = ['Mean', 'Sum', 'Max', 'Min'] -AGGREGATORS_L = ['mean', 'sum', 'max', 'min'] -LANGS_TO_ISO = { - 'English': 'en', - 'Slovenian': 'sl', - 'German': 'de', - 'Arabic': 'ar', - 'Azerbaijani': 'az', - 'Bengali': 'bn', - 'Chinese': 'zh', - 'Danish': 'da', - 'Dutch': 'nl', - 'Finnish': 'fi', - 'French': 'fr', - 'Greek': 'el', - 'Hebrew': 'he', - 'Hindi': 'hi', - 'Hungarian': 'hu', - 'Indonesian': 'id', - 'Italian': 'it', - 'Japanese': 'ja', - 'Kazakh': 'kk', - 'Korean': 'ko', - 'Nepali': 'ne', - 'Norwegian (Bokm\u00e5l)': 'no', - 'Norwegian (Nynorsk)': 'nn', - 'Polish': 'pl', - 'Portuguese': 'pt', - 'Romanian': 'ro', - 'Russian': 'ru', - 'Spanish': 'es', - 'Swedish': 'sv', - 'Tajik': 'tg', - 'Turkish': 'tr' -} -LANGUAGES = list(LANGS_TO_ISO.values()) +AGGREGATORS = ["mean", "sum", "max", "min"] +# fmt: off +LANGUAGES = [ + 'en', 'sl', 'de', 'ar', 'az', 'bn', 'zh', 'da', 'nl', 'fi', 'fr', 'el', + 'he', 'hi', 'hu', 'id', 'it', 'ja', 'kk', 'ko', 'ne', 'no', 'nn', 'pl', + 'pt', 'ro', 'ru', 'es', 'sv', 'tg', 'tr' +] +# fmt: on class DocumentEmbedder(BaseVectorizer): @@ -62,10 +35,7 @@ class DocumentEmbedder(BaseVectorizer): Evaluation, 2018. Embedding is performed on server so the internet connection is a - prerequisite for using the class. Currently supported languages are: - - English (en) - - Slovenian (sl) - - German (de) + prerequisite for using the class. Attributes ---------- @@ -74,25 +44,18 @@ class DocumentEmbedder(BaseVectorizer): aggregator : str Aggregator which creates document embedding (single vector) from word embeddings (multiple vectors). - Allowed values are Mean, Sum, Max, Min. + Allowed values are mean, sum, max, min. """ - def __init__(self, language: str = 'en', - aggregator: str = 'Mean') -> None: - lang_error = '{} is not a valid language. Allowed values: {}' - agg_error = '{} is not a valid aggregator. Allowed values: {}' - if language.lower() not in LANGUAGES: - raise ValueError(lang_error.format(language, ', '.join(LANGUAGES))) - self.language = language.lower() - if aggregator.lower() not in AGGREGATORS_L: - raise ValueError(agg_error.format(aggregator, ', '.join(AGGREGATORS))) - self.aggregator = aggregator.lower() - - self._embedder = _ServerEmbedder(self.aggregator, - model_name='fasttext-'+self.language, - max_parallel_requests=100, - server_url='https://api.garaza.io', - embedder_type='text') + def __init__( + self, language: Optional[str] = None, aggregator: str = "mean" + ) -> None: + assert ( + language is None or language in LANGUAGES + ), f"Language should be one of: {LANGUAGES}" + assert aggregator in AGGREGATORS, f"Aggregator should be one of: {AGGREGATORS}" + self.aggregator = aggregator + self.language = language def _transform( self, corpus: Corpus, _, callback=dummy_callback @@ -111,7 +74,19 @@ def _transform( Skipped documents Corpus of documents that were not embedded """ - embs = self._embedder.embedd_data( + language = self.language if self.language else corpus.language + if language not in LANGUAGES: + raise ValueError( + "The FastText embedding does not support the Corpus's language." + ) + embedder = _ServerEmbedder( + self.aggregator, + model_name="fasttext-" + language, + max_parallel_requests=100, + server_url="https://api.garaza.io", + embedder_type="text", + ) + embs = embedder.embedd_data( list(corpus.ngrams) if isinstance(corpus, Corpus) else corpus, callback=callback, ) @@ -171,10 +146,10 @@ def report(self) -> Tuple[Tuple[str, str], ...]: ("Aggregator", self.aggregator), ) - def clear_cache(self): + @staticmethod + def clear_cache(language): """Clears embedder cache""" - if self._embedder: - self._embedder.clear_cache() + EmbedderCache(f"fasttext-{language}").clear_cache() class _ServerEmbedder(ServerEmbedderCommunicator):