Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FIX] Keywords - Handle connection error #992

Merged
merged 2 commits into from
Aug 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion orangecontrib/text/keywords/mbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional, Callable, Tuple, List, Iterable

import numpy as np
from Orange.misc.utils.embedder_utils import EmbeddingConnectionError
from nltk import ngrams
from Orange.misc.server_embedder import ServerEmbedderCommunicator
from Orange.util import dummy_callback
Expand Down Expand Up @@ -68,7 +69,10 @@ def mbert_keywords(
server_url="https://api.garaza.io",
embedder_type="text",
)
keywords = emb.embedd_data(documents, callback=progress_callback)
try:
keywords = emb.embedd_data(documents, callback=progress_callback)
except EmbeddingConnectionError:
keywords = [None] * len(documents)
processed_kws = []
for kws in keywords:
if kws is not None:
Expand Down
9 changes: 9 additions & 0 deletions orangecontrib/text/tests/test_keywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
from Orange.data import Domain, StringVariable
from Orange.misc.utils.embedder_utils import EmbeddingConnectionError

from orangecontrib.text import Corpus
from orangecontrib.text.keywords import (
Expand Down Expand Up @@ -180,6 +181,14 @@ def test_mbert_keywords(self, _):
]
self.assertListEqual(expected, res)

@patch(
"orangecontrib.text.keywords.mbert._BertServerCommunicator.embedd_data",
side_effect=EmbeddingConnectionError,
)
def test_mbert_keywords_fail(self, _):
res = mbert_keywords(["Text 1", "Text 2"], max_len=3)
self.assertListEqual([None, None], res)


