diff --git a/orangecontrib/text/widgets/owdocumentembedding.py b/orangecontrib/text/widgets/owdocumentembedding.py index 21f1cf381..a88b8674f 100644 --- a/orangecontrib/text/widgets/owdocumentembedding.py +++ b/orangecontrib/text/widgets/owdocumentembedding.py @@ -34,14 +34,14 @@ def _transform(self, callback): class OWDocumentEmbedding(OWBaseVectorizer): name = "Document Embedding" description = "Document embedding using pretrained models." - keywords = ["embedding", "document embedding", "text"] + keywords = ["embedding", "document embedding", "text", "fasttext", "bert", "sbert"] icon = "icons/TextEmbedding.svg" priority = 300 buttons_area_orientation = Qt.Vertical settings_version = 2 - Methods = [DocumentEmbedder, SBERT] + Methods = [SBERT, DocumentEmbedder] class Outputs(OWBaseVectorizer.Outputs): skipped = Output("Skipped documents", Corpus) @@ -74,9 +74,10 @@ def create_configuration_layout(self): rbtns = gui.radioButtons(None, self, "method", callback=self.on_change) layout.addWidget(rbtns) + gui.appendRadioButton(rbtns, "Multilingual SBERT") gui.appendRadioButton(rbtns, "fastText:") ibox = gui.indentedBox(rbtns) - gui.comboBox( + self.language_cb = gui.comboBox( ibox, self, "language", @@ -85,8 +86,9 @@ def create_configuration_layout(self): sendSelectedValue=True, # value is actual string not index orientation=Qt.Horizontal, callback=self.on_change, + searchable=True, ) - gui.comboBox( + self.aggregator_cb = gui.comboBox( ibox, self, "aggregator", @@ -95,17 +97,20 @@ def create_configuration_layout(self): sendSelectedValue=True, # value is actual string not index orientation=Qt.Horizontal, callback=self.on_change, + searchable=True, ) - gui.appendRadioButton(rbtns, "Multilingual SBERT:") return layout def update_method(self): + disabled = self.method == 0 + self.aggregator_cb.setDisabled(disabled) + self.language_cb.setDisabled(disabled) self.vectorizer = EmbeddingVectorizer(self.init_method(), self.corpus) def init_method(self): params = dict(language=LANGS_TO_ISO[self.language], aggregator=self.aggregator) - kwargs = (params, {})[self.method] + kwargs = ({}, params)[self.method] return self.Methods[self.method](**kwargs) @gui.deferred diff --git a/orangecontrib/text/widgets/tests/test_owdocumentembedding.py b/orangecontrib/text/widgets/tests/test_owdocumentembedding.py index 4346d26d4..5aab010aa 100644 --- a/orangecontrib/text/widgets/tests/test_owdocumentembedding.py +++ b/orangecontrib/text/widgets/tests/test_owdocumentembedding.py @@ -27,6 +27,13 @@ def setUp(self): self.corpus = Corpus.from_file('deerwester') self.larger_corpus = Corpus.from_file('book-excerpts') + # test on fastText, except for tests that change the setting + self.widget.findChildren(QRadioButton)[1].click() + self.widget.vectorizer.method.clear_cache() + + def tearDown(self): + self.widget.vectorizer.method.clear_cache() + def test_input(self): set_data = self.widget.set_data = Mock() self.send_signal("Corpus", None) @@ -42,7 +49,6 @@ def test_output(self): self.assertIsNone(self.get_output(self.widget.Outputs.corpus)) self.send_signal("Corpus", self.corpus) - self.wait_until_finished() result = self.get_output(self.widget.Outputs.corpus) self.assertIsNotNone(result) self.assertIsInstance(result, Corpus) @@ -113,7 +119,7 @@ def test_skipped_documents(self): @patch(PATCH_METHOD, make_dummy_post(SBERT_RESPONSE)) def test_sbert(self): - self.widget.findChildren(QRadioButton)[1].click() + self.widget.findChildren(QRadioButton)[0].click() self.widget.vectorizer.method.clear_cache() self.send_signal("Corpus", self.corpus) diff --git a/orangecontrib/text/widgets/utils/owbasevectorizer.py b/orangecontrib/text/widgets/utils/owbasevectorizer.py index 0434250bd..5ec8af80d 100644 --- a/orangecontrib/text/widgets/utils/owbasevectorizer.py +++ b/orangecontrib/text/widgets/utils/owbasevectorizer.py @@ -70,6 +70,7 @@ def __init__(self): box = QGroupBox(title="Options") box.setLayout(self.create_configuration_layout()) + box.layout().setContentsMargins(4, 4, 4, 4) # same than other widgets self.controlArea.layout().addWidget(box) output_layout = gui.hBox(self.controlArea)