Skip to content

Commit

Permalink
Score documents - replace fasttext with sbert embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
PrimozGodec committed Jan 18, 2023
1 parent 41824aa commit c26aaca
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 27 deletions.
60 changes: 38 additions & 22 deletions orangecontrib/text/widgets/owscoredocuments.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from orangecontrib.text import Corpus
from orangecontrib.text.preprocess import BaseNormalizer, BaseTransformer
from orangecontrib.text.vectorization.document_embedder import DocumentEmbedder
from orangecontrib.text.vectorization.sbert import SBERT
from orangecontrib.text.widgets.utils.words import create_words_table


Expand Down Expand Up @@ -67,25 +67,26 @@ def _embedding_similarity(
) -> np.ndarray:
# make sure there will be only embeddings in X after calling the embedder
corpus = Corpus.from_table(Domain([], metas=corpus.domain.metas), corpus)
emb = DocumentEmbedder()
emb = SBERT()

cb_part = len(corpus) / (len(corpus) + len(words))
documet_embeddings, skipped = emb.transform(
corpus, wrap_callback(callback, 0, cb_part)
)
assert skipped is None
if skipped:
# raise when any embedding failed. It could be also done that distances
# are computed only for valid embeddings, but it doesn't make sense
# since cases when part of documents do not embed are extremely rare
# usually when a network error happen embedding of all documents fail
raise ValueError("Some documents not embedded; try to rerun scoring")

# document embedding need corpus - changing list of words to corpus
words_feature = StringVariable("words")
words_c = Corpus.from_numpy(
Domain([], metas=[words_feature]),
np.empty((len(words), 0)),
metas=np.array([[w] for w in words]),
text_features=[words_feature],
language=corpus.language
)
w_emb, _ = emb.transform(words_c, wrap_callback(callback, cb_part, 1 - cb_part))
return cosine_similarity(documet_embeddings.X, w_emb.X)
w_emb = emb.embed_batches(words, batch_size=50)
if any(x is None for x in w_emb):
# raise when some words not embedded, using only valid word embedding
# would cause wrong results
raise ValueError("Some words not embedded; try to rerun scoring")
return cosine_similarity(documet_embeddings.X, np.array(w_emb))


