Skip to content

Commit

Permalink
Keywords: Add 'Embedding' scoring method
Browse files Browse the repository at this point in the history
  • Loading branch information
VesnaT committed Jun 1, 2021
1 parent 5461fa1 commit e398b72
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 6 deletions.
9 changes: 6 additions & 3 deletions orangecontrib/text/keywords/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from Orange.util import dummy_callback

from orangecontrib.text.keywords.rake import Rake
from orangecontrib.text.keywords.embedding import embedding_keywords, EMBEDDING_LANGUAGE_MAPPING
from orangecontrib.text.keywords.embedding import embedding_keywords, \
EMBEDDING_LANGUAGE_MAPPING
from orangecontrib.text.preprocess import StopwordsFilter

# all available languages for RAKE
Expand Down Expand Up @@ -175,8 +176,10 @@ class ScoringMethods:
Scoring methods enum.
"""
TF_IDF, RAKE, YAKE, EMBEDDING = "TF-IDF", "Rake", "YAKE!", "Embedding"
ITEMS = list(zip((TF_IDF, YAKE, RAKE),
(tfidf_keywords, yake_keywords, rake_keywords)))
ITEMS = list(zip(
(TF_IDF, YAKE, RAKE, EMBEDDING),
(tfidf_keywords, yake_keywords, rake_keywords, embedding_keywords)
))

TOKEN_METHODS = TF_IDF, EMBEDDING
DOCUMENT_METHODS = RAKE, YAKE
Expand Down
19 changes: 18 additions & 1 deletion orangecontrib/text/widgets/owkeywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@

from orangecontrib.text import Corpus
from orangecontrib.text.keywords import ScoringMethods, AggregationMethods, \
YAKE_LANGUAGE_MAPPING, RAKE_LANGUAGES
YAKE_LANGUAGE_MAPPING, RAKE_LANGUAGES, EMBEDDING_LANGUAGE_MAPPING
from orangecontrib.text.preprocess import BaseNormalizer

WORDS_COLUMN_NAME = "Words"
YAKE_LANGUAGES = list(YAKE_LANGUAGE_MAPPING.keys())
EMBEDDING_LANGUAGES = list(EMBEDDING_LANGUAGE_MAPPING.keys())


class Results(SimpleNamespace):
Expand Down Expand Up @@ -181,6 +182,7 @@ class OWKeywords(OWWidget, ConcurrentWidgetMixin):
selected_scoring_methods: Set[str] = Setting({ScoringMethods.TF_IDF})
yake_lang_index: int = Setting(YAKE_LANGUAGES.index("English"))
rake_lang_index: int = Setting(RAKE_LANGUAGES.index("English"))
embedding_lang_index: int = Setting(EMBEDDING_LANGUAGES.index("English"))
agg_method: int = Setting(AggregationMethods.MEAN)
sel_method: int = ContextSetting(SelectionMethods.N_BEST)
n_selected: int = ContextSetting(3)
Expand Down Expand Up @@ -219,6 +221,10 @@ def _setup_gui(self):
self.controlArea, self, "rake_lang_index", items=RAKE_LANGUAGES,
callback=self.__on_rake_lang_changed
)
embedding_cb = gui.comboBox(
self.controlArea, self, "embedding_lang_index",
items=EMBEDDING_LANGUAGES, callback=self.__on_emb_lang_changed
)

for i, (method_name, _) in enumerate(ScoringMethods.ITEMS):
check_box = QCheckBox(method_name, self)
Expand All @@ -232,6 +238,8 @@ def _setup_gui(self):
box.layout().addWidget(yake_cb, i, 1)
if method_name == ScoringMethods.RAKE:
box.layout().addWidget(rake_cb, i, 1)
if method_name == ScoringMethods.EMBEDDING:
box.layout().addWidget(embedding_cb, i, 1)

box = gui.vBox(self.controlArea, "Aggregation")
gui.comboBox(
Expand Down Expand Up @@ -308,6 +316,12 @@ def __on_rake_lang_changed(self):
del self.__cached_keywords[ScoringMethods.RAKE]
self.update_scores()

def __on_emb_lang_changed(self):
if ScoringMethods.EMBEDDING in self.selected_scoring_methods:
if ScoringMethods.EMBEDDING in self.__cached_keywords:
del self.__cached_keywords[ScoringMethods.EMBEDDING]
self.update_scores()

def __on_filter_changed(self):
model = self.view.model()
model.setFilterFixedString(self.__filter_line_edit.text().strip())
Expand Down Expand Up @@ -369,6 +383,9 @@ def update_scores(self):
"language": RAKE_LANGUAGES[self.rake_lang_index],
"max_len": self.corpus.ngram_range[1] if self.corpus else 1
},
ScoringMethods.EMBEDDING: {
"language": EMBEDDING_LANGUAGES[self.embedding_lang_index],
},
}
self.start(run, self.corpus, self.words, self.__cached_keywords,
self.selected_scoring_methods, kwargs, self.agg_method)
Expand Down
36 changes: 34 additions & 2 deletions orangecontrib/text/widgets/tests/test_owkeywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import numpy as np

from Orange.data import StringVariable, Table, Domain
from Orange.widgets.tests.base import WidgetTest
from Orange.widgets.tests.base import WidgetTest, simulate

from orangecontrib.text import Corpus
from orangecontrib.text.keywords import tfidf_keywords, yake_keywords, \
rake_keywords, embedding_keywords
from orangecontrib.text.preprocess import *
from orangecontrib.text.widgets.owkeywords import OWKeywords, run, \
AggregationMethods, ScoringMethods
Expand Down Expand Up @@ -126,7 +128,7 @@ def assertNanEqual(self, table1, table2):
self.assertEqual(x1, x2)


class TestOWWordList(WidgetTest):
class TestOWKeywords(WidgetTest):
def setUp(self):
self.widget = self.create_widget(OWKeywords)
self.corpus = Corpus.from_file("deerwester")
Expand Down Expand Up @@ -179,6 +181,36 @@ def test_sort_nans_asc(self):
self.assertListEqual(list(output.metas[:, 0]),
["System", "Widths", "opinion"])

def test_scoring_methods(self):
methods = [("TF-IDF", Mock(wraps=tfidf_keywords)),
("YAKE!", Mock(wraps=yake_keywords)),
("Rake", Mock(wraps=rake_keywords)),
("Embedding", Mock(wraps=embedding_keywords))]
with patch.object(ScoringMethods, "ITEMS", methods) as m:
scores = {"TF-IDF", "YAKE!", "Rake", "Embedding"}
settings = {"selected_scoring_methods": scores}
widget = self.create_widget(OWKeywords, stored_settings=settings)

cb = widget.controls.yake_lang_index
simulate.combobox_activate_item(cb, "Arabic")
cb = widget.controls.rake_lang_index
simulate.combobox_activate_item(cb, "Finnish")
cb = widget.controls.embedding_lang_index
simulate.combobox_activate_item(cb, "Kazakh")

self.send_signal(widget.Inputs.corpus, self.corpus, widget=widget)
self.wait_until_finished(widget=widget, timeout=10000)
out = self.get_output(widget.Outputs.words, widget=widget)
self.assertEqual(scores, {a.name for a in out.domain.attributes})

m[0][1].assert_called_once()
m[1][1].assert_called_once()
m[2][1].assert_called_once()
m[3][1].assert_called_once()
self.assertEqual(m[1][1].call_args[1]["language"], "Arabic")
self.assertEqual(m[2][1].call_args[1]["language"], "Finnish")
self.assertEqual(m[3][1].call_args[1]["language"], "Kazakh")

def test_send_report(self):
self.send_signal(self.widget.Inputs.corpus, self.corpus)
self.wait_until_finished()
Expand Down

0 comments on commit e398b72

Please sign in to comment.