Skip to content

Commit

Permalink
Merge pull request #992 from PrimozGodec/keywords-connection-fail
Browse files Browse the repository at this point in the history
[FIX] Keywords - Handle connection error
  • Loading branch information
janezd authored Aug 20, 2023
2 parents 57df491 + 9867d3c commit a32bf17
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 8 deletions.
6 changes: 5 additions & 1 deletion orangecontrib/text/keywords/mbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional, Callable, Tuple, List, Iterable

import numpy as np
from Orange.misc.utils.embedder_utils import EmbeddingConnectionError
from nltk import ngrams
from Orange.misc.server_embedder import ServerEmbedderCommunicator
from Orange.util import dummy_callback
Expand Down Expand Up @@ -68,7 +69,10 @@ def mbert_keywords(
server_url="https://api.garaza.io",
embedder_type="text",
)
keywords = emb.embedd_data(documents, callback=progress_callback)
try:
keywords = emb.embedd_data(documents, callback=progress_callback)
except EmbeddingConnectionError:
keywords = [None] * len(documents)
processed_kws = []
for kws in keywords:
if kws is not None:
Expand Down
9 changes: 9 additions & 0 deletions orangecontrib/text/tests/test_keywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
from Orange.data import Domain, StringVariable
from Orange.misc.utils.embedder_utils import EmbeddingConnectionError

from orangecontrib.text import Corpus
from orangecontrib.text.keywords import (
Expand Down Expand Up @@ -180,6 +181,14 @@ def test_mbert_keywords(self, _):
]
self.assertListEqual(expected, res)

@patch(
"orangecontrib.text.keywords.mbert._BertServerCommunicator.embedd_data",
side_effect=EmbeddingConnectionError,
)
def test_mbert_keywords_fail(self, _):
res = mbert_keywords(["Text 1", "Text 2"], max_len=3)
self.assertListEqual([None, None], res)


