Skip to content

Commit

Permalink
Document Embedding widget - use language from Corpus
Browse files Browse the repository at this point in the history
  • Loading branch information
PrimozGodec committed Mar 29, 2023
1 parent 57b54fc commit 90fc0a5
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 18 deletions.
37 changes: 28 additions & 9 deletions orangecontrib/text/widgets/owdocumentembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,18 @@
from Orange.widgets.widget import Msg, Output, OWWidget

from orangecontrib.text.corpus import Corpus
from orangecontrib.text.language import ISO2LANG, LANG2ISO
from orangecontrib.text.vectorization.document_embedder import (
AGGREGATORS,
LANGS_TO_ISO,
DocumentEmbedder,
LANGUAGES,
)
from orangecontrib.text.vectorization.sbert import SBERT
from orangecontrib.text.widgets.utils.owbasevectorizer import (
OWBaseVectorizer,
Vectorizer,
)

LANGUAGES = sorted(list(LANGS_TO_ISO.keys()))


class EmbeddingVectorizer(Vectorizer):
skipped_documents = None
Expand All @@ -42,6 +41,7 @@ class OWDocumentEmbedding(OWBaseVectorizer):
settings_version = 2

Methods = [SBERT, DocumentEmbedder]
DEFAULT_LANGUAGE = "English"

class Outputs(OWBaseVectorizer.Outputs):
skipped = Output("Skipped documents", Corpus)
Expand All @@ -57,7 +57,7 @@ class Warning(OWWidget.Warning):
unsuccessful_embeddings = Msg("Some embeddings were unsuccessful.")

method: int = Setting(default=0)
language: str = Setting(default="English")
language: str = Setting(default=DEFAULT_LANGUAGE, shema_only=True)
aggregator: str = Setting(default="Mean")

def __init__(self):
Expand All @@ -68,6 +68,8 @@ def __init__(self):
self.cancel_button.clicked.connect(self.cancel)
self.buttonsArea.layout().addWidget(self.cancel_button)
self.cancel_button.setDisabled(True)
# it should be only set when setting loaded from schema/workflow
self.__pending_language = self.language

def create_configuration_layout(self):
layout = QVBoxLayout()
Expand All @@ -81,7 +83,7 @@ def create_configuration_layout(self):
ibox,
self,
"language",
items=LANGUAGES,
items=[ISO2LANG[lg] for lg in LANGUAGES],
label="Language:",
sendSelectedValue=True, # value is actual string not index
orientation=Qt.Horizontal,
Expand All @@ -92,24 +94,41 @@ def create_configuration_layout(self):
ibox,
self,
"aggregator",
items=AGGREGATORS,
items=[a.capitalize() for a in AGGREGATORS],
label="Aggregator:",
sendSelectedValue=True, # value is actual string not index
orientation=Qt.Horizontal,
callback=self.on_change,
searchable=True,
)

return layout

@OWBaseVectorizer.Inputs.corpus
def set_data(self, corpus):
# set language from corpus as selected language
if corpus and corpus.language in LANGUAGES:
self.language = ISO2LANG[corpus.language]
else:
# if Corpus's language not supported use default language
self.language = self.DEFAULT_LANGUAGE

# when workflow loaded use language saved in workflow
if self.__pending_language is not None:
self.language = self.__pending_language
self.__pending_language = None

super().set_data(corpus)

def update_method(self):
disabled = self.method == 0
self.aggregator_cb.setDisabled(disabled)
self.language_cb.setDisabled(disabled)
self.vectorizer = EmbeddingVectorizer(self.init_method(), self.corpus)

def init_method(self):
params = dict(language=LANGS_TO_ISO[self.language], aggregator=self.aggregator)
params = dict(
language=LANG2ISO[self.language], aggregator=self.aggregator.lower()
)
kwargs = ({}, params)[self.method]
return self.Methods[self.method](**kwargs)

Expand All @@ -133,7 +152,7 @@ def on_exception(self, ex: Exception):
if isinstance(ex, EmbeddingConnectionError):
self.Error.no_connection()
else:
self.Error.unexpected_error(type(ex).__name__)
self.Error.unexpected_error(str(ex))
self.cancel()

def cancel(self):
Expand Down
87 changes: 78 additions & 9 deletions orangecontrib/text/widgets/tests/test_owdocumentembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from Orange.misc.utils.embedder_utils import EmbeddingConnectionError

from orangecontrib.text.tests.test_documentembedder import PATCH_METHOD, make_dummy_post
from orangecontrib.text.vectorization.sbert import EMB_DIM
from orangecontrib.text.vectorization.document_embedder import DocumentEmbedder
from orangecontrib.text.vectorization.sbert import EMB_DIM, SBERT
from orangecontrib.text.widgets.owdocumentembedding import OWDocumentEmbedding
from orangecontrib.text import Corpus

Expand All @@ -28,10 +29,14 @@ def setUp(self):

# test on fastText, except for tests that change the setting
self.widget.findChildren(QRadioButton)[1].click()
self.widget.vectorizer.method.clear_cache()
SBERT().clear_cache()
DocumentEmbedder.clear_cache("en")
DocumentEmbedder.clear_cache("sl")

