diff --git a/orangecontrib/text/widgets/owscoredocuments.py b/orangecontrib/text/widgets/owscoredocuments.py index b657d22db..b9cbe4301 100644 --- a/orangecontrib/text/widgets/owscoredocuments.py +++ b/orangecontrib/text/widgets/owscoredocuments.py @@ -4,6 +4,7 @@ from typing import List, Callable, Tuple, Union import numpy as np +from Orange.widgets.utils.annotated_data import create_annotated_table from pandas import isnull from Orange.data import ( Table, @@ -19,7 +20,7 @@ from orangewidget.settings import Setting from Orange.widgets.utils.itemmodels import PyTableModel, TableModel from AnyQt.QtWidgets import QTableView, QLineEdit, QHeaderView -from AnyQt.QtCore import Qt, QSortFilterProxyModel +from AnyQt.QtCore import Qt, QSortFilterProxyModel, QItemSelection from sklearn.metrics.pairwise import cosine_similarity from orangecontrib.text import Corpus @@ -189,7 +190,8 @@ def __init__(self): super().__init__( sortingEnabled=True, editTriggers=QTableView.NoEditTriggers, - selectionMode=QTableView.NoSelection, + selectionMode=QTableView.ExtendedSelection, + selectionBehavior=QTableView.SelectRows, cornerButtonEnabled=False, ) self.setItemDelegate(gui.ColoredBarItemDelegate(self)) @@ -273,6 +275,7 @@ class Inputs: words = Input("Words", Table) class Outputs: + selected_documents = Output("Selected documents", Corpus, default=True) corpus = Output("Corpus", Corpus) class Warning(OWWidget.Warning): @@ -347,6 +350,7 @@ def _setup_main_area(self) -> None: proxy_model.setFilterCaseSensitivity(False) view.setModel(proxy_model) view.model().setSourceModel(self.model) + self.view.selectionModel().selectionChanged.connect(self._send_output) def __on_filter_changed(self) -> None: model = self.view.model() @@ -426,10 +430,16 @@ def _gather_scores(self) -> Tuple[np.ndarray, List[str]]: labels = [SCORING_METHODS[m][0] for m in methods] return scores, labels - def _send_output(self, scores: np.ndarray, labels: List[str]) -> None: + def _send_output(self) -> None: """ Create corpus with scores and output it """ + if self.corpus is None: + self.Outputs.corpus.send(None) + self.Outputs.selected_documents.send(None) + return + + scores, labels = self._gather_scores() if labels: d = self.corpus.domain domain = Domain( @@ -437,39 +447,42 @@ def _send_output(self, scores: np.ndarray, labels: List[str]) -> None: d.class_var, metas=d.metas + tuple(ContinuousVariable(l) for l in labels), ) - corpus = Corpus( + out_corpus = Corpus( domain, self.corpus.X, self.corpus.Y, np.hstack([self.corpus.metas, scores]), ) - Corpus.retain_preprocessing(self.corpus, corpus) - self.Outputs.corpus.send(corpus) - elif self.corpus is not None: - self.Outputs.corpus.send(self.corpus) + Corpus.retain_preprocessing(self.corpus, out_corpus) else: - self.Outputs.corpus.send(None) + out_corpus = self.corpus - def _fill_table(self, scores: np.ndarray, labels: List[str]) -> None: + selected_indices = self.get_selected_indices() + self.Outputs.corpus.send(create_annotated_table(out_corpus, selected_indices)) + self.Outputs.selected_documents.send(out_corpus[selected_indices] if selected_indices else None) + + def _fill_table(self) -> None: """ Fill the table in the widget with scores and document names """ if self.corpus is None: self.model.clear() return + scores, labels = self._gather_scores() labels = ["Document"] + labels titles = self.corpus.titles.tolist() + + selected_indices = [i.row() for i in self.view.selectionModel().selectedRows()] self.model.wrap([[c] + s for c, s in zip(titles, scores.tolist())]) self.model.setHorizontalHeaderLabels(labels) self.view.update_column_widths() - self.view.horizontalHeader().setSortIndicator(*self.sort_column_order) + self._select_rows(selected_indices) def _fill_and_output(self) -> None: """ Fill the table in the widget and send the output """ - scores, labels = self._gather_scores() - self._fill_table(scores, labels) - self._send_output(scores, labels) + self._fill_table() + self._send_output() def _clear_and_run(self) -> None: """ Clear cached scores and commit """ @@ -512,14 +525,12 @@ def commit(self) -> None: self._fill_and_output() def on_done(self, _: None) -> None: - scores, labels = self._gather_scores() - self._send_output(scores, labels) + self._send_output() def on_partial_result(self, result: Tuple[str, str, np.ndarray]) -> None: sc_method, aggregation, scores = result self.scores[(sc_method, aggregation)] = scores - scores, labels = self._gather_scores() - self._fill_table(scores, labels) + self._fill_table() def on_exception(self, ex: Exception) -> None: self.Error.unknown_err(ex) @@ -554,6 +565,16 @@ def _is_corpus_normalized(corpus: Corpus) -> bool: for pp in corpus.used_preprocessor.preprocessors ) + def get_selected_indices(self): + selected_rows = self.view.selectionModel().selectedRows() + return [self.view.model().mapToSource(r).row() for r in selected_rows] + + def _select_rows(self, selected_indices): + self.view.setSelectionMode(QTableView.MultiSelection) + for row in selected_indices: + self.view.selectRow(row) + self.view.setSelectionMode(QTableView.ExtendedSelection) + if __name__ == "__main__": from orangewidget.utils.widgetpreview import WidgetPreview diff --git a/orangecontrib/text/widgets/tests/test_owscoredocuments.py b/orangecontrib/text/widgets/tests/test_owscoredocuments.py index 69b31e090..f8b7b62f6 100644 --- a/orangecontrib/text/widgets/tests/test_owscoredocuments.py +++ b/orangecontrib/text/widgets/tests/test_owscoredocuments.py @@ -4,7 +4,7 @@ from unittest.mock import patch import numpy as np -from AnyQt.QtCore import Qt +from AnyQt.QtCore import Qt, QItemSelectionModel from Orange.widgets.tests.base import WidgetTest from Orange.data import Table, StringVariable, Domain, ContinuousVariable from Orange.widgets.tests.utils import simulate @@ -75,9 +75,10 @@ def test_set_data(self): output.domain.variables, self.corpus.domain.variables ) self.assertTupleEqual( - output.domain.metas[:-1], self.corpus.domain.metas + output.domain.metas[:1], self.corpus.domain.metas ) - self.assertEqual(str(output.domain.metas[-1]), "Word count") + self.assertEqual(str(output.domain.metas[1]), "Word count") + self.assertEqual(str(output.domain.metas[2]), "Selected") self.assertEqual(len(output), len(self.corpus)) def test_corpus_not_normalized(self): @@ -336,6 +337,56 @@ def test_sort_setting(self): self.assertTupleEqual((1, Qt.DescendingOrder), current_sorting) self.assertListEqual(sorted(data, reverse=True), data) + def test_selection(self): + self.send_signal(self.widget.Inputs.corpus, self.corpus) + + mode = QItemSelectionModel.Rows | QItemSelectionModel.Select + view = self.widget.view + model = view.model() + view.selectionModel().select(model.index(2, 0), mode) + view.selectionModel().select(model.index(3, 0), mode) + + output = self.get_output(self.widget.Outputs.selected_documents) + self.assertListEqual(["Document 3", "Document 4"], output.titles.tolist()) + + # change table order and check if correct selection at the output + view.sortByColumn(0, Qt.DescendingOrder) + view.clearSelection() + view.selectionModel().select(model.index(0, 0), mode) + view.selectionModel().select(model.index(1, 0), mode) + + output = self.get_output(self.widget.Outputs.selected_documents) + self.assertListEqual(["Document 140", "Document 139"], output.titles.tolist()) + view.clearSelection() + + def test_selection_socres(self): + # repeat with scores + self.send_signal(self.widget.Inputs.corpus, self.corpus) + words = self.create_words_table(["house", "doctor", "boy", "way", "Rum"]) + self.send_signal(self.widget.Inputs.words, words) + self.wait_until_finished() + + mode = QItemSelectionModel.Rows | QItemSelectionModel.Select + view = self.widget.view + model = view.model() + + model = self.widget.model + view = self.widget.view + view.selectionModel().select(model.index(2, 0), mode) + view.selectionModel().select(model.index(3, 0), mode) + + output = self.get_output(self.widget.Outputs.selected_documents) + self.assertListEqual(["Document 3", "Document 4"], output.titles.tolist()) + + # change table order and check if correct selection at the output + view.sortByColumn(0, Qt.DescendingOrder) + view.clearSelection() + view.selectionModel().select(model.index(0, 0), mode) + view.selectionModel().select(model.index(1, 0), mode) + + output = self.get_output(self.widget.Outputs.selected_documents) + self.assertListEqual(["Document 140", "Document 139"], output.titles.tolist()) + if __name__ == "__main__": unittest.main()