@patch(
"orangecontrib.text.keywords.mbert._BertServerCommunicator._send_request",
Expand Down
31 changes: 26 additions & 5 deletions orangecontrib/text/widgets/owkeywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
WORDS_COLUMN_NAME

YAKE_LANGUAGES = list(YAKE_LANGUAGE_MAPPING.keys())
CONNECTION_WARNING = (
f"{ScoringMethods.MBERT} could not extract keywords from some "
"documents due to connection error. Please rerun keyword extraction."
)


class Results(SimpleNamespace):
Expand All @@ -37,6 +41,8 @@ class Results(SimpleNamespace):
labels: List[str] = []
# all calculated keywords {method: [[(word1, score1), ...]]}
all_keywords: Dict[str, List[List[Tuple[str, float]]]] = {}
# warnings happening during keyword extraction process
warnings: List[str] = []


def run(
Expand All @@ -48,7 +54,7 @@ def run(
agg_method: int,
state: TaskState
) -> Results:
results = Results(scores=[], labels=[], all_keywords={})
results = Results(scores=[], labels=[], all_keywords={}, warnings=[])
if not corpus:
return results

Expand All @@ -70,7 +76,8 @@ def callback(i: float, status=""):
step = 1 / len(scoring_methods)
for method_name, func in ScoringMethods.ITEMS:
if method_name in scoring_methods:
if method_name not in results.all_keywords:
keywords = results.all_keywords.get(method_name)
if keywords is None:
i = len(results.labels)
cb = wrap_callback(callback, start=i * step,
end=(i + 1) * step)
Expand All @@ -79,10 +86,20 @@ def callback(i: float, status=""):
kw = {"progress_callback": cb}
kw.update(scoring_methods_kwargs.get(method_name, {}))

keywords = func(corpus if needs_tokens else documents, **kw)
results.all_keywords[method_name] = keywords
kws = func(corpus if needs_tokens else documents, **kw)
# None means that embedding completely failed on document
# currently it only happens with mbert when connection fails
keywords = [kws for kws in kws if kws is not None]
# don't store keywords to all_keywords if any were not computed
# due to connection issues; storing them would cause that
# missing keywords would not be recomputed on next run
# mbert's existing keywords are cached in embedding cache
# only missing will be recomputed
if len(kws) > len(keywords) and method_name == ScoringMethods.MBERT:
results.warnings.append(CONNECTION_WARNING)
else:
results.all_keywords[method_name] = keywords

keywords = results.all_keywords[method_name]
scores[method_name] = \
dict(AggregationMethods.aggregate(keywords, agg_method))

Expand Down Expand Up @@ -210,6 +227,7 @@ class Outputs:

class Warning(OWWidget.Warning):
no_words_column = Msg("Input is missing 'Words' column.")
extraction_warnings = Msg("{}")

def __init__(self):
OWWidget.__init__(self)
Expand Down Expand Up @@ -376,6 +394,7 @@ def handleNewSignals(self):
self.update_scores()

def update_scores(self):
self.Warning.extraction_warnings.clear()
kwargs = {
ScoringMethods.YAKE: {
"language": YAKE_LANGUAGES[self.yake_lang_index],
Expand Down Expand Up @@ -441,6 +460,8 @@ def on_done(self, results: Results):
self._select_rows()
else:
self.__on_selection_changed()
if results.warnings:
self.Warning.extraction_warnings("\n".join(results.warnings))

def _apply_sorting(self):
if self.model.columnCount() <= self.sort_column_order[0]:
Expand Down
63 changes: 61 additions & 2 deletions orangecontrib/text/widgets/tests/test_owkeywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,14 @@
from orangecontrib.text.keywords import tfidf_keywords, yake_keywords, \
rake_keywords
from orangecontrib.text.preprocess import *
from orangecontrib.text.widgets.owkeywords import OWKeywords, run, \
AggregationMethods, ScoringMethods, SelectionMethods
from orangecontrib.text.widgets.owkeywords import (
OWKeywords,
run,
AggregationMethods,
ScoringMethods,
SelectionMethods,
CONNECTION_WARNING,
)
from orangecontrib.text.widgets.utils.words import create_words_table


Expand Down Expand Up @@ -111,6 +117,27 @@ def test_run_interrupt(self):
{ScoringMethods.TF_IDF}, {},
AggregationMethods.MEAN, state)

def test_run_mbert_fail(self):
"""Test mbert partially or completely fails due to connection issues"""
agg, sc = AggregationMethods.MEAN, {ScoringMethods.MBERT}
res = [[("keyword1", 10), ("keyword2", 2)], None, [("keyword1", 5)]]
with patch.object(ScoringMethods, "ITEMS", [("mBERT", Mock(return_value=res))]):
results = run(self.corpus[:3], None, {}, sc, {}, agg, self.state)
self.assertListEqual([["keyword1", 7.5], ["keyword2", 1]], results.scores)
self.assertListEqual(["mBERT"], results.labels)
# not stored to all_keywords since not all extracted exactly
self.assertDictEqual({}, results.all_keywords)
self.assertListEqual([CONNECTION_WARNING], results.warnings)

res = [None] * 3
with patch.object(ScoringMethods, "ITEMS", [("mBERT", Mock(return_value=res))]):
results = run(self.corpus[:3], None, {}, sc, {}, agg, self.state)
self.assertListEqual([], results.scores)
self.assertListEqual(["mBERT"], results.labels)
# not stored to all_keywords since not all extracted exactly
self.assertDictEqual({}, results.all_keywords)
self.assertListEqual([CONNECTION_WARNING], results.warnings)

def assertNanEqual(self, table1, table2):
for list1, list2 in zip(table1, table2):
for x1, x2 in zip(list1, list2):
Expand Down Expand Up @@ -274,6 +301,38 @@ def test_selection_n_best(self):
output = self.get_output(self.widget.Outputs.words)
self.assertEqual(5, len(output))

def test_connection_error(self):
self.widget.controlArea.findChildren(QCheckBox)[0].click() # unselect tfidf
self.widget.controlArea.findChildren(QCheckBox)[3].click() # unselect mbert
res = [[("keyword1", 10), ("keyword2", 2)], None, [("keyword1", 5)]]
with patch.object(ScoringMethods, "ITEMS", [("mBERT", Mock(return_value=res))]):
self.send_signal(self.widget.Inputs.corpus, self.corpus)
output = self.get_output(self.widget.Outputs.words)
self.assertEqual(len(output), 2)
np.testing.assert_array_equal(output.metas, [["keyword1"], ["keyword2"]])
np.testing.assert_array_equal(output.X, [[7.5], [1]])
self.assertTrue(self.widget.Warning.extraction_warnings.is_shown())
self.assertEqual(
CONNECTION_WARNING, str(self.widget.Warning.extraction_warnings)
)

res = [None] * 3 # all failed
with patch.object(ScoringMethods, "ITEMS", [("mBERT", Mock(return_value=res))]):
self.send_signal(self.widget.Inputs.corpus, self.corpus)
self.assertIsNone(self.get_output(self.widget.Outputs.words))
self.assertTrue(self.widget.Warning.extraction_warnings.is_shown())
self.assertEqual(
CONNECTION_WARNING, str(self.widget.Warning.extraction_warnings)
)

res = [[("keyword1", 10), ("keyword2", 2)], [("keyword1", 5)]]
with patch.object(ScoringMethods, "ITEMS", [("mBERT", Mock(return_value=res))]):
self.send_signal(self.widget.Inputs.corpus, self.corpus)
output = self.get_output(self.widget.Outputs.words)
np.testing.assert_array_equal(output.metas, [["keyword1"], ["keyword2"]])
np.testing.assert_array_equal(output.X, [[7.5], [1]])
self.assertFalse(self.widget.Warning.extraction_warnings.is_shown())


if __name__ == "__main__":
unittest.main()

0 comments on commit a32bf17

Please sign in to comment.