From ae7cf75f9309dd41ef8e646d9802e5ca71e32c9a Mon Sep 17 00:00:00 2001 From: djukicn Date: Wed, 27 Apr 2022 13:17:23 +0200 Subject: [PATCH 1/4] Document Embedding: add SBERT --- orangecontrib/text/vectorization/sbert.py | 52 ++++++++++++++++++- .../text/widgets/owdocumentembedding.py | 3 +- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/orangecontrib/text/vectorization/sbert.py b/orangecontrib/text/vectorization/sbert.py index b2367e18a..9a33f1f6e 100644 --- a/orangecontrib/text/vectorization/sbert.py +++ b/orangecontrib/text/vectorization/sbert.py @@ -2,12 +2,13 @@ import base64 import zlib import sys -from typing import Any, List, Optional, Callable +from typing import Any, List, Optional, Callable, Tuple, Union import numpy as np from Orange.misc.server_embedder import ServerEmbedderCommunicator from Orange.util import dummy_callback +from orangecontrib.text import Corpus # maximum document size that we still send to the server MAX_PACKAGE_SIZE = 3000000 @@ -100,6 +101,55 @@ def _make_chunks(self, encoded_texts, sizes, depth=0): result.append(chunks[i]) return [list(r) for r in result if len(r) > 0] + def embed_and_add_to_corpus( + self, + corpus: Corpus, + callback: Optional[Callable] = dummy_callback + ) -> Union[Tuple[Corpus, Corpus], List[Optional[List[float]]]]: + """Used for Document Embedding widget. See DocumentEmbedder for details.""" + + embs = self(corpus.documents, callback) + + dim = None + for emb in embs: # find embedding dimension + if emb is not None: + dim = len(emb) + break + # Check if some documents in corpus in weren't embedded + # for some reason. This is a very rare case. + skipped_documents = [emb is None for emb in embs] + embedded_documents = np.logical_not(skipped_documents) + + variable_attrs = { + 'hidden': True, + 'skip-normalization': True, + 'embedding-feature': True + } + + new_corpus = None + if np.any(embedded_documents): + # if at least one embedding is not None, extend attributes + new_corpus = corpus[embedded_documents] + new_corpus = new_corpus.extend_attributes( + np.array( + [e for e in embs if e], + dtype=float, + ), + ['Dim{}'.format(i + 1) for i in range(dim)], + var_attrs=variable_attrs + ) + + skipped_corpus = None + if np.any(skipped_documents): + skipped_corpus = corpus[skipped_documents].copy() + skipped_corpus.name = "Skipped documents" + warnings.warn(("Some documents were not embedded for " + + "unknown reason. Those documents " + + "are skipped."), + RuntimeWarning) + + return new_corpus, skipped_corpus + def clear_cache(self): if self._server_communicator: self._server_communicator.clear_cache() diff --git a/orangecontrib/text/widgets/owdocumentembedding.py b/orangecontrib/text/widgets/owdocumentembedding.py index 8bae94e30..1a8524d6e 100644 --- a/orangecontrib/text/widgets/owdocumentembedding.py +++ b/orangecontrib/text/widgets/owdocumentembedding.py @@ -30,7 +30,6 @@ def _transform(self, callback): self.new_corpus = embeddings self.skipped_documents = skipped - class OWDocumentEmbedding(OWBaseVectorizer): name = "Document Embedding" description = "Document embedding using pretrained models." @@ -56,6 +55,7 @@ class Error(OWWidget.Error): class Warning(OWWidget.Warning): unsuccessful_embeddings = Msg("Some embeddings were unsuccessful.") + method = Setting(default=0) language = Setting(default="English") aggregator = Setting(default="Mean") @@ -133,6 +133,7 @@ def migrate_settings(cls, settings: Dict[str, Any], version: Optional[int]): settings["aggregator"] = AGGREGATORS[settings["aggregator"]] + if __name__ == "__main__": from orangewidget.utils.widgetpreview import WidgetPreview From 181f9bc8178054f71fde08dd636f4fa3c750a6ed Mon Sep 17 00:00:00 2001 From: Primoz Godec Date: Wed, 29 Jun 2022 15:01:35 +0200 Subject: [PATCH 2/4] document_embedder - add method to report --- orangecontrib/text/vectorization/document_embedder.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/orangecontrib/text/vectorization/document_embedder.py b/orangecontrib/text/vectorization/document_embedder.py index 2f3fdd604..10adc89bf 100644 --- a/orangecontrib/text/vectorization/document_embedder.py +++ b/orangecontrib/text/vectorization/document_embedder.py @@ -157,7 +157,7 @@ def _transform( return new_corpus, skipped_corpus - def report(self) -> Tuple[Tuple[str, str], Tuple[str, str]]: + def report(self) -> Tuple[Tuple[str, str], ...]: """Reports on current parameters of DocumentEmbedder. Returns @@ -165,8 +165,11 @@ def report(self) -> Tuple[Tuple[str, str], Tuple[str, str]]: tuple Tuple of parameters. """ - return (('Language', self.language), - ('Aggregator', self.aggregator)) + return ( + ("Embedder", "fastText"), + ("Language", self.language), + ("Aggregator", self.aggregator), + ) def clear_cache(self): """Clears embedder cache""" From c16e7adddf6b4077cebc07043a8c7e6445742bc2 Mon Sep 17 00:00:00 2001 From: Primoz Godec Date: Wed, 29 Jun 2022 15:04:41 +0200 Subject: [PATCH 3/4] sbert - depend on base vectorizer, updates --- orangecontrib/text/tests/test_sbert.py | 67 +++------ orangecontrib/text/vectorization/sbert.py | 162 +++++++++------------- 2 files changed, 88 insertions(+), 141 deletions(-) diff --git a/orangecontrib/text/tests/test_sbert.py b/orangecontrib/text/tests/test_sbert.py index 2e677ff9b..c32c8ec7a 100644 --- a/orangecontrib/text/tests/test_sbert.py +++ b/orangecontrib/text/tests/test_sbert.py @@ -3,12 +3,7 @@ from collections.abc import Iterator import asyncio -from orangecontrib.text.vectorization.sbert import ( - SBERT, - MIN_CHUNKS, - MAX_PACKAGE_SIZE, - EMB_DIM -) +from orangecontrib.text.vectorization.sbert import SBERT, EMB_DIM from orangecontrib.text import Corpus PATCH_METHOD = 'httpx.AsyncClient.post' @@ -37,47 +32,17 @@ async def dummy_post(url, headers, data): class TestSBERT(unittest.TestCase): - def setUp(self): self.sbert = SBERT() + self.sbert.clear_cache() self.corpus = Corpus.from_file('deerwester') def tearDown(self): self.sbert.clear_cache() - def test_make_chunks_small(self): - chunks = self.sbert._make_chunks( - self.corpus.documents, [100] * len(self.corpus.documents) - ) - self.assertEqual(len(chunks), min(len(self.corpus.documents), MIN_CHUNKS)) - - def test_make_chunks_medium(self): - num_docs = len(self.corpus.documents) - documents = self.corpus.documents - if num_docs < MIN_CHUNKS: - documents = [documents[0]] * MIN_CHUNKS - chunks = self.sbert._make_chunks( - documents, [MAX_PACKAGE_SIZE / MIN_CHUNKS - 1] * len(documents) - ) - self.assertEqual(len(chunks), MIN_CHUNKS) - - def test_make_chunks_large(self): - num_docs = len(self.corpus.documents) - documents = self.corpus.documents - if num_docs < MIN_CHUNKS: - documents = [documents[0]] * MIN_CHUNKS * 100 - mps = MAX_PACKAGE_SIZE - chunks = self.sbert._make_chunks( - documents, - [mps / 100] * (len(documents) - 2) + [0.3 * mps, 0.9 * mps, mps] - ) - self.assertGreater(len(chunks), MIN_CHUNKS) - @patch(PATCH_METHOD) def test_empty_corpus(self, mock): - self.assertEqual( - len(self.sbert(self.corpus.documents[:0])), 0 - ) + self.assertEqual(len(self.sbert(self.corpus.documents[:0])), 0) mock.request.assert_not_called() mock.get_response.assert_not_called() self.assertEqual( @@ -95,14 +60,24 @@ def test_none_result(self): result = self.sbert(self.corpus.documents) self.assertEqual(result, IDEAL_RESPONSE[:-1] + [None]) - @patch(PATCH_METHOD, make_dummy_post(RESPONSE[0])) - def test_success_chunks(self): - num_docs = len(self.corpus.documents) - documents = self.corpus.documents - if num_docs < MIN_CHUNKS: - documents = [documents[0]] * MIN_CHUNKS - result = self.sbert(documents) - self.assertEqual(len(result), MIN_CHUNKS) + @patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE))) + def test_transform(self): + res, skipped = self.sbert.transform(self.corpus) + self.assertIsNone(skipped) + self.assertEqual(len(self.corpus), len(res)) + self.assertTupleEqual(self.corpus.domain.metas, res.domain.metas) + self.assertEqual(384, len(res.domain.attributes)) + + @patch(PATCH_METHOD, make_dummy_post(iter(RESPONSE[:-1] + [None] * 3))) + def test_transform_skipped(self): + res, skipped = self.sbert.transform(self.corpus) + self.assertEqual(len(self.corpus) - 1, len(res)) + self.assertTupleEqual(self.corpus.domain.metas, res.domain.metas) + self.assertEqual(384, len(res.domain.attributes)) + + self.assertEqual(1, len(skipped)) + self.assertTupleEqual(self.corpus.domain.metas, skipped.domain.metas) + self.assertEqual(0, len(skipped.domain.attributes)) if __name__ == "__main__": diff --git a/orangecontrib/text/vectorization/sbert.py b/orangecontrib/text/vectorization/sbert.py index 9a33f1f6e..7019a2217 100644 --- a/orangecontrib/text/vectorization/sbert.py +++ b/orangecontrib/text/vectorization/sbert.py @@ -1,31 +1,29 @@ import json import base64 +import warnings import zlib import sys -from typing import Any, List, Optional, Callable, Tuple, Union +from typing import Any, List, Optional, Callable, Tuple import numpy as np - from Orange.misc.server_embedder import ServerEmbedderCommunicator from Orange.util import dummy_callback + from orangecontrib.text import Corpus +from orangecontrib.text.vectorization.base import BaseVectorizer # maximum document size that we still send to the server MAX_PACKAGE_SIZE = 3000000 -# maximum size of a chunk - when one document is longer send is as a chunk with -# a single document -MAX_CHUNK_SIZE = 50000 -MIN_CHUNKS = 20 EMB_DIM = 384 -class SBERT: +class SBERT(BaseVectorizer): def __init__(self) -> None: self._server_communicator = _ServerCommunicator( - model_name='sbert', + model_name="sbert", max_parallel_requests=100, - server_url='https://api.garaza.io', - embedder_type='text', + server_url="https://api.garaza.io", + embedder_type="text", ) def __call__( @@ -42,90 +40,47 @@ def __call__( ------- An array of embeddings. """ - if len(texts) == 0: return [] + # sort text by their lengths that longer texts start to embed first. It + # prevents that long text with long embedding times start embedding + # at the end and thus add extra time to the complete embedding time + sorted_texts = sorted( + enumerate(texts), + key=lambda x: len(x[1][0]) if x[1] is not None else 0, + reverse=True, + ) + indices, sorted_texts = zip(*sorted_texts) + # embedd - send to server + results = self._server_communicator.embedd_data(sorted_texts, callback=callback) + # unsort and unpack + return [x[0] if x else None for _, x in sorted(zip(indices, results))] + + def _transform( + self, corpus: Corpus, _, callback=dummy_callback + ) -> Tuple[Corpus, Optional[Corpus]]: + """ + Computes embeddings for given corpus and append results to the corpus - skipped = list() - - encoded_texts = list() - sizes = list() - chunks = list() - for i, text in enumerate(texts): - encoded = base64.b64encode(zlib.compress( - text.encode('utf-8', 'replace'), level=-1) - ).decode('utf-8', 'replace') - size = sys.getsizeof(encoded) - if size > MAX_PACKAGE_SIZE: - skipped.append(i) - continue - encoded_texts.append(encoded) - sizes.append(size) - - chunks = self._make_chunks(encoded_texts, sizes) - - result_ = self._server_communicator.embedd_data(chunks, callback=callback) - if result_ is None: - return [None] * len(texts) - - result = list() - assert len(result_) == len(chunks) - for res_chunk, orig_chunk in zip(result_, chunks): - if res_chunk is None: - # when embedder fails (Timeout or other error) result will be None - result.extend([None] * len(orig_chunk)) - else: - result.extend(res_chunk) - - results = list() - idx = 0 - for i in range(len(texts)): - if i in skipped: - results.append(None) - else: - results.append(result[idx]) - idx += 1 - - return results - - def _make_chunks(self, encoded_texts, sizes, depth=0): - chunks = np.array_split(encoded_texts, MIN_CHUNKS if depth == 0 else 2) - chunk_sizes = np.array_split(sizes, MIN_CHUNKS if depth == 0 else 2) - result = list() - for i in range(len(chunks)): - # checking that more than one text in chunk prevent recursion to infinity - # when one text is bigger than MAX_CHUNK_SIZE - if len(chunks[i]) > 1 and np.sum(chunk_sizes[i]) > MAX_CHUNK_SIZE: - result.extend(self._make_chunks(chunks[i], chunk_sizes[i], depth + 1)) - else: - result.append(chunks[i]) - return [list(r) for r in result if len(r) > 0] - - def embed_and_add_to_corpus( - self, - corpus: Corpus, - callback: Optional[Callable] = dummy_callback - ) -> Union[Tuple[Corpus, Corpus], List[Optional[List[float]]]]: - """Used for Document Embedding widget. See DocumentEmbedder for details.""" + Parameters + ---------- + corpus + Corpus on which transform is performed. + Returns + ------- + Embeddings + Corpus with new features added. + Skipped documents + Corpus of documents that were not embedded + """ embs = self(corpus.documents, callback) - dim = None - for emb in embs: # find embedding dimension - if emb is not None: - dim = len(emb) - break # Check if some documents in corpus in weren't embedded # for some reason. This is a very rare case. skipped_documents = [emb is None for emb in embs] embedded_documents = np.logical_not(skipped_documents) - variable_attrs = { - 'hidden': True, - 'skip-normalization': True, - 'embedding-feature': True - } - new_corpus = None if np.any(embedded_documents): # if at least one embedding is not None, extend attributes @@ -135,34 +90,51 @@ def embed_and_add_to_corpus( [e for e in embs if e], dtype=float, ), - ['Dim{}'.format(i + 1) for i in range(dim)], - var_attrs=variable_attrs + ["Dim{}".format(i + 1) for i in range(EMB_DIM)], + var_attrs={ + "embedding-feature": True, + "hidden": True, + }, ) skipped_corpus = None if np.any(skipped_documents): skipped_corpus = corpus[skipped_documents].copy() skipped_corpus.name = "Skipped documents" - warnings.warn(("Some documents were not embedded for " + - "unknown reason. Those documents " + - "are skipped."), - RuntimeWarning) + warnings.warn( + "Some documents were not embedded for unknown reason. Those " + "documents are skipped.", + RuntimeWarning, + ) return new_corpus, skipped_corpus + def report(self) -> Tuple[Tuple[str, str], ...]: + """Reports on current parameters of DocumentEmbedder. + + Returns + ------- + tuple + Tuple of parameters. + """ + return (("Embedder", "Multilingual SBERT"),) + def clear_cache(self): if self._server_communicator: self._server_communicator.clear_cache() - def __enter__(self): - return self - class _ServerCommunicator(ServerEmbedderCommunicator): - def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - self.content_type = 'application/json' + self.content_type = "application/json" async def _encode_data_instance(self, data_instance: Any) -> Optional[bytes]: - return json.dumps(data_instance).encode('utf-8', 'replace') + data = base64.b64encode( + zlib.compress(data_instance.encode("utf-8", "replace"), level=-1) + ).decode("utf-8", "replace") + if sys.getsizeof(data) > 500000: + # Document in corpus is too large. Size limit is 500 KB + # (after compression). - document skipped + return None + return json.dumps([data]).encode("utf-8", "replace") From 170cd4c3fc882149c95b8829fe7aad1007c38122 Mon Sep 17 00:00:00 2001 From: Primoz Godec Date: Wed, 29 Jun 2022 15:05:18 +0200 Subject: [PATCH 4/4] Document Embedding - add SBERT method to widget --- .../text/widgets/owdocumentembedding.py | 59 +++++++++++-------- .../widgets/tests/test_owdocumentembedding.py | 18 ++++++ .../text/widgets/utils/owbasevectorizer.py | 2 +- 3 files changed, 54 insertions(+), 25 deletions(-) diff --git a/orangecontrib/text/widgets/owdocumentembedding.py b/orangecontrib/text/widgets/owdocumentembedding.py index 1a8524d6e..21f1cf381 100644 --- a/orangecontrib/text/widgets/owdocumentembedding.py +++ b/orangecontrib/text/widgets/owdocumentembedding.py @@ -1,7 +1,7 @@ from typing import Dict, Optional, Any from AnyQt.QtCore import Qt -from AnyQt.QtWidgets import QGridLayout, QLabel, QPushButton, QStyle +from AnyQt.QtWidgets import QVBoxLayout, QPushButton, QStyle from Orange.misc.utils.embedder_utils import EmbeddingConnectionError from Orange.widgets import gui from Orange.widgets.settings import Setting @@ -13,7 +13,7 @@ LANGS_TO_ISO, DocumentEmbedder, ) -from orangecontrib.text.widgets.utils import widgets +from orangecontrib.text.vectorization.sbert import SBERT from orangecontrib.text.widgets.utils.owbasevectorizer import ( OWBaseVectorizer, Vectorizer, @@ -30,6 +30,7 @@ def _transform(self, callback): self.new_corpus = embeddings self.skipped_documents = skipped + class OWDocumentEmbedding(OWBaseVectorizer): name = "Document Embedding" description = "Document embedding using pretrained models." @@ -40,7 +41,7 @@ class OWDocumentEmbedding(OWBaseVectorizer): buttons_area_orientation = Qt.Vertical settings_version = 2 - Method = DocumentEmbedder + Methods = [DocumentEmbedder, SBERT] class Outputs(OWBaseVectorizer.Outputs): skipped = Output("Skipped documents", Corpus) @@ -55,9 +56,9 @@ class Error(OWWidget.Error): class Warning(OWWidget.Warning): unsuccessful_embeddings = Msg("Some embeddings were unsuccessful.") - method = Setting(default=0) - language = Setting(default="English") - aggregator = Setting(default="Mean") + method: int = Setting(default=0) + language: str = Setting(default="English") + aggregator: str = Setting(default="Mean") def __init__(self): super().__init__() @@ -69,32 +70,43 @@ def __init__(self): self.cancel_button.setDisabled(True) def create_configuration_layout(self): - layout = QGridLayout() - layout.setSpacing(10) - - combo = widgets.ComboBox( + layout = QVBoxLayout() + rbtns = gui.radioButtons(None, self, "method", callback=self.on_change) + layout.addWidget(rbtns) + + gui.appendRadioButton(rbtns, "fastText:") + ibox = gui.indentedBox(rbtns) + gui.comboBox( + ibox, self, "language", items=LANGUAGES, + label="Language:", + sendSelectedValue=True, # value is actual string not index + orientation=Qt.Horizontal, + callback=self.on_change, + ) + gui.comboBox( + ibox, + self, + "aggregator", + items=AGGREGATORS, + label="Aggregator:", + sendSelectedValue=True, # value is actual string not index + orientation=Qt.Horizontal, + callback=self.on_change, ) - combo.currentIndexChanged.connect(self.on_change) - layout.addWidget(QLabel("Language:")) - layout.addWidget(combo, 0, 1) - - combo = widgets.ComboBox(self, "aggregator", items=AGGREGATORS) - combo.currentIndexChanged.connect(self.on_change) - layout.addWidget(QLabel("Aggregator:")) - layout.addWidget(combo, 1, 1) + gui.appendRadioButton(rbtns, "Multilingual SBERT:") return layout def update_method(self): self.vectorizer = EmbeddingVectorizer(self.init_method(), self.corpus) def init_method(self): - return self.Method( - language=LANGS_TO_ISO[self.language], aggregator=self.aggregator - ) + params = dict(language=LANGS_TO_ISO[self.language], aggregator=self.aggregator) + kwargs = (params, {})[self.method] + return self.Methods[self.method](**kwargs) @gui.deferred def commit(self): @@ -103,13 +115,13 @@ def commit(self): self.cancel_button.setDisabled(False) super().commit() - def on_done(self, _): + def on_done(self, result): self.cancel_button.setDisabled(True) skipped = self.vectorizer.skipped_documents self.Outputs.skipped.send(skipped) if skipped is not None and len(skipped) > 0: self.Warning.unsuccessful_embeddings() - super().on_done(_) + super().on_done(result) def on_exception(self, ex: Exception): self.cancel_button.setDisabled(True) @@ -133,7 +145,6 @@ def migrate_settings(cls, settings: Dict[str, Any], version: Optional[int]): settings["aggregator"] = AGGREGATORS[settings["aggregator"]] - if __name__ == "__main__": from orangewidget.utils.widgetpreview import WidgetPreview diff --git a/orangecontrib/text/widgets/tests/test_owdocumentembedding.py b/orangecontrib/text/widgets/tests/test_owdocumentembedding.py index 7a0210a1c..4346d26d4 100644 --- a/orangecontrib/text/widgets/tests/test_owdocumentembedding.py +++ b/orangecontrib/text/widgets/tests/test_owdocumentembedding.py @@ -1,12 +1,15 @@ import unittest from unittest.mock import Mock, patch +import numpy as np from AnyQt.QtWidgets import QComboBox from Orange.widgets.tests.base import WidgetTest from Orange.widgets.tests.utils import simulate from Orange.misc.utils.embedder_utils import EmbeddingConnectionError +from PyQt5.QtWidgets import QRadioButton from orangecontrib.text.tests.test_documentembedder import PATCH_METHOD, make_dummy_post +from orangecontrib.text.vectorization.sbert import EMB_DIM from orangecontrib.text.widgets.owdocumentembedding import OWDocumentEmbedding from orangecontrib.text import Corpus @@ -14,6 +17,9 @@ async def none_method(_, __): return None +_response_list = str(np.arange(0, EMB_DIM, dtype=float).tolist()) +SBERT_RESPONSE = f'{{"embedding": [{_response_list}]}}'.encode() + class TestOWDocumentEmbedding(WidgetTest): def setUp(self): @@ -105,6 +111,18 @@ def test_skipped_documents(self): self.assertEqual(len(self.get_output(self.widget.Outputs.skipped)), len(self.corpus)) self.assertTrue(self.widget.Warning.unsuccessful_embeddings.is_shown()) + @patch(PATCH_METHOD, make_dummy_post(SBERT_RESPONSE)) + def test_sbert(self): + self.widget.findChildren(QRadioButton)[1].click() + self.widget.vectorizer.method.clear_cache() + + self.send_signal("Corpus", self.corpus) + result = self.get_output(self.widget.Outputs.corpus) + self.assertIsInstance(result, Corpus) + self.assertEqual(len(self.corpus), len(result)) + self.assertTupleEqual(self.corpus.domain.metas, result.domain.metas) + self.assertEqual(384, len(result.domain.attributes)) + if __name__ == "__main__": unittest.main() diff --git a/orangecontrib/text/widgets/utils/owbasevectorizer.py b/orangecontrib/text/widgets/utils/owbasevectorizer.py index 41b029775..0434250bd 100644 --- a/orangecontrib/text/widgets/utils/owbasevectorizer.py +++ b/orangecontrib/text/widgets/utils/owbasevectorizer.py @@ -117,7 +117,7 @@ def on_change(self): self.commit.deferred() def send_report(self): - self.report_items(self.method.report()) + self.report_items(self.vectorizer.method.report()) def create_configuration_layout(self): raise NotImplementedError