@patch(
"orangecontrib.text.keywords.mbert._BertServerCommunicator._send_request",
Expand Down
31 changes: 26 additions & 5 deletions orangecontrib/text/widgets/owkeywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
WORDS_COLUMN_NAME

YAKE_LANGUAGES = list(YAKE_LANGUAGE_MAPPING.keys())
CONNECTION_WARNING = (
f"{ScoringMethods.MBERT} could not extract keywords from some "
"documents due to connection error. Please rerun keyword extraction."
)


class Results(SimpleNamespace):
Expand All @@ -37,6 +41,8 @@ class Results(SimpleNamespace):
labels: List[str] = []
# all calculated keywords {method: [[(word1, score1), ...]]}
all_keywords: Dict[str, List[List[Tuple[str, float]]]] = {}
# warnings happening during keyword extraction process
warnings: List[str] = []


def run(
Expand All @@ -48,7 +54,7 @@ def run(
agg_method: int,
state: TaskState
) -> Results:
results = Results(scores=[], labels=[], all_keywords={})
results = Results(scores=[], labels=[], all_keywords={}, warnings=[])
if not corpus:
return results

Expand All @@ -70,7 +76,8 @@ def callback(i: float, status=""):
step = 1 / len(scoring_methods)
for method_name, func in ScoringMethods.ITEMS:
if method_name in scoring_methods:
if method_name not in results.all_keywords:
keywords = results.all_keywords.get(method_name)
if keywords is None:
i = len(results.labels)
cb = wrap_callback(callback, start=i * step,
end=(i + 1) * step)
Expand All @@ -79,10 +86,20 @@ def callback(i: float, status=""):
kw = {"progress_callback": cb}
kw.update(scoring_methods_kwargs.get(method_name, {}))

keywords = func(corpus if needs_tokens else documents, **kw)
results.all_keywords[method_name] = keywords
kws = func(corpus if needs_tokens else documents, **kw)
# None means that embedding completely failed on document
# currently it only happens with mbert when connection fails
keywords = [kws for kws in kws if kws is not None]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment: I'd have for kw in kws; because it's singular and also because somebody (who doesn't know that this was fixed in Python 3) would fear that kws would be leaked out of list comprehension.

I'm merging as it is; if you want, just push a commit with this change into master. :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was actually my mistake (probably while making some changes). Agree to the use of the singular. I will fix it on master.

# don't store keywords to all_keywords if any were not computed
# due to connection issues; storing them would cause that
# missing keywords would not be recomputed on next run
# mbert's existing keywords are cached in embedding cache
# only missing will be recomputed
if len(kws) > len(keywords) and method_name == ScoringMethods.MBERT:
results.warnings.append(CONNECTION_WARNING)
else:
results.all_keywords[method_name] = keywords

keywords = results.all_keywords[method_name]
scores[method_name] = \
dict(AggregationMethods.aggregate(keywords, agg_method))

Expand Down Expand Up @@ -210,6 +227,7 @@ class Outputs:

class Warning(OWWidget.Warning):
no_words_column = Msg("Input is missing 'Words' column.")
extraction_warnings = Msg("{}")

def __init__(self):
OWWidget.__init__(self)
Expand Down Expand Up @@ -376,6 +394,7 @@ def handleNewSignals(self):
self.update_scores()

def update_scores(self):
self.Warning.extraction_warnings.clear()
kwargs = {
ScoringMethods.YAKE: {
"language": YAKE_LANGUAGES[self.yake_lang_index],
Expand Down Expand Up @@ -441,6 +460,8 @@ def on_done(self, results: Results):
self._select_rows()
else:
self.__on_selection_changed()
if results.warnings:
self.Warning.extraction_warnings("\n".join(results.warnings))

def _apply_sorting(self):
if self.model.columnCount() <= self.sort_column_order[0]:
Expand Down
63 changes: 61 additions & 2 deletions orangecontrib/text/widgets/tests/test_owkeywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,14 @@
from orangecontrib.text.keywords import tfidf_keywords, yake_keywords, \
rake_keywords
from orangecontrib.text.preprocess import *
from orangecontrib.text.widgets.owkeywords import OWKeywords, run, \
AggregationMethods, ScoringMethods, SelectionMethods
from orangecontrib.text.widgets.owkeywords import (
OWKeywords,
run,
AggregationMethods,
ScoringMethods,
SelectionMethods,
CONNECTION_WARNING,
)
from orangecontrib.text.widgets.utils.words import create_words_table


Expand Down Expand Up @@ -111,6 +117,27 @@ def test_run_interrupt(self):
{ScoringMethods.TF_IDF}, {},
AggregationMethods.MEAN, state)

def test_run_mbert_fail(self):
"""Test mbert partially or completely fails due to connection issues"""
agg, sc = AggregationMethods.MEAN, {ScoringMethods.MBERT}
res = [[("keyword1", 10), ("keyword2", 2)], None, [("keyword1", 5)]]
with patch.object(ScoringMethods, "ITEMS", [("mBERT", Mock(return_value=res))]):
results = run(self.corpus[:3], None, {}, sc, {}, agg, self.state)
self.assertListEqual([["keyword1", 7.5], ["keyword2", 1]], results.scores)
self.assertListEqual(["mBERT"], results.labels)
# not stored to all_keywords since not all extracted exactly
self.assertDictEqual({}, results.all_keywords)
self.assertListEqual([CONNECTION_WARNING], results.warnings)

res = [None] * 3
with patch.object(ScoringMethods, "ITEMS", [("mBERT", Mock(return_value=res))]):
results = run(self.corpus[:3], None, {}, sc, {}, agg, self.state)
self.assertListEqual([], results.scores)
self.assertListEqual(["mBERT"], results.labels)
# not stored to all_keywords since not all extracted exactly
self.assertDictEqual({}, results.all_keywords)
self.assertListEqual([CONNECTION_WARNING], results.warnings)

def assertNanEqual(self, table1, table2):
for list1, list2 in zip(table1, table2):
for x1, x2 in zip(list1, list2):
Expand Down Expand Up @@ -274,6 +301,38 @@ def test_selection_n_best(self):
output = self.get_output(self.widget.Outputs.words)
self.assertEqual(5, len(output))

def test_connection_error(self):
self.widget.controlArea.findChildren(QCheckBox)[0].click() # unselect tfidf
self.widget.controlArea.findChildren(QCheckBox)[3].click() # unselect mbert
res = [[("keyword1", 10), ("keyword2", 2)], None, [("keyword1", 5)]]
with patch.object(ScoringMethods, "ITEMS", [("mBERT", Mock(return_value=res))]):
self.send_signal(self.widget.Inputs.corpus, self.corpus)
output = self.get_output(self.widget.Outputs.words)
self.assertEqual(len(output), 2)
np.testing.assert_array_equal(output.metas, [["keyword1"], ["keyword2"]])
np.testing.assert_array_equal(output.X, [[7.5], [1]])
self.assertTrue(self.widget.Warning.extraction_warnings.is_shown())
self.assertEqual(
CONNECTION_WARNING, str(self.widget.Warning.extraction_warnings)
)

res = [None] * 3 # all failed
with patch.object(ScoringMethods, "ITEMS", [("mBERT", Mock(return_value=res))]):
self.send_signal(self.widget.Inputs.corpus, self.corpus)
self.assertIsNone(self.get_output(self.widget.Outputs.words))
self.assertTrue(self.widget.Warning.extraction_warnings.is_shown())
self.assertEqual(
CONNECTION_WARNING, str(self.widget.Warning.extraction_warnings)
)

res = [[("keyword1", 10), ("keyword2", 2)], [("keyword1", 5)]]
with patch.object(ScoringMethods, "ITEMS", [("mBERT", Mock(return_value=res))]):
self.send_signal(self.widget.Inputs.corpus, self.corpus)
output = self.get_output(self.widget.Outputs.words)
np.testing.assert_array_equal(output.metas, [["keyword1"], ["keyword2"]])
np.testing.assert_array_equal(output.X, [[7.5], [1]])
self.assertFalse(self.widget.Warning.extraction_warnings.is_shown())


if __name__ == "__main__":
unittest.main()