SCORING_METHODS = {
Expand Down Expand Up @@ -192,12 +193,16 @@ def callback(i: float) -> None:
scoring_method = SCORING_METHODS[sm][1]
sig = signature(scoring_method)
add_params = {k: v for k, v in additional_params.items() if k in sig.parameters}
scs = scoring_method(
corpus,
words,
wrap_callback(callback, start=(i + 1) * cb_part, end=(i + 2) * cb_part),
**add_params
)
try:
scs = scoring_method(
corpus,
words,
wrap_callback(callback, start=(i + 1) * cb_part, end=(i + 2) * cb_part),
**add_params
)
except ValueError as ex:
state.set_partial_result((sm, aggregation, str(ex)))
continue
scs = AGGREGATIONS[aggregation](scs, axis=1)
state.set_partial_result((sm, aggregation, scs))

Expand Down Expand Up @@ -343,6 +348,7 @@ class Outputs:

class Warning(OWWidget.Warning):
corpus_not_normalized = Msg("Use Preprocess Text to normalize corpus.")
scoring_warning = Msg("{}")

class Error(OWWidget.Error):
custom_err = Msg("{}")
Expand Down Expand Up @@ -617,6 +623,7 @@ def __setting_changed(self) -> None:
@gui.deferred
def commit(self) -> None:
self.Error.custom_err.clear()
self.Warning.scoring_warning.clear()
self.cancel()
if self.corpus is not None and self.words is not None:
scorers = self._get_active_scorers()
Expand All @@ -640,10 +647,19 @@ def commit(self) -> None:
def on_done(self, _: None) -> None:
self._send_output()

def on_partial_result(self, result: Tuple[str, str, np.ndarray]) -> None:
def on_partial_result(
self, result: Tuple[str, str, Union[np.ndarray, str]]
) -> None:
sc_method, aggregation, scores = result
self.scores[(sc_method, aggregation)] = scores
self._fill_table()
if isinstance(scores, str):
# scoring failed with error in scores variable
self.Warning.scoring_warning(
f"{SCORING_METHODS[sc_method][0]} failed: {scores}"
)
else:
# scoring successful
self.scores[(sc_method, aggregation)] = scores
self._fill_table()

def on_exception(self, ex: Exception) -> None:
self.Error.custom_err(ex)
Expand Down
116 changes: 111 additions & 5 deletions orangecontrib/text/widgets/tests/test_owscoredocuments.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from Orange.widgets.tests.utils import simulate

from orangecontrib.text import Corpus, preprocess
from orangecontrib.text.vectorization.document_embedder import _ServerEmbedder
from orangecontrib.text.vectorization.sbert import SBERT
from orangecontrib.text.widgets.owscoredocuments import (
OWScoreDocuments,
SelectionMethods,
Expand All @@ -22,8 +22,12 @@
from orangecontrib.text.widgets.utils.words import create_words_table


def embedding_mock(_, data, callback=None):
return np.ones((len(data), 10))
def embedding_mock(_, data, batch_size=None, callback=None):
return np.ones((len(data), 10)).tolist()


def embedding_mock_none(_, data, batch_size=None, callback=None):
return np.ones((len(data) - 1, 10)).tolist() + [None]


class TestOWScoreDocuments(WidgetTest):
Expand Down Expand Up @@ -117,7 +121,8 @@ def test_guess_word_attribute(self):
self.send_signal(self.widget.Inputs.words, None)
self.assertIsNone(self.widget.words)

@patch.object(_ServerEmbedder, "embedd_data", new=embedding_mock)
@patch.object(SBERT, "embed_batches", new=embedding_mock)
@patch.object(SBERT, "__call__", new=embedding_mock)
def test_change_scorer(self):
model = self.widget.model
self.send_signal(self.widget.Inputs.corpus, self.corpus)
Expand Down Expand Up @@ -230,7 +235,8 @@ def test_word_appearance(self):
self.assertTrue(all(isinstance(s, float) for s in scores))
self.assertListEqual(scores, [0, 0])

@patch.object(_ServerEmbedder, "embedd_data", new=embedding_mock)
@patch.object(SBERT, "embed_batches", new=embedding_mock)
@patch.object(SBERT, "__call__", new=embedding_mock)
def test_embedding_similarity(self):
corpus = self.create_corpus(
[
Expand Down Expand Up @@ -453,6 +459,106 @@ def test_titles_no_newline(self):
"The Little Match-Seller test", self.widget.view.model().index(0, 0).data()
)

@patch.object(SBERT, "embed_batches", new=embedding_mock)
@patch.object(SBERT, "__call__")
def test_warning_unsuccessful_scoring(self, emb_mock):
"""Test when embedding for at least one document is not successful"""
emb_mock.return_value = np.ones((len(self.corpus) - 1, 10)).tolist() + [None]

model = self.widget.model
self.send_signal(self.widget.Inputs.corpus, self.corpus)
self.send_signal(self.widget.Inputs.words, self.words)
self.wait_until_finished()

# scoring fails
self.widget.controls.embedding_similarity.click()
self.wait_until_finished()
self.assertEqual(2, model.columnCount()) # name and word count, no similarity
self.assertEqual(model.headerData(1, Qt.Horizontal), "Word count")
self.assertTrue(self.widget.Warning.scoring_warning.is_shown())
self.assertEqual(
"Similarity failed: Some documents not embedded; try to rerun scoring",
str(self.widget.Warning.scoring_warning),
)

# rerun without falling scoring
self.widget.controls.embedding_similarity.click()
self.wait_until_finished()
self.assertEqual(2, model.columnCount())
self.assertEqual(model.headerData(1, Qt.Horizontal), "Word count")
self.assertFalse(self.widget.Warning.scoring_warning.is_shown())

# run failing scoring again
self.widget.controls.embedding_similarity.click()
self.wait_until_finished()
self.assertEqual(2, model.columnCount()) # name and word count, no similarity
self.assertEqual(model.headerData(1, Qt.Horizontal), "Word count")
self.assertTrue(self.widget.Warning.scoring_warning.is_shown())
self.assertEqual(
"Similarity failed: Some documents not embedded; try to rerun scoring",
str(self.widget.Warning.scoring_warning),
)

# run scoring again, this time does not fail, warning should disapper
emb_mock.return_value = np.ones((len(self.corpus), 10)).tolist()
self.widget.controls.embedding_similarity.click()
self.widget.controls.embedding_similarity.click()
self.wait_until_finished()
self.assertEqual(3, model.columnCount()) # name and word count, no similarity
self.assertEqual(model.headerData(1, Qt.Horizontal), "Word count")
self.assertEqual(model.headerData(2, Qt.Horizontal), "Similarity")
self.assertFalse(self.widget.Warning.scoring_warning.is_shown())

@patch.object(SBERT, "embed_batches")
@patch.object(SBERT, "__call__", new=embedding_mock)
def test_warning_unsuccessful_scoring_words(self, emb_mock):
"""Test when words embedding for at least one word is not successful"""
emb_mock.return_value = np.ones((len(self.words), 10)).tolist() + [None]

model = self.widget.model
self.send_signal(self.widget.Inputs.corpus, self.corpus)
self.send_signal(self.widget.Inputs.words, self.words)
self.wait_until_finished()

# scoring fails
self.widget.controls.embedding_similarity.click()
self.wait_until_finished()
self.assertEqual(2, model.columnCount()) # name and word count, no similarity
self.assertEqual(model.headerData(1, Qt.Horizontal), "Word count")
self.assertTrue(self.widget.Warning.scoring_warning.is_shown())
self.assertEqual(
"Similarity failed: Some words not embedded; try to rerun scoring",
str(self.widget.Warning.scoring_warning),
)

# rerun without falling scoring
self.widget.controls.embedding_similarity.click()
self.wait_until_finished()
self.assertEqual(2, model.columnCount())
self.assertEqual(model.headerData(1, Qt.Horizontal), "Word count")
self.assertFalse(self.widget.Warning.scoring_warning.is_shown())

# run failing scoring again
self.widget.controls.embedding_similarity.click()
self.wait_until_finished()
self.assertEqual(2, model.columnCount()) # name and word count, no similarity
self.assertEqual(model.headerData(1, Qt.Horizontal), "Word count")
self.assertTrue(self.widget.Warning.scoring_warning.is_shown())
self.assertEqual(
"Similarity failed: Some words not embedded; try to rerun scoring",
str(self.widget.Warning.scoring_warning),
)

# run scoring again, this time does not fail, warning should disapper
emb_mock.return_value = np.ones((len(self.words), 10)).tolist()
self.widget.controls.embedding_similarity.click()
self.widget.controls.embedding_similarity.click()
self.wait_until_finished()
self.assertEqual(3, model.columnCount()) # name and word count, no similarity
self.assertEqual(model.headerData(1, Qt.Horizontal), "Word count")
self.assertEqual(model.headerData(2, Qt.Horizontal), "Similarity")
self.assertFalse(self.widget.Warning.scoring_warning.is_shown())


if __name__ == "__main__":
unittest.main()

0 comments on commit c26aaca

Please sign in to comment.