def tearDown(self):
self.widget.vectorizer.method.clear_cache()
SBERT().clear_cache()
DocumentEmbedder.clear_cache("en")
DocumentEmbedder.clear_cache("sl")

def test_input(self):
set_data = self.widget.set_data = Mock()
Expand All @@ -57,10 +62,11 @@ def test_output(self):
@patch(PATCH_METHOD, make_dummy_post(b''))
def test_some_failed(self):
simulate.combobox_activate_index(
self.widget.controlArea.findChildren(QComboBox)[1], 1
self.widget.controlArea.findChildren(QComboBox)[0], 1
)
self.send_signal("Corpus", self.corpus)
self.wait_until_finished()
with self.assertWarns(RuntimeWarning): # avoid warnings in test logs
self.send_signal("Corpus", self.corpus)
self.wait_until_finished()
result = self.get_output(self.widget.Outputs.corpus)
skipped = self.get_output(self.widget.Outputs.skipped)
self.assertIsNone(result)
Expand Down Expand Up @@ -111,16 +117,17 @@ def test_rerun_on_new_data(self):
@patch('orangecontrib.text.vectorization.document_embedder' +
'._ServerEmbedder._encode_data_instance', none_method)
def test_skipped_documents(self):
self.send_signal("Corpus", self.corpus)
self.wait_until_finished()
with self.assertWarns(RuntimeWarning): # avoid warnings in test logs
self.send_signal("Corpus", self.corpus)
self.wait_until_finished()
self.assertIsNone(self.get_output(self.widget.Outputs.corpus))
self.assertEqual(len(self.get_output(self.widget.Outputs.skipped)), len(self.corpus))
self.assertTrue(self.widget.Warning.unsuccessful_embeddings.is_shown())

@patch(PATCH_METHOD, make_dummy_post(SBERT_RESPONSE))
def test_sbert(self):
self.widget.findChildren(QRadioButton)[0].click()
self.widget.vectorizer.method.clear_cache()
SBERT().clear_cache()

self.send_signal("Corpus", self.corpus)
result = self.get_output(self.widget.Outputs.corpus)
Expand All @@ -145,6 +152,68 @@ def test_corpus_name_preserved(self):
self.assertIsNotNone(result)
self.assertEqual("deerwester", result.name)

@patch(PATCH_METHOD, make_dummy_post(b'{"embedding": [1.3, 1]}'))
def test_fasttext_language(self):
# english corpus
self.send_signal(self.widget.Inputs.corpus, self.corpus)
self.assertEqual("English", self.widget.language)
result = self.get_output(self.widget.Outputs.corpus)
self.assertEqual(9, len(result))

# slovenian corpus
self.corpus.attributes["language"] = "sl"
self.send_signal(self.widget.Inputs.corpus, self.corpus)
self.assertEqual("Slovenian", self.widget.language)
result = self.get_output(self.widget.Outputs.corpus)
self.assertEqual(9, len(result))

# language none
self.corpus.attributes["language"] = None
self.send_signal(self.widget.Inputs.corpus, self.corpus)
# use widgets default language English
self.assertEqual(self.widget.DEFAULT_LANGUAGE, self.widget.language)
result = self.get_output(self.widget.Outputs.corpus)
self.assertEqual(9, len(result))

# language not supported
self.corpus.attributes["language"] = "be"
self.send_signal(self.widget.Inputs.corpus, self.corpus)
# use widgets default language English
self.assertEqual(self.widget.DEFAULT_LANGUAGE, self.widget.language)
result = self.get_output(self.widget.Outputs.corpus)
self.assertEqual(9, len(result))

# language english
self.corpus.attributes["language"] = "en"
self.send_signal(self.widget.Inputs.corpus, self.corpus)
self.assertEqual("English", self.widget.language)
result = self.get_output(self.widget.Outputs.corpus)
self.assertEqual(9, len(result))

# manually set language
simulate.combobox_activate_item(
self.widget.controlArea.findChildren(QComboBox)[0], "French"
)
self.assertEqual("French", self.widget.language)
result = self.get_output(self.widget.Outputs.corpus)
self.assertEqual(9, len(result))

# providing new corpus should reset language
self.send_signal(self.widget.Inputs.corpus, self.corpus)
self.assertEqual("English", self.widget.language)

def test_language_from_settings(self):
self.send_signal(self.widget.Inputs.corpus, self.corpus)
simulate.combobox_activate_item(
self.widget.controlArea.findChildren(QComboBox)[0], "French"
)
self.assertEqual("French", self.widget.language)
settings = self.widget.settingsHandler.pack_data(self.widget)

widget = self.create_widget(OWDocumentEmbedding, stored_settings=settings)
self.send_signal(widget.Inputs.corpus, self.corpus, widget=widget)
self.assertEqual("French", widget.language)


if __name__ == "__main__":
unittest.main()

0 comments on commit 90fc0a5

Please sign in to comment.