From b521d15420dc27b4c6bf4ee60631564eefe0cce2 Mon Sep 17 00:00:00 2001
From: Primoz Godec
Date: Tue, 17 Aug 2021 15:14:23 +0200
Subject: [PATCH] Score documents: Selection
---
.../text/widgets/owscoredocuments.py | 214 ++++++++++++++++--
.../widgets/tests/test_owscoredocuments.py | 57 ++++-
2 files changed, 243 insertions(+), 28 deletions(-)
diff --git a/orangecontrib/text/widgets/owscoredocuments.py b/orangecontrib/text/widgets/owscoredocuments.py
index d50b7b5c0..c46650654 100644
--- a/orangecontrib/text/widgets/owscoredocuments.py
+++ b/orangecontrib/text/widgets/owscoredocuments.py
@@ -1,20 +1,38 @@
import re
from collections import Counter
+from contextlib import contextmanager
from inspect import signature
from typing import Callable, List, Tuple, Union
import numpy as np
-from AnyQt.QtCore import QSortFilterProxyModel, Qt
-from AnyQt.QtWidgets import QHeaderView, QLineEdit, QTableView
+from AnyQt.QtCore import (
+ QItemSelection,
+ QItemSelectionModel,
+ QSortFilterProxyModel,
+ Qt,
+ Signal,
+)
+from AnyQt.QtWidgets import (
+ QButtonGroup,
+ QGridLayout,
+ QHeaderView,
+ QLineEdit,
+ QRadioButton,
+ QTableView,
+)
+from pandas import isnull
+from sklearn.metrics.pairwise import cosine_similarity
+
+# todo: uncomment when minimum version of Orange is 3.29.2
+# from orangecanvas.gui.utils import disconnected
+from orangewidget import gui
from Orange.data import ContinuousVariable, Domain, StringVariable, Table
from Orange.util import wrap_callback
+from Orange.widgets.settings import ContextSetting, PerfectDomainContextHandler, Setting
+from Orange.widgets.utils.annotated_data import create_annotated_table
from Orange.widgets.utils.concurrent import ConcurrentWidgetMixin, TaskState
from Orange.widgets.utils.itemmodels import PyTableModel, TableModel
from Orange.widgets.widget import Input, Msg, Output, OWWidget
-from orangewidget import gui
-from orangewidget.settings import Setting
-from pandas import isnull
-from sklearn.metrics.pairwise import cosine_similarity
from orangecontrib.text import Corpus
from orangecontrib.text.preprocess import BaseNormalizer, BaseTransformer
@@ -24,6 +42,16 @@
)
+# todo: remove when minimum version of Orange is 3.29.2
+@contextmanager
+def disconnected(signal, slot, type=Qt.UniqueConnection):
+ signal.disconnect(slot)
+ try:
+ yield
+ finally:
+ signal.connect(slot, type)
+
+
def _word_frequency(corpus: Corpus, words: List[str], callback: Callable) -> np.ndarray:
res = []
tokens = corpus.tokens
@@ -183,12 +211,20 @@ def callback(i: float) -> None:
state.set_partial_result((sm, aggregation, scs))
+class SelectionMethods:
+ NONE, ALL, MANUAL, N_BEST = range(4)
+ ITEMS = "None", "All", "Manual", "Top documents"
+
+
class ScoreDocumentsTableView(QTableView):
+ pressedAny = Signal()
+
def __init__(self):
super().__init__(
sortingEnabled=True,
editTriggers=QTableView.NoEditTriggers,
- selectionMode=QTableView.NoSelection,
+ selectionMode=QTableView.ExtendedSelection,
+ selectionBehavior=QTableView.SelectRows,
cornerButtonEnabled=False,
)
self.setItemDelegate(gui.ColoredBarItemDelegate(self))
@@ -215,6 +251,10 @@ def update_column_widths(self) -> None:
# document title column is one that stretch
header.setSectionResizeMode(0, QHeaderView.Stretch)
+ def mousePressEvent(self, event):
+ super().mousePressEvent(event)
+ self.pressedAny.emit()
+
class ScoreDocumentsProxyModel(QSortFilterProxyModel):
@staticmethod
@@ -253,23 +293,31 @@ class OWScoreDocuments(OWWidget, ConcurrentWidgetMixin):
icon = "icons/ScoreDocuments.svg"
priority = 500
+ buttons_area_orientation = Qt.Vertical
+
# default order - table sorted in input order
DEFAULT_SORTING = (-1, Qt.AscendingOrder)
+ settingsHandler = PerfectDomainContextHandler()
auto_commit: bool = Setting(True)
aggregation: int = Setting(0)
word_frequency: bool = Setting(True)
word_appearance: bool = Setting(False)
embedding_similarity: bool = Setting(False)
+ embedding_language: int = Setting(0)
+
sort_column_order: Tuple[int, int] = Setting(DEFAULT_SORTING)
- embedding_language = Setting(0)
+ selected_rows: List[int] = ContextSetting([], schema_only=True)
+ sel_method: int = ContextSetting(SelectionMethods.N_BEST)
+ n_selected: int = ContextSetting(3)
class Inputs:
corpus = Input("Corpus", Corpus)
words = Input("Words", Table)
class Outputs:
+ selected_documents = Output("Selected documents", Corpus, default=True)
corpus = Output("Corpus", Corpus)
class Warning(OWWidget.Warning):
@@ -322,6 +370,33 @@ def _setup_control_area(self) -> None:
)
gui.rubber(self.controlArea)
+
+ # select words box
+ box = gui.vBox(self.buttonsArea, "Select Documents")
+ grid = QGridLayout()
+ grid.setContentsMargins(0, 0, 0, 0)
+
+ self._sel_method_buttons = QButtonGroup()
+ for method, label in enumerate(SelectionMethods.ITEMS):
+ button = QRadioButton(label)
+ button.setChecked(method == self.sel_method)
+ grid.addWidget(button, method, 0)
+ self._sel_method_buttons.addButton(button, method)
+ self._sel_method_buttons.buttonClicked[int].connect(self.__set_selection_method)
+
+ spin = gui.spin(
+ box,
+ self,
+ "n_selected",
+ 1,
+ 999,
+ addToLayout=False,
+ callback=lambda: self.__set_selection_method(SelectionMethods.N_BEST),
+ )
+ grid.addWidget(spin, 3, 1)
+ box.layout().addLayout(grid)
+
+ # autocommit
gui.auto_send(self.buttonsArea, self, "auto_commit")
def _setup_main_area(self) -> None:
@@ -333,7 +408,11 @@ def _setup_main_area(self) -> None:
self.model = model = ScoreDocumentsTableModel(parent=self)
model.setHorizontalHeaderLabels(["Document"])
+ def select_manual():
+ self.__set_selection_method(SelectionMethods.MANUAL)
+
self.view = view = ScoreDocumentsTableView()
+ view.pressedAny.connect(select_manual)
self.mainArea.layout().addWidget(view)
# by default data are sorted in the Table order
header = self.view.horizontalHeader()
@@ -344,6 +423,7 @@ def _setup_main_area(self) -> None:
proxy_model.setFilterCaseSensitivity(False)
view.setModel(proxy_model)
view.model().setSourceModel(self.model)
+ self.view.selectionModel().selectionChanged.connect(self.__on_selection_change)
def __on_filter_changed(self) -> None:
model = self.view.model()
@@ -352,15 +432,39 @@ def __on_filter_changed(self) -> None:
def __on_horizontal_header_clicked(self, index: int):
header = self.view.horizontalHeader()
self.sort_column_order = (index, header.sortIndicatorOrder())
+ self._select_rows()
+ # when sorting change output table must consider the new order
+ # call explicitly since selection in table is not changed
+ if (
+ self.sel_method == SelectionMethods.MANUAL
+ and self.selected_rows
+ or self.sel_method == SelectionMethods.ALL
+ ):
+ # retrieve selection in new order
+ self.selected_rows = self.get_selected_indices()
+ self._send_output()
+
+ def __on_selection_change(self):
+ self.selected_rows = self.get_selected_indices()
+ self._send_output()
+
+ def __set_selection_method(self, method: int):
+ self.sel_method = method
+ self._sel_method_buttons.button(method).setChecked(True)
+ self._select_rows()
@Inputs.corpus
def set_data(self, corpus: Corpus) -> None:
+ self.closeContext()
self.Warning.corpus_not_normalized.clear()
if corpus is not None:
self.Warning.missing_corpus.clear()
if not self._is_corpus_normalized(corpus):
self.Warning.corpus_not_normalized()
self.corpus = corpus
+ self.selected_rows = []
+ self.openContext(corpus)
+ self._sel_method_buttons.button(self.sel_method).setChecked(True)
self._clear_and_run()
@staticmethod
@@ -416,10 +520,16 @@ def _gather_scores(self) -> Tuple[np.ndarray, List[str]]:
labels = [SCORING_METHODS[m][0] for m in methods]
return scores, labels
- def _send_output(self, scores: np.ndarray, labels: List[str]) -> None:
+ def _send_output(self) -> None:
"""
Create corpus with scores and output it
"""
+ if self.corpus is None:
+ self.Outputs.corpus.send(None)
+ self.Outputs.selected_documents.send(None)
+ return
+
+ scores, labels = self._gather_scores()
if labels:
d = self.corpus.domain
domain = Domain(
@@ -427,39 +537,54 @@ def _send_output(self, scores: np.ndarray, labels: List[str]) -> None:
d.class_var,
metas=d.metas + tuple(ContinuousVariable(l) for l in labels),
)
- corpus = Corpus(
+ out_corpus = Corpus(
domain,
self.corpus.X,
self.corpus.Y,
np.hstack([self.corpus.metas, scores]),
)
- Corpus.retain_preprocessing(self.corpus, corpus)
- self.Outputs.corpus.send(corpus)
- elif self.corpus is not None:
- self.Outputs.corpus.send(self.corpus)
+ Corpus.retain_preprocessing(self.corpus, out_corpus)
else:
- self.Outputs.corpus.send(None)
+ out_corpus = self.corpus
- def _fill_table(self, scores: np.ndarray, labels: List[str]) -> None:
+ self.Outputs.corpus.send(create_annotated_table(out_corpus, self.selected_rows))
+ self.Outputs.selected_documents.send(
+ out_corpus[self.selected_rows] if self.selected_rows else None
+ )
+
+ def _fill_table(self) -> None:
"""
Fill the table in the widget with scores and document names
"""
if self.corpus is None:
self.model.clear()
return
+ scores, labels = self._gather_scores()
labels = ["Document"] + labels
titles = self.corpus.titles.tolist()
+
+ # clearing selection and sorting to prevent SEGFAULT on model.wrap
+ self.view.horizontalHeader().setSortIndicator(-1, Qt.AscendingOrder)
+ with disconnected(
+ self.view.selectionModel().selectionChanged, self.__on_selection_change
+ ):
+ self.view.clearSelection()
+
self.model.wrap([[c] + s for c, s in zip(titles, scores.tolist())])
self.model.setHorizontalHeaderLabels(labels)
self.view.update_column_widths()
+ if self.model.columnCount() > self.sort_column_order[0]:
+ # if not enough columns do not apply sorting from settings since
+ # sorting can besaved for score column while scores are still computing
+ # tables is filled before scores are computed with document names
+ self.view.horizontalHeader().setSortIndicator(*self.sort_column_order)
- self.view.horizontalHeader().setSortIndicator(*self.sort_column_order)
+ self._select_rows()
def _fill_and_output(self) -> None:
"""Fill the table in the widget and send the output"""
- scores, labels = self._gather_scores()
- self._fill_table(scores, labels)
- self._send_output(scores, labels)
+ self._fill_table()
+ self._send_output()
def _clear_and_run(self) -> None:
"""Clear cached scores and commit"""
@@ -500,14 +625,12 @@ def commit(self) -> None:
self._fill_and_output()
def on_done(self, _: None) -> None:
- scores, labels = self._gather_scores()
- self._send_output(scores, labels)
+ self._send_output()
def on_partial_result(self, result: Tuple[str, str, np.ndarray]) -> None:
sc_method, aggregation, scores = result
self.scores[(sc_method, aggregation)] = scores
- scores, labels = self._gather_scores()
- self._fill_table(scores, labels)
+ self._fill_table()
def on_exception(self, ex: Exception) -> None:
self.Error.custom_err(ex)
@@ -543,6 +666,49 @@ def _is_corpus_normalized(corpus: Corpus) -> bool:
for pp in corpus.used_preprocessor.preprocessors
)
+ def get_selected_indices(self) -> List[int]:
+ # get indices in table's order - that the selected output table have same order
+ selected_rows = sorted(
+ self.view.selectionModel().selectedRows(), key=lambda idx: idx.row()
+ )
+ return [self.view.model().mapToSource(r).row() for r in selected_rows]
+
+ def _select_rows(self):
+ proxy_model = self.view.model()
+ n_rows, n_columns = proxy_model.rowCount(), proxy_model.columnCount()
+ if self.sel_method == SelectionMethods.NONE:
+ selection = QItemSelection()
+ elif self.sel_method == SelectionMethods.ALL:
+ selection = QItemSelection(
+ proxy_model.index(0, 0), proxy_model.index(n_rows - 1, n_columns - 1)
+ )
+ elif self.sel_method == SelectionMethods.MANUAL:
+ selection = QItemSelection()
+ new_sel = []
+ for row in self.selected_rows:
+ if row < n_rows:
+ new_sel.append(row)
+ _selection = QItemSelection(
+ self.model.index(row, 0), self.model.index(row, n_columns - 1)
+ )
+ selection.merge(
+ proxy_model.mapSelectionFromSource(_selection),
+ QItemSelectionModel.Select,
+ )
+ # selected rows must be updated when the same dataset with less rows
+ # appear at the input - it is not handled by selectionChanged
+ # in cases when all selected rows missing in new table
+ self.selected_rows = new_sel
+ elif self.sel_method == SelectionMethods.N_BEST:
+ n_sel = min(self.n_selected, n_rows)
+ selection = QItemSelection(
+ proxy_model.index(0, 0), proxy_model.index(n_sel - 1, n_columns - 1)
+ )
+ else:
+ raise NotImplementedError
+
+ self.view.selectionModel().select(selection, QItemSelectionModel.ClearAndSelect)
+
if __name__ == "__main__":
from orangewidget.utils.widgetpreview import WidgetPreview
diff --git a/orangecontrib/text/widgets/tests/test_owscoredocuments.py b/orangecontrib/text/widgets/tests/test_owscoredocuments.py
index 9d8d4fc20..bbd958b63 100644
--- a/orangecontrib/text/widgets/tests/test_owscoredocuments.py
+++ b/orangecontrib/text/widgets/tests/test_owscoredocuments.py
@@ -4,7 +4,8 @@
from unittest.mock import patch
import numpy as np
-from AnyQt.QtCore import Qt
+from AnyQt.QtCore import QItemSelectionModel, Qt
+
from Orange.data import ContinuousVariable, Domain, StringVariable, Table
from Orange.misc.collections import natural_sorted
from Orange.util import dummy_callback
@@ -15,6 +16,7 @@
from orangecontrib.text.vectorization.document_embedder import DocumentEmbedder
from orangecontrib.text.widgets.owscoredocuments import (
OWScoreDocuments,
+ SelectionMethods,
_preprocess_words,
)
@@ -74,8 +76,9 @@ def test_set_data(self):
output = self.get_output(self.widget.Outputs.corpus)
self.assertTupleEqual(output.domain.variables, self.corpus.domain.variables)
- self.assertTupleEqual(output.domain.metas[:-1], self.corpus.domain.metas)
- self.assertEqual(str(output.domain.metas[-1]), "Word count")
+ self.assertTupleEqual(output.domain.metas[:1], self.corpus.domain.metas)
+ self.assertEqual(str(output.domain.metas[1]), "Word count")
+ self.assertEqual(str(output.domain.metas[2]), "Selected")
self.assertEqual(len(output), len(self.corpus))
def test_corpus_not_normalized(self):
@@ -352,11 +355,14 @@ def test_sort_setting(self):
"""
view = self.widget.view
model = view.model()
- self.widget.sort_column_order = (1, Qt.DescendingOrder)
self.send_signal(self.widget.Inputs.corpus, self.corpus)
self.send_signal(self.widget.Inputs.words, self.words)
self.wait_until_finished()
+ self.widget.sort_column_order = (1, Qt.DescendingOrder)
+ self.widget._fill_table()
+ self.wait_until_finished()
+
header = self.widget.view.horizontalHeader()
current_sorting = (header.sortIndicatorSection(), header.sortIndicatorOrder())
data = [model.data(model.index(i, 1)) for i in range(model.rowCount())]
@@ -385,6 +391,49 @@ def test_sort_setting(self):
self.assertTupleEqual((1, Qt.DescendingOrder), current_sorting)
self.assertListEqual(sorted(data, reverse=True), data)
+ def test_selection_none(self):
+ self.send_signal(self.widget.Inputs.corpus, self.corpus)
+ self.widget._sel_method_buttons.button(SelectionMethods.NONE).click()
+
+ output = self.get_output(self.widget.Outputs.selected_documents)
+ self.assertIsNone(output)
+
+ def tests_selection_all(self):
+ self.send_signal(self.widget.Inputs.corpus, self.corpus)
+ self.widget._sel_method_buttons.button(SelectionMethods.ALL).click()
+
+ output = self.get_output(self.widget.Outputs.selected_documents)
+ self.assertEqual(len(self.corpus), len(output))
+
+ def test_selection_manual(self):
+ self.send_signal(self.widget.Inputs.corpus, self.corpus)
+ self.widget._sel_method_buttons.button(SelectionMethods.MANUAL).click()
+
+ mode = QItemSelectionModel.Rows | QItemSelectionModel.Select
+ view = self.widget.view
+ view.clearSelection()
+ model = view.model()
+ view.selectionModel().select(model.index(2, 0), mode)
+ view.selectionModel().select(model.index(3, 0), mode)
+
+ output = self.get_output(self.widget.Outputs.selected_documents)
+ self.assertListEqual(["Document 3", "Document 4"], output.titles.tolist())
+
+ def test_selection_n_best(self):
+ self.send_signal(self.widget.Inputs.corpus, self.corpus)
+ self.widget._sel_method_buttons.button(SelectionMethods.N_BEST).click()
+
+ output = self.get_output(self.widget.Outputs.selected_documents)
+ self.assertListEqual(
+ [f"Document {i}" for i in range(1, 4)], output.titles.tolist()
+ )
+
+ self.widget.controls.n_selected.setValue(5)
+ output = self.get_output(self.widget.Outputs.selected_documents)
+ self.assertListEqual(
+ [f"Document {i}" for i in range(1, 6)], output.titles.tolist()
+ )
+
if __name__ == "__main__":
unittest.main()