diff --git a/orangecontrib/text/keywords/__init__.py b/orangecontrib/text/keywords/__init__.py index 56bba83e6..e5db373f1 100644 --- a/orangecontrib/text/keywords/__init__.py +++ b/orangecontrib/text/keywords/__init__.py @@ -13,7 +13,8 @@ from Orange.util import dummy_callback from orangecontrib.text.keywords.rake import Rake -from orangecontrib.text.keywords.embedding import embedding_keywords, EMBEDDING_LANGUAGE_MAPPING +from orangecontrib.text.keywords.embedding import embedding_keywords, \ + EMBEDDING_LANGUAGE_MAPPING from orangecontrib.text.preprocess import StopwordsFilter # all available languages for RAKE @@ -175,8 +176,10 @@ class ScoringMethods: Scoring methods enum. """ TF_IDF, RAKE, YAKE, EMBEDDING = "TF-IDF", "Rake", "YAKE!", "Embedding" - ITEMS = list(zip((TF_IDF, YAKE, RAKE), - (tfidf_keywords, yake_keywords, rake_keywords))) + ITEMS = list(zip( + (TF_IDF, YAKE, RAKE, EMBEDDING), + (tfidf_keywords, yake_keywords, rake_keywords, embedding_keywords) + )) TOKEN_METHODS = TF_IDF, EMBEDDING DOCUMENT_METHODS = RAKE, YAKE diff --git a/orangecontrib/text/widgets/owkeywords.py b/orangecontrib/text/widgets/owkeywords.py index 50b1a42c3..b16f56910 100644 --- a/orangecontrib/text/widgets/owkeywords.py +++ b/orangecontrib/text/widgets/owkeywords.py @@ -21,11 +21,12 @@ from orangecontrib.text import Corpus from orangecontrib.text.keywords import ScoringMethods, AggregationMethods, \ - YAKE_LANGUAGE_MAPPING, RAKE_LANGUAGES + YAKE_LANGUAGE_MAPPING, RAKE_LANGUAGES, EMBEDDING_LANGUAGE_MAPPING from orangecontrib.text.preprocess import BaseNormalizer WORDS_COLUMN_NAME = "Words" YAKE_LANGUAGES = list(YAKE_LANGUAGE_MAPPING.keys()) +EMBEDDING_LANGUAGES = list(EMBEDDING_LANGUAGE_MAPPING.keys()) class Results(SimpleNamespace): @@ -181,6 +182,7 @@ class OWKeywords(OWWidget, ConcurrentWidgetMixin): selected_scoring_methods: Set[str] = Setting({ScoringMethods.TF_IDF}) yake_lang_index: int = Setting(YAKE_LANGUAGES.index("English")) rake_lang_index: int = Setting(RAKE_LANGUAGES.index("English")) + embedding_lang_index: int = Setting(EMBEDDING_LANGUAGES.index("English")) agg_method: int = Setting(AggregationMethods.MEAN) sel_method: int = ContextSetting(SelectionMethods.N_BEST) n_selected: int = ContextSetting(3) @@ -219,6 +221,10 @@ def _setup_gui(self): self.controlArea, self, "rake_lang_index", items=RAKE_LANGUAGES, callback=self.__on_rake_lang_changed ) + embedding_cb = gui.comboBox( + self.controlArea, self, "embedding_lang_index", + items=EMBEDDING_LANGUAGES, callback=self.__on_emb_lang_changed + ) for i, (method_name, _) in enumerate(ScoringMethods.ITEMS): check_box = QCheckBox(method_name, self) @@ -232,6 +238,8 @@ def _setup_gui(self): box.layout().addWidget(yake_cb, i, 1) if method_name == ScoringMethods.RAKE: box.layout().addWidget(rake_cb, i, 1) + if method_name == ScoringMethods.EMBEDDING: + box.layout().addWidget(embedding_cb, i, 1) box = gui.vBox(self.controlArea, "Aggregation") gui.comboBox( @@ -308,6 +316,12 @@ def __on_rake_lang_changed(self): del self.__cached_keywords[ScoringMethods.RAKE] self.update_scores() + def __on_emb_lang_changed(self): + if ScoringMethods.EMBEDDING in self.selected_scoring_methods: + if ScoringMethods.EMBEDDING in self.__cached_keywords: + del self.__cached_keywords[ScoringMethods.EMBEDDING] + self.update_scores() + def __on_filter_changed(self): model = self.view.model() model.setFilterFixedString(self.__filter_line_edit.text().strip()) @@ -369,6 +383,9 @@ def update_scores(self): "language": RAKE_LANGUAGES[self.rake_lang_index], "max_len": self.corpus.ngram_range[1] if self.corpus else 1 }, + ScoringMethods.EMBEDDING: { + "language": EMBEDDING_LANGUAGES[self.embedding_lang_index], + }, } self.start(run, self.corpus, self.words, self.__cached_keywords, self.selected_scoring_methods, kwargs, self.agg_method) diff --git a/orangecontrib/text/widgets/tests/test_owkeywords.py b/orangecontrib/text/widgets/tests/test_owkeywords.py index e9e31198c..23408a442 100644 --- a/orangecontrib/text/widgets/tests/test_owkeywords.py +++ b/orangecontrib/text/widgets/tests/test_owkeywords.py @@ -6,9 +6,11 @@ import numpy as np from Orange.data import StringVariable, Table, Domain -from Orange.widgets.tests.base import WidgetTest +from Orange.widgets.tests.base import WidgetTest, simulate from orangecontrib.text import Corpus +from orangecontrib.text.keywords import tfidf_keywords, yake_keywords, \ + rake_keywords, embedding_keywords from orangecontrib.text.preprocess import * from orangecontrib.text.widgets.owkeywords import OWKeywords, run, \ AggregationMethods, ScoringMethods @@ -126,7 +128,7 @@ def assertNanEqual(self, table1, table2): self.assertEqual(x1, x2) -class TestOWWordList(WidgetTest): +class TestOWKeywords(WidgetTest): def setUp(self): self.widget = self.create_widget(OWKeywords) self.corpus = Corpus.from_file("deerwester") @@ -179,6 +181,36 @@ def test_sort_nans_asc(self): self.assertListEqual(list(output.metas[:, 0]), ["System", "Widths", "opinion"]) + def test_scoring_methods(self): + methods = [("TF-IDF", Mock(wraps=tfidf_keywords)), + ("YAKE!", Mock(wraps=yake_keywords)), + ("Rake", Mock(wraps=rake_keywords)), + ("Embedding", Mock(wraps=embedding_keywords))] + with patch.object(ScoringMethods, "ITEMS", methods) as m: + scores = {"TF-IDF", "YAKE!", "Rake", "Embedding"} + settings = {"selected_scoring_methods": scores} + widget = self.create_widget(OWKeywords, stored_settings=settings) + + cb = widget.controls.yake_lang_index + simulate.combobox_activate_item(cb, "Arabic") + cb = widget.controls.rake_lang_index + simulate.combobox_activate_item(cb, "Finnish") + cb = widget.controls.embedding_lang_index + simulate.combobox_activate_item(cb, "Kazakh") + + self.send_signal(widget.Inputs.corpus, self.corpus, widget=widget) + self.wait_until_finished(widget=widget, timeout=10000) + out = self.get_output(widget.Outputs.words, widget=widget) + self.assertEqual(scores, {a.name for a in out.domain.attributes}) + + m[0][1].assert_called_once() + m[1][1].assert_called_once() + m[2][1].assert_called_once() + m[3][1].assert_called_once() + self.assertEqual(m[1][1].call_args[1]["language"], "Arabic") + self.assertEqual(m[2][1].call_args[1]["language"], "Finnish") + self.assertEqual(m[3][1].call_args[1]["language"], "Kazakh") + def test_send_report(self): self.send_signal(self.widget.Inputs.corpus, self.corpus) self.wait_until_finished()